Compare commits

...

11 commits

Author SHA1 Message Date
wulan17
fd17d1ec5d
pyrofork: Refactor test_mode
Some checks failed
Pyrofork / build (macos-latest, 3.10) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.11) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.12) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.13) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.9) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.10) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.11) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.12) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.13) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.9) (push) Has been cancelled
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-10 21:37:41 +07:00
wulan17
6e9e1740b0
pyrofork: Retrive dc address and port from GetConfig
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-10 21:35:46 +07:00
wulan17
375e97d963
Reapply "fix: handle connection closure and retry logic in session management"
This reverts commit 2c3fb1caa6.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:42 +07:00
KurimuzonAkuma
3aed1857d4
Fix 'Client' object has no attribute 'loop'
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:42 +07:00
KurimuzonAkuma
3d71eba4e1
Fix attached to a different loop
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:42 +07:00
KurimuzonAkuma
1a4c578380
Refactor loop
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:41 +07:00
KurimuzonAkuma
e3a8a781d3
Attempt to fix handler is closed
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:41 +07:00
S!R X
921b593285
Fix connection to proxy
Fix the proxy bug.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:40 +07:00
Hitalo M.
b8028541c9
fix(session): prevent task cancellation race condition in stop method
The fix properly cancels the recv_task and suppresses CancelledError when awaiting
it during session shutdown. This resolves the "read() called while another
coroutine is already waiting for incoming data" RuntimeError that occurred when
stopping sessions during reconnection attempts.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 20:52:34 +07:00
Hitalo M.
01e7717e52
refactor(session): replace recursion with loop and add backoff
This refactor replaces recursion with a loop in the session invoke logic. Additionally, a backoff mechanism has been introduced to prevent frequent restarts from crashing the bot.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 19:59:42 +07:00
wulan17
b79ffac690
pyrofork: disable publish workflows
Some checks failed
Build-docs / build (push) Has been cancelled
Pyrofork / build (macos-latest, 3.10) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.11) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.12) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.13) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.9) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.10) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.11) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.12) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.13) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.9) (push) Has been cancelled
Signed-off-by: wulan17 <wulan17@nusantararom.org>
2025-06-06 19:16:53 +07:00
23 changed files with 394 additions and 105 deletions

View file

@ -1,40 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
push:
tags:
- '*'
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
environment: release
permissions:
id-token: write
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]'
- name: Build package
run: hatch build
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1

View file

@ -33,7 +33,7 @@ from importlib import import_module
from io import StringIO, BytesIO from io import StringIO, BytesIO
from mimetypes import MimeTypes from mimetypes import MimeTypes
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram import pyrogram
from pyrogram import __version__, __license__ from pyrogram import __version__, __license__
@ -55,12 +55,13 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
from pyrogram.types import User from pyrogram.types import User
from pyrogram.utils import ainput from pyrogram.utils import ainput
from .connection import Connection from .connection import Connection
from .connection.transport import TCPAbridged from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types from .mime_types import mime_types
from .parser import Parser from .parser import Parser
from .session.internals import MsgId from .session.internals import MsgId
from .session.internals.data_center import DataCenter
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
MONGO_AVAIL = False MONGO_AVAIL = False
@ -220,6 +221,9 @@ class Client(Methods):
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*): client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
The platform where this client is running. The platform where this client is running.
Defaults to 'other' Defaults to 'other'
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
Event loop.
""" """
APP_VERSION = f"Pyrogram {__version__}" APP_VERSION = f"Pyrogram {__version__}"
@ -277,7 +281,9 @@ class Client(Methods):
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS, max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER, client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
max_message_cache_size: int = MAX_CACHE_SIZE, max_message_cache_size: int = MAX_CACHE_SIZE,
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
): ):
super().__init__() super().__init__()
@ -314,6 +320,8 @@ class Client(Methods):
self.max_message_cache_size = max_message_cache_size self.max_message_cache_size = max_message_cache_size
self.max_message_cache_size = max_message_cache_size self.max_message_cache_size = max_message_cache_size
self.max_business_user_connection_cache_size = max_business_user_connection_cache_size self.max_business_user_connection_cache_size = max_business_user_connection_cache_size
self.test_addr = None
self.test_port = None
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler") self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")
@ -371,6 +379,10 @@ class Client(Methods):
self.updates_watchdog_event = asyncio.Event() self.updates_watchdog_event = asyncio.Event()
self.last_update_time = datetime.now() self.last_update_time = datetime.now()
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes} self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
def __enter__(self): def __enter__(self):
@ -391,6 +403,25 @@ class Client(Methods):
except ConnectionError: except ConnectionError:
pass pass
def set_dc(self, addr: str, port: int = 80):
"""Set the data center address and port.
Parameters:
addr (``str``):
The data center address, e.g.: "149.154.167.40".
port (``int``, *optional*):
The data center port, e.g.: 443.
Defaults to 80.
"""
if not isinstance(addr, str):
raise TypeError("addr must be a string")
if not isinstance(port, int):
raise TypeError("port must be an integer")
self.test_addr = addr
self.test_port = port
async def updates_watchdog(self): async def updates_watchdog(self):
while True: while True:
try: try:
@ -425,7 +456,7 @@ class Client(Methods):
if not self.phone_number: if not self.phone_number:
while True: while True:
print("Enter 'qrcode' if you want to login with qrcode.") print("Enter 'qrcode' if you want to login with qrcode.")
value = await ainput("Enter phone number or bot token: ") value = await ainput("Enter phone number or bot token: ", loop=self.loop)
if not value: if not value:
continue continue
@ -434,7 +465,7 @@ class Client(Methods):
self.use_qrcode = True self.use_qrcode = True
break break
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower() confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y": if confirm == "y":
break break
@ -466,7 +497,7 @@ class Client(Methods):
while True: while True:
if not self.use_qrcode and not self.phone_code: if not self.use_qrcode and not self.phone_code:
self.phone_code = await ainput("Enter confirmation code: ") self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
try: try:
if self.use_qrcode: if self.use_qrcode:
@ -483,18 +514,18 @@ class Client(Methods):
print("Password hint: {}".format(await self.get_password_hint())) print("Password hint: {}".format(await self.get_password_hint()))
if not self.password: if not self.password:
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password) self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
try: try:
if not self.password: if not self.password:
confirm = await ainput("Confirm password recovery (y/n): ") confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
if confirm == "y": if confirm == "y":
email_pattern = await self.send_recovery_code() email_pattern = await self.send_recovery_code()
print(f"The recovery code has been sent to {email_pattern}") print(f"The recovery code has been sent to {email_pattern}")
while True: while True:
recovery_code = await ainput("Enter recovery code: ") recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
try: try:
return await self.recover_password(recovery_code) return await self.recover_password(recovery_code)
@ -521,6 +552,16 @@ class Client(Methods):
break break
if isinstance(signed_in, User): if isinstance(signed_in, User):
if not self.test_mode:
is_dc_default = await self.check_dc_default()
if is_dc_default:
log.info("Your session is using the default data center.")
log.info("Updating the data center options from GetConfig...")
await self.update_dc_option()
log.info("Data center updated successfully.")
log.info("Restarting the session to apply the changes...")
await self.session.stop()
await self.session.start()
return signed_in return signed_in
def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]): def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
@ -842,19 +883,64 @@ class Client(Methods):
else: else:
while True: while True:
try: try:
value = int(await ainput("Enter the api_id part of the API key: ")) value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
if value <= 0: if value <= 0:
print("Invalid value") print("Invalid value")
continue continue
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower() confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y": if confirm == "y":
await self.storage.api_id(value) await self.storage.api_id(value)
break break
except Exception as e: except Exception as e:
print(e) print(e)
# Needed for migration from storage v3 to v4
if self.in_memory or self.session_string:
await self.insert_default_dc_options()
if not await self.storage.get_dc_address(await self.storage.dc_id(), self.ipv6):
log.info("No DC address found, inserting default DC options...")
await self.insert_default_dc_options()
async def insert_default_dc_options(self):
for dc_id in range(1, 6):
for is_ipv6 in (False, True):
if dc_id in [2,4]:
for media in (False, True):
address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, media)
)
else:
address, port = DataCenter(dc_id, False, is_ipv6, False, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False)
)
async def update_dc_option(self):
config = await self.invoke(raw.functions.help.GetConfig())
for option in config.dc_options:
await self.storage.update_dc_address(
(option.id, option.address, option.port, option.is_ipv6, option.is_media)
)
async def check_dc_default(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> bool:
default_dc = DataCenter(
dc_id,
test_mode=False,
is_ipv6=is_ipv6,
alt_port=self.alt_port,
media=media
)
current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, media)
if current_dc is not None and current_dc[0] == default_dc.address:
return True
def is_excluded(self, exclude, module): def is_excluded(self, exclude, module):
for e in exclude: for e in exclude:

View file

@ -22,7 +22,6 @@ import logging
from typing import Optional, Type from typing import Optional, Type
from .transport import TCP, TCPAbridged from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -34,26 +33,35 @@ class Connection:
self, self,
dc_id: int, dc_id: int,
test_mode: bool, test_mode: bool,
server_ip: str,
server_port: int,
ipv6: bool, ipv6: bool,
alt_port: bool, alt_port: bool,
proxy: dict, proxy: dict,
media: bool = False, media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> None: ) -> None:
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.server_ip = server_ip
self.server_port = 5222 if alt_port else server_port
self.ipv6 = ipv6 self.ipv6 = ipv6
self.alt_port = alt_port
self.proxy = proxy self.proxy = proxy
self.media = media self.media = media
self.protocol_factory = protocol_factory self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media) self.address = (server_ip, server_port)
self.protocol: Optional[TCP] = None self.protocol: Optional[TCP] = None
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
self.loop = asyncio.get_event_loop()
async def connect(self) -> None: async def connect(self) -> None:
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS): for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy) self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
try: try:
log.info("Connecting...") log.info("Connecting...")

View file

@ -21,6 +21,7 @@ import asyncio
import ipaddress import ipaddress
import logging import logging
import socket import socket
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Dict, TypedDict, Optional from typing import Tuple, Dict, TypedDict, Optional
import socks import socks
@ -45,7 +46,7 @@ class Proxy(TypedDict):
class TCP: class TCP:
TIMEOUT = 10 TIMEOUT = 10
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.ipv6 = ipv6 self.ipv6 = ipv6
self.proxy = proxy self.proxy = proxy
@ -53,7 +54,18 @@ class TCP:
self.writer: Optional[asyncio.StreamWriter] = None self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
self.loop = asyncio.get_event_loop() self.loop = asyncio.get_event_loop()
self._closed = True
@property
def closed(self) -> bool:
return (
self._closed or self.writer is None or self.writer.is_closing() or self.reader is None
)
async def _connect_via_proxy( async def _connect_via_proxy(
self, self,
@ -91,10 +103,8 @@ class TCP:
) )
sock.settimeout(TCP.TIMEOUT) sock.settimeout(TCP.TIMEOUT)
await self.loop.sock_connect( with ThreadPoolExecutor() as executor:
sock=sock, await self.loop.run_in_executor(executor, sock.connect, destination)
address=destination
)
sock.setblocking(False) sock.setblocking(False)
@ -123,11 +133,14 @@ class TCP:
async def connect(self, address: Tuple[str, int]) -> None: async def connect(self, address: Tuple[str, int]) -> None:
try: try:
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
self._closed = False
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
self._closed = True
raise TimeoutError("Connection timed out") raise TimeoutError("Connection timed out")
async def close(self) -> None: async def close(self) -> None:
if self.writer is None: if self.writer is None:
self._closed = True
return None return None
try: try:
@ -135,20 +148,27 @@ class TCP:
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
except Exception as e: except Exception as e:
log.info("Close exception: %s %s", type(e).__name__, e) log.info("Close exception: %s %s", type(e).__name__, e)
finally:
self.writer = None
self._closed = True
async def send(self, data: bytes) -> None: async def send(self, data: bytes) -> None:
if self.writer is None: async with self.lock:
if self.writer is None or self.writer.is_closing():
return None return None
async with self.lock:
try: try:
self.writer.write(data) self.writer.write(data)
await self.writer.drain() await self.writer.drain()
except Exception as e: except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e) log.info("Send exception: %s %s", type(e).__name__, e)
self._closed = True
raise OSError(e) raise OSError(e)
async def recv(self, length: int = 0) -> Optional[bytes]: async def recv(self, length: int = 0) -> Optional[bytes]:
if self._closed or self.reader is None:
return None
data = b"" data = b""
while len(data) < length: while len(data) < length:
@ -158,11 +178,13 @@ class TCP:
TCP.TIMEOUT TCP.TIMEOUT
) )
except (OSError, asyncio.TimeoutError): except (OSError, asyncio.TimeoutError):
self._closed = True
return None return None
else: else:
if chunk: if chunk:
data += chunk data += chunk
else: else:
self._closed = True
return None return None
return data return data

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from typing import Optional, Tuple from typing import Optional, Tuple
@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP): class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None: async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import os import os
from typing import Optional, Tuple from typing import Optional, Tuple
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP): class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from binascii import crc32 from binascii import crc32
from struct import pack, unpack from struct import pack, unpack
@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
class TCPFull(TCP): class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
self.seq_no: Optional[int] = None self.seq_no: Optional[int] = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
from struct import pack, unpack from struct import pack, unpack
from typing import Optional, Tuple from typing import Optional, Tuple
@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP): class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None: async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address) await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import os import os
from struct import pack, unpack from struct import pack, unpack
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP): class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4) RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None: def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy) super().__init__(ipv6, proxy, loop)
self.encrypt = None self.encrypt = None
self.decrypt = None self.decrypt = None

View file

@ -95,7 +95,6 @@ class Dispatcher:
def __init__(self, client: "pyrogram.Client"): def __init__(self, client: "pyrogram.Client"):
self.client = client self.client = client
self.loop = asyncio.get_event_loop()
self.handler_worker_tasks = [] self.handler_worker_tasks = []
self.locks_list = [] self.locks_list = []
@ -271,7 +270,7 @@ class Dispatcher:
self.locks_list.append(asyncio.Lock()) self.locks_list.append(asyncio.Lock())
self.handler_worker_tasks.append( self.handler_worker_tasks.append(
self.loop.create_task(self.handler_worker(self.locks_list[-1])) self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
) )
log.info("Started %s HandlerTasks", self.client.workers) log.info("Started %s HandlerTasks", self.client.workers)
@ -311,7 +310,7 @@ class Dispatcher:
for lock in self.locks_list: for lock in self.locks_list:
lock.release() lock.release()
self.loop.create_task(fn()) self.client.loop.create_task(fn())
def remove_handler(self, handler, group: int): def remove_handler(self, handler, group: int):
async def fn(): async def fn():
@ -333,7 +332,7 @@ class Dispatcher:
for lock in self.locks_list: for lock in self.locks_list:
lock.release() lock.release()
self.loop.create_task(fn()) self.client.loop.create_task(fn())
async def handler_worker(self, lock: asyncio.Lock): async def handler_worker(self, lock: asyncio.Lock):
while True: while True:
@ -404,7 +403,7 @@ class Dispatcher:
if inspect.iscoroutinefunction(handler.callback): if inspect.iscoroutinefunction(handler.callback):
await handler.callback(self.client, *args) await handler.callback(self.client, *args)
else: else:
await self.loop.run_in_executor( await self.client.loop.run_in_executor(
self.client.executor, self.client.executor,
handler.callback, handler.callback,
self.client, self.client,

View file

@ -49,6 +49,6 @@ class Initialize:
await self.dispatcher.start() await self.dispatcher.start()
self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog()) self.updates_watchdog_task = self.loop.create_task(self.updates_watchdog())
self.is_initialized = True self.is_initialized = True

View file

@ -61,6 +61,8 @@ class SendCode:
) )
) )
except (PhoneMigrate, NetworkMigrate) as e: except (PhoneMigrate, NetworkMigrate) as e:
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
await self.session.stop() await self.session.stop()

View file

@ -58,6 +58,8 @@ class SignInBot:
) )
) )
except UserMigrate as e: except UserMigrate as e:
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
await self.session.stop() await self.session.stop()

View file

@ -80,6 +80,8 @@ class SignInQrcode:
return types.User._parse(self, r.authorization.user) return types.User._parse(self, r.authorization.user)
if isinstance(r, raw.types.auth.LoginTokenMigrateTo): if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
await self.session.stop() await self.session.stop()

View file

@ -185,4 +185,4 @@ class DownloadMedia:
if block: if block:
return await downloader return await downloader
else: else:
asyncio.get_event_loop().create_task(downloader) self.loop.create_task(downloader)

View file

@ -71,8 +71,7 @@ class Run:
app.run(main()) app.run(main())
""" """
loop = asyncio.get_event_loop() run = self.loop.run_until_complete
run = loop.run_until_complete
if coroutine is not None: if coroutine is not None:
run(coroutine) run(coroutine)

View file

@ -45,13 +45,16 @@ class Auth:
dc_id: int, dc_id: int,
test_mode: bool test_mode: bool
): ):
self.client = client
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.ipv6 = client.ipv6 self.ipv6 = client.ipv6
self.alt_port = client.alt_port self.alt_port = client.alt_port
self.proxy = client.proxy self.proxy = client.proxy
self.storage = client.storage
self.connection_factory = client.connection_factory self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory self.protocol_factory = client.protocol_factory
self.loop = client.loop
self.connection: Optional[Connection] = None self.connection: Optional[Connection] = None
@ -85,15 +88,25 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail. # The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times. # If that happens, just try again up to MAX_RETRIES times.
if not self.test_mode:
address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True: while True:
self.connection = self.connection_factory( self.connection = self.connection_factory(
dc_id=self.dc_id, dc_id=self.dc_id,
test_mode=self.test_mode, test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.ipv6, ipv6=self.ipv6,
alt_port=self.alt_port, alt_port=self.alt_port,
proxy=self.proxy, proxy=self.proxy,
media=False, media=False,
protocol_factory=self.protocol_factory protocol_factory=self.protocol_factory,
loop=self.loop
) )
try: try:

View file

@ -19,8 +19,10 @@
import asyncio import asyncio
import bisect import bisect
import contextlib
import logging import logging
import os import os
from datetime import datetime, timedelta
from hashlib import sha1 from hashlib import sha1
from io import BytesIO from io import BytesIO
from typing import Optional from typing import Optional
@ -56,6 +58,7 @@ class Session:
ACKS_THRESHOLD = 10 ACKS_THRESHOLD = 10
PING_INTERVAL = 5 PING_INTERVAL = 5
STORED_MSG_IDS_MAX_SIZE = 500 STORED_MSG_IDS_MAX_SIZE = 500
RECONNECT_THRESHOLD = timedelta(seconds=10)
TRANSPORT_ERRORS = { TRANSPORT_ERRORS = {
404: "auth key not found", 404: "auth key not found",
@ -101,24 +104,34 @@ class Session:
self.is_started = asyncio.Event() self.is_started = asyncio.Event()
self.loop = asyncio.get_event_loop() self.last_reconnect_attempt = None
async def start(self): async def start(self):
if not self.test_mode:
address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.is_media)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True: while True:
self.connection = self.client.connection_factory( self.connection = self.client.connection_factory(
dc_id=self.dc_id, dc_id=self.dc_id,
test_mode=self.test_mode, test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.client.ipv6, ipv6=self.client.ipv6,
alt_port=self.client.alt_port, alt_port=self.client.alt_port,
proxy=self.client.proxy, proxy=self.client.proxy,
media=self.is_media, media=self.is_media,
protocol_factory=self.client.protocol_factory protocol_factory=self.client.protocol_factory,
loop=self.client.loop
) )
try: try:
await self.connection.connect() await self.connection.connect()
self.recv_task = self.loop.create_task(self.recv_worker()) self.recv_task = self.client.loop.create_task(self.recv_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT) await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@ -140,7 +153,7 @@ class Session:
timeout=self.START_TIMEOUT timeout=self.START_TIMEOUT
) )
self.ping_task = self.loop.create_task(self.ping_worker()) self.ping_task = self.client.loop.create_task(self.ping_worker())
log.info("Session initialized: Layer %s", layer) log.info("Session initialized: Layer %s", layer)
log.info("Device: %s - %s", self.client.device_model, self.client.app_version) log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
@ -175,6 +188,8 @@ class Session:
await self.connection.close() await self.connection.close()
if self.recv_task: if self.recv_task:
self.recv_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.recv_task await self.recv_task
if not self.is_media and callable(self.client.disconnect_handler): if not self.is_media and callable(self.client.disconnect_handler):
@ -186,12 +201,21 @@ class Session:
log.info("Session stopped") log.info("Session stopped")
async def restart(self): async def restart(self):
now = datetime.now()
if (
self.last_reconnect_attempt
and now - self.last_reconnect_attempt < self.RECONNECT_THRESHOLD
):
log.info("Reconnecting too frequently, sleeping for a while")
await asyncio.sleep(5)
self.last_reconnect_attempt = now
await self.stop() await self.stop()
await self.start() await self.start()
async def handle_packet(self, packet): async def handle_packet(self, packet):
try: try:
data = await self.loop.run_in_executor( data = await self.client.loop.run_in_executor(
pyrogram.crypto_executor, pyrogram.crypto_executor,
mtproto.unpack, mtproto.unpack,
BytesIO(packet), BytesIO(packet),
@ -201,7 +225,7 @@ class Session:
) )
except ValueError as e: except ValueError as e:
log.debug(e) log.debug(e)
self.loop.create_task(self.restart()) self.client.loop.create_task(self.restart())
return return
messages = ( messages = (
@ -263,7 +287,7 @@ class Session:
msg_id = msg.body.msg_id msg_id = msg.body.msg_id
else: else:
if self.client is not None: if self.client is not None:
self.loop.create_task(self.client.handle_updates(msg.body)) self.client.loop.create_task(self.client.handle_updates(msg.body))
if msg_id in self.results: if msg_id in self.results:
self.results[msg_id].value = getattr(msg.body, "result", msg.body) self.results[msg_id].value = getattr(msg.body, "result", msg.body)
@ -297,7 +321,7 @@ class Session:
), False ), False
) )
except OSError: except OSError:
self.loop.create_task(self.restart()) self.client.loop.create_task(self.restart())
break break
except RPCError: except RPCError:
pass pass
@ -326,11 +350,11 @@ class Session:
) )
if self.is_started.is_set(): if self.is_started.is_set():
self.loop.create_task(self.restart()) self.client.loop.create_task(self.restart())
break break
self.loop.create_task(self.handle_packet(packet)) self.client.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped") log.info("NetworkTask stopped")
@ -343,7 +367,7 @@ class Session:
log.debug("Sent: %s", message) log.debug("Sent: %s", message)
payload = await self.loop.run_in_executor( payload = await self.client.loop.run_in_executor(
pyrogram.crypto_executor, pyrogram.crypto_executor,
mtproto.pack, mtproto.pack,
message, message,
@ -415,8 +439,21 @@ class Session:
query_name = ".".join(inner_query.QUALNAME.split(".")[1:]) query_name = ".".join(inner_query.QUALNAME.split(".")[1:])
while True: while retries > 0:
try: try:
if (
self.connection is None
or self.connection.protocol is None
or getattr(self.connection.protocol, "closed", True)
):
log.warning(
"[%s] Connection is closed or not established. Attempting to reconnect...",
self.client.name,
)
await self.restart()
await asyncio.sleep(1)
continue
return await self.send(query, timeout=timeout) return await self.send(query, timeout=timeout)
except (FloodWait, FloodPremiumWait) as e: except (FloodWait, FloodPremiumWait) as e:
amount = e.value amount = e.value
@ -429,15 +466,26 @@ class Session:
await asyncio.sleep(amount) await asyncio.sleep(amount)
except (OSError, InternalServerError, ServiceUnavailable) as e: except (OSError, InternalServerError, ServiceUnavailable) as e:
retries -= 1
if retries == 0: if retries == 0:
raise e from None raise e
(log.warning if retries < 2 else log.info)( (log.warning if retries < 2 else log.info)(
'[%s] Retrying "%s" due to: %s', '[%s] Retrying "%s" due to: %s',
Session.MAX_RETRIES - retries + 1, Session.MAX_RETRIES - retries,
query_name, str(e) or repr(e) query_name, str(e) or repr(e)
) )
if isinstance(e, OSError) and retries > 1:
try:
await self.restart()
except Exception as restart_error:
log.warning(
"[%s] Failed to restart session: %s",
self.client.name,
str(restart_error) or repr(restart_error),
)
await asyncio.sleep(0.5) await asyncio.sleep(0.5)
return await self.invoke(query, retries - 1, timeout) raise TimeoutError("Exceeded maximum number of retries")

View file

@ -68,6 +68,12 @@ class FileStorage(SQLiteStorage):
version += 1 version += 1
if version == 4:
with self.conn:
self.conn.executescript(self.UPDATE_DC_SCHEMA)
version += 1
self.version(version) self.version(version)
async def open(self): async def open(self):

View file

@ -77,6 +77,7 @@ class MongoStorage(Storage):
self._session = database['session'] self._session = database['session']
self._usernames = database['usernames'] self._usernames = database['usernames']
self._states = database['update_state'] self._states = database['update_state']
self._dc_options = database['dc_options']
self._remove_peers = remove_peers self._remove_peers = remove_peers
async def open(self): async def open(self):
@ -221,6 +222,52 @@ class MongoStorage(Storage):
return get_input_peer(r['_id'], r['access_hash'], r['type']) return get_input_peer(r['_id'], r['access_hash'], r['type'])
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_media (bool): Whether it is a media data center.
"""
if value == object:
return
await self._dc_options.update_one(
{"$and": [
{'dc_id': value[0]},
{'is_ipv6': value[3]},
{'is_media': value[4]}
]},
{'$set': {'address': value[1], 'port': value[2]}},
upsert=True
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> Tuple[str, int]:
if dc_id in [1,3,5] and media:
media = False
if dc_id in [4,5] and test_mode:
test_mode = False
r = await self._dc_options.find_one(
{'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_media': media},
{'address': 1, 'port': 1}
)
if r is None:
return None
return r['address'], r['port']
async def _get(self): async def _get(self):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
d = await self._session.find_one({'_id': 0}, {attr: 1}) d = await self._session.find_one({'_id': 0}, {attr: 1})

View file

@ -97,6 +97,20 @@ END;
""" """
UPDATE_DC_SCHEMA = """
CREATE TABLE dc_options
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
dc_id INTEGER,
address TEXT,
port INTEGER,
is_ipv6 BOOLEAN,
is_media BOOLEAN,
UNIQUE(dc_id, is_ipv6, is_media)
);
"""
def get_input_peer(peer_id: int, access_hash: int, peer_type: str): def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
if peer_type in ["user", "bot"]: if peer_type in ["user", "bot"]:
return raw.types.InputPeerUser( return raw.types.InputPeerUser(
@ -121,6 +135,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
class SQLiteStorage(Storage): class SQLiteStorage(Storage):
VERSION = 4 VERSION = 4
USERNAME_TTL = 8 * 60 * 60 USERNAME_TTL = 8 * 60 * 60
UPDATE_DC_SCHEMA = globals().get("UPDATE_DC_SCHEMA", "")
def __init__(self, name: str): def __init__(self, name: str):
super().__init__(name) super().__init__(name)
@ -131,6 +146,7 @@ class SQLiteStorage(Storage):
with self.conn: with self.conn:
self.conn.executescript(SCHEMA) self.conn.executescript(SCHEMA)
self.conn.executescript(UNAME_SCHEMA) self.conn.executescript(UNAME_SCHEMA)
self.conn.executescript(self.UPDATE_DC_SCHEMA)
self.conn.execute( self.conn.execute(
"INSERT INTO version VALUES (?)", "INSERT INTO version VALUES (?)",
@ -252,6 +268,60 @@ class SQLiteStorage(Storage):
return get_input_peer(*r) return get_input_peer(*r)
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool, bool, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_media (bool): Whether it is a media data center.
"""
if value == object:
return
with self.conn:
self.conn.execute(
"""
INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_media)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(dc_id, is_ipv6, is_media)
DO UPDATE SET address=excluded.address, port=excluded.port
""",
value
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> Tuple[str, int]:
"""
Retrieves the address of a data center.
Parameters:
dc_id (int): Data center ID.
is_ipv6 (bool): Whether the address is IPv6.
media (bool): Whether it is a media data center.
Returns:
Tuple[str, int]: A tuple containing the address and port of the data center.
"""
if dc_id in [1,3,5] and media:
media = False
r = self.conn.execute(
"SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_media = ?",
(dc_id, is_ipv6, media)
).fetchone()
return r
def _get(self): def _get(self):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function

View file

@ -76,6 +76,21 @@ class Storage:
async def get_peer_by_phone_number(self, phone_number: str): async def get_peer_by_phone_number(self, phone_number: str):
raise NotImplementedError raise NotImplementedError
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool] = object
):
raise NotImplementedError
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
test_mode: bool = False,
media: bool = False
):
raise NotImplementedError
async def dc_id(self, value: int = object): async def dc_id(self, value: int = object):
raise NotImplementedError raise NotImplementedError

View file

@ -45,11 +45,16 @@ PyromodConfig = SimpleNamespace(
) )
async def ainput(prompt: str = "", *, hide: bool = False): async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Just like the built-in input, but async""" """Just like the built-in input, but async"""
if isinstance(loop, asyncio.AbstractEventLoop):
loop = loop
else:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(1) as executor: with ThreadPoolExecutor(1) as executor:
func = functools.partial(getpass if hide else input, prompt) func = functools.partial(getpass if hide else input, prompt)
return await asyncio.get_event_loop().run_in_executor(executor, func) return await loop.run_in_executor(executor, func)
def get_input_media_from_file_id( def get_input_media_from_file_id(