pyrofork: Refactor test_mode
Some checks failed
Pyrofork / build (macos-latest, 3.10) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.11) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.12) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.13) (push) Has been cancelled
Pyrofork / build (macos-latest, 3.9) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.10) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.11) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.12) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.13) (push) Has been cancelled
Pyrofork / build (ubuntu-latest, 3.9) (push) Has been cancelled

Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
wulan17 2025-06-08 00:43:03 +07:00
parent 6e9e1740b0
commit fd17d1ec5d
No known key found for this signature in database
GPG key ID: 737814D4B5FF0420
8 changed files with 67 additions and 56 deletions

View file

@ -320,6 +320,8 @@ class Client(Methods):
self.max_message_cache_size = max_message_cache_size self.max_message_cache_size = max_message_cache_size
self.max_message_cache_size = max_message_cache_size self.max_message_cache_size = max_message_cache_size
self.max_business_user_connection_cache_size = max_business_user_connection_cache_size self.max_business_user_connection_cache_size = max_business_user_connection_cache_size
self.test_addr = None
self.test_port = None
self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler") self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler")
@ -401,6 +403,25 @@ class Client(Methods):
except ConnectionError: except ConnectionError:
pass pass
def set_dc(self, addr: str, port: int = 80):
"""Set the data center address and port.
Parameters:
addr (``str``):
The data center address, e.g.: "149.154.167.40".
port (``int``, *optional*):
The data center port, e.g.: 443.
Defaults to 80.
"""
if not isinstance(addr, str):
raise TypeError("addr must be a string")
if not isinstance(port, int):
raise TypeError("port must be an integer")
self.test_addr = addr
self.test_port = port
async def updates_watchdog(self): async def updates_watchdog(self):
while True: while True:
try: try:
@ -531,6 +552,7 @@ class Client(Methods):
break break
if isinstance(signed_in, User): if isinstance(signed_in, User):
if not self.test_mode:
is_dc_default = await self.check_dc_default() is_dc_default = await self.check_dc_default()
if is_dc_default: if is_dc_default:
log.info("Your session is using the default data center.") log.info("Your session is using the default data center.")
@ -884,57 +906,39 @@ class Client(Methods):
async def insert_default_dc_options(self): async def insert_default_dc_options(self):
for dc_id in range(1, 6): for dc_id in range(1, 6):
for is_ipv6 in (False, True): for is_ipv6 in (False, True):
if dc_id == 2: if dc_id in [2,4]:
for media in (False, True): for media in (False, True):
address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media) address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media)
await self.storage.update_dc_address( await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False, media) (dc_id, address, port, is_ipv6, media)
)
for test_mode in (False, True):
address, port = DataCenter(dc_id, test_mode, is_ipv6, self.alt_port, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, test_mode, False)
)
elif dc_id == 4:
for media in (False, True):
address, port = DataCenter(dc_id, False, is_ipv6, False, media)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False, media)
)
elif dc_id in [1,3]:
for test_mode in (False, True):
address, port = DataCenter(dc_id, test_mode, is_ipv6, False, False)
await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, test_mode, False)
) )
else: else:
address, port = DataCenter(dc_id, False, is_ipv6, False, False) address, port = DataCenter(dc_id, False, is_ipv6, False, False)
await self.storage.update_dc_address( await self.storage.update_dc_address(
(dc_id, address, port, is_ipv6, False, False) (dc_id, address, port, is_ipv6, False)
) )
async def update_dc_option(self): async def update_dc_option(self):
config = await self.invoke(raw.functions.help.GetConfig()) config = await self.invoke(raw.functions.help.GetConfig())
for option in config.dc_options: for option in config.dc_options:
await self.storage.update_dc_address( await self.storage.update_dc_address(
(option.id, option.address, option.port, option.is_ipv6, option.is_media, option.is_test) (option.id, option.address, option.port, option.is_ipv6, option.is_media)
) )
async def check_dc_default( async def check_dc_default(
self, self,
dc_id: int, dc_id: int,
is_ipv6: bool, is_ipv6: bool,
test_mode: bool = False,
media: bool = False media: bool = False
) -> bool: ) -> bool:
default_dc = DataCenter( default_dc = DataCenter(
dc_id, dc_id,
test_mode=test_mode, test_mode=False,
is_ipv6=is_ipv6, is_ipv6=is_ipv6,
alt_port=self.alt_port, alt_port=self.alt_port,
media=media media=media
) )
current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, test_mode, media) current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, media)
if current_dc is not None and current_dc[0] == default_dc.address: if current_dc is not None and current_dc[0] == default_dc.address:
return True return True

View file

@ -61,6 +61,7 @@ class SendCode:
) )
) )
except (PhoneMigrate, NetworkMigrate) as e: except (PhoneMigrate, NetworkMigrate) as e:
if not self.test_mode:
await self.update_dc_option() await self.update_dc_option()
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
await self.session.stop() await self.session.stop()

View file

@ -58,6 +58,7 @@ class SignInBot:
) )
) )
except UserMigrate as e: except UserMigrate as e:
if not self.test_mode:
await self.update_dc_option() await self.update_dc_option()
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
await self.session.stop() await self.session.stop()

View file

@ -80,6 +80,7 @@ class SignInQrcode:
return types.User._parse(self, r.authorization.user) return types.User._parse(self, r.authorization.user)
if isinstance(r, raw.types.auth.LoginTokenMigrateTo): if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
if not self.test_mode:
await self.update_dc_option() await self.update_dc_option()
# pylint: disable=access-member-before-definition # pylint: disable=access-member-before-definition
await self.session.stop() await self.session.stop()

View file

@ -45,6 +45,7 @@ class Auth:
dc_id: int, dc_id: int,
test_mode: bool test_mode: bool
): ):
self.client = client
self.dc_id = dc_id self.dc_id = dc_id
self.test_mode = test_mode self.test_mode = test_mode
self.ipv6 = client.ipv6 self.ipv6 = client.ipv6
@ -87,7 +88,13 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail. # The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times. # If that happens, just try again up to MAX_RETRIES times.
address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6, self.test_mode) if not self.test_mode:
address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True: while True:
self.connection = self.connection_factory( self.connection = self.connection_factory(
dc_id=self.dc_id, dc_id=self.dc_id,

View file

@ -107,7 +107,13 @@ class Session:
self.last_reconnect_attempt = None self.last_reconnect_attempt = None
async def start(self): async def start(self):
address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.test_mode, self.is_media) if not self.test_mode:
address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.is_media)
else:
address = self.client.test_addr
port = self.client.test_port
if address is None or port is None:
raise ValueError("Test address and port must be set for test mode.")
while True: while True:
self.connection = self.client.connection_factory( self.connection = self.client.connection_factory(
dc_id=self.dc_id, dc_id=self.dc_id,

View file

@ -230,12 +230,11 @@ class MongoStorage(Storage):
Updates or inserts a data center address. Updates or inserts a data center address.
Parameters: Parameters:
value (Tuple[int, str, int, bool, bool]): A tuple containing: value (Tuple[int, str, int, bool]): A tuple containing:
- dc_id (int): Data center ID. - dc_id (int): Data center ID.
- address (str): Address of the data center. - address (str): Address of the data center.
- port (int): Port of the data center. - port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6. - is_ipv6 (bool): Whether the address is IPv6.
- is_test (bool): Whether it is a test data center.
- is_media (bool): Whether it is a media data center. - is_media (bool): Whether it is a media data center.
""" """
if value == object: if value == object:
@ -245,8 +244,7 @@ class MongoStorage(Storage):
{"$and": [ {"$and": [
{'dc_id': value[0]}, {'dc_id': value[0]},
{'is_ipv6': value[3]}, {'is_ipv6': value[3]},
{'is_test': value[4]}, {'is_media': value[4]}
{'is_media': value[5]}
]}, ]},
{'$set': {'address': value[1], 'port': value[2]}}, {'$set': {'address': value[1], 'port': value[2]}},
upsert=True upsert=True
@ -256,7 +254,6 @@ class MongoStorage(Storage):
self, self,
dc_id: int, dc_id: int,
is_ipv6: bool, is_ipv6: bool,
test_mode: bool = False,
media: bool = False media: bool = False
) -> Tuple[str, int]: ) -> Tuple[str, int]:
if dc_id in [1,3,5] and media: if dc_id in [1,3,5] and media:
@ -264,7 +261,7 @@ class MongoStorage(Storage):
if dc_id in [4,5] and test_mode: if dc_id in [4,5] and test_mode:
test_mode = False test_mode = False
r = await self._dc_options.find_one( r = await self._dc_options.find_one(
{'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_test': test_mode, 'is_media': media}, {'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_media': media},
{'address': 1, 'port': 1} {'address': 1, 'port': 1}
) )
if r is None: if r is None:

View file

@ -105,9 +105,8 @@ CREATE TABLE dc_options
address TEXT, address TEXT,
port INTEGER, port INTEGER,
is_ipv6 BOOLEAN, is_ipv6 BOOLEAN,
is_test BOOLEAN,
is_media BOOLEAN, is_media BOOLEAN,
UNIQUE(dc_id, is_ipv6, is_test, is_media) UNIQUE(dc_id, is_ipv6, is_media)
); );
""" """
@ -282,7 +281,6 @@ class SQLiteStorage(Storage):
- address (str): Address of the data center. - address (str): Address of the data center.
- port (int): Port of the data center. - port (int): Port of the data center.
- is_ipv6 (bool): Whether the address is IPv6. - is_ipv6 (bool): Whether the address is IPv6.
- is_test (bool): Whether it is a test data center.
- is_media (bool): Whether it is a media data center. - is_media (bool): Whether it is a media data center.
""" """
if value == object: if value == object:
@ -290,9 +288,9 @@ class SQLiteStorage(Storage):
with self.conn: with self.conn:
self.conn.execute( self.conn.execute(
""" """
INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_test, is_media) INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_media)
VALUES (?, ?, ?, ?, ?, ?) VALUES (?, ?, ?, ?, ?)
ON CONFLICT(dc_id, is_ipv6, is_test, is_media) ON CONFLICT(dc_id, is_ipv6, is_media)
DO UPDATE SET address=excluded.address, port=excluded.port DO UPDATE SET address=excluded.address, port=excluded.port
""", """,
value value
@ -302,7 +300,6 @@ class SQLiteStorage(Storage):
self, self,
dc_id: int, dc_id: int,
is_ipv6: bool, is_ipv6: bool,
test_mode: bool = False,
media: bool = False media: bool = False
) -> Tuple[str, int]: ) -> Tuple[str, int]:
""" """
@ -311,7 +308,6 @@ class SQLiteStorage(Storage):
Parameters: Parameters:
dc_id (int): Data center ID. dc_id (int): Data center ID.
is_ipv6 (bool): Whether the address is IPv6. is_ipv6 (bool): Whether the address is IPv6.
test_mode (bool): Whether it is a test data center.
media (bool): Whether it is a media data center. media (bool): Whether it is a media data center.
Returns: Returns:
@ -319,11 +315,9 @@ class SQLiteStorage(Storage):
""" """
if dc_id in [1,3,5] and media: if dc_id in [1,3,5] and media:
media = False media = False
if dc_id in [4,5] and test_mode:
test_mode = False
r = self.conn.execute( r = self.conn.execute(
"SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_test = ? AND is_media = ?", "SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_media = ?",
(dc_id, is_ipv6, test_mode, media) (dc_id, is_ipv6, media)
).fetchone() ).fetchone()
return r return r