Use factory method for user_manager

Also add some abstraction over the SQL to allow for different SQL
dialects
This commit is contained in:
Anton Melser
2019-01-28 21:17:40 +08:00
parent bfeaeae2e5
commit ea0cbc669b
4 changed files with 102 additions and 27 deletions

View File

@@ -46,16 +46,29 @@ class SqliteUserManager(SimpleUserManager):
SimpleUserManager.__init__(self, collection_path)
self.auth_db_path = os.path.realpath(auth_db_path)
# Default to using sqlite3 but overridable for sub-classes using other
# DB API 2 driver variants
def auth_db_exists(self):
return os.path.isfile(self.auth_db_path)
# Default to using sqlite3 but overridable for sub-classes using other
# DB API 2 driver variants
def _conn(self):
return sqlite.connect(self.auth_db_path)
# Default to using sqlite3 syntax but overridable for sub-classes using other
# DB API 2 driver variants
@staticmethod
def fs(sql):
return sql
def user_list(self):
if not self.auth_db_exists():
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
conn = sqlite.connect(self.auth_db_path)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("SELECT user FROM auth")
cursor.execute(self.fs("SELECT username FROM auth"))
rows = cursor.fetchall()
conn.commit()
conn.close()
@@ -67,13 +80,14 @@ class SqliteUserManager(SimpleUserManager):
return username in users
def del_user(self, username):
# Warning, this doesn't remove the user directory or clean it
if not self.auth_db_exists():
raise ValueError("Auth DB {} doesn't exist".format(self.auth_db_path))
conn = sqlite.connect(self.auth_db_path)
conn = self._conn()
cursor = conn.cursor()
logger.info("Removing user '{}' from auth db".format(username))
cursor.execute("DELETE FROM auth WHERE user=?", (username,))
cursor.execute(self.fs("DELETE FROM auth WHERE username=?"), (username,))
conn.commit()
conn.close()
@@ -91,10 +105,10 @@ class SqliteUserManager(SimpleUserManager):
pass_hash = self._create_pass_hash(username, password)
conn = sqlite.connect(self.auth_db_path)
conn = self._conn()
cursor = conn.cursor()
logger.info("Adding user '{}' to auth db.".format(username))
cursor.execute("INSERT INTO auth VALUES (?, ?)",
cursor.execute(self.fs("INSERT INTO auth VALUES (?, ?)"),
(username, pass_hash))
conn.commit()
conn.close()
@@ -107,9 +121,9 @@ class SqliteUserManager(SimpleUserManager):
hash = self._create_pass_hash(username, new_password)
conn = sqlite.connect(self.auth_db_path)
conn = self._conn()
cursor = conn.cursor()
cursor.execute("UPDATE auth SET hash=? WHERE user=?", (hash, username))
cursor.execute(self.fs("UPDATE auth SET hash=? WHERE username=?"), (hash, username))
conn.commit()
conn.close()
@@ -118,10 +132,10 @@ class SqliteUserManager(SimpleUserManager):
def authenticate(self, username, password):
"""Returns True if this username is allowed to connect with this password. False otherwise."""
conn = sqlite.connect(self.auth_db_path)
conn = self._conn()
cursor = conn.cursor()
param = (username,)
cursor.execute("SELECT hash FROM auth WHERE user=?", param)
cursor.execute(self.fs("SELECT hash FROM auth WHERE username=?"), param)
db_hash = cursor.fetchone()
conn.close()
@@ -156,11 +170,32 @@ class SqliteUserManager(SimpleUserManager):
return pass_hash
def create_auth_db(self):
conn = sqlite.connect(self.auth_db_path)
conn = self._conn()
cursor = conn.cursor()
logger.info("Creating auth db at {}."
.format(self.auth_db_path))
cursor.execute("""CREATE TABLE IF NOT EXISTS auth
(user VARCHAR PRIMARY KEY, hash VARCHAR)""")
cursor.execute(self.fs("""CREATE TABLE IF NOT EXISTS auth
(username VARCHAR PRIMARY KEY, hash VARCHAR)"""))
conn.commit()
conn.close()
def get_user_manager(config):
if "auth_db_path" in config and config["auth_db_path"]:
logger.info("Found auth_db_path in config, using SqliteUserManager for auth")
return SqliteUserManager(config['auth_db_path'], config['data_root'])
elif "user_manager" in config and config["user_manager"]: # load from config
logger.info("Found user_manager in config, using {} for auth".format(config['user_manager']))
import importlib
import inspect
module_name, class_name = config['user_manager'].rsplit('.', 1)
module = importlib.import_module(module_name.strip())
class_ = getattr(module, class_name.strip())
if not SimpleUserManager in inspect.getmro(class_):
raise TypeError('''"user_manager" found in the conf file but it doesn''t
inherit from SimpleUserManager''')
return class_(config)
else:
logger.warning("neither auth_db_path nor user_manager set, ankisyncd will accept any password")
return SimpleUserManager()