From 0932b5c0132ad6cacf3288d271efc31fc008fc8b Mon Sep 17 00:00:00 2001 From: sentriz Date: Thu, 18 Apr 2019 13:37:29 +0100 Subject: [PATCH] add user row to request context --- handler/admin.go | 30 +++++++++++------------------- handler/handler.go | 21 ++++++++++++++++++++- handler/middleware.go | 10 ++++++---- 3 files changed, 37 insertions(+), 24 deletions(-) diff --git a/handler/admin.go b/handler/admin.go index 3ada6dd..9e59508 100644 --- a/handler/admin.go +++ b/handler/admin.go @@ -25,15 +25,17 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, r.Header.Get("Referer"), 302) return } - var user db.User - c.DB.Where("name = ?", username).First(&user) + user := c.GetUserFromName(username) if !(username == user.Name && password == user.Password) { session.AddFlash("invalid username / password") session.Save(r, w) http.Redirect(w, r, r.Header.Get("Referer"), 302) return } - session.Values["user"] = user + // put the user name into the session. future endpoints after this one + // are wrapped with WithUserSession() which will get the name from the + // session and put the row into the request context. + session.Values["user"] = user.Name session.Save(r, w) http.Redirect(w, r, "/admin/home", 303) } @@ -72,7 +74,7 @@ func (c *Controller) ServeChangeOwnPasswordDo(w http.ResponseWriter, r *http.Req http.Redirect(w, r, r.Header.Get("Referer"), 302) return } - user, _ := session.Values["user"].(*db.User) + user := r.Context().Value("user").(*db.User) user.Password = passwordOne c.DB.Save(user) http.Redirect(w, r, "/admin/home", 303) @@ -100,7 +102,7 @@ func (c *Controller) ServeLinkLastFMCallback(w http.ResponseWriter, r *http.Requ http.Redirect(w, r, "/admin/home", 302) return } - user, _ := session.Values["user"].(*db.User) + user := r.Context().Value("user").(*db.User) user.LastFMSession = sessionKey c.DB.Save(&user) http.Redirect(w, r, "/admin/home", 302) @@ -183,12 +185,8 @@ func (c *Controller) ServeCreateUserDo(w http.ResponseWriter, r *http.Request) { func (c *Controller) ServeUpdateLastFMAPIKey(w http.ResponseWriter, r *http.Request) { var data templateData - var apiKey db.Setting - c.DB.Where("key = ?", "lastfm_api_key").First(&apiKey) - data.CurrentLastFMAPIKey = apiKey.Value - var secret db.Setting - c.DB.Where("key = ?", "lastfm_secret").First(&secret) - data.CurrentLastFMAPISecret = secret.Value + data.CurrentLastFMAPIKey = c.GetSetting("lastfm_api_key") + data.CurrentLastFMAPISecret = c.GetSetting("lastfm_secret") renderTemplate(w, r, "update_lastfm_api_key", &data) } @@ -203,13 +201,7 @@ func (c *Controller) ServeUpdateLastFMAPIKeyDo(w http.ResponseWriter, r *http.Re http.Redirect(w, r, r.Header.Get("Referer"), 302) return } - c.DB. - Where(db.Setting{Key: "lastfm_api_key"}). - Assign(db.Setting{Value: apiKey}). - FirstOrCreate(&db.Setting{}) - c.DB. - Where(db.Setting{Key: "lastfm_secret"}). - Assign(db.Setting{Value: secret}). - FirstOrCreate(&db.Setting{}) + c.SetSetting("lastfm_api_key", apiKey) + c.SetSetting("lastfm_secret", secret) http.Redirect(w, r, "/admin/home", 303) } diff --git a/handler/handler.go b/handler/handler.go index 9d39fcf..d87c43e 100644 --- a/handler/handler.go +++ b/handler/handler.go @@ -60,6 +60,25 @@ type Controller struct { SStore *gormstore.Store } +func (c *Controller) GetSetting(key string) string { + var setting db.Setting + c.DB.Where("key = ?", key).First(&setting) + return setting.Value +} + +func (c *Controller) SetSetting(key, value string) { + c.DB. + Where(db.Setting{Key: key}). + Assign(db.Setting{Value: value}). + FirstOrCreate(&db.Setting{}) +} + +func (c *Controller) GetUserFromName(name string) *db.User { + var user db.User + c.DB.Where("name = ?", name).First(&user) + return &user +} + type templateData struct { Flashes []interface{} User *db.User @@ -151,7 +170,7 @@ func renderTemplate(w http.ResponseWriter, r *http.Request, } data.Flashes = session.Flashes() session.Save(r, w) - user, ok := session.Values["user"].(*db.User) + user, ok := r.Context().Value("user").(*db.User) if ok { data.User = user } diff --git a/handler/middleware.go b/handler/middleware.go index 4bf73af..3fd3c3d 100644 --- a/handler/middleware.go +++ b/handler/middleware.go @@ -113,15 +113,17 @@ func (c *Controller) WithUserSession(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // session exists at this point session := r.Context().Value("session").(*sessions.Session) - _, ok := session.Values["user"] + username, ok := session.Values["user"].(string) if !ok { session.AddFlash("you are not authenticated") session.Save(r, w) http.Redirect(w, r, "/admin/login", 303) return } - withSession := context.WithValue(r.Context(), "session", session) - next.ServeHTTP(w, r.WithContext(withSession)) + // take username from sesion and add the user row + user := c.GetUserFromName(username) + withUser := context.WithValue(r.Context(), "user", user) + next.ServeHTTP(w, r.WithContext(withUser)) } } @@ -129,7 +131,7 @@ func (c *Controller) WithAdminSession(next http.HandlerFunc) http.HandlerFunc { return func(w http.ResponseWriter, r *http.Request) { // session and user exist at this point session := r.Context().Value("session").(*sessions.Session) - user := session.Values["user"].(*db.User) + user := r.Context().Value("user").(*db.User) if !user.IsAdmin { session.AddFlash("you are not an admin") session.Save(r, w)