From 5aae488747e4940e9f9eb81fd9d5cba4beeade96 Mon Sep 17 00:00:00 2001 From: wulan17 Date: Fri, 21 Jun 2024 00:00:50 +0700 Subject: [PATCH] pyrofork: Initial Dynamic datacenters ip address support Signed-off-by: wulan17 --- pyrogram/client.py | 58 +++++++++++++- pyrogram/connection/connection.py | 9 ++- pyrogram/methods/advanced/save_file.py | 4 +- pyrogram/methods/auth/connect.py | 4 +- pyrogram/methods/auth/send_code.py | 29 ++++++- pyrogram/methods/auth/sign_in_bot.py | 29 ++++++- pyrogram/methods/messages/inline_session.py | 4 +- pyrogram/session/auth.py | 20 +++-- pyrogram/session/internals/__init__.py | 1 - pyrogram/session/internals/data_center.py | 83 --------------------- pyrogram/session/session.py | 7 +- pyrogram/storage/file_storage.py | 16 ++++ pyrogram/storage/memory_storage.py | 17 ++++- pyrogram/storage/mongo_storage.py | 20 +++++ pyrogram/storage/sqlite_storage.py | 30 +++++++- pyrogram/storage/storage.py | 41 +++++++++- 16 files changed, 258 insertions(+), 114 deletions(-) delete mode 100644 pyrogram/session/internals/data_center.py diff --git a/pyrogram/client.py b/pyrogram/client.py index 761fd1b0..2342271d 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -716,10 +716,26 @@ class Client(Methods): await self.storage.date(0) await self.storage.test_mode(self.test_mode) + if self.test_mode: + await self.storage.server_address("149.154.167.40") + await self.storage.server_address_v6("2001:67c:4e8:f002::e") + await self.storage.server_port(80) + await self.storage.media_address("149.154.167.40") + await self.storage.media_address_v6("2001:67c:4e8:f002::e") + await self.storage.media_port(80) + else: + await self.storage.server_address("149.154.167.51") + await self.storage.server_address_v6("2001:67c:4e8:f002::a") + await self.storage.server_port(5222 if self.alt_port else 443) + await self.storage.media_address("149.154.167.151") + await self.storage.media_address_v6("2001:067c:04e8:f002:0000:0000:0000:000b") + await self.storage.media_port(5222 if self.alt_port else 443) await self.storage.auth_key( await Auth( self, await self.storage.dc_id(), - await self.storage.test_mode() + await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() ).create() ) await self.storage.user_id(None) @@ -745,6 +761,22 @@ class Client(Methods): break except Exception as e: print(e) + # Needed for migration from storage v4 to v5 + if not await self.storage.server_address(): + if self.test_mode: + await self.storage.server_address("149.154.167.40") + await self.storage.server_address_v6("2001:67c:4e8:f002::e") + await self.storage.server_port(80) + await self.storage.media_address("") + await self.storage.media_address_v6("") + await self.storage.media_port(80) + else: + await self.storage.server_address("149.154.167.51") + await self.storage.server_address_v6("2001:67c:4e8:f002::a") + await self.storage.server_port(5222 if self.alt_port else 443) + await self.storage.media_address("149.154.167.151") + await self.storage.media_address_v6("2001:067c:04e8:f002:0000:0000:0000:000b") + await self.storage.media_port(5222 if self.alt_port else 443) def load_plugins(self): if self.plugins: @@ -950,10 +982,18 @@ class Client(Methods): session = Session( self, dc_id, - await Auth(self, dc_id, await self.storage.test_mode()).create() + await Auth( + self, + dc_id, + await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() + ).create() if dc_id != await self.storage.dc_id() else await self.storage.auth_key(), await self.storage.test_mode(), + await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(), + await self.storage.media_port(), is_media=True ) @@ -1021,8 +1061,18 @@ class Client(Methods): elif isinstance(r, raw.types.upload.FileCdnRedirect): cdn_session = Session( - self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(), - await self.storage.test_mode(), is_media=True, is_cdn=True + self, r.dc_id, await Auth( + self, + r.dc_id, + await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() + ).create(), + await self.storage.test_mode(), + await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(), + await self.storage.media_port(), + is_media=True, + is_cdn=True ) try: diff --git a/pyrogram/connection/connection.py b/pyrogram/connection/connection.py index 53ac12bd..0620bb3a 100644 --- a/pyrogram/connection/connection.py +++ b/pyrogram/connection/connection.py @@ -22,7 +22,6 @@ import logging from typing import Optional, Type from .transport import TCP, TCPAbridged -from ..session.internals import DataCenter log = logging.getLogger(__name__) @@ -34,21 +33,23 @@ class Connection: self, dc_id: int, test_mode: bool, + server_ip: str, + server_port: int, ipv6: bool, - alt_port: bool, proxy: dict, media: bool = False, protocol_factory: Type[TCP] = TCPAbridged ) -> None: self.dc_id = dc_id self.test_mode = test_mode + self.server_ip = server_ip + self.server_port = server_port self.ipv6 = ipv6 - self.alt_port = alt_port self.proxy = proxy self.media = media self.protocol_factory = protocol_factory - self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media) + self.address = (server_ip, server_port) self.protocol: Optional[TCP] = None async def connect(self) -> None: diff --git a/pyrogram/methods/advanced/save_file.py b/pyrogram/methods/advanced/save_file.py index 7eae0469..2e65f82d 100644 --- a/pyrogram/methods/advanced/save_file.py +++ b/pyrogram/methods/advanced/save_file.py @@ -142,7 +142,9 @@ class SaveFile: md5_sum = md5() if not is_big and not is_missing_part else None session = Session( self, await self.storage.dc_id(), await self.storage.auth_key(), - await self.storage.test_mode(), is_media=True + await self.storage.test_mode(), + await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(), + await self.storage.media_port(), is_media=True ) workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)] queue = asyncio.Queue(1) diff --git a/pyrogram/methods/auth/connect.py b/pyrogram/methods/auth/connect.py index 0858a832..5bcd04c4 100644 --- a/pyrogram/methods/auth/connect.py +++ b/pyrogram/methods/auth/connect.py @@ -42,7 +42,9 @@ class Connect: self.session = Session( self, await self.storage.dc_id(), - await self.storage.auth_key(), await self.storage.test_mode() + await self.storage.auth_key(), await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() ) await self.session.start() diff --git a/pyrogram/methods/auth/send_code.py b/pyrogram/methods/auth/send_code.py index 9d198fbf..0af0ee70 100644 --- a/pyrogram/methods/auth/send_code.py +++ b/pyrogram/methods/auth/send_code.py @@ -61,18 +61,43 @@ class SendCode: ) ) except (PhoneMigrate, NetworkMigrate) as e: + config = await self.invoke(raw.functions.help.GetConfig()) + for option in config.dc_options: + if (option.id == e.value): + if option.media_only: + if option.ipv6: + await self.storage.media_address_v6(option.ip_address) + else: + await self.storage.media_address(option.ip_address) + if option.this_port_only: + await self.storage.media_port(option.port) + else: + if option.ipv6: + await self.storage.server_address_v6(option.ip_address) + else: + await self.storage.server_address(option.ip_address) + if option.this_port_only: + await self.storage.port(option.port) + if e not in [2,4] or self.storage.test_mode(): + await self.storage.media_address(await self.storage.server_address()) + await self.storage.media_address_v6(await self.storage.server_address_v6()) + await self.storage.media_port(await self.storage.server_port()) await self.session.stop() await self.storage.dc_id(e.value) await self.storage.auth_key( await Auth( self, await self.storage.dc_id(), - await self.storage.test_mode() + await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() ).create() ) self.session = Session( self, await self.storage.dc_id(), - await self.storage.auth_key(), await self.storage.test_mode() + await self.storage.auth_key(), await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() ) await self.session.start() diff --git a/pyrogram/methods/auth/sign_in_bot.py b/pyrogram/methods/auth/sign_in_bot.py index 59147daa..886c2d39 100644 --- a/pyrogram/methods/auth/sign_in_bot.py +++ b/pyrogram/methods/auth/sign_in_bot.py @@ -58,18 +58,43 @@ class SignInBot: ) ) except UserMigrate as e: + config = await self.invoke(raw.functions.help.GetConfig()) + for option in config.dc_options: + if (option.id == e.value): + if option.media_only: + if option.ipv6: + await self.storage.media_address_v6(option.ip_address) + else: + await self.storage.media_address(option.ip_address) + if option.this_port_only: + await self.storage.media_port(option.port) + else: + if option.ipv6: + await self.storage.server_address_v6(option.ip_address) + else: + await self.storage.server_address(option.ip_address) + if option.this_port_only: + await self.storage.port(option.port) + if e not in [2,4] or self.storage.test_mode(): + await self.storage.media_address(await self.storage.server_address()) + await self.storage.media_address_v6(await self.storage.server_address_v6()) + await self.storage.media_port(await self.storage.server_port()) await self.session.stop() await self.storage.dc_id(e.value) await self.storage.auth_key( await Auth( self, await self.storage.dc_id(), - await self.storage.test_mode() + await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() ).create() ) self.session = Session( self, await self.storage.dc_id(), - await self.storage.auth_key(), await self.storage.test_mode() + await self.storage.auth_key(), await self.storage.test_mode(), + await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(), + await self.storage.server_port() ) await self.session.start() diff --git a/pyrogram/methods/messages/inline_session.py b/pyrogram/methods/messages/inline_session.py index ed6e9980..02315cb1 100644 --- a/pyrogram/methods/messages/inline_session.py +++ b/pyrogram/methods/messages/inline_session.py @@ -35,7 +35,9 @@ async def get_session(client: "pyrogram.Client", dc_id: int): session = client.media_sessions[dc_id] = Session( client, dc_id, await Auth(client, dc_id, await client.storage.test_mode()).create(), - await client.storage.test_mode(), is_media=True + await client.storage.test_mode(), + await client.storage.media_address_v6() if client.ipv6 else await client.storage.media_address(), + await client.storage.media_port(), is_media=True ) await session.start() diff --git a/pyrogram/session/auth.py b/pyrogram/session/auth.py index 346b0f92..dc4e2757 100644 --- a/pyrogram/session/auth.py +++ b/pyrogram/session/auth.py @@ -43,12 +43,15 @@ class Auth: self, client: "pyrogram.Client", dc_id: int, - test_mode: bool + test_mode: bool, + server_ip: str, + server_port: int ): self.dc_id = dc_id self.test_mode = test_mode + self.server_ip = server_ip + self.server_port = server_port self.ipv6 = client.ipv6 - self.alt_port = client.alt_port self.proxy = client.proxy self.connection_factory = client.connection_factory self.protocol_factory = client.protocol_factory @@ -86,14 +89,15 @@ class Auth: # The server may close the connection at any time, causing the auth key creation to fail. # If that happens, just try again up to MAX_RETRIES times. while True: - self.connection = self.connection_factory( + self.connection = self.client.connection_factory( dc_id=self.dc_id, test_mode=self.test_mode, - ipv6=self.ipv6, - alt_port=self.alt_port, - proxy=self.proxy, - media=False, - protocol_factory=self.protocol_factory + server_ip=self.server_ip, + server_port=self.server_port, + ipv6=self.client.ipv6, + proxy=self.client.proxy, + media=self.is_media, + protocol_factory=self.client.protocol_factory ) try: diff --git a/pyrogram/session/internals/__init__.py b/pyrogram/session/internals/__init__.py index 1754586e..af51a3e9 100644 --- a/pyrogram/session/internals/__init__.py +++ b/pyrogram/session/internals/__init__.py @@ -17,6 +17,5 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrofork. If not, see . -from .data_center import DataCenter from .msg_factory import MsgFactory from .msg_id import MsgId diff --git a/pyrogram/session/internals/data_center.py b/pyrogram/session/internals/data_center.py deleted file mode 100644 index 43db631c..00000000 --- a/pyrogram/session/internals/data_center.py +++ /dev/null @@ -1,83 +0,0 @@ -# Pyrofork - Telegram MTProto API Client Library for Python -# Copyright (C) 2017-present Dan -# Copyright (C) 2022-present Mayuri-Chan -# -# This file is part of Pyrofork. -# -# Pyrofork is free software: you can redistribute it and/or modify -# it under the terms of the GNU Lesser General Public License as published -# by the Free Software Foundation, either version 3 of the License, or -# (at your option) any later version. -# -# Pyrofork is distributed in the hope that it will be useful, -# but WITHOUT ANY WARRANTY; without even the implied warranty of -# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the -# GNU Lesser General Public License for more details. -# -# You should have received a copy of the GNU Lesser General Public License -# along with Pyrofork. If not, see . - -from typing import Tuple - - -class DataCenter: - TEST = { - 1: "149.154.175.10", - 2: "149.154.167.40", - 3: "149.154.175.117", - } - - PROD = { - 1: "149.154.175.53", - 2: "149.154.167.51", - 3: "149.154.175.100", - 4: "149.154.167.91", - 5: "91.108.56.130", - 203: "91.105.192.100" - } - - PROD_MEDIA = { - 2: "149.154.167.151", - 4: "149.154.164.250" - } - - TEST_IPV6 = { - 1: "2001:b28:f23d:f001::e", - 2: "2001:67c:4e8:f002::e", - 3: "2001:b28:f23d:f003::e", - } - - PROD_IPV6 = { - 1: "2001:b28:f23d:f001::a", - 2: "2001:67c:4e8:f002::a", - 3: "2001:b28:f23d:f003::a", - 4: "2001:67c:4e8:f004::a", - 5: "2001:b28:f23f:f005::a", - 203: "2a0a:f280:0203:000a:5000:0000:0000:0100" - } - - PROD_IPV6_MEDIA = { - 2: "2001:067c:04e8:f002:0000:0000:0000:000b", - 4: "2001:067c:04e8:f004:0000:0000:0000:000b" - } - - def __new__(cls, dc_id: int, test_mode: bool, ipv6: bool, alt_port: bool, media: bool) -> Tuple[str, int]: - if test_mode: - if ipv6: - ip = cls.TEST_IPV6[dc_id] - else: - ip = cls.TEST[dc_id] - - return ip, 80 - else: - if ipv6: - if media: - ip = cls.PROD_IPV6_MEDIA.get(dc_id, cls.PROD_IPV6[dc_id]) - else: - ip = cls.PROD_IPV6[dc_id] - else: - if media: - ip = cls.PROD_MEDIA.get(dc_id, cls.PROD[dc_id]) - else: - ip = cls.PROD[dc_id] - return ip, 5222 if alt_port else 443 diff --git a/pyrogram/session/session.py b/pyrogram/session/session.py index 40625cfd..10051050 100644 --- a/pyrogram/session/session.py +++ b/pyrogram/session/session.py @@ -69,6 +69,8 @@ class Session: dc_id: int, auth_key: bytes, test_mode: bool, + server_ip: str, + server_port: int, is_media: bool = False, is_cdn: bool = False ): @@ -78,6 +80,8 @@ class Session: self.test_mode = test_mode self.is_media = is_media self.is_cdn = is_cdn + self.server_ip = server_ip + self.server_port = server_port self.connection: Optional[Connection] = None @@ -108,8 +112,9 @@ class Session: self.connection = self.client.connection_factory( dc_id=self.dc_id, test_mode=self.test_mode, + server_ip=self.server_ip, + server_port=self.server_port, ipv6=self.client.ipv6, - alt_port=self.client.alt_port, proxy=self.client.proxy, media=self.is_media, protocol_factory=self.client.protocol_factory diff --git a/pyrogram/storage/file_storage.py b/pyrogram/storage/file_storage.py index 031cb4ac..ad3f109d 100644 --- a/pyrogram/storage/file_storage.py +++ b/pyrogram/storage/file_storage.py @@ -38,6 +38,16 @@ CREATE TABLE update_state ); """ +UPDATE_SESSION_SCHEMA = """ +ALTER TABLE sessions +ADD server_address TEXT, +server_address_v6 TEXT, +server_port INTEGER, +media_address TEXT, +media_address_v6 TEXT, +media_port INTEGER; +""" + class FileStorage(SQLiteStorage): FILE_EXTENSION = ".session" @@ -68,6 +78,12 @@ class FileStorage(SQLiteStorage): version += 1 + if version == 4: + with self.conn: + self.conn.execute(UPDATE_SESSION_SCHEMA) + + version += 1 + self.version(version) async def open(self): diff --git a/pyrogram/storage/memory_storage.py b/pyrogram/storage/memory_storage.py index 2a3f1f20..1d965437 100644 --- a/pyrogram/storage/memory_storage.py +++ b/pyrogram/storage/memory_storage.py @@ -19,6 +19,7 @@ import base64 import logging +import re import sqlite3 import struct @@ -57,11 +58,19 @@ class MemoryStorage(SQLiteStorage): log.warning("You are using an old session string format. Use export_session_string to update") return - dc_id, api_id, test_mode, auth_key, user_id, is_bot = struct.unpack( + dc_id, api_id, test_mode, auth_key, user_id, is_bot, server_address, server_address_v6, server_port, media_address, media_address_v6, media_port = struct.unpack( self.SESSION_STRING_FORMAT, base64.urlsafe_b64decode(self.session_string + "=" * (-len(self.session_string) % 4)) ) + # Remove leading zeros + server_address = re.sub(r'^[0]*', '', re.sub(r'\.[0]*', '.', server_address.decode())) + server_address_v6 = re.sub(r'^[0]*', '', re.sub(r'\:[0]*', ':', server_address_v6.decode())) + server_port = int(server_port.decode()) + media_address = re.sub(r'^[0]*', '', re.sub(r'\.[0]*', '.', media_address.decode())) + media_address_v6 = re.sub(r'^[0]*', '', re.sub(r'\:[0]*', ':', media_address_v6.decode())) + media_port = int(media_port.decode()) + await self.dc_id(dc_id) await self.api_id(api_id) await self.test_mode(test_mode) @@ -69,6 +78,12 @@ class MemoryStorage(SQLiteStorage): await self.user_id(user_id) await self.is_bot(is_bot) await self.date(0) + await self.server_address(server_address) + await self.server_address_v6(server_address_v6) + await self.server_port(server_port) + await self.media_address(media_address) + await self.media_address_v6(media_address_v6) + await self.media_port(media_port) async def delete(self): pass diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py index 60e45f63..e196286f 100644 --- a/pyrogram/storage/mongo_storage.py +++ b/pyrogram/storage/mongo_storage.py @@ -223,6 +223,8 @@ class MongoStorage(Storage): d = await self._session.find_one({'_id': 0}, {attr: 1}) if not d: return + if f"{attr}" not in d: + return return d[attr] async def _set(self, value: Any): @@ -252,3 +254,21 @@ class MongoStorage(Storage): async def is_bot(self, value: bool = object): return await self._accessor(value) + + async def server_address(self, value: str = object): + return await self._accessor(value) + + async def server_address_v6(self, value: str = object): + return await self._accessor(value) + + async def server_port(self, value: int = object): + return await self._accessor(value) + + async def media_address(self, value: str = object): + return await self._accessor(value) + + async def media_address_v6(self, value: str = object): + return await self._accessor(value) + + async def media_port(self, value: int = object): + return await self._accessor(value) diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index 162ef0fd..a7716b6d 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -36,7 +36,13 @@ CREATE TABLE sessions auth_key BLOB, date INTEGER NOT NULL, user_id INTEGER, - is_bot INTEGER + is_bot INTEGER, + server_address TEXT, + server_address_v6 TEXT, + server_port INTEGER, + media_address TEXT, + media_address_v6 TEXT, + media_port INTEGER ); CREATE TABLE peers @@ -138,8 +144,8 @@ class SQLiteStorage(Storage): ) self.conn.execute( - "INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)", - (2, None, None, None, 0, None, None) + "INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)", + (2, None, None, None, 0, None, None, None, None, 0, None, None, 0) ) async def open(self): @@ -286,6 +292,24 @@ class SQLiteStorage(Storage): async def is_bot(self, value: bool = object): return self._accessor(value) + async def server_address(self, value: str = object): + return self._accessor(value) + + async def server_address_v6(self, value: str = object): + return self._accessor(value) + + async def server_port(self, value: int = object): + return self._accessor(value) + + async def media_address(self, value: str = object): + return self._accessor(value) + + async def media_address_v6(self, value: str = object): + return self._accessor(value) + + async def media_port(self, value: int = object): + return self._accessor(value) + def version(self, value: int = object): if value == object: return self.conn.execute( diff --git a/pyrogram/storage/storage.py b/pyrogram/storage/storage.py index 7484076a..6cad7605 100644 --- a/pyrogram/storage/storage.py +++ b/pyrogram/storage/storage.py @@ -29,7 +29,7 @@ class Storage: SESSION_STRING_SIZE = 351 SESSION_STRING_SIZE_64 = 356 - SESSION_STRING_FORMAT = ">BI?256sQ?" + SESSION_STRING_FORMAT = ">BI?256sQ?15s39sB15s39sB" def __init__(self, name: str): self.name = name @@ -97,7 +97,38 @@ class Storage: async def is_bot(self, value: bool = object): raise NotImplementedError + async def server_address(self, value: str = object): + raise NotImplementedError + + async def server_address_v6(self, value: str = object): + raise NotImplementedError + + async def server_port(self, value: int = object): + raise NotImplementedError + + async def media_address(self, value: str = object): + raise NotImplementedError + + async def media_address_v6(self, value: str = object): + raise NotImplementedError + + async def media_port(self, value: int = object): + raise NotImplementedError + async def export_session_string(self): + server_ip = await self.server_address() + server_ip_v6 = await self.server_address_v6() + port_server = await self.server_port() + media_ip = await self.media_address() + media_ip_v6 = await self.media_address_v6() + port_media = await self.media_port() + # Add leading zero to make fixed size + server_address = '.'.join(i.zfill(3) for i in server_ip.split('.')) + server_address_v6 = ':'.join(i.zfill(4) for i in server_ip_v6.split(':')) + server_port = f"{port_server}".zfill(5) + media_address = '.'.join(i.zfill(3) for i in media_ip.split('.')) + media_address_v6 = ':'.join(i.zfill(4) for i in media_ip_v6.split(':')) + media_port = f"{port_media}".zfill(5) packed = struct.pack( self.SESSION_STRING_FORMAT, await self.dc_id(), @@ -105,7 +136,13 @@ class Storage: await self.test_mode(), await self.auth_key(), await self.user_id(), - await self.is_bot() + await self.is_bot(), + bytes(server_address.encode()), + bytes(server_address_v6.encode()), + bytes(server_port.encode()), + bytes(media_address.encode()), + bytes(media_address_v6.encode()), + bytes(media_port.encode()) ) return base64.urlsafe_b64encode(packed).decode().rstrip("=")