From 6b29171e29d2c5dcc9d53ea10fbc668dd02a2e2e Mon Sep 17 00:00:00 2001 From: Animesh Murmu <47624427+animeshxd@users.noreply.github.com> Date: Wed, 10 May 2023 18:50:44 +0700 Subject: [PATCH] Pyrofork: Add Mongodb Session Storage Signed-off-by: wulan17 Co-authored-by: wulan17 --- docs/source/topics/storage-engines.rst | 21 ++++ pyrogram/client.py | 10 +- pyrogram/storage/__init__.py | 1 + pyrogram/storage/mongo_storage.py | 164 +++++++++++++++++++++++++ requirements.txt | 4 +- 5 files changed, 198 insertions(+), 2 deletions(-) create mode 100644 pyrogram/storage/mongo_storage.py diff --git a/docs/source/topics/storage-engines.rst b/docs/source/topics/storage-engines.rst index 4db1d9b0..6a6f76ca 100644 --- a/docs/source/topics/storage-engines.rst +++ b/docs/source/topics/storage-engines.rst @@ -61,6 +61,27 @@ In case you don't want to have any session file saved to disk, you can use an in This storage engine is still backed by SQLite, but the database exists purely in memory. This means that, once you stop a client, the entire database is discarded and the session details used for logging in again will be lost forever. +Mongodb Storage +^^^^^^^^^^^^^^^ + +In case you want to have persistent session but you don't have persistent storage you can use mongodb storage by passing +mongodb config as ``dict`` to the ``mongodb`` parameter of the :obj:`~pyrogram.Client` constructor: +.. code-block:: python + + from pyrogram import Client + + # uri (``str``): + # mongodb database uri + # db_name (``str``, *optional*): + # custom database name, default = pyrofork-session + # remove_peers (``bool``, *optional*): + # remove peers collection on logout, default = False + async with Client("my_account", mongodb=dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)) as app: + print(await app.get_me()) + +This storage engine is backed by MongoDB, a session will be created and saved to mongodb database. Any subsequent client +restart will make PyroFork search for a database named that way and the session database will be automatically loaded. + Session Strings --------------- diff --git a/pyrogram/client.py b/pyrogram/client.py index 3e3fbbf0..b9477f56 100644 --- a/pyrogram/client.py +++ b/pyrogram/client.py @@ -48,7 +48,7 @@ from pyrogram.errors import ( from pyrogram.handlers.handler import Handler from pyrogram.methods import Methods from pyrogram.session import Auth, Session -from pyrogram.storage import FileStorage, MemoryStorage +from pyrogram.storage import FileStorage, MemoryStorage, MongoStorage from pyrogram.types import User, TermsOfService from pyrogram.utils import ainput from .dispatcher import Dispatcher @@ -120,6 +120,10 @@ class Client(Methods): pass to the ``session_string`` parameter. Defaults to False. + mongodb (``dict``, *optional*): + Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*. + Only applicable for new sessions. + phone_number (``str``, *optional*): Pass the phone number as string (with the Country Code prefix included) to avoid entering it manually. Only applicable for new sessions. @@ -203,6 +207,7 @@ class Client(Methods): bot_token: str = None, session_string: str = None, in_memory: bool = None, + mongodb: dict = None, phone_number: str = None, phone_code: str = None, password: str = None, @@ -230,6 +235,7 @@ class Client(Methods): self.bot_token = bot_token self.session_string = session_string self.in_memory = in_memory + self.mongodb = mongodb self.phone_number = phone_number self.phone_code = phone_code self.password = password @@ -248,6 +254,8 @@ class Client(Methods): self.storage = MemoryStorage(self.name, self.session_string) elif self.in_memory: self.storage = MemoryStorage(self.name) + elif self.mongodb: + self.storage = MongoStorage(self.mongodb) else: self.storage = FileStorage(self.name, self.workdir) diff --git a/pyrogram/storage/__init__.py b/pyrogram/storage/__init__.py index 2a43309a..09ff0e86 100644 --- a/pyrogram/storage/__init__.py +++ b/pyrogram/storage/__init__.py @@ -18,4 +18,5 @@ from .file_storage import FileStorage from .memory_storage import MemoryStorage +from .mongo_storage import MongoStorage from .storage import Storage diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py new file mode 100644 index 00000000..49a11f07 --- /dev/null +++ b/pyrogram/storage/mongo_storage.py @@ -0,0 +1,164 @@ +import asyncio +import inspect +import time +from typing import List, Tuple, Any + +from motor.motor_asyncio import AsyncIOMotorClient +from pymongo import UpdateOne +from pyrogram.storage.storage import Storage +from pyrogram.storage.sqlite_storage import get_input_peer + + +class MongoStorage(Storage): + """ + config (``dict``) + Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*. + Only applicable for new sessions. + """ + lock: asyncio.Lock + USERNAME_TTL = 8 * 60 * 60 + + def __init__(self, config: dict): + super().__init__('') + db_name = "pyrofork-session" + db_uri = config["uri"] + remove_peers = False + if "db_name" in config: + db_name = config["db_name"] + if "remove_peers" in config: + remove_peers = config["remove_peers"] + database = AsyncIOMotorClient(db_uri)[db_name] + self.lock = asyncio.Lock() + self.database = database + self._peer = database['peers'] + self._session = database['session'] + self._remove_peers = remove_peers + + async def open(self): + """ + + dc_id INTEGER PRIMARY KEY, + api_id INTEGER, + test_mode INTEGER, + auth_key BLOB, + date INTEGER NOT NULL, + user_id INTEGER, + is_bot INTEGER + """ + if await self._session.find_one({'_id': 0}, {}): + return + await self._session.insert_one( + { + '_id': 0, + 'dc_id': 2, + 'api_id': None, + 'test_mode': None, + 'auth_key': b'', + 'date': 0, + 'user_id': 0, + 'is_bot': 0, + + } + ) + + async def save(self): + pass + + async def close(self): + pass + + async def delete(self): + try: + await self._session.delete_one({'_id': 0}) + if self._remove_peers: + await self._peer.remove({}) + except Exception as _: + return + + async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): + """(id, access_hash, type, username, phone_number)""" + s = int(time.time()) + bulk = [ + UpdateOne( + {'_id': i[0]}, + {'$set': { + 'access_hash': i[1], + 'type': i[2], + 'username': i[3], + 'phone_number': i[4], + 'last_update_on': s + }}, + upsert=True + ) for i in peers + ] + if not bulk: + return + await self._peer.bulk_write( + bulk + ) + + async def get_peer_by_id(self, peer_id: int): + # id, access_hash, type + r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1}) + if not r: + raise KeyError(f"ID not found: {peer_id}") + return get_input_peer(*r.values()) + + async def get_peer_by_username(self, username: str): + # id, access_hash, type, last_update_on, + r = await self._peer.find_one({'username': username}, + {'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1}) + + if r is None: + raise KeyError(f"Username not found: {username}") + + if abs(time.time() - r['last_update_on']) > self.USERNAME_TTL: + raise KeyError(f"Username expired: {username}") + + return get_input_peer(*list(r.values())[:3]) + + async def get_peer_by_phone_number(self, phone_number: str): + + # _id, access_hash, type, + r = await self._peer.find_one({'phone_number': phone_number}, + {'_id': 1, 'access_hash': 1, 'type': 1}) + + if r is None: + raise KeyError(f"Phone number not found: {phone_number}") + + return get_input_peer(*r) + + async def _get(self): + attr = inspect.stack()[2].function + d = await self._session.find_one({'_id': 0}, {attr: 1}) + if not d: + return + return d[attr] + + async def _set(self, value: Any): + attr = inspect.stack()[2].function + await self._session.update_one({'_id': 0}, {'$set': {attr: value}}, upsert=True) + + async def _accessor(self, value: Any = object): + return await self._get() if value == object else await self._set(value) + + async def dc_id(self, value: int = object): + return await self._accessor(value) + + async def api_id(self, value: int = object): + return await self._accessor(value) + + async def test_mode(self, value: bool = object): + return await self._accessor(value) + + async def auth_key(self, value: bytes = object): + return await self._accessor(value) + + async def date(self, value: int = object): + return await self._accessor(value) + + async def user_id(self, value: int = object): + return await self._accessor(value) + + async def is_bot(self, value: bool = object): + return await self._accessor(value) diff --git a/requirements.txt b/requirements.txt index 5176ec0e..d0b2161a 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,4 +1,6 @@ +aiosqlite>=0.17.0,<0.19.0 +motor==3.1.2 pyaes==1.6.1 pymediainfo==6.0.1 +pymongo==4.3.3 pysocks==1.7.1 -aiosqlite>=0.17.0,<0.19.0