mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2026-01-03 14:04:51 +00:00
fix
This commit is contained in:
parent
9fad774949
commit
19eb456ad2
2 changed files with 684 additions and 161 deletions
|
|
@ -22,7 +22,7 @@ import base64
|
||||||
import ipaddress
|
import ipaddress
|
||||||
import socket
|
import socket
|
||||||
import struct
|
import struct
|
||||||
from typing import Tuple, Dict, TypedDict, Optional
|
from typing import Tuple, Dict, TypedDict, Optional, Union
|
||||||
|
|
||||||
import socks
|
import socks
|
||||||
|
|
||||||
|
|
@ -41,6 +41,254 @@ class Proxy(TypedDict):
|
||||||
password: Optional[str]
|
password: Optional[str]
|
||||||
|
|
||||||
|
|
||||||
|
class ProxyError(Exception):
|
||||||
|
"""Base exception for proxy-related errors."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class AuthenticationError(ProxyError):
|
||||||
|
"""Authentication failed."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class ConnectionError(ProxyError):
|
||||||
|
"""Connection failed."""
|
||||||
|
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS5Handler:
|
||||||
|
"""Handles SOCKS5 proxy operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def negotiate_auth(
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
username: Optional[str],
|
||||||
|
password: Optional[str],
|
||||||
|
timeout: int,
|
||||||
|
) -> None:
|
||||||
|
"""Handle SOCKS5 authentication negotiation."""
|
||||||
|
auth_methods = b"\x05\x02\x00\x02" if username and password else b"\x05\x01\x00"
|
||||||
|
|
||||||
|
writer.write(auth_methods)
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
response = await asyncio.wait_for(reader.read(2), timeout)
|
||||||
|
if len(response) != 2 or response[0] != 0x05:
|
||||||
|
raise ConnectionError("Invalid SOCKS5 response")
|
||||||
|
|
||||||
|
return response[1]
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def authenticate(
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
username: str,
|
||||||
|
password: str,
|
||||||
|
timeout: int,
|
||||||
|
) -> None:
|
||||||
|
"""Perform username/password authentication."""
|
||||||
|
username_bytes = username.encode("utf-8")
|
||||||
|
password_bytes = password.encode("utf-8")
|
||||||
|
auth_request = (
|
||||||
|
bytes([0x01, len(username_bytes)])
|
||||||
|
+ username_bytes
|
||||||
|
+ bytes([len(password_bytes)])
|
||||||
|
+ password_bytes
|
||||||
|
)
|
||||||
|
|
||||||
|
writer.write(auth_request)
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
auth_response = await asyncio.wait_for(reader.read(2), timeout)
|
||||||
|
if len(auth_response) != 2 or auth_response[1] != 0x00:
|
||||||
|
raise AuthenticationError("SOCKS5 authentication failed")
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_connect_request(host: str, port: int) -> bytes:
|
||||||
|
"""Build SOCKS5 connection request."""
|
||||||
|
request = bytearray([0x05, 0x01, 0x00])
|
||||||
|
|
||||||
|
try:
|
||||||
|
ip = ipaddress.ip_address(host)
|
||||||
|
if isinstance(ip, ipaddress.IPv4Address):
|
||||||
|
request.append(0x01)
|
||||||
|
request.extend(ip.packed)
|
||||||
|
else:
|
||||||
|
request.append(0x04)
|
||||||
|
request.extend(ip.packed)
|
||||||
|
except ValueError:
|
||||||
|
host_bytes = host.encode("utf-8")
|
||||||
|
request.append(0x03)
|
||||||
|
request.append(len(host_bytes))
|
||||||
|
request.extend(host_bytes)
|
||||||
|
|
||||||
|
request.extend(struct.pack(">H", port))
|
||||||
|
return bytes(request)
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def read_bound_address(reader: asyncio.StreamReader, addr_type: int) -> None:
|
||||||
|
"""Read bound address from SOCKS5 response."""
|
||||||
|
if addr_type == 0x01:
|
||||||
|
await reader.read(6)
|
||||||
|
elif addr_type == 0x03:
|
||||||
|
domain_len = (await reader.read(1))[0]
|
||||||
|
await reader.read(domain_len + 2)
|
||||||
|
elif addr_type == 0x04:
|
||||||
|
await reader.read(18)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def handshake(
|
||||||
|
cls,
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
destination: Tuple[str, int],
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
timeout: int = 10,
|
||||||
|
) -> None:
|
||||||
|
"""Perform complete SOCKS5 handshake."""
|
||||||
|
host, port = destination
|
||||||
|
|
||||||
|
# Authentication negotiation
|
||||||
|
selected_method = await cls.negotiate_auth(
|
||||||
|
writer, reader, username, password, timeout
|
||||||
|
)
|
||||||
|
|
||||||
|
# Handle authentication
|
||||||
|
if selected_method == 0x02:
|
||||||
|
if not username or not password:
|
||||||
|
raise ConnectionError("SOCKS5 server requires authentication")
|
||||||
|
await cls.authenticate(writer, reader, username, password, timeout)
|
||||||
|
elif selected_method != 0x00:
|
||||||
|
raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}")
|
||||||
|
|
||||||
|
# Connection request
|
||||||
|
request = cls.build_connect_request(host, port)
|
||||||
|
writer.write(request)
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
# Read connection response
|
||||||
|
conn_response = await asyncio.wait_for(reader.read(4), timeout)
|
||||||
|
if (
|
||||||
|
len(conn_response) != 4
|
||||||
|
or conn_response[0] != 0x05
|
||||||
|
or conn_response[1] != 0x00
|
||||||
|
):
|
||||||
|
raise ConnectionError("SOCKS5 connection failed")
|
||||||
|
|
||||||
|
# Read bound address
|
||||||
|
await cls.read_bound_address(reader, conn_response[3])
|
||||||
|
|
||||||
|
|
||||||
|
class SOCKS4Handler:
|
||||||
|
"""Handles SOCKS4 proxy operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_request(host: str, port: int) -> bytes:
|
||||||
|
"""Build SOCKS4 connection request."""
|
||||||
|
try:
|
||||||
|
ip = ipaddress.IPv4Address(host)
|
||||||
|
ip_bytes = ip.packed
|
||||||
|
request = struct.pack(">BBH", 0x04, 0x01, port) + ip_bytes + b"pyrogram\x00"
|
||||||
|
except ValueError:
|
||||||
|
ip_bytes = b"\x00\x00\x00\x01"
|
||||||
|
request = (
|
||||||
|
struct.pack(">BBH", 0x04, 0x01, port)
|
||||||
|
+ ip_bytes
|
||||||
|
+ b"pyrogram\x00"
|
||||||
|
+ host.encode("utf-8")
|
||||||
|
+ b"\x00"
|
||||||
|
)
|
||||||
|
|
||||||
|
return request
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def handshake(
|
||||||
|
cls,
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
destination: Tuple[str, int],
|
||||||
|
timeout: int = 10,
|
||||||
|
) -> None:
|
||||||
|
"""Perform SOCKS4 handshake."""
|
||||||
|
host, port = destination
|
||||||
|
request = cls.build_request(host, port)
|
||||||
|
|
||||||
|
writer.write(request)
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
response = await asyncio.wait_for(reader.read(8), timeout)
|
||||||
|
if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A:
|
||||||
|
raise ConnectionError("SOCKS4 connection failed")
|
||||||
|
|
||||||
|
|
||||||
|
class HTTPProxyHandler:
|
||||||
|
"""Handles HTTP proxy operations."""
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
def build_request(
|
||||||
|
host: str,
|
||||||
|
port: int,
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
) -> str:
|
||||||
|
"""Build HTTP CONNECT request."""
|
||||||
|
request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n"
|
||||||
|
|
||||||
|
if username and password:
|
||||||
|
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
||||||
|
request += f"Proxy-Authorization: Basic {credentials}\r\n"
|
||||||
|
|
||||||
|
return request + "\r\n"
|
||||||
|
|
||||||
|
@staticmethod
|
||||||
|
async def read_response(reader: asyncio.StreamReader, timeout: int) -> list:
|
||||||
|
"""Read HTTP proxy response."""
|
||||||
|
response_lines = []
|
||||||
|
while True:
|
||||||
|
line = await asyncio.wait_for(reader.readline(), timeout)
|
||||||
|
if not line:
|
||||||
|
raise ConnectionError("HTTP proxy connection closed")
|
||||||
|
|
||||||
|
line = line.decode("utf-8", errors="ignore").strip()
|
||||||
|
response_lines.append(line)
|
||||||
|
|
||||||
|
if not line:
|
||||||
|
break
|
||||||
|
|
||||||
|
return response_lines
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
async def handshake(
|
||||||
|
cls,
|
||||||
|
writer: asyncio.StreamWriter,
|
||||||
|
reader: asyncio.StreamReader,
|
||||||
|
destination: Tuple[str, int],
|
||||||
|
username: Optional[str] = None,
|
||||||
|
password: Optional[str] = None,
|
||||||
|
timeout: int = 10,
|
||||||
|
) -> None:
|
||||||
|
"""Perform HTTP proxy handshake."""
|
||||||
|
host, port = destination
|
||||||
|
request = cls.build_request(host, port, username, password)
|
||||||
|
|
||||||
|
writer.write(request.encode())
|
||||||
|
await writer.drain()
|
||||||
|
|
||||||
|
response_lines = await cls.read_response(reader, timeout)
|
||||||
|
|
||||||
|
if not response_lines or not response_lines[0].startswith("HTTP/1."):
|
||||||
|
raise ConnectionError("Invalid HTTP proxy response")
|
||||||
|
|
||||||
|
status_parts = response_lines[0].split(" ", 2)
|
||||||
|
if len(status_parts) < 2 or status_parts[1] != "200":
|
||||||
|
raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}")
|
||||||
|
|
||||||
|
|
||||||
class TCP:
|
class TCP:
|
||||||
TIMEOUT = 10
|
TIMEOUT = 10
|
||||||
|
|
||||||
|
|
@ -53,181 +301,76 @@ class TCP:
|
||||||
|
|
||||||
self.lock = asyncio.Lock()
|
self.lock = asyncio.Lock()
|
||||||
|
|
||||||
async def _socks5_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
|
def _get_proxy_handler(
|
||||||
destination: Tuple[str, int], username: Optional[str] = None,
|
self, scheme: str
|
||||||
password: Optional[str] = None) -> None:
|
) -> Union[SOCKS5Handler, SOCKS4Handler, HTTPProxyHandler]:
|
||||||
# Authentication negotiation
|
"""Get appropriate proxy handler based on scheme."""
|
||||||
if username and password:
|
handlers = {
|
||||||
auth_methods = b'\x05\x02\x00\x02'
|
"SOCKS5": SOCKS5Handler,
|
||||||
else:
|
"SOCKS4": SOCKS4Handler,
|
||||||
auth_methods = b'\x05\x01\x00'
|
"HTTP": HTTPProxyHandler,
|
||||||
|
}
|
||||||
writer.write(auth_methods)
|
|
||||||
await writer.drain()
|
|
||||||
|
|
||||||
response = await asyncio.wait_for(reader.read(2), self.TIMEOUT)
|
|
||||||
if len(response) != 2 or response[0] != 0x05:
|
|
||||||
raise ConnectionError("Invalid SOCKS5 response")
|
|
||||||
|
|
||||||
selected_method = response[1]
|
|
||||||
|
|
||||||
# Authentication if required
|
|
||||||
if selected_method == 0x02:
|
|
||||||
if not username or not password:
|
|
||||||
raise ConnectionError("SOCKS5 server requires authentication")
|
|
||||||
|
|
||||||
username_bytes = username.encode('utf-8')
|
|
||||||
password_bytes = password.encode('utf-8')
|
|
||||||
auth_request = bytes([0x01, len(username_bytes)]) + username_bytes + bytes([len(password_bytes)]) + password_bytes
|
|
||||||
|
|
||||||
writer.write(auth_request)
|
|
||||||
await writer.drain()
|
|
||||||
|
|
||||||
auth_response = await asyncio.wait_for(reader.read(2), self.TIMEOUT)
|
|
||||||
if len(auth_response) != 2 or auth_response[1] != 0x00:
|
|
||||||
raise ConnectionError("SOCKS5 authentication failed")
|
|
||||||
|
|
||||||
elif selected_method != 0x00:
|
|
||||||
raise ConnectionError(f"Unsupported SOCKS5 auth method: {selected_method}")
|
|
||||||
|
|
||||||
# Connection request
|
|
||||||
host, port = destination
|
|
||||||
request = bytearray([0x05, 0x01, 0x00])
|
|
||||||
|
|
||||||
try:
|
|
||||||
ip = ipaddress.ip_address(host)
|
|
||||||
if isinstance(ip, ipaddress.IPv4Address):
|
|
||||||
request.append(0x01)
|
|
||||||
request.extend(ip.packed)
|
|
||||||
else:
|
|
||||||
request.append(0x04)
|
|
||||||
request.extend(ip.packed)
|
|
||||||
except ValueError:
|
|
||||||
host_bytes = host.encode('utf-8')
|
|
||||||
request.append(0x03)
|
|
||||||
request.append(len(host_bytes))
|
|
||||||
request.extend(host_bytes)
|
|
||||||
|
|
||||||
request.extend(struct.pack('>H', port))
|
|
||||||
|
|
||||||
writer.write(request)
|
|
||||||
await writer.drain()
|
|
||||||
|
|
||||||
conn_response = await asyncio.wait_for(reader.read(4), self.TIMEOUT)
|
|
||||||
if len(conn_response) != 4 or conn_response[0] != 0x05 or conn_response[1] != 0x00:
|
|
||||||
raise ConnectionError("SOCKS5 connection failed")
|
|
||||||
|
|
||||||
# Read bound address
|
|
||||||
addr_type = conn_response[3]
|
|
||||||
if addr_type == 0x01:
|
|
||||||
await reader.read(6)
|
|
||||||
elif addr_type == 0x03:
|
|
||||||
domain_len = (await reader.read(1))[0]
|
|
||||||
await reader.read(domain_len + 2)
|
|
||||||
elif addr_type == 0x04:
|
|
||||||
await reader.read(18)
|
|
||||||
|
|
||||||
async def _socks4_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
|
handler_class = handlers.get(scheme.upper())
|
||||||
destination: Tuple[str, int]) -> None:
|
if not handler_class:
|
||||||
host, port = destination
|
raise ValueError(f"Unknown proxy type {scheme}")
|
||||||
|
|
||||||
try:
|
|
||||||
ip = ipaddress.IPv4Address(host)
|
|
||||||
ip_bytes = ip.packed
|
|
||||||
except ValueError:
|
|
||||||
ip_bytes = b'\x00\x00\x00\x01'
|
|
||||||
|
|
||||||
request = struct.pack('>BBH', 0x04, 0x01, port) + ip_bytes + b'pyrogram\x00'
|
|
||||||
|
|
||||||
if ip_bytes == b'\x00\x00\x00\x01':
|
|
||||||
request += host.encode('utf-8') + b'\x00'
|
|
||||||
|
|
||||||
writer.write(request)
|
|
||||||
await writer.drain()
|
|
||||||
|
|
||||||
response = await asyncio.wait_for(reader.read(8), self.TIMEOUT)
|
|
||||||
if len(response) != 8 or response[0] != 0x00 or response[1] != 0x5A:
|
|
||||||
raise ConnectionError("SOCKS4 connection failed")
|
|
||||||
|
|
||||||
async def _http_proxy_handshake(self, writer: asyncio.StreamWriter, reader: asyncio.StreamReader,
|
return handler_class
|
||||||
destination: Tuple[str, int], username: Optional[str] = None,
|
|
||||||
password: Optional[str] = None) -> None:
|
|
||||||
host, port = destination
|
|
||||||
|
|
||||||
request = f"CONNECT {host}:{port} HTTP/1.1\r\nHost: {host}:{port}\r\n"
|
|
||||||
|
|
||||||
if username and password:
|
|
||||||
credentials = base64.b64encode(f"{username}:{password}".encode()).decode()
|
|
||||||
request += f"Proxy-Authorization: Basic {credentials}\r\n"
|
|
||||||
|
|
||||||
request += "\r\n"
|
|
||||||
|
|
||||||
writer.write(request.encode())
|
|
||||||
await writer.drain()
|
|
||||||
|
|
||||||
response_lines = []
|
|
||||||
while True:
|
|
||||||
line = await asyncio.wait_for(reader.readline(), self.TIMEOUT)
|
|
||||||
if not line:
|
|
||||||
raise ConnectionError("HTTP proxy connection closed")
|
|
||||||
|
|
||||||
line = line.decode('utf-8', errors='ignore').strip()
|
|
||||||
response_lines.append(line)
|
|
||||||
|
|
||||||
if not line:
|
|
||||||
break
|
|
||||||
|
|
||||||
if not response_lines or not response_lines[0].startswith('HTTP/1.'):
|
|
||||||
raise ConnectionError("Invalid HTTP proxy response")
|
|
||||||
|
|
||||||
status_parts = response_lines[0].split(' ', 2)
|
|
||||||
if len(status_parts) < 2 or status_parts[1] != '200':
|
|
||||||
raise ConnectionError(f"HTTP proxy connection failed: {response_lines[0]}")
|
|
||||||
|
|
||||||
async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None:
|
async def _connect_via_proxy(self, destination: Tuple[str, int]) -> None:
|
||||||
|
"""Connect through proxy server."""
|
||||||
scheme = self.proxy.get("scheme")
|
scheme = self.proxy.get("scheme")
|
||||||
if scheme is None:
|
if not scheme:
|
||||||
raise ValueError("No scheme specified")
|
raise ValueError("No scheme specified")
|
||||||
|
|
||||||
proxy_type = proxy_type_by_scheme.get(scheme.upper())
|
|
||||||
if proxy_type is None:
|
|
||||||
raise ValueError(f"Unknown proxy type {scheme}")
|
|
||||||
|
|
||||||
hostname = self.proxy.get("hostname")
|
hostname = self.proxy.get("hostname")
|
||||||
port = self.proxy.get("port")
|
port = self.proxy.get("port")
|
||||||
username = self.proxy.get("username")
|
username = self.proxy.get("username")
|
||||||
password = self.proxy.get("password")
|
password = self.proxy.get("password")
|
||||||
|
|
||||||
|
# Determine proxy address family
|
||||||
try:
|
try:
|
||||||
ip_address = ipaddress.ip_address(hostname)
|
ip_address = ipaddress.ip_address(hostname)
|
||||||
|
proxy_family = (
|
||||||
|
socket.AF_INET6
|
||||||
|
if isinstance(ip_address, ipaddress.IPv6Address)
|
||||||
|
else socket.AF_INET
|
||||||
|
)
|
||||||
except ValueError:
|
except ValueError:
|
||||||
is_proxy_ipv6 = False
|
proxy_family = socket.AF_INET
|
||||||
else:
|
|
||||||
is_proxy_ipv6 = isinstance(ip_address, ipaddress.IPv6Address)
|
|
||||||
|
|
||||||
proxy_family = socket.AF_INET6 if is_proxy_ipv6 else socket.AF_INET
|
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Connect to proxy server
|
# Connect to proxy server
|
||||||
self.reader, self.writer = await asyncio.wait_for(
|
self.reader, self.writer = await asyncio.wait_for(
|
||||||
asyncio.open_connection(
|
asyncio.open_connection(host=hostname, port=port, family=proxy_family),
|
||||||
host=hostname,
|
timeout=self.TIMEOUT,
|
||||||
port=port,
|
|
||||||
family=proxy_family
|
|
||||||
),
|
|
||||||
timeout=self.TIMEOUT
|
|
||||||
)
|
)
|
||||||
|
|
||||||
# Perform proxy handshake
|
# Perform proxy handshake
|
||||||
if proxy_type == socks.SOCKS5:
|
handler = self._get_proxy_handler(scheme)
|
||||||
await self._socks5_handshake(self.writer, self.reader, destination, username, password)
|
if scheme.upper() == "SOCKS5":
|
||||||
elif proxy_type == socks.SOCKS4:
|
await handler.handshake(
|
||||||
await self._socks4_handshake(self.writer, self.reader, destination)
|
self.writer,
|
||||||
elif proxy_type == socks.HTTP:
|
self.reader,
|
||||||
await self._http_proxy_handshake(self.writer, self.reader, destination, username, password)
|
destination,
|
||||||
else:
|
username,
|
||||||
raise ValueError(f"Unsupported proxy type: {scheme}")
|
password,
|
||||||
|
self.TIMEOUT,
|
||||||
|
)
|
||||||
|
elif scheme.upper() == "SOCKS4":
|
||||||
|
await handler.handshake(
|
||||||
|
self.writer, self.reader, destination, self.TIMEOUT
|
||||||
|
)
|
||||||
|
elif scheme.upper() == "HTTP":
|
||||||
|
await handler.handshake(
|
||||||
|
self.writer,
|
||||||
|
self.reader,
|
||||||
|
destination,
|
||||||
|
username,
|
||||||
|
password,
|
||||||
|
self.TIMEOUT,
|
||||||
|
)
|
||||||
|
|
||||||
except Exception:
|
except Exception:
|
||||||
if self.writer:
|
if self.writer:
|
||||||
self.writer.close()
|
self.writer.close()
|
||||||
|
|
@ -235,27 +378,29 @@ class TCP:
|
||||||
raise
|
raise
|
||||||
|
|
||||||
async def _connect_via_direct(self, destination: Tuple[str, int]) -> None:
|
async def _connect_via_direct(self, destination: Tuple[str, int]) -> None:
|
||||||
|
"""Connect directly to destination."""
|
||||||
host, port = destination
|
host, port = destination
|
||||||
family = socket.AF_INET6 if self.ipv6 else socket.AF_INET
|
family = socket.AF_INET6 if self.ipv6 else socket.AF_INET
|
||||||
self.reader, self.writer = await asyncio.open_connection(
|
self.reader, self.writer = await asyncio.open_connection(
|
||||||
host=host,
|
host=host, port=port, family=family
|
||||||
port=port,
|
|
||||||
family=family
|
|
||||||
)
|
)
|
||||||
|
|
||||||
async def _connect(self, destination: Tuple[str, int]) -> None:
|
async def _connect(self, destination: Tuple[str, int]) -> None:
|
||||||
|
"""Establish connection (direct or via proxy)."""
|
||||||
if self.proxy:
|
if self.proxy:
|
||||||
await self._connect_via_proxy(destination)
|
await self._connect_via_proxy(destination)
|
||||||
else:
|
else:
|
||||||
await self._connect_via_direct(destination)
|
await self._connect_via_direct(destination)
|
||||||
|
|
||||||
async def connect(self, address: Tuple[str, int]) -> None:
|
async def connect(self, address: Tuple[str, int]) -> None:
|
||||||
|
"""Connect to the specified address."""
|
||||||
try:
|
try:
|
||||||
await asyncio.wait_for(self._connect(address), self.TIMEOUT)
|
await asyncio.wait_for(self._connect(address), self.TIMEOUT)
|
||||||
except asyncio.TimeoutError:
|
except asyncio.TimeoutError:
|
||||||
raise TimeoutError("Connection timed out")
|
raise TimeoutError("Connection timed out")
|
||||||
|
|
||||||
async def close(self) -> None:
|
async def close(self) -> None:
|
||||||
|
"""Close the connection."""
|
||||||
if self.writer is None:
|
if self.writer is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -266,6 +411,7 @@ class TCP:
|
||||||
pass
|
pass
|
||||||
|
|
||||||
async def send(self, data: bytes) -> None:
|
async def send(self, data: bytes) -> None:
|
||||||
|
"""Send data through the connection."""
|
||||||
if self.writer is None:
|
if self.writer is None:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
@ -277,13 +423,13 @@ class TCP:
|
||||||
raise OSError(e)
|
raise OSError(e)
|
||||||
|
|
||||||
async def recv(self, length: int = 0) -> Optional[bytes]:
|
async def recv(self, length: int = 0) -> Optional[bytes]:
|
||||||
|
"""Receive data from the connection."""
|
||||||
data = b""
|
data = b""
|
||||||
|
|
||||||
while len(data) < length:
|
while len(data) < length:
|
||||||
try:
|
try:
|
||||||
chunk = await asyncio.wait_for(
|
chunk = await asyncio.wait_for(
|
||||||
self.reader.read(length - len(data)),
|
self.reader.read(length - len(data)), self.TIMEOUT
|
||||||
self.TIMEOUT
|
|
||||||
)
|
)
|
||||||
except (OSError, asyncio.TimeoutError):
|
except (OSError, asyncio.TimeoutError):
|
||||||
return None
|
return None
|
||||||
|
|
@ -293,4 +439,4 @@ class TCP:
|
||||||
else:
|
else:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
return data
|
return data
|
||||||
|
|
|
||||||
377
pyrogram/connection/transport/tcp/test.py
Normal file
377
pyrogram/connection/transport/tcp/test.py
Normal file
|
|
@ -0,0 +1,377 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
"""
|
||||||
|
TCP Implementation Diagnostic Tool for Pyrofork
|
||||||
|
Deep analysis to identify specific issues in TCP transport
|
||||||
|
"""
|
||||||
|
|
||||||
|
import asyncio
|
||||||
|
import sys
|
||||||
|
import time
|
||||||
|
import traceback
|
||||||
|
import socket
|
||||||
|
import ssl
|
||||||
|
import os
|
||||||
|
from typing import Dict, Any, List, Tuple, Optional
|
||||||
|
import logging
|
||||||
|
|
||||||
|
# Setup path for local imports
|
||||||
|
current_dir = os.path.dirname(os.path.abspath(__file__))
|
||||||
|
root_dir = os.path.abspath(os.path.join(current_dir, '../../../../../'))
|
||||||
|
sys.path.insert(0, root_dir)
|
||||||
|
|
||||||
|
# Import TCP implementation
|
||||||
|
try:
|
||||||
|
from pyrogram.connection.transport.tcp import TCP
|
||||||
|
print("Successfully imported TCP from local source")
|
||||||
|
except ImportError as e:
|
||||||
|
print(f"Failed to import TCP: {e}")
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
# Configure detailed logging
|
||||||
|
logging.basicConfig(
|
||||||
|
level=logging.DEBUG,
|
||||||
|
format="%(asctime)s - %(name)s - %(levelname)s - %(message)s",
|
||||||
|
datefmt="%H:%M:%S"
|
||||||
|
)
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
class TCPDiagnostics:
|
||||||
|
"""Diagnostic tools for TCP implementation analysis"""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.telegram_servers = [
|
||||||
|
("149.154.175.50", 443, "DC1-Miami"),
|
||||||
|
("149.154.167.51", 443, "DC2-Amsterdam"),
|
||||||
|
("149.154.175.100", 443, "DC3-Miami"),
|
||||||
|
("149.154.167.91", 443, "DC4-Amsterdam"),
|
||||||
|
("91.108.56.130", 443, "DC5-Singapore"),
|
||||||
|
]
|
||||||
|
|
||||||
|
async def test_raw_socket_connection(self, host: str, port: int, name: str) -> Dict[str, Any]:
|
||||||
|
"""Test raw socket connection without Pyrofork TCP wrapper"""
|
||||||
|
print(f"\n[RAW SOCKET] Testing {name} ({host}:{port})")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
sock = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Create raw socket
|
||||||
|
sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)
|
||||||
|
sock.settimeout(10.0)
|
||||||
|
|
||||||
|
# Test connection
|
||||||
|
sock.connect((host, port))
|
||||||
|
connect_time = time.time() - start_time
|
||||||
|
|
||||||
|
# Test basic send/receive
|
||||||
|
test_data = b'GET / HTTP/1.1\r\nHost: ' + host.encode() + b'\r\n\r\n'
|
||||||
|
sock.send(test_data)
|
||||||
|
|
||||||
|
# Try to receive response
|
||||||
|
try:
|
||||||
|
response = sock.recv(1024)
|
||||||
|
received_data = len(response) > 0
|
||||||
|
except socket.timeout:
|
||||||
|
received_data = False
|
||||||
|
|
||||||
|
sock.close()
|
||||||
|
sock = None
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"connect_time": connect_time,
|
||||||
|
"received_data": received_data,
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"[RAW SOCKET] {name}: SUCCESS - {connect_time:.2f}s")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
if sock:
|
||||||
|
sock.close()
|
||||||
|
|
||||||
|
error_time = time.time() - start_time
|
||||||
|
print(f"[RAW SOCKET] {name}: FAILED - {str(e)} ({error_time:.2f}s)")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"connect_time": error_time,
|
||||||
|
"received_data": False,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def test_tcp_implementation(self, host: str, port: int, name: str) -> Dict[str, Any]:
|
||||||
|
"""Test Pyrofork TCP implementation with detailed logging"""
|
||||||
|
print(f"\n[PYROFORK TCP] Testing {name} ({host}:{port})")
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
tcp = None
|
||||||
|
|
||||||
|
try:
|
||||||
|
# Initialize TCP transport
|
||||||
|
print(f"[PYROFORK TCP] Initializing TCP transport...")
|
||||||
|
tcp = TCP(ipv6=False, proxy=None)
|
||||||
|
|
||||||
|
print(f"[PYROFORK TCP] TCP object created: {type(tcp)}")
|
||||||
|
print(f"[PYROFORK TCP] TCP attributes: {dir(tcp)}")
|
||||||
|
|
||||||
|
# Attempt connection
|
||||||
|
print(f"[PYROFORK TCP] Attempting connection...")
|
||||||
|
await asyncio.wait_for(tcp.connect((host, port)), timeout=10.0)
|
||||||
|
|
||||||
|
connect_time = time.time() - start_time
|
||||||
|
print(f"[PYROFORK TCP] Connection established in {connect_time:.2f}s")
|
||||||
|
|
||||||
|
# Test send operation
|
||||||
|
print(f"[PYROFORK TCP] Testing send operation...")
|
||||||
|
test_data = b'\x00\x00\x00\x00\x00\x00\x00\x00'
|
||||||
|
await tcp.send(test_data)
|
||||||
|
print(f"[PYROFORK TCP] Send operation successful")
|
||||||
|
|
||||||
|
# Test receive operation
|
||||||
|
print(f"[PYROFORK TCP] Testing receive operation...")
|
||||||
|
try:
|
||||||
|
response = await asyncio.wait_for(tcp.recv(8), timeout=3.0)
|
||||||
|
received_data = len(response) > 0 if response else False
|
||||||
|
print(f"[PYROFORK TCP] Received {len(response) if response else 0} bytes")
|
||||||
|
except asyncio.TimeoutError:
|
||||||
|
received_data = False
|
||||||
|
print(f"[PYROFORK TCP] Receive timeout (normal for test data)")
|
||||||
|
|
||||||
|
# Clean close
|
||||||
|
print(f"[PYROFORK TCP] Closing connection...")
|
||||||
|
await tcp.close()
|
||||||
|
tcp = None
|
||||||
|
print(f"[PYROFORK TCP] Connection closed successfully")
|
||||||
|
|
||||||
|
result = {
|
||||||
|
"success": True,
|
||||||
|
"connect_time": connect_time,
|
||||||
|
"received_data": received_data,
|
||||||
|
"error": None
|
||||||
|
}
|
||||||
|
|
||||||
|
print(f"[PYROFORK TCP] {name}: SUCCESS - {connect_time:.2f}s")
|
||||||
|
return result
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
error_time = time.time() - start_time
|
||||||
|
print(f"[PYROFORK TCP] {name}: FAILED - {str(e)}")
|
||||||
|
print(f"[PYROFORK TCP] Exception details:")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
# Cleanup
|
||||||
|
if tcp:
|
||||||
|
try:
|
||||||
|
await tcp.close()
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
|
||||||
|
return {
|
||||||
|
"success": False,
|
||||||
|
"connect_time": error_time,
|
||||||
|
"received_data": False,
|
||||||
|
"error": str(e)
|
||||||
|
}
|
||||||
|
|
||||||
|
async def inspect_tcp_implementation(self):
|
||||||
|
"""Inspect TCP class implementation details"""
|
||||||
|
print(f"\n[INSPECTION] Analyzing TCP implementation...")
|
||||||
|
|
||||||
|
try:
|
||||||
|
tcp = TCP(ipv6=False, proxy=None)
|
||||||
|
|
||||||
|
print(f"[INSPECTION] TCP class: {TCP}")
|
||||||
|
print(f"[INSPECTION] TCP module: {TCP.__module__}")
|
||||||
|
print(f"[INSPECTION] TCP file: {TCP.__module__.replace('.', os.sep)}.py")
|
||||||
|
|
||||||
|
# Check methods
|
||||||
|
methods = [attr for attr in dir(tcp) if not attr.startswith('_')]
|
||||||
|
print(f"[INSPECTION] Available methods: {methods}")
|
||||||
|
|
||||||
|
# Check for required methods
|
||||||
|
required_methods = ['connect', 'send', 'recv', 'close']
|
||||||
|
for method in required_methods:
|
||||||
|
if hasattr(tcp, method):
|
||||||
|
method_obj = getattr(tcp, method)
|
||||||
|
print(f"[INSPECTION] {method}: {method_obj} (callable: {callable(method_obj)})")
|
||||||
|
else:
|
||||||
|
print(f"[INSPECTION] MISSING METHOD: {method}")
|
||||||
|
|
||||||
|
# Check initialization parameters
|
||||||
|
print(f"[INSPECTION] TCP instance attributes:")
|
||||||
|
for attr in dir(tcp):
|
||||||
|
if not attr.startswith('__'):
|
||||||
|
try:
|
||||||
|
value = getattr(tcp, attr)
|
||||||
|
if not callable(value):
|
||||||
|
print(f"[INSPECTION] {attr}: {value}")
|
||||||
|
except:
|
||||||
|
print(f"[INSPECTION] {attr}: <unable to access>")
|
||||||
|
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[INSPECTION] Error during inspection: {e}")
|
||||||
|
traceback.print_exc()
|
||||||
|
|
||||||
|
async def test_concurrent_simple(self, host: str, port: int, count: int = 3) -> Dict[str, Any]:
|
||||||
|
"""Simple concurrent connection test"""
|
||||||
|
print(f"\n[CONCURRENT] Testing {count} concurrent connections to {host}:{port}")
|
||||||
|
|
||||||
|
async def single_connect():
|
||||||
|
try:
|
||||||
|
tcp = TCP(ipv6=False, proxy=None)
|
||||||
|
await asyncio.wait_for(tcp.connect((host, port)), timeout=5.0)
|
||||||
|
await tcp.close()
|
||||||
|
return True
|
||||||
|
except Exception as e:
|
||||||
|
print(f"[CONCURRENT] Connection failed: {str(e)[:50]}")
|
||||||
|
return False
|
||||||
|
|
||||||
|
start_time = time.time()
|
||||||
|
tasks = [single_connect() for _ in range(count)]
|
||||||
|
results = await asyncio.gather(*tasks, return_exceptions=True)
|
||||||
|
total_time = time.time() - start_time
|
||||||
|
|
||||||
|
successful = sum(1 for r in results if r is True)
|
||||||
|
success_rate = (successful / count) * 100
|
||||||
|
|
||||||
|
print(f"[CONCURRENT] Results: {successful}/{count} successful ({success_rate:.1f}%) in {total_time:.2f}s")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"successful": successful,
|
||||||
|
"total": count,
|
||||||
|
"success_rate": success_rate,
|
||||||
|
"total_time": total_time
|
||||||
|
}
|
||||||
|
|
||||||
|
async def run_diagnostic_suite(self):
|
||||||
|
"""Run comprehensive diagnostic tests"""
|
||||||
|
print("="*80)
|
||||||
|
print("TCP IMPLEMENTATION DIAGNOSTIC SUITE")
|
||||||
|
print("="*80)
|
||||||
|
|
||||||
|
# Phase 1: Inspect implementation
|
||||||
|
await self.inspect_tcp_implementation()
|
||||||
|
|
||||||
|
# Phase 2: Raw socket tests (baseline)
|
||||||
|
print(f"\n" + "="*50)
|
||||||
|
print("PHASE 1: RAW SOCKET BASELINE TESTS")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
raw_results = []
|
||||||
|
for host, port, name in self.telegram_servers:
|
||||||
|
result = await self.test_raw_socket_connection(host, port, name)
|
||||||
|
raw_results.append((name, result))
|
||||||
|
await asyncio.sleep(0.5)
|
||||||
|
|
||||||
|
# Phase 3: Pyrofork TCP tests
|
||||||
|
print(f"\n" + "="*50)
|
||||||
|
print("PHASE 2: PYROFORK TCP IMPLEMENTATION TESTS")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
tcp_results = []
|
||||||
|
for host, port, name in self.telegram_servers:
|
||||||
|
result = await self.test_tcp_implementation(host, port, name)
|
||||||
|
tcp_results.append((name, result))
|
||||||
|
await asyncio.sleep(1.0)
|
||||||
|
|
||||||
|
# Phase 4: Concurrent test
|
||||||
|
print(f"\n" + "="*50)
|
||||||
|
print("PHASE 3: CONCURRENT CONNECTION TEST")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
concurrent_result = await self.test_concurrent_simple(
|
||||||
|
self.telegram_servers[0][0],
|
||||||
|
self.telegram_servers[0][1]
|
||||||
|
)
|
||||||
|
|
||||||
|
# Analysis and recommendations
|
||||||
|
print(f"\n" + "="*50)
|
||||||
|
print("DIAGNOSTIC ANALYSIS")
|
||||||
|
print("="*50)
|
||||||
|
|
||||||
|
raw_success = sum(1 for _, r in raw_results if r["success"])
|
||||||
|
tcp_success = sum(1 for _, r in tcp_results if r["success"])
|
||||||
|
|
||||||
|
print(f"\nResults Summary:")
|
||||||
|
print(f" Raw Socket Success: {raw_success}/{len(raw_results)} ({raw_success/len(raw_results)*100:.1f}%)")
|
||||||
|
print(f" Pyrofork TCP Success: {tcp_success}/{len(tcp_results)} ({tcp_success/len(tcp_results)*100:.1f}%)")
|
||||||
|
print(f" Concurrent Success: {concurrent_result['success_rate']:.1f}%")
|
||||||
|
|
||||||
|
print(f"\nDiagnostic Findings:")
|
||||||
|
|
||||||
|
if raw_success < len(raw_results) * 0.8:
|
||||||
|
print(" - Network connectivity issues detected")
|
||||||
|
print(" - Check firewall, DNS resolution, or internet connection")
|
||||||
|
else:
|
||||||
|
print(" - Network connectivity is good (raw sockets work)")
|
||||||
|
|
||||||
|
if tcp_success < raw_success:
|
||||||
|
print(" - Pyrofork TCP implementation has issues")
|
||||||
|
print(" - TCP wrapper is failing where raw sockets succeed")
|
||||||
|
print(" - Likely issues: async/await implementation, connection state management")
|
||||||
|
elif tcp_success == raw_success:
|
||||||
|
print(" - Pyrofork TCP implementation matches raw socket performance")
|
||||||
|
print(" - Implementation appears correct")
|
||||||
|
|
||||||
|
if concurrent_result['success_rate'] > tcp_success / len(tcp_results) * 100:
|
||||||
|
print(" - Concurrent connections work better than sequential")
|
||||||
|
print(" - Possible timing or state management issue in sequential testing")
|
||||||
|
|
||||||
|
# Specific recommendations
|
||||||
|
print(f"\nRecommendations:")
|
||||||
|
|
||||||
|
if tcp_success < raw_success:
|
||||||
|
print(" 1. Review TCP.connect() implementation for async/await correctness")
|
||||||
|
print(" 2. Check socket state management and cleanup")
|
||||||
|
print(" 3. Verify timeout handling in async operations")
|
||||||
|
print(" 4. Test with simpler connection sequence")
|
||||||
|
|
||||||
|
if raw_success < len(raw_results):
|
||||||
|
print(" 1. Check network connectivity to Telegram servers")
|
||||||
|
print(" 2. Verify DNS resolution")
|
||||||
|
print(" 3. Check for firewall blocking connections")
|
||||||
|
|
||||||
|
return {
|
||||||
|
"raw_results": raw_results,
|
||||||
|
"tcp_results": tcp_results,
|
||||||
|
"concurrent_result": concurrent_result
|
||||||
|
}
|
||||||
|
|
||||||
|
|
||||||
|
async def main():
|
||||||
|
"""Main diagnostic execution"""
|
||||||
|
try:
|
||||||
|
diagnostics = TCPDiagnostics()
|
||||||
|
results = await diagnostics.run_diagnostic_suite()
|
||||||
|
|
||||||
|
# Determine if ready for debugging or needs environment fixes
|
||||||
|
tcp_success_rate = sum(1 for _, r in results["tcp_results"] if r["success"]) / len(results["tcp_results"])
|
||||||
|
raw_success_rate = sum(1 for _, r in results["raw_results"] if r["success"]) / len(results["raw_results"])
|
||||||
|
|
||||||
|
if raw_success_rate < 0.8:
|
||||||
|
print(f"\nStatus: Environment issues detected - fix network connectivity first")
|
||||||
|
sys.exit(2)
|
||||||
|
elif tcp_success_rate < 0.5:
|
||||||
|
print(f"\nStatus: TCP implementation issues detected - needs code review")
|
||||||
|
sys.exit(1)
|
||||||
|
else:
|
||||||
|
print(f"\nStatus: Implementation appears functional - investigate edge cases")
|
||||||
|
sys.exit(0)
|
||||||
|
|
||||||
|
except KeyboardInterrupt:
|
||||||
|
print("\nDiagnostic interrupted by user")
|
||||||
|
sys.exit(130)
|
||||||
|
except Exception as e:
|
||||||
|
print(f"\nUnexpected error in diagnostic: {str(e)}")
|
||||||
|
traceback.print_exc()
|
||||||
|
sys.exit(1)
|
||||||
|
|
||||||
|
|
||||||
|
if __name__ == "__main__":
|
||||||
|
print("TCP Implementation Diagnostic Tool")
|
||||||
|
print("Deep analysis of Pyrofork TCP transport issues")
|
||||||
|
print("-" * 60)
|
||||||
|
|
||||||
|
asyncio.run(main())
|
||||||
Loading…
Reference in a new issue