mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2026-01-09 08:14:50 +00:00
Compare commits
2 commits
632921b4b2
...
a5cd3b92a6
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
a5cd3b92a6 | ||
|
|
ce356e02f5 |
3 changed files with 12 additions and 47 deletions
|
|
@ -54,14 +54,6 @@ class TCP:
|
|||
|
||||
self.lock = asyncio.Lock()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._closed = True
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return (
|
||||
self._closed or self.writer is None or self.writer.is_closing() or self.reader is None
|
||||
)
|
||||
|
||||
|
||||
async def _connect_via_proxy(
|
||||
self,
|
||||
|
|
@ -131,14 +123,11 @@ class TCP:
|
|||
async def connect(self, address: Tuple[str, int]) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
|
||||
self._closed = False
|
||||
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
|
||||
self._closed = True
|
||||
raise TimeoutError("Connection timed out")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.writer is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
try:
|
||||
|
|
@ -146,12 +135,10 @@ class TCP:
|
|||
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
|
||||
except Exception as e:
|
||||
log.info("Close exception: %s %s", type(e).__name__, e)
|
||||
finally:
|
||||
self._closed = True
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
if self.writer is None or self._closed:
|
||||
raise OSError("Connection is closed")
|
||||
if self.writer is None:
|
||||
return None
|
||||
|
||||
async with self.lock:
|
||||
try:
|
||||
|
|
@ -159,13 +146,9 @@ class TCP:
|
|||
await self.writer.drain()
|
||||
except Exception as e:
|
||||
log.info("Send exception: %s %s", type(e).__name__, e)
|
||||
self._closed = True
|
||||
raise OSError(e)
|
||||
|
||||
async def recv(self, length: int = 0) -> Optional[bytes]:
|
||||
if self._closed or self.reader is None:
|
||||
return None
|
||||
|
||||
data = b""
|
||||
|
||||
while len(data) < length:
|
||||
|
|
@ -175,13 +158,11 @@ class TCP:
|
|||
TCP.TIMEOUT
|
||||
)
|
||||
except (OSError, asyncio.TimeoutError):
|
||||
self._closed = True
|
||||
return None
|
||||
else:
|
||||
if chunk:
|
||||
data += chunk
|
||||
else:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -24,7 +24,11 @@ from pyrogram.filters import Filter
|
|||
|
||||
|
||||
class OnError:
|
||||
def on_error(self=None, errors=None) -> Callable:
|
||||
def on_error(
|
||||
self=None,
|
||||
errors=None,
|
||||
group: int = 0,
|
||||
) -> Callable:
|
||||
"""Decorator for handling new errors.
|
||||
|
||||
This does the same thing as :meth:`~pyrogram.Client.add_handler` using the
|
||||
|
|
@ -34,16 +38,19 @@ class OnError:
|
|||
errors (:obj:`~Exception`, *optional*):
|
||||
Pass one or more errors to allow only a subset of errors to be passed
|
||||
in your function.
|
||||
|
||||
group (``int``, *optional*):
|
||||
The group identifier, defaults to 0.
|
||||
"""
|
||||
|
||||
def decorator(func: Callable) -> Callable:
|
||||
if isinstance(self, pyrogram.Client):
|
||||
self.add_handler(pyrogram.handlers.ErrorHandler(func, errors), 0)
|
||||
self.add_handler(pyrogram.handlers.ErrorHandler(func, errors), group)
|
||||
elif isinstance(self, Filter) or self is None:
|
||||
if not hasattr(func, "handlers"):
|
||||
func.handlers = []
|
||||
|
||||
func.handlers.append((pyrogram.handlers.ErrorHandler(func, self), 0))
|
||||
func.handlers.append((pyrogram.handlers.ErrorHandler(func, self), group))
|
||||
|
||||
return func
|
||||
|
||||
|
|
|
|||
|
|
@ -417,19 +417,6 @@ class Session:
|
|||
|
||||
while True:
|
||||
try:
|
||||
if (
|
||||
self.connection is None
|
||||
or self.connection.protocol is None
|
||||
or getattr(self.connection.protocol, "closed", True)
|
||||
):
|
||||
log.warning(
|
||||
"[%s] Connection is closed or not established. Attempting to reconnect...",
|
||||
self.client.name,
|
||||
)
|
||||
await self.restart()
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
return await self.send(query, timeout=timeout)
|
||||
except (FloodWait, FloodPremiumWait) as e:
|
||||
amount = e.value
|
||||
|
|
@ -451,16 +438,6 @@ class Session:
|
|||
query_name, str(e) or repr(e)
|
||||
)
|
||||
|
||||
if isinstance(e, OSError) and retries > 1:
|
||||
try:
|
||||
await self.restart()
|
||||
except Exception as restart_error:
|
||||
log.warning(
|
||||
"[%s] Failed to restart session: %s",
|
||||
self.client.name,
|
||||
str(restart_error) or repr(restart_error),
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return await self.invoke(query, retries - 1, timeout)
|
||||
|
|
|
|||
Loading…
Reference in a new issue