Refactor loop

Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
KurimuzonAkuma 2025-03-05 09:17:55 +03:00 committed by wulan17
parent e3a8a781d3
commit 1a4c578380
No known key found for this signature in database
GPG key ID: 737814D4B5FF0420
13 changed files with 89 additions and 48 deletions

View file

@ -33,7 +33,7 @@ from importlib import import_module
from io import StringIO, BytesIO from io import StringIO, BytesIO
from mimetypes import MimeTypes from mimetypes import MimeTypes
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram import pyrogram
from pyrogram import __version__, __license__ from pyrogram import __version__, __license__
@ -55,7 +55,7 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
from pyrogram.types import User from pyrogram.types import User
from pyrogram.utils import ainput from pyrogram.utils import ainput
from .connection import Connection from .connection import Connection
from .connection.transport import TCPAbridged from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types from .mime_types import mime_types
@ -220,6 +220,9 @@ class Client(Methods):
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*): client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
The platform where this client is running. The platform where this client is running.
Defaults to 'other' Defaults to 'other'
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
Event loop.
""" """
APP_VERSION = f"Pyrogram {__version__}" APP_VERSION = f"Pyrogram {__version__}"
@ -277,7 +280,9 @@ class Client(Methods):
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS, max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER, client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
max_message_cache_size: int = MAX_CACHE_SIZE, max_message_cache_size: int = MAX_CACHE_SIZE,
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
): ):
super().__init__() super().__init__()
@ -371,7 +376,14 @@ class Client(Methods):
self.updates_watchdog_event = asyncio.Event() self.updates_watchdog_event = asyncio.Event()
self.last_update_time = datetime.now() self.last_update_time = datetime.now()
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes} self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
self.loop = asyncio.get_event_loop()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
def __enter__(self): def __enter__(self):
return self.start() return self.start()
@ -425,7 +437,7 @@ class Client(Methods):
if not self.phone_number: if not self.phone_number:
while True: while True:
print("Enter 'qrcode' if you want to login with qrcode.") print("Enter 'qrcode' if you want to login with qrcode.")
value = await ainput("Enter phone number or bot token: ") value = await ainput("Enter phone number or bot token: ", loop=self.loop)
if not value: if not value:
continue continue
@ -434,7 +446,7 @@ class Client(Methods):
self.use_qrcode = True self.use_qrcode = True
break break
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower() confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y": if confirm == "y":
break break
@ -466,7 +478,7 @@ class Client(Methods):
while True: while True:
if not self.use_qrcode and not self.phone_code: if not self.use_qrcode and not self.phone_code:
self.phone_code = await ainput("Enter confirmation code: ") self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
try: try:
if self.use_qrcode: if self.use_qrcode:
@ -483,18 +495,18 @@ class Client(Methods):
print("Password hint: {}".format(await self.get_password_hint())) print("Password hint: {}".format(await self.get_password_hint()))
if not self.password: if not self.password:
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password) self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
try: try:
if not self.password: if not self.password:
confirm = await ainput("Confirm password recovery (y/n): ") confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
if confirm == "y": if confirm == "y":
email_pattern = await self.send_recovery_code() email_pattern = await self.send_recovery_code()
print(f"The recovery code has been sent to {email_pattern}") print(f"The recovery code has been sent to {email_pattern}")
while True: while True:
recovery_code = await ainput("Enter recovery code: ") recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
try: try:
return await self.recover_password(recovery_code) return await self.recover_password(recovery_code)
@ -842,13 +854,13 @@ class Client(Methods):
else: else:
while True: while True:
try: try:
value = int(await ainput("Enter the api_id part of the API key: ")) value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
if value <= 0: if value <= 0:
print("Invalid value") print("Invalid value")
continue continue
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower() confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y": if confirm == "y":
await self.storage.api_id(value) await self.storage.api_id(value)

View file

@ -38,7 +38,8 @@ class Connection:
alt_port: bool, alt_port: bool,
proxy: dict, proxy: dict,
media: bool = False, media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> None: ) -> None:
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
@ -51,9 +52,17 @@ class Connection:
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media) self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.protocol: Optional[TCP] = None self.protocol: Optional[TCP] = None
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
async def connect(self) -> None: async def connect(self) -> None:
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS): for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy) self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
try: try:
log.info("Connecting...") log.info("Connecting...")

View file

@ -46,7 +46,7 @@ class Proxy(TypedDict):
class TCP: class TCP:
TIMEOUT = 10 TIMEOUT = 10
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.ipv6 = ipv6 self.ipv6 = ipv6
self.proxy = proxy self.proxy = proxy
@ -54,7 +54,14 @@ class TCP:
self.writer: Optional[asyncio.StreamWriter] = None self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
async def _connect_via_proxy( async def _connect_via_proxy(
self, self,

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from typing import Optional, Tuple from typing import Optional, Tuple
@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None: async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP): class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from binascii import crc32 from binascii import crc32
from struct import pack, unpack from struct import pack, unpack
@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
self.seq_no: Optional[int] = None self.seq_no: Optional[int] = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from struct import pack, unpack from struct import pack, unpack
from typing import Optional, Tuple from typing import Optional, Tuple
@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None: async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import os import os
from struct import pack, unpack from struct import pack, unpack
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP): class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None

View file

@ -95,7 +95,6 @@ class Dispatcher:
def __init__(self, client: "pyrogram.Client"): def __init__(self, client: "pyrogram.Client"):
self.client = client self.client = client
self.loop = asyncio.get_event_loop()
self.handler_worker_tasks = [] self.handler_worker_tasks = []
self.locks_list = [] self.locks_list = []
@ -271,7 +270,7 @@ class Dispatcher:
self.locks_list.append(asyncio.Lock()) self.locks_list.append(asyncio.Lock())
self.handler_worker_tasks.append( self.handler_worker_tasks.append(
self.loop.create_task(self.handler_worker(self.locks_list[-1])) self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
) )
log.info("Started %s HandlerTasks", self.client.workers) log.info("Started %s HandlerTasks", self.client.workers)
@ -311,7 +310,7 @@ class Dispatcher:
for lock in self.locks_list: for lock in self.locks_list:
lock.release() lock.release()
self.loop.create_task(fn()) self.client.loop.create_task(fn())
def remove_handler(self, handler, group: int): def remove_handler(self, handler, group: int):
async def fn(): async def fn():
@ -333,7 +332,7 @@ class Dispatcher:
for lock in self.locks_list: for lock in self.locks_list:
lock.release() lock.release()
self.loop.create_task(fn()) self.client.loop.create_task(fn())
async def handler_worker(self, lock: asyncio.Lock): async def handler_worker(self, lock: asyncio.Lock):
while True: while True:
@ -404,7 +403,7 @@ class Dispatcher:
if inspect.iscoroutinefunction(handler.callback): if inspect.iscoroutinefunction(handler.callback):
await handler.callback(self.client, *args) await handler.callback(self.client, *args)
else: else:
await self.loop.run_in_executor( await self.client.loop.run_in_executor(
self.client.executor, self.client.executor,
handler.callback, handler.callback,
self.client, self.client,

View file

@ -185,4 +185,4 @@ class DownloadMedia:
if block: if block:
return await downloader return await downloader
else: else:
asyncio.get_event_loop().create_task(downloader) self.loop.create_task(downloader)

View file

@ -52,6 +52,7 @@ class Auth:
self.proxy = client.proxy self.proxy = client.proxy
self.connection_factory = client.connection_factory self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory self.protocol_factory = client.protocol_factory
self.loop = client.loop
self.connection: Optional[Connection] = None self.connection: Optional[Connection] = None
@ -93,7 +94,8 @@ class Auth:
alt_port=self.alt_port, alt_port=self.alt_port,
proxy=self.proxy, proxy=self.proxy,
media=False, media=False,
protocol_factory=self.protocol_factory protocol_factory=self.protocol_factory,
loop=self.loop
) )
try: try:

View file

@ -104,8 +104,6 @@ class Session:
self.is_started = asyncio.Event() self.is_started = asyncio.Event()
self.loop = asyncio.get_event_loop()
self.last_reconnect_attempt = None self.last_reconnect_attempt = None
async def start(self): async def start(self):
@ -117,13 +115,14 @@ class Session:
alt_port=self.client.alt_port, alt_port=self.client.alt_port,
proxy=self.client.proxy, proxy=self.client.proxy,
media=self.is_media, media=self.is_media,
protocol_factory=self.client.protocol_factory protocol_factory=self.client.protocol_factory,
loop=self.client.loop
) )
try: try:
await self.connection.connect() await self.connection.connect()
self.recv_task = self.loop.create_task(self.recv_worker()) self.recv_task = self.client.loop.create_task(self.recv_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT) await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@ -145,7 +144,7 @@ class Session:
timeout=self.START_TIMEOUT timeout=self.START_TIMEOUT
) )
self.ping_task = self.loop.create_task(self.ping_worker()) self.ping_task = self.client.loop.create_task(self.ping_worker())
log.info("Session initialized: Layer %s", layer) log.info("Session initialized: Layer %s", layer)
log.info("Device: %s - %s", self.client.device_model, self.client.app_version) log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
@ -207,7 +206,7 @@ class Session:
async def handle_packet(self, packet): async def handle_packet(self, packet):
try: try:
data = await self.loop.run_in_executor( data = await self.client.loop.run_in_executor(
pyrogram.crypto_executor, pyrogram.crypto_executor,
mtproto.unpack, mtproto.unpack,
BytesIO(packet), BytesIO(packet),
@ -217,7 +216,7 @@ class Session:
) )
except ValueError as e: except ValueError as e:
log.debug(e) log.debug(e)
self.loop.create_task(self.restart()) self.client.loop.create_task(self.restart())
return return
messages = ( messages = (
@ -279,7 +278,7 @@ class Session:
msg_id = msg.body.msg_id msg_id = msg.body.msg_id
else: else:
if self.client is not None: if self.client is not None:
self.loop.create_task(self.client.handle_updates(msg.body)) self.client.loop.create_task(self.client.handle_updates(msg.body))
if msg_id in self.results: if msg_id in self.results:
self.results[msg_id].value = getattr(msg.body, "result", msg.body) self.results[msg_id].value = getattr(msg.body, "result", msg.body)
@ -313,7 +312,7 @@ class Session:
), False ), False
) )
except OSError: except OSError:
self.loop.create_task(self.restart()) self.client.loop.create_task(self.restart())
break break
except RPCError: except RPCError:
pass pass
@ -342,11 +341,11 @@ class Session:
) )
if self.is_started.is_set(): if self.is_started.is_set():
self.loop.create_task(self.restart()) self.client.loop.create_task(self.restart())
break break
self.loop.create_task(self.handle_packet(packet)) self.client.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped") log.info("NetworkTask stopped")
@ -359,7 +358,7 @@ class Session:
log.debug("Sent: %s", message) log.debug("Sent: %s", message)
payload = await self.loop.run_in_executor( payload = await self.client.loop.run_in_executor(
pyrogram.crypto_executor, pyrogram.crypto_executor,
mtproto.pack, mtproto.pack,
message, message,

View file

@ -43,13 +43,21 @@ PyromodConfig = SimpleNamespace(
unallowed_click_alert=True, unallowed_click_alert=True,
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."), unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
) )
async def ainput(prompt: str = "", *, hide: bool = False):
async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Just like the built-in input, but async""" """Just like the built-in input, but async"""
if isinstance(loop, asyncio.AbstractEventLoop):
loop = loop
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
with ThreadPoolExecutor(1) as executor: with ThreadPoolExecutor(1) as executor:
func = functools.partial(getpass if hide else input, prompt) func = functools.partial(getpass if hide else input, prompt)
return await asyncio.get_event_loop().run_in_executor(executor, func) return await loop.run_in_executor(executor, func)
def get_input_media_from_file_id( def get_input_media_from_file_id(