Move to aiosqlite and enable WAL

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
Hitalo 2022-12-25 16:29:11 -03:00 committed by wulan17
parent c034a2da27
commit afd2e3f0bc
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
4 changed files with 73 additions and 75 deletions

View file

@ -18,7 +18,7 @@
import logging import logging
import os import os
import sqlite3 import aiosqlite
from pathlib import Path from pathlib import Path
from .sqlite_storage import SQLiteStorage from .sqlite_storage import SQLiteStorage
@ -34,39 +34,38 @@ class FileStorage(SQLiteStorage):
self.database = workdir / (self.name + self.FILE_EXTENSION) self.database = workdir / (self.name + self.FILE_EXTENSION)
def update(self): async def update(self):
version = self.version() version = await self.version()
if version == 1: if version == 1:
with self.lock, self.conn: await self.conn.execute("DELETE FROM peers")
self.conn.execute("DELETE FROM peers")
version += 1 version += 1
if version == 2: if version == 2:
with self.lock, self.conn: await self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
version += 1 version += 1
self.version(version) await self.version(version)
async def open(self): async def open(self):
path = self.database path = self.database
file_exists = path.is_file() file_exists = path.is_file()
self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False) self.conn = await aiosqlite.connect(str(path), timeout=1)
await self.conn.execute("PRAGMA journal_mode=WAL")
if not file_exists: if not file_exists:
self.create() await self.create()
else: else:
self.update() await self.update()
with self.conn: try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum await self.conn.execute("VACUUM")
self.conn.execute("VACUUM") except aiosqlite.OperationalError:
except sqlite3.OperationalError: pass
pass
async def delete(self): async def delete(self):
os.remove(self.database) os.remove(self.database)

View file

@ -18,7 +18,7 @@
import base64 import base64
import logging import logging
import sqlite3 import aiosqlite
import struct import struct
from .sqlite_storage import SQLiteStorage from .sqlite_storage import SQLiteStorage
@ -33,8 +33,8 @@ class MemoryStorage(SQLiteStorage):
self.session_string = session_string self.session_string = session_string
async def open(self): async def open(self):
self.conn = sqlite3.connect(":memory:", check_same_thread=False) self.conn = await aiosqlite.connect(":memory:")
self.create() await self.create()
if self.session_string: if self.session_string:
# Old format # Old format

View file

@ -17,7 +17,7 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import inspect import inspect
import sqlite3 import aiosqlite
import time import time
from threading import Lock from threading import Lock
from typing import List, Tuple, Any from typing import List, Tuple, Any
@ -97,22 +97,20 @@ class SQLiteStorage(Storage):
def __init__(self, name: str): def __init__(self, name: str):
super().__init__(name) super().__init__(name)
self.conn = None # type: sqlite3.Connection self.conn = None # type: aiosqlite.Connection
self.lock = Lock()
def create(self): async def create(self):
with self.lock, self.conn: await self.conn.executescript(SCHEMA)
self.conn.executescript(SCHEMA)
self.conn.execute( await self.conn.execute(
"INSERT INTO version VALUES (?)", "INSERT INTO version VALUES (?)",
(self.VERSION,) (self.VERSION,)
) )
self.conn.execute( await self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)", "INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None) (2, None, None, None, 0, None, None)
) )
async def open(self): async def open(self):
raise NotImplementedError raise NotImplementedError
@ -120,29 +118,27 @@ class SQLiteStorage(Storage):
async def save(self): async def save(self):
await self.date(int(time.time())) await self.date(int(time.time()))
with self.lock: await self.conn.commit()
self.conn.commit()
async def close(self): async def close(self):
with self.lock: await self.conn.close()
self.conn.close()
async def delete(self): async def delete(self):
raise NotImplementedError raise NotImplementedError
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
with self.lock: await self.conn.executemany(
self.conn.executemany( "REPLACE INTO peers (id, access_hash, type, username, phone_number)"
"REPLACE INTO peers (id, access_hash, type, username, phone_number)" "VALUES (?, ?, ?, ?, ?)",
"VALUES (?, ?, ?, ?, ?)", peers
peers )
)
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
r = self.conn.execute( q = await self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE id = ?", "SELECT id, access_hash, type FROM peers WHERE id = ?",
(peer_id,) (peer_id,)
).fetchone() )
r = await q.fetchone()
if r is None: if r is None:
raise KeyError(f"ID not found: {peer_id}") raise KeyError(f"ID not found: {peer_id}")
@ -150,11 +146,12 @@ class SQLiteStorage(Storage):
return get_input_peer(*r) return get_input_peer(*r)
async def get_peer_by_username(self, username: str): async def get_peer_by_username(self, username: str):
r = self.conn.execute( q = await self.conn.execute(
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?" "SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?"
"ORDER BY last_update_on DESC", "ORDER BY last_update_on DESC",
(username,) (username,)
).fetchone() )
r = await q.fetchone()
if r is None: if r is None:
raise KeyError(f"Username not found: {username}") raise KeyError(f"Username not found: {username}")
@ -165,64 +162,65 @@ class SQLiteStorage(Storage):
return get_input_peer(*r[:3]) return get_input_peer(*r[:3])
async def get_peer_by_phone_number(self, phone_number: str): async def get_peer_by_phone_number(self, phone_number: str):
r = self.conn.execute( q = await self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?", "SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
(phone_number,) (phone_number,)
).fetchone() )
r = await q.fetchone()
if r is None: if r is None:
raise KeyError(f"Phone number not found: {phone_number}") raise KeyError(f"Phone number not found: {phone_number}")
return get_input_peer(*r) return get_input_peer(*r)
def _get(self): async def _get(self):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
return self.conn.execute( q = await self.conn.execute(
f"SELECT {attr} FROM sessions" f"SELECT {attr} FROM sessions"
).fetchone()[0] )
return (await q.fetchone())[0]
def _set(self, value: Any): async def _set(self, value: Any):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
with self.lock, self.conn: await self.conn.execute(
self.conn.execute( f"UPDATE sessions SET {attr} = ?",
f"UPDATE sessions SET {attr} = ?", (value,)
(value,) )
)
def _accessor(self, value: Any = object): async def _accessor(self, value: Any = object):
return self._get() if value == object else self._set(value) return await self._get() if value == object else await self._set(value)
async def dc_id(self, value: int = object): async def dc_id(self, value: int = object):
return self._accessor(value) return await self._accessor(value)
async def api_id(self, value: int = object): async def api_id(self, value: int = object):
return self._accessor(value) return await self._accessor(value)
async def test_mode(self, value: bool = object): async def test_mode(self, value: bool = object):
return self._accessor(value) return await self._accessor(value)
async def auth_key(self, value: bytes = object): async def auth_key(self, value: bytes = object):
return self._accessor(value) return await self._accessor(value)
async def date(self, value: int = object): async def date(self, value: int = object):
return self._accessor(value) return await self._accessor(value)
async def user_id(self, value: int = object): async def user_id(self, value: int = object):
return self._accessor(value) return await self._accessor(value)
async def is_bot(self, value: bool = object): async def is_bot(self, value: bool = object):
return self._accessor(value) return await self._accessor(value)
def version(self, value: int = object): async def version(self, value: int = object):
if value == object: if value == object:
return self.conn.execute( q = await self.conn.execute(
"SELECT number FROM version" "SELECT number FROM version"
).fetchone()[0] )
return (await q.fetchone())[0]
else: else:
with self.lock, self.conn: await self.conn.execute(
self.conn.execute( "UPDATE version SET number = ?",
"UPDATE version SET number = ?", (value,)
(value,) )
)

View file

@ -1,3 +1,4 @@
pyaes==1.6.1 pyaes==1.6.1
pymediainfo==6.0.1 pymediainfo==6.0.1
pysocks==1.7.1 pysocks==1.7.1
aiosqlite>=0.16.0,<0.18.0