diff --git a/pyrogram/client/client.py b/pyrogram/client/client.py
index 0e8d5554..f17a054b 100644
--- a/pyrogram/client/client.py
+++ b/pyrogram/client/client.py
@@ -49,12 +49,15 @@ from pyrogram.api.errors import (
from pyrogram.client.handlers import DisconnectHandler
from pyrogram.client.handlers.handler import Handler
from pyrogram.client.methods.password.utils import compute_check
+from pyrogram.client.session_storage import BaseSessionConfig
from pyrogram.crypto import AES
from pyrogram.session import Auth, Session
from .dispatcher import Dispatcher
from .ext import utils, Syncer, BaseClient
from .methods import Methods
-from .session_storage import BaseSessionStorage, JsonSessionStorage, SessionDoesNotExist
+from .session_storage import SessionDoesNotExist
+from .session_storage.json_session_storage import JsonSessionStorage
+from .session_storage.string_session_storage import StringSessionStorage
log = logging.getLogger(__name__)
@@ -176,7 +179,7 @@ class Client(Methods, BaseClient):
"""
def __init__(self,
- session_name: str,
+ session_name: Union[str, BaseSessionConfig],
api_id: Union[int, str] = None,
api_hash: str = None,
app_version: str = None,
@@ -198,11 +201,21 @@ class Client(Methods, BaseClient):
config_file: str = BaseClient.CONFIG_FILE,
plugins: dict = None,
no_updates: bool = None,
- takeout: bool = None,
- session_storage_cls: Type[BaseSessionStorage] = JsonSessionStorage):
- super().__init__(session_storage_cls(self))
+ takeout: bool = None):
- self.session_name = session_name
+ if isinstance(session_name, str):
+ if session_name.startswith(':'):
+ session_storage = StringSessionStorage(self, session_name)
+ else:
+ session_storage = JsonSessionStorage(self, session_name)
+ elif isinstance(session_name, BaseSessionConfig):
+ session_storage = session_name.session_storage_cls(self, session_name)
+ else:
+ raise RuntimeError('Wrong session_name passed, expected str or BaseSessionConfig subclass')
+
+ super().__init__(session_storage)
+
+ self.session_name = str(session_name) # TODO: build correct session name
self.api_id = int(api_id) if api_id else None
self.api_hash = api_hash
self.app_version = app_version
@@ -1101,12 +1114,9 @@ class Client(Methods, BaseClient):
def load_session(self):
try:
- self.session_storage.load_session(self.session_name)
+ self.session_storage.load_session()
except SessionDoesNotExist:
- session_name = self.session_name[:32]
- if session_name != self.session_name:
- session_name += '...'
- log.info('Could not load session "{}", initializing new one'.format(self.session_name))
+ log.info('Could not load session "{}", initiate new one'.format(self.session_name))
self.auth_key = Auth(self.dc_id, self.test_mode, self.ipv6, self._proxy).create()
def load_plugins(self):
@@ -1214,7 +1224,7 @@ class Client(Methods, BaseClient):
log.warning('No plugin loaded from "{}"'.format(root))
def save_session(self):
- self.session_storage.save_session(self.session_name)
+ self.session_storage.save_session()
def get_initial_dialogs_chunk(self,
offset_date: int = 0):
diff --git a/pyrogram/client/ext/syncer.py b/pyrogram/client/ext/syncer.py
index 8930b13e..70955624 100644
--- a/pyrogram/client/ext/syncer.py
+++ b/pyrogram/client/ext/syncer.py
@@ -83,10 +83,10 @@ class Syncer:
def sync(cls, client):
client.date = int(time.time())
try:
- client.session_storage.save_session(client.session_name, sync=True)
+ client.session_storage.save_session(sync=True)
except Exception as e:
log.critical(e, exc_info=True)
else:
log.info("Synced {}".format(client.session_name))
finally:
- client.session_storage.sync_cleanup(client.session_name)
+ client.session_storage.sync_cleanup()
diff --git a/pyrogram/client/session_storage/__init__.py b/pyrogram/client/session_storage/__init__.py
index ced103ce..611ec9b7 100644
--- a/pyrogram/client/session_storage/__init__.py
+++ b/pyrogram/client/session_storage/__init__.py
@@ -17,6 +17,4 @@
# along with Pyrogram. If not, see .
from .session_storage_mixin import SessionStorageMixin
-from .base_session_storage import BaseSessionStorage, SessionDoesNotExist
-from .json_session_storage import JsonSessionStorage
-from .string_session_storage import StringSessionStorage
+from .base_session_storage import BaseSessionStorage, BaseSessionConfig, SessionDoesNotExist
diff --git a/pyrogram/client/session_storage/base_session_storage.py b/pyrogram/client/session_storage/base_session_storage.py
index 75e416b4..a5c879f1 100644
--- a/pyrogram/client/session_storage/base_session_storage.py
+++ b/pyrogram/client/session_storage/base_session_storage.py
@@ -17,6 +17,7 @@
# along with Pyrogram. If not, see .
import abc
+from typing import Type
import pyrogram
@@ -26,8 +27,9 @@ class SessionDoesNotExist(Exception):
class BaseSessionStorage(abc.ABC):
- def __init__(self, client: 'pyrogram.client.BaseClient'):
+ def __init__(self, client: 'pyrogram.client.BaseClient', session_data):
self.client = client
+ self.session_data = session_data
self.dc_id = 1
self.test_mode = None
self.auth_key = None
@@ -38,13 +40,20 @@ class BaseSessionStorage(abc.ABC):
self.peers_by_phone = {}
@abc.abstractmethod
- def load_session(self, name: str):
+ def load_session(self):
...
@abc.abstractmethod
- def save_session(self, name: str, sync=False):
+ def save_session(self, sync=False):
...
@abc.abstractmethod
- def sync_cleanup(self, name: str):
+ def sync_cleanup(self):
+ ...
+
+
+class BaseSessionConfig(abc.ABC):
+ @property
+ @abc.abstractmethod
+ def session_storage_cls(self) -> Type[BaseSessionStorage]:
...
diff --git a/pyrogram/client/session_storage/json_session_storage.py b/pyrogram/client/session_storage/json_session_storage.py
index 679a21f3..f41091af 100644
--- a/pyrogram/client/session_storage/json_session_storage.py
+++ b/pyrogram/client/session_storage/json_session_storage.py
@@ -35,8 +35,8 @@ class JsonSessionStorage(BaseSessionStorage):
name += '.session'
return os.path.join(self.client.workdir, name)
- def load_session(self, name: str):
- file_path = self._get_file_name(name)
+ def load_session(self):
+ file_path = self._get_file_name(self.session_data)
log.info('Loading JSON session from {}'.format(file_path))
try:
@@ -66,8 +66,8 @@ class JsonSessionStorage(BaseSessionStorage):
if peer:
self.peers_by_phone[k] = peer
- def save_session(self, name: str, sync=False):
- file_path = self._get_file_name(name)
+ def save_session(self, sync=False):
+ file_path = self._get_file_name(self.session_data)
if sync:
file_path += '.tmp'
@@ -107,10 +107,10 @@ class JsonSessionStorage(BaseSessionStorage):
# execution won't be here if an error has occurred earlier
if sync:
- shutil.move(file_path, self._get_file_name(name))
+ shutil.move(file_path, self._get_file_name(self.session_data))
- def sync_cleanup(self, name: str):
+ def sync_cleanup(self):
try:
- os.remove(self._get_file_name(name) + '.tmp')
+ os.remove(self._get_file_name(self.session_data) + '.tmp')
except OSError:
pass
diff --git a/pyrogram/client/session_storage/string_session_storage.py b/pyrogram/client/session_storage/string_session_storage.py
index 9b6ebf0e..c01a2b35 100644
--- a/pyrogram/client/session_storage/string_session_storage.py
+++ b/pyrogram/client/session_storage/string_session_storage.py
@@ -5,34 +5,33 @@ import struct
from . import BaseSessionStorage, SessionDoesNotExist
-def StringSessionStorage(print_session: bool = False):
- class StringSessionStorageClass(BaseSessionStorage):
- """
- Packs session data as following (forcing little-endian byte order):
- Char dc_id (1 byte, unsigned)
- Boolean test_mode (1 byte)
- Long long user_id (8 bytes, signed)
- Bytes auth_key (256 bytes)
+class StringSessionStorage(BaseSessionStorage):
+ """
+ Packs session data as following (forcing little-endian byte order):
+ Char dc_id (1 byte, unsigned)
+ Boolean test_mode (1 byte)
+ Long long user_id (8 bytes, signed)
+ Bytes auth_key (256 bytes)
- Uses Base64 encoding for printable representation
- """
- PACK_FORMAT = '