mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2026-01-03 14:04:51 +00:00
pyrofork: Add skip_updates parameter to Client class
Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
parent
c4362ad535
commit
4aa4d1a74a
6 changed files with 183 additions and 2 deletions
|
|
@ -174,6 +174,10 @@ class Client(Methods):
|
||||||
Useful for batch programs that don't need to deal with updates.
|
Useful for batch programs that don't need to deal with updates.
|
||||||
Defaults to False (updates enabled and received).
|
Defaults to False (updates enabled and received).
|
||||||
|
|
||||||
|
skip_updates (``bool``, *optional*):
|
||||||
|
Pass True to skip pending updates that arrived while the client was offline.
|
||||||
|
Defaults to True.
|
||||||
|
|
||||||
takeout (``bool``, *optional*):
|
takeout (``bool``, *optional*):
|
||||||
Pass True to let the client use a takeout session instead of a normal one, implies *no_updates=True*.
|
Pass True to let the client use a takeout session instead of a normal one, implies *no_updates=True*.
|
||||||
Useful for exporting Telegram data. Methods invoked inside a takeout session (such as get_chat_history,
|
Useful for exporting Telegram data. Methods invoked inside a takeout session (such as get_chat_history,
|
||||||
|
|
@ -248,7 +252,8 @@ class Client(Methods):
|
||||||
plugins: Optional[dict] = None,
|
plugins: Optional[dict] = None,
|
||||||
parse_mode: "enums.ParseMode" = enums.ParseMode.DEFAULT,
|
parse_mode: "enums.ParseMode" = enums.ParseMode.DEFAULT,
|
||||||
no_updates: Optional[bool] = None,
|
no_updates: Optional[bool] = None,
|
||||||
takeout: Optional[bool] = None,
|
skip_updates: bool = True,
|
||||||
|
takeout: bool = None,
|
||||||
sleep_threshold: int = Session.SLEEP_THRESHOLD,
|
sleep_threshold: int = Session.SLEEP_THRESHOLD,
|
||||||
hide_password: Optional[bool] = False,
|
hide_password: Optional[bool] = False,
|
||||||
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
|
max_concurrent_transmissions: int = MAX_CONCURRENT_TRANSMISSIONS,
|
||||||
|
|
@ -279,6 +284,7 @@ class Client(Methods):
|
||||||
self.plugins = plugins
|
self.plugins = plugins
|
||||||
self.parse_mode = parse_mode
|
self.parse_mode = parse_mode
|
||||||
self.no_updates = no_updates
|
self.no_updates = no_updates
|
||||||
|
self.skip_updates = skip_updates
|
||||||
self.takeout = takeout
|
self.takeout = takeout
|
||||||
self.sleep_threshold = sleep_threshold
|
self.sleep_threshold = sleep_threshold
|
||||||
self.hide_password = hide_password
|
self.hide_password = hide_password
|
||||||
|
|
@ -607,6 +613,17 @@ class Client(Methods):
|
||||||
pts = getattr(update, "pts", None)
|
pts = getattr(update, "pts", None)
|
||||||
pts_count = getattr(update, "pts_count", None)
|
pts_count = getattr(update, "pts_count", None)
|
||||||
|
|
||||||
|
if pts:
|
||||||
|
await self.storage.update_state(
|
||||||
|
(
|
||||||
|
utils.get_channel_id(channel_id) if channel_id else self.me.id,
|
||||||
|
pts,
|
||||||
|
None,
|
||||||
|
updates.date,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
if isinstance(update, raw.types.UpdateChannelTooLong):
|
if isinstance(update, raw.types.UpdateChannelTooLong):
|
||||||
log.info(update)
|
log.info(update)
|
||||||
|
|
||||||
|
|
@ -637,6 +654,16 @@ class Client(Methods):
|
||||||
|
|
||||||
self.dispatcher.updates_queue.put_nowait((update, users, chats))
|
self.dispatcher.updates_queue.put_nowait((update, users, chats))
|
||||||
elif isinstance(updates, (raw.types.UpdateShortMessage, raw.types.UpdateShortChatMessage)):
|
elif isinstance(updates, (raw.types.UpdateShortMessage, raw.types.UpdateShortChatMessage)):
|
||||||
|
await self.storage.update_state(
|
||||||
|
(
|
||||||
|
self.me.id,
|
||||||
|
updates.pts,
|
||||||
|
None,
|
||||||
|
updates.date,
|
||||||
|
None
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
diff = await self.invoke(
|
diff = await self.invoke(
|
||||||
raw.functions.updates.GetDifference(
|
raw.functions.updates.GetDifference(
|
||||||
pts=updates.pts - updates.pts_count,
|
pts=updates.pts - updates.pts_count,
|
||||||
|
|
|
||||||
|
|
@ -24,6 +24,7 @@ from collections import OrderedDict
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram import utils
|
from pyrogram import utils
|
||||||
|
from pyrogram import raw
|
||||||
from pyrogram.handlers import (
|
from pyrogram.handlers import (
|
||||||
BotBusinessConnectHandler,
|
BotBusinessConnectHandler,
|
||||||
BotBusinessMessageHandler,
|
BotBusinessMessageHandler,
|
||||||
|
|
@ -251,6 +252,87 @@ class Dispatcher:
|
||||||
|
|
||||||
log.info("Started %s HandlerTasks", self.client.workers)
|
log.info("Started %s HandlerTasks", self.client.workers)
|
||||||
|
|
||||||
|
if not self.client.skip_updates:
|
||||||
|
states = await self.client.storage.update_state()
|
||||||
|
|
||||||
|
if not states:
|
||||||
|
log.info("No states found, skipping recovery.")
|
||||||
|
return
|
||||||
|
|
||||||
|
message_updates_counter = 0
|
||||||
|
other_updates_counter = 0
|
||||||
|
|
||||||
|
for state in states:
|
||||||
|
id, local_pts, _, local_date, _ = state
|
||||||
|
|
||||||
|
prev_pts = 0
|
||||||
|
|
||||||
|
while True:
|
||||||
|
diff = await self.client.invoke(
|
||||||
|
raw.functions.updates.GetDifference(
|
||||||
|
pts=local_pts,
|
||||||
|
date=local_date,
|
||||||
|
qts=0
|
||||||
|
) if id == self.client.me.id else
|
||||||
|
raw.functions.updates.GetChannelDifference(
|
||||||
|
channel=await self.client.resolve_peer(id),
|
||||||
|
filter=raw.types.ChannelMessagesFilterEmpty(),
|
||||||
|
pts=local_pts,
|
||||||
|
limit=10000
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(diff, (raw.types.updates.DifferenceEmpty, raw.types.updates.ChannelDifferenceEmpty)):
|
||||||
|
break
|
||||||
|
elif isinstance(diff, (raw.types.updates.DifferenceTooLong, raw.types.updates.ChannelDifferenceTooLong)):
|
||||||
|
break
|
||||||
|
elif isinstance(diff, raw.types.updates.ChannelDifference):
|
||||||
|
local_pts = diff.pts
|
||||||
|
elif isinstance(diff, raw.types.updates.Difference):
|
||||||
|
local_pts = diff.state.pts
|
||||||
|
elif isinstance(diff, raw.types.updates.DifferenceSlice):
|
||||||
|
local_pts = diff.intermediate_state.pts
|
||||||
|
local_date = diff.intermediate_state.date
|
||||||
|
|
||||||
|
if prev_pts == local_pts:
|
||||||
|
break
|
||||||
|
|
||||||
|
prev_pts = local_pts
|
||||||
|
|
||||||
|
users = {i.id: i for i in diff.users}
|
||||||
|
chats = {i.id: i for i in diff.chats}
|
||||||
|
|
||||||
|
for message in diff.new_messages:
|
||||||
|
message_updates_counter += 1
|
||||||
|
self.updates_queue.put_nowait(
|
||||||
|
(
|
||||||
|
raw.types.UpdateNewMessage(
|
||||||
|
message=message,
|
||||||
|
pts=local_pts,
|
||||||
|
pts_count=-1
|
||||||
|
) if id == self.client.me.id else
|
||||||
|
raw.types.UpdateNewChannelMessage(
|
||||||
|
message=message,
|
||||||
|
pts=local_pts,
|
||||||
|
pts_count=-1
|
||||||
|
),
|
||||||
|
users,
|
||||||
|
chats
|
||||||
|
)
|
||||||
|
)
|
||||||
|
|
||||||
|
for update in diff.other_updates:
|
||||||
|
other_updates_counter += 1
|
||||||
|
self.updates_queue.put_nowait(
|
||||||
|
(update, users, chats)
|
||||||
|
)
|
||||||
|
|
||||||
|
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
|
||||||
|
break
|
||||||
|
|
||||||
|
await self.client.storage.update_state(None)
|
||||||
|
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
|
||||||
|
|
||||||
async def stop(self):
|
async def stop(self):
|
||||||
if not self.client.no_updates:
|
if not self.client.no_updates:
|
||||||
for i in range(self.client.workers):
|
for i in range(self.client.workers):
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,18 @@ from .sqlite_storage import SQLiteStorage
|
||||||
log = logging.getLogger(__name__)
|
log = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
UPDATE_STATE_SCHEMA = """
|
||||||
|
CREATE TABLE update_state
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
pts INTEGER,
|
||||||
|
qts INTEGER,
|
||||||
|
date INTEGER,
|
||||||
|
seq INTEGER
|
||||||
|
);
|
||||||
|
"""
|
||||||
|
|
||||||
|
|
||||||
class FileStorage(SQLiteStorage):
|
class FileStorage(SQLiteStorage):
|
||||||
FILE_EXTENSION = ".session"
|
FILE_EXTENSION = ".session"
|
||||||
|
|
||||||
|
|
@ -50,6 +62,12 @@ class FileStorage(SQLiteStorage):
|
||||||
|
|
||||||
version += 1
|
version += 1
|
||||||
|
|
||||||
|
if version == 3:
|
||||||
|
with self.conn:
|
||||||
|
self.conn.executescript(UPDATE_STATE_SCHEMA)
|
||||||
|
|
||||||
|
version += 1
|
||||||
|
|
||||||
self.version(version)
|
self.version(version)
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
|
|
|
||||||
|
|
@ -76,6 +76,7 @@ class MongoStorage(Storage):
|
||||||
self._peer = database['peers']
|
self._peer = database['peers']
|
||||||
self._session = database['session']
|
self._session = database['session']
|
||||||
self._usernames = database['usernames']
|
self._usernames = database['usernames']
|
||||||
|
self._states = database['update_state']
|
||||||
self._remove_peers = remove_peers
|
self._remove_peers = remove_peers
|
||||||
|
|
||||||
async def open(self):
|
async def open(self):
|
||||||
|
|
@ -167,6 +168,16 @@ class MongoStorage(Storage):
|
||||||
bulk
|
bulk
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def update_state(self, value: Tuple[int, int, int, int, int] = object):
|
||||||
|
if value == object:
|
||||||
|
states = [[state['_id'],state['pts'],state['qts'],state['date'],state['seq']] async for state in self._states.find()]
|
||||||
|
return states if len(states) > 0 else None
|
||||||
|
else:
|
||||||
|
if value is None:
|
||||||
|
await self._states.drop()
|
||||||
|
else:
|
||||||
|
await self._states.update_one({'_id': value[0]}, {'$set': {'pts': value[1], 'qts': value[2], 'date': value[3], 'seq': value[4]}}, upsert=True)
|
||||||
|
|
||||||
async def get_peer_by_id(self, peer_id: int):
|
async def get_peer_by_id(self, peer_id: int):
|
||||||
# id, access_hash, type
|
# id, access_hash, type
|
||||||
r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1})
|
r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1})
|
||||||
|
|
|
||||||
|
|
@ -49,6 +49,15 @@ CREATE TABLE peers
|
||||||
last_update_on INTEGER NOT NULL DEFAULT (CAST(STRFTIME('%s', 'now') AS INTEGER))
|
last_update_on INTEGER NOT NULL DEFAULT (CAST(STRFTIME('%s', 'now') AS INTEGER))
|
||||||
);
|
);
|
||||||
|
|
||||||
|
CREATE TABLE update_state
|
||||||
|
(
|
||||||
|
id INTEGER PRIMARY KEY,
|
||||||
|
pts INTEGER,
|
||||||
|
qts INTEGER,
|
||||||
|
date INTEGER,
|
||||||
|
seq INTEGER
|
||||||
|
);
|
||||||
|
|
||||||
CREATE TABLE version
|
CREATE TABLE version
|
||||||
(
|
(
|
||||||
number INTEGER PRIMARY KEY
|
number INTEGER PRIMARY KEY
|
||||||
|
|
@ -110,7 +119,7 @@ def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
|
||||||
|
|
||||||
|
|
||||||
class SQLiteStorage(Storage):
|
class SQLiteStorage(Storage):
|
||||||
VERSION = 3
|
VERSION = 4
|
||||||
USERNAME_TTL = 8 * 60 * 60
|
USERNAME_TTL = 8 * 60 * 60
|
||||||
|
|
||||||
def __init__(self, name: str):
|
def __init__(self, name: str):
|
||||||
|
|
@ -166,6 +175,24 @@ class SQLiteStorage(Storage):
|
||||||
usernames
|
usernames
|
||||||
)
|
)
|
||||||
|
|
||||||
|
async def update_state(self, value: Tuple[int, int, int, int, int] = object):
|
||||||
|
if value == object:
|
||||||
|
return self.conn.execute(
|
||||||
|
"SELECT id, pts, qts, date, seq FROM update_state"
|
||||||
|
).fetchall()
|
||||||
|
else:
|
||||||
|
with self.conn:
|
||||||
|
if value is None:
|
||||||
|
self.conn.execute(
|
||||||
|
"DELETE FROM update_state"
|
||||||
|
)
|
||||||
|
else:
|
||||||
|
self.conn.execute(
|
||||||
|
"REPLACE INTO update_state (id, pts, qts, date, seq)"
|
||||||
|
"VALUES (?, ?, ?, ?, ?)",
|
||||||
|
value
|
||||||
|
)
|
||||||
|
|
||||||
async def get_peer_by_id(self, peer_id: int):
|
async def get_peer_by_id(self, peer_id: int):
|
||||||
r = self.conn.execute(
|
r = self.conn.execute(
|
||||||
"SELECT id, access_hash, type FROM peers WHERE id = ?",
|
"SELECT id, access_hash, type FROM peers WHERE id = ?",
|
||||||
|
|
|
||||||
|
|
@ -19,6 +19,7 @@
|
||||||
|
|
||||||
import base64
|
import base64
|
||||||
import struct
|
import struct
|
||||||
|
from abc import abstractmethod
|
||||||
from typing import List, Tuple
|
from typing import List, Tuple
|
||||||
|
|
||||||
|
|
||||||
|
|
@ -51,6 +52,21 @@ class Storage:
|
||||||
async def update_usernames(self, usernames: List[Tuple[int, str]]):
|
async def update_usernames(self, usernames: List[Tuple[int, str]]):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
async def update_state(self, update_state: Tuple[int, int, int, int, int] = object):
|
||||||
|
"""Get or set the update state of the current session.
|
||||||
|
|
||||||
|
Parameters:
|
||||||
|
update_state (``Tuple[int, int, int, int, int]``): A tuple containing the update state to set.
|
||||||
|
Tuple must contain the following information:
|
||||||
|
- ``int``: The id of the entity.
|
||||||
|
- ``int``: The pts.
|
||||||
|
- ``int``: The qts.
|
||||||
|
- ``int``: The date.
|
||||||
|
- ``int``: The seq.
|
||||||
|
"""
|
||||||
|
raise NotImplementedError
|
||||||
|
|
||||||
async def get_peer_by_id(self, peer_id: int):
|
async def get_peer_by_id(self, peer_id: int):
|
||||||
raise NotImplementedError
|
raise NotImplementedError
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in a new issue