pyrofork: Initial Dynamic datacenters ip address support

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
wulan17 2024-06-21 00:00:50 +07:00
parent c394a3ea3a
commit 5aae488747
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
16 changed files with 258 additions and 114 deletions

View file

@ -716,10 +716,26 @@ class Client(Methods):
await self.storage.date(0)
await self.storage.test_mode(self.test_mode)
if self.test_mode:
await self.storage.server_address("149.154.167.40")
await self.storage.server_address_v6("2001:67c:4e8:f002::e")
await self.storage.server_port(80)
await self.storage.media_address("149.154.167.40")
await self.storage.media_address_v6("2001:67c:4e8:f002::e")
await self.storage.media_port(80)
else:
await self.storage.server_address("149.154.167.51")
await self.storage.server_address_v6("2001:67c:4e8:f002::a")
await self.storage.server_port(5222 if self.alt_port else 443)
await self.storage.media_address("149.154.167.151")
await self.storage.media_address_v6("2001:067c:04e8:f002:0000:0000:0000:000b")
await self.storage.media_port(5222 if self.alt_port else 443)
await self.storage.auth_key(
await Auth(
self, await self.storage.dc_id(),
await self.storage.test_mode()
await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
).create()
)
await self.storage.user_id(None)
@ -745,6 +761,22 @@ class Client(Methods):
break
except Exception as e:
print(e)
# Needed for migration from storage v4 to v5
if not await self.storage.server_address():
if self.test_mode:
await self.storage.server_address("149.154.167.40")
await self.storage.server_address_v6("2001:67c:4e8:f002::e")
await self.storage.server_port(80)
await self.storage.media_address("")
await self.storage.media_address_v6("")
await self.storage.media_port(80)
else:
await self.storage.server_address("149.154.167.51")
await self.storage.server_address_v6("2001:67c:4e8:f002::a")
await self.storage.server_port(5222 if self.alt_port else 443)
await self.storage.media_address("149.154.167.151")
await self.storage.media_address_v6("2001:067c:04e8:f002:0000:0000:0000:000b")
await self.storage.media_port(5222 if self.alt_port else 443)
def load_plugins(self):
if self.plugins:
@ -950,10 +982,18 @@ class Client(Methods):
session = Session(
self, dc_id,
await Auth(self, dc_id, await self.storage.test_mode()).create()
await Auth(
self,
dc_id,
await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
).create()
if dc_id != await self.storage.dc_id()
else await self.storage.auth_key(),
await self.storage.test_mode(),
await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(),
await self.storage.media_port(),
is_media=True
)
@ -1021,8 +1061,18 @@ class Client(Methods):
elif isinstance(r, raw.types.upload.FileCdnRedirect):
cdn_session = Session(
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
await self.storage.test_mode(), is_media=True, is_cdn=True
self, r.dc_id, await Auth(
self,
r.dc_id,
await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
).create(),
await self.storage.test_mode(),
await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(),
await self.storage.media_port(),
is_media=True,
is_cdn=True
)
try:

View file

@ -22,7 +22,6 @@ import logging
from typing import Optional, Type
from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter
log = logging.getLogger(__name__)
@ -34,21 +33,23 @@ class Connection:
self,
dc_id: int,
test_mode: bool,
server_ip: str,
server_port: int,
ipv6: bool,
alt_port: bool,
proxy: dict,
media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
self.server_ip = server_ip
self.server_port = server_port
self.ipv6 = ipv6
self.alt_port = alt_port
self.proxy = proxy
self.media = media
self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.address = (server_ip, server_port)
self.protocol: Optional[TCP] = None
async def connect(self) -> None:

View file

@ -142,7 +142,9 @@ class SaveFile:
md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(
self, await self.storage.dc_id(), await self.storage.auth_key(),
await self.storage.test_mode(), is_media=True
await self.storage.test_mode(),
await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(),
await self.storage.media_port(), is_media=True
)
workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)]
queue = asyncio.Queue(1)

View file

@ -42,7 +42,9 @@ class Connect:
self.session = Session(
self, await self.storage.dc_id(),
await self.storage.auth_key(), await self.storage.test_mode()
await self.storage.auth_key(), await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
)
await self.session.start()

View file

@ -61,18 +61,43 @@ class SendCode:
)
)
except (PhoneMigrate, NetworkMigrate) as e:
config = await self.invoke(raw.functions.help.GetConfig())
for option in config.dc_options:
if (option.id == e.value):
if option.media_only:
if option.ipv6:
await self.storage.media_address_v6(option.ip_address)
else:
await self.storage.media_address(option.ip_address)
if option.this_port_only:
await self.storage.media_port(option.port)
else:
if option.ipv6:
await self.storage.server_address_v6(option.ip_address)
else:
await self.storage.server_address(option.ip_address)
if option.this_port_only:
await self.storage.port(option.port)
if e not in [2,4] or self.storage.test_mode():
await self.storage.media_address(await self.storage.server_address())
await self.storage.media_address_v6(await self.storage.server_address_v6())
await self.storage.media_port(await self.storage.server_port())
await self.session.stop()
await self.storage.dc_id(e.value)
await self.storage.auth_key(
await Auth(
self, await self.storage.dc_id(),
await self.storage.test_mode()
await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
).create()
)
self.session = Session(
self, await self.storage.dc_id(),
await self.storage.auth_key(), await self.storage.test_mode()
await self.storage.auth_key(), await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
)
await self.session.start()

View file

@ -58,18 +58,43 @@ class SignInBot:
)
)
except UserMigrate as e:
config = await self.invoke(raw.functions.help.GetConfig())
for option in config.dc_options:
if (option.id == e.value):
if option.media_only:
if option.ipv6:
await self.storage.media_address_v6(option.ip_address)
else:
await self.storage.media_address(option.ip_address)
if option.this_port_only:
await self.storage.media_port(option.port)
else:
if option.ipv6:
await self.storage.server_address_v6(option.ip_address)
else:
await self.storage.server_address(option.ip_address)
if option.this_port_only:
await self.storage.port(option.port)
if e not in [2,4] or self.storage.test_mode():
await self.storage.media_address(await self.storage.server_address())
await self.storage.media_address_v6(await self.storage.server_address_v6())
await self.storage.media_port(await self.storage.server_port())
await self.session.stop()
await self.storage.dc_id(e.value)
await self.storage.auth_key(
await Auth(
self, await self.storage.dc_id(),
await self.storage.test_mode()
await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
).create()
)
self.session = Session(
self, await self.storage.dc_id(),
await self.storage.auth_key(), await self.storage.test_mode()
await self.storage.auth_key(), await self.storage.test_mode(),
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
await self.storage.server_port()
)
await self.session.start()

View file

@ -35,7 +35,9 @@ async def get_session(client: "pyrogram.Client", dc_id: int):
session = client.media_sessions[dc_id] = Session(
client, dc_id,
await Auth(client, dc_id, await client.storage.test_mode()).create(),
await client.storage.test_mode(), is_media=True
await client.storage.test_mode(),
await client.storage.media_address_v6() if client.ipv6 else await client.storage.media_address(),
await client.storage.media_port(), is_media=True
)
await session.start()

View file

@ -43,12 +43,15 @@ class Auth:
self,
client: "pyrogram.Client",
dc_id: int,
test_mode: bool
test_mode: bool,
server_ip: str,
server_port: int
):
self.dc_id = dc_id
self.test_mode = test_mode
self.server_ip = server_ip
self.server_port = server_port
self.ipv6 = client.ipv6
self.alt_port = client.alt_port
self.proxy = client.proxy
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
@ -86,14 +89,15 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
while True:
self.connection = self.connection_factory(
self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
ipv6=self.ipv6,
alt_port=self.alt_port,
proxy=self.proxy,
media=False,
protocol_factory=self.protocol_factory
server_ip=self.server_ip,
server_port=self.server_port,
ipv6=self.client.ipv6,
proxy=self.client.proxy,
media=self.is_media,
protocol_factory=self.client.protocol_factory
)
try:

View file

@ -17,6 +17,5 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
from .data_center import DataCenter
from .msg_factory import MsgFactory
from .msg_id import MsgId

View file

@ -1,83 +0,0 @@
# Pyrofork - Telegram MTProto API Client Library for Python
# Copyright (C) 2017-present Dan <https://github.com/delivrance>
# Copyright (C) 2022-present Mayuri-Chan <https://github.com/Mayuri-Chan>
#
# This file is part of Pyrofork.
#
# Pyrofork is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrofork is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
from typing import Tuple
class DataCenter:
TEST = {
1: "149.154.175.10",
2: "149.154.167.40",
3: "149.154.175.117",
}
PROD = {
1: "149.154.175.53",
2: "149.154.167.51",
3: "149.154.175.100",
4: "149.154.167.91",
5: "91.108.56.130",
203: "91.105.192.100"
}
PROD_MEDIA = {
2: "149.154.167.151",
4: "149.154.164.250"
}
TEST_IPV6 = {
1: "2001:b28:f23d:f001::e",
2: "2001:67c:4e8:f002::e",
3: "2001:b28:f23d:f003::e",
}
PROD_IPV6 = {
1: "2001:b28:f23d:f001::a",
2: "2001:67c:4e8:f002::a",
3: "2001:b28:f23d:f003::a",
4: "2001:67c:4e8:f004::a",
5: "2001:b28:f23f:f005::a",
203: "2a0a:f280:0203:000a:5000:0000:0000:0100"
}
PROD_IPV6_MEDIA = {
2: "2001:067c:04e8:f002:0000:0000:0000:000b",
4: "2001:067c:04e8:f004:0000:0000:0000:000b"
}
def __new__(cls, dc_id: int, test_mode: bool, ipv6: bool, alt_port: bool, media: bool) -> Tuple[str, int]:
if test_mode:
if ipv6:
ip = cls.TEST_IPV6[dc_id]
else:
ip = cls.TEST[dc_id]
return ip, 80
else:
if ipv6:
if media:
ip = cls.PROD_IPV6_MEDIA.get(dc_id, cls.PROD_IPV6[dc_id])
else:
ip = cls.PROD_IPV6[dc_id]
else:
if media:
ip = cls.PROD_MEDIA.get(dc_id, cls.PROD[dc_id])
else:
ip = cls.PROD[dc_id]
return ip, 5222 if alt_port else 443

View file

@ -69,6 +69,8 @@ class Session:
dc_id: int,
auth_key: bytes,
test_mode: bool,
server_ip: str,
server_port: int,
is_media: bool = False,
is_cdn: bool = False
):
@ -78,6 +80,8 @@ class Session:
self.test_mode = test_mode
self.is_media = is_media
self.is_cdn = is_cdn
self.server_ip = server_ip
self.server_port = server_port
self.connection: Optional[Connection] = None
@ -108,8 +112,9 @@ class Session:
self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
server_ip=self.server_ip,
server_port=self.server_port,
ipv6=self.client.ipv6,
alt_port=self.client.alt_port,
proxy=self.client.proxy,
media=self.is_media,
protocol_factory=self.client.protocol_factory

View file

@ -38,6 +38,16 @@ CREATE TABLE update_state
);
"""
UPDATE_SESSION_SCHEMA = """
ALTER TABLE sessions
ADD server_address TEXT,
server_address_v6 TEXT,
server_port INTEGER,
media_address TEXT,
media_address_v6 TEXT,
media_port INTEGER;
"""
class FileStorage(SQLiteStorage):
FILE_EXTENSION = ".session"
@ -68,6 +78,12 @@ class FileStorage(SQLiteStorage):
version += 1
if version == 4:
with self.conn:
self.conn.execute(UPDATE_SESSION_SCHEMA)
version += 1
self.version(version)
async def open(self):

View file

@ -19,6 +19,7 @@
import base64
import logging
import re
import sqlite3
import struct
@ -57,11 +58,19 @@ class MemoryStorage(SQLiteStorage):
log.warning("You are using an old session string format. Use export_session_string to update")
return
dc_id, api_id, test_mode, auth_key, user_id, is_bot = struct.unpack(
dc_id, api_id, test_mode, auth_key, user_id, is_bot, server_address, server_address_v6, server_port, media_address, media_address_v6, media_port = struct.unpack(
self.SESSION_STRING_FORMAT,
base64.urlsafe_b64decode(self.session_string + "=" * (-len(self.session_string) % 4))
)
# Remove leading zeros
server_address = re.sub(r'^[0]*', '', re.sub(r'\.[0]*', '.', server_address.decode()))
server_address_v6 = re.sub(r'^[0]*', '', re.sub(r'\:[0]*', ':', server_address_v6.decode()))
server_port = int(server_port.decode())
media_address = re.sub(r'^[0]*', '', re.sub(r'\.[0]*', '.', media_address.decode()))
media_address_v6 = re.sub(r'^[0]*', '', re.sub(r'\:[0]*', ':', media_address_v6.decode()))
media_port = int(media_port.decode())
await self.dc_id(dc_id)
await self.api_id(api_id)
await self.test_mode(test_mode)
@ -69,6 +78,12 @@ class MemoryStorage(SQLiteStorage):
await self.user_id(user_id)
await self.is_bot(is_bot)
await self.date(0)
await self.server_address(server_address)
await self.server_address_v6(server_address_v6)
await self.server_port(server_port)
await self.media_address(media_address)
await self.media_address_v6(media_address_v6)
await self.media_port(media_port)
async def delete(self):
pass

View file

@ -223,6 +223,8 @@ class MongoStorage(Storage):
d = await self._session.find_one({'_id': 0}, {attr: 1})
if not d:
return
if f"{attr}" not in d:
return
return d[attr]
async def _set(self, value: Any):
@ -252,3 +254,21 @@ class MongoStorage(Storage):
async def is_bot(self, value: bool = object):
return await self._accessor(value)
async def server_address(self, value: str = object):
return await self._accessor(value)
async def server_address_v6(self, value: str = object):
return await self._accessor(value)
async def server_port(self, value: int = object):
return await self._accessor(value)
async def media_address(self, value: str = object):
return await self._accessor(value)
async def media_address_v6(self, value: str = object):
return await self._accessor(value)
async def media_port(self, value: int = object):
return await self._accessor(value)

View file

@ -36,7 +36,13 @@ CREATE TABLE sessions
auth_key BLOB,
date INTEGER NOT NULL,
user_id INTEGER,
is_bot INTEGER
is_bot INTEGER,
server_address TEXT,
server_address_v6 TEXT,
server_port INTEGER,
media_address TEXT,
media_address_v6 TEXT,
media_port INTEGER
);
CREATE TABLE peers
@ -138,8 +144,8 @@ class SQLiteStorage(Storage):
)
self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None)
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None, None, None, 0, None, None, 0)
)
async def open(self):
@ -286,6 +292,24 @@ class SQLiteStorage(Storage):
async def is_bot(self, value: bool = object):
return self._accessor(value)
async def server_address(self, value: str = object):
return self._accessor(value)
async def server_address_v6(self, value: str = object):
return self._accessor(value)
async def server_port(self, value: int = object):
return self._accessor(value)
async def media_address(self, value: str = object):
return self._accessor(value)
async def media_address_v6(self, value: str = object):
return self._accessor(value)
async def media_port(self, value: int = object):
return self._accessor(value)
def version(self, value: int = object):
if value == object:
return self.conn.execute(

View file

@ -29,7 +29,7 @@ class Storage:
SESSION_STRING_SIZE = 351
SESSION_STRING_SIZE_64 = 356
SESSION_STRING_FORMAT = ">BI?256sQ?"
SESSION_STRING_FORMAT = ">BI?256sQ?15s39sB15s39sB"
def __init__(self, name: str):
self.name = name
@ -97,7 +97,38 @@ class Storage:
async def is_bot(self, value: bool = object):
raise NotImplementedError
async def server_address(self, value: str = object):
raise NotImplementedError
async def server_address_v6(self, value: str = object):
raise NotImplementedError
async def server_port(self, value: int = object):
raise NotImplementedError
async def media_address(self, value: str = object):
raise NotImplementedError
async def media_address_v6(self, value: str = object):
raise NotImplementedError
async def media_port(self, value: int = object):
raise NotImplementedError
async def export_session_string(self):
server_ip = await self.server_address()
server_ip_v6 = await self.server_address_v6()
port_server = await self.server_port()
media_ip = await self.media_address()
media_ip_v6 = await self.media_address_v6()
port_media = await self.media_port()
# Add leading zero to make fixed size
server_address = '.'.join(i.zfill(3) for i in server_ip.split('.'))
server_address_v6 = ':'.join(i.zfill(4) for i in server_ip_v6.split(':'))
server_port = f"{port_server}".zfill(5)
media_address = '.'.join(i.zfill(3) for i in media_ip.split('.'))
media_address_v6 = ':'.join(i.zfill(4) for i in media_ip_v6.split(':'))
media_port = f"{port_media}".zfill(5)
packed = struct.pack(
self.SESSION_STRING_FORMAT,
await self.dc_id(),
@ -105,7 +136,13 @@ class Storage:
await self.test_mode(),
await self.auth_key(),
await self.user_id(),
await self.is_bot()
await self.is_bot(),
bytes(server_address.encode()),
bytes(server_address_v6.encode()),
bytes(server_port.encode()),
bytes(media_address.encode()),
bytes(media_address_v6.encode()),
bytes(media_port.encode())
)
return base64.urlsafe_b64encode(packed).decode().rstrip("=")