mirror of
https://github.com/Mayuri-Chan/pyrofork.git
synced 2025-12-29 12:04:51 +00:00
Pyrofork: Add Mongodb Session Storage
Signed-off-by: wulan17 <wulan17@nusantararom.org> Co-authored-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
parent
91ff2dc82d
commit
6b29171e29
5 changed files with 198 additions and 2 deletions
|
|
@ -61,6 +61,27 @@ In case you don't want to have any session file saved to disk, you can use an in
|
|||
This storage engine is still backed by SQLite, but the database exists purely in memory. This means that, once you stop
|
||||
a client, the entire database is discarded and the session details used for logging in again will be lost forever.
|
||||
|
||||
Mongodb Storage
|
||||
^^^^^^^^^^^^^^^
|
||||
|
||||
In case you want to have persistent session but you don't have persistent storage you can use mongodb storage by passing
|
||||
mongodb config as ``dict`` to the ``mongodb`` parameter of the :obj:`~pyrogram.Client` constructor:
|
||||
.. code-block:: python
|
||||
|
||||
from pyrogram import Client
|
||||
|
||||
# uri (``str``):
|
||||
# mongodb database uri
|
||||
# db_name (``str``, *optional*):
|
||||
# custom database name, default = pyrofork-session
|
||||
# remove_peers (``bool``, *optional*):
|
||||
# remove peers collection on logout, default = False
|
||||
async with Client("my_account", mongodb=dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)) as app:
|
||||
print(await app.get_me())
|
||||
|
||||
This storage engine is backed by MongoDB, a session will be created and saved to mongodb database. Any subsequent client
|
||||
restart will make PyroFork search for a database named that way and the session database will be automatically loaded.
|
||||
|
||||
Session Strings
|
||||
---------------
|
||||
|
||||
|
|
|
|||
|
|
@ -48,7 +48,7 @@ from pyrogram.errors import (
|
|||
from pyrogram.handlers.handler import Handler
|
||||
from pyrogram.methods import Methods
|
||||
from pyrogram.session import Auth, Session
|
||||
from pyrogram.storage import FileStorage, MemoryStorage
|
||||
from pyrogram.storage import FileStorage, MemoryStorage, MongoStorage
|
||||
from pyrogram.types import User, TermsOfService
|
||||
from pyrogram.utils import ainput
|
||||
from .dispatcher import Dispatcher
|
||||
|
|
@ -120,6 +120,10 @@ class Client(Methods):
|
|||
pass to the ``session_string`` parameter.
|
||||
Defaults to False.
|
||||
|
||||
mongodb (``dict``, *optional*):
|
||||
Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*.
|
||||
Only applicable for new sessions.
|
||||
|
||||
phone_number (``str``, *optional*):
|
||||
Pass the phone number as string (with the Country Code prefix included) to avoid entering it manually.
|
||||
Only applicable for new sessions.
|
||||
|
|
@ -203,6 +207,7 @@ class Client(Methods):
|
|||
bot_token: str = None,
|
||||
session_string: str = None,
|
||||
in_memory: bool = None,
|
||||
mongodb: dict = None,
|
||||
phone_number: str = None,
|
||||
phone_code: str = None,
|
||||
password: str = None,
|
||||
|
|
@ -230,6 +235,7 @@ class Client(Methods):
|
|||
self.bot_token = bot_token
|
||||
self.session_string = session_string
|
||||
self.in_memory = in_memory
|
||||
self.mongodb = mongodb
|
||||
self.phone_number = phone_number
|
||||
self.phone_code = phone_code
|
||||
self.password = password
|
||||
|
|
@ -248,6 +254,8 @@ class Client(Methods):
|
|||
self.storage = MemoryStorage(self.name, self.session_string)
|
||||
elif self.in_memory:
|
||||
self.storage = MemoryStorage(self.name)
|
||||
elif self.mongodb:
|
||||
self.storage = MongoStorage(self.mongodb)
|
||||
else:
|
||||
self.storage = FileStorage(self.name, self.workdir)
|
||||
|
||||
|
|
|
|||
|
|
@ -18,4 +18,5 @@
|
|||
|
||||
from .file_storage import FileStorage
|
||||
from .memory_storage import MemoryStorage
|
||||
from .mongo_storage import MongoStorage
|
||||
from .storage import Storage
|
||||
|
|
|
|||
164
pyrogram/storage/mongo_storage.py
Normal file
164
pyrogram/storage/mongo_storage.py
Normal file
|
|
@ -0,0 +1,164 @@
|
|||
import asyncio
|
||||
import inspect
|
||||
import time
|
||||
from typing import List, Tuple, Any
|
||||
|
||||
from motor.motor_asyncio import AsyncIOMotorClient
|
||||
from pymongo import UpdateOne
|
||||
from pyrogram.storage.storage import Storage
|
||||
from pyrogram.storage.sqlite_storage import get_input_peer
|
||||
|
||||
|
||||
class MongoStorage(Storage):
|
||||
"""
|
||||
config (``dict``)
|
||||
Mongodb config as dict, e.g.: *dict(uri="mongodb://...", db_name="pyrofork-session", remove_peers=False)*.
|
||||
Only applicable for new sessions.
|
||||
"""
|
||||
lock: asyncio.Lock
|
||||
USERNAME_TTL = 8 * 60 * 60
|
||||
|
||||
def __init__(self, config: dict):
|
||||
super().__init__('')
|
||||
db_name = "pyrofork-session"
|
||||
db_uri = config["uri"]
|
||||
remove_peers = False
|
||||
if "db_name" in config:
|
||||
db_name = config["db_name"]
|
||||
if "remove_peers" in config:
|
||||
remove_peers = config["remove_peers"]
|
||||
database = AsyncIOMotorClient(db_uri)[db_name]
|
||||
self.lock = asyncio.Lock()
|
||||
self.database = database
|
||||
self._peer = database['peers']
|
||||
self._session = database['session']
|
||||
self._remove_peers = remove_peers
|
||||
|
||||
async def open(self):
|
||||
"""
|
||||
|
||||
dc_id INTEGER PRIMARY KEY,
|
||||
api_id INTEGER,
|
||||
test_mode INTEGER,
|
||||
auth_key BLOB,
|
||||
date INTEGER NOT NULL,
|
||||
user_id INTEGER,
|
||||
is_bot INTEGER
|
||||
"""
|
||||
if await self._session.find_one({'_id': 0}, {}):
|
||||
return
|
||||
await self._session.insert_one(
|
||||
{
|
||||
'_id': 0,
|
||||
'dc_id': 2,
|
||||
'api_id': None,
|
||||
'test_mode': None,
|
||||
'auth_key': b'',
|
||||
'date': 0,
|
||||
'user_id': 0,
|
||||
'is_bot': 0,
|
||||
|
||||
}
|
||||
)
|
||||
|
||||
async def save(self):
|
||||
pass
|
||||
|
||||
async def close(self):
|
||||
pass
|
||||
|
||||
async def delete(self):
|
||||
try:
|
||||
await self._session.delete_one({'_id': 0})
|
||||
if self._remove_peers:
|
||||
await self._peer.remove({})
|
||||
except Exception as _:
|
||||
return
|
||||
|
||||
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
|
||||
"""(id, access_hash, type, username, phone_number)"""
|
||||
s = int(time.time())
|
||||
bulk = [
|
||||
UpdateOne(
|
||||
{'_id': i[0]},
|
||||
{'$set': {
|
||||
'access_hash': i[1],
|
||||
'type': i[2],
|
||||
'username': i[3],
|
||||
'phone_number': i[4],
|
||||
'last_update_on': s
|
||||
}},
|
||||
upsert=True
|
||||
) for i in peers
|
||||
]
|
||||
if not bulk:
|
||||
return
|
||||
await self._peer.bulk_write(
|
||||
bulk
|
||||
)
|
||||
|
||||
async def get_peer_by_id(self, peer_id: int):
|
||||
# id, access_hash, type
|
||||
r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1})
|
||||
if not r:
|
||||
raise KeyError(f"ID not found: {peer_id}")
|
||||
return get_input_peer(*r.values())
|
||||
|
||||
async def get_peer_by_username(self, username: str):
|
||||
# id, access_hash, type, last_update_on,
|
||||
r = await self._peer.find_one({'username': username},
|
||||
{'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1})
|
||||
|
||||
if r is None:
|
||||
raise KeyError(f"Username not found: {username}")
|
||||
|
||||
if abs(time.time() - r['last_update_on']) > self.USERNAME_TTL:
|
||||
raise KeyError(f"Username expired: {username}")
|
||||
|
||||
return get_input_peer(*list(r.values())[:3])
|
||||
|
||||
async def get_peer_by_phone_number(self, phone_number: str):
|
||||
|
||||
# _id, access_hash, type,
|
||||
r = await self._peer.find_one({'phone_number': phone_number},
|
||||
{'_id': 1, 'access_hash': 1, 'type': 1})
|
||||
|
||||
if r is None:
|
||||
raise KeyError(f"Phone number not found: {phone_number}")
|
||||
|
||||
return get_input_peer(*r)
|
||||
|
||||
async def _get(self):
|
||||
attr = inspect.stack()[2].function
|
||||
d = await self._session.find_one({'_id': 0}, {attr: 1})
|
||||
if not d:
|
||||
return
|
||||
return d[attr]
|
||||
|
||||
async def _set(self, value: Any):
|
||||
attr = inspect.stack()[2].function
|
||||
await self._session.update_one({'_id': 0}, {'$set': {attr: value}}, upsert=True)
|
||||
|
||||
async def _accessor(self, value: Any = object):
|
||||
return await self._get() if value == object else await self._set(value)
|
||||
|
||||
async def dc_id(self, value: int = object):
|
||||
return await self._accessor(value)
|
||||
|
||||
async def api_id(self, value: int = object):
|
||||
return await self._accessor(value)
|
||||
|
||||
async def test_mode(self, value: bool = object):
|
||||
return await self._accessor(value)
|
||||
|
||||
async def auth_key(self, value: bytes = object):
|
||||
return await self._accessor(value)
|
||||
|
||||
async def date(self, value: int = object):
|
||||
return await self._accessor(value)
|
||||
|
||||
async def user_id(self, value: int = object):
|
||||
return await self._accessor(value)
|
||||
|
||||
async def is_bot(self, value: bool = object):
|
||||
return await self._accessor(value)
|
||||
|
|
@ -1,4 +1,6 @@
|
|||
aiosqlite>=0.17.0,<0.19.0
|
||||
motor==3.1.2
|
||||
pyaes==1.6.1
|
||||
pymediainfo==6.0.1
|
||||
pymongo==4.3.3
|
||||
pysocks==1.7.1
|
||||
aiosqlite>=0.17.0,<0.19.0
|
||||
|
|
|
|||
Loading…
Reference in a new issue