From 9c2f2e381b09e1523da79e6f14345e875b215fc2 Mon Sep 17 00:00:00 2001 From: sentriz Date: Wed, 7 Aug 2019 13:53:02 +0100 Subject: [PATCH] remove response writer from most admin handlers --- server/ctrladmin/ctrl.go | 42 ++++++++--- server/ctrladmin/handlers.go | 126 ++++++++++++------------------- server/ctrladmin/handlers_raw.go | 41 ++++++++++ server/ctrladmin/middleware.go | 10 +-- server/ctrlsubsonic/ctrl.go | 2 +- server/ctrlsubsonic/testdata/db | Bin 114688 -> 114688 bytes server/server.go | 4 +- 7 files changed, 128 insertions(+), 97 deletions(-) create mode 100644 server/ctrladmin/handlers_raw.go diff --git a/server/ctrladmin/ctrl.go b/server/ctrladmin/ctrl.go index 2574773..5a8a8aa 100644 --- a/server/ctrladmin/ctrl.go +++ b/server/ctrladmin/ctrl.go @@ -103,7 +103,7 @@ type templateData struct { SelectedUser *model.User } -type adminHandler func(w http.ResponseWriter, r *http.Request) *Response +type adminHandler func(r *http.Request) *Response type Response struct { // code is 200 @@ -111,6 +111,8 @@ type Response struct { data *templateData // code is 303 redirect string + flashN string // normal + flashW string // warning // code is >= 400 code int err string @@ -118,7 +120,16 @@ type Response struct { func (c *Controller) H(h adminHandler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := h(w, r) + resp := h(r) + session, ok := r.Context().Value(key.Session).(*sessions.Session) + if ok { + sessAddFlashN(session, resp.flashN) + sessAddFlashW(session, resp.flashW) + if err := session.Save(r, w); err != nil { + http.Error(w, fmt.Sprint("error saving session: %v", err), 500) + return + } + } if resp.redirect != "" { http.Redirect(w, r, resp.redirect, http.StatusSeeOther) return @@ -134,9 +145,12 @@ func (c *Controller) H(h adminHandler) http.Handler { if resp.data == nil { resp.data = &templateData{} } - if session, ok := r.Context().Value(key.Session).(*sessions.Session); ok { + if session != nil { resp.data.Flashes = session.Flashes() - sessLogSave(w, r, session) + if err := session.Save(r, w); err != nil { + http.Error(w, fmt.Sprint("error saving session: %v", err), 500) + return + } } if user, ok := r.Context().Value(key.User).(*model.User); ok { resp.data.User = user @@ -172,7 +186,7 @@ func firstExisting(or string, strings ...string) string { return or } -func sessLogSave(w http.ResponseWriter, r *http.Request, s *sessions.Session) { +func sessLogSave(s *sessions.Session, w http.ResponseWriter, r *http.Request) { if err := s.Save(r, w); err != nil { log.Printf("error saving session: %v\n", err) } @@ -183,26 +197,32 @@ type Flash struct { Type string } -func sessAddFlashW(message string, s *sessions.Session) { +func sessAddFlashW(s *sessions.Session, message string) { + if message == "" { + return + } s.AddFlash(Flash{ Message: message, Type: "warning", }) } -func sessAddFlashWf(message string, s *sessions.Session, a ...interface{}) { - sessAddFlashW(fmt.Sprintf(message, a...), s) +func sessAddFlashWf(s *sessions.Session, message string, a ...interface{}) { + sessAddFlashW(s, fmt.Sprintf(message, a...)) } -func sessAddFlashN(message string, s *sessions.Session) { +func sessAddFlashN(s *sessions.Session, message string) { + if message == "" { + return + } s.AddFlash(Flash{ Message: message, Type: "normal", }) } -func sessAddFlashNf(message string, s *sessions.Session, a ...interface{}) { - sessAddFlashN(fmt.Sprintf(message, a...), s) +func sessAddFlashNf(s *sessions.Session, message string, a ...interface{}) { + sessAddFlashN(s, fmt.Sprintf(message, a...)) } // ## begin validation diff --git a/server/ctrladmin/handlers.go b/server/ctrladmin/handlers.go index a6d39a9..5023a3e 100644 --- a/server/ctrladmin/handlers.go +++ b/server/ctrladmin/handlers.go @@ -7,53 +7,21 @@ import ( "strconv" "time" - "github.com/gorilla/sessions" - "senan.xyz/g/gonic/model" "senan.xyz/g/gonic/scanner" "senan.xyz/g/gonic/server/key" "senan.xyz/g/gonic/server/lastfm" ) -func (c *Controller) ServeNotFound(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeNotFound(r *http.Request) *Response { return &Response{template: "not_found.tmpl"} } -func (c *Controller) ServeLogin(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeLogin(r *http.Request) *Response { return &Response{template: "login.tmpl"} } -func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) *Response { - session := r.Context().Value(key.Session).(*sessions.Session) - username := r.FormValue("username") - password := r.FormValue("password") - if username == "" || password == "" { - sessAddFlashW("please provide both a username and password", session) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} - } - user := c.DB.GetUserFromName(username) - if user == nil || password != user.Password { - sessAddFlashW("invalid username / password", session) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} - } - // put the user name into the session. future endpoints after this one - // are wrapped with WithUserSession() which will get the name from the - // session and put the row into the request context - session.Values["user"] = user.Name - sessLogSave(w, r, session) - return &Response{redirect: "/admin/home"} -} - -func (c *Controller) ServeLogout(w http.ResponseWriter, r *http.Request) *Response { - session := r.Context().Value(key.Session).(*sessions.Session) - session.Options.MaxAge = -1 - sessLogSave(w, r, session) - return &Response{redirect: "/admin/login"} -} - -func (c *Controller) ServeHome(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeHome(r *http.Request) *Response { data := &templateData{} // // stats box @@ -97,19 +65,19 @@ func (c *Controller) ServeHome(w http.ResponseWriter, r *http.Request) *Response } } -func (c *Controller) ServeChangeOwnPassword(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeChangeOwnPassword(r *http.Request) *Response { return &Response{template: "change_own_password.tmpl"} } -func (c *Controller) ServeChangeOwnPasswordDo(w http.ResponseWriter, r *http.Request) *Response { - session := r.Context().Value(key.Session).(*sessions.Session) +func (c *Controller) ServeChangeOwnPasswordDo(r *http.Request) *Response { passwordOne := r.FormValue("password_one") passwordTwo := r.FormValue("password_two") err := validatePasswords(passwordOne, passwordTwo) if err != nil { - sessAddFlashW(err.Error(), session) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} + return &Response{ + redirect: r.Referer(), + flashW: err.Error(), + } } user := r.Context().Value(key.User).(*model.User) user.Password = passwordOne @@ -117,7 +85,7 @@ func (c *Controller) ServeChangeOwnPasswordDo(w http.ResponseWriter, r *http.Req return &Response{redirect: "/admin/home"} } -func (c *Controller) ServeLinkLastFMDo(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeLinkLastFMDo(r *http.Request) *Response { token := r.URL.Query().Get("token") if token == "" { return &Response{ @@ -131,10 +99,10 @@ func (c *Controller) ServeLinkLastFMDo(w http.ResponseWriter, r *http.Request) * token, ) if err != nil { - session := r.Context().Value(key.Session).(*sessions.Session) - sessAddFlashW(err.Error(), session) - sessLogSave(w, r, session) - return &Response{redirect: "/admin/home"} + return &Response{ + redirect: "/admin/home", + flashW: err.Error(), + } } user := r.Context().Value(key.User).(*model.User) user.LastFMSession = sessionKey @@ -142,14 +110,14 @@ func (c *Controller) ServeLinkLastFMDo(w http.ResponseWriter, r *http.Request) * return &Response{redirect: "/admin/home"} } -func (c *Controller) ServeUnlinkLastFMDo(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeUnlinkLastFMDo(r *http.Request) *Response { user := r.Context().Value(key.User).(*model.User) user.LastFMSession = "" c.DB.Save(&user) return &Response{redirect: "/admin/home"} } -func (c *Controller) ServeChangePassword(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeChangePassword(r *http.Request) *Response { username := r.URL.Query().Get("user") if username == "" { return &Response{ @@ -172,16 +140,16 @@ func (c *Controller) ServeChangePassword(w http.ResponseWriter, r *http.Request) } } -func (c *Controller) ServeChangePasswordDo(w http.ResponseWriter, r *http.Request) *Response { - session := r.Context().Value(key.Session).(*sessions.Session) +func (c *Controller) ServeChangePasswordDo(r *http.Request) *Response { username := r.URL.Query().Get("user") passwordOne := r.FormValue("password_one") passwordTwo := r.FormValue("password_two") err := validatePasswords(passwordOne, passwordTwo) if err != nil { - sessAddFlashW(err.Error(), session) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} + return &Response{ + redirect: r.Referer(), + flashW: err.Error(), + } } user := c.DB.GetUserFromName(username) user.Password = passwordOne @@ -189,7 +157,7 @@ func (c *Controller) ServeChangePasswordDo(w http.ResponseWriter, r *http.Reques return &Response{redirect: "/admin/home"} } -func (c *Controller) ServeDeleteUser(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeDeleteUser(r *http.Request) *Response { username := r.URL.Query().Get("user") if username == "" { return &Response{ @@ -212,33 +180,34 @@ func (c *Controller) ServeDeleteUser(w http.ResponseWriter, r *http.Request) *Re } } -func (c *Controller) ServeDeleteUserDo(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeDeleteUserDo(r *http.Request) *Response { username := r.URL.Query().Get("user") user := c.DB.GetUserFromName(username) c.DB.Delete(user) return &Response{redirect: "/admin/home"} } -func (c *Controller) ServeCreateUser(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeCreateUser(r *http.Request) *Response { return &Response{template: "create_user.tmpl"} } -func (c *Controller) ServeCreateUserDo(w http.ResponseWriter, r *http.Request) *Response { - session := r.Context().Value(key.Session).(*sessions.Session) +func (c *Controller) ServeCreateUserDo(r *http.Request) *Response { username := r.FormValue("username") err := validateUsername(username) if err != nil { - sessAddFlashW(err.Error(), session) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} + return &Response{ + redirect: r.Referer(), + flashW: err.Error(), + } } passwordOne := r.FormValue("password_one") passwordTwo := r.FormValue("password_two") err = validatePasswords(passwordOne, passwordTwo) if err != nil { - sessAddFlashW(err.Error(), session) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} + return &Response{ + redirect: r.Referer(), + flashW: err.Error(), + } } user := model.User{ Name: username, @@ -246,14 +215,15 @@ func (c *Controller) ServeCreateUserDo(w http.ResponseWriter, r *http.Request) * } err = c.DB.Create(&user).Error if err != nil { - sessAddFlashWf("could not create user `%s`: %v", session, username, err) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} + return &Response{ + redirect: r.Referer(), + flashW: fmt.Sprintf("could not create user `%s`: %v", username, err), + } } return &Response{redirect: "/admin/home"} } -func (c *Controller) ServeUpdateLastFMAPIKey(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeUpdateLastFMAPIKey(r *http.Request) *Response { data := &templateData{} data.CurrentLastFMAPIKey = c.DB.GetSetting("lastfm_api_key") data.CurrentLastFMAPISecret = c.DB.GetSetting("lastfm_secret") @@ -263,21 +233,21 @@ func (c *Controller) ServeUpdateLastFMAPIKey(w http.ResponseWriter, r *http.Requ } } -func (c *Controller) ServeUpdateLastFMAPIKeyDo(w http.ResponseWriter, r *http.Request) *Response { - session := r.Context().Value(key.Session).(*sessions.Session) +func (c *Controller) ServeUpdateLastFMAPIKeyDo(r *http.Request) *Response { apiKey := r.FormValue("api_key") secret := r.FormValue("secret") if err := validateAPIKey(apiKey, secret); err != nil { - sessAddFlashW(err.Error(), session) - sessLogSave(w, r, session) - return &Response{redirect: r.Referer()} + return &Response{ + redirect: r.Referer(), + flashW: err.Error(), + } } c.DB.SetSetting("lastfm_api_key", apiKey) c.DB.SetSetting("lastfm_secret", secret) return &Response{redirect: "/admin/home"} } -func (c *Controller) ServeStartScanDo(w http.ResponseWriter, r *http.Request) *Response { +func (c *Controller) ServeStartScanDo(r *http.Request) *Response { defer func() { go func() { err := scanner. @@ -288,8 +258,8 @@ func (c *Controller) ServeStartScanDo(w http.ResponseWriter, r *http.Request) *R } }() }() - session := r.Context().Value(key.Session).(*sessions.Session) - sessAddFlashN("scan started. refresh for results", session) - sessLogSave(w, r, session) - return &Response{redirect: "/admin/home"} + return &Response{ + redirect: "/admin/home", + flashN: "scan started. refresh for results", + } } diff --git a/server/ctrladmin/handlers_raw.go b/server/ctrladmin/handlers_raw.go new file mode 100644 index 0000000..4fad088 --- /dev/null +++ b/server/ctrladmin/handlers_raw.go @@ -0,0 +1,41 @@ +package ctrladmin + +import ( + "net/http" + + "github.com/gorilla/sessions" + + "senan.xyz/g/gonic/server/key" +) + +func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value(key.Session).(*sessions.Session) + username := r.FormValue("username") + password := r.FormValue("password") + if username == "" || password == "" { + sessAddFlashW(session, "please provide both a username and password") + sessLogSave(session, w, r) + http.Redirect(w, r, r.Referer(), http.StatusSeeOther) + return + } + user := c.DB.GetUserFromName(username) + if user == nil || password != user.Password { + sessAddFlashW(session, "invalid username / password") + sessLogSave(session, w, r) + http.Redirect(w, r, r.Referer(), http.StatusSeeOther) + return + } + // put the user name into the session. future endpoints after this one + // are wrapped with WithUserSession() which will get the name from the + // session and put the row into the request context + session.Values["user"] = user.Name + sessLogSave(session, w, r) + http.Redirect(w, r, "/admin/home", http.StatusSeeOther) +} + +func (c *Controller) ServeLogout(w http.ResponseWriter, r *http.Request) { + session := r.Context().Value(key.Session).(*sessions.Session) + session.Options.MaxAge = -1 + sessLogSave(session, w, r) + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) +} diff --git a/server/ctrladmin/middleware.go b/server/ctrladmin/middleware.go index cf56df3..ed84f02 100644 --- a/server/ctrladmin/middleware.go +++ b/server/ctrladmin/middleware.go @@ -24,8 +24,8 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler { session := r.Context().Value(key.Session).(*sessions.Session) username, ok := session.Values["user"].(string) if !ok { - sessAddFlashW("you are not authenticated", session) - sessLogSave(w, r, session) + sessAddFlashW(session, "you are not authenticated") + sessLogSave(session, w, r) http.Redirect(w, r, "/admin/login", http.StatusSeeOther) return } @@ -35,7 +35,7 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler { // the username in the client's session no longer relates to a // user in the database (maybe the user was deleted) session.Options.MaxAge = -1 - sessLogSave(w, r, session) + sessLogSave(session, w, r) http.Redirect(w, r, "/admin/login", http.StatusSeeOther) return } @@ -50,8 +50,8 @@ func (c *Controller) WithAdminSession(next http.Handler) http.Handler { session := r.Context().Value(key.Session).(*sessions.Session) user := r.Context().Value(key.User).(*model.User) if !user.IsAdmin { - sessAddFlashW("you are not an admin", session) - sessLogSave(w, r, session) + sessAddFlashW(session, "you are not an admin") + sessLogSave(session, w, r) http.Redirect(w, r, "/admin/login", http.StatusSeeOther) return } diff --git a/server/ctrlsubsonic/ctrl.go b/server/ctrlsubsonic/ctrl.go index 2b2c684..535803c 100644 --- a/server/ctrlsubsonic/ctrl.go +++ b/server/ctrlsubsonic/ctrl.go @@ -84,7 +84,7 @@ func (c *Controller) H(h subsonicHandler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { response := h(r) if response == nil { - log.Println("error: non raw subsonic handler returned a nil response\n") + log.Println("error: non raw subsonic handler returned a nil response") return } if err := writeResp(w, r, response); err != nil { diff --git a/server/ctrlsubsonic/testdata/db b/server/ctrlsubsonic/testdata/db index af72b97db4f6937f73a30ff100d0fd05d4bea200..ff3a467b14fc9b5ccdad3280099f3026b2c660a2 100644 GIT binary patch delta 22 dcmZo@U~gz(pCHW`G*QNxF{m+NYXakf{Qyz#2c`f3 delta 22 dcmZo@U~gz(pCHZXH&Mo!(XTOKYXakf{Qyy$2c7@` diff --git a/server/server.go b/server/server.go index 385c15e..7cc71aa 100644 --- a/server/server.go +++ b/server/server.go @@ -65,7 +65,7 @@ func (s *Server) SetupAdmin() error { routPublic.Use(ctrl.WithSession) routPublic.NotFoundHandler = ctrl.H(ctrl.ServeNotFound) routPublic.Handle("/login", ctrl.H(ctrl.ServeLogin)) - routPublic.Handle("/login_do", ctrl.H(ctrl.ServeLoginDo)) + routPublic.HandleFunc("/login_do", ctrl.ServeLoginDo) // "raw" handler, updates session assets.PrefixDo("static", func(path string, asset *assets.EmbeddedAsset) { _, name := filepath.Split(path) route := filepath.Join("/static", name) @@ -78,7 +78,7 @@ func (s *Server) SetupAdmin() error { // begin user routes (if session is valid) routUser := routPublic.NewRoute().Subrouter() routUser.Use(ctrl.WithUserSession) - routUser.Handle("/logout", ctrl.H(ctrl.ServeLogout)) + routUser.HandleFunc("/logout", ctrl.ServeLogout) // "raw" handler, updates session routUser.Handle("/home", ctrl.H(ctrl.ServeHome)) routUser.Handle("/change_own_password", ctrl.H(ctrl.ServeChangeOwnPassword)) routUser.Handle("/change_own_password_do", ctrl.H(ctrl.ServeChangeOwnPasswordDo))