Add: backend session support and bug fix
This commit is contained in:
5
go.mod
5
go.mod
@@ -2,4 +2,7 @@ module msw-open-music
|
|||||||
|
|
||||||
go 1.16
|
go 1.16
|
||||||
|
|
||||||
require github.com/mattn/go-sqlite3 v1.14.7 // indirect
|
require (
|
||||||
|
github.com/gorilla/sessions v1.2.1
|
||||||
|
github.com/mattn/go-sqlite3 v1.14.7
|
||||||
|
)
|
||||||
|
|||||||
4
go.sum
4
go.sum
@@ -1,2 +1,6 @@
|
|||||||
|
github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ=
|
||||||
|
github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4=
|
||||||
|
github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI=
|
||||||
|
github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM=
|
||||||
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
|
github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA=
|
||||||
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU=
|
||||||
|
|||||||
@@ -1,9 +1,11 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"github.com/gorilla/sessions"
|
||||||
"msw-open-music/pkg/database"
|
"msw-open-music/pkg/database"
|
||||||
"msw-open-music/pkg/tmpfs"
|
"msw-open-music/pkg/tmpfs"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
)
|
)
|
||||||
|
|
||||||
type API struct {
|
type API struct {
|
||||||
@@ -12,6 +14,8 @@ type API struct {
|
|||||||
token string
|
token string
|
||||||
APIConfig APIConfig
|
APIConfig APIConfig
|
||||||
Tmpfs *tmpfs.Tmpfs
|
Tmpfs *tmpfs.Tmpfs
|
||||||
|
store *sessions.CookieStore
|
||||||
|
defaultSessionName string
|
||||||
}
|
}
|
||||||
|
|
||||||
func NewAPIConfig() APIConfig {
|
func NewAPIConfig() APIConfig {
|
||||||
@@ -43,6 +47,8 @@ func NewAPI(config Config) (*API, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
store := sessions.NewCookieStore([]byte(os.Getenv("SESSION_KEY")))
|
||||||
|
|
||||||
mux := http.NewServeMux()
|
mux := http.NewServeMux()
|
||||||
apiMux := http.NewServeMux()
|
apiMux := http.NewServeMux()
|
||||||
|
|
||||||
@@ -53,6 +59,8 @@ func NewAPI(config Config) (*API, error) {
|
|||||||
Handler: mux,
|
Handler: mux,
|
||||||
},
|
},
|
||||||
APIConfig: apiConfig,
|
APIConfig: apiConfig,
|
||||||
|
store: store,
|
||||||
|
defaultSessionName: "msw-open-music",
|
||||||
}
|
}
|
||||||
api.Tmpfs = tmpfs.NewTmpfs(tmpfsConfig)
|
api.Tmpfs = tmpfs.NewTmpfs(tmpfsConfig)
|
||||||
|
|
||||||
|
|||||||
@@ -1,6 +1,7 @@
|
|||||||
package api
|
package api
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"database/sql"
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"log"
|
"log"
|
||||||
"msw-open-music/pkg/database"
|
"msw-open-music/pkg/database"
|
||||||
@@ -16,32 +17,85 @@ type LoginResponse struct {
|
|||||||
User *database.User `json:"user"`
|
User *database.User `json:"user"`
|
||||||
}
|
}
|
||||||
|
|
||||||
func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
func (api *API) LoginAsAnonymous(w http.ResponseWriter, r *http.Request) {
|
||||||
// Get method will login as anonymous user
|
user, err := api.Db.LoginAsAnonymous()
|
||||||
if r.Method == "GET" {
|
|
||||||
log.Println("Login as anonymous user")
|
|
||||||
user, err := api.Db.LoginAsAnonymous()
|
|
||||||
if err != nil {
|
|
||||||
api.HandleError(w, r, err)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
resp := &LoginResponse{
|
|
||||||
User: user,
|
|
||||||
}
|
|
||||||
err = json.NewEncoder(w).Encode(resp)
|
|
||||||
return
|
|
||||||
}
|
|
||||||
|
|
||||||
var request LoginRequest
|
|
||||||
err := json.NewDecoder(r.Body).Decode(&request)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.HandleError(w, r, err)
|
api.HandleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Login as user", request.Username)
|
session, _ := api.store.Get(r, api.defaultSessionName)
|
||||||
|
|
||||||
user, err := api.Db.Login(request.Username, request.Password)
|
// save session
|
||||||
|
session.Values["userId"] = user.ID
|
||||||
|
err = session.Save(r, w)
|
||||||
|
if err != nil {
|
||||||
|
api.HandleError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
resp := &LoginResponse{
|
||||||
|
User: user,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.NewEncoder(w).Encode(resp)
|
||||||
|
if err != nil {
|
||||||
|
api.HandleError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) {
|
||||||
|
var user *database.User
|
||||||
|
var err error
|
||||||
|
session, _ := api.store.Get(r, api.defaultSessionName)
|
||||||
|
log.Println("Session:", session.Values)
|
||||||
|
|
||||||
|
// Get method will login current or anonymous user
|
||||||
|
if r.Method == "GET" {
|
||||||
|
|
||||||
|
// if user already logged in
|
||||||
|
if userId, ok := session.Values["userId"]; ok {
|
||||||
|
user, err = api.Db.GetUserById(userId.(int64))
|
||||||
|
if err != nil {
|
||||||
|
if err != sql.ErrNoRows {
|
||||||
|
api.HandleError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("User not found")
|
||||||
|
// login as anonymous user
|
||||||
|
api.LoginAsAnonymous(w, r)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
log.Println("User already logged in:", user)
|
||||||
|
|
||||||
|
} else {
|
||||||
|
// login as anonymous user
|
||||||
|
log.Println("Login as anonymous user")
|
||||||
|
api.LoginAsAnonymous(w, r)
|
||||||
|
}
|
||||||
|
|
||||||
|
} else {
|
||||||
|
|
||||||
|
var request LoginRequest
|
||||||
|
err := json.NewDecoder(r.Body).Decode(&request)
|
||||||
|
if err != nil {
|
||||||
|
api.HandleError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
log.Println("Login as user", request.Username)
|
||||||
|
|
||||||
|
user, err = api.Db.Login(request.Username, request.Password)
|
||||||
|
if err != nil {
|
||||||
|
api.HandleError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
||||||
|
// save session
|
||||||
|
session.Values["userId"] = user.ID
|
||||||
|
err = session.Save(r, w)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.HandleError(w, r, err)
|
api.HandleError(w, r, err)
|
||||||
return
|
return
|
||||||
@@ -73,11 +127,19 @@ func (api *API) HandleRegister(w http.ResponseWriter, r *http.Request) {
|
|||||||
|
|
||||||
log.Println("Register user", request.Username)
|
log.Println("Register user", request.Username)
|
||||||
|
|
||||||
err = api.Db.Register(request.Username, request.Password, request.Role)
|
user, err := api.Db.Register(request.Username, request.Password, request.Role)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
api.HandleError(w, r, err)
|
api.HandleError(w, r, err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
api.HandleOK(w, r)
|
resp := &LoginResponse{
|
||||||
|
User: user,
|
||||||
|
}
|
||||||
|
|
||||||
|
err = json.NewEncoder(w).Encode(resp)
|
||||||
|
if err != nil {
|
||||||
|
api.HandleError(w, r, err)
|
||||||
|
return
|
||||||
|
}
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -22,10 +22,21 @@ func (database *Database) LoginAsAnonymous() (*User, error) {
|
|||||||
return user, nil
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (database *Database) Register(username string, password string, usertype int64) (error) {
|
func (database *Database) Register(username string, password string, usertype int64) (*User, error) {
|
||||||
_, err := database.stmt.insertUser.Exec(username, password, usertype, 0)
|
_, err := database.stmt.insertUser.Exec(username, password, usertype, 0)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return nil, err
|
||||||
}
|
}
|
||||||
return nil
|
return database.Login(username, password)
|
||||||
|
}
|
||||||
|
|
||||||
|
func (database *Database) GetUserById(id int64) (*User, error) {
|
||||||
|
user := &User{}
|
||||||
|
|
||||||
|
// get user from database
|
||||||
|
err := database.stmt.getUserById.QueryRow(id).Scan(&user.ID, &user.Username, &user.Role, &user.AvatarId)
|
||||||
|
if err != nil {
|
||||||
|
return user, err
|
||||||
|
}
|
||||||
|
return user, nil
|
||||||
}
|
}
|
||||||
|
|||||||
@@ -165,6 +165,8 @@ var countAdminQuery = `SELECT count(*) FROM users WHERE role= 1;`
|
|||||||
|
|
||||||
var getUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE username = ? AND password = ? LIMIT 1;`
|
var getUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE username = ? AND password = ? LIMIT 1;`
|
||||||
|
|
||||||
|
var getUserByIdQuery = `SELECT id, username, role, avatar_id FROM users WHERE id = ? LIMIT 1;`
|
||||||
|
|
||||||
var getAnonymousUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE role = 0 LIMIT 1;`
|
var getAnonymousUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE role = 0 LIMIT 1;`
|
||||||
|
|
||||||
type Stmt struct {
|
type Stmt struct {
|
||||||
@@ -197,6 +199,7 @@ type Stmt struct {
|
|||||||
countUser *sql.Stmt
|
countUser *sql.Stmt
|
||||||
countAdmin *sql.Stmt
|
countAdmin *sql.Stmt
|
||||||
getUser *sql.Stmt
|
getUser *sql.Stmt
|
||||||
|
getUserById *sql.Stmt
|
||||||
getAnonymousUser *sql.Stmt
|
getAnonymousUser *sql.Stmt
|
||||||
}
|
}
|
||||||
|
|
||||||
@@ -429,6 +432,12 @@ func NewPreparedStatement(sqlConn *sql.DB) (*Stmt, error) {
|
|||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// init getUserById
|
||||||
|
stmt.getUserById, err = sqlConn.Prepare(getUserByIdQuery)
|
||||||
|
if err != nil {
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
|
||||||
// init getAnonymousUser
|
// init getAnonymousUser
|
||||||
stmt.getAnonymousUser, err = sqlConn.Prepare(getAnonymousUserQuery)
|
stmt.getAnonymousUser, err = sqlConn.Prepare(getAnonymousUserQuery)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|||||||
Reference in New Issue
Block a user