diff --git a/go.mod b/go.mod index c87cde8..4e6b581 100644 --- a/go.mod +++ b/go.mod @@ -2,4 +2,7 @@ module msw-open-music go 1.16 -require github.com/mattn/go-sqlite3 v1.14.7 // indirect +require ( + github.com/gorilla/sessions v1.2.1 + github.com/mattn/go-sqlite3 v1.14.7 +) diff --git a/go.sum b/go.sum index 96ff824..26f85c8 100644 --- a/go.sum +++ b/go.sum @@ -1,2 +1,6 @@ +github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= +github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= +github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= +github.com/gorilla/sessions v1.2.1/go.mod h1:dk2InVEVJ0sfLlnXv9EAgkf6ecYs/i80K/zI+bUmuGM= github.com/mattn/go-sqlite3 v1.14.7 h1:fxWBnXkxfM6sRiuH3bqJ4CfzZojMOLVc0UTsTglEghA= github.com/mattn/go-sqlite3 v1.14.7/go.mod h1:NyWgC/yNuGj7Q9rpYnZvas74GogHl5/Z4A/KQRfk6bU= diff --git a/pkg/api/api.go b/pkg/api/api.go index ba7c2f9..b81e80e 100644 --- a/pkg/api/api.go +++ b/pkg/api/api.go @@ -1,9 +1,11 @@ package api import ( + "github.com/gorilla/sessions" "msw-open-music/pkg/database" "msw-open-music/pkg/tmpfs" "net/http" + "os" ) type API struct { @@ -12,6 +14,8 @@ type API struct { token string APIConfig APIConfig Tmpfs *tmpfs.Tmpfs + store *sessions.CookieStore + defaultSessionName string } func NewAPIConfig() APIConfig { @@ -43,6 +47,8 @@ func NewAPI(config Config) (*API, error) { return nil, err } + store := sessions.NewCookieStore([]byte(os.Getenv("SESSION_KEY"))) + mux := http.NewServeMux() apiMux := http.NewServeMux() @@ -53,6 +59,8 @@ func NewAPI(config Config) (*API, error) { Handler: mux, }, APIConfig: apiConfig, + store: store, + defaultSessionName: "msw-open-music", } api.Tmpfs = tmpfs.NewTmpfs(tmpfsConfig) diff --git a/pkg/api/handle_user.go b/pkg/api/handle_user.go index 92d19be..121d06d 100644 --- a/pkg/api/handle_user.go +++ b/pkg/api/handle_user.go @@ -1,6 +1,7 @@ package api import ( + "database/sql" "encoding/json" "log" "msw-open-music/pkg/database" @@ -16,32 +17,85 @@ type LoginResponse struct { User *database.User `json:"user"` } -func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) { - // Get method will login as anonymous user - if r.Method == "GET" { - log.Println("Login as anonymous user") - user, err := api.Db.LoginAsAnonymous() - if err != nil { - api.HandleError(w, r, err) - return - } - resp := &LoginResponse{ - User: user, - } - err = json.NewEncoder(w).Encode(resp) - return - } - - var request LoginRequest - err := json.NewDecoder(r.Body).Decode(&request) +func (api *API) LoginAsAnonymous(w http.ResponseWriter, r *http.Request) { + user, err := api.Db.LoginAsAnonymous() if err != nil { api.HandleError(w, r, err) return } - log.Println("Login as user", request.Username) + session, _ := api.store.Get(r, api.defaultSessionName) - user, err := api.Db.Login(request.Username, request.Password) + // save session + session.Values["userId"] = user.ID + err = session.Save(r, w) + if err != nil { + api.HandleError(w, r, err) + return + } + + resp := &LoginResponse{ + User: user, + } + + err = json.NewEncoder(w).Encode(resp) + if err != nil { + api.HandleError(w, r, err) + return + } +} + +func (api *API) HandleLogin(w http.ResponseWriter, r *http.Request) { + var user *database.User + var err error + session, _ := api.store.Get(r, api.defaultSessionName) + log.Println("Session:", session.Values) + + // Get method will login current or anonymous user + if r.Method == "GET" { + + // if user already logged in + if userId, ok := session.Values["userId"]; ok { + user, err = api.Db.GetUserById(userId.(int64)) + if err != nil { + if err != sql.ErrNoRows { + api.HandleError(w, r, err) + return + } + log.Println("User not found") + // login as anonymous user + api.LoginAsAnonymous(w, r) + return + } + log.Println("User already logged in:", user) + + } else { + // login as anonymous user + log.Println("Login as anonymous user") + api.LoginAsAnonymous(w, r) + } + + } else { + + var request LoginRequest + err := json.NewDecoder(r.Body).Decode(&request) + if err != nil { + api.HandleError(w, r, err) + return + } + + log.Println("Login as user", request.Username) + + user, err = api.Db.Login(request.Username, request.Password) + if err != nil { + api.HandleError(w, r, err) + return + } + } + + // save session + session.Values["userId"] = user.ID + err = session.Save(r, w) if err != nil { api.HandleError(w, r, err) return @@ -73,11 +127,19 @@ func (api *API) HandleRegister(w http.ResponseWriter, r *http.Request) { log.Println("Register user", request.Username) - err = api.Db.Register(request.Username, request.Password, request.Role) + user, err := api.Db.Register(request.Username, request.Password, request.Role) if err != nil { api.HandleError(w, r, err) return } - api.HandleOK(w, r) + resp := &LoginResponse{ + User: user, + } + + err = json.NewEncoder(w).Encode(resp) + if err != nil { + api.HandleError(w, r, err) + return + } } diff --git a/pkg/database/method_user.go b/pkg/database/method_user.go index ecd3f5c..01a4571 100644 --- a/pkg/database/method_user.go +++ b/pkg/database/method_user.go @@ -22,10 +22,21 @@ func (database *Database) LoginAsAnonymous() (*User, error) { return user, nil } -func (database *Database) Register(username string, password string, usertype int64) (error) { +func (database *Database) Register(username string, password string, usertype int64) (*User, error) { _, err := database.stmt.insertUser.Exec(username, password, usertype, 0) if err != nil { - return err + return nil, err } - return nil + return database.Login(username, password) +} + +func (database *Database) GetUserById(id int64) (*User, error) { + user := &User{} + + // get user from database + err := database.stmt.getUserById.QueryRow(id).Scan(&user.ID, &user.Username, &user.Role, &user.AvatarId) + if err != nil { + return user, err + } + return user, nil } diff --git a/pkg/database/sql_stmt.go b/pkg/database/sql_stmt.go index e4bad08..a04b2ef 100644 --- a/pkg/database/sql_stmt.go +++ b/pkg/database/sql_stmt.go @@ -165,6 +165,8 @@ var countAdminQuery = `SELECT count(*) FROM users WHERE role= 1;` var getUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE username = ? AND password = ? LIMIT 1;` +var getUserByIdQuery = `SELECT id, username, role, avatar_id FROM users WHERE id = ? LIMIT 1;` + var getAnonymousUserQuery = `SELECT id, username, role, avatar_id FROM users WHERE role = 0 LIMIT 1;` type Stmt struct { @@ -197,6 +199,7 @@ type Stmt struct { countUser *sql.Stmt countAdmin *sql.Stmt getUser *sql.Stmt + getUserById *sql.Stmt getAnonymousUser *sql.Stmt } @@ -429,6 +432,12 @@ func NewPreparedStatement(sqlConn *sql.DB) (*Stmt, error) { return nil, err } + // init getUserById + stmt.getUserById, err = sqlConn.Prepare(getUserByIdQuery) + if err != nil { + return nil, err + } + // init getAnonymousUser stmt.getAnonymousUser, err = sqlConn.Prepare(getAnonymousUserQuery) if err != nil {