From e3d134492b089d36eb78b249301f6545e03d7e7d Mon Sep 17 00:00:00 2001 From: Animesh Murmu <47624427+animeshxd@users.noreply.github.com> Date: Wed, 10 May 2023 18:50:44 +0700 Subject: [PATCH] Pyrofork: Add Mongodb Session Storage Signed-off-by: wulan17 Co-authored-by: wulan17 --- pyrogram/client.py | 10 +- pyrogram/storage/__init__.py | 1 + pyrogram/storage/mongo_storage.py | 164 ++++++++++++++++++++++++++++++ requirements.txt | 4 +- 4 files changed, 177 insertions(+), 2 deletions(-) create mode 100644 pyrogram/storage/mongo_storage.py diff --git a/pyrogram/client.py b/pyrogram/client.py index 3e3fbbf0..b9477f56 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -48,7 +48,7 @@ from pyrogram.errors import ( from pyrogram.handlers.handler import Handler from pyrogram.methods import Methods from pyrogram.session import Auth, Session -from pyrogram.storage import FileStorage, MemoryStorage +from pyrogram.storage import FileStorage, MemoryStorage, MongoStorage from pyrogram.types import User, TermsOfService from pyrogram.utils import ainput from .dispatcher import Dispatcher @@ -120,6 +120,10 @@ class Client(Methods): pass to the ``session_string`` parameter. Defaults to False. + mongodb (``dict``, *optional*): + Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*. + Only applicable for new sessions. + phone_number (``str``, *optional*): Pass the phone number as string (with the Country Code prefix included) to avoid entering it manually. Only applicable for new sessions. @@ -203,6 +207,7 @@ class Client(Methods): bot_token: str = None, session_string: str = None, in_memory: bool = None, + mongodb: dict = None, phone_number: str = None, phone_code: str = None, password: str = None, @@ -230,6 +235,7 @@ class Client(Methods): self.bot_token = bot_token self.session_string = session_string self.in_memory = in_memory + self.mongodb = mongodb self.phone_number = phone_number self.phone_code = phone_code self.password = password @@ -248,6 +254,8 @@ class Client(Methods): self.storage = MemoryStorage(self.name, self.session_string) elif self.in_memory: self.storage = MemoryStorage(self.name) + elif self.mongodb: + self.storage = MongoStorage(self.mongodb) else: self.storage = FileStorage(self.name, self.workdir) diff --git a/pyrogram/storage/__init__.py b/pyrogram/storage/__init__.py index 2a43309a..09ff0e86 100644 --- a/pyrogram/storage/__init__.py +++ b/pyrogram/storage/__init__.py @@ -18,4 +18,5 @@ from .file_storage import FileStorage from .memory_storage import MemoryStorage +from .mongo_storage import MongoStorage from .storage import Storage diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py new file mode 100644 index 00000000..49a11f07 --- /dev/null +++ b/pyrogram/storage/mongo_storage.py @@ -0,0 +1,164 @@ +import asyncio +import inspect +import time +from typing import List, Tuple, Any + +from motor.motor_asyncio import AsyncIOMotorClient +from pymongo import UpdateOne +from pyrogram.storage.storage import Storage +from pyrogram.storage.sqlite_storage import get_input_peer + + +class MongoStorage(Storage): + """ + config (``dict``) + Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*. + Only applicable for new sessions. + """ + lock: asyncio.Lock + USERNAME_TTL = 8 * 60 * 60 + + def __init__(self, config: dict): + super().__init__('') + db_name = "pyrofork-session" + db_uri = config["uri"] + remove_peers = False + if "db_name" in config: + db_name = config["db_name"] + if "remove_peers" in config: + remove_peers = config["remove_peers"] + database = AsyncIOMotorClient(db_uri)[db_name] + self.lock = asyncio.Lock() + self.database = database + self._peer = database['peers'] + self._session = database['session'] + self._remove_peers = remove_peers + + async def open(self): + """ + + dc_id INTEGER PRIMARY KEY, + api_id INTEGER, + test_mode INTEGER, + auth_key BLOB, + date INTEGER NOT NULL, + user_id INTEGER, + is_bot INTEGER + """ + if await self._session.find_one({'_id': 0}, {}): + return + await self._session.insert_one( + { + '_id': 0, + 'dc_id': 2, + 'api_id': None, + 'test_mode': None, + 'auth_key': b'', + 'date': 0, + 'user_id': 0, + 'is_bot': 0, + + } + ) + + async def save(self): + pass + + async def close(self): + pass + + async def delete(self): + try: + await self._session.delete_one({'_id': 0}) + if self._remove_peers: + await self._peer.remove({}) + except Exception as _: + return + + async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): + """(id, access_hash, type, username, phone_number)""" + s = int(time.time()) + bulk = [ + UpdateOne( + {'_id': i[0]}, + {'$set': { + 'access_hash': i[1], + 'type': i[2], + 'username': i[3], + 'phone_number': i[4], + 'last_update_on': s + }}, + upsert=True + ) for i in peers + ] + if not bulk: + return + await self._peer.bulk_write( + bulk + ) + + async def get_peer_by_id(self, peer_id: int): + # id, access_hash, type + r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1}) + if not r: + raise KeyError(f"ID not found: {peer_id}") + return get_input_peer(*r.values()) + + async def get_peer_by_username(self, username: str): + # id, access_hash, type, last_update_on, + r = await self._peer.find_one({'username': username}, + {'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1}) + + if r is None: + raise KeyError(f"Username not found: {username}") + + if abs(time.time() - r['last_update_on']) > self.USERNAME_TTL: + raise KeyError(f"Username expired: {username}") + + return get_input_peer(*list(r.values())[:3]) + + async def get_peer_by_phone_number(self, phone_number: str): + + # _id, access_hash, type, + r = await self._peer.find_one({'phone_number': phone_number}, + {'_id': 1, 'access_hash': 1, 'type': 1}) + + if r is None: + raise KeyError(f"Phone number not found: {phone_number}") + + return get_input_peer(*r) + + async def _get(self): + attr = inspect.stack()[2].function + d = await self._session.find_one({'_id': 0}, {attr: 1}) + if not d: + return + return d[attr] + + async def _set(self, value: Any): + attr = inspect.stack()[2].function + await self._session.update_one({'_id': 0}, {'$set': {attr: value}}, upsert=True) + + async def _accessor(self, value: Any = object): + return await self._get() if value == object else await self._set(value) + + async def dc_id(self, value: int = object): + return await self._accessor(value) + + async def api_id(self, value: int = object): + return await self._accessor(value) + + async def test_mode(self, value: bool = object): + return await self._accessor(value) + + async def auth_key(self, value: bytes = object): + return await self._accessor(value) + + async def date(self, value: int = object): + return await self._accessor(value) + + async def user_id(self, value: int = object): + return await self._accessor(value) + + async def is_bot(self, value: bool = object): + return await self._accessor(value) diff --git a/requirements.txt b/requirements.txt index 5176ec0e..d0b2161a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ +aiosqlite>=0.17.0,<0.19.0 +motor==3.1.2 pyaes==1.6.1 pymediainfo==6.0.1 +pymongo==4.3.3 pysocks==1.7.1 -aiosqlite>=0.17.0,<0.19.0