PyroFork: Add async helper, and wrap sqlite3

Signed-off-by: wulan17 <wulan17@nusantararom.org>
This commit is contained in:
wulan17 2023-06-22 01:15:30 +07:00
parent de68ba1919
commit 2a93257fa2
No known key found for this signature in database
GPG key ID: 318CD6CD3A6AC0A5
11 changed files with 160 additions and 78 deletions

View file

@ -138,6 +138,7 @@ def pyrogram_api():
start
stop
run
run_sync
restart
add_handler
remove_handler

View file

@ -74,7 +74,7 @@ Chats
{chats}
Stickers
-----
--------
.. autosummary::
:nosignatures:

View file

@ -21,6 +21,7 @@ from .export_session_string import ExportSessionString
from .remove_handler import RemoveHandler
from .restart import Restart
from .run import Run
from .run_sync import RunSync
from .start import Start
from .stop import Stop
from .stop_transmission import StopTransmission
@ -32,6 +33,7 @@ class Utilities(
RemoveHandler,
Restart,
Run,
RunSync,
Start,
Stop,
StopTransmission

View file

@ -0,0 +1,41 @@
"""PyroFork async utils"""
# Copyright (C) 2020 - 2023 UserbotIndo Team, <https://github.com/userbotindo.git>
# Copyright (C) 2023 Mayuri-Chan, <https://github.com/Mayuri-Chan.git>
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU General Public License as published by
# the Free Software Foundation, either version 3 of the License, or
# (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU General Public License for more details.
#
# You should have received a copy of the GNU General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import asyncio
import functools
from pyrogram import utils
from typing import Any, Callable, TypeVar
class RunSync:
Result = TypeVar("Result")
async def run_sync(self, func: Callable[..., Result], *args: Any, **kwargs: Any) -> Result:
"""
Runs the given sync function (optionally with arguments) on a separate thread.
Parameters:
func (``Callable``):
Sync function to run.
*args (``any``, *optional*):
Function argument.
**kwargs (``any``, *optional*):
Function extras arguments.
"""
return await utils.run_sync(func, *args, **kwargs)

View file

@ -21,6 +21,7 @@ import os
import sqlite3
from pathlib import Path
from . import sqlite
from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__)
@ -34,39 +35,36 @@ class FileStorage(SQLiteStorage):
self.database = workdir / (self.name + self.FILE_EXTENSION)
def update(self):
version = self.version()
async def update(self):
version = await self.version()
if version == 1:
with self.lock, self.conn:
self.conn.execute("DELETE FROM peers")
await self.conn.execute("DELETE FROM peers")
version += 1
if version == 2:
with self.lock, self.conn:
self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
await self.conn.execute("ALTER TABLE sessions ADD api_id INTEGER")
version += 1
self.version(version)
await self.version(version)
async def open(self):
path = self.database
file_exists = path.is_file()
self.conn = sqlite3.connect(str(path), timeout=1, check_same_thread=False)
self.conn = sqlite.AsyncSqlite(database=str(path), timeout=1, check_same_thread=False)
if not file_exists:
self.create()
await self.create()
else:
self.update()
await self.update()
with self.conn:
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
self.conn.execute("VACUUM")
except sqlite3.OperationalError:
pass
try: # Python 3.6.0 (exactly this version) is bugged and won't successfully execute the vacuum
await self.conn.execute("VACUUM")
except sqlite3.OperationalError:
pass
async def delete(self):
os.remove(self.database)

View file

@ -18,9 +18,9 @@
import base64
import logging
import sqlite3
import struct
from . import sqlite
from .sqlite_storage import SQLiteStorage
log = logging.getLogger(__name__)
@ -33,8 +33,8 @@ class MemoryStorage(SQLiteStorage):
self.session_string = session_string
async def open(self):
self.conn = sqlite3.connect(":memory:", check_same_thread=False)
self.create()
self.conn = sqlite.AsyncSqlite(database=":memory:", check_same_thread=False)
await self.create()
if self.session_string:
# Old format

View file

@ -0,0 +1,4 @@
from .cursor import AsyncCursor
from .sqlite import AsyncSqlite
__all__ = [AsyncSqlite, AsyncCursor]

View file

@ -0,0 +1,10 @@
from pyrogram.utils import run_sync
from sqlite3 import Cursor
from threading import Thread
class AsyncCursor(Thread):
def __init__(self, cursor: Cursor):
self.cursor = cursor
async def fetchone(self):
return await run_sync(self.cursor.fetchone)

View file

@ -0,0 +1,27 @@
import sqlite3
from .cursor import AsyncCursor
from pathlib import Path
from pyrogram.utils import run_sync
from threading import Thread
from typing import Union
class AsyncSqlite(Thread):
def __init__(self, database: Union[str, Path], *args, **kwargs):
self.connection = sqlite3.connect(database, *args, **kwargs)
async def commit(self):
return await run_sync(self.connection.commit)
async def close(self):
return await run_sync(self.connection.close)
async def execute(self, *args, **kwargs):
r = await run_sync(self.connection.execute, *args, **kwargs)
return AsyncCursor(r)
async def executemany(self, *args, **kwargs):
r = await run_sync(self.connection.executemany, *args, **kwargs)
return AsyncCursor(r)
async def executescript(self, *args, **kwargs):
r = await run_sync(self.connection.executescript, *args, **kwargs)

View file

@ -17,9 +17,7 @@
# along with Pyrogram. If not, see <http://www.gnu.org/licenses/>.
import inspect
import sqlite3
import time
from threading import Lock
from typing import List, Tuple, Any
from pyrogram import raw
@ -68,7 +66,6 @@ BEGIN
END;
"""
def get_input_peer(peer_id: int, access_hash: int, peer_type: str):
if peer_type in ["user", "bot"]:
return raw.types.InputPeerUser(
@ -97,22 +94,20 @@ class SQLiteStorage(Storage):
def __init__(self, name: str):
super().__init__(name)
self.conn = None # type: sqlite3.Connection
self.lock = Lock()
self.conn = None
def create(self):
with self.lock, self.conn:
self.conn.executescript(SCHEMA)
async def create(self):
await self.conn.executescript(SCHEMA)
self.conn.execute(
"INSERT INTO version VALUES (?)",
(self.VERSION,)
)
await self.conn.execute(
"INSERT INTO version VALUES (?)",
(self.VERSION,)
)
self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None)
)
await self.conn.execute(
"INSERT INTO sessions VALUES (?, ?, ?, ?, ?, ?, ?)",
(2, None, None, None, 0, None, None)
)
async def open(self):
raise NotImplementedError
@ -120,29 +115,27 @@ class SQLiteStorage(Storage):
async def save(self):
await self.date(int(time.time()))
with self.lock:
self.conn.commit()
await self.conn.commit()
async def close(self):
with self.lock:
self.conn.close()
await self.conn.close()
async def delete(self):
raise NotImplementedError
async def update_peers(self, peers: List[Tuple[int, int, str, str, str]]):
with self.lock:
self.conn.executemany(
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
"VALUES (?, ?, ?, ?, ?)",
peers
)
await self.conn.executemany(
"REPLACE INTO peers (id, access_hash, type, username, phone_number)"
"VALUES (?, ?, ?, ?, ?)",
peers
)
async def get_peer_by_id(self, peer_id: int):
r = self.conn.execute(
q = await self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE id = ?",
(peer_id,)
).fetchone()
)
r = await q.fetchone()
if r is None:
raise KeyError(f"ID not found: {peer_id}")
@ -150,11 +143,12 @@ class SQLiteStorage(Storage):
return get_input_peer(*r)
async def get_peer_by_username(self, username: str):
r = self.conn.execute(
q = await self.conn.execute(
"SELECT id, access_hash, type, last_update_on FROM peers WHERE username = ?"
"ORDER BY last_update_on DESC",
(username,)
).fetchone()
)
r = await q.fetchone()
if r is None:
raise KeyError(f"Username not found: {username}")
@ -165,64 +159,65 @@ class SQLiteStorage(Storage):
return get_input_peer(*r[:3])
async def get_peer_by_phone_number(self, phone_number: str):
r = self.conn.execute(
q = await self.conn.execute(
"SELECT id, access_hash, type FROM peers WHERE phone_number = ?",
(phone_number,)
).fetchone()
)
r = await q.fetchone()
if r is None:
raise KeyError(f"Phone number not found: {phone_number}")
return get_input_peer(*r)
def _get(self):
async def _get(self):
attr = inspect.stack()[2].function
return self.conn.execute(
q = await self.conn.execute(
f"SELECT {attr} FROM sessions"
).fetchone()[0]
)
return (await q.fetchone())[0]
def _set(self, value: Any):
async def _set(self, value: Any):
attr = inspect.stack()[2].function
with self.lock, self.conn:
self.conn.execute(
f"UPDATE sessions SET {attr} = ?",
(value,)
)
await self.conn.execute(
f"UPDATE sessions SET {attr} = ?",
(value,)
)
def _accessor(self, value: Any = object):
return self._get() if value == object else self._set(value)
async def _accessor(self, value: Any = object):
return await self._get() if value == object else await self._set(value)
async def dc_id(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def api_id(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def test_mode(self, value: bool = object):
return self._accessor(value)
return await self._accessor(value)
async def auth_key(self, value: bytes = object):
return self._accessor(value)
return await self._accessor(value)
async def date(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def user_id(self, value: int = object):
return self._accessor(value)
return await self._accessor(value)
async def is_bot(self, value: bool = object):
return self._accessor(value)
return await self._accessor(value)
def version(self, value: int = object):
async def version(self, value: int = object):
if value == object:
return self.conn.execute(
q = await self.conn.execute(
"SELECT number FROM version"
).fetchone()[0]
)
return (await q.fetchone())[0]
else:
with self.lock, self.conn:
self.conn.execute(
"UPDATE version SET number = ?",
(value,)
)
await self.conn.execute(
"UPDATE version SET number = ?",
(value,)
)

View file

@ -25,7 +25,7 @@ import struct
from concurrent.futures.thread import ThreadPoolExecutor
from datetime import datetime, timezone
from getpass import getpass
from typing import Union, List, Dict, Optional
from typing import Union, List, Dict, Optional, Any, Callable, TypeVar
import pyrogram
from pyrogram import raw, enums
@ -376,3 +376,7 @@ def timestamp_to_datetime(ts: Optional[int]) -> Optional[datetime]:
def datetime_to_timestamp(dt: Optional[datetime]) -> Optional[int]:
return int(dt.timestamp()) if dt else None
async def run_sync(func: Callable[..., TypeVar("Result")], *args: Any, **kwargs: Any) -> TypeVar("Result"):
loop = asyncio.get_event_loop()
return await loop.run_in_executor(None, functools.partial(func, *args, **kwargs))