diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py index 8eba760a..a3197d0c 100644 --- a/pyrogram/client/client.py +++ b/pyrogram/client/client.py @@ -16,6 +16,7 @@ # You should have received a copy of the GNU Lesser General Public License # along with Pyrogram. If not, see . +import asyncio import base64 import binascii import getpass @@ -28,7 +29,6 @@ import re import shutil import struct import tempfile -import threading import time from configparser import ConfigParser from datetime import datetime @@ -43,11 +43,11 @@ from pyrogram.api.errors import ( PhoneCodeExpired, PhoneCodeEmpty, SessionPasswordNeeded, PasswordHashInvalid, FloodWait, PeerIdInvalid, FirstnameInvalid, PhoneNumberBanned, VolumeLocNotFound, UserMigrate, FileIdInvalid) -from pyrogram.client.handlers import DisconnectHandler from pyrogram.crypto import AES from pyrogram.session import Auth, Session from .dispatcher import Dispatcher -from .ext import utils, Syncer, BaseClient +from .ext import BaseClient, Syncer, utils +from .handlers import DisconnectHandler from .methods import Methods # Custom format for nice looking log lines @@ -114,7 +114,7 @@ class Client(Methods, BaseClient): be an empty string: "". Only applicable for new sessions. workers (``int``, *optional*): - Thread pool size for handling incoming updates. Defaults to 4. + Number of maximum concurrent workers for handling incoming updates. Defaults to 4. workdir (``str``, *optional*): Define a custom working directory. The working directory is the location in your filesystem @@ -168,15 +168,10 @@ class Client(Methods, BaseClient): self._proxy["enabled"] = True self._proxy.update(value) - async def start(self, debug: bool = False): + async def start(self): """Use this method to start the Client after creating it. Requires no parameters. - Args: - debug (``bool``, *optional*): - Enable or disable debug mode. When enabled, extra logging - lines will be printed out on your console. - Raises: :class:`Error ` """ @@ -188,7 +183,7 @@ class Client(Methods, BaseClient): self.session_name = self.session_name.split(":")[0] self.load_config() - self.load_session() + await self.load_session() self.session = Session( self.dc_id, @@ -204,9 +199,9 @@ class Client(Methods, BaseClient): if self.user_id is None: if self.token is None: - self.authorize_user() + await self.authorize_user() else: - self.authorize_bot() + await self.authorize_bot() self.save_session() @@ -217,38 +212,27 @@ class Client(Methods, BaseClient): self.peers_by_username = {} self.peers_by_phone = {} - self.get_dialogs() - self.get_contacts() + await self.get_dialogs() + await self.get_contacts() else: - self.send(functions.messages.GetPinnedDialogs()) - self.get_dialogs_chunk(0) + await self.send(functions.messages.GetPinnedDialogs()) + await self.get_dialogs_chunk(0) else: await self.send(functions.updates.GetState()) - # for i in range(self.UPDATES_WORKERS): - # self.updates_workers_list.append( - # Thread( - # target=self.updates_worker, - # name="UpdatesWorker#{}".format(i + 1) - # ) - # ) - # - # self.updates_workers_list[-1].start() - # - # for i in range(self.DOWNLOAD_WORKERS): - # self.download_workers_list.append( - # Thread( - # target=self.download_worker, - # name="DownloadWorker#{}".format(i + 1) - # ) - # ) - # - # self.download_workers_list[-1].start() - # - # self.dispatcher.start() + self.updates_worker_task = asyncio.ensure_future(self.updates_worker()) + + for _ in range(Client.DOWNLOAD_WORKERS): + self.download_worker_tasks.append( + asyncio.ensure_future(self.download_worker()) + ) + + log.info("Started {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS)) + + await self.dispatcher.start() + await Syncer.add(self) mimetypes.init() - # Syncer.add(self) async def stop(self): """Use this method to manually stop the Client. @@ -257,29 +241,26 @@ class Client(Methods, BaseClient): if not self.is_started: raise ConnectionError("Client is already stopped") - # Syncer.remove(self) - # self.dispatcher.stop() - # - # for _ in range(self.DOWNLOAD_WORKERS): - # self.download_queue.put(None) - # - # for i in self.download_workers_list: - # i.join() - # - # self.download_workers_list.clear() - # - # for _ in range(self.UPDATES_WORKERS): - # self.updates_queue.put(None) - # - # for i in self.updates_workers_list: - # i.join() - # - # self.updates_workers_list.clear() - # - # for i in self.media_sessions.values(): - # i.stop() - # - # self.media_sessions.clear() + await Syncer.remove(self) + await self.dispatcher.stop() + + for _ in range(Client.DOWNLOAD_WORKERS): + self.download_queue.put_nowait(None) + + for task in self.download_worker_tasks: + await task + + self.download_worker_tasks.clear() + + log.info("Stopped {} DownloadWorkerTasks".format(Client.DOWNLOAD_WORKERS)) + + self.updates_queue.put_nowait(None) + await self.updates_worker_task + + for media_session in self.media_sessions.values(): + await media_session.stop() + + self.media_sessions.clear() self.is_started = False await self.session.stop() @@ -327,9 +308,9 @@ class Client(Methods, BaseClient): else: self.dispatcher.remove_handler(handler, group) - def authorize_bot(self): + async def authorize_bot(self): try: - r = self.send( + r = await self.send( functions.auth.ImportBotAuthorization( flags=0, api_id=self.api_id, @@ -338,10 +319,10 @@ class Client(Methods, BaseClient): ) ) except UserMigrate as e: - self.session.stop() + await self.session.stop() self.dc_id = e.x - self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create() + self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create() self.session = Session( self.dc_id, @@ -352,12 +333,12 @@ class Client(Methods, BaseClient): client=self ) - self.session.start() - self.authorize_bot() + await self.session.start() + await self.authorize_bot() else: self.user_id = r.user.id - def authorize_user(self): + async def authorize_user(self): phone_number_invalid_raises = self.phone_number is not None phone_code_invalid_raises = self.phone_code is not None password_hash_invalid_raises = self.password is not None @@ -378,7 +359,7 @@ class Client(Methods, BaseClient): self.phone_number = self.phone_number.strip("+") try: - r = self.send( + r = await self.send( functions.auth.SendCode( self.phone_number, self.api_id, @@ -386,10 +367,10 @@ class Client(Methods, BaseClient): ) ) except (PhoneMigrate, NetworkMigrate) as e: - self.session.stop() + await self.session.stop() self.dc_id = e.x - self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create() + self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create() self.session = Session( self.dc_id, @@ -399,9 +380,9 @@ class Client(Methods, BaseClient): self.api_id, client=self ) - self.session.start() + await self.session.start() - r = self.send( + r = await self.send( functions.auth.SendCode( self.phone_number, self.api_id, @@ -430,7 +411,7 @@ class Client(Methods, BaseClient): phone_code_hash = r.phone_code_hash if self.force_sms: - self.send( + await self.send( functions.auth.ResendCode( phone_number=self.phone_number, phone_code_hash=phone_code_hash @@ -446,7 +427,7 @@ class Client(Methods, BaseClient): try: if phone_registered: - r = self.send( + r = await self.send( functions.auth.SignIn( self.phone_number, phone_code_hash, @@ -455,7 +436,7 @@ class Client(Methods, BaseClient): ) else: try: - self.send( + await self.send( functions.auth.SignIn( self.phone_number, phone_code_hash, @@ -468,7 +449,7 @@ class Client(Methods, BaseClient): self.first_name = self.first_name if self.first_name is not None else input("First name: ") self.last_name = self.last_name if self.last_name is not None else input("Last name: ") - r = self.send( + r = await self.send( functions.auth.SignUp( self.phone_number, phone_code_hash, @@ -491,7 +472,7 @@ class Client(Methods, BaseClient): self.first_name = None except SessionPasswordNeeded as e: print(e.MESSAGE) - r = self.send(functions.account.GetPassword()) + r = await self.send(functions.account.GetPassword()) while True: try: @@ -505,7 +486,7 @@ class Client(Methods, BaseClient): password_hash = sha256(self.password).digest() - r = self.send(functions.auth.CheckPassword(password_hash)) + r = await self.send(functions.auth.CheckPassword(password_hash)) except PasswordHashInvalid as e: if password_hash_invalid_raises: raise @@ -594,12 +575,9 @@ class Client(Methods, BaseClient): if username is not None: self.peers_by_username[username.lower()] = input_peer - def download_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) - + async def download_worker(self): while True: - media = self.download_queue.get() + media = await self.download_queue.get() if media is None: break @@ -666,7 +644,7 @@ class Client(Methods, BaseClient): extension ) - temp_file_path = self.get_file( + temp_file_path = await self.get_file( dc_id=dc_id, id=id, access_hash=access_hash, @@ -697,14 +675,11 @@ class Client(Methods, BaseClient): finally: done.set() - log.debug("{} stopped".format(name)) - - def updates_worker(self): - name = threading.current_thread().name - log.debug("{} started".format(name)) + async def updates_worker(self): + log.info("UpdatesWorkerTask started") while True: - updates = self.updates_queue.get() + updates = await self.updates_queue.get() if updates is None: break @@ -730,9 +705,9 @@ class Client(Methods, BaseClient): message = update.message if not isinstance(message, types.MessageEmpty): - diff = self.send( + diff = await self.send( functions.updates.GetChannelDifference( - channel=self.resolve_peer(int("-100" + str(channel_id))), + channel=await self.resolve_peer(int("-100" + str(channel_id))), filter=types.ChannelMessagesFilter( ranges=[types.MessageRange( min_id=update.message.id, @@ -760,9 +735,9 @@ class Client(Methods, BaseClient): if len(self.channels_pts[channel_id]) > 50: self.channels_pts[channel_id] = self.channels_pts[channel_id][25:] - self.dispatcher.updates.put((update, updates.users, updates.chats)) + self.dispatcher.updates.put_nowait((update, updates.users, updates.chats)) elif isinstance(updates, (types.UpdateShortMessage, types.UpdateShortChatMessage)): - diff = self.send( + diff = await self.send( functions.updates.GetDifference( pts=updates.pts - updates.pts_count, date=updates.date, @@ -771,7 +746,7 @@ class Client(Methods, BaseClient): ) if diff.new_messages: - self.dispatcher.updates.put(( + self.dispatcher.updates.put_nowait(( types.UpdateNewMessage( message=diff.new_messages[0], pts=updates.pts, @@ -781,18 +756,19 @@ class Client(Methods, BaseClient): diff.chats )) else: - self.dispatcher.updates.put((diff.other_updates[0], [], [])) + self.dispatcher.updates.put_nowait((diff.other_updates[0], [], [])) elif isinstance(updates, types.UpdateShort): - self.dispatcher.updates.put((updates.update, [], [])) + self.dispatcher.updates.put_nowait((updates.update, [], [])) except Exception as e: log.error(e, exc_info=True) - log.debug("{} stopped".format(name)) + log.info("UpdatesWorkerTask stopped") def signal_handler(self, *args): + log.info("Stop signal received ({}). Exiting...".format(args[0])) self.is_idle = False - def idle(self, stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)): + async def idle(self, stop_signals: tuple = (SIGINT, SIGTERM, SIGABRT)): """Blocks the program execution until one of the signals are received, then gently stop the Client by closing the underlying connection. @@ -807,9 +783,9 @@ class Client(Methods, BaseClient): self.is_idle = True while self.is_idle: - time.sleep(1) + await asyncio.sleep(1) - self.stop() + await self.stop() async def send(self, data: Object): """Use this method to send Raw Function queries. @@ -863,14 +839,14 @@ class Client(Methods, BaseClient): self._proxy["username"] = parser.get("proxy", "username", fallback=None) or None self._proxy["password"] = parser.get("proxy", "password", fallback=None) or None - def load_session(self): + async def load_session(self): try: with open(os.path.join(self.workdir, "{}.session".format(self.session_name)), encoding="utf-8") as f: s = json.load(f) except FileNotFoundError: self.dc_id = 1 self.date = 0 - self.auth_key = Auth(self.dc_id, self.test_mode, self._proxy).create() + self.auth_key = await Auth(self.dc_id, self.test_mode, self._proxy).create() else: self.dc_id = s["dc_id"] self.test_mode = s["test_mode"] @@ -912,10 +888,10 @@ class Client(Methods, BaseClient): indent=4 ) - def get_dialogs_chunk(self, offset_date): + async def get_dialogs_chunk(self, offset_date): while True: try: - r = self.send( + r = await self.send( functions.messages.GetDialogs( offset_date, 0, types.InputPeerEmpty(), self.DIALOGS_AT_ONCE, True @@ -923,24 +899,24 @@ class Client(Methods, BaseClient): ) except FloodWait as e: log.warning("get_dialogs flood: waiting {} seconds".format(e.x)) - time.sleep(e.x) + await asyncio.sleep(e.x) else: log.info("Total peers: {}".format(len(self.peers_by_id))) return r - def get_dialogs(self): - self.send(functions.messages.GetPinnedDialogs()) + async def get_dialogs(self): + await self.send(functions.messages.GetPinnedDialogs()) - dialogs = self.get_dialogs_chunk(0) + dialogs = await self.get_dialogs_chunk(0) offset_date = utils.get_offset_date(dialogs) while len(dialogs.dialogs) == self.DIALOGS_AT_ONCE: - dialogs = self.get_dialogs_chunk(offset_date) + dialogs = await self.get_dialogs_chunk(offset_date) offset_date = utils.get_offset_date(dialogs) - self.get_dialogs_chunk(0) + await self.get_dialogs_chunk(0) - def resolve_peer(self, peer_id: int or str): + async def resolve_peer(self, peer_id: int or str): """Use this method to get the *InputPeer* of a known *peer_id*. It is intended to be used when working with Raw Functions (i.e: a Telegram API method you wish to use which is @@ -968,7 +944,7 @@ class Client(Methods, BaseClient): try: decoded = base64.b64decode(match.group(1) + "=" * (-len(match.group(1)) % 4), "-_") - return self.resolve_peer(struct.unpack(">2iq", decoded)[1]) + return await self.resolve_peer(struct.unpack(">2iq", decoded)[1]) except (AttributeError, binascii.Error, struct.error): pass @@ -980,7 +956,7 @@ class Client(Methods, BaseClient): try: return self.peers_by_username[peer_id] except KeyError: - self.send(functions.contacts.ResolveUsername(peer_id)) + await self.send(functions.contacts.ResolveUsername(peer_id)) return self.peers_by_username[peer_id] else: try: @@ -1007,12 +983,12 @@ class Client(Methods, BaseClient): except (KeyError, ValueError): raise PeerIdInvalid - def save_file(self, - path: str, - file_id: int = None, - file_part: int = 0, - progress: callable = None, - progress_args: tuple = ()): + async def save_file(self, + path: str, + file_id: int = None, + file_part: int = 0, + progress: callable = None, + progress_args: tuple = ()): part_size = 512 * 1024 file_size = os.path.getsize(path) file_total_parts = int(math.ceil(file_size / part_size)) @@ -1022,7 +998,7 @@ class Client(Methods, BaseClient): md5_sum = md5() if not is_big and not is_missing_part else None session = Session(self.dc_id, self.test_mode, self._proxy, self.auth_key, self.api_id) - session.start() + await session.start() try: with open(path, "rb") as f: @@ -1050,7 +1026,7 @@ class Client(Methods, BaseClient): bytes=chunk ) - assert self.send(rpc), "Couldn't upload file" + assert await session.send(rpc), "Couldn't upload file" if is_missing_part: return @@ -1080,25 +1056,25 @@ class Client(Methods, BaseClient): md5_checksum=md5_sum ) finally: - session.stop() + await session.stop() - def get_file(self, - dc_id: int, - id: int = None, - access_hash: int = None, - volume_id: int = None, - local_id: int = None, - secret: int = None, - version: int = 0, - size: int = None, - progress: callable = None, - progress_args: tuple = None) -> str: - with self.media_sessions_lock: + async def get_file(self, + dc_id: int, + id: int = None, + access_hash: int = None, + volume_id: int = None, + local_id: int = None, + secret: int = None, + version: int = 0, + size: int = None, + progress: callable = None, + progress_args: tuple = None) -> str: + with await self.media_sessions_lock: session = self.media_sessions.get(dc_id, None) if session is None: if dc_id != self.dc_id: - exported_auth = self.send( + exported_auth = await self.send( functions.auth.ExportAuthorization( dc_id=dc_id ) @@ -1108,15 +1084,15 @@ class Client(Methods, BaseClient): dc_id, self.test_mode, self._proxy, - Auth(dc_id, self.test_mode, self._proxy).create(), + await Auth(dc_id, self.test_mode, self._proxy).create(), self.api_id ) - session.start() + await session.start() self.media_sessions[dc_id] = session - session.send( + await session.send( functions.auth.ImportAuthorization( id=exported_auth.id, bytes=exported_auth.bytes @@ -1131,7 +1107,7 @@ class Client(Methods, BaseClient): self.api_id ) - session.start() + await session.start() self.media_sessions[dc_id] = session @@ -1153,7 +1129,7 @@ class Client(Methods, BaseClient): file_name = "" try: - r = session.send( + r = await session.send( functions.upload.GetFile( location=location, offset=offset, @@ -1180,7 +1156,7 @@ class Client(Methods, BaseClient): if progress: progress(self, min(offset, size), size, *progress_args) - r = session.send( + r = await session.send( functions.upload.GetFile( location=location, offset=offset, @@ -1189,7 +1165,7 @@ class Client(Methods, BaseClient): ) elif isinstance(r, types.upload.FileCdnRedirect): - with self.media_sessions_lock: + with await self.media_sessions_lock: cdn_session = self.media_sessions.get(r.dc_id, None) if cdn_session is None: @@ -1197,12 +1173,12 @@ class Client(Methods, BaseClient): r.dc_id, self.test_mode, self._proxy, - Auth(r.dc_id, self.test_mode, self._proxy).create(), + await Auth(r.dc_id, self.test_mode, self._proxy).create(), self.api_id, is_cdn=True ) - cdn_session.start() + await cdn_session.start() self.media_sessions[r.dc_id] = cdn_session @@ -1211,7 +1187,7 @@ class Client(Methods, BaseClient): file_name = f.name while True: - r2 = cdn_session.send( + r2 = await cdn_session.send( functions.upload.GetCdnFile( file_token=r.file_token, offset=offset, @@ -1221,7 +1197,7 @@ class Client(Methods, BaseClient): if isinstance(r2, types.upload.CdnFileReuploadNeeded): try: - session.send( + await session.send( functions.upload.ReuploadCdnFile( file_token=r.file_token, request_token=r2.request_token @@ -1244,7 +1220,7 @@ class Client(Methods, BaseClient): ) ) - hashes = session.send( + hashes = await session.send( functions.upload.GetCdnFileHashes( r.file_token, offset