diff --git a/pyrogram/client.py b/pyrogram/client.py index e2fabbdb..6d1da5ea 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -61,6 +61,7 @@ from .file_id import FileId, FileType, ThumbnailSource from .mime_types import mime_types from .parser import Parser from .session.internals import MsgId +from .session.internals.data_center import DataCenter log = logging.getLogger(__name__) MONGO_AVAIL = False @@ -530,6 +531,15 @@ class Client(Methods): break if isinstance(signed_in, User): + is_dc_default = await self.check_dc_default() + if is_dc_default: + log.info("Your session is using the default data center.") + log.info("Updating the data center options from GetConfig...") + await self.update_dc_option() + log.info("Data center updated successfully.") + log.info("Restarting the session to apply the changes...") + await self.session.stop() + await self.session.start() return signed_in def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]): @@ -864,6 +874,69 @@ class Client(Methods): break except Exception as e: print(e) + # Needed for migration from storage v3 to v4 + if self.in_memory or self.session_string: + await self.insert_default_dc_options() + if not await self.storage.get_dc_address(await self.storage.dc_id(), self.ipv6): + log.info("No DC address found, inserting default DC options...") + await self.insert_default_dc_options() + + async def insert_default_dc_options(self): + for dc_id in range(1, 6): + for is_ipv6 in (False, True): + if dc_id == 2: + for media in (False, True): + address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media) + await self.storage.update_dc_address( + (dc_id, address, port, is_ipv6, False, media) + ) + for test_mode in (False, True): + address, port = DataCenter(dc_id, test_mode, is_ipv6, self.alt_port, False) + await self.storage.update_dc_address( + (dc_id, address, port, is_ipv6, test_mode, False) + ) + elif dc_id == 4: + for media in (False, True): + address, port = DataCenter(dc_id, False, is_ipv6, False, media) + await self.storage.update_dc_address( + (dc_id, address, port, is_ipv6, False, media) + ) + elif dc_id in [1,3]: + for test_mode in (False, True): + address, port = DataCenter(dc_id, test_mode, is_ipv6, False, False) + await self.storage.update_dc_address( + (dc_id, address, port, is_ipv6, test_mode, False) + ) + else: + address, port = DataCenter(dc_id, False, is_ipv6, False, False) + await self.storage.update_dc_address( + (dc_id, address, port, is_ipv6, False, False) + ) + + async def update_dc_option(self): + config = await self.invoke(raw.functions.help.GetConfig()) + for option in config.dc_options: + await self.storage.update_dc_address( + (option.id, option.address, option.port, option.is_ipv6, option.is_media, option.is_test) + ) + + async def check_dc_default( + self, + dc_id: int, + is_ipv6: bool, + test_mode: bool = False, + media: bool = False + ) -> bool: + default_dc = DataCenter( + dc_id, + test_mode=test_mode, + is_ipv6=is_ipv6, + alt_port=self.alt_port, + media=media + ) + current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, test_mode, media) + if current_dc is not None and current_dc[0] == default_dc.address: + return True def is_excluded(self, exclude, module): for e in exclude: diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 016907e4..4118f8d5 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -22,7 +22,6 @@ import logging from typing import Optional, Type from .transport import TCP, TCPAbridged -from ..session.internals import DataCenter log = logging.getLogger(__name__) @@ -34,6 +33,8 @@ class Connection: self, dc_id: int, test_mode: bool, + server_ip: str, + server_port: int, ipv6: bool, alt_port: bool, proxy: dict, @@ -43,13 +44,14 @@ class Connection: ) -> None: self.dc_id = dc_id self.test_mode = test_mode + self.server_ip = server_ip + self.server_port = 5222 if alt_port else server_port self.ipv6 = ipv6 - self.alt_port = alt_port self.proxy = proxy self.media = media self.protocol_factory = protocol_factory - self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media) + self.address = (server_ip, server_port) self.protocol: Optional[TCP] = None if isinstance(loop, asyncio.AbstractEventLoop): diff --git a/pyrogram/methods/auth/send_code.py b/pyrogram/methods/auth/send_code.py index ba4d3283..a0ba8114 100644 --- a/pyrogram/methods/auth/send_code.py +++ b/pyrogram/methods/auth/send_code.py @@ -61,6 +61,7 @@ class SendCode: ) ) except (PhoneMigrate, NetworkMigrate) as e: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/methods/auth/sign_in_bot.py b/pyrogram/methods/auth/sign_in_bot.py index 231f6d9b..a7ea7eee 100644 --- a/pyrogram/methods/auth/sign_in_bot.py +++ b/pyrogram/methods/auth/sign_in_bot.py @@ -58,6 +58,7 @@ class SignInBot: ) ) except UserMigrate as e: + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/methods/auth/sign_in_qrcode.py b/pyrogram/methods/auth/sign_in_qrcode.py index 0b255788..48f8dcc6 100644 --- a/pyrogram/methods/auth/sign_in_qrcode.py +++ b/pyrogram/methods/auth/sign_in_qrcode.py @@ -80,6 +80,7 @@ class SignInQrcode: return types.User._parse(self, r.authorization.user) if isinstance(r, raw.types.auth.LoginTokenMigrateTo): + await self.update_dc_option() # pylint: disable=access-member-before-definition await self.session.stop() diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index 23b6e8d7..c376a3a1 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -50,6 +50,7 @@ class Auth: self.ipv6 = client.ipv6 self.alt_port = client.alt_port self.proxy = client.proxy + self.storage = client.storage self.connection_factory = client.connection_factory self.protocol_factory = client.protocol_factory self.loop = client.loop @@ -86,10 +87,13 @@ class Auth: # The server may close the connection at any time, causing the auth key creation to fail. # If that happens, just try again up to MAX_RETRIES times. + address, port = await self.storage.get_dc_address(self.dc_id, self.ipv6, self.test_mode) while True: self.connection = self.connection_factory( dc_id=self.dc_id, test_mode=self.test_mode, + server_ip=address, + server_port=port, ipv6=self.ipv6, alt_port=self.alt_port, proxy=self.proxy, diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 098d0cd9..e533f897 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -107,10 +107,13 @@ class Session: self.last_reconnect_attempt = None async def start(self): + address, port = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.test_mode, self.is_media) while True: self.connection = self.client.connection_factory( dc_id=self.dc_id, test_mode=self.test_mode, + server_ip=address, + server_port=port, ipv6=self.client.ipv6, alt_port=self.client.alt_port, proxy=self.client.proxy, diff --git a/pyrogram/storage/file_storage.py b/pyrogram/storage/file_storage.py index 031cb4ac..7afcf0a5 100644 --- a/pyrogram/storage/file_storage.py +++ b/pyrogram/storage/file_storage.py @@ -68,6 +68,12 @@ class FileStorage(SQLiteStorage): version += 1 + if version == 4: + with self.conn: + self.conn.executescript(self.UPDATE_DC_SCHEMA) + + version += 1 + self.version(version) async def open(self): diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py index 8570d3d4..f040ca67 100644 --- a/pyrogram/storage/mongo_storage.py +++ b/pyrogram/storage/mongo_storage.py @@ -77,6 +77,7 @@ class MongoStorage(Storage): self._session = database['session'] self._usernames = database['usernames'] self._states = database['update_state'] + self._dc_options = database['dc_options'] self._remove_peers = remove_peers async def open(self): @@ -221,6 +222,55 @@ class MongoStorage(Storage): return get_input_peer(r['_id'], r['access_hash'], r['type']) + async def update_dc_address( + self, + value: Tuple[int, str, int, bool, bool] = object + ): + """ + Updates or inserts a data center address. + + Parameters: + value (Tuple[int, str, int, bool, bool]): A tuple containing: + - dc_id (int): Data center ID. + - address (str): Address of the data center. + - port (int): Port of the data center. + - is_ipv6 (bool): Whether the address is IPv6. + - is_test (bool): Whether it is a test data center. + - is_media (bool): Whether it is a media data center. + """ + if value == object: + return + + await self._dc_options.update_one( + {"$and": [ + {'dc_id': value[0]}, + {'is_ipv6': value[3]}, + {'is_test': value[4]}, + {'is_media': value[5]} + ]}, + {'$set': {'address': value[1], 'port': value[2]}}, + upsert=True + ) + + async def get_dc_address( + self, + dc_id: int, + is_ipv6: bool, + test_mode: bool = False, + media: bool = False + ) -> Tuple[str, int]: + if dc_id in [1,3,5] and media: + media = False + if dc_id in [4,5] and test_mode: + test_mode = False + r = await self._dc_options.find_one( + {'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_test': test_mode, 'is_media': media}, + {'address': 1, 'port': 1} + ) + if r is None: + return None + return r['address'], r['port'] + async def _get(self): attr = inspect.stack()[2].function d = await self._session.find_one({'_id': 0}, {attr: 1}) diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index a1542818..63a0a537 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -97,6 +97,21 @@ END; """ +UPDATE_DC_SCHEMA = """ +CREATE TABLE dc_options +( + id INTEGER PRIMARY KEY AUTOINCREMENT, + dc_id INTEGER, + address TEXT, + port INTEGER, + is_ipv6 BOOLEAN, + is_test BOOLEAN, + is_media BOOLEAN, + UNIQUE(dc_id, is_ipv6, is_test, is_media) +); +""" + + def get_input_peer(peer_id: int, access_hash: int, peer_type: str): if peer_type in ["user", "bot"]: return raw.types.InputPeerUser( @@ -121,6 +136,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str): class SQLiteStorage(Storage): VERSION = 4 USERNAME_TTL = 8 * 60 * 60 + UPDATE_DC_SCHEMA = globals().get("UPDATE_DC_SCHEMA", "") def __init__(self, name: str): super().__init__(name) @@ -131,6 +147,7 @@ class SQLiteStorage(Storage): with self.conn: self.conn.executescript(SCHEMA) self.conn.executescript(UNAME_SCHEMA) + self.conn.executescript(self.UPDATE_DC_SCHEMA) self.conn.execute( "INSERT INTO version VALUES (?)", @@ -252,6 +269,65 @@ class SQLiteStorage(Storage): return get_input_peer(*r) + async def update_dc_address( + self, + value: Tuple[int, str, int, bool, bool, bool] = object + ): + """ + Updates or inserts a data center address. + + Parameters: + value (Tuple[int, str, int, bool, bool, bool]): A tuple containing: + - dc_id (int): Data center ID. + - address (str): Address of the data center. + - port (int): Port of the data center. + - is_ipv6 (bool): Whether the address is IPv6. + - is_test (bool): Whether it is a test data center. + - is_media (bool): Whether it is a media data center. + """ + if value == object: + return + with self.conn: + self.conn.execute( + """ + INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_test, is_media) + VALUES (?, ?, ?, ?, ?, ?) + ON CONFLICT(dc_id, is_ipv6, is_test, is_media) + DO UPDATE SET address=excluded.address, port=excluded.port + """, + value + ) + + async def get_dc_address( + self, + dc_id: int, + is_ipv6: bool, + test_mode: bool = False, + media: bool = False + ) -> Tuple[str, int]: + """ + Retrieves the address of a data center. + + Parameters: + dc_id (int): Data center ID. + is_ipv6 (bool): Whether the address is IPv6. + test_mode (bool): Whether it is a test data center. + media (bool): Whether it is a media data center. + + Returns: + Tuple[str, int]: A tuple containing the address and port of the data center. + """ + if dc_id in [1,3,5] and media: + media = False + if dc_id in [4,5] and test_mode: + test_mode = False + r = self.conn.execute( + "SELECT address, port FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_test = ? AND is_media = ?", + (dc_id, is_ipv6, test_mode, media) + ).fetchone() + + return r + def _get(self): attr = inspect.stack()[2].function diff --git a/pyrogram/storage/storage.py b/pyrogram/storage/storage.py index 7484076a..69a9e1c0 100644 --- a/pyrogram/storage/storage.py +++ b/pyrogram/storage/storage.py @@ -76,6 +76,21 @@ class Storage: async def get_peer_by_phone_number(self, phone_number: str): raise NotImplementedError + async def update_dc_address( + self, + value: Tuple[int, str, int, bool, bool] = object + ): + raise NotImplementedError + + async def get_dc_address( + self, + dc_id: int, + is_ipv6: bool, + test_mode: bool = False, + media: bool = False + ): + raise NotImplementedError + async def dc_id(self, value: int = object): raise NotImplementedError