diff --git a/pyrogram/client.py b/pyrogram/client.py index 51a2c33e..ecaa5983 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -316,6 +316,8 @@ class Client(Methods): self.max_message_cache_size = max_message_cache_size self.max_message_cache_size = max_message_cache_size self.max_business_user_connection_cache_size = max_business_user_connection_cache_size + self.test_addr = None + self.test_port = None self.executor = ThreadPoolExecutor(self.workers, thread_name_prefix="Handler") @@ -393,6 +395,25 @@ class Client(Methods): except ConnectionError: pass + def set_dc(self, addr: str, port: int = 80): + """Set the data center address and port. + + Parameters: + addr (``str``): + The data center address, e.g.: "149.154.167.40". + + port (``int``, *optional*): + The data center port, e.g.: 443. + Defaults to 80. + """ + if not isinstance(addr, str): + raise TypeError("addr must be a string") + if not isinstance(port, int): + raise TypeError("port must be an integer") + + self.test_addr = addr + self.test_port = port + async def updates_watchdog(self): while True: try: @@ -523,15 +544,16 @@ class Client(Methods): break if isinstance(signed_in, User): - is_dc_default = await self.check_dc_default() - if is_dc_default: - log.info("Your session is using the default data center.") - log.info("Updating the data center options from GetConfig...") - await self.update_dc_option() - log.info("Data center updated successfully.") - log.info("Restarting the session to apply the changes...") - await self.session.stop() - await self.session.start() + if not self.test_mode: + is_dc_default = await self.check_dc_default() + if is_dc_default: + log.info("Your session is using the default data center.") + log.info("Updating the data center options from GetConfig...") + await self.update_dc_option() + log.info("Data center updated successfully.") + log.info("Restarting the session to apply the changes...") + await self.session.stop() + await self.session.start() return signed_in def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]): @@ -876,52 +898,35 @@ class Client(Methods): async def insert_default_dc_options(self): for dc_id in range(1, 6): for is_ipv6 in (False, True): - if dc_id == 2: + if dc_id in [2,4]: for media in (False, True): address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media) await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, False, media, True) - ) - for test_mode in (False, True): - address, port = DataCenter(dc_id, test_mode, is_ipv6, self.alt_port, False) - await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, test_mode, False, True) - ) - elif dc_id == 4: - for media in (False, True): - address, port = DataCenter(dc_id, False, is_ipv6, False, media) - await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, False, media, True) - ) - elif dc_id in [1,3]: - for test_mode in (False, True): - address, port = DataCenter(dc_id, test_mode, is_ipv6, False, False) - await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, test_mode, False, True) + (dc_id, address, port, is_ipv6, media, True) ) else: address, port = DataCenter(dc_id, False, is_ipv6, False, False) await self.storage.update_dc_address( - (dc_id, address, port, is_ipv6, False, False, True) + (dc_id, address, port, is_ipv6, False, True) ) async def update_dc_option(self): config = await self.invoke(raw.functions.help.GetConfig()) for option in config.dc_options: await self.storage.update_dc_address( - (option.id, option.address, option.port, option.is_ipv6, option.is_media, option.is_test, False) + (option.id, option.address, option.port, option.is_ipv6, option.is_media, False) ) async def check_dc_default( self, dc_id: int, is_ipv6: bool, - test_mode: bool = False, media: bool = False ) -> bool: - current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, test_mode, media) + current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, media) if current_dc is not None and current_dc[2]: return True + return False def is_excluded(self, exclude, module): for e in exclude: diff --git a/pyrogram/methods/auth/send_code.py b/pyrogram/methods/auth/send_code.py index a0ba8114..accba1f0 100644 --- a/pyrogram/methods/auth/send_code.py +++ b/pyrogram/methods/auth/send_code.py @@ -61,7 +61,8 @@ class SendCode: ) ) except (PhoneMigrate, NetworkMigrate) as e: - await self.update_dc_option() + if not self.test_mode: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/methods/auth/sign_in_bot.py b/pyrogram/methods/auth/sign_in_bot.py index a7ea7eee..e2b31b29 100644 --- a/pyrogram/methods/auth/sign_in_bot.py +++ b/pyrogram/methods/auth/sign_in_bot.py @@ -58,7 +58,8 @@ class SignInBot: ) ) except UserMigrate as e: - await self.update_dc_option() + if not self.test_mode: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/methods/auth/sign_in_qrcode.py b/pyrogram/methods/auth/sign_in_qrcode.py index 48f8dcc6..6bd58ba8 100644 --- a/pyrogram/methods/auth/sign_in_qrcode.py +++ b/pyrogram/methods/auth/sign_in_qrcode.py @@ -80,7 +80,8 @@ class SignInQrcode: return types.User._parse(self, r.authorization.user) if isinstance(r, raw.types.auth.LoginTokenMigrateTo): - await self.update_dc_option() + if not self.test_mode: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index ea5b9ede..99ca737c 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -45,6 +45,7 @@ class Auth: dc_id: int, test_mode: bool ): + self.client = client self.dc_id = dc_id self.test_mode = test_mode self.ipv6 = client.ipv6 @@ -86,7 +87,13 @@ class Auth: # The server may close the connection at any time, causing the auth key creation to fail. # If that happens, just try again up to MAX_RETRIES times. - address, port, _ = await self.storage.get_dc_address(self.dc_id, self.ipv6, self.test_mode) + if not self.test_mode: + address, port, _ = await self.storage.get_dc_address(self.dc_id, self.ipv6) + else: + address = self.client.test_addr + port = self.client.test_port + if address is None or port is None: + raise ValueError("Test address and port must be set for test mode.") while True: self.connection = self.connection_factory( dc_id=self.dc_id, diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 9a14077f..33094f0e 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -104,7 +104,13 @@ class Session: self.loop = asyncio.get_event_loop() async def start(self): - address, port, _ = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.test_mode, self.is_media) + if not self.test_mode: + address, port, _ = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.is_media) + else: + address = self.client.test_addr + port = self.client.test_port + if address is None or port is None: + raise ValueError("Test address and port must be set for test mode.") while True: self.connection = self.client.connection_factory( dc_id=self.dc_id, diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py index 57a53bc5..cf955267 100644 --- a/pyrogram/storage/mongo_storage.py +++ b/pyrogram/storage/mongo_storage.py @@ -230,12 +230,11 @@ class MongoStorage(Storage): Updates or inserts a data center address. Parameters: - value (Tuple[int, str, int, bool, bool]): A tuple containing: + value (Tuple[int, str, int, bool]): A tuple containing: - dc_id (int): Data center ID. - address (str): Address of the data center. - port (int): Port of the data center. - is_ipv6 (bool): Whether the address is IPv6. - - is_test (bool): Whether it is a test data center. - is_media (bool): Whether it is a media data center. - is_default_ip (bool): Whether it is the dc IP address provided by library. """ @@ -246,10 +245,9 @@ class MongoStorage(Storage): {"$and": [ {'dc_id': value[0]}, {'is_ipv6': value[3]}, - {'is_test': value[4]}, - {'is_media': value[5]} + {'is_media': value[4]} ]}, - {'$set': {'address': value[1], 'port': value[2], 'is_default_ip': value[6]}}, + {'$set': {'address': value[1], 'port': value[2], 'is_default_ip': value[5]}}, upsert=True ) @@ -257,15 +255,12 @@ class MongoStorage(Storage): self, dc_id: int, is_ipv6: bool, - test_mode: bool = False, media: bool = False ) -> Tuple[str, int]: if dc_id in [1,3,5] and media: media = False - if dc_id in [4,5] and test_mode: - test_mode = False r = await self._dc_options.find_one( - {'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_test': test_mode, 'is_media': media}, + {'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_media': media}, {'address': 1, 'port': 1, 'is_default_ip': 1} ) if r is None: diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index ac339017..464e306a 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -105,10 +105,9 @@ CREATE TABLE dc_options address TEXT, port INTEGER, is_ipv6 BOOLEAN, - is_test BOOLEAN, is_media BOOLEAN, is_default_ip BOOLEAN, - UNIQUE(dc_id, is_ipv6, is_test, is_media) + UNIQUE(dc_id, is_ipv6, is_media) ); """ @@ -283,7 +282,6 @@ class SQLiteStorage(Storage): - address (str): Address of the data center. - port (int): Port of the data center. - is_ipv6 (bool): Whether the address is IPv6. - - is_test (bool): Whether it is a test data center. - is_media (bool): Whether it is a media data center. - is_default_ip (bool): Whether it is the dc IP address provided by library. """ @@ -292,9 +290,9 @@ class SQLiteStorage(Storage): with self.conn: self.conn.execute( """ - INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_test, is_media, is_default_ip) - VALUES (?, ?, ?, ?, ?, ?, ?) - ON CONFLICT(dc_id, is_ipv6, is_test, is_media) + INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_media, is_default_ip) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(dc_id, is_ipv6, is_media) DO UPDATE SET address=excluded.address, port=excluded.port """, value @@ -304,7 +302,6 @@ class SQLiteStorage(Storage): self, dc_id: int, is_ipv6: bool, - test_mode: bool = False, media: bool = False ) -> Tuple[str, int]: """ @@ -313,7 +310,6 @@ class SQLiteStorage(Storage): Parameters: dc_id (int): Data center ID. is_ipv6 (bool): Whether the address is IPv6. - test_mode (bool): Whether it is a test data center. media (bool): Whether it is a media data center. Returns: @@ -321,11 +317,9 @@ class SQLiteStorage(Storage): """ if dc_id in [1,3,5] and media: media = False - if dc_id in [4,5] and test_mode: - test_mode = False r = self.conn.execute( - "SELECT address, port, is_default_ip FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_test = ? AND is_media = ?", - (dc_id, is_ipv6, test_mode, media) + "SELECT address, port, is_default_ip FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_media = ?", + (dc_id, is_ipv6, media) ).fetchone() return r