Add: backend session support and bug fix

This commit is contained in:
2021-12-12 01:13:26 +08:00
parent f3a95973e9
commit e608a6b1df
6 changed files with 123 additions and 26 deletions

5
go.mod
View File

@@ -2,4 +2,7 @@ module msw-open-music
go 1.16
require github.com/mattn/go-sqlite3 v1.14.7 // indirect
require (
github.com/gorilla/sessions v1.2.1
github.com/mattn/go-sqlite3 v1.14.7
)

4
go.sum
View File

@@ -1,2 +1,6 @@
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=

View File

@@ -1,9 +1,11 @@
package api
import (
"github.com/gorilla/sessions"
"msw-open-music/pkg/database"
"msw-open-music/pkg/tmpfs"
"net/http"
"os"
)
type API struct {
@@ -12,6 +14,8 @@ type API struct {
token string
APIConfig APIConfig
Tmpfs *tmpfs.Tmpfs
store *sessions.CookieStore
defaultSessionName string
}
func NewAPIConfig() APIConfig {
@@ -43,6 +47,8 @@ func NewAPI(config Config) (*API, error) {
return nil, err
}
store := sessions.NewCookieStore([]byte(os.Getenv("SESSION_KEY")))
mux := http.NewServeMux()
apiMux := http.NewServeMux()
@@ -53,6 +59,8 @@ func NewAPI(config Config) (*API, error) {
Handler: mux,
},
APIConfig: apiConfig,
store: store,
defaultSessionName: "msw-open-music",
}
api.Tmpfs = tmpfs.NewTmpfs(tmpfsConfig)

View File

@@ -1,6 +1,7 @@
package api
import (
"database/sql"
"encoding/json"
"log"
"msw-open-music/pkg/database"
@@ -16,21 +17,65 @@ type LoginResponse struct {
User *database.User `json:"user"`
}
func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) {
// Get method will login as anonymous user
if r.Method == "GET" {
log.Println("Login as anonymous user")
func (api *API) LoginAsAnonymous(w http.ResponseWriter, r *http.Request) {
user, err := api.Db.LoginAsAnonymous()
if err != nil {
api.HandleError(w, r, err)
return
}
session, _ := api.store.Get(r, api.defaultSessionName)
// save session
session.Values["userId"] = user.ID
err = session.Save(r, w)
if err != nil {
api.HandleError(w, r, err)
return
}
resp := &LoginResponse{
User: user,
}
err = json.NewEncoder(w).Encode(resp)
if err != nil {
api.HandleError(w, r, err)
return
}
}
func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) {
var user *database.User
var err error
session, _ := api.store.Get(r, api.defaultSessionName)
log.Println("Session:", session.Values)
// Get method will login current or anonymous user
if r.Method == "GET" {
// if user already logged in
if userId, ok := session.Values["userId"]; ok {
user, err = api.Db.GetUserById(userId.(int64))
if err != nil {
if err != sql.ErrNoRows {
api.HandleError(w, r, err)
return
}
log.Println("User not found")
// login as anonymous user
api.LoginAsAnonymous(w, r)
return
}
log.Println("User already logged in:", user)
} else {
// login as anonymous user
log.Println("Login as anonymous user")
api.LoginAsAnonymous(w, r)
}
} else {
var request LoginRequest
err := json.NewDecoder(r.Body).Decode(&request)
@@ -41,7 +86,16 @@ func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) {
log.Println("Login as user", request.Username)
user, err := api.Db.Login(request.Username, request.Password)
user, err = api.Db.Login(request.Username, request.Password)
if err != nil {
api.HandleError(w, r, err)
return
}
}
// save session
session.Values["userId"] = user.ID
err = session.Save(r, w)
if err != nil {
api.HandleError(w, r, err)
return
@@ -73,11 +127,19 @@ func (api *API) HandleRegister(w http.ResponseWriter, r *http.Request) {
log.Println("Register user", request.Username)
err = api.Db.Register(request.Username, request.Password, request.Role)
user, err := api.Db.Register(request.Username, request.Password, request.Role)
if err != nil {
api.HandleError(w, r, err)
return
}
api.HandleOK(w, r)
resp := &LoginResponse{
User: user,
}
err = json.NewEncoder(w).Encode(resp)
if err != nil {
api.HandleError(w, r, err)
return
}
}

View File

@@ -22,10 +22,21 @@ func (database *Database) LoginAsAnonymous() (*User, error) {
return user, nil
}
func (database *Database) Register(username string, password string, usertype int64) (error) {
func (database *Database) Register(username string, password string, usertype int64) (*User, error) {
_, err := database.stmt.insertUser.Exec(username, password, usertype, 0)
if err != nil {
return err
return nil, err
}
return nil
return database.Login(username, password)
}
func (database *Database) GetUserById(id int64) (*User, error) {
user := &User{}
// get user from database
err := database.stmt.getUserById.QueryRow(id).Scan(&user.ID, &user.Username, &user.Role, &user.AvatarId)
if err != nil {
return user, err
}
return user, nil
}

View File

@@ -165,6 +165,8 @@ var countAdminQuery = `SELECT count(*) FROM users WHERE role= 1;`
var getUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE username = ? AND password = ? LIMIT 1;`
var getUserByIdQuery = `SELECT id, username, role, avatar_id FROM users WHERE id = ? LIMIT 1;`
var getAnonymousUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE role = 0 LIMIT 1;`
type Stmt struct {
@@ -197,6 +199,7 @@ type Stmt struct {
countUser *sql.Stmt
countAdmin *sql.Stmt
getUser *sql.Stmt
getUserById *sql.Stmt
getAnonymousUser *sql.Stmt
}
@@ -429,6 +432,12 @@ func NewPreparedStatement(sqlConn *sql.DB) (*Stmt, error) {
return nil, err
}
// init getUserById
stmt.getUserById, err = sqlConn.Prepare(getUserByIdQuery)
if err != nil {
return nil, err
}
// init getAnonymousUser
stmt.getAnonymousUser, err = sqlConn.Prepare(getAnonymousUserQuery)
if err != nil {