diff --git a/pyrogram/client.py b/pyrogram/client.py
index 1e8fef35..1e53e8f6 100644
--- a/pyrogram/client.py
+++ b/pyrogram/client.py
@@ -33,7 +33,7 @@ from importlib import import_module
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
-from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple
+from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram
from pyrogram import __version__, __license__
@@ -55,7 +55,7 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
from pyrogram.types import User
from pyrogram.utils import ainput
from .connection import Connection
-from .connection.transport import TCPAbridged
+from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
@@ -220,6 +220,9 @@ class Client(Methods):
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
The platform where this client is running.
Defaults to 'other'
+
+ loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
+ Event loop.
"""
APP_VERSION = f"Pyrogram {__version__}"
@@ -277,7 +280,9 @@ class Client(Methods):
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
max_message_cache_size: int = MAX_CACHE_SIZE,
- max_business_user_connection_cache_size: int = MAX_CACHE_SIZE
+ max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
+ protocol_factory: Type[TCP] = TCPAbridged,
+ loop: Optional[asyncio.AbstractEventLoop] = None
):
super().__init__()
@@ -371,7 +376,14 @@ class Client(Methods):
self.updates_watchdog_event = asyncio.Event()
self.last_update_time = datetime.now()
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
- self.loop = asyncio.get_event_loop()
+
+ if isinstance(loop, asyncio.AbstractEventLoop):
+ self.loop = loop
+ else:
+ try:
+ self.loop = asyncio.get_running_loop()
+ except RuntimeError:
+ self.loop = asyncio.new_event_loop()
def __enter__(self):
return self.start()
@@ -425,7 +437,7 @@ class Client(Methods):
if not self.phone_number:
while True:
print("Enter 'qrcode' if you want to login with qrcode.")
- value = await ainput("Enter phone number or bot token: ")
+ value = await ainput("Enter phone number or bot token: ", loop=self.loop)
if not value:
continue
@@ -434,7 +446,7 @@ class Client(Methods):
self.use_qrcode = True
break
- confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
+ confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y":
break
@@ -466,7 +478,7 @@ class Client(Methods):
while True:
if not self.use_qrcode and not self.phone_code:
- self.phone_code = await ainput("Enter confirmation code: ")
+ self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
try:
if self.use_qrcode:
@@ -483,18 +495,18 @@ class Client(Methods):
print("Password hint: {}".format(await self.get_password_hint()))
if not self.password:
- self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password)
+ self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
try:
if not self.password:
- confirm = await ainput("Confirm password recovery (y/n): ")
+ confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
if confirm == "y":
email_pattern = await self.send_recovery_code()
print(f"The recovery code has been sent to {email_pattern}")
while True:
- recovery_code = await ainput("Enter recovery code: ")
+ recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
try:
return await self.recover_password(recovery_code)
@@ -842,13 +854,13 @@ class Client(Methods):
else:
while True:
try:
- value = int(await ainput("Enter the api_id part of the API key: "))
+ value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
if value <= 0:
print("Invalid value")
continue
- confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
+ confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y":
await self.storage.api_id(value)
diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py
index 1f100acc..a7f6acdb 100644
--- a/pyrogram/connection/connection.py
+++ b/pyrogram/connection/connection.py
@@ -38,7 +38,8 @@ class Connection:
alt_port: bool,
proxy: dict,
media: bool = False,
- protocol_factory: Type[TCP] = TCPAbridged
+ protocol_factory: Type[TCP] = TCPAbridged,
+ loop: Optional[asyncio.AbstractEventLoop] = None
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
@@ -51,9 +52,17 @@ class Connection:
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.protocol: Optional[TCP] = None
+ if isinstance(loop, asyncio.AbstractEventLoop):
+ self.loop = loop
+ else:
+ try:
+ self.loop = asyncio.get_running_loop()
+ except RuntimeError:
+ self.loop = asyncio.new_event_loop()
+
async def connect(self) -> None:
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
- self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
+ self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
try:
log.info("Connecting...")
diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py
index 3e403b33..fcfbe6ad 100644
--- a/pyrogram/connection/transport/tcp/tcp.py
+++ b/pyrogram/connection/transport/tcp/tcp.py
@@ -46,7 +46,7 @@ class Proxy(TypedDict):
class TCP:
TIMEOUT = 10
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
+ def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.ipv6 = ipv6
self.proxy = proxy
@@ -54,7 +54,14 @@ class TCP:
self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock()
- self.loop = asyncio.get_event_loop()
+
+ if isinstance(loop, asyncio.AbstractEventLoop):
+ self.loop = loop
+ else:
+ try:
+ self.loop = asyncio.get_running_loop()
+ except RuntimeError:
+ self.loop = asyncio.new_event_loop()
async def _connect_via_proxy(
self,
diff --git a/pyrogram/connection/transport/tcp/tcp_abridged.py b/pyrogram/connection/transport/tcp/tcp_abridged.py
index f23d26fd..4381220d 100644
--- a/pyrogram/connection/transport/tcp/tcp_abridged.py
+++ b/pyrogram/connection/transport/tcp/tcp_abridged.py
@@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see .
+import asyncio
import logging
from typing import Optional, Tuple
@@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP):
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
- super().__init__(ipv6, proxy)
+ def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+ super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
diff --git a/pyrogram/connection/transport/tcp/tcp_abridged_o.py b/pyrogram/connection/transport/tcp/tcp_abridged_o.py
index 0b856728..3bd74725 100644
--- a/pyrogram/connection/transport/tcp/tcp_abridged_o.py
+++ b/pyrogram/connection/transport/tcp/tcp_abridged_o.py
@@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see .
+import asyncio
import logging
import os
from typing import Optional, Tuple
@@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
- super().__init__(ipv6, proxy)
+ def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+ super().__init__(ipv6, proxy, loop)
self.encrypt = None
self.decrypt = None
diff --git a/pyrogram/connection/transport/tcp/tcp_full.py b/pyrogram/connection/transport/tcp/tcp_full.py
index 894ca0e2..f59c1558 100644
--- a/pyrogram/connection/transport/tcp/tcp_full.py
+++ b/pyrogram/connection/transport/tcp/tcp_full.py
@@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see .
+import asyncio
import logging
from binascii import crc32
from struct import pack, unpack
@@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
class TCPFull(TCP):
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
- super().__init__(ipv6, proxy)
+ def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+ super().__init__(ipv6, proxy, loop)
self.seq_no: Optional[int] = None
diff --git a/pyrogram/connection/transport/tcp/tcp_intermediate.py b/pyrogram/connection/transport/tcp/tcp_intermediate.py
index ed1c1e57..aa3757e1 100644
--- a/pyrogram/connection/transport/tcp/tcp_intermediate.py
+++ b/pyrogram/connection/transport/tcp/tcp_intermediate.py
@@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see .
+import asyncio
import logging
from struct import pack, unpack
from typing import Optional, Tuple
@@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP):
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
- super().__init__(ipv6, proxy)
+ def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+ super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
diff --git a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py
index 61131b33..8a186ef4 100644
--- a/pyrogram/connection/transport/tcp/tcp_intermediate_o.py
+++ b/pyrogram/connection/transport/tcp/tcp_intermediate_o.py
@@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see .
+import asyncio
import logging
import os
from struct import pack, unpack
@@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
- def __init__(self, ipv6: bool, proxy: Proxy) -> None:
- super().__init__(ipv6, proxy)
+ def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
+ super().__init__(ipv6, proxy, loop)
self.encrypt = None
self.decrypt = None
diff --git a/pyrogram/dispatcher.py b/pyrogram/dispatcher.py
index c5f166d5..7d201d60 100644
--- a/pyrogram/dispatcher.py
+++ b/pyrogram/dispatcher.py
@@ -95,7 +95,6 @@ class Dispatcher:
def __init__(self, client: "pyrogram.Client"):
self.client = client
- self.loop = asyncio.get_event_loop()
self.handler_worker_tasks = []
self.locks_list = []
@@ -271,7 +270,7 @@ class Dispatcher:
self.locks_list.append(asyncio.Lock())
self.handler_worker_tasks.append(
- self.loop.create_task(self.handler_worker(self.locks_list[-1]))
+ self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
)
log.info("Started %s HandlerTasks", self.client.workers)
@@ -311,7 +310,7 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
- self.loop.create_task(fn())
+ self.client.loop.create_task(fn())
def remove_handler(self, handler, group: int):
async def fn():
@@ -333,7 +332,7 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
- self.loop.create_task(fn())
+ self.client.loop.create_task(fn())
async def handler_worker(self, lock: asyncio.Lock):
while True:
@@ -404,7 +403,7 @@ class Dispatcher:
if inspect.iscoroutinefunction(handler.callback):
await handler.callback(self.client, *args)
else:
- await self.loop.run_in_executor(
+ await self.client.loop.run_in_executor(
self.client.executor,
handler.callback,
self.client,
diff --git a/pyrogram/methods/messages/download_media.py b/pyrogram/methods/messages/download_media.py
index 6ba62526..2099e677 100644
--- a/pyrogram/methods/messages/download_media.py
+++ b/pyrogram/methods/messages/download_media.py
@@ -185,4 +185,4 @@ class DownloadMedia:
if block:
return await downloader
else:
- asyncio.get_event_loop().create_task(downloader)
+ self.loop.create_task(downloader)
diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py
index 346b0f92..23b6e8d7 100644
--- a/pyrogram/session/auth.py
+++ b/pyrogram/session/auth.py
@@ -52,6 +52,7 @@ class Auth:
self.proxy = client.proxy
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
+ self.loop = client.loop
self.connection: Optional[Connection] = None
@@ -93,7 +94,8 @@ class Auth:
alt_port=self.alt_port,
proxy=self.proxy,
media=False,
- protocol_factory=self.protocol_factory
+ protocol_factory=self.protocol_factory,
+ loop=self.loop
)
try:
diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py
index 79801086..3424b42e 100644
--- a/pyrogram/session/session.py
+++ b/pyrogram/session/session.py
@@ -104,8 +104,6 @@ class Session:
self.is_started = asyncio.Event()
- self.loop = asyncio.get_event_loop()
-
self.last_reconnect_attempt = None
async def start(self):
@@ -117,13 +115,14 @@ class Session:
alt_port=self.client.alt_port,
proxy=self.client.proxy,
media=self.is_media,
- protocol_factory=self.client.protocol_factory
+ protocol_factory=self.client.protocol_factory,
+ loop=self.client.loop
)
try:
await self.connection.connect()
- self.recv_task = self.loop.create_task(self.recv_worker())
+ self.recv_task = self.client.loop.create_task(self.recv_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@@ -145,7 +144,7 @@ class Session:
timeout=self.START_TIMEOUT
)
- self.ping_task = self.loop.create_task(self.ping_worker())
+ self.ping_task = self.client.loop.create_task(self.ping_worker())
log.info("Session initialized: Layer %s", layer)
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
@@ -207,7 +206,7 @@ class Session:
async def handle_packet(self, packet):
try:
- data = await self.loop.run_in_executor(
+ data = await self.client.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.unpack,
BytesIO(packet),
@@ -217,7 +216,7 @@ class Session:
)
except ValueError as e:
log.debug(e)
- self.loop.create_task(self.restart())
+ self.client.loop.create_task(self.restart())
return
messages = (
@@ -279,7 +278,7 @@ class Session:
msg_id = msg.body.msg_id
else:
if self.client is not None:
- self.loop.create_task(self.client.handle_updates(msg.body))
+ self.client.loop.create_task(self.client.handle_updates(msg.body))
if msg_id in self.results:
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
@@ -313,7 +312,7 @@ class Session:
), False
)
except OSError:
- self.loop.create_task(self.restart())
+ self.client.loop.create_task(self.restart())
break
except RPCError:
pass
@@ -342,11 +341,11 @@ class Session:
)
if self.is_started.is_set():
- self.loop.create_task(self.restart())
+ self.client.loop.create_task(self.restart())
break
- self.loop.create_task(self.handle_packet(packet))
+ self.client.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped")
@@ -359,7 +358,7 @@ class Session:
log.debug("Sent: %s", message)
- payload = await self.loop.run_in_executor(
+ payload = await self.client.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.pack,
message,
diff --git a/pyrogram/utils.py b/pyrogram/utils.py
index 4ad6c5c2..fca5dbb6 100644
--- a/pyrogram/utils.py
+++ b/pyrogram/utils.py
@@ -43,13 +43,21 @@ PyromodConfig = SimpleNamespace(
unallowed_click_alert=True,
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
)
-
-async def ainput(prompt: str = "", *, hide: bool = False):
+
+async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Just like the built-in input, but async"""
+ if isinstance(loop, asyncio.AbstractEventLoop):
+ loop = loop
+ else:
+ try:
+ loop = asyncio.get_running_loop()
+ except RuntimeError:
+ loop = asyncio.new_event_loop()
+
with ThreadPoolExecutor(1) as executor:
func = functools.partial(getpass if hide else input, prompt)
- return await asyncio.get_event_loop().run_in_executor(executor, func)
+ return await loop.run_in_executor(executor, func)
def get_input_media_from_file_id(