mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2025-12-29 12:04:51 +00:00
Refactor loop
Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
parent
e3a8a781d3
commit
1a4c578380
13 changed files with 89 additions and 48 deletions
|
|
@ -33,7 +33,7 @@ from importlib import import_module
|
|||
from io import StringIO, BytesIO
|
||||
from mimetypes import MimeTypes
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple
|
||||
from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
|
||||
|
||||
import pyrogram
|
||||
from pyrogram import __version__, __license__
|
||||
|
|
@ -55,7 +55,7 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
|
|||
from pyrogram.types import User
|
||||
from pyrogram.utils import ainput
|
||||
from .connection import Connection
|
||||
from .connection.transport import TCPAbridged
|
||||
from .connection.transport import TCP, TCPAbridged
|
||||
from .dispatcher import Dispatcher
|
||||
from .file_id import FileId, FileType, ThumbnailSource
|
||||
from .mime_types import mime_types
|
||||
|
|
@ -220,6 +220,9 @@ class Client(Methods):
|
|||
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
|
||||
The platform where this client is running.
|
||||
Defaults to 'other'
|
||||
|
||||
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
|
||||
Event loop.
|
||||
"""
|
||||
|
||||
APP_VERSION = f"Pyrogram {__version__}"
|
||||
|
|
@ -277,7 +280,9 @@ class Client(Methods):
|
|||
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
|
||||
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
|
||||
max_message_cache_size: int = MAX_CACHE_SIZE,
|
||||
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE
|
||||
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
|
||||
protocol_factory: Type[TCP] = TCPAbridged,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -371,7 +376,14 @@ class Client(Methods):
|
|||
self.updates_watchdog_event = asyncio.Event()
|
||||
self.last_update_time = datetime.now()
|
||||
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
self.loop = loop
|
||||
else:
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
def __enter__(self):
|
||||
return self.start()
|
||||
|
|
@ -425,7 +437,7 @@ class Client(Methods):
|
|||
if not self.phone_number:
|
||||
while True:
|
||||
print("Enter 'qrcode' if you want to login with qrcode.")
|
||||
value = await ainput("Enter phone number or bot token: ")
|
||||
value = await ainput("Enter phone number or bot token: ", loop=self.loop)
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
|
@ -434,7 +446,7 @@ class Client(Methods):
|
|||
self.use_qrcode = True
|
||||
break
|
||||
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
|
||||
|
||||
if confirm == "y":
|
||||
break
|
||||
|
|
@ -466,7 +478,7 @@ class Client(Methods):
|
|||
|
||||
while True:
|
||||
if not self.use_qrcode and not self.phone_code:
|
||||
self.phone_code = await ainput("Enter confirmation code: ")
|
||||
self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
|
||||
|
||||
try:
|
||||
if self.use_qrcode:
|
||||
|
|
@ -483,18 +495,18 @@ class Client(Methods):
|
|||
print("Password hint: {}".format(await self.get_password_hint()))
|
||||
|
||||
if not self.password:
|
||||
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password)
|
||||
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
|
||||
|
||||
try:
|
||||
if not self.password:
|
||||
confirm = await ainput("Confirm password recovery (y/n): ")
|
||||
confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
|
||||
|
||||
if confirm == "y":
|
||||
email_pattern = await self.send_recovery_code()
|
||||
print(f"The recovery code has been sent to {email_pattern}")
|
||||
|
||||
while True:
|
||||
recovery_code = await ainput("Enter recovery code: ")
|
||||
recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
|
||||
|
||||
try:
|
||||
return await self.recover_password(recovery_code)
|
||||
|
|
@ -842,13 +854,13 @@ class Client(Methods):
|
|||
else:
|
||||
while True:
|
||||
try:
|
||||
value = int(await ainput("Enter the api_id part of the API key: "))
|
||||
value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
|
||||
|
||||
if value <= 0:
|
||||
print("Invalid value")
|
||||
continue
|
||||
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
|
||||
|
||||
if confirm == "y":
|
||||
await self.storage.api_id(value)
|
||||
|
|
|
|||
|
|
@ -38,7 +38,8 @@ class Connection:
|
|||
alt_port: bool,
|
||||
proxy: dict,
|
||||
media: bool = False,
|
||||
protocol_factory: Type[TCP] = TCPAbridged
|
||||
protocol_factory: Type[TCP] = TCPAbridged,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
) -> None:
|
||||
self.dc_id = dc_id
|
||||
self.test_mode = test_mode
|
||||
|
|
@ -51,9 +52,17 @@ class Connection:
|
|||
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
|
||||
self.protocol: Optional[TCP] = None
|
||||
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
self.loop = loop
|
||||
else:
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
async def connect(self) -> None:
|
||||
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
|
||||
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
|
||||
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
|
||||
|
||||
try:
|
||||
log.info("Connecting...")
|
||||
|
|
|
|||
|
|
@ -46,7 +46,7 @@ class Proxy(TypedDict):
|
|||
class TCP:
|
||||
TIMEOUT = 10
|
||||
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
self.ipv6 = ipv6
|
||||
self.proxy = proxy
|
||||
|
||||
|
|
@ -54,7 +54,14 @@ class TCP:
|
|||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
self.loop = loop
|
||||
else:
|
||||
try:
|
||||
self.loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
self.loop = asyncio.new_event_loop()
|
||||
|
||||
async def _connect_via_proxy(
|
||||
self,
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
|
@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TCPAbridged(TCP):
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
async def connect(self, address: Tuple[str, int]) -> None:
|
||||
await super().connect(address)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
|
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
|
|||
class TCPAbridgedO(TCP):
|
||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
self.encrypt = None
|
||||
self.decrypt = None
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from binascii import crc32
|
||||
from struct import pack, unpack
|
||||
|
|
@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TCPFull(TCP):
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
self.seq_no: Optional[int] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from struct import pack, unpack
|
||||
from typing import Optional, Tuple
|
||||
|
|
@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TCPIntermediate(TCP):
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
async def connect(self, address: Tuple[str, int]) -> None:
|
||||
await super().connect(address)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from struct import pack, unpack
|
||||
|
|
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
|
|||
class TCPIntermediateO(TCP):
|
||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
self.encrypt = None
|
||||
self.decrypt = None
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ class Dispatcher:
|
|||
|
||||
def __init__(self, client: "pyrogram.Client"):
|
||||
self.client = client
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self.handler_worker_tasks = []
|
||||
self.locks_list = []
|
||||
|
|
@ -271,7 +270,7 @@ class Dispatcher:
|
|||
self.locks_list.append(asyncio.Lock())
|
||||
|
||||
self.handler_worker_tasks.append(
|
||||
self.loop.create_task(self.handler_worker(self.locks_list[-1]))
|
||||
self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
|
||||
)
|
||||
|
||||
log.info("Started %s HandlerTasks", self.client.workers)
|
||||
|
|
@ -311,7 +310,7 @@ class Dispatcher:
|
|||
for lock in self.locks_list:
|
||||
lock.release()
|
||||
|
||||
self.loop.create_task(fn())
|
||||
self.client.loop.create_task(fn())
|
||||
|
||||
def remove_handler(self, handler, group: int):
|
||||
async def fn():
|
||||
|
|
@ -333,7 +332,7 @@ class Dispatcher:
|
|||
for lock in self.locks_list:
|
||||
lock.release()
|
||||
|
||||
self.loop.create_task(fn())
|
||||
self.client.loop.create_task(fn())
|
||||
|
||||
async def handler_worker(self, lock: asyncio.Lock):
|
||||
while True:
|
||||
|
|
@ -404,7 +403,7 @@ class Dispatcher:
|
|||
if inspect.iscoroutinefunction(handler.callback):
|
||||
await handler.callback(self.client, *args)
|
||||
else:
|
||||
await self.loop.run_in_executor(
|
||||
await self.client.loop.run_in_executor(
|
||||
self.client.executor,
|
||||
handler.callback,
|
||||
self.client,
|
||||
|
|
|
|||
|
|
@ -185,4 +185,4 @@ class DownloadMedia:
|
|||
if block:
|
||||
return await downloader
|
||||
else:
|
||||
asyncio.get_event_loop().create_task(downloader)
|
||||
self.loop.create_task(downloader)
|
||||
|
|
|
|||
|
|
@ -52,6 +52,7 @@ class Auth:
|
|||
self.proxy = client.proxy
|
||||
self.connection_factory = client.connection_factory
|
||||
self.protocol_factory = client.protocol_factory
|
||||
self.loop = client.loop
|
||||
|
||||
self.connection: Optional[Connection] = None
|
||||
|
||||
|
|
@ -93,7 +94,8 @@ class Auth:
|
|||
alt_port=self.alt_port,
|
||||
proxy=self.proxy,
|
||||
media=False,
|
||||
protocol_factory=self.protocol_factory
|
||||
protocol_factory=self.protocol_factory,
|
||||
loop=self.loop
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -104,8 +104,6 @@ class Session:
|
|||
|
||||
self.is_started = asyncio.Event()
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self.last_reconnect_attempt = None
|
||||
|
||||
async def start(self):
|
||||
|
|
@ -117,13 +115,14 @@ class Session:
|
|||
alt_port=self.client.alt_port,
|
||||
proxy=self.client.proxy,
|
||||
media=self.is_media,
|
||||
protocol_factory=self.client.protocol_factory
|
||||
protocol_factory=self.client.protocol_factory,
|
||||
loop=self.client.loop
|
||||
)
|
||||
|
||||
try:
|
||||
await self.connection.connect()
|
||||
|
||||
self.recv_task = self.loop.create_task(self.recv_worker())
|
||||
self.recv_task = self.client.loop.create_task(self.recv_worker())
|
||||
|
||||
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
||||
|
||||
|
|
@ -145,7 +144,7 @@ class Session:
|
|||
timeout=self.START_TIMEOUT
|
||||
)
|
||||
|
||||
self.ping_task = self.loop.create_task(self.ping_worker())
|
||||
self.ping_task = self.client.loop.create_task(self.ping_worker())
|
||||
|
||||
log.info("Session initialized: Layer %s", layer)
|
||||
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
|
||||
|
|
@ -207,7 +206,7 @@ class Session:
|
|||
|
||||
async def handle_packet(self, packet):
|
||||
try:
|
||||
data = await self.loop.run_in_executor(
|
||||
data = await self.client.loop.run_in_executor(
|
||||
pyrogram.crypto_executor,
|
||||
mtproto.unpack,
|
||||
BytesIO(packet),
|
||||
|
|
@ -217,7 +216,7 @@ class Session:
|
|||
)
|
||||
except ValueError as e:
|
||||
log.debug(e)
|
||||
self.loop.create_task(self.restart())
|
||||
self.client.loop.create_task(self.restart())
|
||||
return
|
||||
|
||||
messages = (
|
||||
|
|
@ -279,7 +278,7 @@ class Session:
|
|||
msg_id = msg.body.msg_id
|
||||
else:
|
||||
if self.client is not None:
|
||||
self.loop.create_task(self.client.handle_updates(msg.body))
|
||||
self.client.loop.create_task(self.client.handle_updates(msg.body))
|
||||
|
||||
if msg_id in self.results:
|
||||
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
|
||||
|
|
@ -313,7 +312,7 @@ class Session:
|
|||
), False
|
||||
)
|
||||
except OSError:
|
||||
self.loop.create_task(self.restart())
|
||||
self.client.loop.create_task(self.restart())
|
||||
break
|
||||
except RPCError:
|
||||
pass
|
||||
|
|
@ -342,11 +341,11 @@ class Session:
|
|||
)
|
||||
|
||||
if self.is_started.is_set():
|
||||
self.loop.create_task(self.restart())
|
||||
self.client.loop.create_task(self.restart())
|
||||
|
||||
break
|
||||
|
||||
self.loop.create_task(self.handle_packet(packet))
|
||||
self.client.loop.create_task(self.handle_packet(packet))
|
||||
|
||||
log.info("NetworkTask stopped")
|
||||
|
||||
|
|
@ -359,7 +358,7 @@ class Session:
|
|||
|
||||
log.debug("Sent: %s", message)
|
||||
|
||||
payload = await self.loop.run_in_executor(
|
||||
payload = await self.client.loop.run_in_executor(
|
||||
pyrogram.crypto_executor,
|
||||
mtproto.pack,
|
||||
message,
|
||||
|
|
|
|||
|
|
@ -43,13 +43,21 @@ PyromodConfig = SimpleNamespace(
|
|||
unallowed_click_alert=True,
|
||||
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
|
||||
)
|
||||
|
||||
|
||||
async def ainput(prompt: str = "", *, hide: bool = False):
|
||||
|
||||
async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
|
||||
"""Just like the built-in input, but async"""
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
loop = loop
|
||||
else:
|
||||
try:
|
||||
loop = asyncio.get_running_loop()
|
||||
except RuntimeError:
|
||||
loop = asyncio.new_event_loop()
|
||||
|
||||
with ThreadPoolExecutor(1) as executor:
|
||||
func = functools.partial(getpass if hide else input, prompt)
|
||||
return await asyncio.get_event_loop().run_in_executor(executor, func)
|
||||
return await loop.run_in_executor(executor, func)
|
||||
|
||||
|
||||
def get_input_media_from_file_id(
|
||||
|
|
|
|||
Loading…
Reference in a new issue