mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2025-12-29 12:04:51 +00:00
Fix: Register RawUpdateHandler and Refactor Dispatcher for Better Modularity.
Signed-off-by: Ling-ex <nekochan@rizkiofficial.com> Signed-off-by: wulan17 <wulan17@komodos.id>
This commit is contained in:
parent
315f61d1ed
commit
7d10a6fb9c
1 changed files with 72 additions and 50 deletions
|
|
@ -21,6 +21,7 @@ import asyncio
|
||||||
import inspect
|
import inspect
|
||||||
import logging
|
import logging
|
||||||
from collections import OrderedDict
|
from collections import OrderedDict
|
||||||
|
from typing import Any
|
||||||
|
|
||||||
import pyrogram
|
import pyrogram
|
||||||
from pyrogram import raw, types, utils
|
from pyrogram import raw, types, utils
|
||||||
|
|
@ -337,47 +338,12 @@ class Dispatcher:
|
||||||
async def handler_worker(self, lock: asyncio.Lock):
|
async def handler_worker(self, lock: asyncio.Lock):
|
||||||
while True:
|
while True:
|
||||||
packet = await self.updates_queue.get()
|
packet = await self.updates_queue.get()
|
||||||
|
|
||||||
if packet is None:
|
if packet is None:
|
||||||
break
|
break
|
||||||
await self._process_packet(packet, lock)
|
|
||||||
|
|
||||||
async def _process_packet(
|
|
||||||
self,
|
|
||||||
packet: tuple[raw.core.TLObject, dict[int, types.Update], dict[int, types.Update]],
|
|
||||||
lock: asyncio.Lock,
|
|
||||||
):
|
|
||||||
try:
|
try:
|
||||||
update, users, chats = packet
|
await self._handle_packet(packet, lock)
|
||||||
parser = self.update_parsers.get(type(update))
|
|
||||||
|
|
||||||
if parser is not None:
|
|
||||||
parsed_result = parser(update, users, chats)
|
|
||||||
if inspect.isawaitable(parsed_result):
|
|
||||||
parsed_update, handler_type = await parsed_result
|
|
||||||
else:
|
|
||||||
parsed_update, handler_type = parsed_result
|
|
||||||
else:
|
|
||||||
parsed_update, handler_type = (None, type(None))
|
|
||||||
|
|
||||||
async with lock:
|
|
||||||
for group in self.groups.values():
|
|
||||||
for handler in group:
|
|
||||||
try:
|
|
||||||
if parsed_update is not None:
|
|
||||||
if isinstance(handler, handler_type) and await handler.check(
|
|
||||||
self.client, parsed_update
|
|
||||||
):
|
|
||||||
await self._execute_callback(handler, parsed_update)
|
|
||||||
break
|
|
||||||
elif isinstance(handler, RawUpdateHandler):
|
|
||||||
await self._execute_callback(handler, update, users, chats)
|
|
||||||
break
|
|
||||||
except (pyrogram.StopPropagation, pyrogram.ContinuePropagation) as e:
|
|
||||||
if isinstance(e, pyrogram.StopPropagation):
|
|
||||||
raise
|
|
||||||
except Exception as exception:
|
|
||||||
if parsed_update is not None:
|
|
||||||
await self._handle_exception(parsed_update, exception)
|
|
||||||
except pyrogram.StopPropagation:
|
except pyrogram.StopPropagation:
|
||||||
pass
|
pass
|
||||||
except Exception as e:
|
except Exception as e:
|
||||||
|
|
@ -385,11 +351,75 @@ class Dispatcher:
|
||||||
finally:
|
finally:
|
||||||
self.updates_queue.task_done()
|
self.updates_queue.task_done()
|
||||||
|
|
||||||
async def _handle_exception(self, parsed_update: types.Update, exception: Exception):
|
async def _handle_packet(self, packet, lock: asyncio.Lock):
|
||||||
|
update, users, chats = packet
|
||||||
|
parser = self.update_parsers.get(type(update))
|
||||||
|
|
||||||
|
parsed_update, handler_type = (
|
||||||
|
await parser(update, users, chats)
|
||||||
|
if parser is not None else (None, type(None))
|
||||||
|
)
|
||||||
|
async with lock:
|
||||||
|
await self._dispatch_to_handlers(update, users, chats, parsed_update, handler_type)
|
||||||
|
|
||||||
|
|
||||||
|
async def _dispatch_to_handlers(
|
||||||
|
self, update, users, chats, parsed_update, handler_type,
|
||||||
|
):
|
||||||
|
for group in self.groups.values():
|
||||||
|
for handler in group:
|
||||||
|
args = await self._match_handler(
|
||||||
|
handler, update, users, chats, parsed_update, handler_type,
|
||||||
|
)
|
||||||
|
if args is None:
|
||||||
|
continue
|
||||||
|
|
||||||
|
try:
|
||||||
|
await self._execute_handler(handler, *args)
|
||||||
|
except pyrogram.StopPropagation:
|
||||||
|
raise
|
||||||
|
except pyrogram.ContinuePropagation:
|
||||||
|
continue
|
||||||
|
except Exception as error:
|
||||||
|
if parsed_update is not None:
|
||||||
|
await self._handle_exception(parsed_update, error)
|
||||||
|
break
|
||||||
|
|
||||||
|
async def _match_handler(
|
||||||
|
self, handler, update, users, chats, parsed_update, handler_type,
|
||||||
|
):
|
||||||
|
try:
|
||||||
|
if isinstance(handler, handler_type):
|
||||||
|
if await handler.check(self.client, parsed_update):
|
||||||
|
return (parsed_update,)
|
||||||
|
elif isinstance(handler, RawUpdateHandler):
|
||||||
|
if await handler.check(self.client, update):
|
||||||
|
return (update, users, chats)
|
||||||
|
except Exception as e:
|
||||||
|
log.exception(e)
|
||||||
|
|
||||||
|
return None
|
||||||
|
|
||||||
|
async def _execute_handler(self, handler, *args: Any):
|
||||||
|
if inspect.iscoroutinefunction(handler.callback):
|
||||||
|
await handler.callback(self.client, *args)
|
||||||
|
else:
|
||||||
|
await self.loop.run_in_executor(
|
||||||
|
self.client.executor,
|
||||||
|
handler.callback,
|
||||||
|
self.client,
|
||||||
|
*args
|
||||||
|
)
|
||||||
|
|
||||||
|
async def _handle_exception(
|
||||||
|
self, parsed_update: types.Update, exception: Exception,
|
||||||
|
):
|
||||||
handled_error = False
|
handled_error = False
|
||||||
for error_handler in self.error_handlers:
|
for error_handler in self.error_handlers:
|
||||||
try:
|
try:
|
||||||
if await error_handler.check(self.client, parsed_update, exception):
|
if await error_handler.check(
|
||||||
|
self.client, parsed_update, exception,
|
||||||
|
):
|
||||||
handled_error = True
|
handled_error = True
|
||||||
break
|
break
|
||||||
except pyrogram.StopPropagation:
|
except pyrogram.StopPropagation:
|
||||||
|
|
@ -401,11 +431,3 @@ class Dispatcher:
|
||||||
|
|
||||||
if not handled_error:
|
if not handled_error:
|
||||||
log.exception("Unhandled exception: %s", exception)
|
log.exception("Unhandled exception: %s", exception)
|
||||||
|
|
||||||
async def _execute_callback(self, handler: Handler, *args):
|
|
||||||
if inspect.iscoroutinefunction(handler.callback):
|
|
||||||
await handler.callback(self.client, *args)
|
|
||||||
else:
|
|
||||||
await self.client.loop.run_in_executor(
|
|
||||||
self.client.executor, handler.callback, self.client, *args
|
|
||||||
)
|
|
||||||
Loading…
Reference in a new issue