Move packages into src folder
This commit is contained in:
76
src/addon/__init__.py
Normal file
76
src/addon/__init__.py
Normal file
@@ -0,0 +1,76 @@
|
||||
from PyQt5.Qt import Qt, QCheckBox, QLabel, QHBoxLayout, QLineEdit
|
||||
from aqt.forms import preferences
|
||||
from anki.hooks import wrap, addHook
|
||||
import aqt
|
||||
import anki.consts
|
||||
import anki.sync
|
||||
|
||||
DEFAULT_ADDR = "http://localhost:27701/"
|
||||
config = aqt.mw.addonManager.getConfig(__name__)
|
||||
|
||||
# TODO: force the user to log out before changing any of the settings
|
||||
|
||||
def addui(self, _):
|
||||
self = self.form
|
||||
parent_w = self.tab_2
|
||||
parent_l = self.vboxlayout
|
||||
self.useCustomServer = QCheckBox(parent_w)
|
||||
self.useCustomServer.setText("Use custom sync server")
|
||||
parent_l.addWidget(self.useCustomServer)
|
||||
cshl = QHBoxLayout()
|
||||
parent_l.addLayout(cshl)
|
||||
|
||||
self.serverAddrLabel = QLabel(parent_w)
|
||||
self.serverAddrLabel.setText("Server address")
|
||||
cshl.addWidget(self.serverAddrLabel)
|
||||
self.customServerAddr = QLineEdit(parent_w)
|
||||
self.customServerAddr.setPlaceholderText(DEFAULT_ADDR)
|
||||
cshl.addWidget(self.customServerAddr)
|
||||
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig["enabled"]:
|
||||
self.useCustomServer.setCheckState(Qt.Checked)
|
||||
if pconfig["addr"]:
|
||||
self.customServerAddr.setText(pconfig["addr"])
|
||||
|
||||
self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text))
|
||||
def onchecked(state):
|
||||
pconfig["enabled"] = state == Qt.Checked
|
||||
updateui(self, state)
|
||||
updateserver(self, self.customServerAddr.text())
|
||||
self.useCustomServer.stateChanged.connect(onchecked)
|
||||
|
||||
updateui(self, self.useCustomServer.checkState())
|
||||
|
||||
def updateserver(self, text):
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig['enabled']:
|
||||
addr = text or self.customServerAddr.placeholderText()
|
||||
pconfig['addr'] = addr
|
||||
setserver()
|
||||
aqt.mw.addonManager.writeConfig(__name__, config)
|
||||
|
||||
def updateui(self, state):
|
||||
self.serverAddrLabel.setEnabled(state == Qt.Checked)
|
||||
self.customServerAddr.setEnabled(state == Qt.Checked)
|
||||
|
||||
def setserver():
|
||||
pconfig = getprofileconfig()
|
||||
if pconfig['enabled']:
|
||||
aqt.mw.pm.profile['hostNum'] = None
|
||||
anki.sync.SYNC_BASE = "%s" + pconfig['addr']
|
||||
else:
|
||||
anki.sync.SYNC_BASE = anki.consts.SYNC_BASE
|
||||
|
||||
def getprofileconfig():
|
||||
if aqt.mw.pm.name not in config["profiles"]:
|
||||
# inherit global settings if present (used in earlier versions of the addon)
|
||||
config["profiles"][aqt.mw.pm.name] = {
|
||||
"enabled": config.get("enabled", False),
|
||||
"addr": config.get("addr", DEFAULT_ADDR),
|
||||
}
|
||||
aqt.mw.addonManager.writeConfig(__name__, config)
|
||||
return config["profiles"][aqt.mw.pm.name]
|
||||
|
||||
addHook("profileLoaded", setserver)
|
||||
aqt.preferences.Preferences.__init__ = wrap(aqt.preferences.Preferences.__init__, addui, "after")
|
||||
1
src/addon/config.json
Normal file
1
src/addon/config.json
Normal file
@@ -0,0 +1 @@
|
||||
{"profiles":{}}
|
||||
82
src/ankisyncctl.py
Executable file
82
src/ankisyncctl.py
Executable file
@@ -0,0 +1,82 @@
|
||||
#!/usr/bin/env python
|
||||
import os
|
||||
import sys
|
||||
import getpass
|
||||
|
||||
import ankisyncd.config
|
||||
from ankisyncd.users import get_user_manager
|
||||
|
||||
|
||||
config = ankisyncd.config.load()
|
||||
|
||||
def usage():
|
||||
print("usage: {} <command> [<args>]".format(sys.argv[0]))
|
||||
print()
|
||||
print("Commands:")
|
||||
print(" adduser <username> - add a new user")
|
||||
print(" deluser <username> - delete a user")
|
||||
print(" lsuser - list users")
|
||||
print(" passwd <username> - change password of a user")
|
||||
|
||||
def adduser(username):
|
||||
password = getpass.getpass("Enter password for {}: ".format(username))
|
||||
|
||||
user_manager = get_user_manager(config)
|
||||
user_manager.add_user(username, password)
|
||||
|
||||
def deluser(username):
|
||||
user_manager = get_user_manager(config)
|
||||
try:
|
||||
user_manager.del_user(username)
|
||||
except ValueError as error:
|
||||
print("Could not delete user {}: {}".format(username, error), file=sys.stderr)
|
||||
|
||||
def lsuser():
|
||||
user_manager = get_user_manager(config)
|
||||
try:
|
||||
users = user_manager.user_list()
|
||||
for username in users:
|
||||
print(username)
|
||||
except ValueError as error:
|
||||
print("Could not list users: {}".format(error), file=sys.stderr)
|
||||
|
||||
def passwd(username):
|
||||
user_manager = get_user_manager(config)
|
||||
|
||||
if username not in user_manager.user_list():
|
||||
print("User {} doesn't exist".format(username))
|
||||
return
|
||||
|
||||
password = getpass.getpass("Enter password for {}: ".format(username))
|
||||
try:
|
||||
user_manager.set_password_for_user(username, password)
|
||||
except ValueError as error:
|
||||
print("Could not set password for user {}: {}".format(username, error), file=sys.stderr)
|
||||
|
||||
def main():
|
||||
argc = len(sys.argv)
|
||||
|
||||
cmds = {
|
||||
"adduser": adduser,
|
||||
"deluser": deluser,
|
||||
"lsuser": lsuser,
|
||||
"passwd": passwd,
|
||||
}
|
||||
|
||||
if argc < 2:
|
||||
usage()
|
||||
exit(1)
|
||||
|
||||
c = sys.argv[1]
|
||||
try:
|
||||
if argc > 2:
|
||||
for arg in sys.argv[2:]:
|
||||
cmds[c](arg)
|
||||
else:
|
||||
cmds[c]()
|
||||
except KeyError:
|
||||
usage()
|
||||
exit(1)
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
20
src/ankisyncd.conf
Normal file
20
src/ankisyncd.conf
Normal file
@@ -0,0 +1,20 @@
|
||||
[sync_app]
|
||||
# change to 127.0.0.1 if you don't want the server to be accessible from the internet
|
||||
host = 0.0.0.0
|
||||
port = 27701
|
||||
data_root = ./collections
|
||||
base_url = /sync/
|
||||
base_media_url = /msync/
|
||||
auth_db_path = ./auth.db
|
||||
# optional, for session persistence between restarts
|
||||
session_db_path = ./session.db
|
||||
|
||||
# optional, for overriding the default managers and wrappers
|
||||
# # must inherit from ankisyncd.full_sync.FullSyncManager, e.g,
|
||||
# full_sync_manager = great_stuff.postgres.PostgresFullSyncManager
|
||||
# # must inherit from ankisyncd.session.SimpleSessionManager, e.g,
|
||||
# session_manager = great_stuff.postgres.PostgresSessionManager
|
||||
# # must inherit from ankisyncd.users.SimpleUserManager, e.g,
|
||||
# user_manager = great_stuff.postgres.PostgresUserManager
|
||||
# # must inherit from ankisyncd.collection.CollectionWrapper, e.g,
|
||||
# collection_wrapper = great_stuff.postgres.PostgresCollectionWrapper
|
||||
33
src/ankisyncd/__init__.py
Normal file
33
src/ankisyncd/__init__.py
Normal file
@@ -0,0 +1,33 @@
|
||||
import os
|
||||
import sys
|
||||
|
||||
sys.path.insert(0, "/usr/share/anki")
|
||||
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "anki-bundled"))
|
||||
|
||||
_homepage = "https://github.com/tsudoko/anki-sync-server"
|
||||
_unknown_version = "[unknown version]"
|
||||
|
||||
|
||||
def _get_version():
|
||||
try:
|
||||
from ankisyncd._version import version
|
||||
|
||||
return version
|
||||
except ImportError:
|
||||
pass
|
||||
|
||||
import subprocess
|
||||
|
||||
try:
|
||||
return (
|
||||
subprocess.run(
|
||||
["git", "describe", "--always"],
|
||||
stdout=subprocess.PIPE,
|
||||
stderr=subprocess.PIPE,
|
||||
)
|
||||
.stdout.strip()
|
||||
.decode()
|
||||
or _unknown_version
|
||||
)
|
||||
except (FileNotFoundError, subprocess.CalledProcessError):
|
||||
return _unknown_version
|
||||
11
src/ankisyncd/__main__.py
Normal file
11
src/ankisyncd/__main__.py
Normal file
@@ -0,0 +1,11 @@
|
||||
import sys
|
||||
|
||||
if __package__ is None and not hasattr(sys, "frozen"):
|
||||
import os.path
|
||||
path = os.path.realpath(os.path.abspath(__file__))
|
||||
sys.path.insert(0, os.path.dirname(os.path.dirname(path)))
|
||||
|
||||
import ankisyncd.sync_app
|
||||
|
||||
if __name__ == "__main__":
|
||||
ankisyncd.sync_app.main()
|
||||
138
src/ankisyncd/collection.py
Normal file
138
src/ankisyncd/collection.py
Normal file
@@ -0,0 +1,138 @@
|
||||
import anki
|
||||
import anki.storage
|
||||
|
||||
import ankisyncd.media
|
||||
|
||||
import os, errno
|
||||
import logging
|
||||
|
||||
logger = logging.getLogger("ankisyncd.collection")
|
||||
|
||||
|
||||
class CollectionWrapper:
|
||||
"""A simple wrapper around an anki.storage.Collection object.
|
||||
|
||||
This allows us to manage and refer to the collection, whether it's open or not. It
|
||||
also provides a special "continuation passing" interface for executing functions
|
||||
on the collection, which makes it easy to switch to a threading mode.
|
||||
|
||||
See ThreadingCollectionWrapper for a version that maintains a seperate thread for
|
||||
interacting with the collection.
|
||||
"""
|
||||
|
||||
def __init__(self, _config, path, setup_new_collection=None):
|
||||
self.path = os.path.realpath(path)
|
||||
self.username = os.path.basename(os.path.dirname(self.path))
|
||||
self.setup_new_collection = setup_new_collection
|
||||
self.__col = None
|
||||
|
||||
def __del__(self):
|
||||
"""Close the collection if the user forgot to do so."""
|
||||
self.close()
|
||||
|
||||
def execute(self, func, args=[], kw={}, waitForReturn=True):
|
||||
""" Executes the given function with the underlying anki.storage.Collection
|
||||
object as the first argument and any additional arguments specified by *args
|
||||
and **kw.
|
||||
|
||||
If 'waitForReturn' is True, then it will block until the function has
|
||||
executed and return its return value. If False, the function MAY be
|
||||
executed some time later and None will be returned.
|
||||
"""
|
||||
|
||||
# Open the collection and execute the function
|
||||
self.open()
|
||||
args = [self.__col] + args
|
||||
ret = func(*args, **kw)
|
||||
|
||||
# Only return the value if they requested it, so the interface remains
|
||||
# identical between this class and ThreadingCollectionWrapper
|
||||
if waitForReturn:
|
||||
return ret
|
||||
|
||||
def __create_collection(self):
|
||||
"""Creates a new collection and runs any special setup."""
|
||||
|
||||
# mkdir -p the path, because it might not exist
|
||||
os.makedirs(os.path.dirname(self.path), exist_ok=True)
|
||||
|
||||
col = self._get_collection()
|
||||
|
||||
# Do any special setup
|
||||
if self.setup_new_collection is not None:
|
||||
self.setup_new_collection(col)
|
||||
|
||||
return col
|
||||
|
||||
def _get_collection(self):
|
||||
col = anki.storage.Collection(self.path)
|
||||
|
||||
# Ugly hack, replace default media manager with our custom one
|
||||
col.media.close()
|
||||
col.media = ankisyncd.media.ServerMediaManager(col)
|
||||
|
||||
return col
|
||||
|
||||
def open(self):
|
||||
"""Open the collection, or create it if it doesn't exist."""
|
||||
if self.__col is None:
|
||||
if os.path.exists(self.path):
|
||||
self.__col = self._get_collection()
|
||||
else:
|
||||
self.__col = self.__create_collection()
|
||||
|
||||
def close(self):
|
||||
"""Close the collection if opened."""
|
||||
if not self.opened():
|
||||
return
|
||||
|
||||
self.__col.close()
|
||||
self.__col = None
|
||||
|
||||
def opened(self):
|
||||
"""Returns True if the collection is open, False otherwise."""
|
||||
return self.__col is not None
|
||||
|
||||
class CollectionManager:
|
||||
"""Manages a set of CollectionWrapper objects."""
|
||||
|
||||
collection_wrapper = CollectionWrapper
|
||||
|
||||
def __init__(self, config):
|
||||
self.collections = {}
|
||||
self.config = config
|
||||
|
||||
def get_collection(self, path, setup_new_collection=None):
|
||||
"""Gets a CollectionWrapper for the given path."""
|
||||
|
||||
path = os.path.realpath(path)
|
||||
|
||||
try:
|
||||
col = self.collections[path]
|
||||
except KeyError:
|
||||
col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
|
||||
|
||||
return col
|
||||
|
||||
def shutdown(self):
|
||||
"""Close all CollectionWrappers managed by this object."""
|
||||
for path, col in list(self.collections.items()):
|
||||
del self.collections[path]
|
||||
col.close()
|
||||
|
||||
def get_collection_wrapper(config, path, setup_new_collection = None):
|
||||
if "collection_wrapper" in config and config["collection_wrapper"]:
|
||||
logger.info("Found collection_wrapper in config, using {} for "
|
||||
"user data persistence".format(config['collection_wrapper']))
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not CollectionWrapper in inspect.getmro(class_):
|
||||
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
|
||||
inherit from CollectionWrapper''')
|
||||
return class_(config, path, setup_new_collection)
|
||||
else:
|
||||
return CollectionWrapper(config, path, setup_new_collection)
|
||||
42
src/ankisyncd/config.py
Normal file
42
src/ankisyncd/config.py
Normal file
@@ -0,0 +1,42 @@
|
||||
import configparser
|
||||
import logging
|
||||
import os
|
||||
from os.path import dirname, realpath
|
||||
|
||||
logger = logging.getLogger("ankisyncd")
|
||||
|
||||
paths = [
|
||||
"/etc/ankisyncd/ankisyncd.conf",
|
||||
os.environ.get("XDG_CONFIG_HOME") and
|
||||
(os.path.join(os.environ['XDG_CONFIG_HOME'], "ankisyncd", "ankisyncd.conf")) or
|
||||
os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"),
|
||||
os.path.join(dirname(dirname(realpath(__file__))), "ankisyncd.conf"),
|
||||
]
|
||||
|
||||
# Get values from ENV and update the config. To use this prepend `ANKISYNCD_`
|
||||
# to the uppercase form of the key. E.g, `ANKISYNCD_SESSION_MANAGER` to set
|
||||
# `session_manager`
|
||||
def load_from_env(conf):
|
||||
logger.debug("Loading/overriding config values from ENV")
|
||||
for env in os.environ:
|
||||
if env.startswith('ANKISYNCD_'):
|
||||
config_key = env[10:].lower()
|
||||
conf[config_key] = os.getenv(env)
|
||||
logger.info("Setting {} from ENV".format(config_key))
|
||||
|
||||
def load(path=None):
|
||||
choices = paths
|
||||
parser = configparser.ConfigParser()
|
||||
if path:
|
||||
choices = [path]
|
||||
for path in choices:
|
||||
logger.debug("config.location: trying", path)
|
||||
try:
|
||||
parser.read(path)
|
||||
conf = parser['sync_app']
|
||||
logger.info("Loaded config from {}".format(path))
|
||||
load_from_env(conf)
|
||||
return conf
|
||||
except KeyError:
|
||||
pass
|
||||
raise Exception("No config found, looked for {}".format(", ".join(choices)))
|
||||
59
src/ankisyncd/full_sync.py
Normal file
59
src/ankisyncd/full_sync.py
Normal file
@@ -0,0 +1,59 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
|
||||
import os
|
||||
from sqlite3 import dbapi2 as sqlite
|
||||
|
||||
import anki.db
|
||||
|
||||
class FullSyncManager:
|
||||
def upload(self, col, data, session):
|
||||
# Verify integrity of the received database file before replacing our
|
||||
# existing db.
|
||||
temp_db_path = session.get_collection_path() + ".tmp"
|
||||
with open(temp_db_path, 'wb') as f:
|
||||
f.write(data)
|
||||
|
||||
try:
|
||||
with anki.db.DB(temp_db_path) as test_db:
|
||||
if test_db.scalar("pragma integrity_check") != "ok":
|
||||
raise HTTPBadRequest("Integrity check failed for uploaded "
|
||||
"collection database file.")
|
||||
except sqlite.Error as e:
|
||||
raise HTTPBadRequest("Uploaded collection database file is "
|
||||
"corrupt.")
|
||||
|
||||
# Overwrite existing db.
|
||||
col.close()
|
||||
try:
|
||||
os.replace(temp_db_path, session.get_collection_path())
|
||||
finally:
|
||||
col.reopen()
|
||||
col.load()
|
||||
|
||||
return "OK"
|
||||
|
||||
|
||||
def download(self, col, session):
|
||||
col.close()
|
||||
try:
|
||||
data = open(session.get_collection_path(), 'rb').read()
|
||||
finally:
|
||||
col.reopen()
|
||||
col.load()
|
||||
return data
|
||||
|
||||
|
||||
def get_full_sync_manager(config):
|
||||
if "full_sync_manager" in config and config["full_sync_manager"]: # load from config
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['full_sync_manager'].rsplit('.', 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not FullSyncManager in inspect.getmro(class_):
|
||||
raise TypeError('''"full_sync_manager" found in the conf file but it doesn''t
|
||||
inherit from FullSyncManager''')
|
||||
return class_(config)
|
||||
else:
|
||||
return FullSyncManager()
|
||||
67
src/ankisyncd/media.py
Normal file
67
src/ankisyncd/media.py
Normal file
@@ -0,0 +1,67 @@
|
||||
# Based on anki.media.MediaManager, © Ankitects Pty Ltd and contributors
|
||||
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
|
||||
# Original source: https://raw.githubusercontent.com/dae/anki/62481ddc1aa78430cb8114cbf00a7739824318a8/anki/media.py
|
||||
|
||||
import logging
|
||||
import re
|
||||
import os
|
||||
import os.path
|
||||
|
||||
import anki.db
|
||||
|
||||
logger = logging.getLogger("ankisyncd.media")
|
||||
|
||||
|
||||
class ServerMediaManager:
|
||||
def __init__(self, col):
|
||||
self._dir = re.sub(r"(?i)\.(anki2)$", ".media", col.path)
|
||||
self.connect()
|
||||
|
||||
def connect(self):
|
||||
path = self.dir() + ".server.db"
|
||||
create = not os.path.exists(path)
|
||||
self.db = anki.db.DB(path)
|
||||
if create:
|
||||
self.db.executescript(
|
||||
"""CREATE TABLE media (
|
||||
fname TEXT NOT NULL PRIMARY KEY,
|
||||
usn INT NOT NULL,
|
||||
csum TEXT -- null if deleted
|
||||
);
|
||||
CREATE INDEX idx_media_usn ON media (usn);"""
|
||||
)
|
||||
oldpath = self.dir() + ".db2"
|
||||
if os.path.exists(oldpath):
|
||||
logger.info("Found client media database, migrating contents")
|
||||
self.db.execute("ATTACH ? AS old", oldpath)
|
||||
self.db.execute(
|
||||
"INSERT INTO media SELECT fname, lastUsn, csum FROM old.media, old.meta"
|
||||
)
|
||||
self.db.commit()
|
||||
self.db.execute("DETACH old")
|
||||
|
||||
def close(self):
|
||||
self.db.close()
|
||||
|
||||
def dir(self):
|
||||
return self._dir
|
||||
|
||||
def lastUsn(self):
|
||||
return self.db.scalar("SELECT max(usn) FROM media") or 0
|
||||
|
||||
def mediaCount(self):
|
||||
return self.db.scalar("SELECT count() FROM media WHERE csum IS NOT NULL")
|
||||
|
||||
# used only in unit tests
|
||||
def syncInfo(self, fname):
|
||||
return self.db.first("SELECT csum, 0 FROM media WHERE fname=?", fname)
|
||||
|
||||
def syncDelete(self, fname):
|
||||
fpath = os.path.join(self.dir(), fname)
|
||||
if os.path.exists(fpath):
|
||||
os.remove(fpath)
|
||||
self.db.execute(
|
||||
"UPDATE media SET csum = NULL, usn = ? WHERE fname = ?",
|
||||
self.lastUsn() + 1,
|
||||
fname,
|
||||
)
|
||||
141
src/ankisyncd/sessions.py
Normal file
141
src/ankisyncd/sessions.py
Normal file
@@ -0,0 +1,141 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import os
|
||||
import logging
|
||||
from sqlite3 import dbapi2 as sqlite
|
||||
|
||||
logger = logging.getLogger("ankisyncd.sessions")
|
||||
|
||||
|
||||
class SimpleSessionManager:
|
||||
"""A simple session manager that keeps the sessions in memory."""
|
||||
|
||||
def __init__(self):
|
||||
self.sessions = {}
|
||||
|
||||
def load(self, hkey, session_factory=None):
|
||||
return self.sessions.get(hkey)
|
||||
|
||||
def load_from_skey(self, skey, session_factory=None):
|
||||
for i in self.sessions:
|
||||
if self.sessions[i].skey == skey:
|
||||
return self.sessions[i]
|
||||
|
||||
def save(self, hkey, session):
|
||||
self.sessions[hkey] = session
|
||||
|
||||
def delete(self, hkey):
|
||||
del self.sessions[hkey]
|
||||
|
||||
|
||||
class SqliteSessionManager(SimpleSessionManager):
|
||||
"""Stores sessions in a SQLite database to prevent the user from being logged out
|
||||
everytime the SyncApp is restarted."""
|
||||
|
||||
def __init__(self, session_db_path):
|
||||
SimpleSessionManager.__init__(self)
|
||||
|
||||
self.session_db_path = os.path.realpath(session_db_path)
|
||||
self._ensure_schema_up_to_date()
|
||||
|
||||
def _ensure_schema_up_to_date(self):
|
||||
if not os.path.exists(self.session_db_path):
|
||||
return True
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'session'")
|
||||
res = cursor.fetchone()
|
||||
conn.close()
|
||||
if res is not None:
|
||||
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
|
||||
|
||||
def _conn(self):
|
||||
new = not os.path.exists(self.session_db_path)
|
||||
conn = sqlite.connect(self.session_db_path)
|
||||
if new:
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)")
|
||||
return conn
|
||||
|
||||
# Default to using sqlite3 syntax but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
@staticmethod
|
||||
def fs(sql):
|
||||
return sql
|
||||
|
||||
def load(self, hkey, session_factory=None):
|
||||
session = SimpleSessionManager.load(self, hkey)
|
||||
if session is not None:
|
||||
return session
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,))
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
session = self.sessions[hkey] = session_factory(res[1], res[2])
|
||||
session.skey = res[0]
|
||||
return session
|
||||
|
||||
def load_from_skey(self, skey, session_factory=None):
|
||||
session = SimpleSessionManager.load_from_skey(self, skey)
|
||||
if session is not None:
|
||||
return session
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,))
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
session = self.sessions[res[0]] = session_factory(res[1], res[2])
|
||||
session.skey = skey
|
||||
return session
|
||||
|
||||
def save(self, hkey, session):
|
||||
SimpleSessionManager.save(self, hkey, session)
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
|
||||
(hkey, session.skey, session.name, session.path))
|
||||
|
||||
conn.commit()
|
||||
|
||||
def delete(self, hkey):
|
||||
SimpleSessionManager.delete(self, hkey)
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
|
||||
cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,))
|
||||
conn.commit()
|
||||
|
||||
def get_session_manager(config):
|
||||
if "session_db_path" in config and config["session_db_path"]:
|
||||
logger.info("Found session_db_path in config, using SqliteSessionManager for auth")
|
||||
return SqliteSessionManager(config['session_db_path'])
|
||||
elif "session_manager" in config and config["session_manager"]: # load from config
|
||||
logger.info("Found session_manager in config, using {} for persisting sessions".format(
|
||||
config['session_manager'])
|
||||
)
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['session_manager'].rsplit('.', 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not SimpleSessionManager in inspect.getmro(class_):
|
||||
raise TypeError('''"session_manager" found in the conf file but it doesn''t
|
||||
inherit from SimpleSessionManager''')
|
||||
return class_(config)
|
||||
else:
|
||||
logger.warning("Neither session_db_path nor session_manager set, "
|
||||
"ankisyncd will lose sessions on application restart")
|
||||
return SimpleSessionManager()
|
||||
675
src/ankisyncd/sync_app.py
Normal file
675
src/ankisyncd/sync_app.py
Normal file
@@ -0,0 +1,675 @@
|
||||
# ankisyncd - A personal Anki sync server
|
||||
# Copyright (C) 2013 David Snopek
|
||||
#
|
||||
# This program is free software: you can redistribute it and/or modify
|
||||
# it under the terms of the GNU Affero General Public License as
|
||||
# published by the Free Software Foundation, either version 3 of the
|
||||
# License, or (at your option) any later version.
|
||||
#
|
||||
# This program is distributed in the hope that it will be useful,
|
||||
# but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||
# GNU Affero General Public License for more details.
|
||||
#
|
||||
# You should have received a copy of the GNU Affero General Public License
|
||||
# along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||
|
||||
import gzip
|
||||
import hashlib
|
||||
import io
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
import random
|
||||
import re
|
||||
import string
|
||||
import sys
|
||||
import time
|
||||
import unicodedata
|
||||
import zipfile
|
||||
from configparser import ConfigParser
|
||||
from sqlite3 import dbapi2 as sqlite
|
||||
|
||||
from webob import Response
|
||||
from webob.dec import wsgify
|
||||
from webob.exc import *
|
||||
|
||||
import anki.db
|
||||
import anki.sync
|
||||
import anki.utils
|
||||
from anki.consts import SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
|
||||
from anki.consts import REM_CARD, REM_NOTE
|
||||
|
||||
from ankisyncd.users import get_user_manager
|
||||
from ankisyncd.sessions import get_session_manager
|
||||
from ankisyncd.full_sync import get_full_sync_manager
|
||||
|
||||
logger = logging.getLogger("ankisyncd")
|
||||
|
||||
|
||||
class SyncCollectionHandler(anki.sync.Syncer):
|
||||
operations = ['meta', 'applyChanges', 'start', 'applyGraves', 'chunk', 'applyChunk', 'sanityCheck2', 'finish']
|
||||
|
||||
def __init__(self, col):
|
||||
# So that 'server' (the 3rd argument) can't get set
|
||||
anki.sync.Syncer.__init__(self, col)
|
||||
|
||||
@staticmethod
|
||||
def _old_client(cv):
|
||||
if not cv:
|
||||
return False
|
||||
|
||||
note = {"alpha": 0, "beta": 0, "rc": 0}
|
||||
client, version, platform = cv.split(',')
|
||||
|
||||
for name in note.keys():
|
||||
if name in version:
|
||||
vs = version.split(name)
|
||||
version = vs[0]
|
||||
note[name] = int(vs[-1])
|
||||
|
||||
# convert the version string, ignoring non-numeric suffixes like in beta versions of Anki
|
||||
version_nosuffix = re.sub(r'[^0-9.].*$', '', version)
|
||||
version_int = [int(x) for x in version_nosuffix.split('.')]
|
||||
|
||||
if client == 'ankidesktop':
|
||||
return version_int < [2, 0, 27]
|
||||
elif client == 'ankidroid':
|
||||
if version_int == [2, 3]:
|
||||
if note["alpha"]:
|
||||
return note["alpha"] < 4
|
||||
else:
|
||||
return version_int < [2, 2, 3]
|
||||
else: # unknown client, assume current version
|
||||
return False
|
||||
|
||||
def meta(self, v=None, cv=None):
|
||||
if self._old_client(cv):
|
||||
return Response(status=501) # client needs upgrade
|
||||
if v > SYNC_VER:
|
||||
return {"cont": False, "msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(v, SYNC_VER)}
|
||||
if v < 9 and self.col.schedVer() >= 2:
|
||||
return {"cont": False, "msg": "Your client doesn't support the v{} scheduler.".format(self.col.schedVer())}
|
||||
|
||||
# Make sure the media database is open!
|
||||
if self.col.media.db is None:
|
||||
self.col.media.connect()
|
||||
|
||||
return {
|
||||
'scm': self.col.scm,
|
||||
'ts': anki.utils.intTime(),
|
||||
'mod': self.col.mod,
|
||||
'usn': self.col._usn,
|
||||
'musn': self.col.media.lastUsn(),
|
||||
'msg': '',
|
||||
'cont': True,
|
||||
}
|
||||
|
||||
def usnLim(self):
|
||||
return "usn >= %d" % self.minUsn
|
||||
|
||||
# ankidesktop >=2.1rc2 sends graves in applyGraves, but still expects
|
||||
# server-side deletions to be returned by start
|
||||
def start(self, minUsn, lnewer, graves={"cards": [], "notes": [], "decks": []}, offset=None):
|
||||
if offset is not None:
|
||||
raise NotImplementedError('You are using the experimental V2 scheduler, which is not supported by the server.')
|
||||
self.maxUsn = self.col._usn
|
||||
self.minUsn = minUsn
|
||||
self.lnewer = not lnewer
|
||||
lgraves = self.removed()
|
||||
self.remove(graves)
|
||||
return lgraves
|
||||
|
||||
def applyGraves(self, chunk):
|
||||
self.remove(chunk)
|
||||
|
||||
def applyChanges(self, changes):
|
||||
self.rchg = changes
|
||||
lchg = self.changes()
|
||||
# merge our side before returning
|
||||
self.mergeChanges(lchg, self.rchg)
|
||||
return lchg
|
||||
|
||||
def sanityCheck2(self, client):
|
||||
server = self.sanityCheck()
|
||||
if client != server:
|
||||
return dict(status="bad", c=client, s=server)
|
||||
return dict(status="ok")
|
||||
|
||||
def finish(self, mod=None):
|
||||
return anki.sync.Syncer.finish(self, anki.utils.intTime(1000))
|
||||
|
||||
# This function had to be put here in its entirety because Syncer.removed()
|
||||
# doesn't use self.usnLim() (which we override in this class) in queries.
|
||||
# "usn=-1" has been replaced with "usn >= ?", self.minUsn by hand.
|
||||
def removed(self):
|
||||
cards = []
|
||||
notes = []
|
||||
decks = []
|
||||
|
||||
curs = self.col.db.execute(
|
||||
"select oid, type from graves where usn >= ?", self.minUsn)
|
||||
|
||||
for oid, type in curs:
|
||||
if type == REM_CARD:
|
||||
cards.append(oid)
|
||||
elif type == REM_NOTE:
|
||||
notes.append(oid)
|
||||
else:
|
||||
decks.append(oid)
|
||||
|
||||
return dict(cards=cards, notes=notes, decks=decks)
|
||||
|
||||
def getModels(self):
|
||||
return [m for m in self.col.models.all() if m['usn'] >= self.minUsn]
|
||||
|
||||
def getDecks(self):
|
||||
return [
|
||||
[g for g in self.col.decks.all() if g['usn'] >= self.minUsn],
|
||||
[g for g in self.col.decks.allConf() if g['usn'] >= self.minUsn]
|
||||
]
|
||||
|
||||
def getTags(self):
|
||||
return [t for t, usn in self.col.tags.allItems()
|
||||
if usn >= self.minUsn]
|
||||
|
||||
class SyncMediaHandler:
|
||||
operations = ['begin', 'mediaChanges', 'mediaSanity', 'uploadChanges', 'downloadFiles']
|
||||
|
||||
def __init__(self, col):
|
||||
self.col = col
|
||||
|
||||
def begin(self, skey):
|
||||
return {
|
||||
'data': {
|
||||
'sk': skey,
|
||||
'usn': self.col.media.lastUsn(),
|
||||
},
|
||||
'err': '',
|
||||
}
|
||||
|
||||
def uploadChanges(self, data):
|
||||
"""
|
||||
The zip file contains files the client hasn't synced with the server
|
||||
yet ('dirty'), and info on files it has deleted from its own media dir.
|
||||
"""
|
||||
|
||||
with zipfile.ZipFile(io.BytesIO(data), "r") as z:
|
||||
self._check_zip_data(z)
|
||||
processed_count = self._adopt_media_changes_from_zip(z)
|
||||
|
||||
return {
|
||||
'data': [processed_count, self.col.media.lastUsn()],
|
||||
'err': '',
|
||||
}
|
||||
|
||||
@staticmethod
|
||||
def _check_zip_data(zip_file):
|
||||
max_zip_size = 100*1024*1024
|
||||
max_meta_file_size = 100000
|
||||
|
||||
meta_file_size = zip_file.getinfo("_meta").file_size
|
||||
sum_file_sizes = sum(info.file_size for info in zip_file.infolist())
|
||||
|
||||
if meta_file_size > max_meta_file_size:
|
||||
raise ValueError("Zip file's metadata file is larger than %s "
|
||||
"Bytes." % max_meta_file_size)
|
||||
elif sum_file_sizes > max_zip_size:
|
||||
raise ValueError("Zip file contents are larger than %s Bytes." %
|
||||
max_zip_size)
|
||||
|
||||
def _adopt_media_changes_from_zip(self, zip_file):
|
||||
"""
|
||||
Adds and removes files to/from the database and media directory
|
||||
according to the data in zip file zipData.
|
||||
"""
|
||||
|
||||
# Get meta info first.
|
||||
meta = json.loads(zip_file.read("_meta").decode())
|
||||
|
||||
# Remove media files that were removed on the client.
|
||||
media_to_remove = []
|
||||
for normname, ordinal in meta:
|
||||
if ordinal == '':
|
||||
media_to_remove.append(self._normalize_filename(normname))
|
||||
|
||||
# Add media files that were added on the client.
|
||||
media_to_add = []
|
||||
usn = self.col.media.lastUsn()
|
||||
oldUsn = usn
|
||||
for i in zip_file.infolist():
|
||||
if i.filename == "_meta": # Ignore previously retrieved metadata.
|
||||
continue
|
||||
|
||||
file_data = zip_file.read(i)
|
||||
csum = anki.utils.checksum(file_data)
|
||||
filename = self._normalize_filename(meta[int(i.filename)][0])
|
||||
file_path = os.path.join(self.col.media.dir(), filename)
|
||||
|
||||
# Save file to media directory.
|
||||
with open(file_path, 'wb') as f:
|
||||
f.write(file_data)
|
||||
|
||||
usn += 1
|
||||
media_to_add.append((filename, usn, csum))
|
||||
|
||||
# We count all files we are to remove, even if we don't have them in
|
||||
# our media directory and our db doesn't know about them.
|
||||
processed_count = len(media_to_remove) + len(media_to_add)
|
||||
|
||||
assert len(meta) == processed_count # sanity check
|
||||
|
||||
if media_to_remove:
|
||||
self._remove_media_files(media_to_remove)
|
||||
|
||||
if media_to_add:
|
||||
self.col.media.db.executemany(
|
||||
"INSERT OR REPLACE INTO media VALUES (?,?,?)", media_to_add)
|
||||
self.col.media.db.commit()
|
||||
|
||||
assert self.col.media.lastUsn() == oldUsn + processed_count # TODO: move to some unit test
|
||||
return processed_count
|
||||
|
||||
@staticmethod
|
||||
def _normalize_filename(filename):
|
||||
"""
|
||||
Performs unicode normalization for file names. Logic taken from Anki's
|
||||
MediaManager.addFilesFromZip().
|
||||
"""
|
||||
|
||||
# Normalize name for platform.
|
||||
if anki.utils.isMac: # global
|
||||
filename = unicodedata.normalize("NFD", filename)
|
||||
else:
|
||||
filename = unicodedata.normalize("NFC", filename)
|
||||
|
||||
return filename
|
||||
|
||||
def _remove_media_files(self, filenames):
|
||||
"""
|
||||
Marks all files in list filenames as deleted and removes them from the
|
||||
media directory.
|
||||
"""
|
||||
logger.debug('Removing %d files from media dir.' % len(filenames))
|
||||
for filename in filenames:
|
||||
try:
|
||||
self.col.media.syncDelete(filename)
|
||||
self.col.media.db.commit()
|
||||
except OSError as err:
|
||||
logger.error("Error when removing file '%s' from media dir: "
|
||||
"%s" % (filename, str(err)))
|
||||
|
||||
def downloadFiles(self, files):
|
||||
flist = {}
|
||||
cnt = 0
|
||||
sz = 0
|
||||
f = io.BytesIO()
|
||||
|
||||
with zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED) as z:
|
||||
for fname in files:
|
||||
z.write(os.path.join(self.col.media.dir(), fname), str(cnt))
|
||||
flist[str(cnt)] = fname
|
||||
sz += os.path.getsize(os.path.join(self.col.media.dir(), fname))
|
||||
if sz > SYNC_ZIP_SIZE or cnt > SYNC_ZIP_COUNT:
|
||||
break
|
||||
cnt += 1
|
||||
|
||||
z.writestr("_meta", json.dumps(flist))
|
||||
|
||||
return f.getvalue()
|
||||
|
||||
def mediaChanges(self, lastUsn):
|
||||
result = []
|
||||
server_lastUsn = self.col.media.lastUsn()
|
||||
fname = csum = None
|
||||
|
||||
if lastUsn < server_lastUsn or lastUsn == 0:
|
||||
for fname,usn,csum, in self.col.media.db.execute("select fname,usn,csum from media order by usn desc limit ?", server_lastUsn - lastUsn):
|
||||
result.append([fname, usn, csum])
|
||||
|
||||
# anki assumes server_lastUsn == result[-1][1]
|
||||
# ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7)
|
||||
result.reverse()
|
||||
|
||||
return {'data': result, 'err': ''}
|
||||
|
||||
def mediaSanity(self, local=None):
|
||||
if self.col.media.mediaCount() == local:
|
||||
result = "OK"
|
||||
else:
|
||||
result = "FAILED"
|
||||
|
||||
return {'data': result, 'err': ''}
|
||||
|
||||
class SyncUserSession:
|
||||
def __init__(self, name, path, collection_manager, setup_new_collection=None):
|
||||
self.skey = self._generate_session_key()
|
||||
self.name = name
|
||||
self.path = path
|
||||
self.collection_manager = collection_manager
|
||||
self.setup_new_collection = setup_new_collection
|
||||
self.version = None
|
||||
self.client_version = None
|
||||
self.created = time.time()
|
||||
self.collection_handler = None
|
||||
self.media_handler = None
|
||||
|
||||
# make sure the user path exists
|
||||
if not os.path.exists(path):
|
||||
os.mkdir(path)
|
||||
|
||||
def _generate_session_key(self):
|
||||
return anki.utils.checksum(str(random.random()))[:8]
|
||||
|
||||
def get_collection_path(self):
|
||||
return os.path.realpath(os.path.join(self.path, 'collection.anki2'))
|
||||
|
||||
def get_thread(self):
|
||||
return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection)
|
||||
|
||||
def get_handler_for_operation(self, operation, col):
|
||||
if operation in SyncCollectionHandler.operations:
|
||||
attr, handler_class = 'collection_handler', SyncCollectionHandler
|
||||
elif operation in SyncMediaHandler.operations:
|
||||
attr, handler_class = 'media_handler', SyncMediaHandler
|
||||
else:
|
||||
raise Exception("no handler for {}".format(operation))
|
||||
|
||||
if getattr(self, attr) is None:
|
||||
setattr(self, attr, handler_class(col))
|
||||
handler = getattr(self, attr)
|
||||
# The col object may actually be new now! This happens when we close a collection
|
||||
# for inactivity and then later re-open it (creating a new Collection object).
|
||||
handler.col = col
|
||||
return handler
|
||||
|
||||
class SyncApp:
|
||||
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
|
||||
|
||||
def __init__(self, config):
|
||||
from ankisyncd.thread import get_collection_manager
|
||||
|
||||
self.data_root = os.path.abspath(config['data_root'])
|
||||
self.base_url = config['base_url']
|
||||
self.base_media_url = config['base_media_url']
|
||||
self.setup_new_collection = None
|
||||
|
||||
self.prehooks = {}
|
||||
self.posthooks = {}
|
||||
|
||||
self.user_manager = get_user_manager(config)
|
||||
self.session_manager = get_session_manager(config)
|
||||
self.full_sync_manager = get_full_sync_manager(config)
|
||||
self.collection_manager = get_collection_manager(config)
|
||||
|
||||
# make sure the base_url has a trailing slash
|
||||
if not self.base_url.endswith('/'):
|
||||
self.base_url += '/'
|
||||
if not self.base_media_url.endswith('/'):
|
||||
self.base_media_url += '/'
|
||||
|
||||
# backwards compat
|
||||
@property
|
||||
def hook_pre_sync(self):
|
||||
return self.prehooks.get("start")
|
||||
|
||||
@hook_pre_sync.setter
|
||||
def hook_pre_sync(self, value):
|
||||
self.prehooks['start'] = value
|
||||
|
||||
@property
|
||||
def hook_post_sync(self):
|
||||
return self.posthooks.get("finish")
|
||||
|
||||
@hook_post_sync.setter
|
||||
def hook_post_sync(self, value):
|
||||
self.posthooks['finish'] = value
|
||||
|
||||
@property
|
||||
def hook_upload(self):
|
||||
return self.prehooks.get("upload")
|
||||
|
||||
@hook_upload.setter
|
||||
def hook_upload(self, value):
|
||||
self.prehooks['upload'] = value
|
||||
|
||||
@property
|
||||
def hook_download(self):
|
||||
return self.posthooks.get("download")
|
||||
|
||||
@hook_download.setter
|
||||
def hook_download(self, value):
|
||||
self.posthooks['download'] = value
|
||||
|
||||
def generateHostKey(self, username):
|
||||
"""Generates a new host key to be used by the given username to identify their session.
|
||||
This values is random."""
|
||||
|
||||
import hashlib, time, random, string
|
||||
chars = string.ascii_letters + string.digits
|
||||
val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))]).encode()
|
||||
return hashlib.md5(val).hexdigest()
|
||||
|
||||
def create_session(self, username, user_path):
|
||||
return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection)
|
||||
|
||||
def _decode_data(self, data, compression=0):
|
||||
if compression:
|
||||
with gzip.GzipFile(mode="rb", fileobj=io.BytesIO(data)) as gz:
|
||||
data = gz.read()
|
||||
|
||||
try:
|
||||
data = json.loads(data.decode())
|
||||
except (ValueError, UnicodeDecodeError):
|
||||
data = {'data': data}
|
||||
|
||||
return data
|
||||
|
||||
def operation_hostKey(self, username, password):
|
||||
if not self.user_manager.authenticate(username, password):
|
||||
return
|
||||
|
||||
dirname = self.user_manager.userdir(username)
|
||||
if dirname is None:
|
||||
return
|
||||
|
||||
hkey = self.generateHostKey(username)
|
||||
user_path = os.path.join(self.data_root, dirname)
|
||||
session = self.create_session(username, user_path)
|
||||
self.session_manager.save(hkey, session)
|
||||
|
||||
return {'key': hkey}
|
||||
|
||||
def operation_upload(self, col, data, session):
|
||||
# Verify integrity of the received database file before replacing our
|
||||
# existing db.
|
||||
|
||||
return self.full_sync_manager.upload(col, data, session)
|
||||
|
||||
def operation_download(self, col, session):
|
||||
# returns user data (not media) as a sqlite3 database for replacing their
|
||||
# local copy in Anki
|
||||
return self.full_sync_manager.download(col, session)
|
||||
|
||||
@wsgify
|
||||
def __call__(self, req):
|
||||
# Get and verify the session
|
||||
try:
|
||||
hkey = req.POST['k']
|
||||
except KeyError:
|
||||
hkey = None
|
||||
|
||||
session = self.session_manager.load(hkey, self.create_session)
|
||||
|
||||
if session is None:
|
||||
try:
|
||||
skey = req.POST['sk']
|
||||
session = self.session_manager.load_from_skey(skey, self.create_session)
|
||||
except KeyError:
|
||||
skey = None
|
||||
|
||||
try:
|
||||
compression = int(req.POST['c'])
|
||||
except KeyError:
|
||||
compression = 0
|
||||
|
||||
try:
|
||||
data = req.POST['data'].file.read()
|
||||
data = self._decode_data(data, compression)
|
||||
except KeyError:
|
||||
data = {}
|
||||
|
||||
if req.path.startswith(self.base_url):
|
||||
url = req.path[len(self.base_url):]
|
||||
if url not in self.valid_urls:
|
||||
raise HTTPNotFound()
|
||||
|
||||
if url == 'hostKey':
|
||||
result = self.operation_hostKey(data.get("u"), data.get("p"))
|
||||
if result:
|
||||
return json.dumps(result)
|
||||
else:
|
||||
# TODO: do I have to pass 'null' for the client to receive None?
|
||||
raise HTTPForbidden('null')
|
||||
|
||||
if session is None:
|
||||
raise HTTPForbidden()
|
||||
|
||||
if url in SyncCollectionHandler.operations + SyncMediaHandler.operations:
|
||||
# 'meta' passes the SYNC_VER but it isn't used in the handler
|
||||
if url == 'meta':
|
||||
if session.skey == None and 's' in req.POST:
|
||||
session.skey = req.POST['s']
|
||||
if 'v' in data:
|
||||
session.version = data['v']
|
||||
if 'cv' in data:
|
||||
session.client_version = data['cv']
|
||||
|
||||
self.session_manager.save(hkey, session)
|
||||
session = self.session_manager.load(hkey, self.create_session)
|
||||
|
||||
thread = session.get_thread()
|
||||
|
||||
if url in self.prehooks:
|
||||
thread.execute(self.prehooks[url], [session])
|
||||
|
||||
result = self._execute_handler_method_in_thread(url, data, session)
|
||||
|
||||
# If it's a complex data type, we convert it to JSON
|
||||
if type(result) not in (str, bytes, Response):
|
||||
result = json.dumps(result)
|
||||
|
||||
if url in self.posthooks:
|
||||
thread.execute(self.posthooks[url], [session])
|
||||
|
||||
return result
|
||||
|
||||
elif url == 'upload':
|
||||
thread = session.get_thread()
|
||||
if url in self.prehooks:
|
||||
thread.execute(self.prehooks[url], [session])
|
||||
result = thread.execute(self.operation_upload, [data['data'], session])
|
||||
if url in self.posthooks:
|
||||
thread.execute(self.posthooks[url], [session])
|
||||
return result
|
||||
|
||||
elif url == 'download':
|
||||
thread = session.get_thread()
|
||||
if url in self.prehooks:
|
||||
thread.execute(self.prehooks[url], [session])
|
||||
result = thread.execute(self.operation_download, [session])
|
||||
if url in self.posthooks:
|
||||
thread.execute(self.posthooks[url], [session])
|
||||
return result
|
||||
|
||||
# This was one of our operations but it didn't get handled... Oops!
|
||||
raise HTTPInternalServerError()
|
||||
|
||||
# media sync
|
||||
elif req.path.startswith(self.base_media_url):
|
||||
if session is None:
|
||||
raise HTTPForbidden()
|
||||
|
||||
url = req.path[len(self.base_media_url):]
|
||||
|
||||
if url not in self.valid_urls:
|
||||
raise HTTPNotFound()
|
||||
|
||||
if url == "begin":
|
||||
data['skey'] = session.skey
|
||||
|
||||
result = self._execute_handler_method_in_thread(url, data, session)
|
||||
|
||||
# If it's a complex data type, we convert it to JSON
|
||||
if type(result) not in (str, bytes):
|
||||
result = json.dumps(result)
|
||||
|
||||
return result
|
||||
|
||||
return "Anki Sync Server"
|
||||
|
||||
@staticmethod
|
||||
def _execute_handler_method_in_thread(method_name, keyword_args, session):
|
||||
"""
|
||||
Gets and runs the handler method specified by method_name inside the
|
||||
thread for session. The handler method will access the collection as
|
||||
self.col.
|
||||
"""
|
||||
|
||||
def run_func(col, **keyword_args):
|
||||
# Retrieve the correct handler method.
|
||||
handler = session.get_handler_for_operation(method_name, col)
|
||||
handler_method = getattr(handler, method_name)
|
||||
|
||||
res = handler_method(**keyword_args)
|
||||
|
||||
col.save()
|
||||
return res
|
||||
|
||||
run_func.__name__ = method_name # More useful debugging messages.
|
||||
|
||||
# Send the closure to the thread for execution.
|
||||
thread = session.get_thread()
|
||||
result = thread.execute(run_func, kw=keyword_args)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
def make_app(global_conf, **local_conf):
|
||||
return SyncApp(**local_conf)
|
||||
|
||||
def main():
|
||||
logging.basicConfig(level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s")
|
||||
import ankisyncd
|
||||
logger.info("ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage))
|
||||
from wsgiref.simple_server import make_server, WSGIRequestHandler
|
||||
from ankisyncd.thread import shutdown
|
||||
import ankisyncd.config
|
||||
|
||||
class RequestHandler(WSGIRequestHandler):
|
||||
logger = logging.getLogger("ankisyncd.http")
|
||||
|
||||
def log_error(self, format, *args):
|
||||
self.logger.error("%s %s", self.address_string(), format%args)
|
||||
|
||||
def log_message(self, format, *args):
|
||||
self.logger.info("%s %s", self.address_string(), format%args)
|
||||
|
||||
if len(sys.argv) > 1:
|
||||
# backwards compat
|
||||
config = ankisyncd.config.load(sys.argv[1])
|
||||
else:
|
||||
config = ankisyncd.config.load()
|
||||
|
||||
ankiserver = SyncApp(config)
|
||||
httpd = make_server(config['host'], int(config['port']), ankiserver, handler_class=RequestHandler)
|
||||
|
||||
try:
|
||||
logger.info("Serving HTTP on {} port {}...".format(*httpd.server_address))
|
||||
httpd.serve_forever()
|
||||
except KeyboardInterrupt:
|
||||
logger.info("Exiting...")
|
||||
finally:
|
||||
shutdown()
|
||||
|
||||
if __name__ == '__main__': main()
|
||||
218
src/ankisyncd/thread.py
Normal file
218
src/ankisyncd/thread.py
Normal file
@@ -0,0 +1,218 @@
|
||||
from ankisyncd.collection import CollectionManager, get_collection_wrapper
|
||||
|
||||
from threading import Thread
|
||||
from queue import Queue
|
||||
|
||||
import time, logging
|
||||
|
||||
def short_repr(obj, logger=logging.getLogger(), maxlen=80):
|
||||
"""Like repr, but shortens strings and bytestrings if logger's logging level
|
||||
is above DEBUG. Currently shallow and very limited, only implemented for
|
||||
dicts and lists."""
|
||||
if logger.isEnabledFor(logging.DEBUG):
|
||||
return repr(obj)
|
||||
|
||||
def shorten(s):
|
||||
if isinstance(s, (bytes, str)) and len(s) > maxlen:
|
||||
return s[:maxlen] + ("..." if isinstance(s, str) else b"...")
|
||||
else:
|
||||
return s
|
||||
|
||||
o = obj.copy()
|
||||
if isinstance(o, dict):
|
||||
for k in o:
|
||||
o[k] = shorten(o[k])
|
||||
elif isinstance(o, list):
|
||||
for k in range(len(o)):
|
||||
o[k] = shorten(o[k])
|
||||
|
||||
return repr(o)
|
||||
|
||||
class ThreadingCollectionWrapper:
|
||||
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
|
||||
interact with the collection."""
|
||||
|
||||
def __init__(self, config, path, setup_new_collection=None):
|
||||
self.path = path
|
||||
self.wrapper = get_collection_wrapper(config, path, setup_new_collection)
|
||||
self.logger = logging.getLogger("ankisyncd." + str(self))
|
||||
|
||||
self._queue = Queue()
|
||||
self._thread = None
|
||||
self._running = False
|
||||
self.last_timestamp = time.time()
|
||||
|
||||
self.start()
|
||||
|
||||
def __str__(self):
|
||||
return "CollectionThread[{}]".format(self.wrapper.username)
|
||||
|
||||
@property
|
||||
def running(self):
|
||||
return self._running
|
||||
|
||||
def qempty(self):
|
||||
return self._queue.empty()
|
||||
|
||||
def current(self):
|
||||
from threading import current_thread
|
||||
return current_thread() == self._thread
|
||||
|
||||
def execute(self, func, args=[], kw={}, waitForReturn=True):
|
||||
""" Executes a given function on this thread with the *args and **kw.
|
||||
|
||||
If 'waitForReturn' is True, then it will block until the function has
|
||||
executed and return its return value. If False, it will return None
|
||||
immediately and the function will be executed sometime later.
|
||||
"""
|
||||
|
||||
if waitForReturn:
|
||||
return_queue = Queue()
|
||||
else:
|
||||
return_queue = None
|
||||
|
||||
self._queue.put((func, args, kw, return_queue))
|
||||
|
||||
if return_queue is not None:
|
||||
ret = return_queue.get(True)
|
||||
if isinstance(ret, Exception):
|
||||
raise ret
|
||||
return ret
|
||||
|
||||
def _run(self):
|
||||
self.logger.info("Starting...")
|
||||
|
||||
try:
|
||||
while self._running:
|
||||
func, args, kw, return_queue = self._queue.get(True)
|
||||
|
||||
if hasattr(func, '__name__'):
|
||||
func_name = func.__name__
|
||||
else:
|
||||
func_name = func.__class__.__name__
|
||||
|
||||
self.logger.info("Running %s(*%s, **%s)", func_name, short_repr(args, self.logger), short_repr(kw, self.logger))
|
||||
self.last_timestamp = time.time()
|
||||
|
||||
try:
|
||||
ret = self.wrapper.execute(func, args, kw, return_queue)
|
||||
except Exception as e:
|
||||
self.logger.error("Unable to %s(*%s, **%s): %s",
|
||||
func_name, repr(args), repr(kw), e, exc_info=True)
|
||||
# we return the Exception which will be raise'd on the other end
|
||||
ret = e
|
||||
|
||||
if return_queue is not None:
|
||||
return_queue.put(ret)
|
||||
except Exception as e:
|
||||
self.logger.error("Thread crashed! Exception: %s", e, exc_info=True)
|
||||
finally:
|
||||
self.wrapper.close()
|
||||
# clean out old thread object
|
||||
self._thread = None
|
||||
# in case we got here via an exception
|
||||
self._running = False
|
||||
|
||||
self.logger.info("Stopped!")
|
||||
|
||||
def start(self):
|
||||
if not self._running:
|
||||
self._running = True
|
||||
assert self._thread is None
|
||||
self._thread = Thread(target=self._run)
|
||||
self._thread.start()
|
||||
|
||||
def stop(self):
|
||||
def _stop(col):
|
||||
self._running = False
|
||||
self.execute(_stop, waitForReturn=False)
|
||||
|
||||
def stop_and_wait(self):
|
||||
""" Tell the thread to stop and wait for it to happen. """
|
||||
self.stop()
|
||||
if self._thread is not None:
|
||||
self._thread.join()
|
||||
|
||||
#
|
||||
# Mimic the CollectionWrapper interface
|
||||
#
|
||||
|
||||
def open(self):
|
||||
"""Non-op. The collection will be opened on demand."""
|
||||
pass
|
||||
|
||||
def close(self):
|
||||
"""Closes the underlying collection without stopping the thread."""
|
||||
|
||||
def _close(col):
|
||||
self.wrapper.close()
|
||||
self.execute(_close, waitForReturn=False)
|
||||
|
||||
def opened(self):
|
||||
return self.wrapper.opened()
|
||||
|
||||
class ThreadingCollectionManager(CollectionManager):
|
||||
"""Manages a set of ThreadingCollectionWrapper objects."""
|
||||
|
||||
collection_wrapper = ThreadingCollectionWrapper
|
||||
|
||||
def __init__(self, config):
|
||||
super(ThreadingCollectionManager, self).__init__(config)
|
||||
|
||||
self.monitor_frequency = 15
|
||||
self.monitor_inactivity = 90
|
||||
self.logger = logging.getLogger("ankisyncd.ThreadingCollectionManager")
|
||||
|
||||
monitor = Thread(target=self._monitor_run)
|
||||
monitor.daemon = True
|
||||
monitor.start()
|
||||
self._monitor_thread = monitor
|
||||
|
||||
# TODO: we should raise some error if a collection is started on a manager that has already been shutdown!
|
||||
# or maybe we could support being restarted?
|
||||
|
||||
# TODO: it would be awesome to have a safe way to stop inactive threads completely!
|
||||
# TODO: we need a way to inform other code that the collection has been closed
|
||||
def _monitor_run(self):
|
||||
""" Monitors threads for inactivity and closes the collection on them
|
||||
(leaves the thread itself running -- hopefully waiting peacefully with only a
|
||||
small memory footprint!) """
|
||||
while True:
|
||||
cur = time.time()
|
||||
for path, thread in self.collections.items():
|
||||
if thread.running and thread.wrapper.opened() and thread.qempty() and cur - thread.last_timestamp >= self.monitor_inactivity:
|
||||
self.logger.info("Monitor is closing collection on inactive %s", thread)
|
||||
thread.close()
|
||||
time.sleep(self.monitor_frequency)
|
||||
|
||||
def shutdown(self):
|
||||
# TODO: stop the monitor thread!
|
||||
|
||||
# stop all the threads
|
||||
for path, col in list(self.collections.items()):
|
||||
del self.collections[path]
|
||||
col.stop()
|
||||
|
||||
# let the parent do whatever else it might want to do...
|
||||
super(ThreadingCollectionManager, self).shutdown()
|
||||
|
||||
#
|
||||
# For working with the global ThreadingCollectionManager:
|
||||
#
|
||||
|
||||
collection_manager = None
|
||||
|
||||
def get_collection_manager(config):
|
||||
"""Return the global ThreadingCollectionManager for this process."""
|
||||
global collection_manager
|
||||
if collection_manager is None:
|
||||
collection_manager = ThreadingCollectionManager(config)
|
||||
return collection_manager
|
||||
|
||||
def shutdown():
|
||||
"""If the global ThreadingCollectionManager exists, shut it down."""
|
||||
global collection_manager
|
||||
if collection_manager is not None:
|
||||
collection_manager.shutdown()
|
||||
collection_manager = None
|
||||
|
||||
217
src/ankisyncd/users.py
Normal file
217
src/ankisyncd/users.py
Normal file
@@ -0,0 +1,217 @@
|
||||
# -*- coding: utf-8 -*-
|
||||
import binascii
|
||||
import hashlib
|
||||
import logging
|
||||
import os
|
||||
import sqlite3 as sqlite
|
||||
|
||||
logger = logging.getLogger("ankisyncd.users")
|
||||
|
||||
|
||||
class SimpleUserManager:
|
||||
"""A simple user manager that always allows any user."""
|
||||
|
||||
def __init__(self, collection_path=''):
|
||||
self.collection_path = collection_path
|
||||
|
||||
def authenticate(self, username, password):
|
||||
"""
|
||||
Returns True if this username is allowed to connect with this password.
|
||||
False otherwise. Override this to change how users are authenticated.
|
||||
"""
|
||||
|
||||
return True
|
||||
|
||||
def userdir(self, username):
|
||||
"""
|
||||
Returns the directory name for the given user. By default, this is just
|
||||
the username. Override this to adjust the mapping between users and
|
||||
their directory.
|
||||
"""
|
||||
|
||||
return username
|
||||
|
||||
def _create_user_dir(self, username):
|
||||
user_dir_path = os.path.join(self.collection_path, username)
|
||||
if not os.path.isdir(user_dir_path):
|
||||
logger.info("Creating collection directory for user '{}' at {}"
|
||||
.format(username, user_dir_path))
|
||||
os.makedirs(user_dir_path)
|
||||
|
||||
|
||||
class SqliteUserManager(SimpleUserManager):
|
||||
"""Authenticates users against a SQLite database."""
|
||||
|
||||
def __init__(self, auth_db_path, collection_path=None):
|
||||
SimpleUserManager.__init__(self, collection_path)
|
||||
|
||||
self.auth_db_path = os.path.realpath(auth_db_path)
|
||||
self._ensure_schema_up_to_date()
|
||||
|
||||
def _ensure_schema_up_to_date(self):
|
||||
if not self.auth_db_exists():
|
||||
return True
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'auth'")
|
||||
res = cursor.fetchone()
|
||||
conn.close()
|
||||
if res is not None:
|
||||
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
|
||||
|
||||
# Default to using sqlite3 but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
def auth_db_exists(self):
|
||||
return os.path.isfile(self.auth_db_path)
|
||||
|
||||
# Default to using sqlite3 but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
def _conn(self):
|
||||
return sqlite.connect(self.auth_db_path)
|
||||
|
||||
# Default to using sqlite3 syntax but overridable for sub-classes using other
|
||||
# DB API 2 driver variants
|
||||
@staticmethod
|
||||
def fs(sql):
|
||||
return sql
|
||||
|
||||
def user_list(self):
|
||||
if not self.auth_db_exists():
|
||||
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(self.fs("SELECT username FROM auth"))
|
||||
rows = cursor.fetchall()
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
return [row[0] for row in rows]
|
||||
|
||||
def user_exists(self, username):
|
||||
users = self.user_list()
|
||||
return username in users
|
||||
|
||||
def del_user(self, username):
|
||||
# Warning, this doesn't remove the user directory or clean it
|
||||
if not self.auth_db_exists():
|
||||
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Removing user '{}' from auth db".format(username))
|
||||
cursor.execute(self.fs("DELETE FROM auth WHERE username=?"), (username,))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def add_user(self, username, password):
|
||||
self._add_user_to_auth_db(username, password)
|
||||
self._create_user_dir(username)
|
||||
|
||||
def add_users(self, users_data):
|
||||
for username, password in users_data:
|
||||
self.add_user(username, password)
|
||||
|
||||
def _add_user_to_auth_db(self, username, password):
|
||||
if not self.auth_db_exists():
|
||||
self.create_auth_db()
|
||||
|
||||
pass_hash = self._create_pass_hash(username, password)
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Adding user '{}' to auth db.".format(username))
|
||||
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"),
|
||||
(username, pass_hash))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
def set_password_for_user(self, username, new_password):
|
||||
if not self.auth_db_exists():
|
||||
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
|
||||
elif not self.user_exists(username):
|
||||
raise ValueError("User {} doesn't exist".format(username))
|
||||
|
||||
hash = self._create_pass_hash(username, new_password)
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
logger.info("Changed password for user {}".format(username))
|
||||
|
||||
def authenticate(self, username, password):
|
||||
"""Returns True if this username is allowed to connect with this password. False otherwise."""
|
||||
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
param = (username,)
|
||||
cursor.execute(self.fs("SELECT hash FROM auth WHERE username=?"), param)
|
||||
db_hash = cursor.fetchone()
|
||||
conn.close()
|
||||
|
||||
if db_hash is None:
|
||||
logger.info("Authentication failed for nonexistent user {}."
|
||||
.format(username))
|
||||
return False
|
||||
|
||||
expected_value = str(db_hash[0])
|
||||
salt = self._extract_salt(expected_value)
|
||||
|
||||
hashobj = hashlib.sha256()
|
||||
hashobj.update((username + password + salt).encode())
|
||||
actual_value = hashobj.hexdigest() + salt
|
||||
|
||||
if actual_value == expected_value:
|
||||
logger.info("Authentication succeeded for user {}".format(username))
|
||||
return True
|
||||
else:
|
||||
logger.info("Authentication failed for user {}".format(username))
|
||||
return False
|
||||
|
||||
@staticmethod
|
||||
def _extract_salt(hash):
|
||||
return hash[-16:]
|
||||
|
||||
@staticmethod
|
||||
def _create_pass_hash(username, password):
|
||||
salt = binascii.b2a_hex(os.urandom(8))
|
||||
pass_hash = (hashlib.sha256((username + password).encode() + salt).hexdigest() +
|
||||
salt.decode())
|
||||
return pass_hash
|
||||
|
||||
def create_auth_db(self):
|
||||
conn = self._conn()
|
||||
cursor = conn.cursor()
|
||||
logger.info("Creating auth db at {}."
|
||||
.format(self.auth_db_path))
|
||||
cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth
|
||||
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""))
|
||||
conn.commit()
|
||||
conn.close()
|
||||
|
||||
|
||||
def get_user_manager(config):
|
||||
if "auth_db_path" in config and config["auth_db_path"]:
|
||||
logger.info("Found auth_db_path in config, using SqliteUserManager for auth")
|
||||
return SqliteUserManager(config['auth_db_path'], config['data_root'])
|
||||
elif "user_manager" in config and config["user_manager"]: # load from config
|
||||
logger.info("Found user_manager in config, using {} for auth".format(config['user_manager']))
|
||||
import importlib
|
||||
import inspect
|
||||
module_name, class_name = config['user_manager'].rsplit('.', 1)
|
||||
module = importlib.import_module(module_name.strip())
|
||||
class_ = getattr(module, class_name.strip())
|
||||
|
||||
if not SimpleUserManager in inspect.getmro(class_):
|
||||
raise TypeError('''"user_manager" found in the conf file but it doesn''t
|
||||
inherit from SimpleUserManager''')
|
||||
return class_(config)
|
||||
else:
|
||||
logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password")
|
||||
return SimpleUserManager()
|
||||
70
src/utils/migrate_user_tables.py
Executable file
70
src/utils/migrate_user_tables.py
Executable file
@@ -0,0 +1,70 @@
|
||||
#!/usr/bin/env python3
|
||||
# -*- coding: utf-8 -*-
|
||||
"""
|
||||
This script updates the auth and session sqlite3 databases to use the
|
||||
more compatible `username` column instead of `user`, which is a reserved
|
||||
word in many other SQL dialects.
|
||||
"""
|
||||
import os
|
||||
import sys
|
||||
path = os.path.realpath(os.path.abspath(os.path.join(__file__, '../')))
|
||||
sys.path.insert(0, os.path.dirname(path))
|
||||
|
||||
import sqlite3
|
||||
import ankisyncd.config
|
||||
conf = ankisyncd.config.load()
|
||||
|
||||
|
||||
def main():
|
||||
|
||||
if os.path.isfile(conf["auth_db_path"]):
|
||||
conn = sqlite3.connect(conf["auth_db_path"])
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
|
||||
"AND tbl_name = 'auth'")
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
cursor.execute("ALTER TABLE auth RENAME TO auth_old")
|
||||
cursor.execute("CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)")
|
||||
cursor.execute("INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old")
|
||||
cursor.execute("DROP TABLE auth_old")
|
||||
conn.commit()
|
||||
print("Successfully updated table 'auth'")
|
||||
else:
|
||||
print("No outdated 'auth' table found.")
|
||||
|
||||
conn.close()
|
||||
else:
|
||||
print("No auth DB found at the configured 'auth_db_path' path.")
|
||||
|
||||
if os.path.isfile(conf["session_db_path"]):
|
||||
conn = sqlite3.connect(conf["session_db_path"])
|
||||
|
||||
cursor = conn.cursor()
|
||||
cursor.execute("SELECT * FROM sqlite_master "
|
||||
"WHERE sql LIKE '%user VARCHAR%' "
|
||||
"AND tbl_name = 'session'")
|
||||
res = cursor.fetchone()
|
||||
|
||||
if res is not None:
|
||||
cursor.execute("ALTER TABLE session RENAME TO session_old")
|
||||
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, "
|
||||
"username VARCHAR, path VARCHAR)")
|
||||
cursor.execute("INSERT INTO session (hkey, skey, username, path) "
|
||||
"SELECT hkey, skey, user, path FROM session_old")
|
||||
cursor.execute("DROP TABLE session_old")
|
||||
conn.commit()
|
||||
print("Successfully updated table 'session'")
|
||||
else:
|
||||
print("No outdated 'session' table found.")
|
||||
|
||||
conn.close()
|
||||
else:
|
||||
print("No session DB found at the configured 'session_db_path' path.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
Reference in New Issue
Block a user