From ebe1f3433f9b5d11f8c4548a5d41d1205c91894c Mon Sep 17 00:00:00 2001 From: rpt0 Date: Wed, 6 Aug 2025 11:22:19 -0400 Subject: [PATCH 1/5] async tcp --- pyrogram/connection/transport/tcp/tcp.py | 437 ++++++++++++++++++++--- 1 file changed, 380 insertions(+), 57 deletions(-) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 1848ba35..d1ce181b 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -18,15 +18,14 @@ # along with Pyrofork. If not, see . import asyncio +import base64 import ipaddress -import logging import socket -from typing import Tuple, Dict, TypedDict, Optional +import struct +from typing import Tuple, Dict, TypedDict, Optional, Union import socks -log = logging.getLogger(__name__) - proxy_type_by_scheme: Dict[str, int] = { "SOCKS4": socks.SOCKS4, "SOCKS5": socks.SOCKS5, @@ -42,6 +41,271 @@ class Proxy(TypedDict): password: Optional[str] +class ProxyError(Exception): + """Base exception for proxy-related errors.""" + + +class AuthenticationError(ProxyError): + """Authentication failed.""" + + +class ConnectionError(ProxyError): + """Connection failed.""" + + +class SOCKS5Handler: + """Handles SOCKS5 proxy operations.""" + + @staticmethod + async def negotiate_auth( + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + username: Optional[str], + password: Optional[str], + timeout: int, + ) -> int: + """Handle SOCKS5 authentication negotiation.""" + auth_methods = b"\x05\x02\x00\x02" if username and password else b"\x05\x01\x00" + + writer.write(auth_methods) + await writer.drain() + + response = await asyncio.wait_for(reader.read(2), timeout) + if len(response) != 2 or response[0] != 0x05: + raise ConnectionError("Invalid SOCKS5 response") + + return response[1] + + @staticmethod + async def authenticate( + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + username: str, + password: str, + timeout: int, + ) -> None: + """Perform username/password authentication.""" + username_bytes = username.encode("utf-8") + password_bytes = password.encode("utf-8") + auth_request = ( + bytes([0x01, len(username_bytes)]) + + username_bytes + + bytes([len(password_bytes)]) + + password_bytes + ) + + writer.write(auth_request) + await writer.drain() + + auth_response = await asyncio.wait_for(reader.read(2), timeout) + if len(auth_response) != 2 or auth_response[1] != 0x00: + raise AuthenticationError("SOCKS5 authentication failed") + + @staticmethod + def build_connect_request(host: str, port: int) -> bytes: + """Build SOCKS5 connection request.""" + request = bytearray([0x05, 0x01, 0x00]) + + try: + ip = ipaddress.ip_address(host) + if isinstance(ip, ipaddress.IPv4Address): + request.append(0x01) + request.extend(ip.packed) + else: + request.append(0x04) + request.extend(ip.packed) + except ValueError: + host_bytes = host.encode("utf-8") + request.append(0x03) + request.append(len(host_bytes)) + request.extend(host_bytes) + + request.extend(struct.pack(">H", port)) + return bytes(request) + + @staticmethod + async def read_bound_address(reader: asyncio.StreamReader, addr_type: int) -> None: + """Read bound address from SOCKS5 response.""" + if addr_type == 0x01: + await reader.read(6) + elif addr_type == 0x03: + domain_len = (await reader.read(1))[0] + await reader.read(domain_len + 2) + elif addr_type == 0x04: + await reader.read(18) + + @classmethod + async def handshake( + cls, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + destination: Tuple[str, int], + *, + username: Optional[str] = None, + password: Optional[str] = None, + timeout: int = 10, + ) -> None: + """Perform complete SOCKS5 handshake.""" + host, port = destination + + # Authentication negotiation + selected_method = await cls.negotiate_auth( + writer, reader, username, password, timeout + ) + + # Handle authentication + if selected_method == 0x02: + if not username or not password: + raise ConnectionError("SOCKS5 server requires authentication") + await cls.authenticate(writer, reader, username, password, timeout) + elif selected_method != 0x00: + raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}") + + # Connection request + request = cls.build_connect_request(host, port) + writer.write(request) + await writer.drain() + + # Read connection response + conn_response = await asyncio.wait_for(reader.read(4), timeout) + if ( + len(conn_response) != 4 + or conn_response[0] != 0x05 + or conn_response[1] != 0x00 + ): + raise ConnectionError("SOCKS5 connection failed") + + # Read bound address + await cls.read_bound_address(reader, conn_response[3]) + + +class SOCKS4Handler: + """Handles SOCKS4 proxy operations.""" + + @staticmethod + def build_request(host: str, port: int, username: Optional[str] = None) -> bytes: + """Build SOCKS4 connection request.""" + try: + ip = ipaddress.IPv4Address(host) + ip_bytes = ip.packed + user_id = (username or "pyrogram").encode("utf-8") + request = ( + struct.pack(">BBH", 0x04, 0x01, port) + ip_bytes + user_id + b"\x00" + ) + except ValueError: + # SOCKS4A - use domain name + ip_bytes = b"\x00\x00\x00\x01" + user_id = (username or "pyrogram").encode("utf-8") + request = ( + struct.pack(">BBH", 0x04, 0x01, port) + + ip_bytes + + user_id + + b"\x00" + + host.encode("utf-8") + + b"\x00" + ) + + return request + + @classmethod + async def handshake( + cls, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + destination: Tuple[str, int], + *, + username: Optional[str] = None, + timeout: int = 10, + ) -> None: + """Perform SOCKS4 handshake.""" + host, port = destination + request = cls.build_request(host, port, username) + + writer.write(request) + await writer.drain() + + response = await asyncio.wait_for(reader.read(8), timeout) + if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: + raise ConnectionError("SOCKS4 connection failed") + + +class HTTPProxyHandler: + """Handles HTTP proxy operations.""" + + @staticmethod + def build_request( + host: str, + port: int, + username: Optional[str] = None, + password: Optional[str] = None, + ) -> str: + """Build HTTP CONNECT request.""" + request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n" + + if username and password: + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + request += f"Proxy-Authorization: Basic {credentials}\r\n" + + return request + "\r\n" + + @staticmethod + def sanitize_request(request: str) -> str: + """Sanitize HTTP request to prevent injection attacks.""" + if "\r\n\r\n" not in request: + raise ValueError("Invalid HTTP request format") + + lines = request.split("\r\n") + if not lines[0].startswith("CONNECT ") or "HTTP/1.1" not in lines[0]: + raise ValueError("Invalid CONNECT request") + + return request + + @staticmethod + async def read_response(reader: asyncio.StreamReader, timeout: int) -> list: + """Read HTTP proxy response.""" + response_lines = [] + while True: + line = await asyncio.wait_for(reader.readline(), timeout) + if not line: + raise ConnectionError("HTTP proxy connection closed") + + line = line.decode("utf-8", errors="ignore").strip() + response_lines.append(line) + + if not line: + break + + return response_lines + + @classmethod + async def handshake( + cls, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + destination: Tuple[str, int], + *, + username: Optional[str] = None, + password: Optional[str] = None, + timeout: int = 10, + ) -> None: + """Perform HTTP proxy handshake.""" + host, port = destination + request = cls.build_request(host, port, username, password) + sanitized_request = cls.sanitize_request(request) + + writer.write(sanitized_request.encode()) + await writer.drain() + + response_lines = await cls.read_response(reader, timeout) + + if not response_lines or not response_lines[0].startswith("HTTP/1."): + raise ConnectionError("Invalid HTTP proxy response") + + status_parts = response_lines[0].split(" ", 2) + if len(status_parts) < 2 or status_parts[1] != "200": + raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}") + + class TCP: TIMEOUT = 10 @@ -53,90 +317,150 @@ class TCP: self.writer: Optional[asyncio.StreamWriter] = None self.lock = asyncio.Lock() - self.loop = asyncio.get_event_loop() - async def _connect_via_proxy( - self, - destination: Tuple[str, int] - ) -> None: - scheme = self.proxy.get("scheme") - if scheme is None: - raise ValueError("No scheme specified") + def _get_proxy_handler( + self, scheme: str + ) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]: + """Get appropriate proxy handler based on scheme.""" + handlers = { + "SOCKS5": SOCKS5Handler, + "SOCKS4": SOCKS4Handler, + "HTTP": HTTPProxyHandler, + } - proxy_type = proxy_type_by_scheme.get(scheme.upper()) - if proxy_type is None: + handler_class = handlers.get(scheme.upper()) + if not handler_class: raise ValueError(f"Unknown proxy type {scheme}") + return handler_class + + def _validate_proxy_config(self) -> None: + """Validate proxy configuration.""" + if not self.proxy.get("scheme"): + raise ValueError("No scheme specified") + + if self.proxy["scheme"].upper() not in proxy_type_by_scheme: + raise ValueError(f"Unknown proxy type {self.proxy['scheme']}") + + async def _establish_proxy_connection( + self, hostname: str, port: int, proxy_family: int + ) -> None: + """Establish connection to proxy server.""" + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_connection(host=hostname, port=port, family=proxy_family), + timeout=self.TIMEOUT, + ) + + def _get_proxy_family(self, hostname: str) -> int: + """Determine address family for proxy connection.""" + try: + ip_address = ipaddress.ip_address(hostname) + return ( + socket.AF_INET6 + if isinstance(ip_address, ipaddress.IPv6Address) + else socket.AF_INET + ) + except ValueError: + return socket.AF_INET + + async def _perform_handshake( + self, + scheme: str, + destination: Tuple[str, int], + username: Optional[str], + password: Optional[str], + ) -> None: + """Perform proxy handshake based on scheme.""" + handler = self._get_proxy_handler(scheme) + + if scheme.upper() == "SOCKS5": + await handler.handshake( + self.writer, + self.reader, + destination, + username=username, + password=password, + timeout=self.TIMEOUT, + ) + elif scheme.upper() == "SOCKS4": + await handler.handshake( + self.writer, + self.reader, + destination, + username=username, + timeout=self.TIMEOUT, + ) + elif scheme.upper() == "HTTP": + await handler.handshake( + self.writer, + self.reader, + destination, + username=username, + password=password, + timeout=self.TIMEOUT, + ) + + async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None: + """Connect through proxy server.""" + self._validate_proxy_config() + + scheme = self.proxy["scheme"] hostname = self.proxy.get("hostname") port = self.proxy.get("port") username = self.proxy.get("username") password = self.proxy.get("password") + proxy_family = self._get_proxy_family(hostname) + try: - ip_address = ipaddress.ip_address(hostname) - except ValueError: - is_proxy_ipv6 = False - else: - is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address) + await self._establish_proxy_connection(hostname, port, proxy_family) + await self._perform_handshake(scheme, destination, username, password) + except (ConnectionError, AuthenticationError, ValueError): + if self.writer: + self.writer.close() + await self.writer.wait_closed() + raise + except Exception as e: + if self.writer: + self.writer.close() + await self.writer.wait_closed() + raise ConnectionError(f"Proxy connection failed: {e}") from e - proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET - sock = socks.socksocket(proxy_family) - - sock.set_proxy( - proxy_type=proxy_type, - addr=hostname, - port=port, - username=username, - password=password - ) - sock.settimeout(TCP.TIMEOUT) - - await self.loop.sock_connect( - sock=sock, - address=destination - ) - - sock.setblocking(False) - - self.reader, self.writer = await asyncio.open_connection( - sock=sock - ) - - async def _connect_via_direct( - self, - destination: Tuple[str, int] - ) -> None: + async def _connect_via_direct(self, destination: Tuple[str, int]) -> None: + """Connect directly to destination.""" host, port = destination family = socket.AF_INET6 if self.ipv6 else socket.AF_INET self.reader, self.writer = await asyncio.open_connection( - host=host, - port=port, - family=family + host=host, port=port, family=family ) async def _connect(self, destination: Tuple[str, int]) -> None: + """Establish connection (direct or via proxy).""" if self.proxy: await self._connect_via_proxy(destination) else: await self._connect_via_direct(destination) async def connect(self, address: Tuple[str, int]) -> None: + """Connect to the specified address.""" try: - await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) - except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 + await asyncio.wait_for(self._connect(address), self.TIMEOUT) + except asyncio.TimeoutError: raise TimeoutError("Connection timed out") async def close(self) -> None: + """Close the connection.""" if self.writer is None: return None try: self.writer.close() - await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) - except Exception as e: - log.info("Close exception: %s %s", type(e).__name__, e) + await asyncio.wait_for(self.writer.wait_closed(), self.TIMEOUT) + except Exception: + pass async def send(self, data: bytes) -> None: + """Send data through the connection.""" if self.writer is None: return None @@ -145,17 +469,16 @@ class TCP: self.writer.write(data) await self.writer.drain() except Exception as e: - log.info("Send exception: %s %s", type(e).__name__, e) raise OSError(e) async def recv(self, length: int = 0) -> Optional[bytes]: + """Receive data from the connection.""" data = b"" while len(data) < length: try: chunk = await asyncio.wait_for( - self.reader.read(length - len(data)), - TCP.TIMEOUT + self.reader.read(length - len(data)), self.TIMEOUT ) except (OSError, asyncio.TimeoutError): return None From 7c58bb34af1a525923610908224ed3f6b6f99d85 Mon Sep 17 00:00:00 2001 From: rvck Date: Wed, 6 Aug 2025 22:41:37 +0700 Subject: [PATCH 2/5] Update tcp.py --- pyrogram/connection/transport/tcp/tcp.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index d1ce181b..86fa6ae5 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -456,7 +456,7 @@ class TCP: try: self.writer.close() await asyncio.wait_for(self.writer.wait_closed(), self.TIMEOUT) - except Exception: + except (OSError, asyncio.TimeoutError): pass async def send(self, data: bytes) -> None: @@ -468,7 +468,7 @@ class TCP: try: self.writer.write(data) await self.writer.drain() - except Exception as e: + except (OSError, asyncio.TimeoutError) as e: raise OSError(e) async def recv(self, length: int = 0) -> Optional[bytes]: From f7e64c6cb46337fdd476000ed39aa716aa15f937 Mon Sep 17 00:00:00 2001 From: rvck Date: Wed, 6 Aug 2025 22:50:56 +0700 Subject: [PATCH 3/5] Update tcp.py --- pyrogram/connection/transport/tcp/tcp.py | 1 - 1 file changed, 1 deletion(-) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 86fa6ae5..3b7b0944 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -388,7 +388,6 @@ class TCP: self.reader, destination, username=username, - timeout=self.TIMEOUT, ) elif scheme.upper() == "HTTP": await handler.handshake( From f724347ddd2cc63fae6176d1e4ed283f5e16bccb Mon Sep 17 00:00:00 2001 From: rvck Date: Wed, 6 Aug 2025 22:57:42 +0700 Subject: [PATCH 4/5] Update tcp.py --- pyrogram/connection/transport/tcp/tcp.py | 6 ++---- 1 file changed, 2 insertions(+), 4 deletions(-) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 3b7b0944..380d6064 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -213,9 +213,7 @@ class SOCKS4Handler: writer: asyncio.StreamWriter, reader: asyncio.StreamReader, destination: Tuple[str, int], - *, - username: Optional[str] = None, - timeout: int = 10, + username: Optional[str] = None ) -> None: """Perform SOCKS4 handshake.""" host, port = destination @@ -224,7 +222,7 @@ class SOCKS4Handler: writer.write(request) await writer.drain() - response = await asyncio.wait_for(reader.read(8), timeout) + response = await asyncio.wait_for(reader.read(8), 10) if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: raise ConnectionError("SOCKS4 connection failed") From d26e65cfe3a94f2836d2f10da28ddf19a2a33fee Mon Sep 17 00:00:00 2001 From: rvck Date: Wed, 6 Aug 2025 22:59:18 +0700 Subject: [PATCH 5/5] Update tcp.py