feat(storage): migrated to aiosqlite

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
Hitalo M 2023-11-06 19:24:56 -03:00 committed by wulan17
parent ce46d49ec2
commit 8193d87e2c
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
4 changed files with 86 additions and 87 deletions

View file

@ -3,7 +3,12 @@ name = "pyrofork"
dynamic = ["version"]
description = "Fork of pyrogram. Elegant, modern and asynchronous Telegram MTProto API framework in Python for users and bots"
authors = [{ name = "wulan17", email = "mayuri@mayuri.my.id" }]
dependencies = ["pyaes==1.6.1", "pysocks==1.7.1", "pymediainfo-pyrofork>=6.0.1,<7.0.0"]
dependencies = [
"aiosqlite>=0.19.0",
"pyaes==1.6.1",
"pysocks==1.7.1",
"pymediainfo-pyrofork>=6.0.1,<7.0.0"
]
readme = "README.md"
license = "LGPL-3.0-or-later"
requires-python = "~=3.8"

View file

@ -18,10 +18,10 @@
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import logging
import os
import sqlite3
from pathlib import Path
import aiosqlite
from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__)
@ -35,36 +35,38 @@ class FileStorage(SQLiteStorage):
self.database = workdir / (self.name + self.FILE_EXTENSION)
def update(self):
version = self.version()
async def update(self):
version = await self.version()
if version == 1:
with self.conn:
self.conn.execute("DELETE FROM peers")
await self.conn.execute("DELETE FROM peers")
await self.conn.commit()
version += 1
if version == 2:
with self.conn:
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
await self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
await self.conn.commit()
version += 1
self.version(version)
await self.version(version)
async def open(self):
path = self.database
file_exists = path.is_file()
self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False)
self.conn = await aiosqlite.connect(str(path), timeout=1)
await self.conn.execute("PRAGMA journal_mode=WAL")
if not file_exists:
self.create()
await self.create()
else:
self.update()
await self.update()
with self.conn:
self.conn.execute("VACUUM")
await self.conn.execute("VACUUM")
await self.conn.commit()
async def delete(self):
os.remove(self.database)
Path(self.database).unlink()

View file

@ -19,9 +19,10 @@
import base64
import logging
import sqlite3
import struct
import aiosqlite
from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__)
@ -34,8 +35,8 @@ class MemoryStorage(SQLiteStorage):
self.session_string = session_string
async def open(self):
self.conn = sqlite3.connect(":memory:", check_same_thread=False)
self.create()
self.conn = await aiosqlite.connect(":memory:")
await self.create()
if self.session_string:
# Old format

View file

@ -18,13 +18,13 @@
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import inspect
import sqlite3
import time
from typing import List, Tuple, Any
from typing import Any, List, Tuple
from pyrogram import raw
import aiosqlite
from pyrogram import raw, utils
from .storage import Storage
from .. import utils
# language=SQLite
SCHEMA = """
@ -116,61 +116,55 @@ class SQLiteStorage(Storage):
def __init__(self, name: str):
super().__init__(name)
self.conn = None # type: sqlite3.Connection
self.conn: aiosqlite.Connection = None
def create(self):
with self.conn:
self.conn.executescript(SCHEMA)
self.conn.executescript(UNAME_SCHEMA)
self.conn.execute(
"INSERT INTO version VALUES (?)",
(self.VERSION,)
)
self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None)
)
async def create(self):
await self.conn.executescript(SCHEMA)
await self.conn.execute("INSERT INTO version VALUES (?)", (self.VERSION,))
await self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None),
)
await self.conn.commit()
async def open(self):
raise NotImplementedError
async def save(self):
await self.date(int(time.time()))
self.conn.commit()
await self.conn.commit()
async def close(self):
self.conn.close()
await self.conn.close()
async def delete(self):
raise NotImplementedError
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
self.conn.executemany(
await self.conn.executemany(
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
"VALUES (?, ?, ?, ?, ?)",
peers
)
async def update_usernames(self, usernames: List[Tuple[int, str]]):
self.conn.executescript(UNAME_SCHEMA)
await self.conn.executescript(UNAME_SCHEMA)
for user in usernames:
self.conn.execute(
await self.conn.execute(
"DELETE FROM usernames WHERE peer_id=?",
(user[0],)
)
self.conn.executemany(
await self.conn.executemany(
"REPLACE INTO usernames (peer_id, id)"
"VALUES (?, ?)",
usernames
)
async def get_peer_by_id(self, peer_id: int):
r = self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE id = ?",
(peer_id,)
).fetchone()
q = await self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE id = ?", (peer_id,)
)
r = await q.fetchone()
if r is None:
raise KeyError(f"ID not found: {peer_id}")
@ -178,27 +172,30 @@ class SQLiteStorage(Storage):
return get_input_peer(*r)
async def get_peer_by_username(self, username: str):
r = self.conn.execute(
q = await self.conn.execute(
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?"
"ORDER BY last_update_on DESC",
(username,)
).fetchone()
(username,),
)
r = await q.fetchone()
if r is None:
r2 = self.conn.execute(
"SELECT peer_id, last_update_on FROM usernames WHERE id = ?"
"ORDER BY last_update_on DESC",
(username,)
).fetchone()
)
r2 = await r2.fetchone()
if r2 is None:
raise KeyError(f"Username not found: {username}")
if abs(time.time() - r2[1]) > self.USERNAME_TTL:
raise KeyError(f"Username expired: {username}")
r = r = self.conn.execute(
r = await self.conn.execute(
"SELECT id, access_hash, type, last_update_on FROM peers WHERE id = ?"
"ORDER BY last_update_on DESC",
(r2[0],)
).fetchone()
)
r = await r.fetchone()
if r is None:
raise KeyError(f"Username not found: {username}")
@ -208,64 +205,58 @@ class SQLiteStorage(Storage):
return get_input_peer(*r[:3])
async def get_peer_by_phone_number(self, phone_number: str):
r = self.conn.execute(
q = await self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
(phone_number,)
).fetchone()
(phone_number,),
)
r = await q.fetchone()
if r is None:
raise KeyError(f"Phone number not found: {phone_number}")
return get_input_peer(*r)
def _get(self):
async def _get(self):
attr = inspect.stack()[2].function
return self.conn.execute(
f"SELECT {attr} FROM sessions"
).fetchone()[0]
q = await self.conn.execute(f"SELECT {attr} FROM sessions")
row = await q.fetchone()
return row[0] if row else None
def _set(self, value: Any):
async def _set(self, value: Any):
attr = inspect.stack()[2].function
await self.conn.execute(f"UPDATE sessions SET {attr} = ?", (value,))
await self.conn.commit()
with self.conn:
self.conn.execute(
f"UPDATE sessions SET {attr} = ?",
(value,)
)
def _accessor(self, value: Any = object):
return self._get() if value == object else self._set(value)
async def _accessor(self, value: Any = object):
return await self._get() if value == object else await self._set(value)
async def dc_id(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def api_id(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def test_mode(self, value: bool = object):
return self._accessor(value)
return await self._accessor(value)
async def auth_key(self, value: bytes = object):
return self._accessor(value)
return await self._accessor(value)
async def date(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def user_id(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def is_bot(self, value: bool = object):
return self._accessor(value)
return await self._accessor(value)
def version(self, value: int = object):
async def version(self, value: int = object):
if value == object:
return self.conn.execute(
"SELECT number FROM version"
).fetchone()[0]
else:
with self.conn:
self.conn.execute(
"UPDATE version SET number = ?",
(value,)
)
q = await self.conn.execute("SELECT number FROM version")
row = await q.fetchone()
return row[0] if row else None
await self.conn.execute("UPDATE version SET number = ?", (value,))
await self.conn.commit()
return None