import asyncio import inspect import time from typing import List, Tuple, Any from motor.motor_asyncio import AsyncIOMotorClient from pymongo import UpdateOne from pyrogram.storage.storage import Storage from pyrogram.storage.sqlite_storage import get_input_peer class MongoStorage(Storage): """ config (``dict``) Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*. Only applicable for new sessions. """ lock: asyncio.Lock USERNAME_TTL = 8 * 60 * 60 def __init__(self, config: dict): super().__init__('') db_name = "pyrofork-session" db_uri = config["uri"] remove_peers = False if "db_name" in config: db_name = config["db_name"] if "remove_peers" in config: remove_peers = config["remove_peers"] database = AsyncIOMotorClient(db_uri)[db_name] self.lock = asyncio.Lock() self.database = database self._peer = database['peers'] self._session = database['session'] self._remove_peers = remove_peers async def open(self): """ dc_id INTEGER PRIMARY KEY, api_id INTEGER, test_mode INTEGER, auth_key BLOB, date INTEGER NOT NULL, user_id INTEGER, is_bot INTEGER """ if await self._session.find_one({'_id': 0}, {}): return await self._session.insert_one( { '_id': 0, 'dc_id': 2, 'api_id': None, 'test_mode': None, 'auth_key': b'', 'date': 0, 'user_id': 0, 'is_bot': 0, } ) async def save(self): pass async def close(self): pass async def delete(self): try: await self._session.delete_one({'_id': 0}) if self._remove_peers: await self._peer.remove({}) except Exception as _: return async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): """(id, access_hash, type, username, phone_number)""" s = int(time.time()) bulk = [ UpdateOne( {'_id': i[0]}, {'$set': { 'access_hash': i[1], 'type': i[2], 'username': i[3], 'phone_number': i[4], 'last_update_on': s }}, upsert=True ) for i in peers ] if not bulk: return await self._peer.bulk_write( bulk ) async def get_peer_by_id(self, peer_id: int): # id, access_hash, type r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1}) if not r: raise KeyError(f"ID not found: {peer_id}") return get_input_peer(*r.values()) async def get_peer_by_username(self, username: str): # id, access_hash, type, last_update_on, r = await self._peer.find_one({'username': username}, {'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1}) if r is None: raise KeyError(f"Username not found: {username}") if abs(time.time() - r['last_update_on']) > self.USERNAME_TTL: raise KeyError(f"Username expired: {username}") return get_input_peer(*list(r.values())[:3]) async def get_peer_by_phone_number(self, phone_number: str): # _id, access_hash, type, r = await self._peer.find_one({'phone_number': phone_number}, {'_id': 1, 'access_hash': 1, 'type': 1}) if r is None: raise KeyError(f"Phone number not found: {phone_number}") return get_input_peer(*r) async def _get(self): attr = inspect.stack()[2].function d = await self._session.find_one({'_id': 0}, {attr: 1}) if not d: return return d[attr] async def _set(self, value: Any): attr = inspect.stack()[2].function await self._session.update_one({'_id': 0}, {'$set': {attr: value}}, upsert=True) async def _accessor(self, value: Any = object): return await self._get() if value == object else await self._set(value) async def dc_id(self, value: int = object): return await self._accessor(value) async def api_id(self, value: int = object): return await self._accessor(value) async def test_mode(self, value: bool = object): return await self._accessor(value) async def auth_key(self, value: bytes = object): return await self._accessor(value) async def date(self, value: int = object): return await self._accessor(value) async def user_id(self, value: int = object): return await self._accessor(value) async def is_bot(self, value: bool = object): return await self._accessor(value)