fix: handle connection closure and retry logic in session management
Some checks failed
Build-docs / build (push) Has been cancelled
Pyrofork / build (macos-latest, 3.10) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.11) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.12) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.13) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.9) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.10) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.11) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.12) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.13) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.9) (push) Has been cancelled

Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
Hitalo M. 2025-04-17 13:49:26 -03:00 committed by wulan17
parent ff05420048
commit 5a345f99d1
No known key found for this signature in database
GPG key ID: 737814D4B5FF0420
2 changed files with 44 additions and 2 deletions

View file

@ -54,6 +54,14 @@ class TCP:
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._closed = True
@property
def closed(self) -> bool:
return (
self._closed or self.writer is None or self.writer.is_closing() or self.reader is None
)
async def _connect_via_proxy( async def _connect_via_proxy(
self, self,
@ -123,11 +131,14 @@ class TCP:
async def connect(self, address: Tuple[str, int]) -> None: async def connect(self, address: Tuple[str, int]) -> None:
try: try:
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
self._closed = False
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
self._closed = True
raise TimeoutError("Connection timed out") raise TimeoutError("Connection timed out")
async def close(self) -> None: async def close(self) -> None:
if self.writer is None: if self.writer is None:
self._closed = True
return None return None
try: try:
@ -135,10 +146,12 @@ class TCP:
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
except Exception as e: except Exception as e:
log.info("Close exception: %s %s", type(e).__name__, e) log.info("Close exception: %s %s", type(e).__name__, e)
finally:
self._closed = True
async def send(self, data: bytes) -> None: async def send(self, data: bytes) -> None:
if self.writer is None: if self.writer is None or self._closed:
return None raise OSError("Connection is closed")
async with self.lock: async with self.lock:
try: try:
@ -146,9 +159,13 @@ class TCP:
await self.writer.drain() await self.writer.drain()
except Exception as e: except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e) log.info("Send exception: %s %s", type(e).__name__, e)
self._closed = True
raise OSError(e) raise OSError(e)
async def recv(self, length: int = 0) -> Optional[bytes]: async def recv(self, length: int = 0) -> Optional[bytes]:
if self._closed or self.reader is None:
return None
data = b"" data = b""
while len(data) < length: while len(data) < length:
@ -158,11 +175,13 @@ class TCP:
TCP.TIMEOUT TCP.TIMEOUT
) )
except (OSError, asyncio.TimeoutError): except (OSError, asyncio.TimeoutError):
self._closed = True
return None return None
else: else:
if chunk: if chunk:
data += chunk data += chunk
else: else:
self._closed = True
return None return None
return data return data

View file

@ -417,6 +417,19 @@ class Session:
while True: while True:
try: try:
if (
self.connection is None
or self.connection.protocol is None
or getattr(self.connection.protocol, "closed", True)
):
log.warning(
"[%s] Connection is closed or not established. Attempting to reconnect...",
self.client.name,
)
await self.restart()
await asyncio.sleep(1)
continue
return await self.send(query, timeout=timeout) return await self.send(query, timeout=timeout)
except (FloodWait, FloodPremiumWait) as e: except (FloodWait, FloodPremiumWait) as e:
amount = e.value amount = e.value
@ -438,6 +451,16 @@ class Session:
query_name, str(e) or repr(e) query_name, str(e) or repr(e)
) )
if isinstance(e, OSError) and retries > 1:
try:
await self.restart()
except Exception as restart_error:
log.warning(
"[%s] Failed to restart session: %s",
self.client.name,
str(restart_error) or repr(restart_error),
)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
return await self.invoke(query, retries - 1, timeout) return await self.invoke(query, retries - 1, timeout)