pyrofork: Retrive dc address and port from GetConfig

Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
wulan17 2025-06-07 22:22:39 +07:00
parent 6c7de705ce
commit 3115408b12
No known key found for this signature in database
GPG key ID: 737814D4B5FF0420
11 changed files with 231 additions and 3 deletions

View file

@ -61,6 +61,7 @@ from .file_id import FileId, FileType, ThumbnailSource
from .mime_types import mime_types
from .parser import Parser
from .session.internals import MsgId
from .session.internals.data_center import DataCenter
log = logging.getLogger(__name__)
MONGO_AVAIL = False
@ -522,6 +523,15 @@ class Client(Methods):
break
if isinstance(signed_in, User):
is_dc_default = await self.check_dc_default()
if is_dc_default:
log.info("Your session is using the default data center.")
log.info("Updating the data center options from GetConfig...")
await self.update_dc_option()
log.info("Data center updated successfully.")
log.info("Restarting the session to apply the changes...")
await self.session.stop()
await self.session.start()
return signed_in
def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
@ -856,6 +866,62 @@ class Client(Methods):
break
except Exception as e:
print(e)
# Needed for migration from storage v3 to v4
if self.in_memory or self.session_string:
await self.insert_default_dc_options()
if not await self.storage.get_dc_address(await self.storage.dc_id(), self.ipv6):
log.info("No DC address found, inserting default DC options...")
await self.insert_default_dc_options()
async def insert_default_dc_options(self):
for dc_id in range(1, 6):
for is_ipv6 in (False, True):
if dc_id == 2:
for media in (False, True):
address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False, media, True)
)
for test_mode in (False, True):
address, port = DataCenter(dc_id, test_mode, is_ipv6, self.alt_port, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, test_mode, False, True)
)
elif dc_id == 4:
for media in (False, True):
address, port = DataCenter(dc_id, False, is_ipv6, False, media)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False, media, True)
)
elif dc_id in [1,3]:
for test_mode in (False, True):
address, port = DataCenter(dc_id, test_mode, is_ipv6, False, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, test_mode, False, True)
)
else:
address, port = DataCenter(dc_id, False, is_ipv6, False, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False, False, True)
)
async def update_dc_option(self):
config = await self.invoke(raw.functions.help.GetConfig())
for option in config.dc_options:
await self.storage.update_dc_address(
(option.id, option.address, option.port, option.is_ipv6, option.is_media, option.is_test, False)
)
async def check_dc_default(
self,
dc_id: int,
is_ipv6: bool,
test_mode: bool = False,
media: bool = False
) -> bool:
current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, test_mode, media)
if current_dc is not None and current_dc[2]:
return True
def is_excluded(self, exclude, module):
for e in exclude:

View file

@ -22,7 +22,6 @@ import logging
from typing import Optional, Type
from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter
log = logging.getLogger(__name__)
@ -34,6 +33,8 @@ class Connection:
self,
dc_id: int,
test_mode: bool,
server_ip: str,
server_port: int,
ipv6: bool,
alt_port: bool,
proxy: dict,
@ -42,13 +43,14 @@ class Connection:
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
self.server_ip = server_ip
self.server_port = 5222 if alt_port else server_port
self.ipv6 = ipv6
self.alt_port = alt_port
self.proxy = proxy
self.media = media
self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.address = (server_ip, server_port)
self.protocol: Optional[TCP] = None
async def connect(self) -> None:

View file

@ -61,6 +61,7 @@ class SendCode:
)
)
except (PhoneMigrate, NetworkMigrate) as e:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -58,6 +58,7 @@ class SignInBot:
)
)
except UserMigrate as e:
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -80,6 +80,7 @@ class SignInQrcode:
return types.User._parse(self, r.authorization.user)
if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
await self.update_dc_option()
# pylint: disable=access-member-before-definition
await self.session.stop()

View file

@ -50,6 +50,7 @@ class Auth:
self.ipv6 = client.ipv6
self.alt_port = client.alt_port
self.proxy = client.proxy
self.storage = client.storage
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
@ -85,10 +86,13 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
address, port, _ = await self.storage.get_dc_address(self.dc_id, self.ipv6, self.test_mode)
while True:
self.connection = self.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.ipv6,
alt_port=self.alt_port,
proxy=self.proxy,

View file

@ -104,10 +104,13 @@ class Session:
self.loop = asyncio.get_event_loop()
async def start(self):
address, port, _ = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.test_mode, self.is_media)
while True:
self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
server_ip=address,
server_port=port,
ipv6=self.client.ipv6,
alt_port=self.client.alt_port,
proxy=self.client.proxy,

View file

@ -68,6 +68,12 @@ class FileStorage(SQLiteStorage):
version += 1
if version == 4:
with self.conn:
self.conn.executescript(self.UPDATE_DC_SCHEMA)
version += 1
self.version(version)
async def open(self):

View file

@ -77,6 +77,7 @@ class MongoStorage(Storage):
self._session = database['session']
self._usernames = database['usernames']
self._states = database['update_state']
self._dc_options = database['dc_options']
self._remove_peers = remove_peers
async def open(self):
@ -221,6 +222,56 @@ class MongoStorage(Storage):
return get_input_peer(r['_id'], r['access_hash'], r['type'])
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_test (bool): Whether it is a test data center.
- is_media (bool): Whether it is a media data center.
- is_default_ip (bool): Whether it is the dc IP address provided by library.
"""
if value == object:
return
await self._dc_options.update_one(
{"$and": [
{'dc_id': value[0]},
{'is_ipv6': value[3]},
{'is_test': value[4]},
{'is_media': value[5]}
]},
{'$set': {'address': value[1], 'port': value[2], 'is_default_ip': value[6]}},
upsert=True
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
test_mode: bool = False,
media: bool = False
) -> Tuple[str, int]:
if dc_id in [1,3,5] and media:
media = False
if dc_id in [4,5] and test_mode:
test_mode = False
r = await self._dc_options.find_one(
{'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_test': test_mode, 'is_media': media},
{'address': 1, 'port': 1, 'is_default_ip': 1}
)
if r is None:
return None
return r['address'], r['port'], r['is_default_ip']
async def _get(self):
attr = inspect.stack()[2].function
d = await self._session.find_one({'_id': 0}, {attr: 1})

View file

@ -97,6 +97,22 @@ END;
"""
UPDATE_DC_SCHEMA = """
CREATE TABLE dc_options
(
id INTEGER PRIMARY KEY AUTOINCREMENT,
dc_id INTEGER,
address TEXT,
port INTEGER,
is_ipv6 BOOLEAN,
is_test BOOLEAN,
is_media BOOLEAN,
is_default_ip BOOLEAN,
UNIQUE(dc_id, is_ipv6, is_test, is_media)
);
"""
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
if peer_type in ["user", "bot"]:
return raw.types.InputPeerUser(
@ -121,6 +137,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
class SQLiteStorage(Storage):
VERSION = 4
USERNAME_TTL = 8 * 60 * 60
UPDATE_DC_SCHEMA = globals().get("UPDATE_DC_SCHEMA", "")
def __init__(self, name: str):
super().__init__(name)
@ -131,6 +148,7 @@ class SQLiteStorage(Storage):
with self.conn:
self.conn.executescript(SCHEMA)
self.conn.executescript(UNAME_SCHEMA)
self.conn.executescript(self.UPDATE_DC_SCHEMA)
self.conn.execute(
"INSERT INTO version VALUES (?)",
@ -252,6 +270,66 @@ class SQLiteStorage(Storage):
return get_input_peer(*r)
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool, bool, bool] = object
):
"""
Updates or inserts a data center address.
Parameters:
value (Tuple[int, str, int, bool, bool, bool]): A tuple containing:
- dc_id (int): Data center ID.
- address (str): Address of the data center.
- port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6.
- is_test (bool): Whether it is a test data center.
- is_media (bool): Whether it is a media data center.
- is_default_ip (bool): Whether it is the dc IP address provided by library.
"""
if value == object:
return
with self.conn:
self.conn.execute(
"""
INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_test, is_media, is_default_ip)
VALUES (?, ?, ?, ?, ?, ?, ?)
ON CONFLICT(dc_id, is_ipv6, is_test, is_media)
DO UPDATE SET address=excluded.address, port=excluded.port
""",
value
)
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
test_mode: bool = False,
media: bool = False
) -> Tuple[str, int]:
"""
Retrieves the address of a data center.
Parameters:
dc_id (int): Data center ID.
is_ipv6 (bool): Whether the address is IPv6.
test_mode (bool): Whether it is a test data center.
media (bool): Whether it is a media data center.
Returns:
Tuple[str, int]: A tuple containing the address and port of the data center.
"""
if dc_id in [1,3,5] and media:
media = False
if dc_id in [4,5] and test_mode:
test_mode = False
r = self.conn.execute(
"SELECT address, port, is_default_ip FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_test = ? AND is_media = ?",
(dc_id, is_ipv6, test_mode, media)
).fetchone()
return r
def _get(self):
attr = inspect.stack()[2].function

View file

@ -76,6 +76,21 @@ class Storage:
async def get_peer_by_phone_number(self, phone_number: str):
raise NotImplementedError
async def update_dc_address(
self,
value: Tuple[int, str, int, bool, bool] = object
):
raise NotImplementedError
async def get_dc_address(
self,
dc_id: int,
is_ipv6: bool,
test_mode: bool = False,
media: bool = False
):
raise NotImplementedError
async def dc_id(self, value: int = object):
raise NotImplementedError