PyroFork: storage: mongo: Use existing database connection

support both async_pymongo and motor

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
wulan17 2023-05-22 19:18:55 +07:00
parent 8ca1b35f81
commit b62adc60c8
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
3 changed files with 39 additions and 9 deletions

View file

@ -121,7 +121,7 @@ class Client(Methods):
Defaults to False.
mongodb (``dict``, *optional*):
Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*.
Mongodb config as dict, e.g.: *dict(connection=async_pymongo.AsyncClient("mongodb://..."), remove_peers=False)*.
Only applicable for new sessions.
phone_number (``str``, *optional*):

View file

@ -3,7 +3,6 @@ import inspect
import time
from typing import List, Tuple, Any
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import UpdateOne
from pyrogram.storage.storage import Storage
from pyrogram.storage.sqlite_storage import get_input_peer
@ -17,9 +16,9 @@ class MongoStorage(Storage):
- name (`str`):
The session name used for database name.
- uri (`str`):
MongoDB Connection String URI.
For more information refer to https://www.mongodb.com/docs/manual/reference/connection-string
- connection (`obj`):
Mongodb connections object.
~async_pymongo.AsyncClient or ~motor.motor_asyncio.AsyncIOMotorClient object
- remove_peers (`bool`, *optional*):
Flag to remove data in the peers collection. If set to True,
@ -27,14 +26,46 @@ class MongoStorage(Storage):
If set to False or None, the data will not be removed.
Example:
session = MongoStorage("my_session", uri="mongodb://...", remove_peers=True)
import async_pymongo
conn = async_pymongo.AsyncClient("mongodb://...")
bot_db = conn["my_bot"]
session = MongoStorage("my_session", connection=conn, remove_peers=True)
"""
lock: asyncio.Lock
USERNAME_TTL = 8 * 60 * 60
def __init__(self, name: str, uri: str, remove_peers: bool = False):
def __init__(
self,
name: str,
connection: object,
remove_peers: bool = False
):
super().__init__(name=name)
database = AsyncIOMotorClient(uri)[name]
database = None
try:
import async_pymongo
except ImportError:
pass
else:
if isinstance(connection, async_pymongo.AsyncClient):
database = connection[name]
try:
from motor.motor_asyncio import AsyncIOMotorClient
except ImportError:
pass
else:
if database:
pass
elif isinstance(connection, AsyncIOMotorClient):
database = connection[name]
else:
raise Exception("Wrong connection object type! please pass valid connection object to connection parameter!")
if not database:
raise Exception("Please install one of following modules!: async_pymongo, motor")
self.lock = asyncio.Lock()
self.database = database
self._peer = database['peers']

View file

@ -1,5 +1,4 @@
aiosqlite>=0.17.0,<0.19.0
motor==3.1.2
pyaes==1.6.1
pymediainfo==6.0.1
pymongo==4.3.3