mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2025-12-29 12:04:51 +00:00
pyrofork: Retrive dc address and port from GetConfig
Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
parent
070afc0246
commit
2212b98b81
11 changed files with 231 additions and 3 deletions
|
|
@ -61,6 +61,7 @@ from .file_id import FileId, FileType, ThumbnailSource
|
||||||
from .mime_types import mime_types
|
from .mime_types import mime_types
|
||||||
from .parser import Parser
|
from .parser import Parser
|
||||||
from .session.internals import MsgId
|
from .session.internals import MsgId
|
||||||
|
from .session.internals.data_center import DataCenter
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
MONGO_AVAIL = False
|
MONGO_AVAIL = False
|
||||||
|
|
@ -522,6 +523,15 @@ class Client(Methods):
|
||||||
break
|
break
|
||||||
|
|
||||||
if isinstance(signed_in, User):
|
if isinstance(signed_in, User):
|
||||||
|
is_dc_default = await self.check_dc_default()
|
||||||
|
if is_dc_default:
|
||||||
|
log.info("Your session is using the default data center.")
|
||||||
|
log.info("Updating the data center options from GetConfig...")
|
||||||
|
await self.update_dc_option()
|
||||||
|
log.info("Data center updated successfully.")
|
||||||
|
log.info("Restarting the session to apply the changes...")
|
||||||
|
await self.session.stop()
|
||||||
|
await self.session.start()
|
||||||
return signed_in
|
return signed_in
|
||||||
|
|
||||||
def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
|
def set_parse_mode(self, parse_mode: Optional["enums.ParseMode"]):
|
||||||
|
|
@ -856,6 +866,62 @@ class Client(Methods):
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
# Needed for migration from storage v3 to v4
|
||||||
|
if self.in_memory or self.session_string:
|
||||||
|
await self.insert_default_dc_options()
|
||||||
|
if not await self.storage.get_dc_address(await self.storage.dc_id(), self.ipv6):
|
||||||
|
log.info("No DC address found, inserting default DC options...")
|
||||||
|
await self.insert_default_dc_options()
|
||||||
|
|
||||||
|
async def insert_default_dc_options(self):
|
||||||
|
for dc_id in range(1, 6):
|
||||||
|
for is_ipv6 in (False, True):
|
||||||
|
if dc_id == 2:
|
||||||
|
for media in (False, True):
|
||||||
|
address, port = DataCenter(dc_id, False, is_ipv6, self.alt_port, media)
|
||||||
|
await self.storage.update_dc_address(
|
||||||
|
(dc_id, address, port, is_ipv6, False, media, True)
|
||||||
|
)
|
||||||
|
for test_mode in (False, True):
|
||||||
|
address, port = DataCenter(dc_id, test_mode, is_ipv6, self.alt_port, False)
|
||||||
|
await self.storage.update_dc_address(
|
||||||
|
(dc_id, address, port, is_ipv6, test_mode, False, True)
|
||||||
|
)
|
||||||
|
elif dc_id == 4:
|
||||||
|
for media in (False, True):
|
||||||
|
address, port = DataCenter(dc_id, False, is_ipv6, False, media)
|
||||||
|
await self.storage.update_dc_address(
|
||||||
|
(dc_id, address, port, is_ipv6, False, media, True)
|
||||||
|
)
|
||||||
|
elif dc_id in [1,3]:
|
||||||
|
for test_mode in (False, True):
|
||||||
|
address, port = DataCenter(dc_id, test_mode, is_ipv6, False, False)
|
||||||
|
await self.storage.update_dc_address(
|
||||||
|
(dc_id, address, port, is_ipv6, test_mode, False, True)
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
address, port = DataCenter(dc_id, False, is_ipv6, False, False)
|
||||||
|
await self.storage.update_dc_address(
|
||||||
|
(dc_id, address, port, is_ipv6, False, False, True)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def update_dc_option(self):
|
||||||
|
config = await self.invoke(raw.functions.help.GetConfig())
|
||||||
|
for option in config.dc_options:
|
||||||
|
await self.storage.update_dc_address(
|
||||||
|
(option.id, option.address, option.port, option.is_ipv6, option.is_media, option.is_test, False)
|
||||||
|
)
|
||||||
|
|
||||||
|
async def check_dc_default(
|
||||||
|
self,
|
||||||
|
dc_id: int,
|
||||||
|
is_ipv6: bool,
|
||||||
|
test_mode: bool = False,
|
||||||
|
media: bool = False
|
||||||
|
) -> bool:
|
||||||
|
current_dc = await self.storage.get_dc_address(dc_id, is_ipv6, test_mode, media)
|
||||||
|
if current_dc is not None and current_dc[2]:
|
||||||
|
return True
|
||||||
|
|
||||||
def is_excluded(self, exclude, module):
|
def is_excluded(self, exclude, module):
|
||||||
for e in exclude:
|
for e in exclude:
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ import logging
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from .transport import TCP, TCPAbridged
|
from .transport import TCP, TCPAbridged
|
||||||
from ..session.internals import DataCenter
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -34,6 +33,8 @@ class Connection:
|
||||||
self,
|
self,
|
||||||
dc_id: int,
|
dc_id: int,
|
||||||
test_mode: bool,
|
test_mode: bool,
|
||||||
|
server_ip: str,
|
||||||
|
server_port: int,
|
||||||
ipv6: bool,
|
ipv6: bool,
|
||||||
alt_port: bool,
|
alt_port: bool,
|
||||||
proxy: dict,
|
proxy: dict,
|
||||||
|
|
@ -42,13 +43,14 @@ class Connection:
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dc_id = dc_id
|
self.dc_id = dc_id
|
||||||
self.test_mode = test_mode
|
self.test_mode = test_mode
|
||||||
|
self.server_ip = server_ip
|
||||||
|
self.server_port = 5222 if alt_port else server_port
|
||||||
self.ipv6 = ipv6
|
self.ipv6 = ipv6
|
||||||
self.alt_port = alt_port
|
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self.media = media
|
self.media = media
|
||||||
self.protocol_factory = protocol_factory
|
self.protocol_factory = protocol_factory
|
||||||
|
|
||||||
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
|
self.address = (server_ip, server_port)
|
||||||
self.protocol: Optional[TCP] = None
|
self.protocol: Optional[TCP] = None
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -61,6 +61,7 @@ class SendCode:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (PhoneMigrate, NetworkMigrate) as e:
|
except (PhoneMigrate, NetworkMigrate) as e:
|
||||||
|
await self.update_dc_option()
|
||||||
# pylint: disable=access-member-before-definition
|
# pylint: disable=access-member-before-definition
|
||||||
await self.session.stop()
|
await self.session.stop()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -58,6 +58,7 @@ class SignInBot:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except UserMigrate as e:
|
except UserMigrate as e:
|
||||||
|
await self.update_dc_option()
|
||||||
# pylint: disable=access-member-before-definition
|
# pylint: disable=access-member-before-definition
|
||||||
await self.session.stop()
|
await self.session.stop()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -80,6 +80,7 @@ class SignInQrcode:
|
||||||
|
|
||||||
return types.User._parse(self, r.authorization.user)
|
return types.User._parse(self, r.authorization.user)
|
||||||
if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
|
if isinstance(r, raw.types.auth.LoginTokenMigrateTo):
|
||||||
|
await self.update_dc_option()
|
||||||
# pylint: disable=access-member-before-definition
|
# pylint: disable=access-member-before-definition
|
||||||
await self.session.stop()
|
await self.session.stop()
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -50,6 +50,7 @@ class Auth:
|
||||||
self.ipv6 = client.ipv6
|
self.ipv6 = client.ipv6
|
||||||
self.alt_port = client.alt_port
|
self.alt_port = client.alt_port
|
||||||
self.proxy = client.proxy
|
self.proxy = client.proxy
|
||||||
|
self.storage = client.storage
|
||||||
self.connection_factory = client.connection_factory
|
self.connection_factory = client.connection_factory
|
||||||
self.protocol_factory = client.protocol_factory
|
self.protocol_factory = client.protocol_factory
|
||||||
|
|
||||||
|
|
@ -85,10 +86,13 @@ class Auth:
|
||||||
|
|
||||||
# The server may close the connection at any time, causing the auth key creation to fail.
|
# The server may close the connection at any time, causing the auth key creation to fail.
|
||||||
# If that happens, just try again up to MAX_RETRIES times.
|
# If that happens, just try again up to MAX_RETRIES times.
|
||||||
|
address, port, _ = await self.storage.get_dc_address(self.dc_id, self.ipv6, self.test_mode)
|
||||||
while True:
|
while True:
|
||||||
self.connection = self.connection_factory(
|
self.connection = self.connection_factory(
|
||||||
dc_id=self.dc_id,
|
dc_id=self.dc_id,
|
||||||
test_mode=self.test_mode,
|
test_mode=self.test_mode,
|
||||||
|
server_ip=address,
|
||||||
|
server_port=port,
|
||||||
ipv6=self.ipv6,
|
ipv6=self.ipv6,
|
||||||
alt_port=self.alt_port,
|
alt_port=self.alt_port,
|
||||||
proxy=self.proxy,
|
proxy=self.proxy,
|
||||||
|
|
|
||||||
|
|
@ -104,10 +104,13 @@ class Session:
|
||||||
self.loop = asyncio.get_event_loop()
|
self.loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
async def start(self):
|
async def start(self):
|
||||||
|
address, port, _ = await self.client.storage.get_dc_address(self.dc_id, self.client.ipv6, self.test_mode, self.is_media)
|
||||||
while True:
|
while True:
|
||||||
self.connection = self.client.connection_factory(
|
self.connection = self.client.connection_factory(
|
||||||
dc_id=self.dc_id,
|
dc_id=self.dc_id,
|
||||||
test_mode=self.test_mode,
|
test_mode=self.test_mode,
|
||||||
|
server_ip=address,
|
||||||
|
server_port=port,
|
||||||
ipv6=self.client.ipv6,
|
ipv6=self.client.ipv6,
|
||||||
alt_port=self.client.alt_port,
|
alt_port=self.client.alt_port,
|
||||||
proxy=self.client.proxy,
|
proxy=self.client.proxy,
|
||||||
|
|
|
||||||
|
|
@ -68,6 +68,12 @@ class FileStorage(SQLiteStorage):
|
||||||
|
|
||||||
version += 1
|
version += 1
|
||||||
|
|
||||||
|
if version == 4:
|
||||||
|
with self.conn:
|
||||||
|
self.conn.executescript(self.UPDATE_DC_SCHEMA)
|
||||||
|
|
||||||
|
version += 1
|
||||||
|
|
||||||
self.version(version)
|
self.version(version)
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
|
|
|
||||||
|
|
@ -77,6 +77,7 @@ class MongoStorage(Storage):
|
||||||
self._session = database['session']
|
self._session = database['session']
|
||||||
self._usernames = database['usernames']
|
self._usernames = database['usernames']
|
||||||
self._states = database['update_state']
|
self._states = database['update_state']
|
||||||
|
self._dc_options = database['dc_options']
|
||||||
self._remove_peers = remove_peers
|
self._remove_peers = remove_peers
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
|
|
@ -221,6 +222,56 @@ class MongoStorage(Storage):
|
||||||
|
|
||||||
return get_input_peer(r['_id'], r['access_hash'], r['type'])
|
return get_input_peer(r['_id'], r['access_hash'], r['type'])
|
||||||
|
|
||||||
|
async def update_dc_address(
|
||||||
|
self,
|
||||||
|
value: Tuple[int, str, int, bool, bool, bool] = object
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Updates or inserts a data center address.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
value (Tuple[int, str, int, bool, bool]): A tuple containing:
|
||||||
|
- dc_id (int): Data center ID.
|
||||||
|
- address (str): Address of the data center.
|
||||||
|
- port (int): Port of the data center.
|
||||||
|
- is_ipv6 (bool): Whether the address is IPv6.
|
||||||
|
- is_test (bool): Whether it is a test data center.
|
||||||
|
- is_media (bool): Whether it is a media data center.
|
||||||
|
- is_default_ip (bool): Whether it is the dc IP address provided by library.
|
||||||
|
"""
|
||||||
|
if value == object:
|
||||||
|
return
|
||||||
|
|
||||||
|
await self._dc_options.update_one(
|
||||||
|
{"$and": [
|
||||||
|
{'dc_id': value[0]},
|
||||||
|
{'is_ipv6': value[3]},
|
||||||
|
{'is_test': value[4]},
|
||||||
|
{'is_media': value[5]}
|
||||||
|
]},
|
||||||
|
{'$set': {'address': value[1], 'port': value[2], 'is_default_ip': value[6]}},
|
||||||
|
upsert=True
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_dc_address(
|
||||||
|
self,
|
||||||
|
dc_id: int,
|
||||||
|
is_ipv6: bool,
|
||||||
|
test_mode: bool = False,
|
||||||
|
media: bool = False
|
||||||
|
) -> Tuple[str, int]:
|
||||||
|
if dc_id in [1,3,5] and media:
|
||||||
|
media = False
|
||||||
|
if dc_id in [4,5] and test_mode:
|
||||||
|
test_mode = False
|
||||||
|
r = await self._dc_options.find_one(
|
||||||
|
{'dc_id': dc_id, 'is_ipv6': is_ipv6, 'is_test': test_mode, 'is_media': media},
|
||||||
|
{'address': 1, 'port': 1, 'is_default_ip': 1}
|
||||||
|
)
|
||||||
|
if r is None:
|
||||||
|
return None
|
||||||
|
return r['address'], r['port'], r.get['is_default_ip']
|
||||||
|
|
||||||
async def _get(self):
|
async def _get(self):
|
||||||
attr = inspect.stack()[2].function
|
attr = inspect.stack()[2].function
|
||||||
d = await self._session.find_one({'_id': 0}, {attr: 1})
|
d = await self._session.find_one({'_id': 0}, {attr: 1})
|
||||||
|
|
|
||||||
|
|
@ -97,6 +97,22 @@ END;
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
|
||||||
|
UPDATE_DC_SCHEMA = """
|
||||||
|
CREATE TABLE dc_options
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY AUTOINCREMENT,
|
||||||
|
dc_id INTEGER,
|
||||||
|
address TEXT,
|
||||||
|
port INTEGER,
|
||||||
|
is_ipv6 BOOLEAN,
|
||||||
|
is_test BOOLEAN,
|
||||||
|
is_media BOOLEAN,
|
||||||
|
is_default_ip BOOLEAN,
|
||||||
|
UNIQUE(dc_id, is_ipv6, is_test, is_media)
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
||||||
if peer_type in ["user", "bot"]:
|
if peer_type in ["user", "bot"]:
|
||||||
return raw.types.InputPeerUser(
|
return raw.types.InputPeerUser(
|
||||||
|
|
@ -121,6 +137,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
||||||
class SQLiteStorage(Storage):
|
class SQLiteStorage(Storage):
|
||||||
VERSION = 4
|
VERSION = 4
|
||||||
USERNAME_TTL = 8 * 60 * 60
|
USERNAME_TTL = 8 * 60 * 60
|
||||||
|
UPDATE_DC_SCHEMA = globals().get("UPDATE_DC_SCHEMA", "")
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
super().__init__(name)
|
super().__init__(name)
|
||||||
|
|
@ -131,6 +148,7 @@ class SQLiteStorage(Storage):
|
||||||
with self.conn:
|
with self.conn:
|
||||||
self.conn.executescript(SCHEMA)
|
self.conn.executescript(SCHEMA)
|
||||||
self.conn.executescript(UNAME_SCHEMA)
|
self.conn.executescript(UNAME_SCHEMA)
|
||||||
|
self.conn.executescript(self.UPDATE_DC_SCHEMA)
|
||||||
|
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
"INSERT INTO version VALUES (?)",
|
"INSERT INTO version VALUES (?)",
|
||||||
|
|
@ -252,6 +270,66 @@ class SQLiteStorage(Storage):
|
||||||
|
|
||||||
return get_input_peer(*r)
|
return get_input_peer(*r)
|
||||||
|
|
||||||
|
async def update_dc_address(
|
||||||
|
self,
|
||||||
|
value: Tuple[int, str, int, bool, bool, bool, bool] = object
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Updates or inserts a data center address.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
value (Tuple[int, str, int, bool, bool, bool]): A tuple containing:
|
||||||
|
- dc_id (int): Data center ID.
|
||||||
|
- address (str): Address of the data center.
|
||||||
|
- port (int): Port of the data center.
|
||||||
|
- is_ipv6 (bool): Whether the address is IPv6.
|
||||||
|
- is_test (bool): Whether it is a test data center.
|
||||||
|
- is_media (bool): Whether it is a media data center.
|
||||||
|
- is_default_ip (bool): Whether it is the dc IP address provided by library.
|
||||||
|
"""
|
||||||
|
if value == object:
|
||||||
|
return
|
||||||
|
with self.conn:
|
||||||
|
self.conn.execute(
|
||||||
|
"""
|
||||||
|
INSERT INTO dc_options (dc_id, address, port, is_ipv6, is_test, is_media)
|
||||||
|
VALUES (?, ?, ?, ?, ?, ?, ?)
|
||||||
|
ON CONFLICT(dc_id, is_ipv6, is_test, is_media, is_default_ip)
|
||||||
|
DO UPDATE SET address=excluded.address, port=excluded.port
|
||||||
|
""",
|
||||||
|
value
|
||||||
|
)
|
||||||
|
|
||||||
|
async def get_dc_address(
|
||||||
|
self,
|
||||||
|
dc_id: int,
|
||||||
|
is_ipv6: bool,
|
||||||
|
test_mode: bool = False,
|
||||||
|
media: bool = False
|
||||||
|
) -> Tuple[str, int]:
|
||||||
|
"""
|
||||||
|
Retrieves the address of a data center.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
dc_id (int): Data center ID.
|
||||||
|
is_ipv6 (bool): Whether the address is IPv6.
|
||||||
|
test_mode (bool): Whether it is a test data center.
|
||||||
|
media (bool): Whether it is a media data center.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Tuple[str, int]: A tuple containing the address and port of the data center.
|
||||||
|
"""
|
||||||
|
if dc_id in [1,3,5] and media:
|
||||||
|
media = False
|
||||||
|
if dc_id in [4,5] and test_mode:
|
||||||
|
test_mode = False
|
||||||
|
r = self.conn.execute(
|
||||||
|
"SELECT address, port, is_default_ip FROM dc_options WHERE dc_id = ? AND is_ipv6 = ? AND is_test = ? AND is_media = ?",
|
||||||
|
(dc_id, is_ipv6, test_mode, media)
|
||||||
|
).fetchone()
|
||||||
|
|
||||||
|
return r
|
||||||
|
|
||||||
def _get(self):
|
def _get(self):
|
||||||
attr = inspect.stack()[2].function
|
attr = inspect.stack()[2].function
|
||||||
|
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,21 @@ class Storage:
|
||||||
async def get_peer_by_phone_number(self, phone_number: str):
|
async def get_peer_by_phone_number(self, phone_number: str):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def update_dc_address(
|
||||||
|
self,
|
||||||
|
value: Tuple[int, str, int, bool, bool] = object
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def get_dc_address(
|
||||||
|
self,
|
||||||
|
dc_id: int,
|
||||||
|
is_ipv6: bool,
|
||||||
|
test_mode: bool = False,
|
||||||
|
media: bool = False
|
||||||
|
):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def dc_id(self, value: int = object):
|
async def dc_id(self, value: int = object):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue