Add: backend session support and bug fix
This commit is contained in:
5
go.mod
5
go.mod
@@ -2,4 +2,7 @@ module msw-open-music
|
||||
|
||||
go 1.16
|
||||
|
||||
require github.com/mattn/go-sqlite3 v1.14.7 // indirect
|
||||
require (
|
||||
github.com/gorilla/sessions v1.2.1
|
||||
github.com/mattn/go-sqlite3 v1.14.7
|
||||
)
|
||||
|
||||
4
go.sum
4
go.sum
@@ -1,2 +1,6 @@
|
||||
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
|
||||
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
|
||||
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||
|
||||
@@ -1,9 +1,11 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"github.com/gorilla/sessions"
|
||||
"msw-open-music/pkg/database"
|
||||
"msw-open-music/pkg/tmpfs"
|
||||
"net/http"
|
||||
"os"
|
||||
)
|
||||
|
||||
type API struct {
|
||||
@@ -12,6 +14,8 @@ type API struct {
|
||||
token string
|
||||
APIConfig APIConfig
|
||||
Tmpfs *tmpfs.Tmpfs
|
||||
store *sessions.CookieStore
|
||||
defaultSessionName string
|
||||
}
|
||||
|
||||
func NewAPIConfig() APIConfig {
|
||||
@@ -43,6 +47,8 @@ func NewAPI(config Config) (*API, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
store := sessions.NewCookieStore([]byte(os.Getenv("SESSION_KEY")))
|
||||
|
||||
mux := http.NewServeMux()
|
||||
apiMux := http.NewServeMux()
|
||||
|
||||
@@ -53,6 +59,8 @@ func NewAPI(config Config) (*API, error) {
|
||||
Handler: mux,
|
||||
},
|
||||
APIConfig: apiConfig,
|
||||
store: store,
|
||||
defaultSessionName: "msw-open-music",
|
||||
}
|
||||
api.Tmpfs = tmpfs.NewTmpfs(tmpfsConfig)
|
||||
|
||||
|
||||
@@ -1,6 +1,7 @@
|
||||
package api
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"encoding/json"
|
||||
"log"
|
||||
"msw-open-music/pkg/database"
|
||||
@@ -16,32 +17,85 @@ type LoginResponse struct {
|
||||
User *database.User `json:"user"`
|
||||
}
|
||||
|
||||
func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
// Get method will login as anonymous user
|
||||
if r.Method == "GET" {
|
||||
log.Println("Login as anonymous user")
|
||||
user, err := api.Db.LoginAsAnonymous()
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
resp := &LoginResponse{
|
||||
User: user,
|
||||
}
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
return
|
||||
}
|
||||
|
||||
var request LoginRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&request)
|
||||
func (api *API) LoginAsAnonymous(w http.ResponseWriter, r *http.Request) {
|
||||
user, err := api.Db.LoginAsAnonymous()
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("Login as user", request.Username)
|
||||
session, _ := api.store.Get(r, api.defaultSessionName)
|
||||
|
||||
user, err := api.Db.Login(request.Username, request.Password)
|
||||
// save session
|
||||
session.Values["userId"] = user.ID
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
resp := &LoginResponse{
|
||||
User: user,
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||
var user *database.User
|
||||
var err error
|
||||
session, _ := api.store.Get(r, api.defaultSessionName)
|
||||
log.Println("Session:", session.Values)
|
||||
|
||||
// Get method will login current or anonymous user
|
||||
if r.Method == "GET" {
|
||||
|
||||
// if user already logged in
|
||||
if userId, ok := session.Values["userId"]; ok {
|
||||
user, err = api.Db.GetUserById(userId.(int64))
|
||||
if err != nil {
|
||||
if err != sql.ErrNoRows {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
log.Println("User not found")
|
||||
// login as anonymous user
|
||||
api.LoginAsAnonymous(w, r)
|
||||
return
|
||||
}
|
||||
log.Println("User already logged in:", user)
|
||||
|
||||
} else {
|
||||
// login as anonymous user
|
||||
log.Println("Login as anonymous user")
|
||||
api.LoginAsAnonymous(w, r)
|
||||
}
|
||||
|
||||
} else {
|
||||
|
||||
var request LoginRequest
|
||||
err := json.NewDecoder(r.Body).Decode(&request)
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
log.Println("Login as user", request.Username)
|
||||
|
||||
user, err = api.Db.Login(request.Username, request.Password)
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
// save session
|
||||
session.Values["userId"] = user.ID
|
||||
err = session.Save(r, w)
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
@@ -73,11 +127,19 @@ func (api *API) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
||||
|
||||
log.Println("Register user", request.Username)
|
||||
|
||||
err = api.Db.Register(request.Username, request.Password, request.Role)
|
||||
user, err := api.Db.Register(request.Username, request.Password, request.Role)
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
|
||||
api.HandleOK(w, r)
|
||||
resp := &LoginResponse{
|
||||
User: user,
|
||||
}
|
||||
|
||||
err = json.NewEncoder(w).Encode(resp)
|
||||
if err != nil {
|
||||
api.HandleError(w, r, err)
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
@@ -22,10 +22,21 @@ func (database *Database) LoginAsAnonymous() (*User, error) {
|
||||
return user, nil
|
||||
}
|
||||
|
||||
func (database *Database) Register(username string, password string, usertype int64) (error) {
|
||||
func (database *Database) Register(username string, password string, usertype int64) (*User, error) {
|
||||
_, err := database.stmt.insertUser.Exec(username, password, usertype, 0)
|
||||
if err != nil {
|
||||
return err
|
||||
return nil, err
|
||||
}
|
||||
return nil
|
||||
return database.Login(username, password)
|
||||
}
|
||||
|
||||
func (database *Database) GetUserById(id int64) (*User, error) {
|
||||
user := &User{}
|
||||
|
||||
// get user from database
|
||||
err := database.stmt.getUserById.QueryRow(id).Scan(&user.ID, &user.Username, &user.Role, &user.AvatarId)
|
||||
if err != nil {
|
||||
return user, err
|
||||
}
|
||||
return user, nil
|
||||
}
|
||||
|
||||
@@ -165,6 +165,8 @@ var countAdminQuery = `SELECT count(*) FROM users WHERE role= 1;`
|
||||
|
||||
var getUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE username = ? AND password = ? LIMIT 1;`
|
||||
|
||||
var getUserByIdQuery = `SELECT id, username, role, avatar_id FROM users WHERE id = ? LIMIT 1;`
|
||||
|
||||
var getAnonymousUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE role = 0 LIMIT 1;`
|
||||
|
||||
type Stmt struct {
|
||||
@@ -197,6 +199,7 @@ type Stmt struct {
|
||||
countUser *sql.Stmt
|
||||
countAdmin *sql.Stmt
|
||||
getUser *sql.Stmt
|
||||
getUserById *sql.Stmt
|
||||
getAnonymousUser *sql.Stmt
|
||||
}
|
||||
|
||||
@@ -429,6 +432,12 @@ func NewPreparedStatement(sqlConn *sql.DB) (*Stmt, error) {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// init getUserById
|
||||
stmt.getUserById, err = sqlConn.Prepare(getUserByIdQuery)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
// init getAnonymousUser
|
||||
stmt.getAnonymousUser, err = sqlConn.Prepare(getAnonymousUserQuery)
|
||||
if err != nil {
|
||||
|
||||
Reference in New Issue
Block a user