[DNM] Revert "feat(storage): migrated to aiosqlite"

This reverts commit 8193d87e2c.

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
wulan17 2024-03-23 18:43:15 +07:00
parent 9bdc824a61
commit 3b5db0b988
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
4 changed files with 88 additions and 87 deletions

View file

@ -3,12 +3,7 @@ name = "pyrofork"
dynamic = ["version"] dynamic = ["version"]
description = "Fork of pyrogram. Elegant, modern and asynchronous Telegram MTProto API framework in Python for users and bots" description = "Fork of pyrogram. Elegant, modern and asynchronous Telegram MTProto API framework in Python for users and bots"
authors = [{ name = "wulan17", email = "mayuri@mayuri.my.id" }] authors = [{ name = "wulan17", email = "mayuri@mayuri.my.id" }]
dependencies = [ dependencies = ["pyaes==1.6.1", "pysocks==1.7.1", "pymediainfo-pyrofork>=6.0.1,<7.0.0"]
"aiosqlite>=0.19.0",
"pyaes==1.6.1",
"pysocks==1.7.1",
"pymediainfo-pyrofork>=6.0.1,<7.0.0"
]
readme = "README.md" readme = "README.md"
license = "LGPL-3.0-or-later" license = "LGPL-3.0-or-later"
requires-python = "~=3.8" requires-python = "~=3.8"

View file

@ -18,10 +18,10 @@
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import logging import logging
import os
import sqlite3
from pathlib import Path from pathlib import Path
import aiosqlite
from .sqlite_storage import SQLiteStorage from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,38 +35,36 @@ class FileStorage(SQLiteStorage):
self.database = workdir / (self.name + self.FILE_EXTENSION) self.database = workdir / (self.name + self.FILE_EXTENSION)
async def update(self): def update(self):
version = await self.version() version = self.version()
if version == 1: if version == 1:
await self.conn.execute("DELETE FROM peers") with self.conn:
await self.conn.commit() self.conn.execute("DELETE FROM peers")
version += 1 version += 1
if version == 2: if version == 2:
await self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER") with self.conn:
await self.conn.commit() self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
version += 1 version += 1
await self.version(version) self.version(version)
async def open(self): async def open(self):
path = self.database path = self.database
file_exists = path.is_file() file_exists = path.is_file()
self.conn = await aiosqlite.connect(str(path), timeout=1) self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False)
await self.conn.execute("PRAGMA journal_mode=WAL")
if not file_exists: if not file_exists:
await self.create() self.create()
else: else:
await self.update() self.update()
await self.conn.execute("VACUUM") with self.conn:
await self.conn.commit() self.conn.execute("VACUUM")
async def delete(self): async def delete(self):
Path(self.database).unlink() os.remove(self.database)

View file

@ -19,10 +19,9 @@
import base64 import base64
import logging import logging
import sqlite3
import struct import struct
import aiosqlite
from .sqlite_storage import SQLiteStorage from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,8 +34,8 @@ class MemoryStorage(SQLiteStorage):
self.session_string = session_string self.session_string = session_string
async def open(self): async def open(self):
self.conn = await aiosqlite.connect(":memory:") self.conn = sqlite3.connect(":memory:", check_same_thread=False)
await self.create() self.create()
if self.session_string: if self.session_string:
# Old format # Old format

View file

@ -18,13 +18,13 @@
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>. # along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import inspect import inspect
import sqlite3
import time import time
from typing import Any, List, Tuple from typing import List, Tuple, Any
import aiosqlite from pyrogram import raw
from pyrogram import raw, utils
from .storage import Storage from .storage import Storage
from .. import utils
# language=SQLite # language=SQLite
SCHEMA = """ SCHEMA = """
@ -116,55 +116,61 @@ class SQLiteStorage(Storage):
def __init__(self, name: str): def __init__(self, name: str):
super().__init__(name) super().__init__(name)
self.conn: aiosqlite.Connection = None self.conn = None # type: sqlite3.Connection
async def create(self): def create(self):
await self.conn.executescript(SCHEMA) with self.conn:
await self.conn.execute("INSERT INTO version VALUES (?)", (self.VERSION,)) self.conn.executescript(SCHEMA)
await self.conn.execute( self.conn.executescript(UNAME_SCHEMA)
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None), self.conn.execute(
) "INSERT INTO version VALUES (?)",
await self.conn.commit() (self.VERSION,)
)
self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None)
)
async def open(self): async def open(self):
raise NotImplementedError raise NotImplementedError
async def save(self): async def save(self):
await self.date(int(time.time())) await self.date(int(time.time()))
await self.conn.commit() self.conn.commit()
async def close(self): async def close(self):
await self.conn.close() self.conn.close()
async def delete(self): async def delete(self):
raise NotImplementedError raise NotImplementedError
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
await self.conn.executemany( self.conn.executemany(
"REPLACE INTO peers (id, access_hash, type, username, phone_number)" "REPLACE INTO peers (id, access_hash, type, username, phone_number)"
"VALUES (?, ?, ?, ?, ?)", "VALUES (?, ?, ?, ?, ?)",
peers peers
) )
async def update_usernames(self, usernames: List[Tuple[int, str]]): async def update_usernames(self, usernames: List[Tuple[int, str]]):
await self.conn.executescript(UNAME_SCHEMA) self.conn.executescript(UNAME_SCHEMA)
for user in usernames: for user in usernames:
await self.conn.execute( self.conn.execute(
"DELETE FROM usernames WHERE peer_id=?", "DELETE FROM usernames WHERE peer_id=?",
(user[0],) (user[0],)
) )
await self.conn.executemany( self.conn.executemany(
"REPLACE INTO usernames (peer_id, id)" "REPLACE INTO usernames (peer_id, id)"
"VALUES (?, ?)", "VALUES (?, ?)",
usernames usernames
) )
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
q = await self.conn.execute( r = self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE id = ?", (peer_id,) "SELECT id, access_hash, type FROM peers WHERE id = ?",
) (peer_id,)
r = await q.fetchone() ).fetchone()
if r is None: if r is None:
raise KeyError(f"ID not found: {peer_id}") raise KeyError(f"ID not found: {peer_id}")
@ -172,30 +178,27 @@ class SQLiteStorage(Storage):
return get_input_peer(*r) return get_input_peer(*r)
async def get_peer_by_username(self, username: str): async def get_peer_by_username(self, username: str):
q = await self.conn.execute( r = self.conn.execute(
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?" "SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?"
"ORDER BY last_update_on DESC", "ORDER BY last_update_on DESC",
(username,), (username,)
) ).fetchone()
r = await q.fetchone()
if r is None: if r is None:
r2 = await self.conn.execute( r2 = self.conn.execute(
"SELECT peer_id, last_update_on FROM usernames WHERE id = ?" "SELECT peer_id, last_update_on FROM usernames WHERE id = ?"
"ORDER BY last_update_on DESC", "ORDER BY last_update_on DESC",
(username,) (username,)
) ).fetchone()
r2 = await r2.fetchone()
if r2 is None: if r2 is None:
raise KeyError(f"Username not found: {username}") raise KeyError(f"Username not found: {username}")
if abs(time.time() - r2[1]) > self.USERNAME_TTL: if abs(time.time() - r2[1]) > self.USERNAME_TTL:
raise KeyError(f"Username expired: {username}") raise KeyError(f"Username expired: {username}")
r = await self.conn.execute( r = r = self.conn.execute(
"SELECT id, access_hash, type, last_update_on FROM peers WHERE id = ?" "SELECT id, access_hash, type, last_update_on FROM peers WHERE id = ?"
"ORDER BY last_update_on DESC", "ORDER BY last_update_on DESC",
(r2[0],) (r2[0],)
) ).fetchone()
r = await r.fetchone()
if r is None: if r is None:
raise KeyError(f"Username not found: {username}") raise KeyError(f"Username not found: {username}")
@ -205,58 +208,64 @@ class SQLiteStorage(Storage):
return get_input_peer(*r[:3]) return get_input_peer(*r[:3])
async def get_peer_by_phone_number(self, phone_number: str): async def get_peer_by_phone_number(self, phone_number: str):
q = await self.conn.execute( r = self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?", "SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
(phone_number,), (phone_number,)
) ).fetchone()
r = await q.fetchone()
if r is None: if r is None:
raise KeyError(f"Phone number not found: {phone_number}") raise KeyError(f"Phone number not found: {phone_number}")
return get_input_peer(*r) return get_input_peer(*r)
async def _get(self): def _get(self):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
q = await self.conn.execute(f"SELECT {attr} FROM sessions") return self.conn.execute(
row = await q.fetchone() f"SELECT {attr} FROM sessions"
return row[0] if row else None ).fetchone()[0]
async def _set(self, value: Any): def _set(self, value: Any):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
await self.conn.execute(f"UPDATE sessions SET {attr} = ?", (value,))
await self.conn.commit()
async def _accessor(self, value: Any = object): with self.conn:
return await self._get() if value == object else await self._set(value) self.conn.execute(
f"UPDATE sessions SET {attr} = ?",
(value,)
)
def _accessor(self, value: Any = object):
return self._get() if value == object else self._set(value)
async def dc_id(self, value: int = object): async def dc_id(self, value: int = object):
return await self._accessor(value) return self._accessor(value)
async def api_id(self, value: int = object): async def api_id(self, value: int = object):
return await self._accessor(value) return self._accessor(value)
async def test_mode(self, value: bool = object): async def test_mode(self, value: bool = object):
return await self._accessor(value) return self._accessor(value)
async def auth_key(self, value: bytes = object): async def auth_key(self, value: bytes = object):
return await self._accessor(value) return self._accessor(value)
async def date(self, value: int = object): async def date(self, value: int = object):
return await self._accessor(value) return self._accessor(value)
async def user_id(self, value: int = object): async def user_id(self, value: int = object):
return await self._accessor(value) return self._accessor(value)
async def is_bot(self, value: bool = object): async def is_bot(self, value: bool = object):
return await self._accessor(value) return self._accessor(value)
async def version(self, value: int = object): def version(self, value: int = object):
if value == object: if value == object:
q = await self.conn.execute("SELECT number FROM version") return self.conn.execute(
row = await q.fetchone() "SELECT number FROM version"
return row[0] if row else None ).fetchone()[0]
await self.conn.execute("UPDATE version SET number = ?", (value,)) else:
await self.conn.commit() with self.conn:
return None self.conn.execute(
"UPDATE version SET number = ?",
(value,)
)