diff --git a/pyrogram/dispatcher.py b/pyrogram/dispatcher.py index 5cb97b5e..c5f166d5 100644 --- a/pyrogram/dispatcher.py +++ b/pyrogram/dispatcher.py @@ -21,6 +21,7 @@ import asyncio import inspect import logging from collections import OrderedDict +from typing import Any import pyrogram from pyrogram import raw, types, utils @@ -337,59 +338,88 @@ class Dispatcher: async def handler_worker(self, lock: asyncio.Lock): while True: packet = await self.updates_queue.get() + if packet is None: break - await self._process_packet(packet, lock) - async def _process_packet( - self, - packet: tuple[raw.core.TLObject, dict[int, types.Update], dict[int, types.Update]], - lock: asyncio.Lock, + try: + await self._handle_packet(packet, lock) + except pyrogram.StopPropagation: + pass + except Exception as e: + log.exception(e) + finally: + self.updates_queue.task_done() + + async def _handle_packet(self, packet, lock: asyncio.Lock): + update, users, chats = packet + parser = self.update_parsers.get(type(update)) + + parsed_update, handler_type = ( + await parser(update, users, chats) + if parser is not None else (None, type(None)) + ) + async with lock: + await self._dispatch_to_handlers(update, users, chats, parsed_update, handler_type) + + + async def _dispatch_to_handlers( + self, update, users, chats, parsed_update, handler_type, + ): + for group in self.groups.values(): + for handler in group: + args = await self._match_handler( + handler, update, users, chats, parsed_update, handler_type, + ) + if args is None: + continue + + try: + await self._execute_handler(handler, *args) + except pyrogram.StopPropagation: + raise + except pyrogram.ContinuePropagation: + continue + except Exception as error: + if parsed_update is not None: + await self._handle_exception(parsed_update, error) + break + + async def _match_handler( + self, handler, update, users, chats, parsed_update, handler_type, ): try: - update, users, chats = packet - parser = self.update_parsers.get(type(update)) - - if parser is not None: - parsed_result = parser(update, users, chats) - if inspect.isawaitable(parsed_result): - parsed_update, handler_type = await parsed_result - else: - parsed_update, handler_type = parsed_result - else: - parsed_update, handler_type = (None, type(None)) - - async with lock: - for group in self.groups.values(): - for handler in group: - try: - if parsed_update is not None: - if isinstance(handler, handler_type) and await handler.check( - self.client, parsed_update - ): - await self._execute_callback(handler, parsed_update) - break - elif isinstance(handler, RawUpdateHandler): - await self._execute_callback(handler, update, users, chats) - break - except (pyrogram.StopPropagation, pyrogram.ContinuePropagation) as e: - if isinstance(e, pyrogram.StopPropagation): - raise - except Exception as exception: - if parsed_update is not None: - await self._handle_exception(parsed_update, exception) - except pyrogram.StopPropagation: - pass + if isinstance(handler, handler_type): + if await handler.check(self.client, parsed_update): + return (parsed_update,) + elif isinstance(handler, RawUpdateHandler): + if await handler.check(self.client, update): + return (update, users, chats) except Exception as e: log.exception(e) - finally: - self.updates_queue.task_done() - async def _handle_exception(self, parsed_update: types.Update, exception: Exception): + return None + + async def _execute_handler(self, handler, *args: Any): + if inspect.iscoroutinefunction(handler.callback): + await handler.callback(self.client, *args) + else: + await self.loop.run_in_executor( + self.client.executor, + handler.callback, + self.client, + *args + ) + + async def _handle_exception( + self, parsed_update: types.Update, exception: Exception, + ): handled_error = False for error_handler in self.error_handlers: try: - if await error_handler.check(self.client, parsed_update, exception): + if await error_handler.check( + self.client, parsed_update, exception, + ): handled_error = True break except pyrogram.StopPropagation: @@ -401,11 +431,3 @@ class Dispatcher: if not handled_error: log.exception("Unhandled exception: %s", exception) - - async def _execute_callback(self, handler: Handler, *args): - if inspect.iscoroutinefunction(handler.callback): - await handler.callback(self.client, *args) - else: - await self.client.loop.run_in_executor( - self.client.executor, handler.callback, self.client, *args - ) \ No newline at end of file