diff --git a/pyrogram/storage/dummy_client.py b/pyrogram/storage/dummy_client.py new file mode 100644 index 00000000..30929c01 --- /dev/null +++ b/pyrogram/storage/dummy_client.py @@ -0,0 +1,44 @@ +from pymongo.client_session import TransactionOptions +from bson.codec_options import CodecOptions +from pymongo.read_concern import ReadConcern +from pymongo.read_preferences import ( + Nearest, + Primary, + PrimaryPreferred, + Secondary, + SecondaryPreferred, +) +from pymongo.write_concern import WriteConcern +from typing import Any, Optional, Union + +try: + from typing import Protocol, runtime_checkable +except ImportError: + from typing_extensions import Protocol, runtime_checkable + +ReadPreferences = Union[Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest] + +@runtime_checkable +class DummyMongoClient(Protocol): + def __init__(self, *args: Any, **kwargs: Any) -> None: + raise NotImplementedError + + def get_database( + self, + name: Optional[str] = None, + *, + codec_options: Optional[CodecOptions] = None, + read_preference: Optional[ReadPreferences] = None, + write_concern: Optional[WriteConcern] = None, + read_concern: Optional[ReadConcern] = None, + ): + raise NotImplementedError + + async def start_session( + self, + *, + causal_consistency: Optional[bool] = None, + default_transaction_options: Optional[TransactionOptions] = None, + snapshot: bool = False, + ): + raise NotImplementedError diff --git a/pyrogram/storage/mongo_storage.py b/pyrogram/storage/mongo_storage.py index 416e6453..5240917e 100644 --- a/pyrogram/storage/mongo_storage.py +++ b/pyrogram/storage/mongo_storage.py @@ -3,7 +3,8 @@ import inspect import time from typing import List, Tuple, Any -from pymongo import UpdateOne +from .dummy_client import DummyMongoClient +from pymongo import MongoClient, UpdateOne from pyrogram.storage.storage import Storage from pyrogram.storage.sqlite_storage import get_input_peer @@ -38,33 +39,18 @@ class MongoStorage(Storage): def __init__( self, name: str, - connection: object, + connection: DummyMongoClient, remove_peers: bool = False ): super().__init__(name=name) database = None - try: - import async_pymongo - except ImportError: - pass - else: - if isinstance(connection, async_pymongo.AsyncClient): - database = connection[name] - try: - from motor.motor_asyncio import AsyncIOMotorClient - except ImportError: - pass + if isinstance(connection, DummyMongoClient): + if isinstance(connection, MongoClient): + raise Exception("Pymongo MongoClient object is not supported! please use async mongodb driver such as async_pymongo and motor.") + database = connection[name] else: - if database: - pass - elif isinstance(connection, AsyncIOMotorClient): - database = connection[name] - else: - raise Exception("Wrong connection object type! please pass valid connection object to connection parameter!") - - if not database: - raise Exception("Please install one of following modules!: async_pymongo, motor") + raise Exception("Wrong connection object type! please pass valid connection object to connection parameter!") self.lock = asyncio.Lock() self.database = database