Revert "PyroFork: wrap sqlite3"

This partial reverts of commit 2a93257fa2.

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
wulan17 2023-08-13 18:53:11 +07:00
parent a2bd7f5ad0
commit 113bfb1900
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
6 changed files with 93 additions and 183 deletions

View file

@ -21,7 +21,6 @@ import os
import sqlite3 import sqlite3
from pathlib import Path from pathlib import Path
from . import sqlite
from .sqlite_storage import SQLiteStorage from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -35,36 +34,39 @@ class FileStorage(SQLiteStorage):
self.database = workdir / (self.name + self.FILE_EXTENSION) self.database = workdir / (self.name + self.FILE_EXTENSION)
async def update(self): def update(self):
version = await self.version() version = self.version()
if version == 1: if version == 1:
await self.conn.execute("DELETE FROM peers") with self.lock, self.conn:
self.conn.execute("DELETE FROM peers")
version += 1 version += 1
if version == 2: if version == 2:
await self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER") with self.lock, self.conn:
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
version += 1 version += 1
await self.version(version) self.version(version)
async def open(self): async def open(self):
path = self.database path = self.database
file_exists = path.is_file() file_exists = path.is_file()
self.conn = sqlite.AsyncSqlite(database=str(path), timeout=1, check_same_thread=False) self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False)
if not file_exists: if not file_exists:
await self.create() self.create()
else: else:
await self.update() self.update()
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum with self.conn:
await self.conn.execute("VACUUM") try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
except sqlite3.OperationalError: self.conn.execute("VACUUM")
pass except sqlite3.OperationalError:
pass
async def delete(self): async def delete(self):
os.remove(self.database) os.remove(self.database)

View file

@ -18,9 +18,9 @@
import base64 import base64
import logging import logging
import sqlite3
import struct import struct
from . import sqlite
from .sqlite_storage import SQLiteStorage from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__) log = logging.getLogger(__name__)
@ -33,8 +33,8 @@ class MemoryStorage(SQLiteStorage):
self.session_string = session_string self.session_string = session_string
async def open(self): async def open(self):
self.conn = sqlite.AsyncSqlite(database=":memory:", check_same_thread=False) self.conn = sqlite3.connect(":memory:", check_same_thread=False)
await self.create() self.create()
if self.session_string: if self.session_string:
# Old format # Old format

View file

@ -1,22 +0,0 @@
# Pyrofork - Telegram MTProto API Client Library for Python
# Copyright (C) 2022-present Mayuri-Chan <https://github.com/Mayuri-Chan>
#
# This file is part of Pyrofork.
#
# Pyrofork is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrofork is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
from .cursor import AsyncCursor
from .sqlite import AsyncSqlite
__all__ = [AsyncSqlite, AsyncCursor]

View file

@ -1,29 +0,0 @@
# Pyrofork - Telegram MTProto API Client Library for Python
# Copyright (C) 2022-present Mayuri-Chan <https://github.com/Mayuri-Chan>
#
# This file is part of Pyrofork.
#
# Pyrofork is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrofork is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
from pyrogram.utils import run_sync
from sqlite3 import Cursor
from threading import Thread
class AsyncCursor(Thread):
def __init__(self, cursor: Cursor):
super().__init__()
self.cursor = cursor
async def fetchone(self):
return await run_sync(self.cursor.fetchone)

View file

@ -1,46 +0,0 @@
# Pyrofork - Telegram MTProto API Client Library for Python
# Copyright (C) 2022-present Mayuri-Chan <https://github.com/Mayuri-Chan>
#
# This file is part of Pyrofork.
#
# Pyrofork is free software: you can redistribute it and/or modify
# it under the terms of the GNU Lesser General Public License as published
# by the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# Pyrofork is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Lesser General Public License for more details.
#
# You should have received a copy of the GNU Lesser General Public License
# along with Pyrofork. If not, see <http://www.gnu.org/licenses/>.
import sqlite3
from .cursor import AsyncCursor
from pathlib import Path
from pyrogram.utils import run_sync
from threading import Thread
from typing import Union
class AsyncSqlite(Thread):
def __init__(self, database: Union[str, Path], *args, **kwargs):
super().__init__()
self.connection = sqlite3.connect(database, *args, **kwargs)
async def commit(self):
return await run_sync(self.connection.commit)
async def close(self):
return await run_sync(self.connection.close)
async def execute(self, *args, **kwargs):
r = await run_sync(self.connection.execute, *args, **kwargs)
return AsyncCursor(r)
async def executemany(self, *args, **kwargs):
r = await run_sync(self.connection.executemany, *args, **kwargs)
return AsyncCursor(r)
async def executescript(self, *args, **kwargs):
r = await run_sync(self.connection.executescript, *args, **kwargs)

View file

@ -17,7 +17,9 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>. # along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import inspect import inspect
import sqlite3
import time import time
from threading import Lock
from typing import List, Tuple, Any from typing import List, Tuple, Any
from pyrogram import raw from pyrogram import raw
@ -66,6 +68,7 @@ BEGIN
END; END;
""" """
def get_input_peer(peer_id: int, access_hash: int, peer_type: str): def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
if peer_type in ["user", "bot"]: if peer_type in ["user", "bot"]:
return raw.types.InputPeerUser( return raw.types.InputPeerUser(
@ -94,20 +97,22 @@ class SQLiteStorage(Storage):
def __init__(self, name: str): def __init__(self, name: str):
super().__init__(name) super().__init__(name)
self.conn = None self.conn = None # type: sqlite3.Connection
self.lock = Lock()
async def create(self): def create(self):
await self.conn.executescript(SCHEMA) with self.lock, self.conn:
self.conn.executescript(SCHEMA)
await self.conn.execute( self.conn.execute(
"INSERT INTO version VALUES (?)", "INSERT INTO version VALUES (?)",
(self.VERSION,) (self.VERSION,)
) )
await self.conn.execute( self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)", "INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None) (2, None, None, None, 0, None, None)
) )
async def open(self): async def open(self):
raise NotImplementedError raise NotImplementedError
@ -115,27 +120,29 @@ class SQLiteStorage(Storage):
async def save(self): async def save(self):
await self.date(int(time.time())) await self.date(int(time.time()))
await self.conn.commit() with self.lock:
self.conn.commit()
async def close(self): async def close(self):
await self.conn.close() with self.lock:
self.conn.close()
async def delete(self): async def delete(self):
raise NotImplementedError raise NotImplementedError
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]): async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
await self.conn.executemany( with self.lock:
"REPLACE INTO peers (id, access_hash, type, username, phone_number)" self.conn.executemany(
"VALUES (?, ?, ?, ?, ?)", "REPLACE INTO peers (id, access_hash, type, username, phone_number)"
peers "VALUES (?, ?, ?, ?, ?)",
) peers
)
async def get_peer_by_id(self, peer_id: int): async def get_peer_by_id(self, peer_id: int):
q = await self.conn.execute( r = self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE id = ?", "SELECT id, access_hash, type FROM peers WHERE id = ?",
(peer_id,) (peer_id,)
) ).fetchone()
r = await q.fetchone()
if r is None: if r is None:
raise KeyError(f"ID not found: {peer_id}") raise KeyError(f"ID not found: {peer_id}")
@ -143,12 +150,11 @@ class SQLiteStorage(Storage):
return get_input_peer(*r) return get_input_peer(*r)
async def get_peer_by_username(self, username: str): async def get_peer_by_username(self, username: str):
q = await self.conn.execute( r = self.conn.execute(
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?" "SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?"
"ORDER BY last_update_on DESC", "ORDER BY last_update_on DESC",
(username,) (username,)
) ).fetchone()
r = await q.fetchone()
if r is None: if r is None:
raise KeyError(f"Username not found: {username}") raise KeyError(f"Username not found: {username}")
@ -159,65 +165,64 @@ class SQLiteStorage(Storage):
return get_input_peer(*r[:3]) return get_input_peer(*r[:3])
async def get_peer_by_phone_number(self, phone_number: str): async def get_peer_by_phone_number(self, phone_number: str):
q = await self.conn.execute( r = self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?", "SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
(phone_number,) (phone_number,)
) ).fetchone()
r = await q.fetchone()
if r is None: if r is None:
raise KeyError(f"Phone number not found: {phone_number}") raise KeyError(f"Phone number not found: {phone_number}")
return get_input_peer(*r) return get_input_peer(*r)
async def _get(self): def _get(self):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
q = await self.conn.execute( return self.conn.execute(
f"SELECT {attr} FROM sessions" f"SELECT {attr} FROM sessions"
) ).fetchone()[0]
return (await q.fetchone())[0]
async def _set(self, value: Any): def _set(self, value: Any):
attr = inspect.stack()[2].function attr = inspect.stack()[2].function
await self.conn.execute( with self.lock, self.conn:
f"UPDATE sessions SET {attr} = ?", self.conn.execute(
(value,) f"UPDATE sessions SET {attr} = ?",
)
async def _accessor(self, value: Any = object):
return await self._get() if value == object else await self._set(value)
async def dc_id(self, value: int = object):
return await self._accessor(value)
async def api_id(self, value: int = object):
return await self._accessor(value)
async def test_mode(self, value: bool = object):
return await self._accessor(value)
async def auth_key(self, value: bytes = object):
return await self._accessor(value)
async def date(self, value: int = object):
return await self._accessor(value)
async def user_id(self, value: int = object):
return await self._accessor(value)
async def is_bot(self, value: bool = object):
return await self._accessor(value)
async def version(self, value: int = object):
if value == object:
q = await self.conn.execute(
"SELECT number FROM version"
)
return (await q.fetchone())[0]
else:
await self.conn.execute(
"UPDATE version SET number = ?",
(value,) (value,)
) )
def _accessor(self, value: Any = object):
return self._get() if value == object else self._set(value)
async def dc_id(self, value: int = object):
return self._accessor(value)
async def api_id(self, value: int = object):
return self._accessor(value)
async def test_mode(self, value: bool = object):
return self._accessor(value)
async def auth_key(self, value: bytes = object):
return self._accessor(value)
async def date(self, value: int = object):
return self._accessor(value)
async def user_id(self, value: int = object):
return self._accessor(value)
async def is_bot(self, value: bool = object):
return self._accessor(value)
def version(self, value: int = object):
if value == object:
return self.conn.execute(
"SELECT number FROM version"
).fetchone()[0]
else:
with self.lock, self.conn:
self.conn.execute(
"UPDATE version SET number = ?",
(value,)
)