From f1e8dcfad495b5bfd4f423d49b1cbee345b2a72a Mon Sep 17 00:00:00 2001 From: heimoshuiyu Date: Mon, 13 Dec 2021 14:28:24 +0800 Subject: [PATCH] Add: handle not active user --- pkg/api/handle_error.go | 1 + pkg/api/handle_user.go | 6 ++++++ pkg/database/method_user.go | 4 ++-- pkg/database/sql_stmt.go | 4 ++-- 4 files changed, 11 insertions(+), 4 deletions(-) diff --git a/pkg/api/handle_error.go b/pkg/api/handle_error.go index cdc0363..347dc2a 100644 --- a/pkg/api/handle_error.go +++ b/pkg/api/handle_error.go @@ -12,6 +12,7 @@ var ( ErrNotAdmin = errors.New("not admin") ErrEmpty = errors.New("Empty field detected, please fill in all fields") ErrAnonymous = errors.New("Anonymous user detected, please login") + ErrNotActive = errors.New("User is not active") ) type Error struct { diff --git a/pkg/api/handle_user.go b/pkg/api/handle_user.go index cf06872..cabe04f 100644 --- a/pkg/api/handle_user.go +++ b/pkg/api/handle_user.go @@ -94,6 +94,12 @@ func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) { } } + // if user is not active + if !user.Active { + api.HandleError(w, r, ErrNotActive) + return + } + // save session session.Values["userId"] = user.ID err = session.Save(r, w) diff --git a/pkg/database/method_user.go b/pkg/database/method_user.go index e5f493c..cb092c1 100644 --- a/pkg/database/method_user.go +++ b/pkg/database/method_user.go @@ -4,7 +4,7 @@ func (database *Database) Login(username string, password string) (*User, error) user := &User{} // get user from database - err := database.stmt.getUser.QueryRow(username, password).Scan(&user.ID, &user.Username, &user.Role, &user.AvatarId) + err := database.stmt.getUser.QueryRow(username, password).Scan(&user.ID, &user.Username, &user.Role, &user.Active, &user.AvatarId) if err != nil { return user, err } @@ -49,7 +49,7 @@ func (database *Database) GetUserById(id int64) (*User, error) { user := &User{} // get user from database - err := database.stmt.getUserById.QueryRow(id).Scan(&user.ID, &user.Username, &user.Role, &user.AvatarId) + err := database.stmt.getUserById.QueryRow(id).Scan(&user.ID, &user.Username, &user.Role, &user.Active, &user.AvatarId) if err != nil { return user, err } diff --git a/pkg/database/sql_stmt.go b/pkg/database/sql_stmt.go index 6dd0156..fa415ac 100644 --- a/pkg/database/sql_stmt.go +++ b/pkg/database/sql_stmt.go @@ -178,11 +178,11 @@ var countUserQuery = `SELECT count(*) FROM users;` var countAdminQuery = `SELECT count(*) FROM users WHERE role= 1;` -var getUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE username = ? AND password = ? LIMIT 1;` +var getUserQuery = `SELECT id, username, role, active, avatar_id FROM users WHERE username = ? AND password = ? LIMIT 1;` var getUsersQuery = `SELECT id, username, role, active, avatar_id FROM users;` -var getUserByIdQuery = `SELECT id, username, role, avatar_id FROM users WHERE id = ? LIMIT 1;` +var getUserByIdQuery = `SELECT id, username, role, active, avatar_id FROM users WHERE id = ? LIMIT 1;` var updateUserActiveQuery = `UPDATE users SET active = ? WHERE id = ?;`