Move packages into src folder

This commit is contained in:
Vikash Kothary
2020-07-30 19:19:45 +01:00
parent 125f7bb1b5
commit 09da3d7337
15 changed files with 0 additions and 0 deletions

76
src/addon/__init__.py Normal file
View File

@@ -0,0 +1,76 @@
from PyQt5.Qt import Qt, QCheckBox, QLabel, QHBoxLayout, QLineEdit
from aqt.forms import preferences
from anki.hooks import wrap, addHook
import aqt
import anki.consts
import anki.sync
DEFAULT_ADDR = "http://localhost:27701/"
config = aqt.mw.addonManager.getConfig(__name__)
# TODO: force the user to log out before changing any of the settings
def addui(self, _):
self = self.form
parent_w = self.tab_2
parent_l = self.vboxlayout
self.useCustomServer = QCheckBox(parent_w)
self.useCustomServer.setText("Use custom sync server")
parent_l.addWidget(self.useCustomServer)
cshl = QHBoxLayout()
parent_l.addLayout(cshl)
self.serverAddrLabel = QLabel(parent_w)
self.serverAddrLabel.setText("Server address")
cshl.addWidget(self.serverAddrLabel)
self.customServerAddr = QLineEdit(parent_w)
self.customServerAddr.setPlaceholderText(DEFAULT_ADDR)
cshl.addWidget(self.customServerAddr)
pconfig = getprofileconfig()
if pconfig["enabled"]:
self.useCustomServer.setCheckState(Qt.Checked)
if pconfig["addr"]:
self.customServerAddr.setText(pconfig["addr"])
self.customServerAddr.textChanged.connect(lambda text: updateserver(self, text))
def onchecked(state):
pconfig["enabled"] = state == Qt.Checked
updateui(self, state)
updateserver(self, self.customServerAddr.text())
self.useCustomServer.stateChanged.connect(onchecked)
updateui(self, self.useCustomServer.checkState())
def updateserver(self, text):
pconfig = getprofileconfig()
if pconfig['enabled']:
addr = text or self.customServerAddr.placeholderText()
pconfig['addr'] = addr
setserver()
aqt.mw.addonManager.writeConfig(__name__, config)
def updateui(self, state):
self.serverAddrLabel.setEnabled(state == Qt.Checked)
self.customServerAddr.setEnabled(state == Qt.Checked)
def setserver():
pconfig = getprofileconfig()
if pconfig['enabled']:
aqt.mw.pm.profile['hostNum'] = None
anki.sync.SYNC_BASE = "%s" + pconfig['addr']
else:
anki.sync.SYNC_BASE = anki.consts.SYNC_BASE
def getprofileconfig():
if aqt.mw.pm.name not in config["profiles"]:
# inherit global settings if present (used in earlier versions of the addon)
config["profiles"][aqt.mw.pm.name] = {
"enabled": config.get("enabled", False),
"addr": config.get("addr", DEFAULT_ADDR),
}
aqt.mw.addonManager.writeConfig(__name__, config)
return config["profiles"][aqt.mw.pm.name]
addHook("profileLoaded", setserver)
aqt.preferences.Preferences.__init__ = wrap(aqt.preferences.Preferences.__init__, addui, "after")

1
src/addon/config.json Normal file
View File

@@ -0,0 +1 @@
{"profiles":{}}

82
src/ankisyncctl.py Executable file
View File

@@ -0,0 +1,82 @@
#!/usr/bin/env python
import os
import sys
import getpass
import ankisyncd.config
from ankisyncd.users import get_user_manager
config = ankisyncd.config.load()
def usage():
print("usage: {} <command> [<args>]".format(sys.argv[0]))
print()
print("Commands:")
print(" adduser <username> - add a new user")
print(" deluser <username> - delete a user")
print(" lsuser - list users")
print(" passwd <username> - change password of a user")
def adduser(username):
password = getpass.getpass("Enter password for {}: ".format(username))
user_manager = get_user_manager(config)
user_manager.add_user(username, password)
def deluser(username):
user_manager = get_user_manager(config)
try:
user_manager.del_user(username)
except ValueError as error:
print("Could not delete user {}: {}".format(username, error), file=sys.stderr)
def lsuser():
user_manager = get_user_manager(config)
try:
users = user_manager.user_list()
for username in users:
print(username)
except ValueError as error:
print("Could not list users: {}".format(error), file=sys.stderr)
def passwd(username):
user_manager = get_user_manager(config)
if username not in user_manager.user_list():
print("User {} doesn't exist".format(username))
return
password = getpass.getpass("Enter password for {}: ".format(username))
try:
user_manager.set_password_for_user(username, password)
except ValueError as error:
print("Could not set password for user {}: {}".format(username, error), file=sys.stderr)
def main():
argc = len(sys.argv)
cmds = {
"adduser": adduser,
"deluser": deluser,
"lsuser": lsuser,
"passwd": passwd,
}
if argc < 2:
usage()
exit(1)
c = sys.argv[1]
try:
if argc > 2:
for arg in sys.argv[2:]:
cmds[c](arg)
else:
cmds[c]()
except KeyError:
usage()
exit(1)
if __name__ == "__main__":
main()

20
src/ankisyncd.conf Normal file
View File

@@ -0,0 +1,20 @@
[sync_app]
# change to 127.0.0.1 if you don't want the server to be accessible from the internet
host = 0.0.0.0
port = 27701
data_root = ./collections
base_url = /sync/
base_media_url = /msync/
auth_db_path = ./auth.db
# optional, for session persistence between restarts
session_db_path = ./session.db
# optional, for overriding the default managers and wrappers
# # must inherit from ankisyncd.full_sync.FullSyncManager, e.g,
# full_sync_manager = great_stuff.postgres.PostgresFullSyncManager
# # must inherit from ankisyncd.session.SimpleSessionManager, e.g,
# session_manager = great_stuff.postgres.PostgresSessionManager
# # must inherit from ankisyncd.users.SimpleUserManager, e.g,
# user_manager = great_stuff.postgres.PostgresUserManager
# # must inherit from ankisyncd.collection.CollectionWrapper, e.g,
# collection_wrapper = great_stuff.postgres.PostgresCollectionWrapper

33
src/ankisyncd/__init__.py Normal file
View File

@@ -0,0 +1,33 @@
import os
import sys
sys.path.insert(0, "/usr/share/anki")
sys.path.insert(0, os.path.join(os.path.dirname(os.path.dirname(__file__)), "anki-bundled"))
_homepage = "https://github.com/tsudoko/anki-sync-server"
_unknown_version = "[unknown version]"
def _get_version():
try:
from ankisyncd._version import version
return version
except ImportError:
pass
import subprocess
try:
return (
subprocess.run(
["git", "describe", "--always"],
stdout=subprocess.PIPE,
stderr=subprocess.PIPE,
)
.stdout.strip()
.decode()
or _unknown_version
)
except (FileNotFoundError, subprocess.CalledProcessError):
return _unknown_version

11
src/ankisyncd/__main__.py Normal file
View File

@@ -0,0 +1,11 @@
import sys
if __package__ is None and not hasattr(sys, "frozen"):
import os.path
path = os.path.realpath(os.path.abspath(__file__))
sys.path.insert(0, os.path.dirname(os.path.dirname(path)))
import ankisyncd.sync_app
if __name__ == "__main__":
ankisyncd.sync_app.main()

138
src/ankisyncd/collection.py Normal file
View File

@@ -0,0 +1,138 @@
import anki
import anki.storage
import ankisyncd.media
import os, errno
import logging
logger = logging.getLogger("ankisyncd.collection")
class CollectionWrapper:
"""A simple wrapper around an anki.storage.Collection object.
This allows us to manage and refer to the collection, whether it's open or not. It
also provides a special "continuation passing" interface for executing functions
on the collection, which makes it easy to switch to a threading mode.
See ThreadingCollectionWrapper for a version that maintains a seperate thread for
interacting with the collection.
"""
def __init__(self, _config, path, setup_new_collection=None):
self.path = os.path.realpath(path)
self.username = os.path.basename(os.path.dirname(self.path))
self.setup_new_collection = setup_new_collection
self.__col = None
def __del__(self):
"""Close the collection if the user forgot to do so."""
self.close()
def execute(self, func, args=[], kw={}, waitForReturn=True):
""" Executes the given function with the underlying anki.storage.Collection
object as the first argument and any additional arguments specified by *args
and **kw.
If 'waitForReturn' is True, then it will block until the function has
executed and return its return value. If False, the function MAY be
executed some time later and None will be returned.
"""
# Open the collection and execute the function
self.open()
args = [self.__col] + args
ret = func(*args, **kw)
# Only return the value if they requested it, so the interface remains
# identical between this class and ThreadingCollectionWrapper
if waitForReturn:
return ret
def __create_collection(self):
"""Creates a new collection and runs any special setup."""
# mkdir -p the path, because it might not exist
os.makedirs(os.path.dirname(self.path), exist_ok=True)
col = self._get_collection()
# Do any special setup
if self.setup_new_collection is not None:
self.setup_new_collection(col)
return col
def _get_collection(self):
col = anki.storage.Collection(self.path)
# Ugly hack, replace default media manager with our custom one
col.media.close()
col.media = ankisyncd.media.ServerMediaManager(col)
return col
def open(self):
"""Open the collection, or create it if it doesn't exist."""
if self.__col is None:
if os.path.exists(self.path):
self.__col = self._get_collection()
else:
self.__col = self.__create_collection()
def close(self):
"""Close the collection if opened."""
if not self.opened():
return
self.__col.close()
self.__col = None
def opened(self):
"""Returns True if the collection is open, False otherwise."""
return self.__col is not None
class CollectionManager:
"""Manages a set of CollectionWrapper objects."""
collection_wrapper = CollectionWrapper
def __init__(self, config):
self.collections = {}
self.config = config
def get_collection(self, path, setup_new_collection=None):
"""Gets a CollectionWrapper for the given path."""
path = os.path.realpath(path)
try:
col = self.collections[path]
except KeyError:
col = self.collections[path] = self.collection_wrapper(self.config, path, setup_new_collection)
return col
def shutdown(self):
"""Close all CollectionWrappers managed by this object."""
for path, col in list(self.collections.items()):
del self.collections[path]
col.close()
def get_collection_wrapper(config, path, setup_new_collection = None):
if "collection_wrapper" in config and config["collection_wrapper"]:
logger.info("Found collection_wrapper in config, using {} for "
"user data persistence".format(config['collection_wrapper']))
import importlib
import inspect
module_name, class_name = config['collection_wrapper'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not CollectionWrapper in inspect.getmro(class_):
raise TypeError('''"collection_wrapper" found in the conf file but it doesn''t
inherit from CollectionWrapper''')
return class_(config, path, setup_new_collection)
else:
return CollectionWrapper(config, path, setup_new_collection)

42
src/ankisyncd/config.py Normal file
View File

@@ -0,0 +1,42 @@
import configparser
import logging
import os
from os.path import dirname, realpath
logger = logging.getLogger("ankisyncd")
paths = [
"/etc/ankisyncd/ankisyncd.conf",
os.environ.get("XDG_CONFIG_HOME") and
(os.path.join(os.environ['XDG_CONFIG_HOME'], "ankisyncd", "ankisyncd.conf")) or
os.path.join(os.path.expanduser("~"), ".config", "ankisyncd", "ankisyncd.conf"),
os.path.join(dirname(dirname(realpath(__file__))), "ankisyncd.conf"),
]
# Get values from ENV and update the config. To use this prepend `ANKISYNCD_`
# to the uppercase form of the key. E.g, `ANKISYNCD_SESSION_MANAGER` to set
# `session_manager`
def load_from_env(conf):
logger.debug("Loading/overriding config values from ENV")
for env in os.environ:
if env.startswith('ANKISYNCD_'):
config_key = env[10:].lower()
conf[config_key] = os.getenv(env)
logger.info("Setting {} from ENV".format(config_key))
def load(path=None):
choices = paths
parser = configparser.ConfigParser()
if path:
choices = [path]
for path in choices:
logger.debug("config.location: trying", path)
try:
parser.read(path)
conf = parser['sync_app']
logger.info("Loaded config from {}".format(path))
load_from_env(conf)
return conf
except KeyError:
pass
raise Exception("No config found, looked for {}".format(", ".join(choices)))

View File

@@ -0,0 +1,59 @@
# -*- coding: utf-8 -*-
import os
from sqlite3 import dbapi2 as sqlite
import anki.db
class FullSyncManager:
def upload(self, col, data, session):
# Verify integrity of the received database file before replacing our
# existing db.
temp_db_path = session.get_collection_path() + ".tmp"
with open(temp_db_path, 'wb') as f:
f.write(data)
try:
with anki.db.DB(temp_db_path) as test_db:
if test_db.scalar("pragma integrity_check") != "ok":
raise HTTPBadRequest("Integrity check failed for uploaded "
"collection database file.")
except sqlite.Error as e:
raise HTTPBadRequest("Uploaded collection database file is "
"corrupt.")
# Overwrite existing db.
col.close()
try:
os.replace(temp_db_path, session.get_collection_path())
finally:
col.reopen()
col.load()
return "OK"
def download(self, col, session):
col.close()
try:
data = open(session.get_collection_path(), 'rb').read()
finally:
col.reopen()
col.load()
return data
def get_full_sync_manager(config):
if "full_sync_manager" in config and config["full_sync_manager"]: # load from config
import importlib
import inspect
module_name, class_name = config['full_sync_manager'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not FullSyncManager in inspect.getmro(class_):
raise TypeError('''"full_sync_manager" found in the conf file but it doesn''t
inherit from FullSyncManager''')
return class_(config)
else:
return FullSyncManager()

67
src/ankisyncd/media.py Normal file
View File

@@ -0,0 +1,67 @@
# Based on anki.media.MediaManager, © Ankitects Pty Ltd and contributors
# License: GNU AGPL, version 3 or later; http://www.gnu.org/licenses/agpl.html
# Original source: https://raw.githubusercontent.com/dae/anki/62481ddc1aa78430cb8114cbf00a7739824318a8/anki/media.py
import logging
import re
import os
import os.path
import anki.db
logger = logging.getLogger("ankisyncd.media")
class ServerMediaManager:
def __init__(self, col):
self._dir = re.sub(r"(?i)\.(anki2)$", ".media", col.path)
self.connect()
def connect(self):
path = self.dir() + ".server.db"
create = not os.path.exists(path)
self.db = anki.db.DB(path)
if create:
self.db.executescript(
"""CREATE TABLE media (
fname TEXT NOT NULL PRIMARY KEY,
usn INT NOT NULL,
csum TEXT -- null if deleted
);
CREATE INDEX idx_media_usn ON media (usn);"""
)
oldpath = self.dir() + ".db2"
if os.path.exists(oldpath):
logger.info("Found client media database, migrating contents")
self.db.execute("ATTACH ? AS old", oldpath)
self.db.execute(
"INSERT INTO media SELECT fname, lastUsn, csum FROM old.media, old.meta"
)
self.db.commit()
self.db.execute("DETACH old")
def close(self):
self.db.close()
def dir(self):
return self._dir
def lastUsn(self):
return self.db.scalar("SELECT max(usn) FROM media") or 0
def mediaCount(self):
return self.db.scalar("SELECT count() FROM media WHERE csum IS NOT NULL")
# used only in unit tests
def syncInfo(self, fname):
return self.db.first("SELECT csum, 0 FROM media WHERE fname=?", fname)
def syncDelete(self, fname):
fpath = os.path.join(self.dir(), fname)
if os.path.exists(fpath):
os.remove(fpath)
self.db.execute(
"UPDATE media SET csum = NULL, usn = ? WHERE fname = ?",
self.lastUsn() + 1,
fname,
)

141
src/ankisyncd/sessions.py Normal file
View File

@@ -0,0 +1,141 @@
# -*- coding: utf-8 -*-
import os
import logging
from sqlite3 import dbapi2 as sqlite
logger = logging.getLogger("ankisyncd.sessions")
class SimpleSessionManager:
"""A simple session manager that keeps the sessions in memory."""
def __init__(self):
self.sessions = {}
def load(self, hkey, session_factory=None):
return self.sessions.get(hkey)
def load_from_skey(self, skey, session_factory=None):
for i in self.sessions:
if self.sessions[i].skey == skey:
return self.sessions[i]
def save(self, hkey, session):
self.sessions[hkey] = session
def delete(self, hkey):
del self.sessions[hkey]
class SqliteSessionManager(SimpleSessionManager):
"""Stores sessions in a SQLite database to prevent the user from being logged out
everytime the SyncApp is restarted."""
def __init__(self, session_db_path):
SimpleSessionManager.__init__(self)
self.session_db_path = os.path.realpath(session_db_path)
self._ensure_schema_up_to_date()
def _ensure_schema_up_to_date(self):
if not os.path.exists(self.session_db_path):
return True
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'session'")
res = cursor.fetchone()
conn.close()
if res is not None:
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
def _conn(self):
new = not os.path.exists(self.session_db_path)
conn = sqlite.connect(self.session_db_path)
if new:
cursor = conn.cursor()
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, username VARCHAR, path VARCHAR)")
return conn
# Default to using sqlite3 syntax but overridable for sub-classes using other
# DB API 2 driver variants
@staticmethod
def fs(sql):
return sql
def load(self, hkey, session_factory=None):
session = SimpleSessionManager.load(self, hkey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("SELECT skey, username, path FROM session WHERE hkey=?"), (hkey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[hkey] = session_factory(res[1], res[2])
session.skey = res[0]
return session
def load_from_skey(self, skey, session_factory=None):
session = SimpleSessionManager.load_from_skey(self, skey)
if session is not None:
return session
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("SELECT hkey, username, path FROM session WHERE skey=?"), (skey,))
res = cursor.fetchone()
if res is not None:
session = self.sessions[res[0]] = session_factory(res[1], res[2])
session.skey = skey
return session
def save(self, hkey, session):
SimpleSessionManager.save(self, hkey, session)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("INSERT OR REPLACE INTO session (hkey, skey, username, path) VALUES (?, ?, ?, ?)",
(hkey, session.skey, session.name, session.path))
conn.commit()
def delete(self, hkey):
SimpleSessionManager.delete(self, hkey)
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("DELETE FROM session WHERE hkey=?"), (hkey,))
conn.commit()
def get_session_manager(config):
if "session_db_path" in config and config["session_db_path"]:
logger.info("Found session_db_path in config, using SqliteSessionManager for auth")
return SqliteSessionManager(config['session_db_path'])
elif "session_manager" in config and config["session_manager"]: # load from config
logger.info("Found session_manager in config, using {} for persisting sessions".format(
config['session_manager'])
)
import importlib
import inspect
module_name, class_name = config['session_manager'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not SimpleSessionManager in inspect.getmro(class_):
raise TypeError('''"session_manager" found in the conf file but it doesn''t
inherit from SimpleSessionManager''')
return class_(config)
else:
logger.warning("Neither session_db_path nor session_manager set, "
"ankisyncd will lose sessions on application restart")
return SimpleSessionManager()

675
src/ankisyncd/sync_app.py Normal file
View File

@@ -0,0 +1,675 @@
# ankisyncd - A personal Anki sync server
# Copyright (C) 2013 David Snopek
#
# This program is free software: you can redistribute it and/or modify
# it under the terms of the GNU Affero General Public License as
# published by the Free Software Foundation, either version 3 of the
# License, or (at your option) any later version.
#
# This program is distributed in the hope that it will be useful,
# but WITHOUT ANY WARRANTY; without even the implied warranty of
# MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
# GNU Affero General Public License for more details.
#
# You should have received a copy of the GNU Affero General Public License
# along with this program. If not, see <http://www.gnu.org/licenses/>.
import gzip
import hashlib
import io
import json
import logging
import os
import random
import re
import string
import sys
import time
import unicodedata
import zipfile
from configparser import ConfigParser
from sqlite3 import dbapi2 as sqlite
from webob import Response
from webob.dec import wsgify
from webob.exc import *
import anki.db
import anki.sync
import anki.utils
from anki.consts import SYNC_VER, SYNC_ZIP_SIZE, SYNC_ZIP_COUNT
from anki.consts import REM_CARD, REM_NOTE
from ankisyncd.users import get_user_manager
from ankisyncd.sessions import get_session_manager
from ankisyncd.full_sync import get_full_sync_manager
logger = logging.getLogger("ankisyncd")
class SyncCollectionHandler(anki.sync.Syncer):
operations = ['meta', 'applyChanges', 'start', 'applyGraves', 'chunk', 'applyChunk', 'sanityCheck2', 'finish']
def __init__(self, col):
# So that 'server' (the 3rd argument) can't get set
anki.sync.Syncer.__init__(self, col)
@staticmethod
def _old_client(cv):
if not cv:
return False
note = {"alpha": 0, "beta": 0, "rc": 0}
client, version, platform = cv.split(',')
for name in note.keys():
if name in version:
vs = version.split(name)
version = vs[0]
note[name] = int(vs[-1])
# convert the version string, ignoring non-numeric suffixes like in beta versions of Anki
version_nosuffix = re.sub(r'[^0-9.].*$', '', version)
version_int = [int(x) for x in version_nosuffix.split('.')]
if client == 'ankidesktop':
return version_int < [2, 0, 27]
elif client == 'ankidroid':
if version_int == [2, 3]:
if note["alpha"]:
return note["alpha"] < 4
else:
return version_int < [2, 2, 3]
else: # unknown client, assume current version
return False
def meta(self, v=None, cv=None):
if self._old_client(cv):
return Response(status=501) # client needs upgrade
if v > SYNC_VER:
return {"cont": False, "msg": "Your client is using unsupported sync protocol ({}, supported version: {})".format(v, SYNC_VER)}
if v < 9 and self.col.schedVer() >= 2:
return {"cont": False, "msg": "Your client doesn't support the v{} scheduler.".format(self.col.schedVer())}
# Make sure the media database is open!
if self.col.media.db is None:
self.col.media.connect()
return {
'scm': self.col.scm,
'ts': anki.utils.intTime(),
'mod': self.col.mod,
'usn': self.col._usn,
'musn': self.col.media.lastUsn(),
'msg': '',
'cont': True,
}
def usnLim(self):
return "usn >= %d" % self.minUsn
# ankidesktop >=2.1rc2 sends graves in applyGraves, but still expects
# server-side deletions to be returned by start
def start(self, minUsn, lnewer, graves={"cards": [], "notes": [], "decks": []}, offset=None):
if offset is not None:
raise NotImplementedError('You are using the experimental V2 scheduler, which is not supported by the server.')
self.maxUsn = self.col._usn
self.minUsn = minUsn
self.lnewer = not lnewer
lgraves = self.removed()
self.remove(graves)
return lgraves
def applyGraves(self, chunk):
self.remove(chunk)
def applyChanges(self, changes):
self.rchg = changes
lchg = self.changes()
# merge our side before returning
self.mergeChanges(lchg, self.rchg)
return lchg
def sanityCheck2(self, client):
server = self.sanityCheck()
if client != server:
return dict(status="bad", c=client, s=server)
return dict(status="ok")
def finish(self, mod=None):
return anki.sync.Syncer.finish(self, anki.utils.intTime(1000))
# This function had to be put here in its entirety because Syncer.removed()
# doesn't use self.usnLim() (which we override in this class) in queries.
# "usn=-1" has been replaced with "usn >= ?", self.minUsn by hand.
def removed(self):
cards = []
notes = []
decks = []
curs = self.col.db.execute(
"select oid, type from graves where usn >= ?", self.minUsn)
for oid, type in curs:
if type == REM_CARD:
cards.append(oid)
elif type == REM_NOTE:
notes.append(oid)
else:
decks.append(oid)
return dict(cards=cards, notes=notes, decks=decks)
def getModels(self):
return [m for m in self.col.models.all() if m['usn'] >= self.minUsn]
def getDecks(self):
return [
[g for g in self.col.decks.all() if g['usn'] >= self.minUsn],
[g for g in self.col.decks.allConf() if g['usn'] >= self.minUsn]
]
def getTags(self):
return [t for t, usn in self.col.tags.allItems()
if usn >= self.minUsn]
class SyncMediaHandler:
operations = ['begin', 'mediaChanges', 'mediaSanity', 'uploadChanges', 'downloadFiles']
def __init__(self, col):
self.col = col
def begin(self, skey):
return {
'data': {
'sk': skey,
'usn': self.col.media.lastUsn(),
},
'err': '',
}
def uploadChanges(self, data):
"""
The zip file contains files the client hasn't synced with the server
yet ('dirty'), and info on files it has deleted from its own media dir.
"""
with zipfile.ZipFile(io.BytesIO(data), "r") as z:
self._check_zip_data(z)
processed_count = self._adopt_media_changes_from_zip(z)
return {
'data': [processed_count, self.col.media.lastUsn()],
'err': '',
}
@staticmethod
def _check_zip_data(zip_file):
max_zip_size = 100*1024*1024
max_meta_file_size = 100000
meta_file_size = zip_file.getinfo("_meta").file_size
sum_file_sizes = sum(info.file_size for info in zip_file.infolist())
if meta_file_size > max_meta_file_size:
raise ValueError("Zip file's metadata file is larger than %s "
"Bytes." % max_meta_file_size)
elif sum_file_sizes > max_zip_size:
raise ValueError("Zip file contents are larger than %s Bytes." %
max_zip_size)
def _adopt_media_changes_from_zip(self, zip_file):
"""
Adds and removes files to/from the database and media directory
according to the data in zip file zipData.
"""
# Get meta info first.
meta = json.loads(zip_file.read("_meta").decode())
# Remove media files that were removed on the client.
media_to_remove = []
for normname, ordinal in meta:
if ordinal == '':
media_to_remove.append(self._normalize_filename(normname))
# Add media files that were added on the client.
media_to_add = []
usn = self.col.media.lastUsn()
oldUsn = usn
for i in zip_file.infolist():
if i.filename == "_meta": # Ignore previously retrieved metadata.
continue
file_data = zip_file.read(i)
csum = anki.utils.checksum(file_data)
filename = self._normalize_filename(meta[int(i.filename)][0])
file_path = os.path.join(self.col.media.dir(), filename)
# Save file to media directory.
with open(file_path, 'wb') as f:
f.write(file_data)
usn += 1
media_to_add.append((filename, usn, csum))
# We count all files we are to remove, even if we don't have them in
# our media directory and our db doesn't know about them.
processed_count = len(media_to_remove) + len(media_to_add)
assert len(meta) == processed_count # sanity check
if media_to_remove:
self._remove_media_files(media_to_remove)
if media_to_add:
self.col.media.db.executemany(
"INSERT OR REPLACE INTO media VALUES (?,?,?)", media_to_add)
self.col.media.db.commit()
assert self.col.media.lastUsn() == oldUsn + processed_count # TODO: move to some unit test
return processed_count
@staticmethod
def _normalize_filename(filename):
"""
Performs unicode normalization for file names. Logic taken from Anki's
MediaManager.addFilesFromZip().
"""
# Normalize name for platform.
if anki.utils.isMac: # global
filename = unicodedata.normalize("NFD", filename)
else:
filename = unicodedata.normalize("NFC", filename)
return filename
def _remove_media_files(self, filenames):
"""
Marks all files in list filenames as deleted and removes them from the
media directory.
"""
logger.debug('Removing %d files from media dir.' % len(filenames))
for filename in filenames:
try:
self.col.media.syncDelete(filename)
self.col.media.db.commit()
except OSError as err:
logger.error("Error when removing file '%s' from media dir: "
"%s" % (filename, str(err)))
def downloadFiles(self, files):
flist = {}
cnt = 0
sz = 0
f = io.BytesIO()
with zipfile.ZipFile(f, "w", compression=zipfile.ZIP_DEFLATED) as z:
for fname in files:
z.write(os.path.join(self.col.media.dir(), fname), str(cnt))
flist[str(cnt)] = fname
sz += os.path.getsize(os.path.join(self.col.media.dir(), fname))
if sz > SYNC_ZIP_SIZE or cnt > SYNC_ZIP_COUNT:
break
cnt += 1
z.writestr("_meta", json.dumps(flist))
return f.getvalue()
def mediaChanges(self, lastUsn):
result = []
server_lastUsn = self.col.media.lastUsn()
fname = csum = None
if lastUsn < server_lastUsn or lastUsn == 0:
for fname,usn,csum, in self.col.media.db.execute("select fname,usn,csum from media order by usn desc limit ?", server_lastUsn - lastUsn):
result.append([fname, usn, csum])
# anki assumes server_lastUsn == result[-1][1]
# ref: anki/sync.py:720 (commit cca3fcb2418880d0430a5c5c2e6b81ba260065b7)
result.reverse()
return {'data': result, 'err': ''}
def mediaSanity(self, local=None):
if self.col.media.mediaCount() == local:
result = "OK"
else:
result = "FAILED"
return {'data': result, 'err': ''}
class SyncUserSession:
def __init__(self, name, path, collection_manager, setup_new_collection=None):
self.skey = self._generate_session_key()
self.name = name
self.path = path
self.collection_manager = collection_manager
self.setup_new_collection = setup_new_collection
self.version = None
self.client_version = None
self.created = time.time()
self.collection_handler = None
self.media_handler = None
# make sure the user path exists
if not os.path.exists(path):
os.mkdir(path)
def _generate_session_key(self):
return anki.utils.checksum(str(random.random()))[:8]
def get_collection_path(self):
return os.path.realpath(os.path.join(self.path, 'collection.anki2'))
def get_thread(self):
return self.collection_manager.get_collection(self.get_collection_path(), self.setup_new_collection)
def get_handler_for_operation(self, operation, col):
if operation in SyncCollectionHandler.operations:
attr, handler_class = 'collection_handler', SyncCollectionHandler
elif operation in SyncMediaHandler.operations:
attr, handler_class = 'media_handler', SyncMediaHandler
else:
raise Exception("no handler for {}".format(operation))
if getattr(self, attr) is None:
setattr(self, attr, handler_class(col))
handler = getattr(self, attr)
# The col object may actually be new now! This happens when we close a collection
# for inactivity and then later re-open it (creating a new Collection object).
handler.col = col
return handler
class SyncApp:
valid_urls = SyncCollectionHandler.operations + SyncMediaHandler.operations + ['hostKey', 'upload', 'download']
def __init__(self, config):
from ankisyncd.thread import get_collection_manager
self.data_root = os.path.abspath(config['data_root'])
self.base_url = config['base_url']
self.base_media_url = config['base_media_url']
self.setup_new_collection = None
self.prehooks = {}
self.posthooks = {}
self.user_manager = get_user_manager(config)
self.session_manager = get_session_manager(config)
self.full_sync_manager = get_full_sync_manager(config)
self.collection_manager = get_collection_manager(config)
# make sure the base_url has a trailing slash
if not self.base_url.endswith('/'):
self.base_url += '/'
if not self.base_media_url.endswith('/'):
self.base_media_url += '/'
# backwards compat
@property
def hook_pre_sync(self):
return self.prehooks.get("start")
@hook_pre_sync.setter
def hook_pre_sync(self, value):
self.prehooks['start'] = value
@property
def hook_post_sync(self):
return self.posthooks.get("finish")
@hook_post_sync.setter
def hook_post_sync(self, value):
self.posthooks['finish'] = value
@property
def hook_upload(self):
return self.prehooks.get("upload")
@hook_upload.setter
def hook_upload(self, value):
self.prehooks['upload'] = value
@property
def hook_download(self):
return self.posthooks.get("download")
@hook_download.setter
def hook_download(self, value):
self.posthooks['download'] = value
def generateHostKey(self, username):
"""Generates a new host key to be used by the given username to identify their session.
This values is random."""
import hashlib, time, random, string
chars = string.ascii_letters + string.digits
val = ':'.join([username, str(int(time.time())), ''.join(random.choice(chars) for x in range(8))]).encode()
return hashlib.md5(val).hexdigest()
def create_session(self, username, user_path):
return SyncUserSession(username, user_path, self.collection_manager, self.setup_new_collection)
def _decode_data(self, data, compression=0):
if compression:
with gzip.GzipFile(mode="rb", fileobj=io.BytesIO(data)) as gz:
data = gz.read()
try:
data = json.loads(data.decode())
except (ValueError, UnicodeDecodeError):
data = {'data': data}
return data
def operation_hostKey(self, username, password):
if not self.user_manager.authenticate(username, password):
return
dirname = self.user_manager.userdir(username)
if dirname is None:
return
hkey = self.generateHostKey(username)
user_path = os.path.join(self.data_root, dirname)
session = self.create_session(username, user_path)
self.session_manager.save(hkey, session)
return {'key': hkey}
def operation_upload(self, col, data, session):
# Verify integrity of the received database file before replacing our
# existing db.
return self.full_sync_manager.upload(col, data, session)
def operation_download(self, col, session):
# returns user data (not media) as a sqlite3 database for replacing their
# local copy in Anki
return self.full_sync_manager.download(col, session)
@wsgify
def __call__(self, req):
# Get and verify the session
try:
hkey = req.POST['k']
except KeyError:
hkey = None
session = self.session_manager.load(hkey, self.create_session)
if session is None:
try:
skey = req.POST['sk']
session = self.session_manager.load_from_skey(skey, self.create_session)
except KeyError:
skey = None
try:
compression = int(req.POST['c'])
except KeyError:
compression = 0
try:
data = req.POST['data'].file.read()
data = self._decode_data(data, compression)
except KeyError:
data = {}
if req.path.startswith(self.base_url):
url = req.path[len(self.base_url):]
if url not in self.valid_urls:
raise HTTPNotFound()
if url == 'hostKey':
result = self.operation_hostKey(data.get("u"), data.get("p"))
if result:
return json.dumps(result)
else:
# TODO: do I have to pass 'null' for the client to receive None?
raise HTTPForbidden('null')
if session is None:
raise HTTPForbidden()
if url in SyncCollectionHandler.operations + SyncMediaHandler.operations:
# 'meta' passes the SYNC_VER but it isn't used in the handler
if url == 'meta':
if session.skey == None and 's' in req.POST:
session.skey = req.POST['s']
if 'v' in data:
session.version = data['v']
if 'cv' in data:
session.client_version = data['cv']
self.session_manager.save(hkey, session)
session = self.session_manager.load(hkey, self.create_session)
thread = session.get_thread()
if url in self.prehooks:
thread.execute(self.prehooks[url], [session])
result = self._execute_handler_method_in_thread(url, data, session)
# If it's a complex data type, we convert it to JSON
if type(result) not in (str, bytes, Response):
result = json.dumps(result)
if url in self.posthooks:
thread.execute(self.posthooks[url], [session])
return result
elif url == 'upload':
thread = session.get_thread()
if url in self.prehooks:
thread.execute(self.prehooks[url], [session])
result = thread.execute(self.operation_upload, [data['data'], session])
if url in self.posthooks:
thread.execute(self.posthooks[url], [session])
return result
elif url == 'download':
thread = session.get_thread()
if url in self.prehooks:
thread.execute(self.prehooks[url], [session])
result = thread.execute(self.operation_download, [session])
if url in self.posthooks:
thread.execute(self.posthooks[url], [session])
return result
# This was one of our operations but it didn't get handled... Oops!
raise HTTPInternalServerError()
# media sync
elif req.path.startswith(self.base_media_url):
if session is None:
raise HTTPForbidden()
url = req.path[len(self.base_media_url):]
if url not in self.valid_urls:
raise HTTPNotFound()
if url == "begin":
data['skey'] = session.skey
result = self._execute_handler_method_in_thread(url, data, session)
# If it's a complex data type, we convert it to JSON
if type(result) not in (str, bytes):
result = json.dumps(result)
return result
return "Anki Sync Server"
@staticmethod
def _execute_handler_method_in_thread(method_name, keyword_args, session):
"""
Gets and runs the handler method specified by method_name inside the
thread for session. The handler method will access the collection as
self.col.
"""
def run_func(col, **keyword_args):
# Retrieve the correct handler method.
handler = session.get_handler_for_operation(method_name, col)
handler_method = getattr(handler, method_name)
res = handler_method(**keyword_args)
col.save()
return res
run_func.__name__ = method_name # More useful debugging messages.
# Send the closure to the thread for execution.
thread = session.get_thread()
result = thread.execute(run_func, kw=keyword_args)
return result
def make_app(global_conf, **local_conf):
return SyncApp(**local_conf)
def main():
logging.basicConfig(level=logging.INFO, format="[%(asctime)s]:%(levelname)s:%(name)s:%(message)s")
import ankisyncd
logger.info("ankisyncd {} ({})".format(ankisyncd._get_version(), ankisyncd._homepage))
from wsgiref.simple_server import make_server, WSGIRequestHandler
from ankisyncd.thread import shutdown
import ankisyncd.config
class RequestHandler(WSGIRequestHandler):
logger = logging.getLogger("ankisyncd.http")
def log_error(self, format, *args):
self.logger.error("%s %s", self.address_string(), format%args)
def log_message(self, format, *args):
self.logger.info("%s %s", self.address_string(), format%args)
if len(sys.argv) > 1:
# backwards compat
config = ankisyncd.config.load(sys.argv[1])
else:
config = ankisyncd.config.load()
ankiserver = SyncApp(config)
httpd = make_server(config['host'], int(config['port']), ankiserver, handler_class=RequestHandler)
try:
logger.info("Serving HTTP on {} port {}...".format(*httpd.server_address))
httpd.serve_forever()
except KeyboardInterrupt:
logger.info("Exiting...")
finally:
shutdown()
if __name__ == '__main__': main()

218
src/ankisyncd/thread.py Normal file
View File

@@ -0,0 +1,218 @@
from ankisyncd.collection import CollectionManager, get_collection_wrapper
from threading import Thread
from queue import Queue
import time, logging
def short_repr(obj, logger=logging.getLogger(), maxlen=80):
"""Like repr, but shortens strings and bytestrings if logger's logging level
is above DEBUG. Currently shallow and very limited, only implemented for
dicts and lists."""
if logger.isEnabledFor(logging.DEBUG):
return repr(obj)
def shorten(s):
if isinstance(s, (bytes, str)) and len(s) > maxlen:
return s[:maxlen] + ("..." if isinstance(s, str) else b"...")
else:
return s
o = obj.copy()
if isinstance(o, dict):
for k in o:
o[k] = shorten(o[k])
elif isinstance(o, list):
for k in range(len(o)):
o[k] = shorten(o[k])
return repr(o)
class ThreadingCollectionWrapper:
"""Provides the same interface as CollectionWrapper, but it creates a new Thread to
interact with the collection."""
def __init__(self, config, path, setup_new_collection=None):
self.path = path
self.wrapper = get_collection_wrapper(config, path, setup_new_collection)
self.logger = logging.getLogger("ankisyncd." + str(self))
self._queue = Queue()
self._thread = None
self._running = False
self.last_timestamp = time.time()
self.start()
def __str__(self):
return "CollectionThread[{}]".format(self.wrapper.username)
@property
def running(self):
return self._running
def qempty(self):
return self._queue.empty()
def current(self):
from threading import current_thread
return current_thread() == self._thread
def execute(self, func, args=[], kw={}, waitForReturn=True):
""" Executes a given function on this thread with the *args and **kw.
If 'waitForReturn' is True, then it will block until the function has
executed and return its return value. If False, it will return None
immediately and the function will be executed sometime later.
"""
if waitForReturn:
return_queue = Queue()
else:
return_queue = None
self._queue.put((func, args, kw, return_queue))
if return_queue is not None:
ret = return_queue.get(True)
if isinstance(ret, Exception):
raise ret
return ret
def _run(self):
self.logger.info("Starting...")
try:
while self._running:
func, args, kw, return_queue = self._queue.get(True)
if hasattr(func, '__name__'):
func_name = func.__name__
else:
func_name = func.__class__.__name__
self.logger.info("Running %s(*%s, **%s)", func_name, short_repr(args, self.logger), short_repr(kw, self.logger))
self.last_timestamp = time.time()
try:
ret = self.wrapper.execute(func, args, kw, return_queue)
except Exception as e:
self.logger.error("Unable to %s(*%s, **%s): %s",
func_name, repr(args), repr(kw), e, exc_info=True)
# we return the Exception which will be raise'd on the other end
ret = e
if return_queue is not None:
return_queue.put(ret)
except Exception as e:
self.logger.error("Thread crashed! Exception: %s", e, exc_info=True)
finally:
self.wrapper.close()
# clean out old thread object
self._thread = None
# in case we got here via an exception
self._running = False
self.logger.info("Stopped!")
def start(self):
if not self._running:
self._running = True
assert self._thread is None
self._thread = Thread(target=self._run)
self._thread.start()
def stop(self):
def _stop(col):
self._running = False
self.execute(_stop, waitForReturn=False)
def stop_and_wait(self):
""" Tell the thread to stop and wait for it to happen. """
self.stop()
if self._thread is not None:
self._thread.join()
#
# Mimic the CollectionWrapper interface
#
def open(self):
"""Non-op. The collection will be opened on demand."""
pass
def close(self):
"""Closes the underlying collection without stopping the thread."""
def _close(col):
self.wrapper.close()
self.execute(_close, waitForReturn=False)
def opened(self):
return self.wrapper.opened()
class ThreadingCollectionManager(CollectionManager):
"""Manages a set of ThreadingCollectionWrapper objects."""
collection_wrapper = ThreadingCollectionWrapper
def __init__(self, config):
super(ThreadingCollectionManager, self).__init__(config)
self.monitor_frequency = 15
self.monitor_inactivity = 90
self.logger = logging.getLogger("ankisyncd.ThreadingCollectionManager")
monitor = Thread(target=self._monitor_run)
monitor.daemon = True
monitor.start()
self._monitor_thread = monitor
# TODO: we should raise some error if a collection is started on a manager that has already been shutdown!
# or maybe we could support being restarted?
# TODO: it would be awesome to have a safe way to stop inactive threads completely!
# TODO: we need a way to inform other code that the collection has been closed
def _monitor_run(self):
""" Monitors threads for inactivity and closes the collection on them
(leaves the thread itself running -- hopefully waiting peacefully with only a
small memory footprint!) """
while True:
cur = time.time()
for path, thread in self.collections.items():
if thread.running and thread.wrapper.opened() and thread.qempty() and cur - thread.last_timestamp >= self.monitor_inactivity:
self.logger.info("Monitor is closing collection on inactive %s", thread)
thread.close()
time.sleep(self.monitor_frequency)
def shutdown(self):
# TODO: stop the monitor thread!
# stop all the threads
for path, col in list(self.collections.items()):
del self.collections[path]
col.stop()
# let the parent do whatever else it might want to do...
super(ThreadingCollectionManager, self).shutdown()
#
# For working with the global ThreadingCollectionManager:
#
collection_manager = None
def get_collection_manager(config):
"""Return the global ThreadingCollectionManager for this process."""
global collection_manager
if collection_manager is None:
collection_manager = ThreadingCollectionManager(config)
return collection_manager
def shutdown():
"""If the global ThreadingCollectionManager exists, shut it down."""
global collection_manager
if collection_manager is not None:
collection_manager.shutdown()
collection_manager = None

217
src/ankisyncd/users.py Normal file
View File

@@ -0,0 +1,217 @@
# -*- coding: utf-8 -*-
import binascii
import hashlib
import logging
import os
import sqlite3 as sqlite
logger = logging.getLogger("ankisyncd.users")
class SimpleUserManager:
"""A simple user manager that always allows any user."""
def __init__(self, collection_path=''):
self.collection_path = collection_path
def authenticate(self, username, password):
"""
Returns True if this username is allowed to connect with this password.
False otherwise. Override this to change how users are authenticated.
"""
return True
def userdir(self, username):
"""
Returns the directory name for the given user. By default, this is just
the username. Override this to adjust the mapping between users and
their directory.
"""
return username
def _create_user_dir(self, username):
user_dir_path = os.path.join(self.collection_path, username)
if not os.path.isdir(user_dir_path):
logger.info("Creating collection directory for user '{}' at {}"
.format(username, user_dir_path))
os.makedirs(user_dir_path)
class SqliteUserManager(SimpleUserManager):
"""Authenticates users against a SQLite database."""
def __init__(self, auth_db_path, collection_path=None):
SimpleUserManager.__init__(self, collection_path)
self.auth_db_path = os.path.realpath(auth_db_path)
self._ensure_schema_up_to_date()
def _ensure_schema_up_to_date(self):
if not self.auth_db_exists():
return True
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'auth'")
res = cursor.fetchone()
conn.close()
if res is not None:
raise Exception("Outdated database schema, run utils/migrate_user_tables.py")
# Default to using sqlite3 but overridable for sub-classes using other
# DB API 2 driver variants
def auth_db_exists(self):
return os.path.isfile(self.auth_db_path)
# Default to using sqlite3 but overridable for sub-classes using other
# DB API 2 driver variants
def _conn(self):
return sqlite.connect(self.auth_db_path)
# Default to using sqlite3 syntax but overridable for sub-classes using other
# DB API 2 driver variants
@staticmethod
def fs(sql):
return sql
def user_list(self):
if not self.auth_db_exists():
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("SELECT username FROM auth"))
rows = cursor.fetchall()
conn.commit()
conn.close()
return [row[0] for row in rows]
def user_exists(self, username):
users = self.user_list()
return username in users
def del_user(self, username):
# Warning, this doesn't remove the user directory or clean it
if not self.auth_db_exists():
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
conn = self._conn()
cursor = conn.cursor()
logger.info("Removing user '{}' from auth db".format(username))
cursor.execute(self.fs("DELETE FROM auth WHERE username=?"), (username,))
conn.commit()
conn.close()
def add_user(self, username, password):
self._add_user_to_auth_db(username, password)
self._create_user_dir(username)
def add_users(self, users_data):
for username, password in users_data:
self.add_user(username, password)
def _add_user_to_auth_db(self, username, password):
if not self.auth_db_exists():
self.create_auth_db()
pass_hash = self._create_pass_hash(username, password)
conn = self._conn()
cursor = conn.cursor()
logger.info("Adding user '{}' to auth db.".format(username))
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"),
(username, pass_hash))
conn.commit()
conn.close()
def set_password_for_user(self, username, new_password):
if not self.auth_db_exists():
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
elif not self.user_exists(username):
raise ValueError("User {} doesn't exist".format(username))
hash = self._create_pass_hash(username, new_password)
conn = self._conn()
cursor = conn.cursor()
cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username))
conn.commit()
conn.close()
logger.info("Changed password for user {}".format(username))
def authenticate(self, username, password):
"""Returns True if this username is allowed to connect with this password. False otherwise."""
conn = self._conn()
cursor = conn.cursor()
param = (username,)
cursor.execute(self.fs("SELECT hash FROM auth WHERE username=?"), param)
db_hash = cursor.fetchone()
conn.close()
if db_hash is None:
logger.info("Authentication failed for nonexistent user {}."
.format(username))
return False
expected_value = str(db_hash[0])
salt = self._extract_salt(expected_value)
hashobj = hashlib.sha256()
hashobj.update((username + password + salt).encode())
actual_value = hashobj.hexdigest() + salt
if actual_value == expected_value:
logger.info("Authentication succeeded for user {}".format(username))
return True
else:
logger.info("Authentication failed for user {}".format(username))
return False
@staticmethod
def _extract_salt(hash):
return hash[-16:]
@staticmethod
def _create_pass_hash(username, password):
salt = binascii.b2a_hex(os.urandom(8))
pass_hash = (hashlib.sha256((username + password).encode() + salt).hexdigest() +
salt.decode())
return pass_hash
def create_auth_db(self):
conn = self._conn()
cursor = conn.cursor()
logger.info("Creating auth db at {}."
.format(self.auth_db_path))
cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""))
conn.commit()
conn.close()
def get_user_manager(config):
if "auth_db_path" in config and config["auth_db_path"]:
logger.info("Found auth_db_path in config, using SqliteUserManager for auth")
return SqliteUserManager(config['auth_db_path'], config['data_root'])
elif "user_manager" in config and config["user_manager"]: # load from config
logger.info("Found user_manager in config, using {} for auth".format(config['user_manager']))
import importlib
import inspect
module_name, class_name = config['user_manager'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not SimpleUserManager in inspect.getmro(class_):
raise TypeError('''"user_manager" found in the conf file but it doesn''t
inherit from SimpleUserManager''')
return class_(config)
else:
logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password")
return SimpleUserManager()

View File

@@ -0,0 +1,70 @@
#!/usr/bin/env python3
# -*- coding: utf-8 -*-
"""
This script updates the auth and session sqlite3 databases to use the
more compatible `username` column instead of `user`, which is a reserved
word in many other SQL dialects.
"""
import os
import sys
path = os.path.realpath(os.path.abspath(os.path.join(__file__, '../')))
sys.path.insert(0, os.path.dirname(path))
import sqlite3
import ankisyncd.config
conf = ankisyncd.config.load()
def main():
if os.path.isfile(conf["auth_db_path"]):
conn = sqlite3.connect(conf["auth_db_path"])
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR PRIMARY KEY%' "
"AND tbl_name = 'auth'")
res = cursor.fetchone()
if res is not None:
cursor.execute("ALTER TABLE auth RENAME TO auth_old")
cursor.execute("CREATE TABLE auth (username VARCHAR PRIMARY KEY, hash VARCHAR)")
cursor.execute("INSERT INTO auth (username, hash) SELECT user, hash FROM auth_old")
cursor.execute("DROP TABLE auth_old")
conn.commit()
print("Successfully updated table 'auth'")
else:
print("No outdated 'auth' table found.")
conn.close()
else:
print("No auth DB found at the configured 'auth_db_path' path.")
if os.path.isfile(conf["session_db_path"]):
conn = sqlite3.connect(conf["session_db_path"])
cursor = conn.cursor()
cursor.execute("SELECT * FROM sqlite_master "
"WHERE sql LIKE '%user VARCHAR%' "
"AND tbl_name = 'session'")
res = cursor.fetchone()
if res is not None:
cursor.execute("ALTER TABLE session RENAME TO session_old")
cursor.execute("CREATE TABLE session (hkey VARCHAR PRIMARY KEY, skey VARCHAR, "
"username VARCHAR, path VARCHAR)")
cursor.execute("INSERT INTO session (hkey, skey, username, path) "
"SELECT hkey, skey, user, path FROM session_old")
cursor.execute("DROP TABLE session_old")
conn.commit()
print("Successfully updated table 'session'")
else:
print("No outdated 'session' table found.")
conn.close()
else:
print("No session DB found at the configured 'session_db_path' path.")
if __name__ == "__main__":
main()