From 5a345f99d1b7c3c5e7bc932bba0c27bff4bf8e35 Mon Sep 17 00:00:00 2001 From: "Hitalo M." Date: Thu, 17 Apr 2025 13:49:26 -0300 Subject: [PATCH] fix: handle connection closure and retry logic in session management Signed-off-by: wulan17 --- pyrogram/connection/transport/tcp/tcp.py | 23 +++++++++++++++++++++-- pyrogram/session/session.py | 23 +++++++++++++++++++++++ 2 files changed, 44 insertions(+), 2 deletions(-) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 1848ba35..4957c061 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -54,6 +54,14 @@ class TCP: self.lock = asyncio.Lock() self.loop = asyncio.get_event_loop() + self._closed = True + + @property + def closed(self) -> bool: + return ( + self._closed or self.writer is None or self.writer.is_closing() or self.reader is None + ) + async def _connect_via_proxy( self, @@ -123,11 +131,14 @@ class TCP: async def connect(self, address: Tuple[str, int]) -> None: try: await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) + self._closed = False except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 + self._closed = True raise TimeoutError("Connection timed out") async def close(self) -> None: if self.writer is None: + self._closed = True return None try: @@ -135,10 +146,12 @@ class TCP: await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) except Exception as e: log.info("Close exception: %s %s", type(e).__name__, e) + finally: + self._closed = True async def send(self, data: bytes) -> None: - if self.writer is None: - return None + if self.writer is None or self._closed: + raise OSError("Connection is closed") async with self.lock: try: @@ -146,9 +159,13 @@ class TCP: await self.writer.drain() except Exception as e: log.info("Send exception: %s %s", type(e).__name__, e) + self._closed = True raise OSError(e) async def recv(self, length: int = 0) -> Optional[bytes]: + if self._closed or self.reader is None: + return None + data = b"" while len(data) < length: @@ -158,11 +175,13 @@ class TCP: TCP.TIMEOUT ) except (OSError, asyncio.TimeoutError): + self._closed = True return None else: if chunk: data += chunk else: + self._closed = True return None return data diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index c84a3cb4..57d16c95 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -417,6 +417,19 @@ class Session: while True: try: + if ( + self.connection is None + or self.connection.protocol is None + or getattr(self.connection.protocol, "closed", True) + ): + log.warning( + "[%s] Connection is closed or not established. Attempting to reconnect...", + self.client.name, + ) + await self.restart() + await asyncio.sleep(1) + continue + return await self.send(query, timeout=timeout) except (FloodWait, FloodPremiumWait) as e: amount = e.value @@ -438,6 +451,16 @@ class Session: query_name, str(e) or repr(e) ) + if isinstance(e, OSError) and retries > 1: + try: + await self.restart() + except Exception as restart_error: + log.warning( + "[%s] Failed to restart session: %s", + self.client.name, + str(restart_error) or repr(restart_error), + ) + await asyncio.sleep(0.5) return await self.invoke(query, retries - 1, timeout)