Start refactoring Connection to accommodate asyncio

This commit is contained in:
Dan 2018-06-08 13:10:07 +02:00
parent 244b4f15ce
commit de39c181ef

View file

@ -16,9 +16,8 @@
# You should have received a copy of the GNU Lesser General Public License # You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import logging import logging
import threading
import time
from .transport import * from .transport import *
@ -36,23 +35,23 @@ class Connection:
4: TCPIntermediateO 4: TCPIntermediateO
} }
def __init__(self, address: tuple, proxy: dict, mode: int = 1): def __init__(self, address: tuple, proxy: dict, mode: int = 2):
self.address = address self.address = address
self.proxy = proxy self.proxy = proxy
self.mode = self.MODES.get(mode, TCPAbridged) self.mode = self.MODES.get(mode, TCPAbridged)
self.lock = threading.Lock()
self.connection = None self.connection = None
def connect(self): async def connect(self):
for i in range(Connection.MAX_RETRIES): for i in range(Connection.MAX_RETRIES):
self.connection = self.mode(self.proxy) self.connection = self.mode(self.proxy)
try: try:
log.info("Connecting...") log.info("Connecting...")
self.connection.connect(self.address) await self.connection.connect(self.address)
except OSError: except OSError:
self.connection.close() self.connection.close()
time.sleep(1) await asyncio.sleep(1)
else: else:
break break
else: else:
@ -62,9 +61,8 @@ class Connection:
self.connection.close() self.connection.close()
log.info("Disconnected") log.info("Disconnected")
def send(self, data: bytes): async def send(self, data: bytes):
with self.lock: await self.connection.send(data)
self.connection.sendall(data)
def recv(self) -> bytes or None: async def recv(self) -> bytes or None:
return self.connection.recvall() return await self.connection.recv()