PyroFork: Use Dummy client object to check wether connection object is valid or not

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
wulan17 2023-06-22 23:14:14 +07:00
parent 2a93257fa2
commit 56f4d2b8ad
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
2 changed files with 52 additions and 22 deletions

View file

@ -0,0 +1,44 @@
from pymongo.client_session import TransactionOptions
from bson.codec_options import CodecOptions
from pymongo.read_concern import ReadConcern
from pymongo.read_preferences import (
Nearest,
Primary,
PrimaryPreferred,
Secondary,
SecondaryPreferred,
)
from pymongo.write_concern import WriteConcern
from typing import Any, Optional, Union
try:
from typing import Protocol, runtime_checkable
except ImportError:
from typing_extensions import Protocol, runtime_checkable
ReadPreferences = Union[Primary, PrimaryPreferred, Secondary, SecondaryPreferred, Nearest]
@runtime_checkable
class DummyMongoClient(Protocol):
def __init__(self, *args: Any, **kwargs: Any) -> None:
raise NotImplementedError
def get_database(
self,
name: Optional[str] = None,
*,
codec_options: Optional[CodecOptions] = None,
read_preference: Optional[ReadPreferences] = None,
write_concern: Optional[WriteConcern] = None,
read_concern: Optional[ReadConcern] = None,
):
raise NotImplementedError
async def start_session(
self,
*,
causal_consistency: Optional[bool] = None,
default_transaction_options: Optional[TransactionOptions] = None,
snapshot: bool = False,
):
raise NotImplementedError

View file

@ -3,7 +3,8 @@ import inspect
import time
from typing import List, Tuple, Any
from pymongo import UpdateOne
from .dummy_client import DummyMongoClient
from pymongo import MongoClient, UpdateOne
from pyrogram.storage.storage import Storage
from pyrogram.storage.sqlite_storage import get_input_peer
@ -38,33 +39,18 @@ class MongoStorage(Storage):
def __init__(
self,
name: str,
connection: object,
connection: DummyMongoClient,
remove_peers: bool = False
):
super().__init__(name=name)
database = None
try:
import async_pymongo
except ImportError:
pass
else:
if isinstance(connection, async_pymongo.AsyncClient):
database = connection[name]
try:
from motor.motor_asyncio import AsyncIOMotorClient
except ImportError:
pass
if isinstance(connection, DummyMongoClient):
if isinstance(connection, MongoClient):
raise Exception("Pymongo MongoClient object is not supported! please use async mongodb driver such as async_pymongo and motor.")
database = connection[name]
else:
if database:
pass
elif isinstance(connection, AsyncIOMotorClient):
database = connection[name]
else:
raise Exception("Wrong connection object type! please pass valid connection object to connection parameter!")
if not database:
raise Exception("Please install one of following modules!: async_pymongo, motor")
raise Exception("Wrong connection object type! please pass valid connection object to connection parameter!")
self.lock = asyncio.Lock()
self.database = database