diff --git a/pkg/database/method_user.go b/pkg/database/method_user.go index 01a4571..50762fd 100644 --- a/pkg/database/method_user.go +++ b/pkg/database/method_user.go @@ -23,7 +23,22 @@ func (database *Database) LoginAsAnonymous() (*User, error) { } func (database *Database) Register(username string, password string, usertype int64) (*User, error) { - _, err := database.stmt.insertUser.Exec(username, password, usertype, 0) + countAdmin, err := database.CountAdmin() + if err != nil { + return nil, err + } + + active := false + if countAdmin == 0 { + active = true + } + + // active normal user by default + if usertype == 2 { + active = true + } + + _, err = database.stmt.insertUser.Exec(username, password, usertype, active, 0) if err != nil { return nil, err } @@ -40,3 +55,12 @@ func (database *Database) GetUserById(id int64) (*User, error) { } return user, nil } + +func (database *Database) CountAdmin() (int64, error) { + var count int64 + err := database.stmt.countAdmin.QueryRow().Scan(&count) + if err != nil { + return 0, err + } + return count, nil +} diff --git a/pkg/database/sql_stmt.go b/pkg/database/sql_stmt.go index 9e9d50a..561e1c1 100644 --- a/pkg/database/sql_stmt.go +++ b/pkg/database/sql_stmt.go @@ -32,6 +32,7 @@ var initUsersTableQuery = `CREATE TABLE IF NOT EXISTS users ( username TEXT NOT NULL UNIQUE, password TEXT NOT NULL, role INTEGER NOT NULL, + active BOOLEAN NOT NULL, avatar_id INTEGER NOT NULL, FOREIGN KEY(avatar_id) REFERENCES avatars(id) );` @@ -170,8 +171,8 @@ LIMIT ?;` var insertFeedbackQuery = `INSERT INTO feedbacks (time, feedback, header) VALUES (?, ?, ?);` -var insertUserQuery = `INSERT INTO users (username, password, role, avatar_id) -VALUES (?, ?, ?, ?);` +var insertUserQuery = `INSERT INTO users (username, password, role, active, avatar_id) +VALUES (?, ?, ?, ?, ?);` var countUserQuery = `SELECT count(*) FROM users;`