Refactor loop

Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
KurimuzonAkuma 2025-03-05 09:17:55 +03:00 committed by wulan17
parent e3a8a781d3
commit 1a4c578380
No known key found for this signature in database
GPG key ID: 737814D4B5FF0420
13 changed files with 89 additions and 48 deletions

View file

@ -33,7 +33,7 @@ from importlib import import_module
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple
from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram
from pyrogram import __version__, __license__
@ -55,7 +55,7 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
from pyrogram.types import User
from pyrogram.utils import ainput
from .connection import Connection
from .connection.transport import TCPAbridged
from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
@ -220,6 +220,9 @@ class Client(Methods):
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
The platform where this client is running.
Defaults to 'other'
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
Event loop.
"""
APP_VERSION = f"Pyrogram {__version__}"
@ -277,7 +280,9 @@ class Client(Methods):
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
max_message_cache_size: int = MAX_CACHE_SIZE,
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
):
super().__init__()
@ -371,7 +376,14 @@ class Client(Methods):
self.updates_watchdog_event = asyncio.Event()
self.last_update_time = datetime.now()
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
self.loop = asyncio.get_event_loop()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
def __enter__(self):
return self.start()
@ -425,7 +437,7 @@ class Client(Methods):
if not self.phone_number:
while True:
print("Enter 'qrcode' if you want to login with qrcode.")
value = await ainput("Enter phone number or bot token: ")
value = await ainput("Enter phone number or bot token: ", loop=self.loop)
if not value:
continue
@ -434,7 +446,7 @@ class Client(Methods):
self.use_qrcode = True
break
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y":
break
@ -466,7 +478,7 @@ class Client(Methods):
while True:
if not self.use_qrcode and not self.phone_code:
self.phone_code = await ainput("Enter confirmation code: ")
self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
try:
if self.use_qrcode:
@ -483,18 +495,18 @@ class Client(Methods):
print("Password hint: {}".format(await self.get_password_hint()))
if not self.password:
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password)
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
try:
if not self.password:
confirm = await ainput("Confirm password recovery (y/n): ")
confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
if confirm == "y":
email_pattern = await self.send_recovery_code()
print(f"The recovery code has been sent to {email_pattern}")
while True:
recovery_code = await ainput("Enter recovery code: ")
recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
try:
return await self.recover_password(recovery_code)
@ -842,13 +854,13 @@ class Client(Methods):
else:
while True:
try:
value = int(await ainput("Enter the api_id part of the API key: "))
value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
if value <= 0:
print("Invalid value")
continue
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y":
await self.storage.api_id(value)

View file

@ -38,7 +38,8 @@ class Connection:
alt_port: bool,
proxy: dict,
media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
@ -51,9 +52,17 @@ class Connection:
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.protocol: Optional[TCP] = None
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
async def connect(self) -> None:
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
try:
log.info("Connecting...")

View file

@ -46,7 +46,7 @@ class Proxy(TypedDict):
class TCP:
TIMEOUT = 10
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.ipv6 = ipv6
self.proxy = proxy
@ -54,7 +54,14 @@ class TCP:
self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
try:
self.loop = asyncio.get_running_loop()
except RuntimeError:
self.loop = asyncio.new_event_loop()
async def _connect_via_proxy(
self,

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from typing import Optional, Tuple
@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
import os
from typing import Optional, Tuple
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
self.encrypt = None
self.decrypt = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from binascii import crc32
from struct import pack, unpack
@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
self.seq_no: Optional[int] = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from struct import pack, unpack
from typing import Optional, Tuple
@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
import os
from struct import pack, unpack
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
self.encrypt = None
self.decrypt = None

View file

@ -95,7 +95,6 @@ class Dispatcher:
def __init__(self, client: "pyrogram.Client"):
self.client = client
self.loop = asyncio.get_event_loop()
self.handler_worker_tasks = []
self.locks_list = []
@ -271,7 +270,7 @@ class Dispatcher:
self.locks_list.append(asyncio.Lock())
self.handler_worker_tasks.append(
self.loop.create_task(self.handler_worker(self.locks_list[-1]))
self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
)
log.info("Started %s HandlerTasks", self.client.workers)
@ -311,7 +310,7 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
self.loop.create_task(fn())
self.client.loop.create_task(fn())
def remove_handler(self, handler, group: int):
async def fn():
@ -333,7 +332,7 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
self.loop.create_task(fn())
self.client.loop.create_task(fn())
async def handler_worker(self, lock: asyncio.Lock):
while True:
@ -404,7 +403,7 @@ class Dispatcher:
if inspect.iscoroutinefunction(handler.callback):
await handler.callback(self.client, *args)
else:
await self.loop.run_in_executor(
await self.client.loop.run_in_executor(
self.client.executor,
handler.callback,
self.client,

View file

@ -185,4 +185,4 @@ class DownloadMedia:
if block:
return await downloader
else:
asyncio.get_event_loop().create_task(downloader)
self.loop.create_task(downloader)

View file

@ -52,6 +52,7 @@ class Auth:
self.proxy = client.proxy
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
self.loop = client.loop
self.connection: Optional[Connection] = None
@ -93,7 +94,8 @@ class Auth:
alt_port=self.alt_port,
proxy=self.proxy,
media=False,
protocol_factory=self.protocol_factory
protocol_factory=self.protocol_factory,
loop=self.loop
)
try:

View file

@ -104,8 +104,6 @@ class Session:
self.is_started = asyncio.Event()
self.loop = asyncio.get_event_loop()
self.last_reconnect_attempt = None
async def start(self):
@ -117,13 +115,14 @@ class Session:
alt_port=self.client.alt_port,
proxy=self.client.proxy,
media=self.is_media,
protocol_factory=self.client.protocol_factory
protocol_factory=self.client.protocol_factory,
loop=self.client.loop
)
try:
await self.connection.connect()
self.recv_task = self.loop.create_task(self.recv_worker())
self.recv_task = self.client.loop.create_task(self.recv_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@ -145,7 +144,7 @@ class Session:
timeout=self.START_TIMEOUT
)
self.ping_task = self.loop.create_task(self.ping_worker())
self.ping_task = self.client.loop.create_task(self.ping_worker())
log.info("Session initialized: Layer %s", layer)
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
@ -207,7 +206,7 @@ class Session:
async def handle_packet(self, packet):
try:
data = await self.loop.run_in_executor(
data = await self.client.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.unpack,
BytesIO(packet),
@ -217,7 +216,7 @@ class Session:
)
except ValueError as e:
log.debug(e)
self.loop.create_task(self.restart())
self.client.loop.create_task(self.restart())
return
messages = (
@ -279,7 +278,7 @@ class Session:
msg_id = msg.body.msg_id
else:
if self.client is not None:
self.loop.create_task(self.client.handle_updates(msg.body))
self.client.loop.create_task(self.client.handle_updates(msg.body))
if msg_id in self.results:
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
@ -313,7 +312,7 @@ class Session:
), False
)
except OSError:
self.loop.create_task(self.restart())
self.client.loop.create_task(self.restart())
break
except RPCError:
pass
@ -342,11 +341,11 @@ class Session:
)
if self.is_started.is_set():
self.loop.create_task(self.restart())
self.client.loop.create_task(self.restart())
break
self.loop.create_task(self.handle_packet(packet))
self.client.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped")
@ -359,7 +358,7 @@ class Session:
log.debug("Sent: %s", message)
payload = await self.loop.run_in_executor(
payload = await self.client.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.pack,
message,

View file

@ -43,13 +43,21 @@ PyromodConfig = SimpleNamespace(
unallowed_click_alert=True,
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
)
async def ainput(prompt: str = "", *, hide: bool = False):
async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Just like the built-in input, but async"""
if isinstance(loop, asyncio.AbstractEventLoop):
loop = loop
else:
try:
loop = asyncio.get_running_loop()
except RuntimeError:
loop = asyncio.new_event_loop()
with ThreadPoolExecutor(1) as executor:
func = functools.partial(getpass if hide else input, prompt)
return await asyncio.get_event_loop().run_in_executor(executor, func)
return await loop.run_in_executor(executor, func)
def get_input_media_from_file_id(