Extract recovering gaps into a separate method

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
KurimuzonAkuma 2024-08-28 21:07:22 +03:00 committed by wulan17
parent 930da9b858
commit ee8f11e2a0
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
2 changed files with 91 additions and 90 deletions

View file

@ -33,7 +33,7 @@ from importlib import import_module
from io import StringIO, BytesIO from io import StringIO, BytesIO
from mimetypes import MimeTypes from mimetypes import MimeTypes
from pathlib import Path from pathlib import Path
from typing import Union, List, Optional, Callable, AsyncGenerator, Type from typing import Union, List, Optional, Callable, AsyncGenerator, Type, Tuple
import pyrogram import pyrogram
from pyrogram import __version__, __license__ from pyrogram import __version__, __license__
@ -45,7 +45,7 @@ from pyrogram.errors import CDNFileHashMismatch
from pyrogram.errors import ( from pyrogram.errors import (
SessionPasswordNeeded, SessionPasswordNeeded,
VolumeLocNotFound, ChannelPrivate, VolumeLocNotFound, ChannelPrivate,
BadRequest BadRequest, ChannelInvalid
) )
from pyrogram.handlers.handler import Handler from pyrogram.handlers.handler import Handler
from pyrogram.methods import Methods from pyrogram.methods import Methods
@ -662,7 +662,8 @@ class Client(Methods):
)] )]
), ),
pts=pts - pts_count, pts=pts - pts_count,
limit=pts limit=pts,
force=False
) )
) )
except ChannelPrivate: except ChannelPrivate:
@ -710,6 +711,92 @@ class Client(Methods):
elif isinstance(updates, raw.types.UpdatesTooLong): elif isinstance(updates, raw.types.UpdatesTooLong):
log.info(updates) log.info(updates)
async def recover_gaps(self) -> Tuple[int, int]:
states = await self.storage.update_state()
message_updates_counter = 0
other_updates_counter = 0
if not states:
log.info("No states found, skipping recovery.")
return (message_updates_counter, other_updates_counter)
for state in states:
id, local_pts, _, local_date, _ = state
prev_pts = 0
while True:
try:
diff = await self.invoke(
raw.functions.updates.GetChannelDifference(
channel=await self.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000,
force=False
) if id < 0 else
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
)
)
except (ChannelPrivate, ChannelInvalid):
break
if isinstance(diff, raw.types.updates.DifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.DifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date
if prev_pts == local_pts:
break
prev_pts = local_pts
elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts
users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}
for message in diff.new_messages:
message_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)
for update in diff.other_updates:
other_updates_counter += 1
self.dispatcher.updates_queue.put_nowait(
(update, users, chats)
)
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break
await self.storage.update_state(id)
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
return (message_updates_counter, other_updates_counter)
async def load_session(self): async def load_session(self):
await self.storage.open() await self.storage.open()

View file

@ -274,93 +274,7 @@ class Dispatcher:
log.info("Started %s HandlerTasks", self.client.workers) log.info("Started %s HandlerTasks", self.client.workers)
if not self.client.skip_updates: if not self.client.skip_updates:
states = await self.client.storage.update_state() await self.client.recover_gaps()
if not states:
log.info("No states found, skipping recovery.")
return
message_updates_counter = 0
other_updates_counter = 0
for state in states:
id, local_pts, _, local_date, _ = state
prev_pts = 0
while True:
try:
diff = await self.client.invoke(
raw.functions.updates.GetChannelDifference(
channel=await self.client.resolve_peer(id),
filter=raw.types.ChannelMessagesFilterEmpty(),
pts=local_pts,
limit=10000
) if id < 0 else
raw.functions.updates.GetDifference(
pts=local_pts,
date=local_date,
qts=0
)
)
except (errors.ChannelPrivate, errors.ChannelInvalid):
break
if isinstance(diff, raw.types.updates.DifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.DifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.Difference):
local_pts = diff.state.pts
elif isinstance(diff, raw.types.updates.DifferenceSlice):
local_pts = diff.intermediate_state.pts
local_date = diff.intermediate_state.date
if prev_pts == local_pts:
break
prev_pts = local_pts
elif isinstance(diff, raw.types.updates.ChannelDifferenceEmpty):
break
elif isinstance(diff, raw.types.updates.ChannelDifferenceTooLong):
break
elif isinstance(diff, raw.types.updates.ChannelDifference):
local_pts = diff.pts
users = {i.id: i for i in diff.users}
chats = {i.id: i for i in diff.chats}
for message in diff.new_messages:
message_updates_counter += 1
self.updates_queue.put_nowait(
(
raw.types.UpdateNewMessage(
message=message,
pts=local_pts,
pts_count=-1
) if id == self.client.me.id else
raw.types.UpdateNewChannelMessage(
message=message,
pts=local_pts,
pts_count=-1
),
users,
chats
)
)
for update in diff.other_updates:
other_updates_counter += 1
self.updates_queue.put_nowait(
(update, users, chats)
)
if isinstance(diff, (raw.types.updates.Difference, raw.types.updates.ChannelDifference)):
break
await self.client.storage.update_state(id)
log.info("Recovered %s messages and %s updates.", message_updates_counter, other_updates_counter)
async def stop(self): async def stop(self):
if not self.client.no_updates: if not self.client.no_updates: