From 4aa4d1a74a13a8ccb75311d68f5c04815e777df2 Mon Sep 17 00:00:00 2001 From: KurimuzonAkuma Date: Wed, 6 Mar 2024 15:28:00 +0300 Subject: [PATCH] pyrofork: Add skip_updates parameter to Client class Signed-off-by: wulan17 --- pyrogram/client.py | 29 ++++++++++- pyrogram/dispatcher.py | 82 ++++++++++++++++++++++++++++++ pyrogram/storage/file_storage.py | 18 +++++++ pyrogram/storage/mongo_storage.py | 11 ++++ pyrogram/storage/sqlite_storage.py | 29 ++++++++++- pyrogram/storage/storage.py | 16 ++++++ 6 files changed, 183 insertions(+), 2 deletions(-) diff --git a/pyrogram/client.py b/pyrogram/client.py index 1034db9b..d7c55050 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -174,6 +174,10 @@ class Client(Methods): Useful for batch programs that don't need to deal with updates. Defaults to False (updates enabled and received). + skip_updates (``bool``, *optional*): + Pass True to skip pending updates that arrived while the client was offline. + Defaults to True. + takeout (``bool``, *optional*): Pass True to let the client use a takeout session instead of a normal one, implies *no_updates=True*. Useful for exporting Telegram data. Methods invoked inside a takeout session (such as get_chat_history, @@ -248,7 +252,8 @@ class Client(Methods): plugins: Optional[dict] = None, parse_mode: "enums.ParseMode" = enums.ParseMode.DEFAULT, no_updates: Optional[bool] = None, - takeout: Optional[bool] = None, + skip_updates: bool = True, + takeout: bool = None, sleep_threshold: int = Session.SLEEP_THRESHOLD, hide_password: Optional[bool] = False, max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS, @@ -279,6 +284,7 @@ class Client(Methods): self.plugins = plugins self.parse_mode = parse_mode self.no_updates = no_updates + self.skip_updates = skip_updates self.takeout = takeout self.sleep_threshold = sleep_threshold self.hide_password = hide_password @@ -607,6 +613,17 @@ class Client(Methods): pts = getattr(update, "pts", None) pts_count = getattr(update, "pts_count", None) + if pts: + await self.storage.update_state( + ( + utils.get_channel_id(channel_id) if channel_id else self.me.id, + pts, + None, + updates.date, + None + ) + ) + if isinstance(update, raw.types.UpdateChannelTooLong): log.info(update) @@ -637,6 +654,16 @@ class Client(Methods): self.dispatcher.updates_queue.put_nowait((update, users, chats)) elif isinstance(updates, (raw.types.UpdateShortMessage, raw.types.UpdateShortChatMessage)): + await self.storage.update_state( + ( + self.me.id, + updates.pts, + None, + updates.date, + None + ) + ) + diff = await self.invoke( raw.functions.updates.GetDifference( pts=updates.pts - updates.pts_count, diff --git a/pyrogram/dispatcher.py b/pyrogram/dispatcher.py index b97bdb6f..99c31be0 100644 --- a/pyrogram/dispatcher.py +++ b/pyrogram/dispatcher.py @@ -24,6 +24,7 @@ from collections import OrderedDict import pyrogram from pyrogram import utils +from pyrogram import raw from pyrogram.handlers import ( BotBusinessConnectHandler, BotBusinessMessageHandler, @@ -251,6 +252,87 @@ class Dispatcher: log.info("Started %s HandlerTasks", self.client.workers) + if not self.client.skip_updates: + states = await self.client.storage.update_state() + + if not states: + log.info("No states found, skipping recovery.") + return + + message_updates_counter = 0 + other_updates_counter = 0 + + for state in states: + id, local_pts, _, local_date, _ = state + + prev_pts = 0 + + while True: + diff = await self.client.invoke( + raw.functions.updates.GetDifference( + pts=local_pts, + date=local_date, + qts=0 + ) if id == self.client.me.id else + raw.functions.updates.GetChannelDifference( + channel=await self.client.resolve_peer(id), + filter=raw.types.ChannelMessagesFilterEmpty(), + pts=local_pts, + limit=10000 + ) + ) + + if isinstance(diff, (raw.types.updates.DifferenceEmpty, raw.types.updates.ChannelDifferenceEmpty)): + break + elif isinstance(diff, (raw.types.updates.DifferenceTooLong, raw.types.updates.ChannelDifferenceTooLong)): + break + elif isinstance(diff, raw.types.updates.ChannelDifference): + local_pts = diff.pts + elif isinstance(diff, raw.types.updates.Difference): + local_pts = diff.state.pts + elif isinstance(diff, raw.types.updates.DifferenceSlice): + local_pts = diff.intermediate_state.pts + local_date = diff.intermediate_state.date + + if prev_pts == local_pts: + break + + prev_pts = local_pts + + users = {i.id: i for i in diff.users} + chats = {i.id: i for i in diff.chats} + + for message in diff.new_messages: + message_updates_counter += 1 + self.updates_queue.put_nowait( + ( + raw.types.UpdateNewMessage( + message=message, + pts=local_pts, + pts_count=-1 + ) if id == self.client.me.id else + raw.types.UpdateNewChannelMessage( + message=message, + pts=local_pts, + pts_count=-1 + ), + users, + chats + ) + ) + + for update in diff.other_updates: + other_updates_counter += 1 + self.updates_queue.put_nowait( + (update, users, chats) + ) + + if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)): + break + + await self.client.storage.update_state(None) + log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter) + async def stop(self): if not self.client.no_updates: for i in range(self.client.workers): diff --git a/pyrogram/storage/file_storage.py b/pyrogram/storage/file_storage.py index 244866dc..031cb4ac 100644 --- a/pyrogram/storage/file_storage.py +++ b/pyrogram/storage/file_storage.py @@ -27,6 +27,18 @@ from .sqlite_storage import SQLiteStorage log = logging.getLogger(__name__) +UPDATE_STATE_SCHEMA = """ +CREATE TABLE update_state +( + id INTEGER PRIMARY KEY, + pts INTEGER, + qts INTEGER, + date INTEGER, + seq INTEGER +); +""" + + class FileStorage(SQLiteStorage): FILE_EXTENSION = ".session" @@ -50,6 +62,12 @@ class FileStorage(SQLiteStorage): version += 1 + if version == 3: + with self.conn: + self.conn.executescript(UPDATE_STATE_SCHEMA) + + version += 1 + self.version(version) async def open(self): diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py index fd72f564..39952dee 100644 --- a/pyrogram/storage/mongo_storage.py +++ b/pyrogram/storage/mongo_storage.py @@ -76,6 +76,7 @@ class MongoStorage(Storage): self._peer = database['peers'] self._session = database['session'] self._usernames = database['usernames'] + self._states = database['update_state'] self._remove_peers = remove_peers async def open(self): @@ -167,6 +168,16 @@ class MongoStorage(Storage): bulk ) + async def update_state(self, value: Tuple[int, int, int, int, int] = object): + if value == object: + states = [[state['_id'],state['pts'],state['qts'],state['date'],state['seq']] async for state in self._states.find()] + return states if len(states) > 0 else None + else: + if value is None: + await self._states.drop() + else: + await self._states.update_one({'_id': value[0]}, {'$set': {'pts': value[1], 'qts': value[2], 'date': value[3], 'seq': value[4]}}, upsert=True) + async def get_peer_by_id(self, peer_id: int): # id, access_hash, type r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1}) diff --git a/pyrogram/storage/sqlite_storage.py b/pyrogram/storage/sqlite_storage.py index 74d785aa..53c25191 100644 --- a/pyrogram/storage/sqlite_storage.py +++ b/pyrogram/storage/sqlite_storage.py @@ -49,6 +49,15 @@ CREATE TABLE peers last_update_on INTEGER NOT NULL DEFAULT (CAST(STRFTIME('%s', 'now') AS INTEGER)) ); +CREATE TABLE update_state +( + id INTEGER PRIMARY KEY, + pts INTEGER, + qts INTEGER, + date INTEGER, + seq INTEGER +); + CREATE TABLE version ( number INTEGER PRIMARY KEY @@ -110,7 +119,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str): class SQLiteStorage(Storage): - VERSION = 3 + VERSION = 4 USERNAME_TTL = 8 * 60 * 60 def __init__(self, name: str): @@ -166,6 +175,24 @@ class SQLiteStorage(Storage): usernames ) + async def update_state(self, value: Tuple[int, int, int, int, int] = object): + if value == object: + return self.conn.execute( + "SELECT id, pts, qts, date, seq FROM update_state" + ).fetchall() + else: + with self.conn: + if value is None: + self.conn.execute( + "DELETE FROM update_state" + ) + else: + self.conn.execute( + "REPLACE INTO update_state (id, pts, qts, date, seq)" + "VALUES (?, ?, ?, ?, ?)", + value + ) + async def get_peer_by_id(self, peer_id: int): r = self.conn.execute( "SELECT id, access_hash, type FROM peers WHERE id = ?", diff --git a/pyrogram/storage/storage.py b/pyrogram/storage/storage.py index 1e47527c..7484076a 100644 --- a/pyrogram/storage/storage.py +++ b/pyrogram/storage/storage.py @@ -19,6 +19,7 @@ import base64 import struct +from abc import abstractmethod from typing import List, Tuple @@ -51,6 +52,21 @@ class Storage: async def update_usernames(self, usernames: List[Tuple[int, str]]): raise NotImplementedError + @abstractmethod + async def update_state(self, update_state: Tuple[int, int, int, int, int] = object): + """Get or set the update state of the current session. + + Parameters: + update_state (``Tuple[int, int, int, int, int]``): A tuple containing the update state to set. + Tuple must contain the following information: + - ``int``: The id of the entity. + - ``int``: The pts. + - ``int``: The qts. + - ``int``: The date. + - ``int``: The seq. + """ + raise NotImplementedError + async def get_peer_by_id(self, peer_id: int): raise NotImplementedError