This commit is contained in:
rpt0 2025-08-06 10:42:49 -04:00
parent 996ac69fa0
commit 88fca87fa5

View file

@ -44,74 +44,54 @@ class Proxy(TypedDict):
class ProxyError(Exception): class ProxyError(Exception):
"""Base exception for proxy-related errors.""" """Base exception for proxy-related errors."""
pass
class AuthenticationError(ProxyError): class AuthenticationError(ProxyError):
"""Authentication failed.""" """Authentication failed."""
pass
class ConnectionError(ProxyError): class ConnectionError(ProxyError):
"""Connection failed.""" """Connection failed."""
pass
class SOCKS5Handler: class SOCKS5Handler:
"""Handles SOCKS5 proxy operations.""" """Handles SOCKS5 proxy operations."""
@staticmethod @staticmethod
async def negotiate_auth( async def negotiate_auth(writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, username: Optional[str], password: Optional[str], timeout: int) -> None:
reader: asyncio.StreamReader,
username: Optional[str],
password: Optional[str],
timeout: int,
) -> None:
"""Handle SOCKS5 authentication negotiation.""" """Handle SOCKS5 authentication negotiation."""
auth_methods = b"\x05\x02\x00\x02" if username and password else b"\x05\x01\x00" auth_methods = b'\x05\x02\x00\x02' if username and password else b'\x05\x01\x00'
writer.write(auth_methods) writer.write(auth_methods)
await writer.drain() await writer.drain()
response = await asyncio.wait_for(reader.read(2), timeout) response = await asyncio.wait_for(reader.read(2), timeout)
if len(response) != 2 or response[0] != 0x05: if len(response) != 2 or response[0] != 0x05:
raise ConnectionError("Invalid SOCKS5 response") raise ConnectionError("Invalid SOCKS5 response")
return response[1] return response[1]
@staticmethod @staticmethod
async def authenticate( async def authenticate(writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
writer: asyncio.StreamWriter, username: str, password: str, timeout: int) -> None:
reader: asyncio.StreamReader,
username: str,
password: str,
timeout: int,
) -> None:
"""Perform username/password authentication.""" """Perform username/password authentication."""
username_bytes = username.encode("utf-8") username_bytes = username.encode('utf-8')
password_bytes = password.encode("utf-8") password_bytes = password.encode('utf-8')
auth_request = ( auth_request = (bytes([0x01, len(username_bytes)]) + username_bytes +
bytes([0x01, len(username_bytes)]) bytes([len(password_bytes)]) + password_bytes)
+ username_bytes
+ bytes([len(password_bytes)])
+ password_bytes
)
writer.write(auth_request) writer.write(auth_request)
await writer.drain() await writer.drain()
auth_response = await asyncio.wait_for(reader.read(2), timeout) auth_response = await asyncio.wait_for(reader.read(2), timeout)
if len(auth_response) != 2 or auth_response[1] != 0x00: if len(auth_response) != 2 or auth_response[1] != 0x00:
raise AuthenticationError("SOCKS5 authentication failed") raise AuthenticationError("SOCKS5 authentication failed")
@staticmethod @staticmethod
def build_connect_request(host: str, port: int) -> bytes: def build_connect_request(host: str, port: int) -> bytes:
"""Build SOCKS5 connection request.""" """Build SOCKS5 connection request."""
request = bytearray([0x05, 0x01, 0x00]) request = bytearray([0x05, 0x01, 0x00])
try: try:
ip = ipaddress.ip_address(host) ip = ipaddress.ip_address(host)
if isinstance(ip, ipaddress.IPv4Address): if isinstance(ip, ipaddress.IPv4Address):
@ -121,14 +101,14 @@ class SOCKS5Handler:
request.append(0x04) request.append(0x04)
request.extend(ip.packed) request.extend(ip.packed)
except ValueError: except ValueError:
host_bytes = host.encode("utf-8") host_bytes = host.encode('utf-8')
request.append(0x03) request.append(0x03)
request.append(len(host_bytes)) request.append(len(host_bytes))
request.extend(host_bytes) request.extend(host_bytes)
request.extend(struct.pack(">H", port)) request.extend(struct.pack('>H', port))
return bytes(request) return bytes(request)
@staticmethod @staticmethod
async def read_bound_address(reader: asyncio.StreamReader, addr_type: int) -> None: async def read_bound_address(reader: asyncio.StreamReader, addr_type: int) -> None:
"""Read bound address from SOCKS5 response.""" """Read bound address from SOCKS5 response."""
@ -139,25 +119,17 @@ class SOCKS5Handler:
await reader.read(domain_len + 2) await reader.read(domain_len + 2)
elif addr_type == 0x04: elif addr_type == 0x04:
await reader.read(18) await reader.read(18)
@classmethod @classmethod
async def handshake( async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
cls, destination: Tuple[str, int], *, username: Optional[str] = None,
writer: asyncio.StreamWriter, password: Optional[str] = None, timeout: int = 10) -> None:
reader: asyncio.StreamReader,
destination: Tuple[str, int],
username: Optional[str] = None,
password: Optional[str] = None,
timeout: int = 10,
) -> None:
"""Perform complete SOCKS5 handshake.""" """Perform complete SOCKS5 handshake."""
host, port = destination host, port = destination
# Authentication negotiation # Authentication negotiation
selected_method = await cls.negotiate_auth( selected_method = await cls.negotiate_auth(writer, reader, username, password, timeout)
writer, reader, username, password, timeout
)
# Handle authentication # Handle authentication
if selected_method == 0x02: if selected_method == 0x02:
if not username or not password: if not username or not password:
@ -165,62 +137,50 @@ class SOCKS5Handler:
await cls.authenticate(writer, reader, username, password, timeout) await cls.authenticate(writer, reader, username, password, timeout)
elif selected_method != 0x00: elif selected_method != 0x00:
raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}") raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}")
# Connection request # Connection request
request = cls.build_connect_request(host, port) request = cls.build_connect_request(host, port)
writer.write(request) writer.write(request)
await writer.drain() await writer.drain()
# Read connection response # Read connection response
conn_response = await asyncio.wait_for(reader.read(4), timeout) conn_response = await asyncio.wait_for(reader.read(4), timeout)
if ( if len(conn_response) != 4 or conn_response[0] != 0x05 or conn_response[1] != 0x00:
len(conn_response) != 4
or conn_response[0] != 0x05
or conn_response[1] != 0x00
):
raise ConnectionError("SOCKS5 connection failed") raise ConnectionError("SOCKS5 connection failed")
# Read bound address # Read bound address
await cls.read_bound_address(reader, conn_response[3]) await cls.read_bound_address(reader, conn_response[3])
class SOCKS4Handler: class SOCKS4Handler:
"""Handles SOCKS4 proxy operations.""" """Handles SOCKS4 proxy operations."""
@staticmethod @staticmethod
def build_request(host: str, port: int) -> bytes: def build_request(host: str, port: int) -> bytes:
"""Build SOCKS4 connection request.""" """Build SOCKS4 connection request."""
try: try:
ip = ipaddress.IPv4Address(host) ip = ipaddress.IPv4Address(host)
ip_bytes = ip.packed ip_bytes = ip.packed
request = struct.pack(">BBH", 0x04, 0x01, port) + ip_bytes + b"pyrogram\x00" request = struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + b'pyrogram\x00'
except ValueError: except ValueError:
ip_bytes = b"\x00\x00\x00\x01" ip_bytes = b'\x00\x00\x00\x01'
request = ( request = (struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes +
struct.pack(">BBH", 0x04, 0x01, port) b'pyrogram\x00' + host.encode('utf-8') + b'\x00')
+ ip_bytes
+ b"pyrogram\x00" @staticmethod
+ host.encode("utf-8") def build_request(host: str, port: int, username: Optional[str] = None,
+ b"\x00" password: Optional[str] = None) -> str:
)
return request
@classmethod @classmethod
async def handshake( async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
cls, destination: Tuple[str, int], timeout: int = 10) -> None:
writer: asyncio.StreamWriter,
reader: asyncio.StreamReader,
destination: Tuple[str, int],
timeout: int = 10,
) -> None:
"""Perform SOCKS4 handshake.""" """Perform SOCKS4 handshake."""
host, port = destination host, port = destination
request = cls.build_request(host, port) request = cls.build_request(host, port)
writer.write(request) writer.write(request)
await writer.drain() await writer.drain()
response = await asyncio.wait_for(reader.read(8), timeout) response = await asyncio.wait_for(reader.read(8), timeout)
if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A: if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A:
raise ConnectionError("SOCKS4 connection failed") raise ConnectionError("SOCKS4 connection failed")
@ -228,23 +188,27 @@ class SOCKS4Handler:
class HTTPProxyHandler: class HTTPProxyHandler:
"""Handles HTTP proxy operations.""" """Handles HTTP proxy operations."""
@staticmethod @staticmethod
def build_request( def sanitize_request(request: str) -> str:
host: str, """Sanitize HTTP request to prevent injection attacks."""
port: int, if '\r\n\r\n' not in request:
username: Optional[str] = None, raise ValueError("Invalid HTTP request format")
password: Optional[str] = None,
) -> str: lines = request.split('\r\n')
if not lines[0].startswith('CONNECT ') or 'HTTP/1.1' not in lines[0]:
raise ValueError("Invalid CONNECT request")
return request
"""Build HTTP CONNECT request.""" """Build HTTP CONNECT request."""
request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n" request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n"
if username and password: if username and password:
credentials = base64.b64encode(f"{username}:{password}".encode()).decode() credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
request += f"Proxy-Authorization: Basic {credentials}\r\n" request += f"Proxy-Authorization: Basic {credentials}\r\n"
return request + "\r\n" return request + "\r\n"
@staticmethod @staticmethod
async def read_response(reader: asyncio.StreamReader, timeout: int) -> list: async def read_response(reader: asyncio.StreamReader, timeout: int) -> list:
"""Read HTTP proxy response.""" """Read HTTP proxy response."""
@ -253,39 +217,34 @@ class HTTPProxyHandler:
line = await asyncio.wait_for(reader.readline(), timeout) line = await asyncio.wait_for(reader.readline(), timeout)
if not line: if not line:
raise ConnectionError("HTTP proxy connection closed") raise ConnectionError("HTTP proxy connection closed")
line = line.decode("utf-8", errors="ignore").strip() line = line.decode('utf-8', errors='ignore').strip()
response_lines.append(line) response_lines.append(line)
if not line: if not line:
break break
return response_lines return response_lines
@classmethod @classmethod
async def handshake( async def handshake(cls, writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
cls, destination: Tuple[str, int], *, username: Optional[str] = None,
writer: asyncio.StreamWriter, password: Optional[str] = None, timeout: int = 10) -> None:
reader: asyncio.StreamReader,
destination: Tuple[str, int],
username: Optional[str] = None,
password: Optional[str] = None,
timeout: int = 10,
) -> None:
"""Perform HTTP proxy handshake.""" """Perform HTTP proxy handshake."""
host, port = destination host, port = destination
request = cls.build_request(host, port, username, password) request = cls.build_request(host, port, username, password)
sanitized_request = cls.sanitize_request(request)
writer.write(request.encode())
writer.write(sanitized_request.encode())
await writer.drain() await writer.drain()
response_lines = await cls.read_response(reader, timeout) response_lines = await cls.read_response(reader, timeout)
if not response_lines or not response_lines[0].startswith("HTTP/1."): if not response_lines or not response_lines[0].startswith('HTTP/1.'):
raise ConnectionError("Invalid HTTP proxy response") raise ConnectionError("Invalid HTTP proxy response")
status_parts = response_lines[0].split(" ", 2) status_parts = response_lines[0].split(' ', 2)
if len(status_parts) < 2 or status_parts[1] != "200": if len(status_parts) < 2 or status_parts[1] != '200':
raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}") raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}")
@ -301,81 +260,83 @@ class TCP:
self.lock = asyncio.Lock() self.lock = asyncio.Lock()
def _get_proxy_handler( def _get_proxy_handler(self, scheme: str) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]:
self, scheme: str
) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]:
"""Get appropriate proxy handler based on scheme.""" """Get appropriate proxy handler based on scheme."""
handlers = { handlers = {
"SOCKS5": SOCKS5Handler, "SOCKS5": SOCKS5Handler,
"SOCKS4": SOCKS4Handler, "SOCKS4": SOCKS4Handler,
"HTTP": HTTPProxyHandler, "HTTP": HTTPProxyHandler
} }
handler_class = handlers.get(scheme.upper()) handler_class = handlers.get(scheme.upper())
if not handler_class: if not handler_class:
raise ValueError(f"Unknown proxy type {scheme}") raise ValueError(f"Unknown proxy type {scheme}")
return handler_class return handler_class
def _validate_proxy_config(self) -> None:
"""Validate proxy configuration."""
if not self.proxy.get("scheme"):
raise ValueError("No scheme specified")
if self.proxy["scheme"].upper() not in proxy_type_by_scheme:
raise ValueError(f"Unknown proxy type {self.proxy['scheme']}")
async def _establish_proxy_connection(self, hostname: str, port: int, proxy_family: int) -> None:
"""Establish connection to proxy server."""
self.reader, self.writer = await asyncio.wait_for(
asyncio.open_connection(host=hostname, port=port, family=proxy_family),
timeout=self.TIMEOUT
)
def _get_proxy_family(self, hostname: str) -> int:
"""Determine address family for proxy connection."""
try:
ip_address = ipaddress.ip_address(hostname)
return (socket.AF_INET6 if isinstance(ip_address, ipaddress.IPv6Address)
else socket.AF_INET)
except ValueError:
return socket.AF_INET
async def _perform_handshake(self, scheme: str, destination: Tuple[str, int],
username: Optional[str], password: Optional[str]) -> None:
"""Perform proxy handshake based on scheme."""
handler = self._get_proxy_handler(scheme)
if scheme.upper() == "SOCKS5":
await handler.handshake(self.writer, self.reader, destination,
username=username, password=password, timeout=self.TIMEOUT)
elif scheme.upper() == "SOCKS4":
await handler.handshake(self.writer, self.reader, destination, self.TIMEOUT)
elif scheme.upper() == "HTTP":
await handler.handshake(self.writer, self.reader, destination,
username=username, password=password, timeout=self.TIMEOUT)
async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None: async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None:
"""Connect through proxy server.""" """Connect through proxy server."""
scheme = self.proxy.get("scheme") self._validate_proxy_config()
if not scheme:
raise ValueError("No scheme specified") scheme = self.proxy["scheme"]
hostname = self.proxy.get("hostname") hostname = self.proxy.get("hostname")
port = self.proxy.get("port") port = self.proxy.get("port")
username = self.proxy.get("username") username = self.proxy.get("username")
password = self.proxy.get("password") password = self.proxy.get("password")
# Determine proxy address family proxy_family = self._get_proxy_family(hostname)
try:
ip_address = ipaddress.ip_address(hostname)
proxy_family = (
socket.AF_INET6
if isinstance(ip_address, ipaddress.IPv6Address)
else socket.AF_INET
)
except ValueError:
proxy_family = socket.AF_INET
try: try:
# Connect to proxy server await self._establish_proxy_connection(hostname, port, proxy_family)
self.reader, self.writer = await asyncio.wait_for( await self._perform_handshake(scheme, destination, username, password)
asyncio.open_connection(host=hostname, port=port, family=proxy_family), except (ConnectionError, AuthenticationError, ValueError):
timeout=self.TIMEOUT,
)
# Perform proxy handshake
handler = self._get_proxy_handler(scheme)
if scheme.upper() == "SOCKS5":
await handler.handshake(
self.writer,
self.reader,
destination,
username,
password,
self.TIMEOUT,
)
elif scheme.upper() == "SOCKS4":
await handler.handshake(
self.writer, self.reader, destination, self.TIMEOUT
)
elif scheme.upper() == "HTTP":
await handler.handshake(
self.writer,
self.reader,
destination,
username,
password,
self.TIMEOUT,
)
except Exception:
if self.writer: if self.writer:
self.writer.close() self.writer.close()
await self.writer.wait_closed() await self.writer.wait_closed()
raise raise
except Exception as e:
if self.writer:
self.writer.close()
await self.writer.wait_closed()
raise ConnectionError(f"Proxy connection failed: {e}") from e
async def _connect_via_direct(self, destination: Tuple[str, int]) -> None: async def _connect_via_direct(self, destination: Tuple[str, int]) -> None:
"""Connect directly to destination.""" """Connect directly to destination."""
@ -429,7 +390,8 @@ class TCP:
while len(data) < length: while len(data) < length:
try: try:
chunk = await asyncio.wait_for( chunk = await asyncio.wait_for(
self.reader.read(length - len(data)), self.TIMEOUT self.reader.read(length - len(data)),
self.TIMEOUT
) )
except (OSError, asyncio.TimeoutError): except (OSError, asyncio.TimeoutError):
return None return None
@ -439,4 +401,4 @@ class TCP:
else: else:
return None return None
return data return data