mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2025-12-29 12:04:51 +00:00
Compare commits
11 commits
f824e72416
...
fd17d1ec5d
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
fd17d1ec5d | ||
|
|
6e9e1740b0 | ||
|
|
375e97d963 | ||
|
|
3aed1857d4 | ||
|
|
3d71eba4e1 | ||
|
|
1a4c578380 | ||
|
|
e3a8a781d3 | ||
|
|
921b593285 | ||
|
|
b8028541c9 | ||
|
|
01e7717e52 | ||
|
|
b79ffac690 |
23 changed files with 394 additions and 105 deletions
40
.github/workflows/python-publish.yml
vendored
40
.github/workflows/python-publish.yml
vendored
|
|
@ -1,40 +0,0 @@
|
|||
# This workflow will upload a Python Package using Twine when a release is created
|
||||
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
|
||||
|
||||
# This workflow uses actions that are not certified by GitHub.
|
||||
# They are provided by a third-party and are governed by
|
||||
# separate terms of service, privacy policy, and support
|
||||
# documentation.
|
||||
|
||||
name: Upload Python Package
|
||||
|
||||
on:
|
||||
push:
|
||||
tags:
|
||||
- '*'
|
||||
|
||||
permissions:
|
||||
contents: read
|
||||
|
||||
jobs:
|
||||
deploy:
|
||||
|
||||
runs-on: ubuntu-latest
|
||||
environment: release
|
||||
permissions:
|
||||
id-token: write
|
||||
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- name: Set up Python
|
||||
uses: actions/setup-python@v5
|
||||
with:
|
||||
python-version: '3.10'
|
||||
- name: Install dependencies
|
||||
run: |
|
||||
python -m pip install --upgrade pip
|
||||
pip install -e '.[dev]'
|
||||
- name: Build package
|
||||
run: hatch build
|
||||
- name: Publish package
|
||||
uses: pypa/gh-action-pypi-publish@release/v1
|
||||
|
|
@ -33,7 +33,7 @@ from importlib import import_module
|
|||
from io import StringIO, BytesIO
|
||||
from mimetypes import MimeTypes
|
||||
from pathlib import Path
|
||||
from typing import Union, List, Optional, Callable, AsyncGenerator, Tuple
|
||||
from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
|
||||
|
||||
import pyrogram
|
||||
from pyrogram import __version__, __license__
|
||||
|
|
@ -55,12 +55,13 @@ from pyrogram.storage import FileStorage, MemoryStorage, Storage
|
|||
from pyrogram.types import User
|
||||
from pyrogram.utils import ainput
|
||||
from .connection import Connection
|
||||
from .connection.transport import TCPAbridged
|
||||
from .connection.transport import TCP, TCPAbridged
|
||||
from .dispatcher import Dispatcher
|
||||
from .file_id import FileId, FileType, ThumbnailSource
|
||||
from .mime_types import mime_types
|
||||
from .parser import Parser
|
||||
from .session.internals import MsgId
|
||||
from .session.internals.data_center import DataCenter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
MONGO_AVAIL = False
|
||||
|
|
@ -220,6 +221,9 @@ class Client(Methods):
|
|||
client_platform (:obj:`~pyrogram.enums.ClientPlatform`, *optional*):
|
||||
The platform where this client is running.
|
||||
Defaults to 'other'
|
||||
|
||||
loop (:py:class:`asyncio.AbstractEventLoop`, *optional*):
|
||||
Event loop.
|
||||
"""
|
||||
|
||||
APP_VERSION = f"Pyrogram {__version__}"
|
||||
|
|
@ -277,7 +281,9 @@ class Client(Methods):
|
|||
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
|
||||
client_platform: "enums.ClientPlatform" = enums.ClientPlatform.OTHER,
|
||||
max_message_cache_size: int = MAX_CACHE_SIZE,
|
||||
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE
|
||||
max_business_user_connection_cache_size: int = MAX_CACHE_SIZE,
|
||||
protocol_factory: Type[TCP] = TCPAbridged,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
):
|
||||
super().__init__()
|
||||
|
||||
|
|
@ -314,6 +320,8 @@ class Client(Methods):
|
|||
self.max_message_cache_size = max_message_cache_size
|
||||
self.max_message_cache_size = max_message_cache_size
|
||||
self.max_business_user_connection_cache_size = max_business_user_connection_cache_size
|
||||
self.test_addr = None
|
||||
self.test_port = None
|
||||
|
||||
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")
|
||||
|
||||
|
|
@ -371,7 +379,11 @@ class Client(Methods):
|
|||
self.updates_watchdog_event = asyncio.Event()
|
||||
self.last_update_time = datetime.now()
|
||||
self.listeners = {listener_type: [] for listener_type in pyrogram.enums.ListenerTypes}
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
self.loop = loop
|
||||
else:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
def __enter__(self):
|
||||
return self.start()
|
||||
|
|
@ -391,6 +403,25 @@ class Client(Methods):
|
|||
except ConnectionError:
|
||||
pass
|
||||
|
||||
def set_dc(self, addr: str, port: int = 80):
|
||||
"""Set the data center address and port.
|
||||
|
||||
Parameters:
|
||||
addr (``str``):
|
||||
The data center address, e.g.: "149.154.167.40".
|
||||
|
||||
port (``int``, *optional*):
|
||||
The data center port, e.g.: 443.
|
||||
Defaults to 80.
|
||||
"""
|
||||
if not isinstance(addr, str):
|
||||
raise TypeError("addr must be a string")
|
||||
if not isinstance(port, int):
|
||||
raise TypeError("port must be an integer")
|
||||
|
||||
self.test_addr = addr
|
||||
self.test_port = port
|
||||
|
||||
async def updates_watchdog(self):
|
||||
while True:
|
||||
try:
|
||||
|
|
@ -425,7 +456,7 @@ class Client(Methods):
|
|||
if not self.phone_number:
|
||||
while True:
|
||||
print("Enter 'qrcode' if you want to login with qrcode.")
|
||||
value = await ainput("Enter phone number or bot token: ")
|
||||
value = await ainput("Enter phone number or bot token: ", loop=self.loop)
|
||||
|
||||
if not value:
|
||||
continue
|
||||
|
|
@ -434,7 +465,7 @@ class Client(Methods):
|
|||
self.use_qrcode = True
|
||||
break
|
||||
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
|
||||
|
||||
if confirm == "y":
|
||||
break
|
||||
|
|
@ -466,7 +497,7 @@ class Client(Methods):
|
|||
|
||||
while True:
|
||||
if not self.use_qrcode and not self.phone_code:
|
||||
self.phone_code = await ainput("Enter confirmation code: ")
|
||||
self.phone_code = await ainput("Enter confirmation code: ", loop=self.loop)
|
||||
|
||||
try:
|
||||
if self.use_qrcode:
|
||||
|
|
@ -483,18 +514,18 @@ class Client(Methods):
|
|||
print("Password hint: {}".format(await self.get_password_hint()))
|
||||
|
||||
if not self.password:
|
||||
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password)
|
||||
self.password = await ainput("Enter password (empty to recover): ", hide=self.hide_password, loop=self.loop)
|
||||
|
||||
try:
|
||||
if not self.password:
|
||||
confirm = await ainput("Confirm password recovery (y/n): ")
|
||||
confirm = await ainput("Confirm password recovery (y/n): ", loop=self.loop)
|
||||
|
||||
if confirm == "y":
|
||||
email_pattern = await self.send_recovery_code()
|
||||
print(f"The recovery code has been sent to {email_pattern}")
|
||||
|
||||
while True:
|
||||
recovery_code = await ainput("Enter recovery code: ")
|
||||
recovery_code = await ainput("Enter recovery code: ", loop=self.loop)
|
||||
|
||||
try:
|
||||
return await self.recover_password(recovery_code)
|
||||
|
|
@ -521,6 +552,16 @@ class Client(Methods):
|
|||
break
|
||||
|
||||
if isinstance(signed_in, User):
|
||||
if not self.test_mode:
|
||||
is_dc_default = await self.check_dc_default()
|
||||
if is_dc_default:
|
||||
log.info("Your session is using the default data center.")
|
||||
log.info("Updating the data center options from GetConfig...")
|
||||
await self.update_dc_option()
|
||||
log.info("Data center updated successfully.")
|
||||
log.info("Restarting the session to apply the changes...")
|
||||
await self.session.stop()
|
||||
await self.session.start()
|
||||
return signed_in
|
||||
|
||||
def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
|
||||
|
|
@ -842,19 +883,64 @@ class Client(Methods):
|
|||
else:
|
||||
while True:
|
||||
try:
|
||||
value = int(await ainput("Enter the api_id part of the API key: "))
|
||||
value = int(await ainput("Enter the api_id part of the API key: ", loop=self.loop))
|
||||
|
||||
if value <= 0:
|
||||
print("Invalid value")
|
||||
continue
|
||||
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ')).lower()
|
||||
confirm = (await ainput(f'Is "{value}" correct? (y/N): ', loop=self.loop)).lower()
|
||||
|
||||
if confirm == "y":
|
||||
await self.storage.api_id(value)
|
||||
break
|
||||
except Exception as e:
|
||||
print(e)
|
||||
# Needed for migration from storage v3 to v4
|
||||
if self.in_memory or self.session_string:
|
||||
await self.insert_default_dc_options()
|
||||
if not await self.storage.get_dc_address(await self.storage.dc_id(), self.ipv6):
|
||||
log.info("No DC address found, inserting default DC options...")
|
||||
await self.insert_default_dc_options()
|
||||
|
||||
async def insert_default_dc_options(self):
|
||||
for dc_id in range(1, 6):
|
||||
for is_ipv6 in (False, True):
|
||||
if dc_id in [2,4]:
|
||||
for media in (False, True):
|
||||
address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media)
|
||||
await self.storage.update_dc_address(
|
||||
(dc_id, address, port, is_ipv6, media)
|
||||
)
|
||||
else:
|
||||
address, port = DataCenter(dc_id, False, is_ipv6, False, False)
|
||||
await self.storage.update_dc_address(
|
||||
(dc_id, address, port, is_ipv6, False)
|
||||
)
|
||||
|
||||
async def update_dc_option(self):
|
||||
config = await self.invoke(raw.functions.help.GetConfig())
|
||||
for option in config.dc_options:
|
||||
await self.storage.update_dc_address(
|
||||
(option.id, option.address, option.port, option.is_ipv6, option.is_media)
|
||||
)
|
||||
|
||||
async def check_dc_default(
|
||||
self,
|
||||
dc_id: int,
|
||||
is_ipv6: bool,
|
||||
media: bool = False
|
||||
) -> bool:
|
||||
default_dc = DataCenter(
|
||||
dc_id,
|
||||
test_mode=False,
|
||||
is_ipv6=is_ipv6,
|
||||
alt_port=self.alt_port,
|
||||
media=media
|
||||
)
|
||||
current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, media)
|
||||
if current_dc is not None and current_dc[0] == default_dc.address:
|
||||
return True
|
||||
|
||||
def is_excluded(self, exclude, module):
|
||||
for e in exclude:
|
||||
|
|
|
|||
|
|
@ -22,7 +22,6 @@ import logging
|
|||
from typing import Optional, Type
|
||||
|
||||
from .transport import TCP, TCPAbridged
|
||||
from ..session.internals import DataCenter
|
||||
|
||||
log = logging.getLogger(__name__)
|
||||
|
||||
|
|
@ -34,26 +33,35 @@ class Connection:
|
|||
self,
|
||||
dc_id: int,
|
||||
test_mode: bool,
|
||||
server_ip: str,
|
||||
server_port: int,
|
||||
ipv6: bool,
|
||||
alt_port: bool,
|
||||
proxy: dict,
|
||||
media: bool = False,
|
||||
protocol_factory: Type[TCP] = TCPAbridged
|
||||
protocol_factory: Type[TCP] = TCPAbridged,
|
||||
loop: Optional[asyncio.AbstractEventLoop] = None
|
||||
) -> None:
|
||||
self.dc_id = dc_id
|
||||
self.test_mode = test_mode
|
||||
self.server_ip = server_ip
|
||||
self.server_port = 5222 if alt_port else server_port
|
||||
self.ipv6 = ipv6
|
||||
self.alt_port = alt_port
|
||||
self.proxy = proxy
|
||||
self.media = media
|
||||
self.protocol_factory = protocol_factory
|
||||
|
||||
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
|
||||
self.address = (server_ip, server_port)
|
||||
self.protocol: Optional[TCP] = None
|
||||
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
self.loop = loop
|
||||
else:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
async def connect(self) -> None:
|
||||
for _ in range(Connection.MAX_CONNECTION_ATTEMPTS):
|
||||
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
|
||||
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy, loop=self.loop)
|
||||
|
||||
try:
|
||||
log.info("Connecting...")
|
||||
|
|
|
|||
|
|
@ -21,6 +21,7 @@ import asyncio
|
|||
import ipaddress
|
||||
import logging
|
||||
import socket
|
||||
from concurrent.futures import ThreadPoolExecutor
|
||||
from typing import Tuple, Dict, TypedDict, Optional
|
||||
|
||||
import socks
|
||||
|
|
@ -45,7 +46,7 @@ class Proxy(TypedDict):
|
|||
class TCP:
|
||||
TIMEOUT = 10
|
||||
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
self.ipv6 = ipv6
|
||||
self.proxy = proxy
|
||||
|
||||
|
|
@ -53,7 +54,18 @@ class TCP:
|
|||
self.writer: Optional[asyncio.StreamWriter] = None
|
||||
|
||||
self.lock = asyncio.Lock()
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
self.loop = loop
|
||||
else:
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self._closed = True
|
||||
|
||||
@property
|
||||
def closed(self) -> bool:
|
||||
return (
|
||||
self._closed or self.writer is None or self.writer.is_closing() or self.reader is None
|
||||
)
|
||||
|
||||
async def _connect_via_proxy(
|
||||
self,
|
||||
|
|
@ -91,10 +103,8 @@ class TCP:
|
|||
)
|
||||
sock.settimeout(TCP.TIMEOUT)
|
||||
|
||||
await self.loop.sock_connect(
|
||||
sock=sock,
|
||||
address=destination
|
||||
)
|
||||
with ThreadPoolExecutor() as executor:
|
||||
await self.loop.run_in_executor(executor, sock.connect, destination)
|
||||
|
||||
sock.setblocking(False)
|
||||
|
||||
|
|
@ -123,11 +133,14 @@ class TCP:
|
|||
async def connect(self, address: Tuple[str, int]) -> None:
|
||||
try:
|
||||
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
|
||||
self._closed = False
|
||||
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
|
||||
self._closed = True
|
||||
raise TimeoutError("Connection timed out")
|
||||
|
||||
async def close(self) -> None:
|
||||
if self.writer is None:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
try:
|
||||
|
|
@ -135,20 +148,27 @@ class TCP:
|
|||
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
|
||||
except Exception as e:
|
||||
log.info("Close exception: %s %s", type(e).__name__, e)
|
||||
finally:
|
||||
self.writer = None
|
||||
self._closed = True
|
||||
|
||||
async def send(self, data: bytes) -> None:
|
||||
if self.writer is None:
|
||||
return None
|
||||
|
||||
async with self.lock:
|
||||
if self.writer is None or self.writer.is_closing():
|
||||
return None
|
||||
|
||||
try:
|
||||
self.writer.write(data)
|
||||
await self.writer.drain()
|
||||
except Exception as e:
|
||||
log.info("Send exception: %s %s", type(e).__name__, e)
|
||||
self._closed = True
|
||||
raise OSError(e)
|
||||
|
||||
async def recv(self, length: int = 0) -> Optional[bytes]:
|
||||
if self._closed or self.reader is None:
|
||||
return None
|
||||
|
||||
data = b""
|
||||
|
||||
while len(data) < length:
|
||||
|
|
@ -158,11 +178,13 @@ class TCP:
|
|||
TCP.TIMEOUT
|
||||
)
|
||||
except (OSError, asyncio.TimeoutError):
|
||||
self._closed = True
|
||||
return None
|
||||
else:
|
||||
if chunk:
|
||||
data += chunk
|
||||
else:
|
||||
self._closed = True
|
||||
return None
|
||||
|
||||
return data
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import Optional, Tuple
|
||||
|
||||
|
|
@ -26,8 +27,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TCPAbridged(TCP):
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
async def connect(self, address: Tuple[str, int]) -> None:
|
||||
await super().connect(address)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from typing import Optional, Tuple
|
||||
|
|
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
|
|||
class TCPAbridgedO(TCP):
|
||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
self.encrypt = None
|
||||
self.decrypt = None
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from binascii import crc32
|
||||
from struct import pack, unpack
|
||||
|
|
@ -28,8 +29,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TCPFull(TCP):
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
self.seq_no: Optional[int] = None
|
||||
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
from struct import pack, unpack
|
||||
from typing import Optional, Tuple
|
||||
|
|
@ -27,8 +28,8 @@ log = logging.getLogger(__name__)
|
|||
|
||||
|
||||
class TCPIntermediate(TCP):
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
async def connect(self, address: Tuple[str, int]) -> None:
|
||||
await super().connect(address)
|
||||
|
|
|
|||
|
|
@ -17,6 +17,7 @@
|
|||
# You should have received a copy of the GNU Lesser General Public License
|
||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import asyncio
|
||||
import logging
|
||||
import os
|
||||
from struct import pack, unpack
|
||||
|
|
@ -31,8 +32,8 @@ log = logging.getLogger(__name__)
|
|||
class TCPIntermediateO(TCP):
|
||||
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
|
||||
|
||||
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
|
||||
super().__init__(ipv6, proxy)
|
||||
def __init__(self, ipv6: bool, proxy: Proxy, loop: Optional[asyncio.AbstractEventLoop] = None) -> None:
|
||||
super().__init__(ipv6, proxy, loop)
|
||||
|
||||
self.encrypt = None
|
||||
self.decrypt = None
|
||||
|
|
|
|||
|
|
@ -95,7 +95,6 @@ class Dispatcher:
|
|||
|
||||
def __init__(self, client: "pyrogram.Client"):
|
||||
self.client = client
|
||||
self.loop = asyncio.get_event_loop()
|
||||
|
||||
self.handler_worker_tasks = []
|
||||
self.locks_list = []
|
||||
|
|
@ -271,7 +270,7 @@ class Dispatcher:
|
|||
self.locks_list.append(asyncio.Lock())
|
||||
|
||||
self.handler_worker_tasks.append(
|
||||
self.loop.create_task(self.handler_worker(self.locks_list[-1]))
|
||||
self.client.loop.create_task(self.handler_worker(self.locks_list[-1]))
|
||||
)
|
||||
|
||||
log.info("Started %s HandlerTasks", self.client.workers)
|
||||
|
|
@ -311,7 +310,7 @@ class Dispatcher:
|
|||
for lock in self.locks_list:
|
||||
lock.release()
|
||||
|
||||
self.loop.create_task(fn())
|
||||
self.client.loop.create_task(fn())
|
||||
|
||||
def remove_handler(self, handler, group: int):
|
||||
async def fn():
|
||||
|
|
@ -333,7 +332,7 @@ class Dispatcher:
|
|||
for lock in self.locks_list:
|
||||
lock.release()
|
||||
|
||||
self.loop.create_task(fn())
|
||||
self.client.loop.create_task(fn())
|
||||
|
||||
async def handler_worker(self, lock: asyncio.Lock):
|
||||
while True:
|
||||
|
|
@ -404,7 +403,7 @@ class Dispatcher:
|
|||
if inspect.iscoroutinefunction(handler.callback):
|
||||
await handler.callback(self.client, *args)
|
||||
else:
|
||||
await self.loop.run_in_executor(
|
||||
await self.client.loop.run_in_executor(
|
||||
self.client.executor,
|
||||
handler.callback,
|
||||
self.client,
|
||||
|
|
|
|||
|
|
@ -49,6 +49,6 @@ class Initialize:
|
|||
|
||||
await self.dispatcher.start()
|
||||
|
||||
self.updates_watchdog_task = asyncio.create_task(self.updates_watchdog())
|
||||
self.updates_watchdog_task = self.loop.create_task(self.updates_watchdog())
|
||||
|
||||
self.is_initialized = True
|
||||
|
|
|
|||
|
|
@ -61,6 +61,8 @@ class SendCode:
|
|||
)
|
||||
)
|
||||
except (PhoneMigrate, NetworkMigrate) as e:
|
||||
if not self.test_mode:
|
||||
await self.update_dc_option()
|
||||
# pylint: disable=access-member-before-definition
|
||||
await self.session.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -58,6 +58,8 @@ class SignInBot:
|
|||
)
|
||||
)
|
||||
except UserMigrate as e:
|
||||
if not self.test_mode:
|
||||
await self.update_dc_option()
|
||||
# pylint: disable=access-member-before-definition
|
||||
await self.session.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -80,6 +80,8 @@ class SignInQrcode:
|
|||
|
||||
return types.User._parse(self, r.authorization.user)
|
||||
if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
|
||||
if not self.test_mode:
|
||||
await self.update_dc_option()
|
||||
# pylint: disable=access-member-before-definition
|
||||
await self.session.stop()
|
||||
|
||||
|
|
|
|||
|
|
@ -185,4 +185,4 @@ class DownloadMedia:
|
|||
if block:
|
||||
return await downloader
|
||||
else:
|
||||
asyncio.get_event_loop().create_task(downloader)
|
||||
self.loop.create_task(downloader)
|
||||
|
|
|
|||
|
|
@ -71,8 +71,7 @@ class Run:
|
|||
|
||||
app.run(main())
|
||||
"""
|
||||
loop = asyncio.get_event_loop()
|
||||
run = loop.run_until_complete
|
||||
run = self.loop.run_until_complete
|
||||
|
||||
if coroutine is not None:
|
||||
run(coroutine)
|
||||
|
|
|
|||
|
|
@ -45,13 +45,16 @@ class Auth:
|
|||
dc_id: int,
|
||||
test_mode: bool
|
||||
):
|
||||
self.client = client
|
||||
self.dc_id = dc_id
|
||||
self.test_mode = test_mode
|
||||
self.ipv6 = client.ipv6
|
||||
self.alt_port = client.alt_port
|
||||
self.proxy = client.proxy
|
||||
self.storage = client.storage
|
||||
self.connection_factory = client.connection_factory
|
||||
self.protocol_factory = client.protocol_factory
|
||||
self.loop = client.loop
|
||||
|
||||
self.connection: Optional[Connection] = None
|
||||
|
||||
|
|
@ -85,15 +88,25 @@ class Auth:
|
|||
|
||||
# The server may close the connection at any time, causing the auth key creation to fail.
|
||||
# If that happens, just try again up to MAX_RETRIES times.
|
||||
if not self.test_mode:
|
||||
address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6)
|
||||
else:
|
||||
address = self.client.test_addr
|
||||
port = self.client.test_port
|
||||
if address is None or port is None:
|
||||
raise ValueError("Test address and port must be set for test mode.")
|
||||
while True:
|
||||
self.connection = self.connection_factory(
|
||||
dc_id=self.dc_id,
|
||||
test_mode=self.test_mode,
|
||||
server_ip=address,
|
||||
server_port=port,
|
||||
ipv6=self.ipv6,
|
||||
alt_port=self.alt_port,
|
||||
proxy=self.proxy,
|
||||
media=False,
|
||||
protocol_factory=self.protocol_factory
|
||||
protocol_factory=self.protocol_factory,
|
||||
loop=self.loop
|
||||
)
|
||||
|
||||
try:
|
||||
|
|
|
|||
|
|
@ -19,8 +19,10 @@
|
|||
|
||||
import asyncio
|
||||
import bisect
|
||||
import contextlib
|
||||
import logging
|
||||
import os
|
||||
from datetime import datetime, timedelta
|
||||
from hashlib import sha1
|
||||
from io import BytesIO
|
||||
from typing import Optional
|
||||
|
|
@ -56,6 +58,7 @@ class Session:
|
|||
ACKS_THRESHOLD = 10
|
||||
PING_INTERVAL = 5
|
||||
STORED_MSG_IDS_MAX_SIZE = 500
|
||||
RECONNECT_THRESHOLD = timedelta(seconds=10)
|
||||
|
||||
TRANSPORT_ERRORS = {
|
||||
404: "auth key not found",
|
||||
|
|
@ -101,24 +104,34 @@ class Session:
|
|||
|
||||
self.is_started = asyncio.Event()
|
||||
|
||||
self.loop = asyncio.get_event_loop()
|
||||
self.last_reconnect_attempt = None
|
||||
|
||||
async def start(self):
|
||||
if not self.test_mode:
|
||||
address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.is_media)
|
||||
else:
|
||||
address = self.client.test_addr
|
||||
port = self.client.test_port
|
||||
if address is None or port is None:
|
||||
raise ValueError("Test address and port must be set for test mode.")
|
||||
while True:
|
||||
self.connection = self.client.connection_factory(
|
||||
dc_id=self.dc_id,
|
||||
test_mode=self.test_mode,
|
||||
server_ip=address,
|
||||
server_port=port,
|
||||
ipv6=self.client.ipv6,
|
||||
alt_port=self.client.alt_port,
|
||||
proxy=self.client.proxy,
|
||||
media=self.is_media,
|
||||
protocol_factory=self.client.protocol_factory
|
||||
protocol_factory=self.client.protocol_factory,
|
||||
loop=self.client.loop
|
||||
)
|
||||
|
||||
try:
|
||||
await self.connection.connect()
|
||||
|
||||
self.recv_task = self.loop.create_task(self.recv_worker())
|
||||
self.recv_task = self.client.loop.create_task(self.recv_worker())
|
||||
|
||||
await self.send(raw.functions.Ping(ping_id=0), timeout=self.START_TIMEOUT)
|
||||
|
||||
|
|
@ -140,7 +153,7 @@ class Session:
|
|||
timeout=self.START_TIMEOUT
|
||||
)
|
||||
|
||||
self.ping_task = self.loop.create_task(self.ping_worker())
|
||||
self.ping_task = self.client.loop.create_task(self.ping_worker())
|
||||
|
||||
log.info("Session initialized: Layer %s", layer)
|
||||
log.info("Device: %s - %s", self.client.device_model, self.client.app_version)
|
||||
|
|
@ -175,7 +188,9 @@ class Session:
|
|||
await self.connection.close()
|
||||
|
||||
if self.recv_task:
|
||||
await self.recv_task
|
||||
self.recv_task.cancel()
|
||||
with contextlib.suppress(asyncio.CancelledError):
|
||||
await self.recv_task
|
||||
|
||||
if not self.is_media and callable(self.client.disconnect_handler):
|
||||
try:
|
||||
|
|
@ -186,12 +201,21 @@ class Session:
|
|||
log.info("Session stopped")
|
||||
|
||||
async def restart(self):
|
||||
now = datetime.now()
|
||||
if (
|
||||
self.last_reconnect_attempt
|
||||
and now - self.last_reconnect_attempt < self.RECONNECT_THRESHOLD
|
||||
):
|
||||
log.info("Reconnecting too frequently, sleeping for a while")
|
||||
await asyncio.sleep(5)
|
||||
|
||||
self.last_reconnect_attempt = now
|
||||
await self.stop()
|
||||
await self.start()
|
||||
|
||||
async def handle_packet(self, packet):
|
||||
try:
|
||||
data = await self.loop.run_in_executor(
|
||||
data = await self.client.loop.run_in_executor(
|
||||
pyrogram.crypto_executor,
|
||||
mtproto.unpack,
|
||||
BytesIO(packet),
|
||||
|
|
@ -201,7 +225,7 @@ class Session:
|
|||
)
|
||||
except ValueError as e:
|
||||
log.debug(e)
|
||||
self.loop.create_task(self.restart())
|
||||
self.client.loop.create_task(self.restart())
|
||||
return
|
||||
|
||||
messages = (
|
||||
|
|
@ -263,7 +287,7 @@ class Session:
|
|||
msg_id = msg.body.msg_id
|
||||
else:
|
||||
if self.client is not None:
|
||||
self.loop.create_task(self.client.handle_updates(msg.body))
|
||||
self.client.loop.create_task(self.client.handle_updates(msg.body))
|
||||
|
||||
if msg_id in self.results:
|
||||
self.results[msg_id].value = getattr(msg.body, "result", msg.body)
|
||||
|
|
@ -297,7 +321,7 @@ class Session:
|
|||
), False
|
||||
)
|
||||
except OSError:
|
||||
self.loop.create_task(self.restart())
|
||||
self.client.loop.create_task(self.restart())
|
||||
break
|
||||
except RPCError:
|
||||
pass
|
||||
|
|
@ -326,11 +350,11 @@ class Session:
|
|||
)
|
||||
|
||||
if self.is_started.is_set():
|
||||
self.loop.create_task(self.restart())
|
||||
self.client.loop.create_task(self.restart())
|
||||
|
||||
break
|
||||
|
||||
self.loop.create_task(self.handle_packet(packet))
|
||||
self.client.loop.create_task(self.handle_packet(packet))
|
||||
|
||||
log.info("NetworkTask stopped")
|
||||
|
||||
|
|
@ -343,7 +367,7 @@ class Session:
|
|||
|
||||
log.debug("Sent: %s", message)
|
||||
|
||||
payload = await self.loop.run_in_executor(
|
||||
payload = await self.client.loop.run_in_executor(
|
||||
pyrogram.crypto_executor,
|
||||
mtproto.pack,
|
||||
message,
|
||||
|
|
@ -415,8 +439,21 @@ class Session:
|
|||
|
||||
query_name = ".".join(inner_query.QUALNAME.split(".")[1:])
|
||||
|
||||
while True:
|
||||
while retries > 0:
|
||||
try:
|
||||
if (
|
||||
self.connection is None
|
||||
or self.connection.protocol is None
|
||||
or getattr(self.connection.protocol, "closed", True)
|
||||
):
|
||||
log.warning(
|
||||
"[%s] Connection is closed or not established. Attempting to reconnect...",
|
||||
self.client.name,
|
||||
)
|
||||
await self.restart()
|
||||
await asyncio.sleep(1)
|
||||
continue
|
||||
|
||||
return await self.send(query, timeout=timeout)
|
||||
except (FloodWait, FloodPremiumWait) as e:
|
||||
amount = e.value
|
||||
|
|
@ -429,15 +466,26 @@ class Session:
|
|||
|
||||
await asyncio.sleep(amount)
|
||||
except (OSError, InternalServerError, ServiceUnavailable) as e:
|
||||
retries -= 1
|
||||
if retries == 0:
|
||||
raise e from None
|
||||
raise e
|
||||
|
||||
(log.warning if retries < 2 else log.info)(
|
||||
'[%s] Retrying "%s" due to: %s',
|
||||
Session.MAX_RETRIES - retries + 1,
|
||||
Session.MAX_RETRIES - retries,
|
||||
query_name, str(e) or repr(e)
|
||||
)
|
||||
|
||||
if isinstance(e, OSError) and retries > 1:
|
||||
try:
|
||||
await self.restart()
|
||||
except Exception as restart_error:
|
||||
log.warning(
|
||||
"[%s] Failed to restart session: %s",
|
||||
self.client.name,
|
||||
str(restart_error) or repr(restart_error),
|
||||
)
|
||||
|
||||
await asyncio.sleep(0.5)
|
||||
|
||||
return await self.invoke(query, retries - 1, timeout)
|
||||
raise TimeoutError("Exceeded maximum number of retries")
|
||||
|
|
|
|||
|
|
@ -68,6 +68,12 @@ class FileStorage(SQLiteStorage):
|
|||
|
||||
version += 1
|
||||
|
||||
if version == 4:
|
||||
with self.conn:
|
||||
self.conn.executescript(self.UPDATE_DC_SCHEMA)
|
||||
|
||||
version += 1
|
||||
|
||||
self.version(version)
|
||||
|
||||
async def open(self):
|
||||
|
|
|
|||
|
|
@ -77,6 +77,7 @@ class MongoStorage(Storage):
|
|||
self._session = database['session']
|
||||
self._usernames = database['usernames']
|
||||
self._states = database['update_state']
|
||||
self._dc_options = database['dc_options']
|
||||
self._remove_peers = remove_peers
|
||||
|
||||
async def open(self):
|
||||
|
|
@ -221,6 +222,52 @@ class MongoStorage(Storage):
|
|||
|
||||
return get_input_peer(r['_id'], r['access_hash'], r['type'])
|
||||
|
||||
async def update_dc_address(
|
||||
self,
|
||||
value: Tuple[int, str, int, bool, bool] = object
|
||||
):
|
||||
"""
|
||||
Updates or inserts a data center address.
|
||||
|
||||
Parameters:
|
||||
value (Tuple[int, str, int, bool]): A tuple containing:
|
||||
- dc_id (int): Data center ID.
|
||||
- address (str): Address of the data center.
|
||||
- port (int): Port of the data center.
|
||||
- is_ipv6 (bool): Whether the address is IPv6.
|
||||
- is_media (bool): Whether it is a media data center.
|
||||
"""
|
||||
if value == object:
|
||||
return
|
||||
|
||||
await self._dc_options.update_one(
|
||||
{"$and": [
|
||||
{'dc_id': value[0]},
|
||||
{'is_ipv6': value[3]},
|
||||
{'is_media': value[4]}
|
||||
]},
|
||||
{'$set': {'address': value[1], 'port': value[2]}},
|
||||
upsert=True
|
||||
)
|
||||
|
||||
async def get_dc_address(
|
||||
self,
|
||||
dc_id: int,
|
||||
is_ipv6: bool,
|
||||
media: bool = False
|
||||
) -> Tuple[str, int]:
|
||||
if dc_id in [1,3,5] and media:
|
||||
media = False
|
||||
if dc_id in [4,5] and test_mode:
|
||||
test_mode = False
|
||||
r = await self._dc_options.find_one(
|
||||
{'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_media': media},
|
||||
{'address': 1, 'port': 1}
|
||||
)
|
||||
if r is None:
|
||||
return None
|
||||
return r['address'], r['port']
|
||||
|
||||
async def _get(self):
|
||||
attr = inspect.stack()[2].function
|
||||
d = await self._session.find_one({'_id': 0}, {attr: 1})
|
||||
|
|
|
|||
|
|
@ -97,6 +97,20 @@ END;
|
|||
"""
|
||||
|
||||
|
||||
UPDATE_DC_SCHEMA = """
|
||||
CREATE TABLE dc_options
|
||||
(
|
||||
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||
dc_id INTEGER,
|
||||
address TEXT,
|
||||
port INTEGER,
|
||||
is_ipv6 BOOLEAN,
|
||||
is_media BOOLEAN,
|
||||
UNIQUE(dc_id, is_ipv6, is_media)
|
||||
);
|
||||
"""
|
||||
|
||||
|
||||
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
||||
if peer_type in ["user", "bot"]:
|
||||
return raw.types.InputPeerUser(
|
||||
|
|
@ -121,6 +135,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
|||
class SQLiteStorage(Storage):
|
||||
VERSION = 4
|
||||
USERNAME_TTL = 8 * 60 * 60
|
||||
UPDATE_DC_SCHEMA = globals().get("UPDATE_DC_SCHEMA", "")
|
||||
|
||||
def __init__(self, name: str):
|
||||
super().__init__(name)
|
||||
|
|
@ -131,6 +146,7 @@ class SQLiteStorage(Storage):
|
|||
with self.conn:
|
||||
self.conn.executescript(SCHEMA)
|
||||
self.conn.executescript(UNAME_SCHEMA)
|
||||
self.conn.executescript(self.UPDATE_DC_SCHEMA)
|
||||
|
||||
self.conn.execute(
|
||||
"INSERT INTO version VALUES (?)",
|
||||
|
|
@ -252,6 +268,60 @@ class SQLiteStorage(Storage):
|
|||
|
||||
return get_input_peer(*r)
|
||||
|
||||
async def update_dc_address(
|
||||
self,
|
||||
value: Tuple[int, str, int, bool, bool, bool] = object
|
||||
):
|
||||
"""
|
||||
Updates or inserts a data center address.
|
||||
|
||||
Parameters:
|
||||
value (Tuple[int, str, int, bool, bool, bool]): A tuple containing:
|
||||
- dc_id (int): Data center ID.
|
||||
- address (str): Address of the data center.
|
||||
- port (int): Port of the data center.
|
||||
- is_ipv6 (bool): Whether the address is IPv6.
|
||||
- is_media (bool): Whether it is a media data center.
|
||||
"""
|
||||
if value == object:
|
||||
return
|
||||
with self.conn:
|
||||
self.conn.execute(
|
||||
"""
|
||||
INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_media)
|
||||
VALUES (?, ?, ?, ?, ?)
|
||||
ON CONFLICT(dc_id, is_ipv6, is_media)
|
||||
DO UPDATE SET address=excluded.address, port=excluded.port
|
||||
""",
|
||||
value
|
||||
)
|
||||
|
||||
async def get_dc_address(
|
||||
self,
|
||||
dc_id: int,
|
||||
is_ipv6: bool,
|
||||
media: bool = False
|
||||
) -> Tuple[str, int]:
|
||||
"""
|
||||
Retrieves the address of a data center.
|
||||
|
||||
Parameters:
|
||||
dc_id (int): Data center ID.
|
||||
is_ipv6 (bool): Whether the address is IPv6.
|
||||
media (bool): Whether it is a media data center.
|
||||
|
||||
Returns:
|
||||
Tuple[str, int]: A tuple containing the address and port of the data center.
|
||||
"""
|
||||
if dc_id in [1,3,5] and media:
|
||||
media = False
|
||||
r = self.conn.execute(
|
||||
"SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_media = ?",
|
||||
(dc_id, is_ipv6, media)
|
||||
).fetchone()
|
||||
|
||||
return r
|
||||
|
||||
def _get(self):
|
||||
attr = inspect.stack()[2].function
|
||||
|
||||
|
|
|
|||
|
|
@ -76,6 +76,21 @@ class Storage:
|
|||
async def get_peer_by_phone_number(self, phone_number: str):
|
||||
raise NotImplementedError
|
||||
|
||||
async def update_dc_address(
|
||||
self,
|
||||
value: Tuple[int, str, int, bool, bool] = object
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
async def get_dc_address(
|
||||
self,
|
||||
dc_id: int,
|
||||
is_ipv6: bool,
|
||||
test_mode: bool = False,
|
||||
media: bool = False
|
||||
):
|
||||
raise NotImplementedError
|
||||
|
||||
async def dc_id(self, value: int = object):
|
||||
raise NotImplementedError
|
||||
|
||||
|
|
|
|||
|
|
@ -43,13 +43,18 @@ PyromodConfig = SimpleNamespace(
|
|||
unallowed_click_alert=True,
|
||||
unallowed_click_alert_text=("[pyromod] You're not expected to click this button."),
|
||||
)
|
||||
|
||||
|
||||
async def ainput(prompt: str = "", *, hide: bool = False):
|
||||
|
||||
async def ainput(prompt: str = "", *, hide: bool = False, loop: Optional[asyncio.AbstractEventLoop] = None):
|
||||
"""Just like the built-in input, but async"""
|
||||
if isinstance(loop, asyncio.AbstractEventLoop):
|
||||
loop = loop
|
||||
else:
|
||||
loop = asyncio.get_event_loop()
|
||||
|
||||
with ThreadPoolExecutor(1) as executor:
|
||||
func = functools.partial(getpass if hide else input, prompt)
|
||||
return await asyncio.get_event_loop().run_in_executor(executor, func)
|
||||
return await loop.run_in_executor(executor, func)
|
||||
|
||||
|
||||
def get_input_media_from_file_id(
|
||||
|
|
|
|||
Loading…
Reference in a new issue