Use factory method for session_manager

Also add some abstraction over the SQL to allow for different SQL
dialects
This commit is contained in:
Anton Melser
2019-01-28 21:23:07 +08:00
parent ea0cbc669b
commit 50cc6a12d9
4 changed files with 251 additions and 157 deletions

View File

@@ -41,6 +41,7 @@ from anki.consts import SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
from anki.consts import REM_CARD, REM_NOTE
from ankisyncd.users import get_user_manager
from ankisyncd.sessions import get_session_manager
logger = logging.getLogger("ankisyncd")
@@ -382,26 +383,6 @@ class SyncUserSession:
handler.col = col
return handler
class SimpleSessionManager:
"""A simple session manager that keeps the sessions in memory."""
def __init__(self):
self.sessions = {}
def load(self, hkey, session_factory=None):
return self.sessions.get(hkey)
def load_from_skey(self, skey, session_factory=None):
for i in self.sessions:
if self.sessions[i].skey == skey:
return self.sessions[i]
def save(self, hkey, session):
self.sessions[hkey] = session
def delete(self, hkey):
del self.sessions[hkey]
class SyncApp:
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
@@ -416,12 +397,8 @@ class SyncApp:
self.prehooks = {}
self.posthooks = {}
if "session_db_path" in config:
self.session_manager = SqliteSessionManager(config['session_db_path'])
else:
self.session_manager = SimpleSessionManager()
self.user_manager = get_user_manager(config)
self.session_manager = get_session_manager(config)
self.collection_manager = getCollectionManager()
# make sure the base_url has a trailing slash
@@ -680,73 +657,6 @@ class SyncApp:
return result
class SqliteSessionManager(SimpleSessionManager):
"""Stores sessions in a SQLite database to prevent the user from being logged out
everytime the SyncApp is restarted."""
def __init__(self, session_db_path):
SimpleSessionManager.__init__(self)
self.session_db_path = os.path.realpath(session_db_path)
def _conn(self):
new = not os.path.exists(self.session_db_path)
conn = sqlite.connect(self.session_db_path)
if new:
cursor = conn.cursor()
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, user VARCHAR, path VARCHAR)")
return conn
def load(self, hkey, session_factory=None):
session = SimpleSessionManager.load(self, hkey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT skey, user, path FROM session WHERE hkey=?", (hkey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[hkey] = session_factory(res[1], res[2])
session.skey = res[0]
return session
def load_from_skey(self, skey, session_factory=None):
session = SimpleSessionManager.load_from_skey(self, skey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT hkey, user, path FROM session WHERE skey=?", (skey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[res[0]] = session_factory(res[1], res[2])
session.skey = skey
return session
def save(self, hkey, session):
SimpleSessionManager.save(self, hkey, session)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, user, path) VALUES (?, ?, ?, ?)",
(hkey, session.skey, session.name, session.path))
conn.commit()
def delete(self, hkey):
SimpleSessionManager.delete(self, hkey)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("DELETE FROM session WHERE hkey=?", (hkey,))
conn.commit()
def make_app(global_conf, **local_conf):
return SyncApp(**local_conf)