mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2026-01-03 14:04:51 +00:00
Refactor loop
Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
parent
e3a8a781d3
commit
1a4c578380
13 changed files with 89 additions and 48 deletions
|
|
@ -33,7 +33,7 @@ from importlib import import_module
|
||||||
from io import StringIO, BytesIO
|
from io import StringIO, BytesIO
|
||||||
from mimetypes import MimeTypes
|
from mimetypes import MimeTypes
|
||||||
from pathlib import Path
|
from pathlib import Path
|
||||||
from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple
|
from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram import __version__, __license__
|
from pyrogram import __version__, __license__
|
||||||
|
|
@ -55,7 +55,7 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
|
||||||
from pyrogram.types import User
|
from pyrogram.types import User
|
||||||
from pyrogram.utils import ainput
|
from pyrogram.utils import ainput
|
||||||
from .connection import Connection
|
from .connection import Connection
|
||||||
from .connection.transport import TCPAbridged
|
from .connection.transport import TCP, TCPAbridged
|
||||||
from .dispatcher import Dispatcher
|
from .dispatcher import Dispatcher
|
||||||
from .file_id import FileId, FileType, ThumbnailSource
|
from .file_id import FileId, FileType, ThumbnailSource
|
||||||
from .mime_types import mime_types
|
from .mime_types import mime_types
|
||||||
|
|
@ -220,6 +220,9 @@ class Client(Methods):
|
||||||
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
|
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
|
||||||
The platform where this client is running.
|
The platform where this client is running.
|
||||||
Defaults to 'other'
|
Defaults to 'other'
|
||||||
|
|
||||||
|
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
|
||||||
|
Event loop.
|
||||||
"""
|
"""
|
||||||
|
|
||||||
APP_VERSION = f"Pyrogram {__version__}"
|
APP_VERSION = f"Pyrogram {__version__}"
|
||||||
|
|
@ -277,7 +280,9 @@ class Client(Methods):
|
||||||
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
|
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
|
||||||
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
|
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
|
||||||
max_message_cache_size: int = MAX_CACHE_SIZE,
|
max_message_cache_size: int = MAX_CACHE_SIZE,
|
||||||
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE
|
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
|
||||||
|
protocol_factory: Type[TCP] = TCPAbridged,
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
):
|
):
|
||||||
super().__init__()
|
super().__init__()
|
||||||
|
|
||||||
|
|
@ -371,7 +376,14 @@ class Client(Methods):
|
||||||
self.updates_watchdog_event = asyncio.Event()
|
self.updates_watchdog_event = asyncio.Event()
|
||||||
self.last_update_time = datetime.now()
|
self.last_update_time = datetime.now()
|
||||||
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
|
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
|
||||||
self.loop = asyncio.get_event_loop()
|
|
||||||
|
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||||
|
self.loop = loop
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
def __enter__(self):
|
def __enter__(self):
|
||||||
return self.start()
|
return self.start()
|
||||||
|
|
@ -425,7 +437,7 @@ class Client(Methods):
|
||||||
if not self.phone_number:
|
if not self.phone_number:
|
||||||
while True:
|
while True:
|
||||||
print("Enter 'qrcode' if you want to login with qrcode.")
|
print("Enter 'qrcode' if you want to login with qrcode.")
|
||||||
value = await ainput("Enter phone number or bot token: ")
|
value = await ainput("Enter phone number or bot token: ", loop=self.loop)
|
||||||
|
|
||||||
if not value:
|
if not value:
|
||||||
continue
|
continue
|
||||||
|
|
@ -434,7 +446,7 @@ class Client(Methods):
|
||||||
self.use_qrcode = True
|
self.use_qrcode = True
|
||||||
break
|
break
|
||||||
|
|
||||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
|
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
|
||||||
|
|
||||||
if confirm == "y":
|
if confirm == "y":
|
||||||
break
|
break
|
||||||
|
|
@ -466,7 +478,7 @@ class Client(Methods):
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
if not self.use_qrcode and not self.phone_code:
|
if not self.use_qrcode and not self.phone_code:
|
||||||
self.phone_code = await ainput("Enter confirmation code: ")
|
self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if self.use_qrcode:
|
if self.use_qrcode:
|
||||||
|
|
@ -483,18 +495,18 @@ class Client(Methods):
|
||||||
print("Password hint: {}".format(await self.get_password_hint()))
|
print("Password hint: {}".format(await self.get_password_hint()))
|
||||||
|
|
||||||
if not self.password:
|
if not self.password:
|
||||||
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password)
|
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
if not self.password:
|
if not self.password:
|
||||||
confirm = await ainput("Confirm password recovery (y/n): ")
|
confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
|
||||||
|
|
||||||
if confirm == "y":
|
if confirm == "y":
|
||||||
email_pattern = await self.send_recovery_code()
|
email_pattern = await self.send_recovery_code()
|
||||||
print(f"The recovery code has been sent to {email_pattern}")
|
print(f"The recovery code has been sent to {email_pattern}")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
recovery_code = await ainput("Enter recovery code: ")
|
recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
return await self.recover_password(recovery_code)
|
return await self.recover_password(recovery_code)
|
||||||
|
|
@ -842,13 +854,13 @@ class Client(Methods):
|
||||||
else:
|
else:
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
value = int(await ainput("Enter the api_id part of the API key: "))
|
value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
|
||||||
|
|
||||||
if value <= 0:
|
if value <= 0:
|
||||||
print("Invalid value")
|
print("Invalid value")
|
||||||
continue
|
continue
|
||||||
|
|
||||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
|
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
|
||||||
|
|
||||||
if confirm == "y":
|
if confirm == "y":
|
||||||
await self.storage.api_id(value)
|
await self.storage.api_id(value)
|
||||||
|
|
|
||||||
|
|
@ -38,7 +38,8 @@ class Connection:
|
||||||
alt_port: bool,
|
alt_port: bool,
|
||||||
proxy: dict,
|
proxy: dict,
|
||||||
media: bool = False,
|
media: bool = False,
|
||||||
protocol_factory: Type[TCP] = TCPAbridged
|
protocol_factory: Type[TCP] = TCPAbridged,
|
||||||
|
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dc_id = dc_id
|
self.dc_id = dc_id
|
||||||
self.test_mode = test_mode
|
self.test_mode = test_mode
|
||||||
|
|
@ -51,9 +52,17 @@ class Connection:
|
||||||
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
|
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
|
||||||
self.protocol: Optional[TCP] = None
|
self.protocol: Optional[TCP] = None
|
||||||
|
|
||||||
|
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||||
|
self.loop = loop
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
|
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
|
||||||
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
|
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
log.info("Connecting...")
|
log.info("Connecting...")
|
||||||
|
|
|
||||||
|
|
@ -46,7 +46,7 @@ class Proxy(TypedDict):
|
||||||
class TCP:
|
class TCP:
|
||||||
TIMEOUT = 10
|
TIMEOUT = 10
|
||||||
|
|
||||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
self.ipv6 = ipv6
|
self.ipv6 = ipv6
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
|
|
||||||
|
|
@ -54,7 +54,14 @@ class TCP:
|
||||||
self.writer: Optional[asyncio.StreamWriter] = None
|
self.writer: Optional[asyncio.StreamWriter] = None
|
||||||
|
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
self.loop = asyncio.get_event_loop()
|
|
||||||
|
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||||
|
self.loop = loop
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
self.loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
self.loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
async def _connect_via_proxy(
|
async def _connect_via_proxy(
|
||||||
self,
|
self,
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
||||||
|
|
@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TCPAbridged(TCP):
|
class TCPAbridged(TCP):
|
||||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
super().__init__(ipv6, proxy)
|
super().__init__(ipv6, proxy, loop)
|
||||||
|
|
||||||
async def connect(self, address: Tuple[str, int]) -> None:
|
async def connect(self, address: Tuple[str, int]) -> None:
|
||||||
await super().connect(address)
|
await super().connect(address)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
|
||||||
class TCPAbridgedO(TCP):
|
class TCPAbridgedO(TCP):
|
||||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||||
|
|
||||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
super().__init__(ipv6, proxy)
|
super().__init__(ipv6, proxy, loop)
|
||||||
|
|
||||||
self.encrypt = None
|
self.encrypt = None
|
||||||
self.decrypt = None
|
self.decrypt = None
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from binascii import crc32
|
from binascii import crc32
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
|
|
@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TCPFull(TCP):
|
class TCPFull(TCP):
|
||||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
super().__init__(ipv6, proxy)
|
super().__init__(ipv6, proxy, loop)
|
||||||
|
|
||||||
self.seq_no: Optional[int] = None
|
self.seq_no: Optional[int] = None
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
from typing import Optional, Tuple
|
from typing import Optional, Tuple
|
||||||
|
|
@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
class TCPIntermediate(TCP):
|
class TCPIntermediate(TCP):
|
||||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
super().__init__(ipv6, proxy)
|
super().__init__(ipv6, proxy, loop)
|
||||||
|
|
||||||
async def connect(self, address: Tuple[str, int]) -> None:
|
async def connect(self, address: Tuple[str, int]) -> None:
|
||||||
await super().connect(address)
|
await super().connect(address)
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,7 @@
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
from struct import pack, unpack
|
from struct import pack, unpack
|
||||||
|
|
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
|
||||||
class TCPIntermediateO(TCP):
|
class TCPIntermediateO(TCP):
|
||||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||||
|
|
||||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||||
super().__init__(ipv6, proxy)
|
super().__init__(ipv6, proxy, loop)
|
||||||
|
|
||||||
self.encrypt = None
|
self.encrypt = None
|
||||||
self.decrypt = None
|
self.decrypt = None
|
||||||
|
|
|
||||||
|
|
@ -95,7 +95,6 @@ class Dispatcher:
|
||||||
|
|
||||||
def __init__(self, client: "pyrogram.Client"):
|
def __init__(self, client: "pyrogram.Client"):
|
||||||
self.client = client
|
self.client = client
|
||||||
self.loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
self.handler_worker_tasks = []
|
self.handler_worker_tasks = []
|
||||||
self.locks_list = []
|
self.locks_list = []
|
||||||
|
|
@ -271,7 +270,7 @@ class Dispatcher:
|
||||||
self.locks_list.append(asyncio.Lock())
|
self.locks_list.append(asyncio.Lock())
|
||||||
|
|
||||||
self.handler_worker_tasks.append(
|
self.handler_worker_tasks.append(
|
||||||
self.loop.create_task(self.handler_worker(self.locks_list[-1]))
|
self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
|
||||||
)
|
)
|
||||||
|
|
||||||
log.info("Started %s HandlerTasks", self.client.workers)
|
log.info("Started %s HandlerTasks", self.client.workers)
|
||||||
|
|
@ -311,7 +310,7 @@ class Dispatcher:
|
||||||
for lock in self.locks_list:
|
for lock in self.locks_list:
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
self.loop.create_task(fn())
|
self.client.loop.create_task(fn())
|
||||||
|
|
||||||
def remove_handler(self, handler, group: int):
|
def remove_handler(self, handler, group: int):
|
||||||
async def fn():
|
async def fn():
|
||||||
|
|
@ -333,7 +332,7 @@ class Dispatcher:
|
||||||
for lock in self.locks_list:
|
for lock in self.locks_list:
|
||||||
lock.release()
|
lock.release()
|
||||||
|
|
||||||
self.loop.create_task(fn())
|
self.client.loop.create_task(fn())
|
||||||
|
|
||||||
async def handler_worker(self, lock: asyncio.Lock):
|
async def handler_worker(self, lock: asyncio.Lock):
|
||||||
while True:
|
while True:
|
||||||
|
|
@ -404,7 +403,7 @@ class Dispatcher:
|
||||||
if inspect.iscoroutinefunction(handler.callback):
|
if inspect.iscoroutinefunction(handler.callback):
|
||||||
await handler.callback(self.client, *args)
|
await handler.callback(self.client, *args)
|
||||||
else:
|
else:
|
||||||
await self.loop.run_in_executor(
|
await self.client.loop.run_in_executor(
|
||||||
self.client.executor,
|
self.client.executor,
|
||||||
handler.callback,
|
handler.callback,
|
||||||
self.client,
|
self.client,
|
||||||
|
|
|
||||||
|
|
@ -185,4 +185,4 @@ class DownloadMedia:
|
||||||
if block:
|
if block:
|
||||||
return await downloader
|
return await downloader
|
||||||
else:
|
else:
|
||||||
asyncio.get_event_loop().create_task(downloader)
|
self.loop.create_task(downloader)
|
||||||
|
|
|
||||||
|
|
@ -52,6 +52,7 @@ class Auth:
|
||||||
self.proxy = client.proxy
|
self.proxy = client.proxy
|
||||||
self.connection_factory = client.connection_factory
|
self.connection_factory = client.connection_factory
|
||||||
self.protocol_factory = client.protocol_factory
|
self.protocol_factory = client.protocol_factory
|
||||||
|
self.loop = client.loop
|
||||||
|
|
||||||
self.connection: Optional[Connection] = None
|
self.connection: Optional[Connection] = None
|
||||||
|
|
||||||
|
|
@ -93,7 +94,8 @@ class Auth:
|
||||||
alt_port=self.alt_port,
|
alt_port=self.alt_port,
|
||||||
proxy=self.proxy,
|
proxy=self.proxy,
|
||||||
media=False,
|
media=False,
|
||||||
protocol_factory=self.protocol_factory
|
protocol_factory=self.protocol_factory,
|
||||||
|
loop=self.loop
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -104,8 +104,6 @@ class Session:
|
||||||
|
|
||||||
self.is_started = asyncio.Event()
|
self.is_started = asyncio.Event()
|
||||||
|
|
||||||
self.loop = asyncio.get_event_loop()
|
|
||||||
|
|
||||||
self.last_reconnect_attempt = None
|
self.last_reconnect_attempt = None
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
|
@ -117,13 +115,14 @@ class Session:
|
||||||
alt_port=self.client.alt_port,
|
alt_port=self.client.alt_port,
|
||||||
proxy=self.client.proxy,
|
proxy=self.client.proxy,
|
||||||
media=self.is_media,
|
media=self.is_media,
|
||||||
protocol_factory=self.client.protocol_factory
|
protocol_factory=self.client.protocol_factory,
|
||||||
|
loop=self.client.loop
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
await self.connection.connect()
|
await self.connection.connect()
|
||||||
|
|
||||||
self.recv_task = self.loop.create_task(self.recv_worker())
|
self.recv_task = self.client.loop.create_task(self.recv_worker())
|
||||||
|
|
||||||
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
||||||
|
|
||||||
|
|
@ -145,7 +144,7 @@ class Session:
|
||||||
timeout=self.START_TIMEOUT
|
timeout=self.START_TIMEOUT
|
||||||
)
|
)
|
||||||
|
|
||||||
self.ping_task = self.loop.create_task(self.ping_worker())
|
self.ping_task = self.client.loop.create_task(self.ping_worker())
|
||||||
|
|
||||||
log.info("Session initialized: Layer %s", layer)
|
log.info("Session initialized: Layer %s", layer)
|
||||||
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
|
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
|
||||||
|
|
@ -207,7 +206,7 @@ class Session:
|
||||||
|
|
||||||
async def handle_packet(self, packet):
|
async def handle_packet(self, packet):
|
||||||
try:
|
try:
|
||||||
data = await self.loop.run_in_executor(
|
data = await self.client.loop.run_in_executor(
|
||||||
pyrogram.crypto_executor,
|
pyrogram.crypto_executor,
|
||||||
mtproto.unpack,
|
mtproto.unpack,
|
||||||
BytesIO(packet),
|
BytesIO(packet),
|
||||||
|
|
@ -217,7 +216,7 @@ class Session:
|
||||||
)
|
)
|
||||||
except ValueError as e:
|
except ValueError as e:
|
||||||
log.debug(e)
|
log.debug(e)
|
||||||
self.loop.create_task(self.restart())
|
self.client.loop.create_task(self.restart())
|
||||||
return
|
return
|
||||||
|
|
||||||
messages = (
|
messages = (
|
||||||
|
|
@ -279,7 +278,7 @@ class Session:
|
||||||
msg_id = msg.body.msg_id
|
msg_id = msg.body.msg_id
|
||||||
else:
|
else:
|
||||||
if self.client is not None:
|
if self.client is not None:
|
||||||
self.loop.create_task(self.client.handle_updates(msg.body))
|
self.client.loop.create_task(self.client.handle_updates(msg.body))
|
||||||
|
|
||||||
if msg_id in self.results:
|
if msg_id in self.results:
|
||||||
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
|
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
|
||||||
|
|
@ -313,7 +312,7 @@ class Session:
|
||||||
), False
|
), False
|
||||||
)
|
)
|
||||||
except OSError:
|
except OSError:
|
||||||
self.loop.create_task(self.restart())
|
self.client.loop.create_task(self.restart())
|
||||||
break
|
break
|
||||||
except RPCError:
|
except RPCError:
|
||||||
pass
|
pass
|
||||||
|
|
@ -342,11 +341,11 @@ class Session:
|
||||||
)
|
)
|
||||||
|
|
||||||
if self.is_started.is_set():
|
if self.is_started.is_set():
|
||||||
self.loop.create_task(self.restart())
|
self.client.loop.create_task(self.restart())
|
||||||
|
|
||||||
break
|
break
|
||||||
|
|
||||||
self.loop.create_task(self.handle_packet(packet))
|
self.client.loop.create_task(self.handle_packet(packet))
|
||||||
|
|
||||||
log.info("NetworkTask stopped")
|
log.info("NetworkTask stopped")
|
||||||
|
|
||||||
|
|
@ -359,7 +358,7 @@ class Session:
|
||||||
|
|
||||||
log.debug("Sent: %s", message)
|
log.debug("Sent: %s", message)
|
||||||
|
|
||||||
payload = await self.loop.run_in_executor(
|
payload = await self.client.loop.run_in_executor(
|
||||||
pyrogram.crypto_executor,
|
pyrogram.crypto_executor,
|
||||||
mtproto.pack,
|
mtproto.pack,
|
||||||
message,
|
message,
|
||||||
|
|
|
||||||
|
|
@ -43,13 +43,21 @@ PyromodConfig = SimpleNamespace(
|
||||||
unallowed_click_alert=True,
|
unallowed_click_alert=True,
|
||||||
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
|
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
async def ainput(prompt: str = "", *, hide: bool = False):
|
|
||||||
|
async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
|
||||||
"""Just like the built-in input, but async"""
|
"""Just like the built-in input, but async"""
|
||||||
|
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||||
|
loop = loop
|
||||||
|
else:
|
||||||
|
try:
|
||||||
|
loop = asyncio.get_running_loop()
|
||||||
|
except RuntimeError:
|
||||||
|
loop = asyncio.new_event_loop()
|
||||||
|
|
||||||
with ThreadPoolExecutor(1) as executor:
|
with ThreadPoolExecutor(1) as executor:
|
||||||
func = functools.partial(getpass if hide else input, prompt)
|
func = functools.partial(getpass if hide else input, prompt)
|
||||||
return await asyncio.get_event_loop().run_in_executor(executor, func)
|
return await loop.run_in_executor(executor, func)
|
||||||
|
|
||||||
|
|
||||||
def get_input_media_from_file_id(
|
def get_input_media_from_file_id(
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue