diff --git a/misskaty/core/misskaty_patch/listen/listen.py b/misskaty/core/misskaty_patch/listen/listen.py index 1d8fe490..fc52eec7 100644 --- a/misskaty/core/misskaty_patch/listen/listen.py +++ b/misskaty/core/misskaty_patch/listen/listen.py @@ -1,133 +1,348 @@ """ pyromod - A monkeypatcher add-on for Pyrogram Copyright (C) 2020 Cezar H. - This file is part of pyromod. - pyromod is free software: you can redistribute it and/or modify it under the terms of the GNU General Public License as published by the Free Software Foundation, either version 3 of the License, or (at your option) any later version. - pyromod is distributed in the hope that it will be useful, but WITHOUT ANY WARRANTY; without even the implied warranty of MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the GNU General Public License for more details. - You should have received a copy of the GNU General Public License along with pyromod. If not, see . """ import asyncio -import functools import pyrogram -from ..utils import patch, patchable +from inspect import iscoroutinefunction +from typing import Optional, Callable, Union +from enum import Enum -loop = asyncio.get_event_loop() +from ..utils import patch, patchable, PyromodConfig -class ListenerCanceled(Exception): +class ListenerStopped(Exception): pass -pyrogram.errors.ListenerCanceled = ListenerCanceled +class ListenerTimeout(Exception): + pass + + +class ListenerTypes(Enum): + MESSAGE = "message" + CALLBACK_QUERY = "callback_query" @patch(pyrogram.client.Client) class Client: @patchable def __init__(self, *args, **kwargs): - self.listening = {} - self.using_mod = True - + self.listeners = {listener_type: {} for listener_type in ListenerTypes} self.old__init__(*args, **kwargs) @patchable - async def listen(self, chat_id, filters=None, timeout=None): - if type(chat_id) != int: - chat = await self.get_chat(chat_id) - chat_id = chat.id + async def listen( + self, + identifier: tuple, + filters=None, + listener_type=ListenerTypes.MESSAGE, + timeout=None, + unallowed_click_alert=True, + ): + if type(listener_type) != ListenerTypes: + raise TypeError("Parameter listener_type should be a" " value from pyromod.listen.ListenerTypes") - future = loop.create_future() - future.add_done_callback(functools.partial(self.clear_listener, chat_id)) - self.listening.update({chat_id: {"future": future, "filters": filters}}) - return await asyncio.wait_for(future, timeout) + future = self.loop.create_future() + future.add_done_callback(lambda f: self.stop_listening(identifier, listener_type)) + + listener_data = { + "future": future, + "filters": filters, + "unallowed_click_alert": unallowed_click_alert, + } + + self.listeners[listener_type].update({identifier: listener_data}) + + try: + return await asyncio.wait_for(future, timeout) + except asyncio.exceptions.TimeoutError: + if callable(PyromodConfig.timeout_handler): + PyromodConfig.timeout_handler(identifier, listener_data, timeout) + elif PyromodConfig.throw_exceptions: + raise ListenerTimeout(timeout) @patchable - async def ask(self, chat_id, text, filters=None, timeout=None, *args, **kwargs): - request = await self.send_message(chat_id, text, *args, **kwargs) - response = await self.listen(chat_id, filters, timeout) - response.request = request + async def ask( + self, + text, + identifier: tuple, + filters=None, + listener_type=ListenerTypes.MESSAGE, + timeout=None, + *args, + **kwargs, + ): + request = await self.send_message(identifier[0], text, *args, **kwargs) + response = await self.listen(identifier, filters, listener_type, timeout) + if response: + response.request = request + return response - @patchable - def clear_listener(self, chat_id, future): - if future == self.listening[chat_id]["future"]: - self.listening.pop(chat_id, None) + """ + needed for matching when message_id or + user_id is null, and to take precedence + """ @patchable - def cancel_listener(self, chat_id): - listener = self.listening.get(chat_id) - if not listener or listener["future"].done(): + def match_listener( + self, + data: Optional[tuple] = None, + listener_type: ListenerTypes = ListenerTypes.MESSAGE, + identifier_pattern: Optional[tuple] = None, + ) -> tuple: + if data: + listeners = self.listeners[listener_type] + # case with 3 args on identifier + # most probably waiting for a specific user + # to click a button in a specific message + if data in listeners: + return listeners[data], data + + # cases with 2 args on identifier + # (None, user, message) does not make + # sense since the message_id is not unique + elif (data[0], data[1], None) in listeners: + matched = (data[0], data[1], None) + elif (data[0], None, data[2]) in listeners: + matched = (data[0], None, data[2]) + + # cases with 1 arg on identifier + # (None, None, message) does not make sense as well + elif (data[0], None, None) in listeners: + matched = (data[0], None, None) + elif (None, data[1], None) in listeners: + matched = (None, data[1], None) + else: + return None, None + + return listeners[matched], matched + elif identifier_pattern: + + def match_identifier(pattern, identifier): + comparison = ( + pattern[0] in (identifier[0], None), + pattern[1] in (identifier[1], None), + pattern[2] in (identifier[2], None), + ) + return comparison == (True, True, True) + + for identifier, listener in self.listeners[listener_type].items(): + if match_identifier(identifier_pattern, identifier): + return listener, identifier + return None, None + + @patchable + def stop_listening( + self, + data: Optional[tuple] = None, + listener_type: ListenerTypes = ListenerTypes.MESSAGE, + identifier_pattern: Optional[tuple] = None, + ): + listener, identifier = self.match_listener(data, listener_type, identifier_pattern) + + if not listener: + return + elif listener["future"].done(): + del self.listeners[listener_type][identifier] return - listener["future"].set_exception(ListenerCanceled()) - self.clear_listener(chat_id, listener["future"]) + if callable(PyromodConfig.stopped_handler): + PyromodConfig.stopped_handler(identifier, listener) + elif PyromodConfig.throw_exceptions: + listener["future"].set_exception(ListenerStopped()) + + del self.listeners[listener_type][identifier] @patch(pyrogram.handlers.message_handler.MessageHandler) class MessageHandler: @patchable - def __init__(self, callback: callable, filters=None): - self.user_callback = callback - self.old__init__(self.resolve_listener, filters) + def __init__(self, callback: Callable, filters=None): + self.registered_handler = callback + self.old__init__(self.resolve_future, filters) @patchable - async def resolve_listener(self, client, message, *args): - listener = client.listening.get(message.chat.id) - if listener and not listener["future"].done(): - listener["future"].set_result(message) + async def check(self, client, message): + if user := getattr(message, "from_user", None): + user = user.id + listener = client.match_listener( + (message.chat.id, user, message.id), + ListenerTypes.MESSAGE, + )[0] + + listener_does_match = handler_does_match = False + + if listener: + filters = listener["filters"] + if callable(filters): + if iscoroutinefunction(filters.__call__): + listener_does_match = await filters(client, message) + else: + listener_does_match = await client.loop.run_in_executor(None, filters, client, message) + else: + listener_does_match = True + + if callable(self.filters): + if iscoroutinefunction(self.filters.__call__): + handler_does_match = await self.filters(client, message) + else: + handler_does_match = await client.loop.run_in_executor(None, self.filters, client, message) else: - if listener and listener["future"].done(): - client.clear_listener(message.chat.id, listener["future"]) - await self.user_callback(client, message, *args) + handler_does_match = True + + # let handler get the chance to handle if listener + # exists but its filters doesn't match + return listener_does_match or handler_does_match @patchable - async def check(self, client, update): - listener = client.listening.get(update.chat.id) + async def resolve_future(self, client, message, *args): + listener_type = ListenerTypes.MESSAGE + if user := getattr(message, "from_user", None): + user = user.id + listener, identifier = client.match_listener( + (message.chat.id, user, message.id), + listener_type, + ) + listener_does_match = False + if listener: + filters = listener["filters"] + if callable(filters): + if iscoroutinefunction(filters.__call__): + listener_does_match = await filters(client, message) + else: + listener_does_match = await client.loop.run_in_executor(None, filters, client, message) + else: + listener_does_match = True + + if listener_does_match: + if not listener["future"].done(): + listener["future"].set_result(message) + del client.listeners[listener_type][identifier] + raise pyrogram.StopPropagation + else: + await self.registered_handler(client, message, *args) + + +@patch(pyrogram.handlers.callback_query_handler.CallbackQueryHandler) +class CallbackQueryHandler: + @patchable + def __init__(self, callback: Callable, filters=None): + self.registered_handler = callback + self.old__init__(self.resolve_future, filters) + + @patchable + async def check(self, client, query): + chatID, mID = None, None + if message := getattr(query, "message", None): + chatID, mID = message.chat.id, message.id + listener = client.match_listener( + (chatID, query.from_user.id, mID), + ListenerTypes.CALLBACK_QUERY, + )[0] + + # managing unallowed user clicks + if PyromodConfig.unallowed_click_alert: + permissive_listener = client.match_listener( + identifier_pattern=( + chatID, + None, + mID, + ), + listener_type=ListenerTypes.CALLBACK_QUERY, + )[0] + + if (permissive_listener and not listener) and permissive_listener["unallowed_click_alert"]: + alert = permissive_listener["unallowed_click_alert"] if type(permissive_listener["unallowed_click_alert"]) == str else PyromodConfig.unallowed_click_alert_text + await query.answer(alert) + return False + + filters = listener["filters"] if listener else self.filters + + if callable(filters): + if iscoroutinefunction(filters.__call__): + return await filters(client, query) + else: + return await client.loop.run_in_executor(None, filters, client, query) + else: + return True + + @patchable + async def resolve_future(self, client, query, *args): + listener_type = ListenerTypes.CALLBACK_QUERY + chatID, mID = None, None + if message := getattr(query, "message", None): + chatID, mID = message.chat.id, message.id + listener, identifier = client.match_listener( + (chatID, query.from_user.id, mID), + listener_type, + ) if listener and not listener["future"].done(): - return await listener["filters"](client, update) if callable(listener["filters"]) else True + listener["future"].set_result(query) + del client.listeners[listener_type][identifier] + else: + await self.registered_handler(client, query, *args) - return await self.filters(client, update) if callable(self.filters) else True + +@patch(pyrogram.types.messages_and_media.message.Message) +class Message(pyrogram.types.messages_and_media.message.Message): + @patchable + async def wait_for_click( + self, + from_user_id: Optional[int] = None, + timeout: Optional[int] = None, + filters=None, + alert: Union[str, bool] = True, + ): + return await self._client.listen( + (self.chat.id, from_user_id, self.id), + listener_type=ListenerTypes.CALLBACK_QUERY, + timeout=timeout, + filters=filters, + unallowed_click_alert=alert, + ) @patch(pyrogram.types.user_and_chats.chat.Chat) class Chat(pyrogram.types.Chat): @patchable def listen(self, *args, **kwargs): - return self._client.listen(self.id, *args, **kwargs) + return self._client.listen((self.id, None, None), *args, **kwargs) @patchable - def ask(self, *args, **kwargs): - return self._client.ask(self.id, *args, **kwargs) + def ask(self, text, *args, **kwargs): + return self._client.ask(text, (self.id, None, None), *args, **kwargs) @patchable - def cancel_listener(self): - return self._client.cancel_listener(self.id) + def stop_listening(self, *args, **kwargs): + return self._client.stop_listening(*args, identifier_pattern=(self.id, None, None), **kwargs) @patch(pyrogram.types.user_and_chats.user.User) class User(pyrogram.types.User): @patchable def listen(self, *args, **kwargs): - return self._client.listen(self.id, *args, **kwargs) + return self._client.listen((None, self.id, None), *args, **kwargs) @patchable - def ask(self, *args, **kwargs): - return self._client.ask(self.id, *args, **kwargs) + def ask(self, text, *args, **kwargs): + return self._client.ask(text, (self.id, self.id, None), *args, **kwargs) @patchable - def cancel_listener(self): - return self._client.cancel_listener(self.id) + def stop_listening(self, *args, **kwargs): + return self._client.stop_listening(*args, identifier_pattern=(None, self.id, None), **kwargs) \ No newline at end of file diff --git a/misskaty/core/misskaty_patch/utils/__init__.py b/misskaty/core/misskaty_patch/utils/__init__.py index cefd2cad..7100e092 100644 --- a/misskaty/core/misskaty_patch/utils/__init__.py +++ b/misskaty/core/misskaty_patch/utils/__init__.py @@ -1 +1 @@ -from .utils import patch, patchable +from .utils import patch, patchable, PyromodConfig diff --git a/misskaty/core/misskaty_patch/utils/utils.py b/misskaty/core/misskaty_patch/utils/utils.py index 22022607..8f26cf59 100644 --- a/misskaty/core/misskaty_patch/utils/utils.py +++ b/misskaty/core/misskaty_patch/utils/utils.py @@ -19,6 +19,14 @@ along with pyromod. If not, see . """ +class PyromodConfig: + timeout_handler = None + stopped_handler = None + throw_exceptions = True + unallowed_click_alert = True + unallowed_click_alert_text = "[misskaty] You're not authorized to click this button." + + def patch(obj): def is_patchable(item): return getattr(item[1], "patchable", False) @@ -26,7 +34,8 @@ def patch(obj): def wrapper(container): for name, func in filter(is_patchable, container.__dict__.items()): old = getattr(obj, name, None) - setattr(obj, f"old{name}", old) + if old is not None: # Not adding 'old' to new func + setattr(obj, f"old{name}", old) setattr(obj, name, func) return container @@ -35,4 +44,4 @@ def patch(obj): def patchable(func): func.patchable = True - return func + return func \ No newline at end of file diff --git a/misskaty/vars.py b/misskaty/vars.py index 231b73f9..ad16b077 100644 --- a/misskaty/vars.py +++ b/misskaty/vars.py @@ -57,7 +57,6 @@ SUDO = list( } ) SUPPORT_CHAT = environ.get("SUPPORT_CHAT", "YasirPediaChannel") -NIGHTMODE = environ.get("NIGHTMODE", False) OPENAI_API = getConfig("OPENAI_API") ## Config For AUtoForwarder