Compare commits

...

11 commits

Author SHA1 Message Date
wulan17
fd17d1ec5d
pyrofork: Refactor test_mode
Some checks failed
Pyrofork / build (macos-latest, 3.10) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.11) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.12) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.13) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.9) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.10) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.11) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.12) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.13) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.9) (push) Has been cancelled
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-10 21:37:41 +07:00
wulan17
6e9e1740b0
pyrofork: Retrive dc address and port from GetConfig
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-10 21:35:46 +07:00
wulan17
375e97d963
Reapply "fix: handle connection closure and retry logic in session management"
This reverts commit 2c3fb1caa6.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:42 +07:00
KurimuzonAkuma
3aed1857d4
Fix 'Client' object has no attribute 'loop'
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:42 +07:00
KurimuzonAkuma
3d71eba4e1
Fix attached to a different loop
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:42 +07:00
KurimuzonAkuma
1a4c578380
Refactor loop
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:41 +07:00
KurimuzonAkuma
e3a8a781d3
Attempt to fix handler is closed
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:41 +07:00
S!R X
921b593285
Fix connection to proxy
Fix the proxy bug.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 21:01:40 +07:00
Hitalo M.
b8028541c9
fix(session): prevent task cancellation race condition in stop method
The fix properly cancels the recv_task and suppresses CancelledError when awaiting
it during session shutdown. This resolves the "read() called while another
coroutine is already waiting for incoming data" RuntimeError that occurred when
stopping sessions during reconnection attempts.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 20:52:34 +07:00
Hitalo M.
01e7717e52
refactor(session): replace recursion with loop and add backoff
This refactor replaces recursion with a loop in the session invoke logic. Additionally, a backoff mechanism has been introduced to prevent frequent restarts from crashing the bot.

Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-07 19:59:42 +07:00
wulan17
b79ffac690
pyrofork: disable publish workflows
Some checks failed
Build-docs / build (push) Has been cancelled
Pyrofork / build (macos-latest, 3.10) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.11) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.12) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.13) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.9) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.10) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.11) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.12) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.13) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.9) (push) Has been cancelled
Signed-off-by: wulan17 <wulan17@nusantararom.org>
2025-06-06 19:16:53 +07:00
23 changed files with 394 additions and 105 deletions

View file

@ -1,40 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
push:
tags:
- '*'
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
environment: release
permissions:
id-token: write
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]'
- name: Build package
run: hatch build
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1

View file

@ -33,7 +33,7 @@ from importlib import import_module
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple
from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram
from pyrogram import __version__, __license__
@ -55,12 +55,13 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
from pyrogram.types import User
from pyrogram.utils import ainput
from .connection import Connection
from .connection.transport import TCPAbridged
from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
from .parser import Parser
from .session.internals import MsgId
from .session.internals.data_center import DataCenter
log = logging.getLogger(__name__)
MONGO_AVAIL = False
@ -220,6 +221,9 @@ class Client(Methods):
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
The platform where this client is running.
Defaults to 'other'
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
Event loop.
"""
APP_VERSION = f"Pyrogram {__version__}"
@ -277,7 +281,9 @@ class Client(Methods):
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
max_message_cache_size: int = MAX_CACHE_SIZE,
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
):
super().__init__()
@ -314,6 +320,8 @@ class Client(Methods):
self.max_message_cache_size = max_message_cache_size
self.max_message_cache_size = max_message_cache_size
self.max_business_user_connection_cache_size = max_business_user_connection_cache_size
self.test_addr = None
self.test_port = None
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")
@ -371,7 +379,11 @@ class Client(Methods):
self.updates_watchdog_event = asyncio.Event()
self.last_update_time = datetime.now()
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
self.loop = asyncio.get_event_loop()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
self.loop = asyncio.get_event_loop()
def __enter__(self):
return self.start()
@ -391,6 +403,25 @@ class Client(Methods):
except ConnectionError:
pass
def set_dc(self, addr: str, port: int = 80):
"""Set the data center address and port.
Parameters:
addr (``str``):
The data center address, e.g.: "149.154.167.40".
port (``int``, *optional*):
The data center port, e.g.: 443.
Defaults to 80.
"""
if not isinstance(addr, str):
raise TypeError("addr must be a string")
if not isinstance(port, int):
raise TypeError("port must be an integer")
self.test_addr = addr
self.test_port = port
async def updates_watchdog(self):
while True:
try:
@ -425,7 +456,7 @@ class Client(Methods):
if not self.phone_number:
while True:
print("Enter 'qrcode' if you want to login with qrcode.")
value = await ainput("Enter phone number or bot token: ")
value = await ainput("Enter phone number or bot token: ", loop=self.loop)
if not value:
continue
@ -434,7 +465,7 @@ class Client(Methods):
self.use_qrcode = True
break
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y":
break
@ -466,7 +497,7 @@ class Client(Methods):
while True:
if not self.use_qrcode and not self.phone_code:
self.phone_code = await ainput("Enter confirmation code: ")
self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
try:
if self.use_qrcode:
@ -483,18 +514,18 @@ class Client(Methods):
print("Password hint: {}".format(await self.get_password_hint()))
if not self.password:
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password)
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
try:
if not self.password:
confirm = await ainput("Confirm password recovery (y/n): ")
confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
if confirm == "y":
email_pattern = await self.send_recovery_code()
print(f"The recovery code has been sent to {email_pattern}")
while True:
recovery_code = await ainput("Enter recovery code: ")
recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
try:
return await self.recover_password(recovery_code)
@ -521,6 +552,16 @@ class Client(Methods):
break
if isinstance(signed_in, User):
if not self.test_mode:
is_dc_default = await self.check_dc_default()
if is_dc_default:
log.info("Your session is using the default data center.")
log.info("Updating the data center options from GetConfig...")
await self.update_dc_option()
log.info("Data center updated successfully.")
log.info("Restarting the session to apply the changes...")
await self.session.stop()
await self.session.start()
return signed_in
def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
@ -842,19 +883,64 @@ class Client(Methods):
else:
while True:
try:
value = int(await ainput("Enter the api_id part of the API key: "))
value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
if value <= 0:
print("Invalid value")
continue
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
if confirm == "y":
await self.storage.api_id(value)
break
except Exception as e:
print(e)
# Needed for migration from storage v3 to v4
if self.in_memory or self.session_string:
await self.insert_default_dc_options()
if not await self.storage.get_dc_address(await self.storage.dc_id(), self.ipv6):
log.info("No DC address found, inserting default DC options...")
await self.insert_default_dc_options()
async def insert_default_dc_options(self):
for dc_id in range(1, 6):
for is_ipv6 in (False, True):
if dc_id in [2,4]:
for media in (False, True):
address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, media)
)
else:
address, port = DataCenter(dc_id, False, is_ipv6, False, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False)
)
async def update_dc_option(self):
config = await self.invoke(raw.functions.help.GetConfig())
for option in config.dc_options:
await self.storage.update_dc_address(
(option.id, option.address, option.port, option.is_ipv6, option.is_media)
)
async def check_dc_default(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> bool:
default_dc = DataCenter(
dc_id,
test_mode=False,
is_ipv6=is_ipv6,
alt_port=self.alt_port,
media=media
)
current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, media)
if current_dc is not None and current_dc[0] == default_dc.address:
return True
def is_excluded(self, exclude, module):
for e in exclude:

View file

@ -22,7 +22,6 @@ import logging
from typing import Optional, Type
from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter
log = logging.getLogger(__name__)
@ -34,26 +33,35 @@ class Connection:
self,
dc_id: int,
test_mode: bool,
server_ip: str,
server_port: int,
ipv6: bool,
alt_port: bool,
proxy: dict,
media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged
protocol_factory: Type[TCP] = TCPAbridged,
loop: Optional[asyncio.AbstractEventLoop] = None
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
self.server_ip = server_ip
self.server_port = 5222 if alt_port else server_port
self.ipv6 = ipv6
self.alt_port = alt_port
self.proxy = proxy
self.media = media
self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.address = (server_ip, server_port)
self.protocol: Optional[TCP] = None
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
self.loop = asyncio.get_event_loop()
async def connect(self) -> None:
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
try:
log.info("Connecting...")

View file

@ -21,6 +21,7 @@ import asyncio
import ipaddress
import logging
import socket
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Dict, TypedDict, Optional
import socks
@ -45,7 +46,7 @@ class Proxy(TypedDict):
class TCP:
TIMEOUT = 10
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
self.ipv6 = ipv6
self.proxy = proxy
@ -53,7 +54,18 @@ class TCP:
self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
if isinstance(loop, asyncio.AbstractEventLoop):
self.loop = loop
else:
self.loop = asyncio.get_event_loop()
self._closed = True
@property
def closed(self) -> bool:
return (
self._closed or self.writer is None or self.writer.is_closing() or self.reader is None
)
async def _connect_via_proxy(
self,
@ -91,10 +103,8 @@ class TCP:
)
sock.settimeout(TCP.TIMEOUT)
await self.loop.sock_connect(
sock=sock,
address=destination
)
with ThreadPoolExecutor() as executor:
await self.loop.run_in_executor(executor, sock.connect, destination)
sock.setblocking(False)
@ -123,11 +133,14 @@ class TCP:
async def connect(self, address: Tuple[str, int]) -> None:
try:
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
self._closed = False
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
self._closed = True
raise TimeoutError("Connection timed out")
async def close(self) -> None:
if self.writer is None:
self._closed = True
return None
try:
@ -135,20 +148,27 @@ class TCP:
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
except Exception as e:
log.info("Close exception: %s %s", type(e).__name__, e)
finally:
self.writer = None
self._closed = True
async def send(self, data: bytes) -> None:
if self.writer is None:
return None
async with self.lock:
if self.writer is None or self.writer.is_closing():
return None
try:
self.writer.write(data)
await self.writer.drain()
except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e)
self._closed = True
raise OSError(e)
async def recv(self, length: int = 0) -> Optional[bytes]:
if self._closed or self.reader is None:
return None
data = b""
while len(data) < length:
@ -158,11 +178,13 @@ class TCP:
TCP.TIMEOUT
)
except (OSError, asyncio.TimeoutError):
self._closed = True
return None
else:
if chunk:
data += chunk
else:
self._closed = True
return None
return data

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from typing import Optional, Tuple
@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
import os
from typing import Optional, Tuple
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
self.encrypt = None
self.decrypt = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from binascii import crc32
from struct import pack, unpack
@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
self.seq_no: Optional[int] = None

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
from struct import pack, unpack
from typing import Optional, Tuple
@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)

View file

@ -17,6 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging
import os
from struct import pack, unpack
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
super().__init__(ipv6, proxy, loop)
self.encrypt = None
self.decrypt = None

View file

@ -95,7 +95,6 @@ class Dispatcher:
def __init__(self, client: "pyrogram.Client"):
self.client = client
self.loop = asyncio.get_event_loop()
self.handler_worker_tasks = []
self.locks_list = []
@ -271,7 +270,7 @@ class Dispatcher:
self.locks_list.append(asyncio.Lock())
self.handler_worker_tasks.append(
self.loop.create_task(self.handler_worker(self.locks_list[-1]))
self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
)
log.info("Started %s HandlerTasks", self.client.workers)
@ -311,7 +310,7 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
self.loop.create_task(fn())
self.client.loop.create_task(fn())
def remove_handler(self, handler, group: int):
async def fn():
@ -333,7 +332,7 @@ class Dispatcher:
for lock in self.locks_list:
lock.release()
self.loop.create_task(fn())
self.client.loop.create_task(fn())
async def handler_worker(self, lock: asyncio.Lock):
while True:
@ -404,7 +403,7 @@ class Dispatcher:
if inspect.iscoroutinefunction(handler.callback):
await handler.callback(self.client, *args)
else:
await self.loop.run_in_executor(
await self.client.loop.run_in_executor(
self.client.executor,
handler.callback,
self.client,

View file

@ -49,6 +49,6 @@ class Initialize:
await self.dispatcher.start()
self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog())
self.updates_watchdog_task = self.loop.create_task(self.updates_watchdog())
self.is_initialized = True

View file

@ -61,6 +61,8 @@ class SendCode:
)
)
except (PhoneMigrate, NetworkMigrate) as e:
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -58,6 +58,8 @@ class SignInBot:
)
)
except UserMigrate as e:
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -80,6 +80,8 @@ class SignInQrcode:
return types.User._parse(self, r.authorization.user)
if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -185,4 +185,4 @@ class DownloadMedia:
if block:
return await downloader
else:
asyncio.get_event_loop().create_task(downloader)
self.loop.create_task(downloader)

View file

@ -71,8 +71,7 @@ class Run:
app.run(main())
"""
loop = asyncio.get_event_loop()
run = loop.run_until_complete
run = self.loop.run_until_complete
if coroutine is not None:
run(coroutine)

View file

@ -45,13 +45,16 @@ class Auth:
dc_id: int,
test_mode: bool
):
self.client = client
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = client.ipv6
self.alt_port = client.alt_port
self.proxy = client.proxy
self.storage = client.storage
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
self.loop = client.loop
self.connection: Optional[Connection] = None
@ -85,15 +88,25 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
if not self.test_mode:
address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True:
self.connection = self.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.ipv6,
alt_port=self.alt_port,
proxy=self.proxy,
media=False,
protocol_factory=self.protocol_factory
protocol_factory=self.protocol_factory,
loop=self.loop
)
try:

View file

@ -19,8 +19,10 @@
import asyncio
import bisect
import contextlib
import logging
import os
from datetime import datetime, timedelta
from hashlib import sha1
from io import BytesIO
from typing import Optional
@ -56,6 +58,7 @@ class Session:
ACKS_THRESHOLD = 10
PING_INTERVAL = 5
STORED_MSG_IDS_MAX_SIZE = 500
RECONNECT_THRESHOLD = timedelta(seconds=10)
TRANSPORT_ERRORS = {
404: "auth key not found",
@ -101,24 +104,34 @@ class Session:
self.is_started = asyncio.Event()
self.loop = asyncio.get_event_loop()
self.last_reconnect_attempt = None
async def start(self):
if not self.test_mode:
address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.is_media)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True:
self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.client.ipv6,
alt_port=self.client.alt_port,
proxy=self.client.proxy,
media=self.is_media,
protocol_factory=self.client.protocol_factory
protocol_factory=self.client.protocol_factory,
loop=self.client.loop
)
try:
await self.connection.connect()
self.recv_task = self.loop.create_task(self.recv_worker())
self.recv_task = self.client.loop.create_task(self.recv_worker())
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
@ -140,7 +153,7 @@ class Session:
timeout=self.START_TIMEOUT
)
self.ping_task = self.loop.create_task(self.ping_worker())
self.ping_task = self.client.loop.create_task(self.ping_worker())
log.info("Session initialized: Layer %s", layer)
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
@ -175,7 +188,9 @@ class Session:
await self.connection.close()
if self.recv_task:
await self.recv_task
self.recv_task.cancel()
with contextlib.suppress(asyncio.CancelledError):
await self.recv_task
if not self.is_media and callable(self.client.disconnect_handler):
try:
@ -186,12 +201,21 @@ class Session:
log.info("Session stopped")
async def restart(self):
now = datetime.now()
if (
self.last_reconnect_attempt
and now - self.last_reconnect_attempt < self.RECONNECT_THRESHOLD
):
log.info("Reconnecting too frequently, sleeping for a while")
await asyncio.sleep(5)
self.last_reconnect_attempt = now
await self.stop()
await self.start()
async def handle_packet(self, packet):
try:
data = await self.loop.run_in_executor(
data = await self.client.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.unpack,
BytesIO(packet),
@ -201,7 +225,7 @@ class Session:
)
except ValueError as e:
log.debug(e)
self.loop.create_task(self.restart())
self.client.loop.create_task(self.restart())
return
messages = (
@ -263,7 +287,7 @@ class Session:
msg_id = msg.body.msg_id
else:
if self.client is not None:
self.loop.create_task(self.client.handle_updates(msg.body))
self.client.loop.create_task(self.client.handle_updates(msg.body))
if msg_id in self.results:
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
@ -297,7 +321,7 @@ class Session:
), False
)
except OSError:
self.loop.create_task(self.restart())
self.client.loop.create_task(self.restart())
break
except RPCError:
pass
@ -326,11 +350,11 @@ class Session:
)
if self.is_started.is_set():
self.loop.create_task(self.restart())
self.client.loop.create_task(self.restart())
break
self.loop.create_task(self.handle_packet(packet))
self.client.loop.create_task(self.handle_packet(packet))
log.info("NetworkTask stopped")
@ -343,7 +367,7 @@ class Session:
log.debug("Sent: %s", message)
payload = await self.loop.run_in_executor(
payload = await self.client.loop.run_in_executor(
pyrogram.crypto_executor,
mtproto.pack,
message,
@ -415,8 +439,21 @@ class Session:
query_name = ".".join(inner_query.QUALNAME.split(".")[1:])
while True:
while retries > 0:
try:
if (
self.connection is None
or self.connection.protocol is None
or getattr(self.connection.protocol, "closed", True)
):
log.warning(
"[%s] Connection is closed or not established. Attempting to reconnect...",
self.client.name,
)
await self.restart()
await asyncio.sleep(1)
continue
return await self.send(query, timeout=timeout)
except (FloodWait, FloodPremiumWait) as e:
amount = e.value
@ -429,15 +466,26 @@ class Session:
await asyncio.sleep(amount)
except (OSError, InternalServerError, ServiceUnavailable) as e:
retries -= 1
if retries == 0:
raise e from None
raise e
(log.warning if retries < 2 else log.info)(
'[%s] Retrying "%s" due to: %s',
Session.MAX_RETRIES - retries + 1,
Session.MAX_RETRIES - retries,
query_name, str(e) or repr(e)
)
if isinstance(e, OSError) and retries > 1:
try:
await self.restart()
except Exception as restart_error:
log.warning(
"[%s] Failed to restart session: %s",
self.client.name,
str(restart_error) or repr(restart_error),
)
await asyncio.sleep(0.5)
return await self.invoke(query, retries - 1, timeout)
raise TimeoutError("Exceeded maximum number of retries")

View file

@ -68,6 +68,12 @@ class FileStorage(SQLiteStorage):
version += 1
if version == 4:
with self.conn:
self.conn.executescript(self.UPDATE_DC_SCHEMA)
version += 1
self.version(version)
async def open(self):

View file

@ -77,6 +77,7 @@ class MongoStorage(Storage):
self._session = database['session']
self._usernames = database['usernames']
self._states = database['update_state']
self._dc_options = database['dc_options']
self._remove_peers = remove_peers
async def open(self):
@ -221,6 +222,52 @@ class MongoStorage(Storage):
return get_input_peer(r['_id'], r['access_hash'], r['type'])
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_media (bool): Whether it is a media data center.
"""
if value == object:
return
await self._dc_options.update_one(
{"$and": [
{'dc_id': value[0]},
{'is_ipv6': value[3]},
{'is_media': value[4]}
]},
{'$set': {'address': value[1], 'port': value[2]}},
upsert=True
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> Tuple[str, int]:
if dc_id in [1,3,5] and media:
media = False
if dc_id in [4,5] and test_mode:
test_mode = False
r = await self._dc_options.find_one(
{'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_media': media},
{'address': 1, 'port': 1}
)
if r is None:
return None
return r['address'], r['port']
async def _get(self):
attr = inspect.stack()[2].function
d = await self._session.find_one({'_id': 0}, {attr: 1})

View file

@ -97,6 +97,20 @@ END;
"""
UPDATE_DC_SCHEMA = """
CREATE TABLE dc_options
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
dc_id INTEGER,
address TEXT,
port INTEGER,
is_ipv6 BOOLEAN,
is_media BOOLEAN,
UNIQUE(dc_id, is_ipv6, is_media)
);
"""
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
if peer_type in ["user", "bot"]:
return raw.types.InputPeerUser(
@ -121,6 +135,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
class SQLiteStorage(Storage):
VERSION = 4
USERNAME_TTL = 8 * 60 * 60
UPDATE_DC_SCHEMA = globals().get("UPDATE_DC_SCHEMA", "")
def __init__(self, name: str):
super().__init__(name)
@ -131,6 +146,7 @@ class SQLiteStorage(Storage):
with self.conn:
self.conn.executescript(SCHEMA)
self.conn.executescript(UNAME_SCHEMA)
self.conn.executescript(self.UPDATE_DC_SCHEMA)
self.conn.execute(
"INSERT INTO version VALUES (?)",
@ -252,6 +268,60 @@ class SQLiteStorage(Storage):
return get_input_peer(*r)
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool, bool, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_media (bool): Whether it is a media data center.
"""
if value == object:
return
with self.conn:
self.conn.execute(
"""
INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_media)
VALUES (?, ?, ?, ?, ?)
ON CONFLICT(dc_id, is_ipv6, is_media)
DO UPDATE SET address=excluded.address, port=excluded.port
""",
value
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> Tuple[str, int]:
"""
Retrieves the address of a data center.
Parameters:
dc_id (int): Data center ID.
is_ipv6 (bool): Whether the address is IPv6.
media (bool): Whether it is a media data center.
Returns:
Tuple[str, int]: A tuple containing the address and port of the data center.
"""
if dc_id in [1,3,5] and media:
media = False
r = self.conn.execute(
"SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_media = ?",
(dc_id, is_ipv6, media)
).fetchone()
return r
def _get(self):
attr = inspect.stack()[2].function

View file

@ -76,6 +76,21 @@ class Storage:
async def get_peer_by_phone_number(self, phone_number: str):
raise NotImplementedError
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool] = object
):
raise NotImplementedError
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
test_mode: bool = False,
media: bool = False
):
raise NotImplementedError
async def dc_id(self, value: int = object):
raise NotImplementedError

View file

@ -43,13 +43,18 @@ PyromodConfig = SimpleNamespace(
unallowed_click_alert=True,
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
)
async def ainput(prompt: str = "", *, hide: bool = False):
async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
"""Just like the built-in input, but async"""
if isinstance(loop, asyncio.AbstractEventLoop):
loop = loop
else:
loop = asyncio.get_event_loop()
with ThreadPoolExecutor(1) as executor:
func = functools.partial(getpass if hide else input, prompt)
return await asyncio.get_event_loop().run_in_executor(executor, func)
return await loop.run_in_executor(executor, func)
def get_input_media_from_file_id(