From 6fcf41d8572e5d64d31ae97f3b40dccc232ab5bd Mon Sep 17 00:00:00 2001
From: Dan <14043624+delivrance@users.noreply.github.com>
Date: Wed, 20 Jun 2018 11:41:22 +0200
Subject: [PATCH] Client becomes async
---
pyrogram/client/client.py | 276 +++++++++++++++++---------------------
1 file changed, 126 insertions(+), 150 deletions(-)
diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py
index 8eba760a..a3197d0c 100644
--- a/pyrogram/client/client.py
+++ b/pyrogram/client/client.py
@@ -16,6 +16,7 @@
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrogram. If not, see .
+import asyncio
import base64
import binascii
import getpass
@@ -28,7 +29,6 @@ import re
import shutil
import struct
import tempfile
-import threading
import time
from configparser import ConfigParser
from datetime import datetime
@@ -43,11 +43,11 @@ from pyrogram.api.errors import (
PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded,
PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned,
VolumeLocNotFound, UserMigrate, FileIdInvalid)
-from pyrogram.client.handlers import DisconnectHandler
from pyrogram.crypto import AES
from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher
-from .ext import utils, Syncer, BaseClient
+from .ext import BaseClient, Syncer, utils
+from .handlers import DisconnectHandler
from .methods import Methods
# Custom format for nice looking log lines
@@ -114,7 +114,7 @@ class Client(Methods, BaseClient):
be an empty string: "". Only applicable for new sessions.
workers (``int``, *optional*):
- Thread pool size for handling incoming updates. Defaults to 4.
+ Number of maximum concurrent workers for handling incoming updates. Defaults to 4.
workdir (``str``, *optional*):
Define a custom working directory. The working directory is the location in your filesystem
@@ -168,15 +168,10 @@ class Client(Methods, BaseClient):
self._proxy["enabled"] = True
self._proxy.update(value)
- async def start(self, debug: bool = False):
+ async def start(self):
"""Use this method to start the Client after creating it.
Requires no parameters.
- Args:
- debug (``bool``, *optional*):
- Enable or disable debug mode. When enabled, extra logging
- lines will be printed out on your console.
-
Raises:
:class:`Error `
"""
@@ -188,7 +183,7 @@ class Client(Methods, BaseClient):
self.session_name = self.session_name.split(":")[0]
self.load_config()
- self.load_session()
+ await self.load_session()
self.session = Session(
self.dc_id,
@@ -204,9 +199,9 @@ class Client(Methods, BaseClient):
if self.user_id is None:
if self.token is None:
- self.authorize_user()
+ await self.authorize_user()
else:
- self.authorize_bot()
+ await self.authorize_bot()
self.save_session()
@@ -217,38 +212,27 @@ class Client(Methods, BaseClient):
self.peers_by_username = {}
self.peers_by_phone = {}
- self.get_dialogs()
- self.get_contacts()
+ await self.get_dialogs()
+ await self.get_contacts()
else:
- self.send(functions.messages.GetPinnedDialogs())
- self.get_dialogs_chunk(0)
+ await self.send(functions.messages.GetPinnedDialogs())
+ await self.get_dialogs_chunk(0)
else:
await self.send(functions.updates.GetState())
- # for i in range(self.UPDATES_WORKERS):
- # self.updates_workers_list.append(
- # Thread(
- # target=self.updates_worker,
- # name="UpdatesWorker#{}".format(i + 1)
- # )
- # )
- #
- # self.updates_workers_list[-1].start()
- #
- # for i in range(self.DOWNLOAD_WORKERS):
- # self.download_workers_list.append(
- # Thread(
- # target=self.download_worker,
- # name="DownloadWorker#{}".format(i + 1)
- # )
- # )
- #
- # self.download_workers_list[-1].start()
- #
- # self.dispatcher.start()
+ self.updates_worker_task = asyncio.ensure_future(self.updates_worker())
+
+ for _ in range(Client.DOWNLOAD_WORKERS):
+ self.download_worker_tasks.append(
+ asyncio.ensure_future(self.download_worker())
+ )
+
+ log.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
+
+ await self.dispatcher.start()
+ await Syncer.add(self)
mimetypes.init()
- # Syncer.add(self)
async def stop(self):
"""Use this method to manually stop the Client.
@@ -257,29 +241,26 @@ class Client(Methods, BaseClient):
if not self.is_started:
raise ConnectionError("Client is already stopped")
- # Syncer.remove(self)
- # self.dispatcher.stop()
- #
- # for _ in range(self.DOWNLOAD_WORKERS):
- # self.download_queue.put(None)
- #
- # for i in self.download_workers_list:
- # i.join()
- #
- # self.download_workers_list.clear()
- #
- # for _ in range(self.UPDATES_WORKERS):
- # self.updates_queue.put(None)
- #
- # for i in self.updates_workers_list:
- # i.join()
- #
- # self.updates_workers_list.clear()
- #
- # for i in self.media_sessions.values():
- # i.stop()
- #
- # self.media_sessions.clear()
+ await Syncer.remove(self)
+ await self.dispatcher.stop()
+
+ for _ in range(Client.DOWNLOAD_WORKERS):
+ self.download_queue.put_nowait(None)
+
+ for task in self.download_worker_tasks:
+ await task
+
+ self.download_worker_tasks.clear()
+
+ log.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS))
+
+ self.updates_queue.put_nowait(None)
+ await self.updates_worker_task
+
+ for media_session in self.media_sessions.values():
+ await media_session.stop()
+
+ self.media_sessions.clear()
self.is_started = False
await self.session.stop()
@@ -327,9 +308,9 @@ class Client(Methods, BaseClient):
else:
self.dispatcher.remove_handler(handler, group)
- def authorize_bot(self):
+ async def authorize_bot(self):
try:
- r = self.send(
+ r = await self.send(
functions.auth.ImportBotAuthorization(
flags=0,
api_id=self.api_id,
@@ -338,10 +319,10 @@ class Client(Methods, BaseClient):
)
)
except UserMigrate as e:
- self.session.stop()
+ await self.session.stop()
self.dc_id = e.x
- self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create()
+ self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
self.session = Session(
self.dc_id,
@@ -352,12 +333,12 @@ class Client(Methods, BaseClient):
client=self
)
- self.session.start()
- self.authorize_bot()
+ await self.session.start()
+ await self.authorize_bot()
else:
self.user_id = r.user.id
- def authorize_user(self):
+ async def authorize_user(self):
phone_number_invalid_raises = self.phone_number is not None
phone_code_invalid_raises = self.phone_code is not None
password_hash_invalid_raises = self.password is not None
@@ -378,7 +359,7 @@ class Client(Methods, BaseClient):
self.phone_number = self.phone_number.strip("+")
try:
- r = self.send(
+ r = await self.send(
functions.auth.SendCode(
self.phone_number,
self.api_id,
@@ -386,10 +367,10 @@ class Client(Methods, BaseClient):
)
)
except (PhoneMigrate, NetworkMigrate) as e:
- self.session.stop()
+ await self.session.stop()
self.dc_id = e.x
- self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create()
+ self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
self.session = Session(
self.dc_id,
@@ -399,9 +380,9 @@ class Client(Methods, BaseClient):
self.api_id,
client=self
)
- self.session.start()
+ await self.session.start()
- r = self.send(
+ r = await self.send(
functions.auth.SendCode(
self.phone_number,
self.api_id,
@@ -430,7 +411,7 @@ class Client(Methods, BaseClient):
phone_code_hash = r.phone_code_hash
if self.force_sms:
- self.send(
+ await self.send(
functions.auth.ResendCode(
phone_number=self.phone_number,
phone_code_hash=phone_code_hash
@@ -446,7 +427,7 @@ class Client(Methods, BaseClient):
try:
if phone_registered:
- r = self.send(
+ r = await self.send(
functions.auth.SignIn(
self.phone_number,
phone_code_hash,
@@ -455,7 +436,7 @@ class Client(Methods, BaseClient):
)
else:
try:
- self.send(
+ await self.send(
functions.auth.SignIn(
self.phone_number,
phone_code_hash,
@@ -468,7 +449,7 @@ class Client(Methods, BaseClient):
self.first_name = self.first_name if self.first_name is not None else input("First name: ")
self.last_name = self.last_name if self.last_name is not None else input("Last name: ")
- r = self.send(
+ r = await self.send(
functions.auth.SignUp(
self.phone_number,
phone_code_hash,
@@ -491,7 +472,7 @@ class Client(Methods, BaseClient):
self.first_name = None
except SessionPasswordNeeded as e:
print(e.MESSAGE)
- r = self.send(functions.account.GetPassword())
+ r = await self.send(functions.account.GetPassword())
while True:
try:
@@ -505,7 +486,7 @@ class Client(Methods, BaseClient):
password_hash = sha256(self.password).digest()
- r = self.send(functions.auth.CheckPassword(password_hash))
+ r = await self.send(functions.auth.CheckPassword(password_hash))
except PasswordHashInvalid as e:
if password_hash_invalid_raises:
raise
@@ -594,12 +575,9 @@ class Client(Methods, BaseClient):
if username is not None:
self.peers_by_username[username.lower()] = input_peer
- def download_worker(self):
- name = threading.current_thread().name
- log.debug("{} started".format(name))
-
+ async def download_worker(self):
while True:
- media = self.download_queue.get()
+ media = await self.download_queue.get()
if media is None:
break
@@ -666,7 +644,7 @@ class Client(Methods, BaseClient):
extension
)
- temp_file_path = self.get_file(
+ temp_file_path = await self.get_file(
dc_id=dc_id,
id=id,
access_hash=access_hash,
@@ -697,14 +675,11 @@ class Client(Methods, BaseClient):
finally:
done.set()
- log.debug("{} stopped".format(name))
-
- def updates_worker(self):
- name = threading.current_thread().name
- log.debug("{} started".format(name))
+ async def updates_worker(self):
+ log.info("UpdatesWorkerTask started")
while True:
- updates = self.updates_queue.get()
+ updates = await self.updates_queue.get()
if updates is None:
break
@@ -730,9 +705,9 @@ class Client(Methods, BaseClient):
message = update.message
if not isinstance(message, types.MessageEmpty):
- diff = self.send(
+ diff = await self.send(
functions.updates.GetChannelDifference(
- channel=self.resolve_peer(int("-100" + str(channel_id))),
+ channel=await self.resolve_peer(int("-100" + str(channel_id))),
filter=types.ChannelMessagesFilter(
ranges=[types.MessageRange(
min_id=update.message.id,
@@ -760,9 +735,9 @@ class Client(Methods, BaseClient):
if len(self.channels_pts[channel_id]) > 50:
self.channels_pts[channel_id] = self.channels_pts[channel_id][25:]
- self.dispatcher.updates.put((update, updates.users, updates.chats))
+ self.dispatcher.updates.put_nowait((update, updates.users, updates.chats))
elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)):
- diff = self.send(
+ diff = await self.send(
functions.updates.GetDifference(
pts=updates.pts - updates.pts_count,
date=updates.date,
@@ -771,7 +746,7 @@ class Client(Methods, BaseClient):
)
if diff.new_messages:
- self.dispatcher.updates.put((
+ self.dispatcher.updates.put_nowait((
types.UpdateNewMessage(
message=diff.new_messages[0],
pts=updates.pts,
@@ -781,18 +756,19 @@ class Client(Methods, BaseClient):
diff.chats
))
else:
- self.dispatcher.updates.put((diff.other_updates[0], [], []))
+ self.dispatcher.updates.put_nowait((diff.other_updates[0], [], []))
elif isinstance(updates, types.UpdateShort):
- self.dispatcher.updates.put((updates.update, [], []))
+ self.dispatcher.updates.put_nowait((updates.update, [], []))
except Exception as e:
log.error(e, exc_info=True)
- log.debug("{} stopped".format(name))
+ log.info("UpdatesWorkerTask stopped")
def signal_handler(self, *args):
+ log.info("Stop signal received ({}). Exiting...".format(args[0]))
self.is_idle = False
- def idle(self, stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)):
+ async def idle(self, stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)):
"""Blocks the program execution until one of the signals are received,
then gently stop the Client by closing the underlying connection.
@@ -807,9 +783,9 @@ class Client(Methods, BaseClient):
self.is_idle = True
while self.is_idle:
- time.sleep(1)
+ await asyncio.sleep(1)
- self.stop()
+ await self.stop()
async def send(self, data: Object):
"""Use this method to send Raw Function queries.
@@ -863,14 +839,14 @@ class Client(Methods, BaseClient):
self._proxy["username"] = parser.get("proxy", "username", fallback=None) or None
self._proxy["password"] = parser.get("proxy", "password", fallback=None) or None
- def load_session(self):
+ async def load_session(self):
try:
with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f:
s = json.load(f)
except FileNotFoundError:
self.dc_id = 1
self.date = 0
- self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create()
+ self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create()
else:
self.dc_id = s["dc_id"]
self.test_mode = s["test_mode"]
@@ -912,10 +888,10 @@ class Client(Methods, BaseClient):
indent=4
)
- def get_dialogs_chunk(self, offset_date):
+ async def get_dialogs_chunk(self, offset_date):
while True:
try:
- r = self.send(
+ r = await self.send(
functions.messages.GetDialogs(
offset_date, 0, types.InputPeerEmpty(),
self.DIALOGS_AT_ONCE, True
@@ -923,24 +899,24 @@ class Client(Methods, BaseClient):
)
except FloodWait as e:
log.warning("get_dialogs flood: waiting {} seconds".format(e.x))
- time.sleep(e.x)
+ await asyncio.sleep(e.x)
else:
log.info("Total peers: {}".format(len(self.peers_by_id)))
return r
- def get_dialogs(self):
- self.send(functions.messages.GetPinnedDialogs())
+ async def get_dialogs(self):
+ await self.send(functions.messages.GetPinnedDialogs())
- dialogs = self.get_dialogs_chunk(0)
+ dialogs = await self.get_dialogs_chunk(0)
offset_date = utils.get_offset_date(dialogs)
while len(dialogs.dialogs) == self.DIALOGS_AT_ONCE:
- dialogs = self.get_dialogs_chunk(offset_date)
+ dialogs = await self.get_dialogs_chunk(offset_date)
offset_date = utils.get_offset_date(dialogs)
- self.get_dialogs_chunk(0)
+ await self.get_dialogs_chunk(0)
- def resolve_peer(self, peer_id: int or str):
+ async def resolve_peer(self, peer_id: int or str):
"""Use this method to get the *InputPeer* of a known *peer_id*.
It is intended to be used when working with Raw Functions (i.e: a Telegram API method you wish to use which is
@@ -968,7 +944,7 @@ class Client(Methods, BaseClient):
try:
decoded = base64.b64decode(match.group(1) + "=" * (-len(match.group(1)) % 4), "-_")
- return self.resolve_peer(struct.unpack(">2iq", decoded)[1])
+ return await self.resolve_peer(struct.unpack(">2iq", decoded)[1])
except (AttributeError, binascii.Error, struct.error):
pass
@@ -980,7 +956,7 @@ class Client(Methods, BaseClient):
try:
return self.peers_by_username[peer_id]
except KeyError:
- self.send(functions.contacts.ResolveUsername(peer_id))
+ await self.send(functions.contacts.ResolveUsername(peer_id))
return self.peers_by_username[peer_id]
else:
try:
@@ -1007,12 +983,12 @@ class Client(Methods, BaseClient):
except (KeyError, ValueError):
raise PeerIdInvalid
- def save_file(self,
- path: str,
- file_id: int = None,
- file_part: int = 0,
- progress: callable = None,
- progress_args: tuple = ()):
+ async def save_file(self,
+ path: str,
+ file_id: int = None,
+ file_part: int = 0,
+ progress: callable = None,
+ progress_args: tuple = ()):
part_size = 512 * 1024
file_size = os.path.getsize(path)
file_total_parts = int(math.ceil(file_size / part_size))
@@ -1022,7 +998,7 @@ class Client(Methods, BaseClient):
md5_sum = md5() if not is_big and not is_missing_part else None
session = Session(self.dc_id, self.test_mode, self._proxy, self.auth_key, self.api_id)
- session.start()
+ await session.start()
try:
with open(path, "rb") as f:
@@ -1050,7 +1026,7 @@ class Client(Methods, BaseClient):
bytes=chunk
)
- assert self.send(rpc), "Couldn't upload file"
+ assert await session.send(rpc), "Couldn't upload file"
if is_missing_part:
return
@@ -1080,25 +1056,25 @@ class Client(Methods, BaseClient):
md5_checksum=md5_sum
)
finally:
- session.stop()
+ await session.stop()
- def get_file(self,
- dc_id: int,
- id: int = None,
- access_hash: int = None,
- volume_id: int = None,
- local_id: int = None,
- secret: int = None,
- version: int = 0,
- size: int = None,
- progress: callable = None,
- progress_args: tuple = None) -> str:
- with self.media_sessions_lock:
+ async def get_file(self,
+ dc_id: int,
+ id: int = None,
+ access_hash: int = None,
+ volume_id: int = None,
+ local_id: int = None,
+ secret: int = None,
+ version: int = 0,
+ size: int = None,
+ progress: callable = None,
+ progress_args: tuple = None) -> str:
+ with await self.media_sessions_lock:
session = self.media_sessions.get(dc_id, None)
if session is None:
if dc_id != self.dc_id:
- exported_auth = self.send(
+ exported_auth = await self.send(
functions.auth.ExportAuthorization(
dc_id=dc_id
)
@@ -1108,15 +1084,15 @@ class Client(Methods, BaseClient):
dc_id,
self.test_mode,
self._proxy,
- Auth(dc_id, self.test_mode, self._proxy).create(),
+ await Auth(dc_id, self.test_mode, self._proxy).create(),
self.api_id
)
- session.start()
+ await session.start()
self.media_sessions[dc_id] = session
- session.send(
+ await session.send(
functions.auth.ImportAuthorization(
id=exported_auth.id,
bytes=exported_auth.bytes
@@ -1131,7 +1107,7 @@ class Client(Methods, BaseClient):
self.api_id
)
- session.start()
+ await session.start()
self.media_sessions[dc_id] = session
@@ -1153,7 +1129,7 @@ class Client(Methods, BaseClient):
file_name = ""
try:
- r = session.send(
+ r = await session.send(
functions.upload.GetFile(
location=location,
offset=offset,
@@ -1180,7 +1156,7 @@ class Client(Methods, BaseClient):
if progress:
progress(self, min(offset, size), size, *progress_args)
- r = session.send(
+ r = await session.send(
functions.upload.GetFile(
location=location,
offset=offset,
@@ -1189,7 +1165,7 @@ class Client(Methods, BaseClient):
)
elif isinstance(r, types.upload.FileCdnRedirect):
- with self.media_sessions_lock:
+ with await self.media_sessions_lock:
cdn_session = self.media_sessions.get(r.dc_id, None)
if cdn_session is None:
@@ -1197,12 +1173,12 @@ class Client(Methods, BaseClient):
r.dc_id,
self.test_mode,
self._proxy,
- Auth(r.dc_id, self.test_mode, self._proxy).create(),
+ await Auth(r.dc_id, self.test_mode, self._proxy).create(),
self.api_id,
is_cdn=True
)
- cdn_session.start()
+ await cdn_session.start()
self.media_sessions[r.dc_id] = cdn_session
@@ -1211,7 +1187,7 @@ class Client(Methods, BaseClient):
file_name = f.name
while True:
- r2 = cdn_session.send(
+ r2 = await cdn_session.send(
functions.upload.GetCdnFile(
file_token=r.file_token,
offset=offset,
@@ -1221,7 +1197,7 @@ class Client(Methods, BaseClient):
if isinstance(r2, types.upload.CdnFileReuploadNeeded):
try:
- session.send(
+ await session.send(
functions.upload.ReuploadCdnFile(
file_token=r.file_token,
request_token=r2.request_token
@@ -1244,7 +1220,7 @@ class Client(Methods, BaseClient):
)
)
- hashes = session.send(
+ hashes = await session.send(
functions.upload.GetCdnFileHashes(
r.file_token,
offset