Compare commits

...

3 commits

Author SHA1 Message Date
wulan17
f824e72416
pyrofork: Refactor test_mode
Some checks failed
Pyrofork / build (macos-latest, 3.10) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.11) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.12) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.13) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.9) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.10) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.11) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.12) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.13) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.9) (push) Has been cancelled
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-26 20:48:41 +07:00
wulan17
2212b98b81
pyrofork: Retrive dc address and port from GetConfig
Signed-off-by: wulan17 <wulan17@komodos.id>
2025-06-26 20:41:48 +07:00
wulan17
070afc0246
pyrofork: disable publish workflows
Signed-off-by: wulan17 <wulan17@nusantararom.org>
2025-06-26 20:21:43 +07:00
12 changed files with 241 additions and 43 deletions

View file

@ -1,40 +0,0 @@
# This workflow will upload a Python Package using Twine when a release is created
# For more information see: https://docs.github.com/en/actions/automating-builds-and-tests/building-and-testing-python#publishing-to-package-registries
# This workflow uses actions that are not certified by GitHub.
# They are provided by a third-party and are governed by
# separate terms of service, privacy policy, and support
# documentation.
name: Upload Python Package
on:
push:
tags:
- '*'
permissions:
contents: read
jobs:
deploy:
runs-on: ubuntu-latest
environment: release
permissions:
id-token: write
steps:
- uses: actions/checkout@v4
- name: Set up Python
uses: actions/setup-python@v5
with:
python-version: '3.10'
- name: Install dependencies
run: |
python -m pip install --upgrade pip
pip install -e '.[dev]'
- name: Build package
run: hatch build
- name: Publish package
uses: pypa/gh-action-pypi-publish@release/v1

View file

@ -61,6 +61,7 @@ from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
from .parser import Parser
from .session.internals import MsgId
from .session.internals.data_center import DataCenter
log = logging.getLogger(__name__)
MONGO_AVAIL = False
@ -315,6 +316,8 @@ class Client(Methods):
self.max_message_cache_size = max_message_cache_size
self.max_message_cache_size = max_message_cache_size
self.max_business_user_connection_cache_size = max_business_user_connection_cache_size
self.test_addr = None
self.test_port = None
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")
@ -392,6 +395,25 @@ class Client(Methods):
except ConnectionError:
pass
def set_dc(self, addr: str, port: int = 80):
"""Set the data center address and port.
Parameters:
addr (``str``):
The data center address, e.g.: "149.154.167.40".
port (``int``, *optional*):
The data center port, e.g.: 443.
Defaults to 80.
"""
if not isinstance(addr, str):
raise TypeError("addr must be a string")
if not isinstance(port, int):
raise TypeError("port must be an integer")
self.test_addr = addr
self.test_port = port
async def updates_watchdog(self):
while True:
try:
@ -522,6 +544,16 @@ class Client(Methods):
break
if isinstance(signed_in, User):
if not self.test_mode:
is_dc_default = await self.check_dc_default()
if is_dc_default:
log.info("Your session is using the default data center.")
log.info("Updating the data center options from GetConfig...")
await self.update_dc_option()
log.info("Data center updated successfully.")
log.info("Restarting the session to apply the changes...")
await self.session.stop()
await self.session.start()
return signed_in
def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
@ -856,6 +888,45 @@ class Client(Methods):
break
except Exception as e:
print(e)
# Needed for migration from storage v3 to v4
if self.in_memory or self.session_string:
await self.insert_default_dc_options()
if not await self.storage.get_dc_address(await self.storage.dc_id(), self.ipv6):
log.info("No DC address found, inserting default DC options...")
await self.insert_default_dc_options()
async def insert_default_dc_options(self):
for dc_id in range(1, 6):
for is_ipv6 in (False, True):
if dc_id in [2,4]:
for media in (False, True):
address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, media, True)
)
else:
address, port = DataCenter(dc_id, False, is_ipv6, False, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False, True)
)
async def update_dc_option(self):
config = await self.invoke(raw.functions.help.GetConfig())
for option in config.dc_options:
await self.storage.update_dc_address(
(option.id, option.address, option.port, option.is_ipv6, option.is_media, False)
)
async def check_dc_default(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> bool:
current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, media)
if current_dc is not None and current_dc[2]:
return True
return False
def is_excluded(self, exclude, module):
for e in exclude:

View file

@ -22,7 +22,6 @@ import logging
from typing import Optional, Type
from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter
log = logging.getLogger(__name__)
@ -34,6 +33,8 @@ class Connection:
self,
dc_id: int,
test_mode: bool,
server_ip: str,
server_port: int,
ipv6: bool,
alt_port: bool,
proxy: dict,
@ -42,13 +43,14 @@ class Connection:
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
self.server_ip = server_ip
self.server_port = 5222 if alt_port else server_port
self.ipv6 = ipv6
self.alt_port = alt_port
self.proxy = proxy
self.media = media
self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.address = (server_ip, server_port)
self.protocol: Optional[TCP] = None
async def connect(self) -> None:

View file

@ -61,6 +61,8 @@ class SendCode:
)
)
except (PhoneMigrate, NetworkMigrate) as e:
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -58,6 +58,8 @@ class SignInBot:
)
)
except UserMigrate as e:
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -80,6 +80,8 @@ class SignInQrcode:
return types.User._parse(self, r.authorization.user)
if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
if not self.test_mode:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -45,11 +45,13 @@ class Auth:
dc_id: int,
test_mode: bool
):
self.client = client
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = client.ipv6
self.alt_port = client.alt_port
self.proxy = client.proxy
self.storage = client.storage
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
@ -85,10 +87,19 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
if not self.test_mode:
address, port, _ = await self.storage.get_dc_address(self.dc_id, self.ipv6)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True:
self.connection = self.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.ipv6,
alt_port=self.alt_port,
proxy=self.proxy,

View file

@ -104,10 +104,19 @@ class Session:
self.loop = asyncio.get_event_loop()
async def start(self):
if not self.test_mode:
address, port, _ = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.is_media)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True:
self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.client.ipv6,
alt_port=self.client.alt_port,
proxy=self.client.proxy,

View file

@ -68,6 +68,12 @@ class FileStorage(SQLiteStorage):
version += 1
if version == 4:
with self.conn:
self.conn.executescript(self.UPDATE_DC_SCHEMA)
version += 1
self.version(version)
async def open(self):

View file

@ -77,6 +77,7 @@ class MongoStorage(Storage):
self._session = database['session']
self._usernames = database['usernames']
self._states = database['update_state']
self._dc_options = database['dc_options']
self._remove_peers = remove_peers
async def open(self):
@ -221,6 +222,51 @@ class MongoStorage(Storage):
return get_input_peer(r['_id'], r['access_hash'], r['type'])
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_media (bool): Whether it is a media data center.
- is_default_ip (bool): Whether it is the dc IP address provided by library.
"""
if value == object:
return
await self._dc_options.update_one(
{"$and": [
{'dc_id': value[0]},
{'is_ipv6': value[3]},
{'is_media': value[4]}
]},
{'$set': {'address': value[1], 'port': value[2], 'is_default_ip': value[6]}},
upsert=True
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> Tuple[str, int]:
if dc_id in [1,3,5] and media:
media = False
r = await self._dc_options.find_one(
{'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_media': media},
{'address': 1, 'port': 1, 'is_default_ip': 1}
)
if r is None:
return None
return r['address'], r['port'], r.get['is_default_ip']
async def _get(self):
attr = inspect.stack()[2].function
d = await self._session.find_one({'_id': 0}, {attr: 1})

View file

@ -97,6 +97,21 @@ END;
"""
UPDATE_DC_SCHEMA = """
CREATE TABLE dc_options
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
dc_id INTEGER,
address TEXT,
port INTEGER,
is_ipv6 BOOLEAN,
is_media BOOLEAN,
is_default_ip BOOLEAN,
UNIQUE(dc_id, is_ipv6, is_media)
);
"""
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
if peer_type in ["user", "bot"]:
return raw.types.InputPeerUser(
@ -121,6 +136,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
class SQLiteStorage(Storage):
VERSION = 4
USERNAME_TTL = 8 * 60 * 60
UPDATE_DC_SCHEMA = globals().get("UPDATE_DC_SCHEMA", "")
def __init__(self, name: str):
super().__init__(name)
@ -131,6 +147,7 @@ class SQLiteStorage(Storage):
with self.conn:
self.conn.executescript(SCHEMA)
self.conn.executescript(UNAME_SCHEMA)
self.conn.executescript(self.UPDATE_DC_SCHEMA)
self.conn.execute(
"INSERT INTO version VALUES (?)",
@ -252,6 +269,61 @@ class SQLiteStorage(Storage):
return get_input_peer(*r)
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool, bool, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_media (bool): Whether it is a media data center.
- is_default_ip (bool): Whether it is the dc IP address provided by library.
"""
if value == object:
return
with self.conn:
self.conn.execute(
"""
INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_media)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(dc_id, is_ipv6, is_media, is_default_ip)
DO UPDATE SET address=excluded.address, port=excluded.port
""",
value
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
media: bool = False
) -> Tuple[str, int]:
"""
Retrieves the address of a data center.
Parameters:
dc_id (int): Data center ID.
is_ipv6 (bool): Whether the address is IPv6.
media (bool): Whether it is a media data center.
Returns:
Tuple[str, int]: A tuple containing the address and port of the data center.
"""
if dc_id in [1,3,5] and media:
media = False
r = self.conn.execute(
"SELECT address, port, is_default_ip FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_media = ?",
(dc_id, is_ipv6, media)
).fetchone()
return r
def _get(self):
attr = inspect.stack()[2].function

View file

@ -76,6 +76,21 @@ class Storage:
async def get_peer_by_phone_number(self, phone_number: str):
raise NotImplementedError
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool] = object
):
raise NotImplementedError
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
test_mode: bool = False,
media: bool = False
):
raise NotImplementedError
async def dc_id(self, value: int = object):
raise NotImplementedError