diff --git a/pyrogram/connection/transport/tcp/tcp.py b/pyrogram/connection/transport/tcp/tcp.py index ab8884a1..b10250aa 100644 --- a/pyrogram/connection/transport/tcp/tcp.py +++ b/pyrogram/connection/transport/tcp/tcp.py @@ -44,74 +44,54 @@ class Proxy(TypedDict): class ProxyError(Exception): """Base exception for proxy-related errors.""" - pass - class AuthenticationError(ProxyError): """Authentication failed.""" - pass - class ConnectionError(ProxyError): """Connection failed.""" - pass - class SOCKS5Handler: """Handles SOCKS5 proxy operations.""" - + @staticmethod - async def negotiate_auth( - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - username: Optional[str], - password: Optional[str], - timeout: int, - ) -> None: + async def negotiate_auth(writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + username: Optional[str], password: Optional[str], timeout: int) -> None: """Handle SOCKS5 authentication negotiation.""" - auth_methods = b"\x05\x02\x00\x02" if username and password else b"\x05\x01\x00" - + auth_methods = b'\x05\x02\x00\x02' if username and password else b'\x05\x01\x00' + writer.write(auth_methods) await writer.drain() - + response = await asyncio.wait_for(reader.read(2), timeout) if len(response) != 2 or response[0] != 0x05: raise ConnectionError("Invalid SOCKS5 response") - + return response[1] - + @staticmethod - async def authenticate( - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - username: str, - password: str, - timeout: int, - ) -> None: + async def authenticate(writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + username: str, password: str, timeout: int) -> None: """Perform username/password authentication.""" - username_bytes = username.encode("utf-8") - password_bytes = password.encode("utf-8") - auth_request = ( - bytes([0x01, len(username_bytes)]) - + username_bytes - + bytes([len(password_bytes)]) - + password_bytes - ) - + username_bytes = username.encode('utf-8') + password_bytes = password.encode('utf-8') + auth_request = (bytes([0x01, len(username_bytes)]) + username_bytes + + bytes([len(password_bytes)]) + password_bytes) + writer.write(auth_request) await writer.drain() - + auth_response = await asyncio.wait_for(reader.read(2), timeout) if len(auth_response) != 2 or auth_response[1] != 0x00: raise AuthenticationError("SOCKS5 authentication failed") - + @staticmethod def build_connect_request(host: str, port: int) -> bytes: """Build SOCKS5 connection request.""" request = bytearray([0x05, 0x01, 0x00]) - + try: ip = ipaddress.ip_address(host) if isinstance(ip, ipaddress.IPv4Address): @@ -121,14 +101,14 @@ class SOCKS5Handler: request.append(0x04) request.extend(ip.packed) except ValueError: - host_bytes = host.encode("utf-8") + host_bytes = host.encode('utf-8') request.append(0x03) request.append(len(host_bytes)) request.extend(host_bytes) - - request.extend(struct.pack(">H", port)) + + request.extend(struct.pack('>H', port)) return bytes(request) - + @staticmethod async def read_bound_address(reader: asyncio.StreamReader, addr_type: int) -> None: """Read bound address from SOCKS5 response.""" @@ -139,25 +119,17 @@ class SOCKS5Handler: await reader.read(domain_len + 2) elif addr_type == 0x04: await reader.read(18) - + @classmethod - async def handshake( - cls, - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - destination: Tuple[str, int], - username: Optional[str] = None, - password: Optional[str] = None, - timeout: int = 10, - ) -> None: + async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], *, username: Optional[str] = None, + password: Optional[str] = None, timeout: int = 10) -> None: """Perform complete SOCKS5 handshake.""" host, port = destination - + # Authentication negotiation - selected_method = await cls.negotiate_auth( - writer, reader, username, password, timeout - ) - + selected_method = await cls.negotiate_auth(writer, reader, username, password, timeout) + # Handle authentication if selected_method == 0x02: if not username or not password: @@ -165,62 +137,50 @@ class SOCKS5Handler: await cls.authenticate(writer, reader, username, password, timeout) elif selected_method != 0x00: raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}") - + # Connection request request = cls.build_connect_request(host, port) writer.write(request) await writer.drain() - + # Read connection response conn_response = await asyncio.wait_for(reader.read(4), timeout) - if ( - len(conn_response) != 4 - or conn_response[0] != 0x05 - or conn_response[1] != 0x00 - ): + if len(conn_response) != 4 or conn_response[0] != 0x05 or conn_response[1] != 0x00: raise ConnectionError("SOCKS5 connection failed") - + # Read bound address await cls.read_bound_address(reader, conn_response[3]) class SOCKS4Handler: """Handles SOCKS4 proxy operations.""" - + @staticmethod def build_request(host: str, port: int) -> bytes: """Build SOCKS4 connection request.""" try: ip = ipaddress.IPv4Address(host) ip_bytes = ip.packed - request = struct.pack(">BBH", 0x04, 0x01, port) + ip_bytes + b"pyrogram\x00" + request = struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + b'pyrogram\x00' except ValueError: - ip_bytes = b"\x00\x00\x00\x01" - request = ( - struct.pack(">BBH", 0x04, 0x01, port) - + ip_bytes - + b"pyrogram\x00" - + host.encode("utf-8") - + b"\x00" - ) - - return request - + ip_bytes = b'\x00\x00\x00\x01' + request = (struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + + b'pyrogram\x00' + host.encode('utf-8') + b'\x00') + + @staticmethod + def build_request(host: str, port: int, username: Optional[str] = None, + password: Optional[str] = None) -> str: + @classmethod - async def handshake( - cls, - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - destination: Tuple[str, int], - timeout: int = 10, - ) -> None: + async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], timeout: int = 10) -> None: """Perform SOCKS4 handshake.""" host, port = destination request = cls.build_request(host, port) - + writer.write(request) await writer.drain() - + response = await asyncio.wait_for(reader.read(8), timeout) if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: raise ConnectionError("SOCKS4 connection failed") @@ -228,23 +188,27 @@ class SOCKS4Handler: class HTTPProxyHandler: """Handles HTTP proxy operations.""" - + @staticmethod - def build_request( - host: str, - port: int, - username: Optional[str] = None, - password: Optional[str] = None, - ) -> str: + def sanitize_request(request: str) -> str: + """Sanitize HTTP request to prevent injection attacks.""" + if '\r\n\r\n' not in request: + raise ValueError("Invalid HTTP request format") + + lines = request.split('\r\n') + if not lines[0].startswith('CONNECT ') or 'HTTP/1.1' not in lines[0]: + raise ValueError("Invalid CONNECT request") + + return request """Build HTTP CONNECT request.""" request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n" - + if username and password: credentials = base64.b64encode(f"{username}:{password}".encode()).decode() request += f"Proxy-Authorization: Basic {credentials}\r\n" - + return request + "\r\n" - + @staticmethod async def read_response(reader: asyncio.StreamReader, timeout: int) -> list: """Read HTTP proxy response.""" @@ -253,39 +217,34 @@ class HTTPProxyHandler: line = await asyncio.wait_for(reader.readline(), timeout) if not line: raise ConnectionError("HTTP proxy connection closed") - - line = line.decode("utf-8", errors="ignore").strip() + + line = line.decode('utf-8', errors='ignore').strip() response_lines.append(line) - + if not line: break - + return response_lines - + @classmethod - async def handshake( - cls, - writer: asyncio.StreamWriter, - reader: asyncio.StreamReader, - destination: Tuple[str, int], - username: Optional[str] = None, - password: Optional[str] = None, - timeout: int = 10, - ) -> None: + async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader, + destination: Tuple[str, int], *, username: Optional[str] = None, + password: Optional[str] = None, timeout: int = 10) -> None: """Perform HTTP proxy handshake.""" host, port = destination request = cls.build_request(host, port, username, password) - - writer.write(request.encode()) + sanitized_request = cls.sanitize_request(request) + + writer.write(sanitized_request.encode()) await writer.drain() - + response_lines = await cls.read_response(reader, timeout) - - if not response_lines or not response_lines[0].startswith("HTTP/1."): + + if not response_lines or not response_lines[0].startswith('HTTP/1.'): raise ConnectionError("Invalid HTTP proxy response") - - status_parts = response_lines[0].split(" ", 2) - if len(status_parts) < 2 or status_parts[1] != "200": + + status_parts = response_lines[0].split(' ', 2) + if len(status_parts) < 2 or status_parts[1] != '200': raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}") @@ -301,81 +260,83 @@ class TCP: self.lock = asyncio.Lock() - def _get_proxy_handler( - self, scheme: str - ) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]: + def _get_proxy_handler(self, scheme: str) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]: """Get appropriate proxy handler based on scheme.""" handlers = { "SOCKS5": SOCKS5Handler, "SOCKS4": SOCKS4Handler, - "HTTP": HTTPProxyHandler, + "HTTP": HTTPProxyHandler } - + handler_class = handlers.get(scheme.upper()) if not handler_class: raise ValueError(f"Unknown proxy type {scheme}") - + return handler_class + def _validate_proxy_config(self) -> None: + """Validate proxy configuration.""" + if not self.proxy.get("scheme"): + raise ValueError("No scheme specified") + + if self.proxy["scheme"].upper() not in proxy_type_by_scheme: + raise ValueError(f"Unknown proxy type {self.proxy['scheme']}") + + async def _establish_proxy_connection(self, hostname: str, port: int, proxy_family: int) -> None: + """Establish connection to proxy server.""" + self.reader, self.writer = await asyncio.wait_for( + asyncio.open_connection(host=hostname, port=port, family=proxy_family), + timeout=self.TIMEOUT + ) + + def _get_proxy_family(self, hostname: str) -> int: + """Determine address family for proxy connection.""" + try: + ip_address = ipaddress.ip_address(hostname) + return (socket.AF_INET6 if isinstance(ip_address, ipaddress.IPv6Address) + else socket.AF_INET) + except ValueError: + return socket.AF_INET + + async def _perform_handshake(self, scheme: str, destination: Tuple[str, int], + username: Optional[str], password: Optional[str]) -> None: + """Perform proxy handshake based on scheme.""" + handler = self._get_proxy_handler(scheme) + + if scheme.upper() == "SOCKS5": + await handler.handshake(self.writer, self.reader, destination, + username=username, password=password, timeout=self.TIMEOUT) + elif scheme.upper() == "SOCKS4": + await handler.handshake(self.writer, self.reader, destination, self.TIMEOUT) + elif scheme.upper() == "HTTP": + await handler.handshake(self.writer, self.reader, destination, + username=username, password=password, timeout=self.TIMEOUT) + async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None: """Connect through proxy server.""" - scheme = self.proxy.get("scheme") - if not scheme: - raise ValueError("No scheme specified") - + self._validate_proxy_config() + + scheme = self.proxy["scheme"] hostname = self.proxy.get("hostname") port = self.proxy.get("port") username = self.proxy.get("username") password = self.proxy.get("password") - - # Determine proxy address family - try: - ip_address = ipaddress.ip_address(hostname) - proxy_family = ( - socket.AF_INET6 - if isinstance(ip_address, ipaddress.IPv6Address) - else socket.AF_INET - ) - except ValueError: - proxy_family = socket.AF_INET + + proxy_family = self._get_proxy_family(hostname) try: - # Connect to proxy server - self.reader, self.writer = await asyncio.wait_for( - asyncio.open_connection(host=hostname, port=port, family=proxy_family), - timeout=self.TIMEOUT, - ) - - # Perform proxy handshake - handler = self._get_proxy_handler(scheme) - if scheme.upper() == "SOCKS5": - await handler.handshake( - self.writer, - self.reader, - destination, - username, - password, - self.TIMEOUT, - ) - elif scheme.upper() == "SOCKS4": - await handler.handshake( - self.writer, self.reader, destination, self.TIMEOUT - ) - elif scheme.upper() == "HTTP": - await handler.handshake( - self.writer, - self.reader, - destination, - username, - password, - self.TIMEOUT, - ) - - except Exception: + await self._establish_proxy_connection(hostname, port, proxy_family) + await self._perform_handshake(scheme, destination, username, password) + except (ConnectionError, AuthenticationError, ValueError): if self.writer: self.writer.close() await self.writer.wait_closed() raise + except Exception as e: + if self.writer: + self.writer.close() + await self.writer.wait_closed() + raise ConnectionError(f"Proxy connection failed: {e}") from e async def _connect_via_direct(self, destination: Tuple[str, int]) -> None: """Connect directly to destination.""" @@ -429,7 +390,8 @@ class TCP: while len(data) < length: try: chunk = await asyncio.wait_for( - self.reader.read(length - len(data)), self.TIMEOUT + self.reader.read(length - len(data)), + self.TIMEOUT ) except (OSError, asyncio.TimeoutError): return None @@ -439,4 +401,4 @@ class TCP: else: return None - return data + return data \ No newline at end of file