pyrofork: Add skip_updates parameter to Client class

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
KurimuzonAkuma 2024-03-06 15:28:00 +03:00 committed by wulan17
parent c4362ad535
commit 4aa4d1a74a
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
6 changed files with 183 additions and 2 deletions

View file

@ -174,6 +174,10 @@ class Client(Methods):
Useful for batch programs that don't need to deal with updates. Useful for batch programs that don't need to deal with updates.
Defaults to False (updates enabled and received). Defaults to False (updates enabled and received).
skip_updates (``bool``, *optional*):
Pass True to skip pending updates that arrived while the client was offline.
Defaults to True.
takeout (``bool``, *optional*): takeout (``bool``, *optional*):
Pass True to let the client use a takeout session instead of a normal one, implies *no_updates=True*. Pass True to let the client use a takeout session instead of a normal one, implies *no_updates=True*.
Useful for exporting Telegram data. Methods invoked inside a takeout session (such as get_chat_history, Useful for exporting Telegram data. Methods invoked inside a takeout session (such as get_chat_history,
@ -248,7 +252,8 @@ class Client(Methods):
plugins: Optional[dict] = None, plugins: Optional[dict] = None,
parse_mode: "enums.ParseMode" = enums.ParseMode.DEFAULT, parse_mode: "enums.ParseMode" = enums.ParseMode.DEFAULT,
no_updates: Optional[bool] = None, no_updates: Optional[bool] = None,
takeout: Optional[bool] = None, skip_updates: bool = True,
takeout: bool = None,
sleep_threshold: int = Session.SLEEP_THRESHOLD, sleep_threshold: int = Session.SLEEP_THRESHOLD,
hide_password: Optional[bool] = False, hide_password: Optional[bool] = False,
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS, max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
@ -279,6 +284,7 @@ class Client(Methods):
self.plugins = plugins self.plugins = plugins
self.parse_mode = parse_mode self.parse_mode = parse_mode
self.no_updates = no_updates self.no_updates = no_updates
self.skip_updates = skip_updates
self.takeout = takeout self.takeout = takeout
self.sleep_threshold = sleep_threshold self.sleep_threshold = sleep_threshold
self.hide_password = hide_password self.hide_password = hide_password
@ -607,6 +613,17 @@ class Client(Methods):
pts = getattr(update, "pts", None) pts = getattr(update, "pts", None)
pts_count = getattr(update, "pts_count", None) pts_count = getattr(update, "pts_count", None)
if pts:
await self.storage.update_state(
(
utils.get_channel_id(channel_id) if channel_id else self.me.id,
pts,
None,
updates.date,
None
)
)
if isinstance(update, raw.types.UpdateChannelTooLong): if isinstance(update, raw.types.UpdateChannelTooLong):
log.info(update) log.info(update)
@ -637,6 +654,16 @@ class Client(Methods):
self.dispatcher.updates_queue.put_nowait((update, users, chats)) self.dispatcher.updates_queue.put_nowait((update, users, chats))
elif isinstance(updates, (raw.types.UpdateShortMessage, raw.types.UpdateShortChatMessage)): elif isinstance(updates, (raw.types.UpdateShortMessage, raw.types.UpdateShortChatMessage)):
await self.storage.update_state(
(
self.me.id,
updates.pts,
None,
updates.date,
None
)
)
diff = await self.invoke( diff = await self.invoke(
raw.functions.updates.GetDifference( raw.functions.updates.GetDifference(
pts=updates.pts - updates.pts_count, pts=updates.pts - updates.pts_count,

View file

@ -24,6 +24,7 @@ from collections import OrderedDict
import pyrogram import pyrogram
from pyrogram import utils from pyrogram import utils
from pyrogram import raw
from pyrogram.handlers import ( from pyrogram.handlers import (
BotBusinessConnectHandler, BotBusinessConnectHandler,
BotBusinessMessageHandler, BotBusinessMessageHandler,
@ -251,6 +252,87 @@ class Dispatcher:
log.info("Started %s HandlerTasks", self.client.workers) log.info("Started %s HandlerTasks", self.client.workers)
if not self.client.skip_updates:
states = await self.client.storage.update_state()
if not states:
log.info("No states found, skipping recovery.")
return
message_updates_counter = 0
other_updates_counter = 0
for state in states:
id, local_pts, _, local_date, _ = state
prev_pts = 0
while True:
diff = await self.client.invoke(
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
) if id == self.client.me.id else
raw.functions.updates.GetChannelDifference(
channel=await self.client.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000
)
)
if isinstance(diff, (raw.types.updates.DifferenceEmpty, raw.types.updates.ChannelDifferenceEmpty)):
break
elif isinstance(diff, (raw.types.updates.DifferenceTooLong, raw.types.updates.ChannelDifferenceTooLong)):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date
if prev_pts == local_pts:
break
prev_pts = local_pts
users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}
for message in diff.new_messages:
message_updates_counter += 1
self.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
) if id == self.client.me.id else
raw.types.UpdateNewChannelMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)
for update in diff.other_updates:
other_updates_counter += 1
self.updates_queue.put_nowait(
(update, users, chats)
)
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break
await self.client.storage.update_state(None)
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
async def stop(self): async def stop(self):
if not self.client.no_updates: if not self.client.no_updates:
for i in range(self.client.workers): for i in range(self.client.workers):

View file

@ -27,6 +27,18 @@ from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
UPDATE_STATE_SCHEMA = """
CREATE TABLE update_state
(
id INTEGER PRIMARY KEY,
pts INTEGER,
qts INTEGER,
date INTEGER,
seq INTEGER
);
"""
class FileStorage(SQLiteStorage): class FileStorage(SQLiteStorage):
FILE_EXTENSION = ".session" FILE_EXTENSION = ".session"
@ -50,6 +62,12 @@ class FileStorage(SQLiteStorage):
version += 1 version += 1
if version == 3:
with self.conn:
self.conn.executescript(UPDATE_STATE_SCHEMA)
version += 1
self.version(version) self.version(version)
async def open(self): async def open(self):

View file

@ -76,6 +76,7 @@ class MongoStorage(Storage):
self._peer = database['peers'] self._peer = database['peers']
self._session = database['session'] self._session = database['session']
self._usernames = database['usernames'] self._usernames = database['usernames']
self._states = database['update_state']
self._remove_peers = remove_peers self._remove_peers = remove_peers
async def open(self): async def open(self):
@ -167,6 +168,16 @@ class MongoStorage(Storage):
bulk bulk
) )
async def update_state(self, value: Tuple[int, int, int, int, int] = object):
if value == object:
states = [[state['_id'],state['pts'],state['qts'],state['date'],state['seq']] async for state in self._states.find()]
return states if len(states) > 0 else None
else:
if value is None:
await self._states.drop()
else:
await self._states.update_one({'_id': value[0]}, {'$set': {'pts': value[1], 'qts': value[2], 'date': value[3], 'seq': value[4]}}, upsert=True)
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
# id, access_hash, type # id, access_hash, type
r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1}) r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1})

View file

@ -49,6 +49,15 @@ CREATE TABLE peers
last_update_on INTEGER NOT NULL DEFAULT (CAST(STRFTIME('%s', 'now') AS INTEGER)) last_update_on INTEGER NOT NULL DEFAULT (CAST(STRFTIME('%s', 'now') AS INTEGER))
); );
CREATE TABLE update_state
(
id INTEGER PRIMARY KEY,
pts INTEGER,
qts INTEGER,
date INTEGER,
seq INTEGER
);
CREATE TABLE version CREATE TABLE version
( (
number INTEGER PRIMARY KEY number INTEGER PRIMARY KEY
@ -110,7 +119,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
class SQLiteStorage(Storage): class SQLiteStorage(Storage):
VERSION = 3 VERSION = 4
USERNAME_TTL = 8 * 60 * 60 USERNAME_TTL = 8 * 60 * 60
def __init__(self, name: str): def __init__(self, name: str):
@ -166,6 +175,24 @@ class SQLiteStorage(Storage):
usernames usernames
) )
async def update_state(self, value: Tuple[int, int, int, int, int] = object):
if value == object:
return self.conn.execute(
"SELECT id, pts, qts, date, seq FROM update_state"
).fetchall()
else:
with self.conn:
if value is None:
self.conn.execute(
"DELETE FROM update_state"
)
else:
self.conn.execute(
"REPLACE INTO update_state (id, pts, qts, date, seq)"
"VALUES (?, ?, ?, ?, ?)",
value
)
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
r = self.conn.execute( r = self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE id = ?", "SELECT id, access_hash, type FROM peers WHERE id = ?",

View file

@ -19,6 +19,7 @@
import base64 import base64
import struct import struct
from abc import abstractmethod
from typing import List, Tuple from typing import List, Tuple
@ -51,6 +52,21 @@ class Storage:
async def update_usernames(self, usernames: List[Tuple[int, str]]): async def update_usernames(self, usernames: List[Tuple[int, str]]):
raise NotImplementedError raise NotImplementedError
@abstractmethod
async def update_state(self, update_state: Tuple[int, int, int, int, int] = object):
"""Get or set the update state of the current session.
Parameters:
update_state (``Tuple[int, int, int, int, int]``): A tuple containing the update state to set.
Tuple must contain the following information:
- ``int``: The id of the entity.
- ``int``: The pts.
- ``int``: The qts.
- ``int``: The date.
- ``int``: The seq.
"""
raise NotImplementedError
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
raise NotImplementedError raise NotImplementedError