From 23c06807965f731ec88b1a896837157fbf5c114c Mon Sep 17 00:00:00 2001 From: yasirarism <55983182+yasirarism@users.noreply.github.com> Date: Mon, 17 Apr 2023 09:03:39 +0700 Subject: [PATCH] Add Session Storage (#54) --- database/session_db.py | 156 ++++++++++++++++++++++++++++++++++++++++ misskaty/__init__.py | 15 ++-- misskaty/plugins/afk.py | 1 - 3 files changed, 164 insertions(+), 8 deletions(-) create mode 100644 database/session_db.py diff --git a/database/session_db.py b/database/session_db.py new file mode 100644 index 00000000..c7257bcc --- /dev/null +++ b/database/session_db.py @@ -0,0 +1,156 @@ +import asyncio +import inspect +import time +from typing import List, Tuple, Any + +from motor.motor_asyncio import AsyncIOMotorDatabase +from pymongo import UpdateOne +from pyrogram.storage.storage import Storage +from pyrogram.storage.sqlite_storage import get_input_peer + + +class MongoStorage(Storage): + """ + database: motor.motor_asyncio.AsyncIOMotorDatabase + required database object of motor + + remove_peers: bool = False + remove peers collection on logout (by default, it will not remove peers) + """ + + lock: asyncio.Lock + USERNAME_TTL = 8 * 60 * 60 + + def __init__(self, database: AsyncIOMotorDatabase, remove_peers: bool = False): + super().__init__("") + self.lock = asyncio.Lock() + self.database = database + self._peer = database["peers"] + self._session = database["session"] + self._remove_peers = remove_peers + + async def open(self): + """ + + dc_id INTEGER PRIMARY KEY, + api_id INTEGER, + test_mode INTEGER, + auth_key BLOB, + date INTEGER NOT NULL, + user_id INTEGER, + is_bot INTEGER + """ + if await self._session.find_one({"_id": 0}, {}): + return + await self._session.insert_one( + { + "_id": 0, + "dc_id": 2, + "api_id": None, + "test_mode": None, + "auth_key": b"", + "date": 0, + "user_id": 0, + "is_bot": 0, + } + ) + + async def save(self): + pass + + async def close(self): + pass + + async def delete(self): + try: + await self._session.delete_one({"_id": 0}) + if self._remove_peers: + await self._peer.remove({}) + except Exception: + return + + async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): + """(id, access_hash, type, username, phone_number)""" + s = int(time.time()) + if bulk := [ + UpdateOne( + {"_id": i[0]}, + { + "$set": { + "access_hash": i[1], + "type": i[2], + "username": i[3], + "phone_number": i[4], + "last_update_on": s, + } + }, + upsert=True, + ) + for i in peers + ]: + await self._peer.bulk_write(bulk) + else: + return + + async def get_peer_by_id(self, peer_id: int): + # id, access_hash, type + r = await self._peer.find_one({"_id": peer_id}, {"_id": 1, "access_hash": 1, "type": 1}) + if not r: + raise KeyError(f"ID not found: {peer_id}") + return get_input_peer(*r.values()) + + async def get_peer_by_username(self, username: str): + # id, access_hash, type, last_update_on, + r = await self._peer.find_one({"username": username}, {"_id": 1, "access_hash": 1, "type": 1, "last_update_on": 1}) + + if r is None: + raise KeyError(f"Username not found: {username}") + + if abs(time.time() - r["last_update_on"]) > self.USERNAME_TTL: + raise KeyError(f"Username expired: {username}") + + return get_input_peer(*list(r.values())[:3]) + + async def get_peer_by_phone_number(self, phone_number: str): + # _id, access_hash, type, + r = await self._peer.find_one({"phone_number": phone_number}, {"_id": 1, "access_hash": 1, "type": 1}) + + if r is None: + raise KeyError(f"Phone number not found: {phone_number}") + + return get_input_peer(*r) + + async def _get(self): + attr = inspect.stack()[2].function + d = await self._session.find_one({"_id": 0}, {attr: 1}) + if not d: + return + return d[attr] + + async def _set(self, value: Any): + attr = inspect.stack()[2].function + await self._session.update_one({"_id": 0}, {"$set": {attr: value}}, upsert=True) + + async def _accessor(self, value: Any = object): + return await self._get() if value == object else await self._set(value) + + async def dc_id(self, value: int = object): + return await self._accessor(value) + + async def api_id(self, value: int = object): + return await self._accessor(value) + + async def test_mode(self, value: bool = object): + return await self._accessor(value) + + async def auth_key(self, value: bytes = object): + return await self._accessor(value) + + async def date(self, value: int = object): + return await self._accessor(value) + + async def user_id(self, value: int = object): + return await self._accessor(value) + + async def is_bot(self, value: bool = object): + return await self._accessor(value) diff --git a/misskaty/__init__.py b/misskaty/__init__.py index 3c472ddf..ca1e1d17 100644 --- a/misskaty/__init__.py +++ b/misskaty/__init__.py @@ -1,14 +1,15 @@ -import os import time from logging import ERROR, INFO, StreamHandler, basicConfig, getLogger, handlers -from misskaty.core import misskaty_patch -from apscheduler.schedulers.asyncio import AsyncIOScheduler from apscheduler.jobstores.mongodb import MongoDBJobStore +from apscheduler.schedulers.asyncio import AsyncIOScheduler +from motor.motor_asyncio import AsyncIOMotorClient from pymongo import MongoClient from pyrogram import Client -from misskaty.vars import API_HASH, API_ID, BOT_TOKEN, DATABASE_URI, USER_SESSION, TZ +from database.session_db import MongoStorage +from misskaty.core import misskaty_patch +from misskaty.vars import API_HASH, API_ID, BOT_TOKEN, DATABASE_URI, TZ, USER_SESSION basicConfig( level=INFO, @@ -28,6 +29,8 @@ HELPABLE = {} cleanmode = {} botStartTime = time.time() +pymonclient = MongoClient(DATABASE_URI) +mongo = AsyncIOMotorClient(DATABASE_URI) # Pyrogram Bot Client app = Client( @@ -36,6 +39,7 @@ app = Client( api_hash=API_HASH, bot_token=BOT_TOKEN, ) +app.storage = MongoStorage(mongo["MissKatyDB"], remove_peers=False) # Pyrogram UserBot Client user = Client( @@ -43,10 +47,7 @@ user = Client( session_string=USER_SESSION, ) -pymonclient = MongoClient(DATABASE_URI) - jobstores = {"default": MongoDBJobStore(client=pymonclient, database="MissKatyDB", collection="nightmode")} - scheduler = AsyncIOScheduler(jobstores=jobstores, timezone=TZ) app.start() diff --git a/misskaty/plugins/afk.py b/misskaty/plugins/afk.py index 84eaad6a..3dc3c5bd 100644 --- a/misskaty/plugins/afk.py +++ b/misskaty/plugins/afk.py @@ -20,7 +20,6 @@ from misskaty import app from misskaty.core.decorator.errors import capture_err from misskaty.core.decorator.permissions import adminsOnly from misskaty.core.decorator.ratelimiter import ratelimiter -from misskaty.core.misskaty_patch.bound import message from misskaty.helper import get_readable_time2 from misskaty.helper.localization import use_chat_lang from misskaty.vars import COMMAND_HANDLER