Reapply "fix: handle connection closure and retry logic in session management"

This reverts commit 2c3fb1caa6.

Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
wulan17 2025-06-07 20:58:36 +07:00
parent 3aed1857d4
commit 375e97d963
No known key found for this signature in database
GPG key ID: 737814D4B5FF0420
2 changed files with 40 additions and 0 deletions

View file

@ -59,6 +59,13 @@ class TCP:
self.loop = loop self.loop = loop
else: else:
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._closed = True
@property
def closed(self) -> bool:
return (
self._closed or self.writer is None or self.writer.is_closing() or self.reader is None
)
async def _connect_via_proxy( async def _connect_via_proxy(
self, self,
@ -126,11 +133,14 @@ class TCP:
async def connect(self, address: Tuple[str, int]) -> None: async def connect(self, address: Tuple[str, int]) -> None:
try: try:
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
self._closed = False
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
self._closed = True
raise TimeoutError("Connection timed out") raise TimeoutError("Connection timed out")
async def close(self) -> None: async def close(self) -> None:
if self.writer is None: if self.writer is None:
self._closed = True
return None return None
try: try:
@ -140,6 +150,7 @@ class TCP:
log.info("Close exception: %s %s", type(e).__name__, e) log.info("Close exception: %s %s", type(e).__name__, e)
finally: finally:
self.writer = None self.writer = None
self._closed = True
async def send(self, data: bytes) -> None: async def send(self, data: bytes) -> None:
async with self.lock: async with self.lock:
@ -151,9 +162,13 @@ class TCP:
await self.writer.drain() await self.writer.drain()
except Exception as e: except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e) log.info("Send exception: %s %s", type(e).__name__, e)
self._closed = True
raise OSError(e) raise OSError(e)
async def recv(self, length: int = 0) -> Optional[bytes]: async def recv(self, length: int = 0) -> Optional[bytes]:
if self._closed or self.reader is None:
return None
data = b"" data = b""
while len(data) < length: while len(data) < length:
@ -163,11 +178,13 @@ class TCP:
TCP.TIMEOUT TCP.TIMEOUT
) )
except (OSError, asyncio.TimeoutError): except (OSError, asyncio.TimeoutError):
self._closed = True
return None return None
else: else:
if chunk: if chunk:
data += chunk data += chunk
else: else:
self._closed = True
return None return None
return data return data

View file

@ -432,6 +432,19 @@ class Session:
while retries > 0: while retries > 0:
try: try:
if (
self.connection is None
or self.connection.protocol is None
or getattr(self.connection.protocol, "closed", True)
):
log.warning(
"[%s] Connection is closed or not established. Attempting to reconnect...",
self.client.name,
)
await self.restart()
await asyncio.sleep(1)
continue
return await self.send(query, timeout=timeout) return await self.send(query, timeout=timeout)
except (FloodWait, FloodPremiumWait) as e: except (FloodWait, FloodPremiumWait) as e:
amount = e.value amount = e.value
@ -454,6 +467,16 @@ class Session:
query_name, str(e) or repr(e) query_name, str(e) or repr(e)
) )
if isinstance(e, OSError) and retries > 1:
try:
await self.restart()
except Exception as restart_error:
log.warning(
"[%s] Failed to restart session: %s",
self.client.name,
str(restart_error) or repr(restart_error),
)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
raise TimeoutError("Exceeded maximum number of retries") raise TimeoutError("Exceeded maximum number of retries")