From 9fad7749493db8d5087741b1783c4140e47f58bb Mon Sep 17 00:00:00 2001 From: root Date: Wed, 6 Aug 2025 07:44:20 -0400 Subject: [PATCH 1/4] fix tcp proxy --- pyrogram/connection/transport/tcp/tcp.py | 210 ++++++++++++++++++----- 1 file changed, 169 insertions(+), 41 deletions(-) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index 1848ba35..d8d4decd 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -18,15 +18,14 @@ # along with Pyrofork. If not, see . import asyncio +import base64 import ipaddress -import logging import socket +import struct from typing import Tuple, Dict, TypedDict, Optional import socks -log = logging.getLogger(__name__) - proxy_type_by_scheme: Dict[str, int] = { "SOCKS4": socks.SOCKS4, "SOCKS5": socks.SOCKS5, @@ -53,12 +52,139 @@ class TCP: self.writer: Optional[asyncio.StreamWriter] = None self.lock = asyncio.Lock() - self.loop = asyncio.get_event_loop() - async def _connect_via_proxy( - self, - destination: Tuple[str, int] - ) -> None: + async def _socks5_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], username: Optional[str] = None, + password: Optional[str] = None) -> None: + # Authentication negotiation + if username and password: + auth_methods = b'\x05\x02\x00\x02' + else: + auth_methods = b'\x05\x01\x00' + + writer.write(auth_methods) + await writer.drain() + + response = await asyncio.wait_for(reader.read(2), self.TIMEOUT) + if len(response) != 2 or response[0] != 0x05: + raise ConnectionError("Invalid SOCKS5 response") + + selected_method = response[1] + + # Authentication if required + if selected_method == 0x02: + if not username or not password: + raise ConnectionError("SOCKS5 server requires authentication") + + username_bytes = username.encode('utf-8') + password_bytes = password.encode('utf-8') + auth_request = bytes([0x01, len(username_bytes)]) + username_bytes + bytes([len(password_bytes)]) + password_bytes + + writer.write(auth_request) + await writer.drain() + + auth_response = await asyncio.wait_for(reader.read(2), self.TIMEOUT) + if len(auth_response) != 2 or auth_response[1] != 0x00: + raise ConnectionError("SOCKS5 authentication failed") + + elif selected_method != 0x00: + raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}") + + # Connection request + host, port = destination + request = bytearray([0x05, 0x01, 0x00]) + + try: + ip = ipaddress.ip_address(host) + if isinstance(ip, ipaddress.IPv4Address): + request.append(0x01) + request.extend(ip.packed) + else: + request.append(0x04) + request.extend(ip.packed) + except ValueError: + host_bytes = host.encode('utf-8') + request.append(0x03) + request.append(len(host_bytes)) + request.extend(host_bytes) + + request.extend(struct.pack('>H', port)) + + writer.write(request) + await writer.drain() + + conn_response = await asyncio.wait_for(reader.read(4), self.TIMEOUT) + if len(conn_response) != 4 or conn_response[0] != 0x05 or conn_response[1] != 0x00: + raise ConnectionError("SOCKS5 connection failed") + + # Read bound address + addr_type = conn_response[3] + if addr_type == 0x01: + await reader.read(6) + elif addr_type == 0x03: + domain_len = (await reader.read(1))[0] + await reader.read(domain_len + 2) + elif addr_type == 0x04: + await reader.read(18) + + async def _socks4_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int]) -> None: + host, port = destination + + try: + ip = ipaddress.IPv4Address(host) + ip_bytes = ip.packed + except ValueError: + ip_bytes = b'\x00\x00\x00\x01' + + request = struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + b'pyrogram\x00' + + if ip_bytes == b'\x00\x00\x00\x01': + request += host.encode('utf-8') + b'\x00' + + writer.write(request) + await writer.drain() + + response = await asyncio.wait_for(reader.read(8), self.TIMEOUT) + if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: + raise ConnectionError("SOCKS4 connection failed") + + async def _http_proxy_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], username: Optional[str] = None, + password: Optional[str] = None) -> None: + host, port = destination + + request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n" + + if username and password: + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + request += f"Proxy-Authorization: Basic {credentials}\r\n" + + request += "\r\n" + + writer.write(request.encode()) + await writer.drain() + + response_lines = [] + while True: + line = await asyncio.wait_for(reader.readline(), self.TIMEOUT) + if not line: + raise ConnectionError("HTTP proxy connection closed") + + line = line.decode('utf-8', errors='ignore').strip() + response_lines.append(line) + + if not line: + break + + if not response_lines or not response_lines[0].startswith('HTTP/1.'): + raise ConnectionError("Invalid HTTP proxy response") + + status_parts = response_lines[0].split(' ', 2) + if len(status_parts) < 2 or status_parts[1] != '200': + raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}") + + async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None: scheme = self.proxy.get("scheme") if scheme is None: raise ValueError("No scheme specified") @@ -80,32 +206,35 @@ class TCP: is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address) proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET - sock = socks.socksocket(proxy_family) + + try: + # Connect to proxy server + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_connection( + host=hostname, + port=port, + family=proxy_family + ), + timeout=self.TIMEOUT + ) + + # Perform proxy handshake + if proxy_type == socks.SOCKS5: + await self._socks5_handshake(self.writer, self.reader, destination, username, password) + elif proxy_type == socks.SOCKS4: + await self._socks4_handshake(self.writer, self.reader, destination) + elif proxy_type == socks.HTTP: + await self._http_proxy_handshake(self.writer, self.reader, destination, username, password) + else: + raise ValueError(f"Unsupported proxy type: {scheme}") + + except Exception: + if self.writer: + self.writer.close() + await self.writer.wait_closed() + raise - sock.set_proxy( - proxy_type=proxy_type, - addr=hostname, - port=port, - username=username, - password=password - ) - sock.settimeout(TCP.TIMEOUT) - - await self.loop.sock_connect( - sock=sock, - address=destination - ) - - sock.setblocking(False) - - self.reader, self.writer = await asyncio.open_connection( - sock=sock - ) - - async def _connect_via_direct( - self, - destination: Tuple[str, int] - ) -> None: + async def _connect_via_direct(self, destination: Tuple[str, int]) -> None: host, port = destination family = socket.AF_INET6 if self.ipv6 else socket.AF_INET self.reader, self.writer = await asyncio.open_connection( @@ -122,8 +251,8 @@ class TCP: async def connect(self, address: Tuple[str, int]) -> None: try: - await asyncio.wait_for(self._connect(address), TCP.TIMEOUT) - except asyncio.TimeoutError: # Re-raise as TimeoutError. asyncio.TimeoutError is deprecated in 3.11 + await asyncio.wait_for(self._connect(address), self.TIMEOUT) + except asyncio.TimeoutError: raise TimeoutError("Connection timed out") async def close(self) -> None: @@ -132,9 +261,9 @@ class TCP: try: self.writer.close() - await asyncio.wait_for(self.writer.wait_closed(), TCP.TIMEOUT) - except Exception as e: - log.info("Close exception: %s %s", type(e).__name__, e) + await asyncio.wait_for(self.writer.wait_closed(), self.TIMEOUT) + except Exception: + pass async def send(self, data: bytes) -> None: if self.writer is None: @@ -145,7 +274,6 @@ class TCP: self.writer.write(data) await self.writer.drain() except Exception as e: - log.info("Send exception: %s %s", type(e).__name__, e) raise OSError(e) async def recv(self, length: int = 0) -> Optional[bytes]: @@ -155,7 +283,7 @@ class TCP: try: chunk = await asyncio.wait_for( self.reader.read(length - len(data)), - TCP.TIMEOUT + self.TIMEOUT ) except (OSError, asyncio.TimeoutError): return None @@ -165,4 +293,4 @@ class TCP: else: return None - return data + return data \ No newline at end of file From 19eb456ad21b88a4700afeff3268941013e45077 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 6 Aug 2025 10:31:54 -0400 Subject: [PATCH 2/4] fix --- pyrogram/connection/transport/tcp/tcp.py | 468 ++++++++++++++-------- pyrogram/connection/transport/tcp/test.py | 377 +++++++++++++++++ 2 files changed, 684 insertions(+), 161 deletions(-) create mode 100644 pyrogram/connection/transport/tcp/test.py diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index d8d4decd..ab8884a1 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -22,7 +22,7 @@ import base64 import ipaddress import socket import struct -from typing import Tuple, Dict, TypedDict, Optional +from typing import Tuple, Dict, TypedDict, Optional, Union import socks @@ -41,6 +41,254 @@ class Proxy(TypedDict): password: Optional[str] +class ProxyError(Exception): + """Base exception for proxy-related errors.""" + + pass + + +class AuthenticationError(ProxyError): + """Authentication failed.""" + + pass + + +class ConnectionError(ProxyError): + """Connection failed.""" + + pass + + +class SOCKS5Handler: + """Handles SOCKS5 proxy operations.""" + + @staticmethod + async def negotiate_auth( + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + username: Optional[str], + password: Optional[str], + timeout: int, + ) -> None: + """Handle SOCKS5 authentication negotiation.""" + auth_methods = b"\x05\x02\x00\x02" if username and password else b"\x05\x01\x00" + + writer.write(auth_methods) + await writer.drain() + + response = await asyncio.wait_for(reader.read(2), timeout) + if len(response) != 2 or response[0] != 0x05: + raise ConnectionError("Invalid SOCKS5 response") + + return response[1] + + @staticmethod + async def authenticate( + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + username: str, + password: str, + timeout: int, + ) -> None: + """Perform username/password authentication.""" + username_bytes = username.encode("utf-8") + password_bytes = password.encode("utf-8") + auth_request = ( + bytes([0x01, len(username_bytes)]) + + username_bytes + + bytes([len(password_bytes)]) + + password_bytes + ) + + writer.write(auth_request) + await writer.drain() + + auth_response = await asyncio.wait_for(reader.read(2), timeout) + if len(auth_response) != 2 or auth_response[1] != 0x00: + raise AuthenticationError("SOCKS5 authentication failed") + + @staticmethod + def build_connect_request(host: str, port: int) -> bytes: + """Build SOCKS5 connection request.""" + request = bytearray([0x05, 0x01, 0x00]) + + try: + ip = ipaddress.ip_address(host) + if isinstance(ip, ipaddress.IPv4Address): + request.append(0x01) + request.extend(ip.packed) + else: + request.append(0x04) + request.extend(ip.packed) + except ValueError: + host_bytes = host.encode("utf-8") + request.append(0x03) + request.append(len(host_bytes)) + request.extend(host_bytes) + + request.extend(struct.pack(">H", port)) + return bytes(request) + + @staticmethod + async def read_bound_address(reader: asyncio.StreamReader, addr_type: int) -> None: + """Read bound address from SOCKS5 response.""" + if addr_type == 0x01: + await reader.read(6) + elif addr_type == 0x03: + domain_len = (await reader.read(1))[0] + await reader.read(domain_len + 2) + elif addr_type == 0x04: + await reader.read(18) + + @classmethod + async def handshake( + cls, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + destination: Tuple[str, int], + username: Optional[str] = None, + password: Optional[str] = None, + timeout: int = 10, + ) -> None: + """Perform complete SOCKS5 handshake.""" + host, port = destination + + # Authentication negotiation + selected_method = await cls.negotiate_auth( + writer, reader, username, password, timeout + ) + + # Handle authentication + if selected_method == 0x02: + if not username or not password: + raise ConnectionError("SOCKS5 server requires authentication") + await cls.authenticate(writer, reader, username, password, timeout) + elif selected_method != 0x00: + raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}") + + # Connection request + request = cls.build_connect_request(host, port) + writer.write(request) + await writer.drain() + + # Read connection response + conn_response = await asyncio.wait_for(reader.read(4), timeout) + if ( + len(conn_response) != 4 + or conn_response[0] != 0x05 + or conn_response[1] != 0x00 + ): + raise ConnectionError("SOCKS5 connection failed") + + # Read bound address + await cls.read_bound_address(reader, conn_response[3]) + + +class SOCKS4Handler: + """Handles SOCKS4 proxy operations.""" + + @staticmethod + def build_request(host: str, port: int) -> bytes: + """Build SOCKS4 connection request.""" + try: + ip = ipaddress.IPv4Address(host) + ip_bytes = ip.packed + request = struct.pack(">BBH", 0x04, 0x01, port) + ip_bytes + b"pyrogram\x00" + except ValueError: + ip_bytes = b"\x00\x00\x00\x01" + request = ( + struct.pack(">BBH", 0x04, 0x01, port) + + ip_bytes + + b"pyrogram\x00" + + host.encode("utf-8") + + b"\x00" + ) + + return request + + @classmethod + async def handshake( + cls, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + destination: Tuple[str, int], + timeout: int = 10, + ) -> None: + """Perform SOCKS4 handshake.""" + host, port = destination + request = cls.build_request(host, port) + + writer.write(request) + await writer.drain() + + response = await asyncio.wait_for(reader.read(8), timeout) + if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: + raise ConnectionError("SOCKS4 connection failed") + + +class HTTPProxyHandler: + """Handles HTTP proxy operations.""" + + @staticmethod + def build_request( + host: str, + port: int, + username: Optional[str] = None, + password: Optional[str] = None, + ) -> str: + """Build HTTP CONNECT request.""" + request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n" + + if username and password: + credentials = base64.b64encode(f"{username}:{password}".encode()).decode() + request += f"Proxy-Authorization: Basic {credentials}\r\n" + + return request + "\r\n" + + @staticmethod + async def read_response(reader: asyncio.StreamReader, timeout: int) -> list: + """Read HTTP proxy response.""" + response_lines = [] + while True: + line = await asyncio.wait_for(reader.readline(), timeout) + if not line: + raise ConnectionError("HTTP proxy connection closed") + + line = line.decode("utf-8", errors="ignore").strip() + response_lines.append(line) + + if not line: + break + + return response_lines + + @classmethod + async def handshake( + cls, + writer: asyncio.StreamWriter, + reader: asyncio.StreamReader, + destination: Tuple[str, int], + username: Optional[str] = None, + password: Optional[str] = None, + timeout: int = 10, + ) -> None: + """Perform HTTP proxy handshake.""" + host, port = destination + request = cls.build_request(host, port, username, password) + + writer.write(request.encode()) + await writer.drain() + + response_lines = await cls.read_response(reader, timeout) + + if not response_lines or not response_lines[0].startswith("HTTP/1."): + raise ConnectionError("Invalid HTTP proxy response") + + status_parts = response_lines[0].split(" ", 2) + if len(status_parts) < 2 or status_parts[1] != "200": + raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}") + + class TCP: TIMEOUT = 10 @@ -53,181 +301,76 @@ class TCP: self.lock = asyncio.Lock() - async def _socks5_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, - destination: Tuple[str, int], username: Optional[str] = None, - password: Optional[str] = None) -> None: - # Authentication negotiation - if username and password: - auth_methods = b'\x05\x02\x00\x02' - else: - auth_methods = b'\x05\x01\x00' - - writer.write(auth_methods) - await writer.drain() - - response = await asyncio.wait_for(reader.read(2), self.TIMEOUT) - if len(response) != 2 or response[0] != 0x05: - raise ConnectionError("Invalid SOCKS5 response") - - selected_method = response[1] - - # Authentication if required - if selected_method == 0x02: - if not username or not password: - raise ConnectionError("SOCKS5 server requires authentication") - - username_bytes = username.encode('utf-8') - password_bytes = password.encode('utf-8') - auth_request = bytes([0x01, len(username_bytes)]) + username_bytes + bytes([len(password_bytes)]) + password_bytes - - writer.write(auth_request) - await writer.drain() - - auth_response = await asyncio.wait_for(reader.read(2), self.TIMEOUT) - if len(auth_response) != 2 or auth_response[1] != 0x00: - raise ConnectionError("SOCKS5 authentication failed") - - elif selected_method != 0x00: - raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}") - - # Connection request - host, port = destination - request = bytearray([0x05, 0x01, 0x00]) - - try: - ip = ipaddress.ip_address(host) - if isinstance(ip, ipaddress.IPv4Address): - request.append(0x01) - request.extend(ip.packed) - else: - request.append(0x04) - request.extend(ip.packed) - except ValueError: - host_bytes = host.encode('utf-8') - request.append(0x03) - request.append(len(host_bytes)) - request.extend(host_bytes) - - request.extend(struct.pack('>H', port)) - - writer.write(request) - await writer.drain() - - conn_response = await asyncio.wait_for(reader.read(4), self.TIMEOUT) - if len(conn_response) != 4 or conn_response[0] != 0x05 or conn_response[1] != 0x00: - raise ConnectionError("SOCKS5 connection failed") - - # Read bound address - addr_type = conn_response[3] - if addr_type == 0x01: - await reader.read(6) - elif addr_type == 0x03: - domain_len = (await reader.read(1))[0] - await reader.read(domain_len + 2) - elif addr_type == 0x04: - await reader.read(18) + def _get_proxy_handler( + self, scheme: str + ) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]: + """Get appropriate proxy handler based on scheme.""" + handlers = { + "SOCKS5": SOCKS5Handler, + "SOCKS4": SOCKS4Handler, + "HTTP": HTTPProxyHandler, + } - async def _socks4_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, - destination: Tuple[str, int]) -> None: - host, port = destination - - try: - ip = ipaddress.IPv4Address(host) - ip_bytes = ip.packed - except ValueError: - ip_bytes = b'\x00\x00\x00\x01' - - request = struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + b'pyrogram\x00' - - if ip_bytes == b'\x00\x00\x00\x01': - request += host.encode('utf-8') + b'\x00' - - writer.write(request) - await writer.drain() - - response = await asyncio.wait_for(reader.read(8), self.TIMEOUT) - if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: - raise ConnectionError("SOCKS4 connection failed") + handler_class = handlers.get(scheme.upper()) + if not handler_class: + raise ValueError(f"Unknown proxy type {scheme}") - async def _http_proxy_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, - destination: Tuple[str, int], username: Optional[str] = None, - password: Optional[str] = None) -> None: - host, port = destination - - request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n" - - if username and password: - credentials = base64.b64encode(f"{username}:{password}".encode()).decode() - request += f"Proxy-Authorization: Basic {credentials}\r\n" - - request += "\r\n" - - writer.write(request.encode()) - await writer.drain() - - response_lines = [] - while True: - line = await asyncio.wait_for(reader.readline(), self.TIMEOUT) - if not line: - raise ConnectionError("HTTP proxy connection closed") - - line = line.decode('utf-8', errors='ignore').strip() - response_lines.append(line) - - if not line: - break - - if not response_lines or not response_lines[0].startswith('HTTP/1.'): - raise ConnectionError("Invalid HTTP proxy response") - - status_parts = response_lines[0].split(' ', 2) - if len(status_parts) < 2 or status_parts[1] != '200': - raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}") + return handler_class async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None: + """Connect through proxy server.""" scheme = self.proxy.get("scheme") - if scheme is None: + if not scheme: raise ValueError("No scheme specified") - proxy_type = proxy_type_by_scheme.get(scheme.upper()) - if proxy_type is None: - raise ValueError(f"Unknown proxy type {scheme}") - hostname = self.proxy.get("hostname") port = self.proxy.get("port") username = self.proxy.get("username") password = self.proxy.get("password") + # Determine proxy address family try: ip_address = ipaddress.ip_address(hostname) + proxy_family = ( + socket.AF_INET6 + if isinstance(ip_address, ipaddress.IPv6Address) + else socket.AF_INET + ) except ValueError: - is_proxy_ipv6 = False - else: - is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address) + proxy_family = socket.AF_INET - proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET - try: # Connect to proxy server self.reader, self.writer = await asyncio.wait_for( - asyncio.open_connection( - host=hostname, - port=port, - family=proxy_family - ), - timeout=self.TIMEOUT + asyncio.open_connection(host=hostname, port=port, family=proxy_family), + timeout=self.TIMEOUT, ) - + # Perform proxy handshake - if proxy_type == socks.SOCKS5: - await self._socks5_handshake(self.writer, self.reader, destination, username, password) - elif proxy_type == socks.SOCKS4: - await self._socks4_handshake(self.writer, self.reader, destination) - elif proxy_type == socks.HTTP: - await self._http_proxy_handshake(self.writer, self.reader, destination, username, password) - else: - raise ValueError(f"Unsupported proxy type: {scheme}") - + handler = self._get_proxy_handler(scheme) + if scheme.upper() == "SOCKS5": + await handler.handshake( + self.writer, + self.reader, + destination, + username, + password, + self.TIMEOUT, + ) + elif scheme.upper() == "SOCKS4": + await handler.handshake( + self.writer, self.reader, destination, self.TIMEOUT + ) + elif scheme.upper() == "HTTP": + await handler.handshake( + self.writer, + self.reader, + destination, + username, + password, + self.TIMEOUT, + ) + except Exception: if self.writer: self.writer.close() @@ -235,27 +378,29 @@ class TCP: raise async def _connect_via_direct(self, destination: Tuple[str, int]) -> None: + """Connect directly to destination.""" host, port = destination family = socket.AF_INET6 if self.ipv6 else socket.AF_INET self.reader, self.writer = await asyncio.open_connection( - host=host, - port=port, - family=family + host=host, port=port, family=family ) async def _connect(self, destination: Tuple[str, int]) -> None: + """Establish connection (direct or via proxy).""" if self.proxy: await self._connect_via_proxy(destination) else: await self._connect_via_direct(destination) async def connect(self, address: Tuple[str, int]) -> None: + """Connect to the specified address.""" try: await asyncio.wait_for(self._connect(address), self.TIMEOUT) except asyncio.TimeoutError: raise TimeoutError("Connection timed out") async def close(self) -> None: + """Close the connection.""" if self.writer is None: return None @@ -266,6 +411,7 @@ class TCP: pass async def send(self, data: bytes) -> None: + """Send data through the connection.""" if self.writer is None: return None @@ -277,13 +423,13 @@ class TCP: raise OSError(e) async def recv(self, length: int = 0) -> Optional[bytes]: + """Receive data from the connection.""" data = b"" while len(data) < length: try: chunk = await asyncio.wait_for( - self.reader.read(length - len(data)), - self.TIMEOUT + self.reader.read(length - len(data)), self.TIMEOUT ) except (OSError, asyncio.TimeoutError): return None @@ -293,4 +439,4 @@ class TCP: else: return None - return data \ No newline at end of file + return data diff --git a/pyrogram/connection/transport/tcp/test.py b/pyrogram/connection/transport/tcp/test.py new file mode 100644 index 00000000..d40efeb8 --- /dev/null +++ b/pyrogram/connection/transport/tcp/test.py @@ -0,0 +1,377 @@ +#!/usr/bin/env python3 +""" +TCP Implementation Diagnostic Tool for Pyrofork +Deep analysis to identify specific issues in TCP transport +""" + +import asyncio +import sys +import time +import traceback +import socket +import ssl +import os +from typing import Dict, Any, List, Tuple, Optional +import logging + +# Setup path for local imports +current_dir = os.path.dirname(os.path.abspath(__file__)) +root_dir = os.path.abspath(os.path.join(current_dir, '../../../../../')) +sys.path.insert(0, root_dir) + +# Import TCP implementation +try: + from pyrogram.connection.transport.tcp import TCP + print("Successfully imported TCP from local source") +except ImportError as e: + print(f"Failed to import TCP: {e}") + sys.exit(1) + +# Configure detailed logging +logging.basicConfig( + level=logging.DEBUG, + format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", + datefmt="%H:%M:%S" +) +logger = logging.getLogger(__name__) + +class TCPDiagnostics: + """Diagnostic tools for TCP implementation analysis""" + + def __init__(self): + self.telegram_servers = [ + ("149.154.175.50", 443, "DC1-Miami"), + ("149.154.167.51", 443, "DC2-Amsterdam"), + ("149.154.175.100", 443, "DC3-Miami"), + ("149.154.167.91", 443, "DC4-Amsterdam"), + ("91.108.56.130", 443, "DC5-Singapore"), + ] + + async def test_raw_socket_connection(self, host: str, port: int, name: str) -> Dict[str, Any]: + """Test raw socket connection without Pyrofork TCP wrapper""" + print(f"\n[RAW SOCKET] Testing {name} ({host}:{port})") + + start_time = time.time() + sock = None + + try: + # Create raw socket + sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) + sock.settimeout(10.0) + + # Test connection + sock.connect((host, port)) + connect_time = time.time() - start_time + + # Test basic send/receive + test_data = b'GET / HTTP/1.1\r\nHost: ' + host.encode() + b'\r\n\r\n' + sock.send(test_data) + + # Try to receive response + try: + response = sock.recv(1024) + received_data = len(response) > 0 + except socket.timeout: + received_data = False + + sock.close() + sock = None + + result = { + "success": True, + "connect_time": connect_time, + "received_data": received_data, + "error": None + } + + print(f"[RAW SOCKET] {name}: SUCCESS - {connect_time:.2f}s") + return result + + except Exception as e: + if sock: + sock.close() + + error_time = time.time() - start_time + print(f"[RAW SOCKET] {name}: FAILED - {str(e)} ({error_time:.2f}s)") + + return { + "success": False, + "connect_time": error_time, + "received_data": False, + "error": str(e) + } + + async def test_tcp_implementation(self, host: str, port: int, name: str) -> Dict[str, Any]: + """Test Pyrofork TCP implementation with detailed logging""" + print(f"\n[PYROFORK TCP] Testing {name} ({host}:{port})") + + start_time = time.time() + tcp = None + + try: + # Initialize TCP transport + print(f"[PYROFORK TCP] Initializing TCP transport...") + tcp = TCP(ipv6=False, proxy=None) + + print(f"[PYROFORK TCP] TCP object created: {type(tcp)}") + print(f"[PYROFORK TCP] TCP attributes: {dir(tcp)}") + + # Attempt connection + print(f"[PYROFORK TCP] Attempting connection...") + await asyncio.wait_for(tcp.connect((host, port)), timeout=10.0) + + connect_time = time.time() - start_time + print(f"[PYROFORK TCP] Connection established in {connect_time:.2f}s") + + # Test send operation + print(f"[PYROFORK TCP] Testing send operation...") + test_data = b'\x00\x00\x00\x00\x00\x00\x00\x00' + await tcp.send(test_data) + print(f"[PYROFORK TCP] Send operation successful") + + # Test receive operation + print(f"[PYROFORK TCP] Testing receive operation...") + try: + response = await asyncio.wait_for(tcp.recv(8), timeout=3.0) + received_data = len(response) > 0 if response else False + print(f"[PYROFORK TCP] Received {len(response) if response else 0} bytes") + except asyncio.TimeoutError: + received_data = False + print(f"[PYROFORK TCP] Receive timeout (normal for test data)") + + # Clean close + print(f"[PYROFORK TCP] Closing connection...") + await tcp.close() + tcp = None + print(f"[PYROFORK TCP] Connection closed successfully") + + result = { + "success": True, + "connect_time": connect_time, + "received_data": received_data, + "error": None + } + + print(f"[PYROFORK TCP] {name}: SUCCESS - {connect_time:.2f}s") + return result + + except Exception as e: + error_time = time.time() - start_time + print(f"[PYROFORK TCP] {name}: FAILED - {str(e)}") + print(f"[PYROFORK TCP] Exception details:") + traceback.print_exc() + + # Cleanup + if tcp: + try: + await tcp.close() + except: + pass + + return { + "success": False, + "connect_time": error_time, + "received_data": False, + "error": str(e) + } + + async def inspect_tcp_implementation(self): + """Inspect TCP class implementation details""" + print(f"\n[INSPECTION] Analyzing TCP implementation...") + + try: + tcp = TCP(ipv6=False, proxy=None) + + print(f"[INSPECTION] TCP class: {TCP}") + print(f"[INSPECTION] TCP module: {TCP.__module__}") + print(f"[INSPECTION] TCP file: {TCP.__module__.replace('.', os.sep)}.py") + + # Check methods + methods = [attr for attr in dir(tcp) if not attr.startswith('_')] + print(f"[INSPECTION] Available methods: {methods}") + + # Check for required methods + required_methods = ['connect', 'send', 'recv', 'close'] + for method in required_methods: + if hasattr(tcp, method): + method_obj = getattr(tcp, method) + print(f"[INSPECTION] {method}: {method_obj} (callable: {callable(method_obj)})") + else: + print(f"[INSPECTION] MISSING METHOD: {method}") + + # Check initialization parameters + print(f"[INSPECTION] TCP instance attributes:") + for attr in dir(tcp): + if not attr.startswith('__'): + try: + value = getattr(tcp, attr) + if not callable(value): + print(f"[INSPECTION] {attr}: {value}") + except: + print(f"[INSPECTION] {attr}: ") + + except Exception as e: + print(f"[INSPECTION] Error during inspection: {e}") + traceback.print_exc() + + async def test_concurrent_simple(self, host: str, port: int, count: int = 3) -> Dict[str, Any]: + """Simple concurrent connection test""" + print(f"\n[CONCURRENT] Testing {count} concurrent connections to {host}:{port}") + + async def single_connect(): + try: + tcp = TCP(ipv6=False, proxy=None) + await asyncio.wait_for(tcp.connect((host, port)), timeout=5.0) + await tcp.close() + return True + except Exception as e: + print(f"[CONCURRENT] Connection failed: {str(e)[:50]}") + return False + + start_time = time.time() + tasks = [single_connect() for _ in range(count)] + results = await asyncio.gather(*tasks, return_exceptions=True) + total_time = time.time() - start_time + + successful = sum(1 for r in results if r is True) + success_rate = (successful / count) * 100 + + print(f"[CONCURRENT] Results: {successful}/{count} successful ({success_rate:.1f}%) in {total_time:.2f}s") + + return { + "successful": successful, + "total": count, + "success_rate": success_rate, + "total_time": total_time + } + + async def run_diagnostic_suite(self): + """Run comprehensive diagnostic tests""" + print("="*80) + print("TCP IMPLEMENTATION DIAGNOSTIC SUITE") + print("="*80) + + # Phase 1: Inspect implementation + await self.inspect_tcp_implementation() + + # Phase 2: Raw socket tests (baseline) + print(f"\n" + "="*50) + print("PHASE 1: RAW SOCKET BASELINE TESTS") + print("="*50) + + raw_results = [] + for host, port, name in self.telegram_servers: + result = await self.test_raw_socket_connection(host, port, name) + raw_results.append((name, result)) + await asyncio.sleep(0.5) + + # Phase 3: Pyrofork TCP tests + print(f"\n" + "="*50) + print("PHASE 2: PYROFORK TCP IMPLEMENTATION TESTS") + print("="*50) + + tcp_results = [] + for host, port, name in self.telegram_servers: + result = await self.test_tcp_implementation(host, port, name) + tcp_results.append((name, result)) + await asyncio.sleep(1.0) + + # Phase 4: Concurrent test + print(f"\n" + "="*50) + print("PHASE 3: CONCURRENT CONNECTION TEST") + print("="*50) + + concurrent_result = await self.test_concurrent_simple( + self.telegram_servers[0][0], + self.telegram_servers[0][1] + ) + + # Analysis and recommendations + print(f"\n" + "="*50) + print("DIAGNOSTIC ANALYSIS") + print("="*50) + + raw_success = sum(1 for _, r in raw_results if r["success"]) + tcp_success = sum(1 for _, r in tcp_results if r["success"]) + + print(f"\nResults Summary:") + print(f" Raw Socket Success: {raw_success}/{len(raw_results)} ({raw_success/len(raw_results)*100:.1f}%)") + print(f" Pyrofork TCP Success: {tcp_success}/{len(tcp_results)} ({tcp_success/len(tcp_results)*100:.1f}%)") + print(f" Concurrent Success: {concurrent_result['success_rate']:.1f}%") + + print(f"\nDiagnostic Findings:") + + if raw_success < len(raw_results) * 0.8: + print(" - Network connectivity issues detected") + print(" - Check firewall, DNS resolution, or internet connection") + else: + print(" - Network connectivity is good (raw sockets work)") + + if tcp_success < raw_success: + print(" - Pyrofork TCP implementation has issues") + print(" - TCP wrapper is failing where raw sockets succeed") + print(" - Likely issues: async/await implementation, connection state management") + elif tcp_success == raw_success: + print(" - Pyrofork TCP implementation matches raw socket performance") + print(" - Implementation appears correct") + + if concurrent_result['success_rate'] > tcp_success / len(tcp_results) * 100: + print(" - Concurrent connections work better than sequential") + print(" - Possible timing or state management issue in sequential testing") + + # Specific recommendations + print(f"\nRecommendations:") + + if tcp_success < raw_success: + print(" 1. Review TCP.connect() implementation for async/await correctness") + print(" 2. Check socket state management and cleanup") + print(" 3. Verify timeout handling in async operations") + print(" 4. Test with simpler connection sequence") + + if raw_success < len(raw_results): + print(" 1. Check network connectivity to Telegram servers") + print(" 2. Verify DNS resolution") + print(" 3. Check for firewall blocking connections") + + return { + "raw_results": raw_results, + "tcp_results": tcp_results, + "concurrent_result": concurrent_result + } + + +async def main(): + """Main diagnostic execution""" + try: + diagnostics = TCPDiagnostics() + results = await diagnostics.run_diagnostic_suite() + + # Determine if ready for debugging or needs environment fixes + tcp_success_rate = sum(1 for _, r in results["tcp_results"] if r["success"]) / len(results["tcp_results"]) + raw_success_rate = sum(1 for _, r in results["raw_results"] if r["success"]) / len(results["raw_results"]) + + if raw_success_rate < 0.8: + print(f"\nStatus: Environment issues detected - fix network connectivity first") + sys.exit(2) + elif tcp_success_rate < 0.5: + print(f"\nStatus: TCP implementation issues detected - needs code review") + sys.exit(1) + else: + print(f"\nStatus: Implementation appears functional - investigate edge cases") + sys.exit(0) + + except KeyboardInterrupt: + print("\nDiagnostic interrupted by user") + sys.exit(130) + except Exception as e: + print(f"\nUnexpected error in diagnostic: {str(e)}") + traceback.print_exc() + sys.exit(1) + + +if __name__ == "__main__": + print("TCP Implementation Diagnostic Tool") + print("Deep analysis of Pyrofork TCP transport issues") + print("-" * 60) + + asyncio.run(main()) \ No newline at end of file From 996ac69fa0f049785a5e16bf2850205eb64225e5 Mon Sep 17 00:00:00 2001 From: root Date: Wed, 6 Aug 2025 10:32:42 -0400 Subject: [PATCH 3/4] fix --- pyrogram/connection/transport/tcp/test.py | 377 ---------------------- 1 file changed, 377 deletions(-) delete mode 100644 pyrogram/connection/transport/tcp/test.py diff --git a/pyrogram/connection/transport/tcp/test.py b/pyrogram/connection/transport/tcp/test.py deleted file mode 100644 index d40efeb8..00000000 --- a/pyrogram/connection/transport/tcp/test.py +++ /dev/null @@ -1,377 +0,0 @@ -#!/usr/bin/env python3 -""" -TCP Implementation Diagnostic Tool for Pyrofork -Deep analysis to identify specific issues in TCP transport -""" - -import asyncio -import sys -import time -import traceback -import socket -import ssl -import os -from typing import Dict, Any, List, Tuple, Optional -import logging - -# Setup path for local imports -current_dir = os.path.dirname(os.path.abspath(__file__)) -root_dir = os.path.abspath(os.path.join(current_dir, '../../../../../')) -sys.path.insert(0, root_dir) - -# Import TCP implementation -try: - from pyrogram.connection.transport.tcp import TCP - print("Successfully imported TCP from local source") -except ImportError as e: - print(f"Failed to import TCP: {e}") - sys.exit(1) - -# Configure detailed logging -logging.basicConfig( - level=logging.DEBUG, - format="%(asctime)s - %(name)s - %(levelname)s - %(message)s", - datefmt="%H:%M:%S" -) -logger = logging.getLogger(__name__) - -class TCPDiagnostics: - """Diagnostic tools for TCP implementation analysis""" - - def __init__(self): - self.telegram_servers = [ - ("149.154.175.50", 443, "DC1-Miami"), - ("149.154.167.51", 443, "DC2-Amsterdam"), - ("149.154.175.100", 443, "DC3-Miami"), - ("149.154.167.91", 443, "DC4-Amsterdam"), - ("91.108.56.130", 443, "DC5-Singapore"), - ] - - async def test_raw_socket_connection(self, host: str, port: int, name: str) -> Dict[str, Any]: - """Test raw socket connection without Pyrofork TCP wrapper""" - print(f"\n[RAW SOCKET] Testing {name} ({host}:{port})") - - start_time = time.time() - sock = None - - try: - # Create raw socket - sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM) - sock.settimeout(10.0) - - # Test connection - sock.connect((host, port)) - connect_time = time.time() - start_time - - # Test basic send/receive - test_data = b'GET / HTTP/1.1\r\nHost: ' + host.encode() + b'\r\n\r\n' - sock.send(test_data) - - # Try to receive response - try: - response = sock.recv(1024) - received_data = len(response) > 0 - except socket.timeout: - received_data = False - - sock.close() - sock = None - - result = { - "success": True, - "connect_time": connect_time, - "received_data": received_data, - "error": None - } - - print(f"[RAW SOCKET] {name}: SUCCESS - {connect_time:.2f}s") - return result - - except Exception as e: - if sock: - sock.close() - - error_time = time.time() - start_time - print(f"[RAW SOCKET] {name}: FAILED - {str(e)} ({error_time:.2f}s)") - - return { - "success": False, - "connect_time": error_time, - "received_data": False, - "error": str(e) - } - - async def test_tcp_implementation(self, host: str, port: int, name: str) -> Dict[str, Any]: - """Test Pyrofork TCP implementation with detailed logging""" - print(f"\n[PYROFORK TCP] Testing {name} ({host}:{port})") - - start_time = time.time() - tcp = None - - try: - # Initialize TCP transport - print(f"[PYROFORK TCP] Initializing TCP transport...") - tcp = TCP(ipv6=False, proxy=None) - - print(f"[PYROFORK TCP] TCP object created: {type(tcp)}") - print(f"[PYROFORK TCP] TCP attributes: {dir(tcp)}") - - # Attempt connection - print(f"[PYROFORK TCP] Attempting connection...") - await asyncio.wait_for(tcp.connect((host, port)), timeout=10.0) - - connect_time = time.time() - start_time - print(f"[PYROFORK TCP] Connection established in {connect_time:.2f}s") - - # Test send operation - print(f"[PYROFORK TCP] Testing send operation...") - test_data = b'\x00\x00\x00\x00\x00\x00\x00\x00' - await tcp.send(test_data) - print(f"[PYROFORK TCP] Send operation successful") - - # Test receive operation - print(f"[PYROFORK TCP] Testing receive operation...") - try: - response = await asyncio.wait_for(tcp.recv(8), timeout=3.0) - received_data = len(response) > 0 if response else False - print(f"[PYROFORK TCP] Received {len(response) if response else 0} bytes") - except asyncio.TimeoutError: - received_data = False - print(f"[PYROFORK TCP] Receive timeout (normal for test data)") - - # Clean close - print(f"[PYROFORK TCP] Closing connection...") - await tcp.close() - tcp = None - print(f"[PYROFORK TCP] Connection closed successfully") - - result = { - "success": True, - "connect_time": connect_time, - "received_data": received_data, - "error": None - } - - print(f"[PYROFORK TCP] {name}: SUCCESS - {connect_time:.2f}s") - return result - - except Exception as e: - error_time = time.time() - start_time - print(f"[PYROFORK TCP] {name}: FAILED - {str(e)}") - print(f"[PYROFORK TCP] Exception details:") - traceback.print_exc() - - # Cleanup - if tcp: - try: - await tcp.close() - except: - pass - - return { - "success": False, - "connect_time": error_time, - "received_data": False, - "error": str(e) - } - - async def inspect_tcp_implementation(self): - """Inspect TCP class implementation details""" - print(f"\n[INSPECTION] Analyzing TCP implementation...") - - try: - tcp = TCP(ipv6=False, proxy=None) - - print(f"[INSPECTION] TCP class: {TCP}") - print(f"[INSPECTION] TCP module: {TCP.__module__}") - print(f"[INSPECTION] TCP file: {TCP.__module__.replace('.', os.sep)}.py") - - # Check methods - methods = [attr for attr in dir(tcp) if not attr.startswith('_')] - print(f"[INSPECTION] Available methods: {methods}") - - # Check for required methods - required_methods = ['connect', 'send', 'recv', 'close'] - for method in required_methods: - if hasattr(tcp, method): - method_obj = getattr(tcp, method) - print(f"[INSPECTION] {method}: {method_obj} (callable: {callable(method_obj)})") - else: - print(f"[INSPECTION] MISSING METHOD: {method}") - - # Check initialization parameters - print(f"[INSPECTION] TCP instance attributes:") - for attr in dir(tcp): - if not attr.startswith('__'): - try: - value = getattr(tcp, attr) - if not callable(value): - print(f"[INSPECTION] {attr}: {value}") - except: - print(f"[INSPECTION] {attr}: ") - - except Exception as e: - print(f"[INSPECTION] Error during inspection: {e}") - traceback.print_exc() - - async def test_concurrent_simple(self, host: str, port: int, count: int = 3) -> Dict[str, Any]: - """Simple concurrent connection test""" - print(f"\n[CONCURRENT] Testing {count} concurrent connections to {host}:{port}") - - async def single_connect(): - try: - tcp = TCP(ipv6=False, proxy=None) - await asyncio.wait_for(tcp.connect((host, port)), timeout=5.0) - await tcp.close() - return True - except Exception as e: - print(f"[CONCURRENT] Connection failed: {str(e)[:50]}") - return False - - start_time = time.time() - tasks = [single_connect() for _ in range(count)] - results = await asyncio.gather(*tasks, return_exceptions=True) - total_time = time.time() - start_time - - successful = sum(1 for r in results if r is True) - success_rate = (successful / count) * 100 - - print(f"[CONCURRENT] Results: {successful}/{count} successful ({success_rate:.1f}%) in {total_time:.2f}s") - - return { - "successful": successful, - "total": count, - "success_rate": success_rate, - "total_time": total_time - } - - async def run_diagnostic_suite(self): - """Run comprehensive diagnostic tests""" - print("="*80) - print("TCP IMPLEMENTATION DIAGNOSTIC SUITE") - print("="*80) - - # Phase 1: Inspect implementation - await self.inspect_tcp_implementation() - - # Phase 2: Raw socket tests (baseline) - print(f"\n" + "="*50) - print("PHASE 1: RAW SOCKET BASELINE TESTS") - print("="*50) - - raw_results = [] - for host, port, name in self.telegram_servers: - result = await self.test_raw_socket_connection(host, port, name) - raw_results.append((name, result)) - await asyncio.sleep(0.5) - - # Phase 3: Pyrofork TCP tests - print(f"\n" + "="*50) - print("PHASE 2: PYROFORK TCP IMPLEMENTATION TESTS") - print("="*50) - - tcp_results = [] - for host, port, name in self.telegram_servers: - result = await self.test_tcp_implementation(host, port, name) - tcp_results.append((name, result)) - await asyncio.sleep(1.0) - - # Phase 4: Concurrent test - print(f"\n" + "="*50) - print("PHASE 3: CONCURRENT CONNECTION TEST") - print("="*50) - - concurrent_result = await self.test_concurrent_simple( - self.telegram_servers[0][0], - self.telegram_servers[0][1] - ) - - # Analysis and recommendations - print(f"\n" + "="*50) - print("DIAGNOSTIC ANALYSIS") - print("="*50) - - raw_success = sum(1 for _, r in raw_results if r["success"]) - tcp_success = sum(1 for _, r in tcp_results if r["success"]) - - print(f"\nResults Summary:") - print(f" Raw Socket Success: {raw_success}/{len(raw_results)} ({raw_success/len(raw_results)*100:.1f}%)") - print(f" Pyrofork TCP Success: {tcp_success}/{len(tcp_results)} ({tcp_success/len(tcp_results)*100:.1f}%)") - print(f" Concurrent Success: {concurrent_result['success_rate']:.1f}%") - - print(f"\nDiagnostic Findings:") - - if raw_success < len(raw_results) * 0.8: - print(" - Network connectivity issues detected") - print(" - Check firewall, DNS resolution, or internet connection") - else: - print(" - Network connectivity is good (raw sockets work)") - - if tcp_success < raw_success: - print(" - Pyrofork TCP implementation has issues") - print(" - TCP wrapper is failing where raw sockets succeed") - print(" - Likely issues: async/await implementation, connection state management") - elif tcp_success == raw_success: - print(" - Pyrofork TCP implementation matches raw socket performance") - print(" - Implementation appears correct") - - if concurrent_result['success_rate'] > tcp_success / len(tcp_results) * 100: - print(" - Concurrent connections work better than sequential") - print(" - Possible timing or state management issue in sequential testing") - - # Specific recommendations - print(f"\nRecommendations:") - - if tcp_success < raw_success: - print(" 1. Review TCP.connect() implementation for async/await correctness") - print(" 2. Check socket state management and cleanup") - print(" 3. Verify timeout handling in async operations") - print(" 4. Test with simpler connection sequence") - - if raw_success < len(raw_results): - print(" 1. Check network connectivity to Telegram servers") - print(" 2. Verify DNS resolution") - print(" 3. Check for firewall blocking connections") - - return { - "raw_results": raw_results, - "tcp_results": tcp_results, - "concurrent_result": concurrent_result - } - - -async def main(): - """Main diagnostic execution""" - try: - diagnostics = TCPDiagnostics() - results = await diagnostics.run_diagnostic_suite() - - # Determine if ready for debugging or needs environment fixes - tcp_success_rate = sum(1 for _, r in results["tcp_results"] if r["success"]) / len(results["tcp_results"]) - raw_success_rate = sum(1 for _, r in results["raw_results"] if r["success"]) / len(results["raw_results"]) - - if raw_success_rate < 0.8: - print(f"\nStatus: Environment issues detected - fix network connectivity first") - sys.exit(2) - elif tcp_success_rate < 0.5: - print(f"\nStatus: TCP implementation issues detected - needs code review") - sys.exit(1) - else: - print(f"\nStatus: Implementation appears functional - investigate edge cases") - sys.exit(0) - - except KeyboardInterrupt: - print("\nDiagnostic interrupted by user") - sys.exit(130) - except Exception as e: - print(f"\nUnexpected error in diagnostic: {str(e)}") - traceback.print_exc() - sys.exit(1) - - -if __name__ == "__main__": - print("TCP Implementation Diagnostic Tool") - print("Deep analysis of Pyrofork TCP transport issues") - print("-" * 60) - - asyncio.run(main()) \ No newline at end of file From 88fca87fa5f1e4012a0874f903cb132c5e6023b0 Mon Sep 17 00:00:00 2001 From: rpt0 Date: Wed, 6 Aug 2025 10:42:49 -0400 Subject: [PATCH 4/4] final --- pyrogram/connection/transport/tcp/tcp.py | 310 ++++++++++------------- 1 file changed, 136 insertions(+), 174 deletions(-) diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index ab8884a1..b10250aa 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -44,74 +44,54 @@ class Proxy(TypedDict): class ProxyError(Exception): """Base exception for proxy-related errors.""" - pass - class AuthenticationError(ProxyError): """Authentication failed.""" - pass - class ConnectionError(ProxyError): """Connection failed.""" - pass - class SOCKS5Handler: """Handles SOCKS5 proxy operations.""" - + @staticmethod - async def negotiate_auth( - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - username: Optional[str], - password: Optional[str], - timeout: int, - ) -> None: + async def negotiate_auth(writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + username: Optional[str], password: Optional[str], timeout: int) -> None: """Handle SOCKS5 authentication negotiation.""" - auth_methods = b"\x05\x02\x00\x02" if username and password else b"\x05\x01\x00" - + auth_methods = b'\x05\x02\x00\x02' if username and password else b'\x05\x01\x00' + writer.write(auth_methods) await writer.drain() - + response = await asyncio.wait_for(reader.read(2), timeout) if len(response) != 2 or response[0] != 0x05: raise ConnectionError("Invalid SOCKS5 response") - + return response[1] - + @staticmethod - async def authenticate( - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - username: str, - password: str, - timeout: int, - ) -> None: + async def authenticate(writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + username: str, password: str, timeout: int) -> None: """Perform username/password authentication.""" - username_bytes = username.encode("utf-8") - password_bytes = password.encode("utf-8") - auth_request = ( - bytes([0x01, len(username_bytes)]) - + username_bytes - + bytes([len(password_bytes)]) - + password_bytes - ) - + username_bytes = username.encode('utf-8') + password_bytes = password.encode('utf-8') + auth_request = (bytes([0x01, len(username_bytes)]) + username_bytes + + bytes([len(password_bytes)]) + password_bytes) + writer.write(auth_request) await writer.drain() - + auth_response = await asyncio.wait_for(reader.read(2), timeout) if len(auth_response) != 2 or auth_response[1] != 0x00: raise AuthenticationError("SOCKS5 authentication failed") - + @staticmethod def build_connect_request(host: str, port: int) -> bytes: """Build SOCKS5 connection request.""" request = bytearray([0x05, 0x01, 0x00]) - + try: ip = ipaddress.ip_address(host) if isinstance(ip, ipaddress.IPv4Address): @@ -121,14 +101,14 @@ class SOCKS5Handler: request.append(0x04) request.extend(ip.packed) except ValueError: - host_bytes = host.encode("utf-8") + host_bytes = host.encode('utf-8') request.append(0x03) request.append(len(host_bytes)) request.extend(host_bytes) - - request.extend(struct.pack(">H", port)) + + request.extend(struct.pack('>H', port)) return bytes(request) - + @staticmethod async def read_bound_address(reader: asyncio.StreamReader, addr_type: int) -> None: """Read bound address from SOCKS5 response.""" @@ -139,25 +119,17 @@ class SOCKS5Handler: await reader.read(domain_len + 2) elif addr_type == 0x04: await reader.read(18) - + @classmethod - async def handshake( - cls, - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - destination: Tuple[str, int], - username: Optional[str] = None, - password: Optional[str] = None, - timeout: int = 10, - ) -> None: + async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], *, username: Optional[str] = None, + password: Optional[str] = None, timeout: int = 10) -> None: """Perform complete SOCKS5 handshake.""" host, port = destination - + # Authentication negotiation - selected_method = await cls.negotiate_auth( - writer, reader, username, password, timeout - ) - + selected_method = await cls.negotiate_auth(writer, reader, username, password, timeout) + # Handle authentication if selected_method == 0x02: if not username or not password: @@ -165,62 +137,50 @@ class SOCKS5Handler: await cls.authenticate(writer, reader, username, password, timeout) elif selected_method != 0x00: raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}") - + # Connection request request = cls.build_connect_request(host, port) writer.write(request) await writer.drain() - + # Read connection response conn_response = await asyncio.wait_for(reader.read(4), timeout) - if ( - len(conn_response) != 4 - or conn_response[0] != 0x05 - or conn_response[1] != 0x00 - ): + if len(conn_response) != 4 or conn_response[0] != 0x05 or conn_response[1] != 0x00: raise ConnectionError("SOCKS5 connection failed") - + # Read bound address await cls.read_bound_address(reader, conn_response[3]) class SOCKS4Handler: """Handles SOCKS4 proxy operations.""" - + @staticmethod def build_request(host: str, port: int) -> bytes: """Build SOCKS4 connection request.""" try: ip = ipaddress.IPv4Address(host) ip_bytes = ip.packed - request = struct.pack(">BBH", 0x04, 0x01, port) + ip_bytes + b"pyrogram\x00" + request = struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + b'pyrogram\x00' except ValueError: - ip_bytes = b"\x00\x00\x00\x01" - request = ( - struct.pack(">BBH", 0x04, 0x01, port) - + ip_bytes - + b"pyrogram\x00" - + host.encode("utf-8") - + b"\x00" - ) - - return request - + ip_bytes = b'\x00\x00\x00\x01' + request = (struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + + b'pyrogram\x00' + host.encode('utf-8') + b'\x00') + + @staticmethod + def build_request(host: str, port: int, username: Optional[str] = None, + password: Optional[str] = None) -> str: + @classmethod - async def handshake( - cls, - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - destination: Tuple[str, int], - timeout: int = 10, - ) -> None: + async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], timeout: int = 10) -> None: """Perform SOCKS4 handshake.""" host, port = destination request = cls.build_request(host, port) - + writer.write(request) await writer.drain() - + response = await asyncio.wait_for(reader.read(8), timeout) if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: raise ConnectionError("SOCKS4 connection failed") @@ -228,23 +188,27 @@ class SOCKS4Handler: class HTTPProxyHandler: """Handles HTTP proxy operations.""" - + @staticmethod - def build_request( - host: str, - port: int, - username: Optional[str] = None, - password: Optional[str] = None, - ) -> str: + def sanitize_request(request: str) -> str: + """Sanitize HTTP request to prevent injection attacks.""" + if '\r\n\r\n' not in request: + raise ValueError("Invalid HTTP request format") + + lines = request.split('\r\n') + if not lines[0].startswith('CONNECT ') or 'HTTP/1.1' not in lines[0]: + raise ValueError("Invalid CONNECT request") + + return request """Build HTTP CONNECT request.""" request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n" - + if username and password: credentials = base64.b64encode(f"{username}:{password}".encode()).decode() request += f"Proxy-Authorization: Basic {credentials}\r\n" - + return request + "\r\n" - + @staticmethod async def read_response(reader: asyncio.StreamReader, timeout: int) -> list: """Read HTTP proxy response.""" @@ -253,39 +217,34 @@ class HTTPProxyHandler: line = await asyncio.wait_for(reader.readline(), timeout) if not line: raise ConnectionError("HTTP proxy connection closed") - - line = line.decode("utf-8", errors="ignore").strip() + + line = line.decode('utf-8', errors='ignore').strip() response_lines.append(line) - + if not line: break - + return response_lines - + @classmethod - async def handshake( - cls, - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - destination: Tuple[str, int], - username: Optional[str] = None, - password: Optional[str] = None, - timeout: int = 10, - ) -> None: + async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], *, username: Optional[str] = None, + password: Optional[str] = None, timeout: int = 10) -> None: """Perform HTTP proxy handshake.""" host, port = destination request = cls.build_request(host, port, username, password) - - writer.write(request.encode()) + sanitized_request = cls.sanitize_request(request) + + writer.write(sanitized_request.encode()) await writer.drain() - + response_lines = await cls.read_response(reader, timeout) - - if not response_lines or not response_lines[0].startswith("HTTP/1."): + + if not response_lines or not response_lines[0].startswith('HTTP/1.'): raise ConnectionError("Invalid HTTP proxy response") - - status_parts = response_lines[0].split(" ", 2) - if len(status_parts) < 2 or status_parts[1] != "200": + + status_parts = response_lines[0].split(' ', 2) + if len(status_parts) < 2 or status_parts[1] != '200': raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}") @@ -301,81 +260,83 @@ class TCP: self.lock = asyncio.Lock() - def _get_proxy_handler( - self, scheme: str - ) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]: + def _get_proxy_handler(self, scheme: str) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]: """Get appropriate proxy handler based on scheme.""" handlers = { "SOCKS5": SOCKS5Handler, "SOCKS4": SOCKS4Handler, - "HTTP": HTTPProxyHandler, + "HTTP": HTTPProxyHandler } - + handler_class = handlers.get(scheme.upper()) if not handler_class: raise ValueError(f"Unknown proxy type {scheme}") - + return handler_class + def _validate_proxy_config(self) -> None: + """Validate proxy configuration.""" + if not self.proxy.get("scheme"): + raise ValueError("No scheme specified") + + if self.proxy["scheme"].upper() not in proxy_type_by_scheme: + raise ValueError(f"Unknown proxy type {self.proxy['scheme']}") + + async def _establish_proxy_connection(self, hostname: str, port: int, proxy_family: int) -> None: + """Establish connection to proxy server.""" + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_connection(host=hostname, port=port, family=proxy_family), + timeout=self.TIMEOUT + ) + + def _get_proxy_family(self, hostname: str) -> int: + """Determine address family for proxy connection.""" + try: + ip_address = ipaddress.ip_address(hostname) + return (socket.AF_INET6 if isinstance(ip_address, ipaddress.IPv6Address) + else socket.AF_INET) + except ValueError: + return socket.AF_INET + + async def _perform_handshake(self, scheme: str, destination: Tuple[str, int], + username: Optional[str], password: Optional[str]) -> None: + """Perform proxy handshake based on scheme.""" + handler = self._get_proxy_handler(scheme) + + if scheme.upper() == "SOCKS5": + await handler.handshake(self.writer, self.reader, destination, + username=username, password=password, timeout=self.TIMEOUT) + elif scheme.upper() == "SOCKS4": + await handler.handshake(self.writer, self.reader, destination, self.TIMEOUT) + elif scheme.upper() == "HTTP": + await handler.handshake(self.writer, self.reader, destination, + username=username, password=password, timeout=self.TIMEOUT) + async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None: """Connect through proxy server.""" - scheme = self.proxy.get("scheme") - if not scheme: - raise ValueError("No scheme specified") - + self._validate_proxy_config() + + scheme = self.proxy["scheme"] hostname = self.proxy.get("hostname") port = self.proxy.get("port") username = self.proxy.get("username") password = self.proxy.get("password") - - # Determine proxy address family - try: - ip_address = ipaddress.ip_address(hostname) - proxy_family = ( - socket.AF_INET6 - if isinstance(ip_address, ipaddress.IPv6Address) - else socket.AF_INET - ) - except ValueError: - proxy_family = socket.AF_INET + + proxy_family = self._get_proxy_family(hostname) try: - # Connect to proxy server - self.reader, self.writer = await asyncio.wait_for( - asyncio.open_connection(host=hostname, port=port, family=proxy_family), - timeout=self.TIMEOUT, - ) - - # Perform proxy handshake - handler = self._get_proxy_handler(scheme) - if scheme.upper() == "SOCKS5": - await handler.handshake( - self.writer, - self.reader, - destination, - username, - password, - self.TIMEOUT, - ) - elif scheme.upper() == "SOCKS4": - await handler.handshake( - self.writer, self.reader, destination, self.TIMEOUT - ) - elif scheme.upper() == "HTTP": - await handler.handshake( - self.writer, - self.reader, - destination, - username, - password, - self.TIMEOUT, - ) - - except Exception: + await self._establish_proxy_connection(hostname, port, proxy_family) + await self._perform_handshake(scheme, destination, username, password) + except (ConnectionError, AuthenticationError, ValueError): if self.writer: self.writer.close() await self.writer.wait_closed() raise + except Exception as e: + if self.writer: + self.writer.close() + await self.writer.wait_closed() + raise ConnectionError(f"Proxy connection failed: {e}") from e async def _connect_via_direct(self, destination: Tuple[str, int]) -> None: """Connect directly to destination.""" @@ -429,7 +390,8 @@ class TCP: while len(data) < length: try: chunk = await asyncio.wait_for( - self.reader.read(length - len(data)), self.TIMEOUT + self.reader.read(length - len(data)), + self.TIMEOUT ) except (OSError, asyncio.TimeoutError): return None @@ -439,4 +401,4 @@ class TCP: else: return None - return data + return data \ No newline at end of file