From 1a4c578380962e7afe5fc39e0a5d0ae73221e663 Mon Sep 17 00:00:00 2001 From: KurimuzonAkuma Date: Wed, 5 Mar 2025 09:17:55 +0300 Subject: [PATCH] Refactor loop Signed-off-by: wulan17 --- pyrogram/client.py | 36 ++++++++++++------- pyrogram/connection/connection.py | 13 +++++-- pyrogram/connection/transport/tcp/tcp.py | 11 ++++-- .../connection/transport/tcp/tcp_abridged.py | 5 +-- .../transport/tcp/tcp_abridged_o.py | 5 +-- pyrogram/connection/transport/tcp/tcp_full.py | 5 +-- .../transport/tcp/tcp_intermediate.py | 5 +-- .../transport/tcp/tcp_intermediate_o.py | 5 +-- pyrogram/dispatcher.py | 9 +++-- pyrogram/methods/messages/download_media.py | 2 +- pyrogram/session/auth.py | 4 ++- pyrogram/session/session.py | 23 ++++++------ pyrogram/utils.py | 14 ++++++-- 13 files changed, 89 insertions(+), 48 deletions(-) diff --git a/pyrogram/client.py b/pyrogram/client.py index 1e8fef35..1e53e8f6 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -33,7 +33,7 @@ from importlib import import_module from io import StringIO, BytesIO from mimetypes import MimeTypes from pathlib import Path -from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple +from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple import pyrogram from pyrogram import __version__, __license__ @@ -55,7 +55,7 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage from pyrogram.types import User from pyrogram.utils import ainput from .connection import Connection -from .connection.transport import TCPAbridged +from .connection.transport import TCP, TCPAbridged from .dispatcher import Dispatcher from .file_id import FileId, FileType, ThumbnailSource from .mime_types import mime_types @@ -220,6 +220,9 @@ class Client(Methods): client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*): The platform where this client is running. Defaults to 'other' + + loop (:py:class:`asyncio.AbstractEventLoop`, *optional*): + Event loop. """ APP_VERSION = f"Pyrogram {__version__}" @@ -277,7 +280,9 @@ class Client(Methods): max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS, client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER, max_message_cache_size: int = MAX_CACHE_SIZE, - max_business_user_connection_cache_size: int = MAX_CACHE_SIZE + max_business_user_connection_cache_size: int = MAX_CACHE_SIZE, + protocol_factory: Type[TCP] = TCPAbridged, + loop: Optional[asyncio.AbstractEventLoop] = None ): super().__init__() @@ -371,7 +376,14 @@ class Client(Methods): self.updates_watchdog_event = asyncio.Event() self.last_update_time = datetime.now() self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes} - self.loop = asyncio.get_event_loop() + + if isinstance(loop, asyncio.AbstractEventLoop): + self.loop = loop + else: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() def __enter__(self): return self.start() @@ -425,7 +437,7 @@ class Client(Methods): if not self.phone_number: while True: print("Enter 'qrcode' if you want to login with qrcode.") - value = await ainput("Enter phone number or bot token: ") + value = await ainput("Enter phone number or bot token: ", loop=self.loop) if not value: continue @@ -434,7 +446,7 @@ class Client(Methods): self.use_qrcode = True break - confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower() + confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower() if confirm == "y": break @@ -466,7 +478,7 @@ class Client(Methods): while True: if not self.use_qrcode and not self.phone_code: - self.phone_code = await ainput("Enter confirmation code: ") + self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop) try: if self.use_qrcode: @@ -483,18 +495,18 @@ class Client(Methods): print("Password hint: {}".format(await self.get_password_hint())) if not self.password: - self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password) + self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop) try: if not self.password: - confirm = await ainput("Confirm password recovery (y/n): ") + confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop) if confirm == "y": email_pattern = await self.send_recovery_code() print(f"The recovery code has been sent to {email_pattern}") while True: - recovery_code = await ainput("Enter recovery code: ") + recovery_code = await ainput("Enter recovery code: ", loop=self.loop) try: return await self.recover_password(recovery_code) @@ -842,13 +854,13 @@ class Client(Methods): else: while True: try: - value = int(await ainput("Enter the api_id part of the API key: ")) + value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop)) if value <= 0: print("Invalid value") continue - confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower() + confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower() if confirm == "y": await self.storage.api_id(value) diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 1f100acc..a7f6acdb 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -38,7 +38,8 @@ class Connection: alt_port: bool, proxy: dict, media: bool = False, - protocol_factory: Type[TCP] = TCPAbridged + protocol_factory: Type[TCP] = TCPAbridged, + loop: Optional[asyncio.AbstractEventLoop] = None ) -> None: self.dc_id = dc_id self.test_mode = test_mode @@ -51,9 +52,17 @@ class Connection: self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media) self.protocol: Optional[TCP] = None + if isinstance(loop, asyncio.AbstractEventLoop): + self.loop = loop + else: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() + async def connect(self) -> None: for _ in range(Connection.MAX_CONNECTION_ATTEMPTS): - self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy) + self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop) try: log.info("Connecting...") diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 3e403b33..fcfbe6ad 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -46,7 +46,7 @@ class Proxy(TypedDict): class TCP: TIMEOUT = 10 - def __init__(self, ipv6: bool, proxy: Proxy) -> None: + def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: self.ipv6 = ipv6 self.proxy = proxy @@ -54,7 +54,14 @@ class TCP: self.writer: Optional[asyncio.StreamWriter] = None self.lock = asyncio.Lock() - self.loop = asyncio.get_event_loop() + + if isinstance(loop, asyncio.AbstractEventLoop): + self.loop = loop + else: + try: + self.loop = asyncio.get_running_loop() + except RuntimeError: + self.loop = asyncio.new_event_loop() async def _connect_via_proxy( self, diff --git a/pyrogram/connection/transport/tcp/tcp_abridged.py b/pyrogram/connection/transport/tcp/tcp_abridged.py index f23d26fd..4381220d 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrofork. If not, see . +import asyncio import logging from typing import Optional, Tuple @@ -26,8 +27,8 @@ log = logging.getLogger(__name__) class TCPAbridged(TCP): - def __init__(self, ipv6: bool, proxy: Proxy) -> None: - super().__init__(ipv6, proxy) + def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(ipv6, proxy, loop) async def connect(self, address: Tuple[str, int]) -> None: await super().connect(address) diff --git a/pyrogram/connection/transport/tcp/tcp_abridged_o.py b/pyrogram/connection/transport/tcp/tcp_abridged_o.py index 0b856728..3bd74725 100644 --- a/pyrogram/connection/transport/tcp/tcp_abridged_o.py +++ b/pyrogram/connection/transport/tcp/tcp_abridged_o.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrofork. If not, see . +import asyncio import logging import os from typing import Optional, Tuple @@ -31,8 +32,8 @@ log = logging.getLogger(__name__) class TCPAbridgedO(TCP): RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) - def __init__(self, ipv6: bool, proxy: Proxy) -> None: - super().__init__(ipv6, proxy) + def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(ipv6, proxy, loop) self.encrypt = None self.decrypt = None diff --git a/pyrogram/connection/transport/tcp/tcp_full.py b/pyrogram/connection/transport/tcp/tcp_full.py index 894ca0e2..f59c1558 100644 --- a/pyrogram/connection/transport/tcp/tcp_full.py +++ b/pyrogram/connection/transport/tcp/tcp_full.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrofork. If not, see . +import asyncio import logging from binascii import crc32 from struct import pack, unpack @@ -28,8 +29,8 @@ log = logging.getLogger(__name__) class TCPFull(TCP): - def __init__(self, ipv6: bool, proxy: Proxy) -> None: - super().__init__(ipv6, proxy) + def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(ipv6, proxy, loop) self.seq_no: Optional[int] = None diff --git a/pyrogram/connection/transport/tcp/tcp_intermediate.py b/pyrogram/connection/transport/tcp/tcp_intermediate.py index ed1c1e57..aa3757e1 100644 --- a/pyrogram/connection/transport/tcp/tcp_intermediate.py +++ b/pyrogram/connection/transport/tcp/tcp_intermediate.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrofork. If not, see . +import asyncio import logging from struct import pack, unpack from typing import Optional, Tuple @@ -27,8 +28,8 @@ log = logging.getLogger(__name__) class TCPIntermediate(TCP): - def __init__(self, ipv6: bool, proxy: Proxy) -> None: - super().__init__(ipv6, proxy) + def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(ipv6, proxy, loop) async def connect(self, address: Tuple[str, int]) -> None: await super().connect(address) diff --git a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py index 61131b33..8a186ef4 100644 --- a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py +++ b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py @@ -17,6 +17,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrofork. If not, see . +import asyncio import logging import os from struct import pack, unpack @@ -31,8 +32,8 @@ log = logging.getLogger(__name__) class TCPIntermediateO(TCP): RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) - def __init__(self, ipv6: bool, proxy: Proxy) -> None: - super().__init__(ipv6, proxy) + def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None: + super().__init__(ipv6, proxy, loop) self.encrypt = None self.decrypt = None diff --git a/pyrogram/dispatcher.py b/pyrogram/dispatcher.py index c5f166d5..7d201d60 100644 --- a/pyrogram/dispatcher.py +++ b/pyrogram/dispatcher.py @@ -95,7 +95,6 @@ class Dispatcher: def __init__(self, client: "pyrogram.Client"): self.client = client - self.loop = asyncio.get_event_loop() self.handler_worker_tasks = [] self.locks_list = [] @@ -271,7 +270,7 @@ class Dispatcher: self.locks_list.append(asyncio.Lock()) self.handler_worker_tasks.append( - self.loop.create_task(self.handler_worker(self.locks_list[-1])) + self.client.loop.create_task(self.handler_worker(self.locks_list[-1])) ) log.info("Started %s HandlerTasks", self.client.workers) @@ -311,7 +310,7 @@ class Dispatcher: for lock in self.locks_list: lock.release() - self.loop.create_task(fn()) + self.client.loop.create_task(fn()) def remove_handler(self, handler, group: int): async def fn(): @@ -333,7 +332,7 @@ class Dispatcher: for lock in self.locks_list: lock.release() - self.loop.create_task(fn()) + self.client.loop.create_task(fn()) async def handler_worker(self, lock: asyncio.Lock): while True: @@ -404,7 +403,7 @@ class Dispatcher: if inspect.iscoroutinefunction(handler.callback): await handler.callback(self.client, *args) else: - await self.loop.run_in_executor( + await self.client.loop.run_in_executor( self.client.executor, handler.callback, self.client, diff --git a/pyrogram/methods/messages/download_media.py b/pyrogram/methods/messages/download_media.py index 6ba62526..2099e677 100644 --- a/pyrogram/methods/messages/download_media.py +++ b/pyrogram/methods/messages/download_media.py @@ -185,4 +185,4 @@ class DownloadMedia: if block: return await downloader else: - asyncio.get_event_loop().create_task(downloader) + self.loop.create_task(downloader) diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index 346b0f92..23b6e8d7 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -52,6 +52,7 @@ class Auth: self.proxy = client.proxy self.connection_factory = client.connection_factory self.protocol_factory = client.protocol_factory + self.loop = client.loop self.connection: Optional[Connection] = None @@ -93,7 +94,8 @@ class Auth: alt_port=self.alt_port, proxy=self.proxy, media=False, - protocol_factory=self.protocol_factory + protocol_factory=self.protocol_factory, + loop=self.loop ) try: diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 79801086..3424b42e 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -104,8 +104,6 @@ class Session: self.is_started = asyncio.Event() - self.loop = asyncio.get_event_loop() - self.last_reconnect_attempt = None async def start(self): @@ -117,13 +115,14 @@ class Session: alt_port=self.client.alt_port, proxy=self.client.proxy, media=self.is_media, - protocol_factory=self.client.protocol_factory + protocol_factory=self.client.protocol_factory, + loop=self.client.loop ) try: await self.connection.connect() - self.recv_task = self.loop.create_task(self.recv_worker()) + self.recv_task = self.client.loop.create_task(self.recv_worker()) await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT) @@ -145,7 +144,7 @@ class Session: timeout=self.START_TIMEOUT ) - self.ping_task = self.loop.create_task(self.ping_worker()) + self.ping_task = self.client.loop.create_task(self.ping_worker()) log.info("Session initialized: Layer %s", layer) log.info("Device: %s - %s", self.client.device_model, self.client.app_version) @@ -207,7 +206,7 @@ class Session: async def handle_packet(self, packet): try: - data = await self.loop.run_in_executor( + data = await self.client.loop.run_in_executor( pyrogram.crypto_executor, mtproto.unpack, BytesIO(packet), @@ -217,7 +216,7 @@ class Session: ) except ValueError as e: log.debug(e) - self.loop.create_task(self.restart()) + self.client.loop.create_task(self.restart()) return messages = ( @@ -279,7 +278,7 @@ class Session: msg_id = msg.body.msg_id else: if self.client is not None: - self.loop.create_task(self.client.handle_updates(msg.body)) + self.client.loop.create_task(self.client.handle_updates(msg.body)) if msg_id in self.results: self.results[msg_id].value = getattr(msg.body, "result", msg.body) @@ -313,7 +312,7 @@ class Session: ), False ) except OSError: - self.loop.create_task(self.restart()) + self.client.loop.create_task(self.restart()) break except RPCError: pass @@ -342,11 +341,11 @@ class Session: ) if self.is_started.is_set(): - self.loop.create_task(self.restart()) + self.client.loop.create_task(self.restart()) break - self.loop.create_task(self.handle_packet(packet)) + self.client.loop.create_task(self.handle_packet(packet)) log.info("NetworkTask stopped") @@ -359,7 +358,7 @@ class Session: log.debug("Sent: %s", message) - payload = await self.loop.run_in_executor( + payload = await self.client.loop.run_in_executor( pyrogram.crypto_executor, mtproto.pack, message, diff --git a/pyrogram/utils.py b/pyrogram/utils.py index 4ad6c5c2..fca5dbb6 100644 --- a/pyrogram/utils.py +++ b/pyrogram/utils.py @@ -43,13 +43,21 @@ PyromodConfig = SimpleNamespace( unallowed_click_alert=True, unallowed_click_alert_text=("[pyromod] You're not expected to click this button."), ) - -async def ainput(prompt: str = "", *, hide: bool = False): + +async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None): """Just like the built-in input, but async""" + if isinstance(loop, asyncio.AbstractEventLoop): + loop = loop + else: + try: + loop = asyncio.get_running_loop() + except RuntimeError: + loop = asyncio.new_event_loop() + with ThreadPoolExecutor(1) as executor: func = functools.partial(getpass if hide else input, prompt) - return await asyncio.get_event_loop().run_in_executor(executor, func) + return await loop.run_in_executor(executor, func) def get_input_media_from_file_id(