From fd17d1ec5d546dfda60090f918a48a57a7c2ef4a Mon Sep 17 00:00:00 2001 From: wulan17 Date: Sun, 8 Jun 2025 00:43:03 +0700 Subject: [PATCH] pyrofork: Refactor test_mode Signed-off-by: wulan17 --- pyrogram/client.py | 70 +++++++++++++------------ pyrogram/methods/auth/send_code.py | 3 +- pyrogram/methods/auth/sign_in_bot.py | 3 +- pyrogram/methods/auth/sign_in_qrcode.py | 3 +- pyrogram/session/auth.py | 9 +++- pyrogram/session/session.py | 8 ++- pyrogram/storage/mongo_storage.py | 9 ++-- pyrogram/storage/sqlite_storage.py | 18 +++---- 8 files changed, 67 insertions(+), 56 deletions(-) diff --git a/pyrogram/client.py b/pyrogram/client.py index 6d1da5ea..e49a5371 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -320,6 +320,8 @@ class Client(Methods): self.max_message_cache_size = max_message_cache_size self.max_message_cache_size = max_message_cache_size self.max_business_user_connection_cache_size = max_business_user_connection_cache_size + self.test_addr = None + self.test_port = None self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler") @@ -401,6 +403,25 @@ class Client(Methods): except ConnectionError: pass + def set_dc(self, addr: str, port: int = 80): + """Set the data center address and port. + + Parameters: + addr (``str``): + The data center address, e.g.: "149.154.167.40". + + port (``int``, *optional*): + The data center port, e.g.: 443. + Defaults to 80. + """ + if not isinstance(addr, str): + raise TypeError("addr must be a string") + if not isinstance(port, int): + raise TypeError("port must be an integer") + + self.test_addr = addr + self.test_port = port + async def updates_watchdog(self): while True: try: @@ -531,15 +552,16 @@ class Client(Methods): break if isinstance(signed_in, User): - is_dc_default = await self.check_dc_default() - if is_dc_default: - log.info("Your session is using the default data center.") - log.info("Updating the data center options from GetConfig...") - await self.update_dc_option() - log.info("Data center updated successfully.") - log.info("Restarting the session to apply the changes...") - await self.session.stop() - await self.session.start() + if not self.test_mode: + is_dc_default = await self.check_dc_default() + if is_dc_default: + log.info("Your session is using the default data center.") + log.info("Updating the data center options from GetConfig...") + await self.update_dc_option() + log.info("Data center updated successfully.") + log.info("Restarting the session to apply the changes...") + await self.session.stop() + await self.session.start() return signed_in def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]): @@ -884,57 +906,39 @@ class Client(Methods): async def insert_default_dc_options(self): for dc_id in range(1, 6): for is_ipv6 in (False, True): - if dc_id == 2: + if dc_id in [2,4]: for media in (False, True): address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media) await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, False, media) - ) - for test_mode in (False, True): - address, port = DataCenter(dc_id, test_mode, is_ipv6, self.alt_port, False) - await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, test_mode, False) - ) - elif dc_id == 4: - for media in (False, True): - address, port = DataCenter(dc_id, False, is_ipv6, False, media) - await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, False, media) - ) - elif dc_id in [1,3]: - for test_mode in (False, True): - address, port = DataCenter(dc_id, test_mode, is_ipv6, False, False) - await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, test_mode, False) + (dc_id, address, port, is_ipv6, media) ) else: address, port = DataCenter(dc_id, False, is_ipv6, False, False) await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, False, False) + (dc_id, address, port, is_ipv6, False) ) async def update_dc_option(self): config = await self.invoke(raw.functions.help.GetConfig()) for option in config.dc_options: await self.storage.update_dc_address( - (option.id, option.address, option.port, option.is_ipv6, option.is_media, option.is_test) + (option.id, option.address, option.port, option.is_ipv6, option.is_media) ) async def check_dc_default( self, dc_id: int, is_ipv6: bool, - test_mode: bool = False, media: bool = False ) -> bool: default_dc = DataCenter( dc_id, - test_mode=test_mode, + test_mode=False, is_ipv6=is_ipv6, alt_port=self.alt_port, media=media ) - current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, test_mode, media) + current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, media) if current_dc is not None and current_dc[0] == default_dc.address: return True diff --git a/pyrogram/methods/auth/send_code.py b/pyrogram/methods/auth/send_code.py index a0ba8114..accba1f0 100644 --- a/pyrogram/methods/auth/send_code.py +++ b/pyrogram/methods/auth/send_code.py @@ -61,7 +61,8 @@ class SendCode: ) ) except (PhoneMigrate, NetworkMigrate) as e: - await self.update_dc_option() + if not self.test_mode: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/methods/auth/sign_in_bot.py b/pyrogram/methods/auth/sign_in_bot.py index a7ea7eee..e2b31b29 100644 --- a/pyrogram/methods/auth/sign_in_bot.py +++ b/pyrogram/methods/auth/sign_in_bot.py @@ -58,7 +58,8 @@ class SignInBot: ) ) except UserMigrate as e: - await self.update_dc_option() + if not self.test_mode: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/methods/auth/sign_in_qrcode.py b/pyrogram/methods/auth/sign_in_qrcode.py index 48f8dcc6..6bd58ba8 100644 --- a/pyrogram/methods/auth/sign_in_qrcode.py +++ b/pyrogram/methods/auth/sign_in_qrcode.py @@ -80,7 +80,8 @@ class SignInQrcode: return types.User._parse(self, r.authorization.user) if isinstance(r, raw.types.auth.LoginTokenMigrateTo): - await self.update_dc_option() + if not self.test_mode: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index c376a3a1..28237cff 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -45,6 +45,7 @@ class Auth: dc_id: int, test_mode: bool ): + self.client = client self.dc_id = dc_id self.test_mode = test_mode self.ipv6 = client.ipv6 @@ -87,7 +88,13 @@ class Auth: # The server may close the connection at any time, causing the auth key creation to fail. # If that happens, just try again up to MAX_RETRIES times. - address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6, self.test_mode) + if not self.test_mode: + address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6) + else: + address = self.client.test_addr + port = self.client.test_port + if address is None or port is None: + raise ValueError("Test address and port must be set for test mode.") while True: self.connection = self.connection_factory( dc_id=self.dc_id, diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index e533f897..39f54133 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -107,7 +107,13 @@ class Session: self.last_reconnect_attempt = None async def start(self): - address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.test_mode, self.is_media) + if not self.test_mode: + address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.is_media) + else: + address = self.client.test_addr + port = self.client.test_port + if address is None or port is None: + raise ValueError("Test address and port must be set for test mode.") while True: self.connection = self.client.connection_factory( dc_id=self.dc_id, diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py index f040ca67..ca0cfde0 100644 --- a/pyrogram/storage/mongo_storage.py +++ b/pyrogram/storage/mongo_storage.py @@ -230,12 +230,11 @@ class MongoStorage(Storage): Updates or inserts a data center address. Parameters: - value (Tuple[int, str, int, bool, bool]): A tuple containing: + value (Tuple[int, str, int, bool]): A tuple containing: - dc_id (int): Data center ID. - address (str): Address of the data center. - port (int): Port of the data center. - is_ipv6 (bool): Whether the address is IPv6. - - is_test (bool): Whether it is a test data center. - is_media (bool): Whether it is a media data center. """ if value == object: @@ -245,8 +244,7 @@ class MongoStorage(Storage): {"$and": [ {'dc_id': value[0]}, {'is_ipv6': value[3]}, - {'is_test': value[4]}, - {'is_media': value[5]} + {'is_media': value[4]} ]}, {'$set': {'address': value[1], 'port': value[2]}}, upsert=True @@ -256,7 +254,6 @@ class MongoStorage(Storage): self, dc_id: int, is_ipv6: bool, - test_mode: bool = False, media: bool = False ) -> Tuple[str, int]: if dc_id in [1,3,5] and media: @@ -264,7 +261,7 @@ class MongoStorage(Storage): if dc_id in [4,5] and test_mode: test_mode = False r = await self._dc_options.find_one( - {'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_test': test_mode, 'is_media': media}, + {'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_media': media}, {'address': 1, 'port': 1} ) if r is None: diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index 63a0a537..1d832d5a 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -105,9 +105,8 @@ CREATE TABLE dc_options address TEXT, port INTEGER, is_ipv6 BOOLEAN, - is_test BOOLEAN, is_media BOOLEAN, - UNIQUE(dc_id, is_ipv6, is_test, is_media) + UNIQUE(dc_id, is_ipv6, is_media) ); """ @@ -282,7 +281,6 @@ class SQLiteStorage(Storage): - address (str): Address of the data center. - port (int): Port of the data center. - is_ipv6 (bool): Whether the address is IPv6. - - is_test (bool): Whether it is a test data center. - is_media (bool): Whether it is a media data center. """ if value == object: @@ -290,9 +288,9 @@ class SQLiteStorage(Storage): with self.conn: self.conn.execute( """ - INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_test, is_media) - VALUES (?, ?, ?, ?, ?, ?) - ON CONFLICT(dc_id, is_ipv6, is_test, is_media) + INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_media) + VALUES (?, ?, ?, ?, ?) + ON CONFLICT(dc_id, is_ipv6, is_media) DO UPDATE SET address=excluded.address, port=excluded.port """, value @@ -302,7 +300,6 @@ class SQLiteStorage(Storage): self, dc_id: int, is_ipv6: bool, - test_mode: bool = False, media: bool = False ) -> Tuple[str, int]: """ @@ -311,7 +308,6 @@ class SQLiteStorage(Storage): Parameters: dc_id (int): Data center ID. is_ipv6 (bool): Whether the address is IPv6. - test_mode (bool): Whether it is a test data center. media (bool): Whether it is a media data center. Returns: @@ -319,11 +315,9 @@ class SQLiteStorage(Storage): """ if dc_id in [1,3,5] and media: media = False - if dc_id in [4,5] and test_mode: - test_mode = False r = self.conn.execute( - "SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_test = ? AND is_media = ?", - (dc_id, is_ipv6, test_mode, media) + "SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_media = ?", + (dc_id, is_ipv6, media) ).fetchone() return r