Implement non-blocking TCP connection (KurimuzonAkuma/pyrogram#71)

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
Artem Ukolov 2024-06-17 08:07:52 +02:00 committed by wulan17
parent 3232a3f139
commit bb4ea00d4e
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
11 changed files with 180 additions and 102 deletions

View file

@ -33,7 +33,7 @@ from importlib import import_module
from io import StringIO, BytesIO
from mimetypes import MimeTypes
from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator
from typing import Union, List, Optional, Callable, AsyncGenerator, Type
import pyrogram
from pyrogram import __version__, __license__
@ -59,6 +59,8 @@ else:
from pyrogram.storage import MongoStorage
from pyrogram.types import User, TermsOfService
from pyrogram.utils import ainput
from .connection import Connection
from .connection.transport import TCP, TCPAbridged
from .dispatcher import Dispatcher
from .file_id import FileId, FileType, ThumbnailSource
from .filters import Filter
@ -313,6 +315,9 @@ class Client(Methods):
else:
self.storage = FileStorage(self.name, self.workdir)
self.connection_factory = Connection
self.protocol_factory = TCPAbridged
self.dispatcher = Dispatcher(self)
self.rnd_id = MsgId

View file

@ -19,7 +19,7 @@
import asyncio
import logging
from typing import Optional
from typing import Optional, Type
from .transport import TCP, TCPAbridged
from ..session.internals import DataCenter
@ -30,20 +30,30 @@ log = logging.getLogger(__name__)
class Connection:
MAX_CONNECTION_ATTEMPTS = 3
def __init__(self, dc_id: int, test_mode: bool, ipv6: bool, alt_port: bool, proxy: dict, media: bool = False):
def __init__(
self,
dc_id: int,
test_mode: bool,
ipv6: bool,
alt_port: bool,
proxy: dict,
media: bool = False,
protocol_factory: Type[TCP] = TCPAbridged
) -> None:
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = ipv6
self.alt_port = alt_port
self.proxy = proxy
self.media = media
self.protocol_factory = protocol_factory
self.address = DataCenter(dc_id, test_mode, ipv6, alt_port, media)
self.protocol: TCP = None
self.protocol: Optional[TCP] = None
async def connect(self):
async def connect(self) -> None:
for i in range(Connection.MAX_CONNECTION_ATTEMPTS):
self.protocol = TCPAbridged(self.ipv6, self.proxy)
self.protocol = self.protocol_factory(ipv6=self.ipv6, proxy=self.proxy)
try:
log.info("Connecting...")
@ -63,11 +73,11 @@ class Connection:
log.warning("Connection failed! Trying again...")
raise ConnectionError
async def close(self):
async def close(self) -> None:
await self.protocol.close()
log.info("Disconnected")
async def send(self, data: bytes):
async def send(self, data: bytes) -> None:
await self.protocol.send(data)
async def recv(self) -> Optional[bytes]:

View file

@ -17,7 +17,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
from .tcp import TCP
from .tcp import TCP, Proxy
from .tcp_abridged import TCPAbridged
from .tcp_abridged_o import TCPAbridgedO
from .tcp_full import TCPFull

View file

@ -22,89 +22,134 @@ import ipaddress
import logging
import socket
from concurrent.futures import ThreadPoolExecutor
from typing import Tuple, Dict, TypedDict, Optional
import socks
log = logging.getLogger(__name__)
proxy_type_by_scheme: Dict[str, int] = {
"SOCKS4": socks.SOCKS4,
"SOCKS5": socks.SOCKS5,
"HTTP": socks.HTTP,
}
class Proxy(TypedDict):
scheme: str
hostname: str
port: int
username: Optional[str]
password: Optional[str]
class TCP:
TIMEOUT = 10
def __init__(self, ipv6: bool, proxy: dict):
self.socket = None
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
self.ipv6 = ipv6
self.proxy = proxy
self.reader = None
self.writer = None
self.reader: Optional[asyncio.StreamReader] = None
self.writer: Optional[asyncio.StreamWriter] = None
self.lock = asyncio.Lock()
self.loop = asyncio.get_event_loop()
self.proxy = proxy
async def _connect_via_proxy(
self,
destination: Tuple[str, int]
) -> None:
scheme = self.proxy.get("scheme")
if scheme is None:
raise ValueError("No scheme specified")
if proxy:
hostname = proxy.get("hostname")
proxy_type = proxy_type_by_scheme.get(scheme.upper())
if proxy_type is None:
raise ValueError(f"Unknown proxy type {scheme}")
hostname = self.proxy.get("hostname")
port = self.proxy.get("port")
username = self.proxy.get("username")
password = self.proxy.get("password")
try:
ip_address = ipaddress.ip_address(hostname)
except ValueError:
self.socket = socks.socksocket(socket.AF_INET)
is_proxy_ipv6 = False
else:
if isinstance(ip_address, ipaddress.IPv6Address):
self.socket = socks.socksocket(socket.AF_INET6)
else:
self.socket = socks.socksocket(socket.AF_INET)
is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address)
self.socket.set_proxy(
proxy_type=getattr(socks, proxy.get("scheme").upper()),
proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET
sock = socks.socksocket(proxy_family)
sock.set_proxy(
proxy_type=proxy_type,
addr=hostname,
port=proxy.get("port", None),
username=proxy.get("username", None),
password=proxy.get("password", None)
port=port,
username=username,
password=password
)
sock.settimeout(TCP.TIMEOUT)
await self.loop.sock_connect(
sock=sock,
address=destination
)
self.socket.settimeout(TCP.TIMEOUT)
sock.setblocking(False)
log.info("Using proxy %s", hostname)
else:
self.socket = socket.socket(
socket.AF_INET6 if ipv6
else socket.AF_INET
self.reader, self.writer = await asyncio.open_connection(
sock=sock
)
self.socket.setblocking(False)
async def _connect_via_direct(
self,
destination: Tuple[str, int]
) -> None:
host, port = destination
family = socket.AF_INET6 if self.ipv6 else socket.AF_INET
self.reader, self.writer = await asyncio.open_connection(
host=host,
port=port,
family=family
)
async def connect(self, address: tuple):
async def _connect(self, destination: Tuple[str, int]) -> None:
if self.proxy:
with ThreadPoolExecutor(1) as executor:
await self.loop.run_in_executor(executor, self.socket.connect, address)
await self._connect_via_proxy(destination)
else:
await self._connect_via_direct(destination)
async def connect(self, address: Tuple[str, int]) -> None:
try:
await asyncio.wait_for(asyncio.get_event_loop().sock_connect(self.socket, address), TCP.TIMEOUT)
await asyncio.wait_for(self._connect(address), TCP.TIMEOUT)
except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11
raise TimeoutError("Connection timed out")
self.reader, self.writer = await asyncio.open_connection(sock=self.socket)
async def close(self) -> None:
if self.writer is None:
return None
async def close(self):
try:
if self.writer is not None:
self.writer.close()
await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT)
except Exception as e:
log.info("Close exception: %s %s", type(e).__name__, e)
async def send(self, data: bytes):
async def send(self, data: bytes) -> None:
if self.writer is None:
return None
async with self.lock:
try:
if self.writer is not None:
self.writer.write(data)
await self.writer.drain()
except Exception as e:
log.info("Send exception: %s %s", type(e).__name__, e)
raise OSError(e)
async def recv(self, length: int = 0):
async def recv(self, length: int = 0) -> Optional[bytes]:
data = b""
while len(data) < length:

View file

@ -18,22 +18,22 @@
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import logging
from typing import Optional
from typing import Optional, Tuple
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
class TCPAbridged(TCP):
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
await super().send(b"\xef")
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
length = len(data) // 4
await super().send(

View file

@ -19,11 +19,11 @@
import logging
import os
from typing import Optional
from typing import Optional, Tuple
import pyrogram
from pyrogram.crypto import aes
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
@ -31,13 +31,13 @@ log = logging.getLogger(__name__)
class TCPAbridgedO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
while True:
@ -56,7 +56,7 @@ class TCPAbridgedO(TCP):
await super().send(nonce)
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
length = len(data) // 4
data = (bytes([length]) if length <= 126 else b"\x7f" + length.to_bytes(3, "little")) + data
payload = await self.loop.run_in_executor(pyrogram.crypto_executor, aes.ctr256_encrypt, data, *self.encrypt)

View file

@ -20,24 +20,24 @@
import logging
from binascii import crc32
from struct import pack, unpack
from typing import Optional
from typing import Optional, Tuple
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
class TCPFull(TCP):
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
self.seq_no = None
self.seq_no: Optional[int] = None
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
self.seq_no = 0
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
data = pack("<II", len(data) + 12, self.seq_no) + data
data += pack("<I", crc32(data))
self.seq_no += 1

View file

@ -19,22 +19,22 @@
import logging
from struct import pack, unpack
from typing import Optional
from typing import Optional, Tuple
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
class TCPIntermediate(TCP):
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
await super().send(b"\xee" * 4)
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
await super().send(pack("<i", len(data)) + data)
async def recv(self, length: int = 0) -> Optional[bytes]:

View file

@ -20,10 +20,10 @@
import logging
import os
from struct import pack, unpack
from typing import Optional
from typing import Optional, Tuple
from pyrogram.crypto import aes
from .tcp import TCP
from .tcp import TCP, Proxy
log = logging.getLogger(__name__)
@ -31,13 +31,13 @@ log = logging.getLogger(__name__)
class TCPIntermediateO(TCP):
RESERVED = (b"HEAD", b"POST", b"GET ", b"OPTI", b"\xee" * 4)
def __init__(self, ipv6: bool, proxy: dict):
def __init__(self, ipv6: bool, proxy: Proxy) -> None:
super().__init__(ipv6, proxy)
self.encrypt = None
self.decrypt = None
async def connect(self, address: tuple):
async def connect(self, address: Tuple[str, int]) -> None:
await super().connect(address)
while True:
@ -56,7 +56,7 @@ class TCPIntermediateO(TCP):
await super().send(nonce)
async def send(self, data: bytes, *args):
async def send(self, data: bytes, *args) -> None:
await super().send(
aes.ctr256_encrypt(
pack("<i", len(data)) + data,

View file

@ -23,6 +23,7 @@ import time
from hashlib import sha1
from io import BytesIO
from os import urandom
from typing import Optional
import pyrogram
from pyrogram import raw
@ -38,14 +39,21 @@ log = logging.getLogger(__name__)
class Auth:
MAX_RETRIES = 5
def __init__(self, client: "pyrogram.Client", dc_id: int, test_mode: bool):
def __init__(
self,
client: "pyrogram.Client",
dc_id: int,
test_mode: bool
):
self.dc_id = dc_id
self.test_mode = test_mode
self.ipv6 = client.ipv6
self.alt_port = client.alt_port
self.proxy = client.proxy
self.connection_factory = client.connection_factory
self.protocol_factory = client.protocol_factory
self.connection = None
self.connection: Optional[Connection] = None
@staticmethod
def pack(data: TLObject) -> bytes:
@ -78,7 +86,15 @@ class Auth:
# The server may close the connection at any time, causing the auth key creation to fail.
# If that happens, just try again up to MAX_RETRIES times.
while True:
self.connection = Connection(self.dc_id, self.test_mode, self.ipv6, self.alt_port, self.proxy)
self.connection = self.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
ipv6=self.ipv6,
alt_port=self.alt_port,
proxy=self.proxy,
media=False,
protocol_factory=self.protocol_factory
)
try:
log.info("Start creating a new auth key on DC%s", self.dc_id)

View file

@ -23,6 +23,7 @@ import logging
import os
from hashlib import sha1
from io import BytesIO
from typing import Optional
import pyrogram
from pyrogram import raw
@ -78,7 +79,7 @@ class Session:
self.is_media = is_media
self.is_cdn = is_cdn
self.connection = None
self.connection: Optional[Connection] = None
self.auth_key_id = sha1(auth_key).digest()[-8:]
@ -104,13 +105,14 @@ class Session:
async def start(self):
while True:
self.connection = Connection(
self.dc_id,
self.test_mode,
self.client.ipv6,
self.client.alt_port,
self.client.proxy,
self.is_media
self.connection = self.client.connection_factory(
dc_id=self.dc_id,
test_mode=self.test_mode,
ipv6=self.client.ipv6,
alt_port=self.client.alt_port,
proxy=self.client.proxy,
media=self.is_media,
protocol_factory=self.client.protocol_factory
)
try: