pyrofork/pyrogram/storage/mongo_storage.py
Juan Simon .D 89feb4a4d8
Pyrofork: Use session name as database name, add some parameters informations and some cleanup (#3)
Changes to be committed:
	modified:   pyrogram/client.py
	modified:   pyrogram/storage/mongo_storage.py

Signed-off-by: wulan17 <wulan17@nusantararom.org>
2023-05-21 20:17:05 +07:00

171 lines
5.3 KiB
Python

import asyncio
import inspect
import time
from typing import List, Tuple, Any
from motor.motor_asyncio import AsyncIOMotorClient
from pymongo import UpdateOne
from pyrogram.storage.storage import Storage
from pyrogram.storage.sqlite_storage import get_input_peer
class MongoStorage(Storage):
"""
Initializes a new session.
Parameters:
- name (`str`):
The session name used for database name.
- uri (`str`):
MongoDB Connection String URI.
For more information refer to https://www.mongodb.com/docs/manual/reference/connection-string
- remove_peers (`bool`, *optional*):
Flag to remove data in the peers collection. If set to True,
the data related to peers will be removed everytime client log out.
If set to False or None, the data will not be removed.
Example:
session = MongoStorage("my_session", uri="mongodb://...", remove_peers=True)
"""
lock: asyncio.Lock
USERNAME_TTL = 8 * 60 * 60
def __init__(self, name: str, uri: str, remove_peers: bool = False):
super().__init__(name=name)
database = AsyncIOMotorClient(uri)[name]
self.lock = asyncio.Lock()
self.database = database
self._peer = database['peers']
self._session = database['session']
self._remove_peers = remove_peers
async def open(self):
"""
dc_id INTEGER PRIMARY KEY,
api_id INTEGER,
test_mode INTEGER,
auth_key BLOB,
date INTEGER NOT NULL,
user_id INTEGER,
is_bot INTEGER
"""
if await self._session.find_one({'_id': 0}, {}):
return
await self._session.insert_one(
{
'_id': 0,
'dc_id': 2,
'api_id': None,
'test_mode': None,
'auth_key': b'',
'date': 0,
'user_id': 0,
'is_bot': 0,
}
)
async def save(self):
pass
async def close(self):
pass
async def delete(self):
try:
await self._session.delete_one({'_id': 0})
if self._remove_peers:
await self._peer.remove({})
except Exception as _:
return
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
"""(id, access_hash, type, username, phone_number)"""
s = int(time.time())
bulk = [
UpdateOne(
{'_id': i[0]},
{'$set': {
'access_hash': i[1],
'type': i[2],
'username': i[3],
'phone_number': i[4],
'last_update_on': s
}},
upsert=True
) for i in peers
]
if not bulk:
return
await self._peer.bulk_write(
bulk
)
async def get_peer_by_id(self, peer_id: int):
# id, access_hash, type
r = await self._peer.find_one({'_id': peer_id}, {'_id': 1, 'access_hash': 1, 'type': 1})
if not r:
raise KeyError(f"ID not found: {peer_id}")
return get_input_peer(r['_id'], r['access_hash'], r['type'])
async def get_peer_by_username(self, username: str):
# id, access_hash, type, last_update_on,
r = await self._peer.find_one({'username': username},
{'_id': 1, 'access_hash': 1, 'type': 1, 'last_update_on': 1})
if r is None:
raise KeyError(f"Username not found: {username}")
if abs(time.time() - r['last_update_on']) > self.USERNAME_TTL:
raise KeyError(f"Username expired: {username}")
return get_input_peer(r['_id'], r['access_hash'], r['type'])
async def get_peer_by_phone_number(self, phone_number: str):
# _id, access_hash, type,
r = await self._peer.find_one({'phone_number': phone_number},
{'_id': 1, 'access_hash': 1, 'type': 1})
if r is None:
raise KeyError(f"Phone number not found: {phone_number}")
return get_input_peer(r['_id'], r['access_hash'], r['type'])
async def _get(self):
attr = inspect.stack()[2].function
d = await self._session.find_one({'_id': 0}, {attr: 1})
if not d:
return
return d[attr]
async def _set(self, value: Any):
attr = inspect.stack()[2].function
await self._session.update_one({'_id': 0}, {'$set': {attr: value}}, upsert=True)
async def _accessor(self, value: Any = object):
return await self._get() if value == object else await self._set(value)
async def dc_id(self, value: int = object):
return await self._accessor(value)
async def api_id(self, value: int = object):
return await self._accessor(value)
async def test_mode(self, value: bool = object):
return await self._accessor(value)
async def auth_key(self, value: bytes = object):
return await self._accessor(value)
async def date(self, value: int = object):
return await self._accessor(value)
async def user_id(self, value: int = object):
return await self._accessor(value)
async def is_bot(self, value: bool = object):
return await self._accessor(value)