mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2025-12-29 12:04:51 +00:00
pyrofork: Initial Dynamic datacenters ip address support
Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
parent
c394a3ea3a
commit
5aae488747
16 changed files with 258 additions and 114 deletions
|
|
@ -716,10 +716,26 @@ class Client(Methods):
|
||||||
await self.storage.date(0)
|
await self.storage.date(0)
|
||||||
|
|
||||||
await self.storage.test_mode(self.test_mode)
|
await self.storage.test_mode(self.test_mode)
|
||||||
|
if self.test_mode:
|
||||||
|
await self.storage.server_address("149.154.167.40")
|
||||||
|
await self.storage.server_address_v6("2001:67c:4e8:f002::e")
|
||||||
|
await self.storage.server_port(80)
|
||||||
|
await self.storage.media_address("149.154.167.40")
|
||||||
|
await self.storage.media_address_v6("2001:67c:4e8:f002::e")
|
||||||
|
await self.storage.media_port(80)
|
||||||
|
else:
|
||||||
|
await self.storage.server_address("149.154.167.51")
|
||||||
|
await self.storage.server_address_v6("2001:67c:4e8:f002::a")
|
||||||
|
await self.storage.server_port(5222 if self.alt_port else 443)
|
||||||
|
await self.storage.media_address("149.154.167.151")
|
||||||
|
await self.storage.media_address_v6("2001:067c:04e8:f002:0000:0000:0000:000b")
|
||||||
|
await self.storage.media_port(5222 if self.alt_port else 443)
|
||||||
await self.storage.auth_key(
|
await self.storage.auth_key(
|
||||||
await Auth(
|
await Auth(
|
||||||
self, await self.storage.dc_id(),
|
self, await self.storage.dc_id(),
|
||||||
await self.storage.test_mode()
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
).create()
|
).create()
|
||||||
)
|
)
|
||||||
await self.storage.user_id(None)
|
await self.storage.user_id(None)
|
||||||
|
|
@ -745,6 +761,22 @@ class Client(Methods):
|
||||||
break
|
break
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
print(e)
|
print(e)
|
||||||
|
# Needed for migration from storage v4 to v5
|
||||||
|
if not await self.storage.server_address():
|
||||||
|
if self.test_mode:
|
||||||
|
await self.storage.server_address("149.154.167.40")
|
||||||
|
await self.storage.server_address_v6("2001:67c:4e8:f002::e")
|
||||||
|
await self.storage.server_port(80)
|
||||||
|
await self.storage.media_address("")
|
||||||
|
await self.storage.media_address_v6("")
|
||||||
|
await self.storage.media_port(80)
|
||||||
|
else:
|
||||||
|
await self.storage.server_address("149.154.167.51")
|
||||||
|
await self.storage.server_address_v6("2001:67c:4e8:f002::a")
|
||||||
|
await self.storage.server_port(5222 if self.alt_port else 443)
|
||||||
|
await self.storage.media_address("149.154.167.151")
|
||||||
|
await self.storage.media_address_v6("2001:067c:04e8:f002:0000:0000:0000:000b")
|
||||||
|
await self.storage.media_port(5222 if self.alt_port else 443)
|
||||||
|
|
||||||
def load_plugins(self):
|
def load_plugins(self):
|
||||||
if self.plugins:
|
if self.plugins:
|
||||||
|
|
@ -950,10 +982,18 @@ class Client(Methods):
|
||||||
|
|
||||||
session = Session(
|
session = Session(
|
||||||
self, dc_id,
|
self, dc_id,
|
||||||
await Auth(self, dc_id, await self.storage.test_mode()).create()
|
await Auth(
|
||||||
|
self,
|
||||||
|
dc_id,
|
||||||
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
|
).create()
|
||||||
if dc_id != await self.storage.dc_id()
|
if dc_id != await self.storage.dc_id()
|
||||||
else await self.storage.auth_key(),
|
else await self.storage.auth_key(),
|
||||||
await self.storage.test_mode(),
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(),
|
||||||
|
await self.storage.media_port(),
|
||||||
is_media=True
|
is_media=True
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
@ -1021,8 +1061,18 @@ class Client(Methods):
|
||||||
|
|
||||||
elif isinstance(r, raw.types.upload.FileCdnRedirect):
|
elif isinstance(r, raw.types.upload.FileCdnRedirect):
|
||||||
cdn_session = Session(
|
cdn_session = Session(
|
||||||
self, r.dc_id, await Auth(self, r.dc_id, await self.storage.test_mode()).create(),
|
self, r.dc_id, await Auth(
|
||||||
await self.storage.test_mode(), is_media=True, is_cdn=True
|
self,
|
||||||
|
r.dc_id,
|
||||||
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
|
).create(),
|
||||||
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(),
|
||||||
|
await self.storage.media_port(),
|
||||||
|
is_media=True,
|
||||||
|
is_cdn=True
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -22,7 +22,6 @@ import logging
|
||||||
from typing import Optional, Type
|
from typing import Optional, Type
|
||||||
|
|
||||||
from .transport import TCP, TCPAbridged
|
from .transport import TCP, TCPAbridged
|
||||||
from ..session.internals import DataCenter
|
|
||||||
|
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
@ -34,21 +33,23 @@ class Connection:
|
||||||
self,
|
self,
|
||||||
dc_id: int,
|
dc_id: int,
|
||||||
test_mode: bool,
|
test_mode: bool,
|
||||||
|
server_ip: str,
|
||||||
|
server_port: int,
|
||||||
ipv6: bool,
|
ipv6: bool,
|
||||||
alt_port: bool,
|
|
||||||
proxy: dict,
|
proxy: dict,
|
||||||
media: bool = False,
|
media: bool = False,
|
||||||
protocol_factory: Type[TCP] = TCPAbridged
|
protocol_factory: Type[TCP] = TCPAbridged
|
||||||
) -> None:
|
) -> None:
|
||||||
self.dc_id = dc_id
|
self.dc_id = dc_id
|
||||||
self.test_mode = test_mode
|
self.test_mode = test_mode
|
||||||
|
self.server_ip = server_ip
|
||||||
|
self.server_port = server_port
|
||||||
self.ipv6 = ipv6
|
self.ipv6 = ipv6
|
||||||
self.alt_port = alt_port
|
|
||||||
self.proxy = proxy
|
self.proxy = proxy
|
||||||
self.media = media
|
self.media = media
|
||||||
self.protocol_factory = protocol_factory
|
self.protocol_factory = protocol_factory
|
||||||
|
|
||||||
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
|
self.address = (server_ip, server_port)
|
||||||
self.protocol: Optional[TCP] = None
|
self.protocol: Optional[TCP] = None
|
||||||
|
|
||||||
async def connect(self) -> None:
|
async def connect(self) -> None:
|
||||||
|
|
|
||||||
|
|
@ -142,7 +142,9 @@ class SaveFile:
|
||||||
md5_sum = md5() if not is_big and not is_missing_part else None
|
md5_sum = md5() if not is_big and not is_missing_part else None
|
||||||
session = Session(
|
session = Session(
|
||||||
self, await self.storage.dc_id(), await self.storage.auth_key(),
|
self, await self.storage.dc_id(), await self.storage.auth_key(),
|
||||||
await self.storage.test_mode(), is_media=True
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.media_address_v6() if self.ipv6 else await self.storage.media_address(),
|
||||||
|
await self.storage.media_port(), is_media=True
|
||||||
)
|
)
|
||||||
workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)]
|
workers = [self.loop.create_task(worker(session)) for _ in range(workers_count)]
|
||||||
queue = asyncio.Queue(1)
|
queue = asyncio.Queue(1)
|
||||||
|
|
|
||||||
|
|
@ -42,7 +42,9 @@ class Connect:
|
||||||
|
|
||||||
self.session = Session(
|
self.session = Session(
|
||||||
self, await self.storage.dc_id(),
|
self, await self.storage.dc_id(),
|
||||||
await self.storage.auth_key(), await self.storage.test_mode()
|
await self.storage.auth_key(), await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.session.start()
|
await self.session.start()
|
||||||
|
|
|
||||||
|
|
@ -61,18 +61,43 @@ class SendCode:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except (PhoneMigrate, NetworkMigrate) as e:
|
except (PhoneMigrate, NetworkMigrate) as e:
|
||||||
|
config = await self.invoke(raw.functions.help.GetConfig())
|
||||||
|
for option in config.dc_options:
|
||||||
|
if (option.id == e.value):
|
||||||
|
if option.media_only:
|
||||||
|
if option.ipv6:
|
||||||
|
await self.storage.media_address_v6(option.ip_address)
|
||||||
|
else:
|
||||||
|
await self.storage.media_address(option.ip_address)
|
||||||
|
if option.this_port_only:
|
||||||
|
await self.storage.media_port(option.port)
|
||||||
|
else:
|
||||||
|
if option.ipv6:
|
||||||
|
await self.storage.server_address_v6(option.ip_address)
|
||||||
|
else:
|
||||||
|
await self.storage.server_address(option.ip_address)
|
||||||
|
if option.this_port_only:
|
||||||
|
await self.storage.port(option.port)
|
||||||
|
if e not in [2,4] or self.storage.test_mode():
|
||||||
|
await self.storage.media_address(await self.storage.server_address())
|
||||||
|
await self.storage.media_address_v6(await self.storage.server_address_v6())
|
||||||
|
await self.storage.media_port(await self.storage.server_port())
|
||||||
await self.session.stop()
|
await self.session.stop()
|
||||||
|
|
||||||
await self.storage.dc_id(e.value)
|
await self.storage.dc_id(e.value)
|
||||||
await self.storage.auth_key(
|
await self.storage.auth_key(
|
||||||
await Auth(
|
await Auth(
|
||||||
self, await self.storage.dc_id(),
|
self, await self.storage.dc_id(),
|
||||||
await self.storage.test_mode()
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
).create()
|
).create()
|
||||||
)
|
)
|
||||||
self.session = Session(
|
self.session = Session(
|
||||||
self, await self.storage.dc_id(),
|
self, await self.storage.dc_id(),
|
||||||
await self.storage.auth_key(), await self.storage.test_mode()
|
await self.storage.auth_key(), await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.session.start()
|
await self.session.start()
|
||||||
|
|
|
||||||
|
|
@ -58,18 +58,43 @@ class SignInBot:
|
||||||
)
|
)
|
||||||
)
|
)
|
||||||
except UserMigrate as e:
|
except UserMigrate as e:
|
||||||
|
config = await self.invoke(raw.functions.help.GetConfig())
|
||||||
|
for option in config.dc_options:
|
||||||
|
if (option.id == e.value):
|
||||||
|
if option.media_only:
|
||||||
|
if option.ipv6:
|
||||||
|
await self.storage.media_address_v6(option.ip_address)
|
||||||
|
else:
|
||||||
|
await self.storage.media_address(option.ip_address)
|
||||||
|
if option.this_port_only:
|
||||||
|
await self.storage.media_port(option.port)
|
||||||
|
else:
|
||||||
|
if option.ipv6:
|
||||||
|
await self.storage.server_address_v6(option.ip_address)
|
||||||
|
else:
|
||||||
|
await self.storage.server_address(option.ip_address)
|
||||||
|
if option.this_port_only:
|
||||||
|
await self.storage.port(option.port)
|
||||||
|
if e not in [2,4] or self.storage.test_mode():
|
||||||
|
await self.storage.media_address(await self.storage.server_address())
|
||||||
|
await self.storage.media_address_v6(await self.storage.server_address_v6())
|
||||||
|
await self.storage.media_port(await self.storage.server_port())
|
||||||
await self.session.stop()
|
await self.session.stop()
|
||||||
|
|
||||||
await self.storage.dc_id(e.value)
|
await self.storage.dc_id(e.value)
|
||||||
await self.storage.auth_key(
|
await self.storage.auth_key(
|
||||||
await Auth(
|
await Auth(
|
||||||
self, await self.storage.dc_id(),
|
self, await self.storage.dc_id(),
|
||||||
await self.storage.test_mode()
|
await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
).create()
|
).create()
|
||||||
)
|
)
|
||||||
self.session = Session(
|
self.session = Session(
|
||||||
self, await self.storage.dc_id(),
|
self, await self.storage.dc_id(),
|
||||||
await self.storage.auth_key(), await self.storage.test_mode()
|
await self.storage.auth_key(), await self.storage.test_mode(),
|
||||||
|
await self.storage.server_address_v6() if self.ipv6 else await self.storage.server_address(),
|
||||||
|
await self.storage.server_port()
|
||||||
)
|
)
|
||||||
|
|
||||||
await self.session.start()
|
await self.session.start()
|
||||||
|
|
|
||||||
|
|
@ -35,7 +35,9 @@ async def get_session(client: "pyrogram.Client", dc_id: int):
|
||||||
session = client.media_sessions[dc_id] = Session(
|
session = client.media_sessions[dc_id] = Session(
|
||||||
client, dc_id,
|
client, dc_id,
|
||||||
await Auth(client, dc_id, await client.storage.test_mode()).create(),
|
await Auth(client, dc_id, await client.storage.test_mode()).create(),
|
||||||
await client.storage.test_mode(), is_media=True
|
await client.storage.test_mode(),
|
||||||
|
await client.storage.media_address_v6() if client.ipv6 else await client.storage.media_address(),
|
||||||
|
await client.storage.media_port(), is_media=True
|
||||||
)
|
)
|
||||||
|
|
||||||
await session.start()
|
await session.start()
|
||||||
|
|
|
||||||
|
|
@ -43,12 +43,15 @@ class Auth:
|
||||||
self,
|
self,
|
||||||
client: "pyrogram.Client",
|
client: "pyrogram.Client",
|
||||||
dc_id: int,
|
dc_id: int,
|
||||||
test_mode: bool
|
test_mode: bool,
|
||||||
|
server_ip: str,
|
||||||
|
server_port: int
|
||||||
):
|
):
|
||||||
self.dc_id = dc_id
|
self.dc_id = dc_id
|
||||||
self.test_mode = test_mode
|
self.test_mode = test_mode
|
||||||
|
self.server_ip = server_ip
|
||||||
|
self.server_port = server_port
|
||||||
self.ipv6 = client.ipv6
|
self.ipv6 = client.ipv6
|
||||||
self.alt_port = client.alt_port
|
|
||||||
self.proxy = client.proxy
|
self.proxy = client.proxy
|
||||||
self.connection_factory = client.connection_factory
|
self.connection_factory = client.connection_factory
|
||||||
self.protocol_factory = client.protocol_factory
|
self.protocol_factory = client.protocol_factory
|
||||||
|
|
@ -86,14 +89,15 @@ class Auth:
|
||||||
# The server may close the connection at any time, causing the auth key creation to fail.
|
# The server may close the connection at any time, causing the auth key creation to fail.
|
||||||
# If that happens, just try again up to MAX_RETRIES times.
|
# If that happens, just try again up to MAX_RETRIES times.
|
||||||
while True:
|
while True:
|
||||||
self.connection = self.connection_factory(
|
self.connection = self.client.connection_factory(
|
||||||
dc_id=self.dc_id,
|
dc_id=self.dc_id,
|
||||||
test_mode=self.test_mode,
|
test_mode=self.test_mode,
|
||||||
ipv6=self.ipv6,
|
server_ip=self.server_ip,
|
||||||
alt_port=self.alt_port,
|
server_port=self.server_port,
|
||||||
proxy=self.proxy,
|
ipv6=self.client.ipv6,
|
||||||
media=False,
|
proxy=self.client.proxy,
|
||||||
protocol_factory=self.protocol_factory
|
media=self.is_media,
|
||||||
|
protocol_factory=self.client.protocol_factory
|
||||||
)
|
)
|
||||||
|
|
||||||
try:
|
try:
|
||||||
|
|
|
||||||
|
|
@ -17,6 +17,5 @@
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
# You should have received a copy of the GNU Lesser General Public License
|
||||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
|
||||||
from .data_center import DataCenter
|
|
||||||
from .msg_factory import MsgFactory
|
from .msg_factory import MsgFactory
|
||||||
from .msg_id import MsgId
|
from .msg_id import MsgId
|
||||||
|
|
|
||||||
|
|
@ -1,83 +0,0 @@
|
||||||
# Pyrofork - Telegram MTProto API Client Library for Python
|
|
||||||
# Copyright (C) 2017-present Dan <https://github.com/delivrance>
|
|
||||||
# Copyright (C) 2022-present Mayuri-Chan <https://github.com/Mayuri-Chan>
|
|
||||||
#
|
|
||||||
# This file is part of Pyrofork.
|
|
||||||
#
|
|
||||||
# Pyrofork is free software: you can redistribute it and/or modify
|
|
||||||
# it under the terms of the GNU Lesser General Public License as published
|
|
||||||
# by the Free Software Foundation, either version 3 of the License, or
|
|
||||||
# (at your option) any later version.
|
|
||||||
#
|
|
||||||
# Pyrofork is distributed in the hope that it will be useful,
|
|
||||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
# GNU Lesser General Public License for more details.
|
|
||||||
#
|
|
||||||
# You should have received a copy of the GNU Lesser General Public License
|
|
||||||
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
|
|
||||||
from typing import Tuple
|
|
||||||
|
|
||||||
|
|
||||||
class DataCenter:
|
|
||||||
TEST = {
|
|
||||||
1: "149.154.175.10",
|
|
||||||
2: "149.154.167.40",
|
|
||||||
3: "149.154.175.117",
|
|
||||||
}
|
|
||||||
|
|
||||||
PROD = {
|
|
||||||
1: "149.154.175.53",
|
|
||||||
2: "149.154.167.51",
|
|
||||||
3: "149.154.175.100",
|
|
||||||
4: "149.154.167.91",
|
|
||||||
5: "91.108.56.130",
|
|
||||||
203: "91.105.192.100"
|
|
||||||
}
|
|
||||||
|
|
||||||
PROD_MEDIA = {
|
|
||||||
2: "149.154.167.151",
|
|
||||||
4: "149.154.164.250"
|
|
||||||
}
|
|
||||||
|
|
||||||
TEST_IPV6 = {
|
|
||||||
1: "2001:b28:f23d:f001::e",
|
|
||||||
2: "2001:67c:4e8:f002::e",
|
|
||||||
3: "2001:b28:f23d:f003::e",
|
|
||||||
}
|
|
||||||
|
|
||||||
PROD_IPV6 = {
|
|
||||||
1: "2001:b28:f23d:f001::a",
|
|
||||||
2: "2001:67c:4e8:f002::a",
|
|
||||||
3: "2001:b28:f23d:f003::a",
|
|
||||||
4: "2001:67c:4e8:f004::a",
|
|
||||||
5: "2001:b28:f23f:f005::a",
|
|
||||||
203: "2a0a:f280:0203:000a:5000:0000:0000:0100"
|
|
||||||
}
|
|
||||||
|
|
||||||
PROD_IPV6_MEDIA = {
|
|
||||||
2: "2001:067c:04e8:f002:0000:0000:0000:000b",
|
|
||||||
4: "2001:067c:04e8:f004:0000:0000:0000:000b"
|
|
||||||
}
|
|
||||||
|
|
||||||
def __new__(cls, dc_id: int, test_mode: bool, ipv6: bool, alt_port: bool, media: bool) -> Tuple[str, int]:
|
|
||||||
if test_mode:
|
|
||||||
if ipv6:
|
|
||||||
ip = cls.TEST_IPV6[dc_id]
|
|
||||||
else:
|
|
||||||
ip = cls.TEST[dc_id]
|
|
||||||
|
|
||||||
return ip, 80
|
|
||||||
else:
|
|
||||||
if ipv6:
|
|
||||||
if media:
|
|
||||||
ip = cls.PROD_IPV6_MEDIA.get(dc_id, cls.PROD_IPV6[dc_id])
|
|
||||||
else:
|
|
||||||
ip = cls.PROD_IPV6[dc_id]
|
|
||||||
else:
|
|
||||||
if media:
|
|
||||||
ip = cls.PROD_MEDIA.get(dc_id, cls.PROD[dc_id])
|
|
||||||
else:
|
|
||||||
ip = cls.PROD[dc_id]
|
|
||||||
return ip, 5222 if alt_port else 443
|
|
||||||
|
|
@ -69,6 +69,8 @@ class Session:
|
||||||
dc_id: int,
|
dc_id: int,
|
||||||
auth_key: bytes,
|
auth_key: bytes,
|
||||||
test_mode: bool,
|
test_mode: bool,
|
||||||
|
server_ip: str,
|
||||||
|
server_port: int,
|
||||||
is_media: bool = False,
|
is_media: bool = False,
|
||||||
is_cdn: bool = False
|
is_cdn: bool = False
|
||||||
):
|
):
|
||||||
|
|
@ -78,6 +80,8 @@ class Session:
|
||||||
self.test_mode = test_mode
|
self.test_mode = test_mode
|
||||||
self.is_media = is_media
|
self.is_media = is_media
|
||||||
self.is_cdn = is_cdn
|
self.is_cdn = is_cdn
|
||||||
|
self.server_ip = server_ip
|
||||||
|
self.server_port = server_port
|
||||||
|
|
||||||
self.connection: Optional[Connection] = None
|
self.connection: Optional[Connection] = None
|
||||||
|
|
||||||
|
|
@ -108,8 +112,9 @@ class Session:
|
||||||
self.connection = self.client.connection_factory(
|
self.connection = self.client.connection_factory(
|
||||||
dc_id=self.dc_id,
|
dc_id=self.dc_id,
|
||||||
test_mode=self.test_mode,
|
test_mode=self.test_mode,
|
||||||
|
server_ip=self.server_ip,
|
||||||
|
server_port=self.server_port,
|
||||||
ipv6=self.client.ipv6,
|
ipv6=self.client.ipv6,
|
||||||
alt_port=self.client.alt_port,
|
|
||||||
proxy=self.client.proxy,
|
proxy=self.client.proxy,
|
||||||
media=self.is_media,
|
media=self.is_media,
|
||||||
protocol_factory=self.client.protocol_factory
|
protocol_factory=self.client.protocol_factory
|
||||||
|
|
|
||||||
|
|
@ -38,6 +38,16 @@ CREATE TABLE update_state
|
||||||
);
|
);
|
||||||
"""
|
"""
|
||||||
|
|
||||||
|
UPDATE_SESSION_SCHEMA = """
|
||||||
|
ALTER TABLE sessions
|
||||||
|
ADD server_address TEXT,
|
||||||
|
server_address_v6 TEXT,
|
||||||
|
server_port INTEGER,
|
||||||
|
media_address TEXT,
|
||||||
|
media_address_v6 TEXT,
|
||||||
|
media_port INTEGER;
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FileStorage(SQLiteStorage):
|
class FileStorage(SQLiteStorage):
|
||||||
FILE_EXTENSION = ".session"
|
FILE_EXTENSION = ".session"
|
||||||
|
|
@ -68,6 +78,12 @@ class FileStorage(SQLiteStorage):
|
||||||
|
|
||||||
version += 1
|
version += 1
|
||||||
|
|
||||||
|
if version == 4:
|
||||||
|
with self.conn:
|
||||||
|
self.conn.execute(UPDATE_SESSION_SCHEMA)
|
||||||
|
|
||||||
|
version += 1
|
||||||
|
|
||||||
self.version(version)
|
self.version(version)
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import logging
|
import logging
|
||||||
|
import re
|
||||||
import sqlite3
|
import sqlite3
|
||||||
import struct
|
import struct
|
||||||
|
|
||||||
|
|
@ -57,11 +58,19 @@ class MemoryStorage(SQLiteStorage):
|
||||||
log.warning("You are using an old session string format. Use export_session_string to update")
|
log.warning("You are using an old session string format. Use export_session_string to update")
|
||||||
return
|
return
|
||||||
|
|
||||||
dc_id, api_id, test_mode, auth_key, user_id, is_bot = struct.unpack(
|
dc_id, api_id, test_mode, auth_key, user_id, is_bot, server_address, server_address_v6, server_port, media_address, media_address_v6, media_port = struct.unpack(
|
||||||
self.SESSION_STRING_FORMAT,
|
self.SESSION_STRING_FORMAT,
|
||||||
base64.urlsafe_b64decode(self.session_string + "=" * (-len(self.session_string) % 4))
|
base64.urlsafe_b64decode(self.session_string + "=" * (-len(self.session_string) % 4))
|
||||||
)
|
)
|
||||||
|
|
||||||
|
# Remove leading zeros
|
||||||
|
server_address = re.sub(r'^[0]*', '', re.sub(r'\.[0]*', '.', server_address.decode()))
|
||||||
|
server_address_v6 = re.sub(r'^[0]*', '', re.sub(r'\:[0]*', ':', server_address_v6.decode()))
|
||||||
|
server_port = int(server_port.decode())
|
||||||
|
media_address = re.sub(r'^[0]*', '', re.sub(r'\.[0]*', '.', media_address.decode()))
|
||||||
|
media_address_v6 = re.sub(r'^[0]*', '', re.sub(r'\:[0]*', ':', media_address_v6.decode()))
|
||||||
|
media_port = int(media_port.decode())
|
||||||
|
|
||||||
await self.dc_id(dc_id)
|
await self.dc_id(dc_id)
|
||||||
await self.api_id(api_id)
|
await self.api_id(api_id)
|
||||||
await self.test_mode(test_mode)
|
await self.test_mode(test_mode)
|
||||||
|
|
@ -69,6 +78,12 @@ class MemoryStorage(SQLiteStorage):
|
||||||
await self.user_id(user_id)
|
await self.user_id(user_id)
|
||||||
await self.is_bot(is_bot)
|
await self.is_bot(is_bot)
|
||||||
await self.date(0)
|
await self.date(0)
|
||||||
|
await self.server_address(server_address)
|
||||||
|
await self.server_address_v6(server_address_v6)
|
||||||
|
await self.server_port(server_port)
|
||||||
|
await self.media_address(media_address)
|
||||||
|
await self.media_address_v6(media_address_v6)
|
||||||
|
await self.media_port(media_port)
|
||||||
|
|
||||||
async def delete(self):
|
async def delete(self):
|
||||||
pass
|
pass
|
||||||
|
|
|
||||||
|
|
@ -223,6 +223,8 @@ class MongoStorage(Storage):
|
||||||
d = await self._session.find_one({'_id': 0}, {attr: 1})
|
d = await self._session.find_one({'_id': 0}, {attr: 1})
|
||||||
if not d:
|
if not d:
|
||||||
return
|
return
|
||||||
|
if f"{attr}" not in d:
|
||||||
|
return
|
||||||
return d[attr]
|
return d[attr]
|
||||||
|
|
||||||
async def _set(self, value: Any):
|
async def _set(self, value: Any):
|
||||||
|
|
@ -252,3 +254,21 @@ class MongoStorage(Storage):
|
||||||
|
|
||||||
async def is_bot(self, value: bool = object):
|
async def is_bot(self, value: bool = object):
|
||||||
return await self._accessor(value)
|
return await self._accessor(value)
|
||||||
|
|
||||||
|
async def server_address(self, value: str = object):
|
||||||
|
return await self._accessor(value)
|
||||||
|
|
||||||
|
async def server_address_v6(self, value: str = object):
|
||||||
|
return await self._accessor(value)
|
||||||
|
|
||||||
|
async def server_port(self, value: int = object):
|
||||||
|
return await self._accessor(value)
|
||||||
|
|
||||||
|
async def media_address(self, value: str = object):
|
||||||
|
return await self._accessor(value)
|
||||||
|
|
||||||
|
async def media_address_v6(self, value: str = object):
|
||||||
|
return await self._accessor(value)
|
||||||
|
|
||||||
|
async def media_port(self, value: int = object):
|
||||||
|
return await self._accessor(value)
|
||||||
|
|
|
||||||
|
|
@ -36,7 +36,13 @@ CREATE TABLE sessions
|
||||||
auth_key BLOB,
|
auth_key BLOB,
|
||||||
date INTEGER NOT NULL,
|
date INTEGER NOT NULL,
|
||||||
user_id INTEGER,
|
user_id INTEGER,
|
||||||
is_bot INTEGER
|
is_bot INTEGER,
|
||||||
|
server_address TEXT,
|
||||||
|
server_address_v6 TEXT,
|
||||||
|
server_port INTEGER,
|
||||||
|
media_address TEXT,
|
||||||
|
media_address_v6 TEXT,
|
||||||
|
media_port INTEGER
|
||||||
);
|
);
|
||||||
|
|
||||||
CREATE TABLE peers
|
CREATE TABLE peers
|
||||||
|
|
@ -138,8 +144,8 @@ class SQLiteStorage(Storage):
|
||||||
)
|
)
|
||||||
|
|
||||||
self.conn.execute(
|
self.conn.execute(
|
||||||
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
|
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?, ?)",
|
||||||
(2, None, None, None, 0, None, None)
|
(2, None, None, None, 0, None, None, None, None, 0, None, None, 0)
|
||||||
)
|
)
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
|
|
@ -286,6 +292,24 @@ class SQLiteStorage(Storage):
|
||||||
async def is_bot(self, value: bool = object):
|
async def is_bot(self, value: bool = object):
|
||||||
return self._accessor(value)
|
return self._accessor(value)
|
||||||
|
|
||||||
|
async def server_address(self, value: str = object):
|
||||||
|
return self._accessor(value)
|
||||||
|
|
||||||
|
async def server_address_v6(self, value: str = object):
|
||||||
|
return self._accessor(value)
|
||||||
|
|
||||||
|
async def server_port(self, value: int = object):
|
||||||
|
return self._accessor(value)
|
||||||
|
|
||||||
|
async def media_address(self, value: str = object):
|
||||||
|
return self._accessor(value)
|
||||||
|
|
||||||
|
async def media_address_v6(self, value: str = object):
|
||||||
|
return self._accessor(value)
|
||||||
|
|
||||||
|
async def media_port(self, value: int = object):
|
||||||
|
return self._accessor(value)
|
||||||
|
|
||||||
def version(self, value: int = object):
|
def version(self, value: int = object):
|
||||||
if value == object:
|
if value == object:
|
||||||
return self.conn.execute(
|
return self.conn.execute(
|
||||||
|
|
|
||||||
|
|
@ -29,7 +29,7 @@ class Storage:
|
||||||
SESSION_STRING_SIZE = 351
|
SESSION_STRING_SIZE = 351
|
||||||
SESSION_STRING_SIZE_64 = 356
|
SESSION_STRING_SIZE_64 = 356
|
||||||
|
|
||||||
SESSION_STRING_FORMAT = ">BI?256sQ?"
|
SESSION_STRING_FORMAT = ">BI?256sQ?15s39sB15s39sB"
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
self.name = name
|
self.name = name
|
||||||
|
|
@ -97,7 +97,38 @@ class Storage:
|
||||||
async def is_bot(self, value: bool = object):
|
async def is_bot(self, value: bool = object):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def server_address(self, value: str = object):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def server_address_v6(self, value: str = object):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def server_port(self, value: int = object):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def media_address(self, value: str = object):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def media_address_v6(self, value: str = object):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
|
async def media_port(self, value: int = object):
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def export_session_string(self):
|
async def export_session_string(self):
|
||||||
|
server_ip = await self.server_address()
|
||||||
|
server_ip_v6 = await self.server_address_v6()
|
||||||
|
port_server = await self.server_port()
|
||||||
|
media_ip = await self.media_address()
|
||||||
|
media_ip_v6 = await self.media_address_v6()
|
||||||
|
port_media = await self.media_port()
|
||||||
|
# Add leading zero to make fixed size
|
||||||
|
server_address = '.'.join(i.zfill(3) for i in server_ip.split('.'))
|
||||||
|
server_address_v6 = ':'.join(i.zfill(4) for i in server_ip_v6.split(':'))
|
||||||
|
server_port = f"{port_server}".zfill(5)
|
||||||
|
media_address = '.'.join(i.zfill(3) for i in media_ip.split('.'))
|
||||||
|
media_address_v6 = ':'.join(i.zfill(4) for i in media_ip_v6.split(':'))
|
||||||
|
media_port = f"{port_media}".zfill(5)
|
||||||
packed = struct.pack(
|
packed = struct.pack(
|
||||||
self.SESSION_STRING_FORMAT,
|
self.SESSION_STRING_FORMAT,
|
||||||
await self.dc_id(),
|
await self.dc_id(),
|
||||||
|
|
@ -105,7 +136,13 @@ class Storage:
|
||||||
await self.test_mode(),
|
await self.test_mode(),
|
||||||
await self.auth_key(),
|
await self.auth_key(),
|
||||||
await self.user_id(),
|
await self.user_id(),
|
||||||
await self.is_bot()
|
await self.is_bot(),
|
||||||
|
bytes(server_address.encode()),
|
||||||
|
bytes(server_address_v6.encode()),
|
||||||
|
bytes(server_port.encode()),
|
||||||
|
bytes(media_address.encode()),
|
||||||
|
bytes(media_address_v6.encode()),
|
||||||
|
bytes(media_port.encode())
|
||||||
)
|
)
|
||||||
|
|
||||||
return base64.urlsafe_b64encode(packed).decode().rstrip("=")
|
return base64.urlsafe_b64encode(packed).decode().rstrip("=")
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue