diff --git a/server/ctrladmin/handlers.go b/server/ctrladmin/handlers.go index ce11fb8..3908261 100644 --- a/server/ctrladmin/handlers.go +++ b/server/ctrladmin/handlers.go @@ -151,7 +151,7 @@ func (c *Controller) ServeChangePassword(r *http.Request) *Response { code: 400, } } - user := c.DB.GetUserFromName(username) + user := c.DB.GetUserByName(username) if user == nil { return &Response{ err: "couldn't find a user with that name", @@ -176,7 +176,7 @@ func (c *Controller) ServeChangePasswordDo(r *http.Request) *Response { flashW: []string{err.Error()}, } } - user := c.DB.GetUserFromName(username) + user := c.DB.GetUserByName(username) user.Password = passwordOne c.DB.Save(user) return &Response{redirect: "/admin/home"} @@ -190,7 +190,7 @@ func (c *Controller) ServeDeleteUser(r *http.Request) *Response { code: 400, } } - user := c.DB.GetUserFromName(username) + user := c.DB.GetUserByName(username) if user == nil { return &Response{ err: "couldn't find a user with that name", @@ -207,7 +207,7 @@ func (c *Controller) ServeDeleteUser(r *http.Request) *Response { func (c *Controller) ServeDeleteUserDo(r *http.Request) *Response { username := r.URL.Query().Get("user") - user := c.DB.GetUserFromName(username) + user := c.DB.GetUserByName(username) if user.IsAdmin { return &Response{ redirect: "/admin/home", diff --git a/server/ctrladmin/handlers_raw.go b/server/ctrladmin/handlers_raw.go index 8232b0a..adbf1eb 100644 --- a/server/ctrladmin/handlers_raw.go +++ b/server/ctrladmin/handlers_raw.go @@ -16,7 +16,7 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, r.Referer(), http.StatusSeeOther) return } - user := c.DB.GetUserFromName(username) + user := c.DB.GetUserByName(username) if user == nil || password != user.Password { sessAddFlashW(session, []string{"invalid username / password"}) sessLogSave(session, w, r) @@ -26,7 +26,7 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { // put the user name into the session. future endpoints after this one // are wrapped with WithUserSession() which will get the name from the // session and put the row into the request context - session.Values["user"] = user.Name + session.Values["user"] = user.ID sessLogSave(session, w, r) http.Redirect(w, r, c.Path("/admin/home"), http.StatusSeeOther) } diff --git a/server/ctrladmin/middleware.go b/server/ctrladmin/middleware.go index b5472f3..e89932f 100644 --- a/server/ctrladmin/middleware.go +++ b/server/ctrladmin/middleware.go @@ -21,7 +21,7 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // session exists at this point session := r.Context().Value(CtxSession).(*sessions.Session) - username, ok := session.Values["user"].(string) + userID, ok := session.Values["user"].(int) if !ok { sessAddFlashW(session, []string{"you are not authenticated"}) sessLogSave(session, w, r) @@ -29,7 +29,7 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler { return } // take username from sesion and add the user row to the context - user := c.DB.GetUserFromName(username) + user := c.DB.GetUserByID(userID) if user == nil { // the username in the client's session no longer relates to a // user in the database (maybe the user was deleted) diff --git a/server/ctrlsubsonic/middleware.go b/server/ctrlsubsonic/middleware.go index 3d4daf4..5dc2a7f 100644 --- a/server/ctrlsubsonic/middleware.go +++ b/server/ctrlsubsonic/middleware.go @@ -67,7 +67,7 @@ func (c *Controller) WithUser(next http.Handler) http.Handler { "please provide `t` and `s`, or just `p`")) return } - user := c.DB.GetUserFromName(username) + user := c.DB.GetUserByName(username) if user == nil { _ = writeResp(w, r, spec.NewError(40, "invalid username `%s`", username)) diff --git a/server/db/db.go b/server/db/db.go index d1bf57a..6945125 100644 --- a/server/db/db.go +++ b/server/db/db.go @@ -111,7 +111,19 @@ func (db *DB) GetOrCreateKey(key string) string { return value } -func (db *DB) GetUserFromName(name string) *User { +func (db *DB) GetUserByID(id int) *User { + user := &User{} + err := db. + Where("id=?", id). + First(user). + Error + if gorm.IsRecordNotFoundError(err) { + return nil + } + return user +} + +func (db *DB) GetUserByName(name string) *User { user := &User{} err := db. Where("name=?", name).