diff --git a/server/ctrladmin/ctrl.go b/server/ctrladmin/ctrl.go index 715fc2a..a4b4b94 100644 --- a/server/ctrladmin/ctrl.go +++ b/server/ctrladmin/ctrl.go @@ -19,7 +19,6 @@ import ( "senan.xyz/g/gonic/assets" "senan.xyz/g/gonic/model" "senan.xyz/g/gonic/server/ctrlbase" - "senan.xyz/g/gonic/server/key" "senan.xyz/g/gonic/version" ) @@ -27,6 +26,13 @@ func init() { gob.Register(&Flash{}) } +type CtxKey int + +const ( + CtxUser CtxKey = iota + CtxSession +) + // extendFromPaths /extends/ the given template for every asset // with given prefix func extendFromPaths(b *template.Template, p string) *template.Template { @@ -124,7 +130,7 @@ type Response struct { func (c *Controller) H(h adminHandler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { resp := h(r) - session, ok := r.Context().Value(key.Session).(*sessions.Session) + session, ok := r.Context().Value(CtxSession).(*sessions.Session) if ok { sessAddFlashN(session, resp.flashN) sessAddFlashW(session, resp.flashW) @@ -156,7 +162,7 @@ func (c *Controller) H(h adminHandler) http.Handler { return } } - if user, ok := r.Context().Value(key.User).(*model.User); ok { + if user, ok := r.Context().Value(CtxUser).(*model.User); ok { resp.data.User = user } buff := c.buffPool.Get() diff --git a/server/ctrladmin/handlers.go b/server/ctrladmin/handlers.go index 620a38e..73dd5e5 100644 --- a/server/ctrladmin/handlers.go +++ b/server/ctrladmin/handlers.go @@ -9,7 +9,6 @@ import ( "senan.xyz/g/gonic/model" "senan.xyz/g/gonic/scanner" - "senan.xyz/g/gonic/server/key" "senan.xyz/g/gonic/server/lastfm" ) @@ -60,7 +59,7 @@ func (c *Controller) ServeHome(r *http.Request) *Response { } // // playlists box - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) c.DB. Select("*, count(items.id) as track_count"). Joins(` @@ -92,7 +91,7 @@ func (c *Controller) ServeChangeOwnPasswordDo(r *http.Request) *Response { flashW: []string{err.Error()}, } } - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) user.Password = passwordOne c.DB.Save(user) return &Response{redirect: "/admin/home"} @@ -117,14 +116,14 @@ func (c *Controller) ServeLinkLastFMDo(r *http.Request) *Response { flashW: []string{err.Error()}, } } - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) user.LastFMSession = sessionKey c.DB.Save(&user) return &Response{redirect: "/admin/home"} } func (c *Controller) ServeUnlinkLastFMDo(r *http.Request) *Response { - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) user.LastFMSession = "" c.DB.Save(&user) return &Response{redirect: "/admin/home"} @@ -241,7 +240,7 @@ func (c *Controller) ServeUpdateLastFMAPIKey(r *http.Request) *Response { data.CurrentLastFMAPIKey = c.DB.GetSetting("lastfm_api_key") data.CurrentLastFMAPISecret = c.DB.GetSetting("lastfm_secret") return &Response{ - template: "update_lastfm_api_key.tmpl", + template: "update_lastfm_api_key", data: data, } } @@ -285,7 +284,7 @@ func (c *Controller) ServeUploadPlaylistDo(r *http.Request) *Response { code: 500, } } - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) var playlistCount int var errors []string for _, headers := range r.MultipartForm.File { diff --git a/server/ctrladmin/handlers_raw.go b/server/ctrladmin/handlers_raw.go index 30c7a43..7a0f10d 100644 --- a/server/ctrladmin/handlers_raw.go +++ b/server/ctrladmin/handlers_raw.go @@ -4,12 +4,10 @@ import ( "net/http" "github.com/gorilla/sessions" - - "senan.xyz/g/gonic/server/key" ) func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { - session := r.Context().Value(key.Session).(*sessions.Session) + session := r.Context().Value(CtxSession).(*sessions.Session) username := r.FormValue("username") password := r.FormValue("password") if username == "" || password == "" { @@ -34,7 +32,7 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { } func (c *Controller) ServeLogout(w http.ResponseWriter, r *http.Request) { - session := r.Context().Value(key.Session).(*sessions.Session) + session := r.Context().Value(CtxSession).(*sessions.Session) session.Options.MaxAge = -1 sessLogSave(session, w, r) http.Redirect(w, r, "/admin/login", http.StatusSeeOther) diff --git a/server/ctrladmin/middleware.go b/server/ctrladmin/middleware.go index bac6689..e536ecd 100644 --- a/server/ctrladmin/middleware.go +++ b/server/ctrladmin/middleware.go @@ -7,13 +7,12 @@ import ( "github.com/gorilla/sessions" "senan.xyz/g/gonic/model" - "senan.xyz/g/gonic/server/key" ) func (c *Controller) WithSession(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { session, _ := c.sessDB.Get(r, "gonic") - withSession := context.WithValue(r.Context(), key.Session, session) + withSession := context.WithValue(r.Context(), CtxSession, session) next.ServeHTTP(w, r.WithContext(withSession)) }) } @@ -21,7 +20,7 @@ func (c *Controller) WithSession(next http.Handler) http.Handler { func (c *Controller) WithUserSession(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // session exists at this point - session := r.Context().Value(key.Session).(*sessions.Session) + session := r.Context().Value(CtxSession).(*sessions.Session) username, ok := session.Values["user"].(string) if !ok { sessAddFlashW(session, []string{"you are not authenticated"}) @@ -39,7 +38,7 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler { http.Redirect(w, r, "/admin/login", http.StatusSeeOther) return } - withUser := context.WithValue(r.Context(), key.User, user) + withUser := context.WithValue(r.Context(), CtxUser, user) next.ServeHTTP(w, r.WithContext(withUser)) }) } @@ -47,8 +46,8 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler { func (c *Controller) WithAdminSession(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { // session and user exist at this point - session := r.Context().Value(key.Session).(*sessions.Session) - user := r.Context().Value(key.User).(*model.User) + session := r.Context().Value(CtxSession).(*sessions.Session) + user := r.Context().Value(CtxUser).(*model.User) if !user.IsAdmin { sessAddFlashW(session, []string{"you are not an admin"}) sessLogSave(session, w, r) diff --git a/server/ctrlsubsonic/ctrl.go b/server/ctrlsubsonic/ctrl.go index 535803c..b78d000 100644 --- a/server/ctrlsubsonic/ctrl.go +++ b/server/ctrlsubsonic/ctrl.go @@ -10,8 +10,16 @@ import ( "github.com/pkg/errors" "senan.xyz/g/gonic/server/ctrlbase" + "senan.xyz/g/gonic/server/ctrlsubsonic/params" "senan.xyz/g/gonic/server/ctrlsubsonic/spec" - "senan.xyz/g/gonic/server/parsing" +) + +type CtxKey int + +const ( + CtxUser CtxKey = iota + CtxSession + CtxParams ) type Controller struct { @@ -42,12 +50,10 @@ func (ew *errWriter) write(buf []byte) { } func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) error { - if resp.Error != nil { - w.WriteHeader(http.StatusBadRequest) - } res := metaResponse{Response: resp} + params := r.Context().Value(CtxParams).(params.Params) ew := &errWriter{w: w} - switch parsing.GetStrParam(r, "f") { + switch params.Get("f") { case "json": w.Header().Set("Content-Type", "application/json") data, err := json.Marshal(res) @@ -62,7 +68,7 @@ func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) erro return errors.Wrap(err, "marshal to jsonp") } // TODO: error if no callback provided instead of using a default - pCall := parsing.GetStrParamOr(r, "callback", "cb") + pCall := params.GetOr("callback", "cb") ew.write([]byte(pCall)) ew.write([]byte("(")) ew.write(data) diff --git a/server/ctrlsubsonic/handlers_by_folder.go b/server/ctrlsubsonic/handlers_by_folder.go index 4f3f663..8c1ffed 100644 --- a/server/ctrlsubsonic/handlers_by_folder.go +++ b/server/ctrlsubsonic/handlers_by_folder.go @@ -9,9 +9,8 @@ import ( "github.com/jinzhu/gorm" "senan.xyz/g/gonic/model" + "senan.xyz/g/gonic/server/ctrlsubsonic/params" "senan.xyz/g/gonic/server/ctrlsubsonic/spec" - "senan.xyz/g/gonic/server/key" - "senan.xyz/g/gonic/server/parsing" ) // the subsonic spec metions "artist" a lot when talking about the @@ -60,7 +59,8 @@ func (c *Controller) ServeGetIndexes(r *http.Request) *spec.Response { } func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { - id, err := parsing.GetIntParam(r, "id") + params := r.Context().Value(CtxParams).(params.Params) + id, err := params.GetInt("id") if err != nil { return spec.NewError(10, "please provide an `id` parameter") } @@ -86,7 +86,7 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { Find(&childTracks) for _, c := range childTracks { toAppend := spec.NewTCTrackByFolder(c, folder) - if parsing.GetStrParam(r, "c") == "Jamstash" { + if params.Get("c") == "Jamstash" { // jamstash thinks it can't play flacs toAppend.ContentType = "audio/mpeg" toAppend.Suffix = "mp3" @@ -103,7 +103,8 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { // changes to this function should be reflected in in _by_tags.go's // getAlbumListTwo() function func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { - listType := parsing.GetStrParam(r, "type") + params := r.Context().Value(CtxParams).(params.Params) + listType := params.Get("type") if listType == "" { return spec.NewError(10, "please provide a `type` parameter") } @@ -117,7 +118,7 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { case "alphabeticalByName": q = q.Order("right_path") case "frequent": - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) q = q.Joins(` JOIN plays ON albums.id = plays.album_id AND plays.user_id = ?`, @@ -128,7 +129,7 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { case "random": q = q.Order(gorm.Expr("random()")) case "recent": - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) q = q.Joins(` JOIN plays ON albums.id = plays.album_id AND plays.user_id = ?`, @@ -140,8 +141,8 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { var folders []*model.Album q. Where("albums.tag_artist_id IS NOT NULL"). - Offset(parsing.GetIntParamOr(r, "offset", 0)). - Limit(parsing.GetIntParamOr(r, "size", 10)). + Offset(params.GetIntOr("offset", 0)). + Limit(params.GetIntOr("size", 10)). Preload("Parent"). Find(&folders) sub := spec.NewResponse() @@ -155,7 +156,8 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { } func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { - query := parsing.GetStrParam(r, "query") + params := r.Context().Value(CtxParams).(params.Params) + query := params.Get("query") if query == "" { return spec.NewError(10, "please provide a `query` parameter") } @@ -170,8 +172,8 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { AND (right_path LIKE ? OR right_path_u_dec LIKE ?) `, query, query). - Offset(parsing.GetIntParamOr(r, "artistOffset", 0)). - Limit(parsing.GetIntParamOr(r, "artistCount", 20)). + Offset(params.GetIntOr("artistOffset", 0)). + Limit(params.GetIntOr("artistCount", 20)). Find(&artists) for _, a := range artists { results.Artists = append(results.Artists, @@ -186,8 +188,8 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { AND (right_path LIKE ? OR right_path_u_dec LIKE ?) `, query, query). - Offset(parsing.GetIntParamOr(r, "albumOffset", 0)). - Limit(parsing.GetIntParamOr(r, "albumCount", 20)). + Offset(params.GetIntOr("albumOffset", 0)). + Limit(params.GetIntOr("albumCount", 20)). Find(&albums) for _, a := range albums { results.Albums = append(results.Albums, spec.NewTCAlbumByFolder(a)) @@ -201,8 +203,8 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { filename LIKE ? OR filename_u_dec LIKE ? `, query, query). - Offset(parsing.GetIntParamOr(r, "songOffset", 0)). - Limit(parsing.GetIntParamOr(r, "songCount", 20)). + Offset(params.GetIntOr("songOffset", 0)). + Limit(params.GetIntOr("songCount", 20)). Find(&tracks) for _, t := range tracks { results.Tracks = append(results.Tracks, diff --git a/server/ctrlsubsonic/handlers_by_tags.go b/server/ctrlsubsonic/handlers_by_tags.go index 13cc305..eeaa1ff 100644 --- a/server/ctrlsubsonic/handlers_by_tags.go +++ b/server/ctrlsubsonic/handlers_by_tags.go @@ -9,9 +9,8 @@ import ( "github.com/jinzhu/gorm" "senan.xyz/g/gonic/model" + "senan.xyz/g/gonic/server/ctrlsubsonic/params" "senan.xyz/g/gonic/server/ctrlsubsonic/spec" - "senan.xyz/g/gonic/server/key" - "senan.xyz/g/gonic/server/parsing" ) func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response { @@ -52,7 +51,8 @@ func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response { } func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response { - id, err := parsing.GetIntParam(r, "id") + params := r.Context().Value(CtxParams).(params.Params) + id, err := params.GetInt("id") if err != nil { return spec.NewError(10, "please provide an `id` parameter") } @@ -70,7 +70,8 @@ func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response { } func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response { - id, err := parsing.GetIntParam(r, "id") + params := r.Context().Value(CtxParams).(params.Params) + id, err := params.GetInt("id") if err != nil { return spec.NewError(10, "please provide an `id` parameter") } @@ -97,7 +98,8 @@ func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response { // changes to this function should be reflected in in _by_folder.go's // getAlbumList() function func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { - listType := parsing.GetStrParam(r, "type") + params := r.Context().Value(CtxParams).(params.Params) + listType := params.Get("type") if listType == "" { return spec.NewError(10, "please provide a `type` parameter") } @@ -113,11 +115,11 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { case "byYear": q = q.Where( "tag_year BETWEEN ? AND ?", - parsing.GetIntParamOr(r, "fromYear", 1800), - parsing.GetIntParamOr(r, "toYear", 2200)) + params.GetIntOr("fromYear", 1800), + params.GetIntOr("toYear", 2200)) q = q.Order("tag_year") case "frequent": - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) q = q.Joins(` JOIN plays ON albums.id = plays.album_id AND plays.user_id = ?`, @@ -128,7 +130,7 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { case "random": q = q.Order(gorm.Expr("random()")) case "recent": - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) q = q.Joins(` JOIN plays ON albums.id = plays.album_id AND plays.user_id = ?`, @@ -140,8 +142,8 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { var albums []*model.Album q. Where("albums.tag_artist_id IS NOT NULL"). - Offset(parsing.GetIntParamOr(r, "offset", 0)). - Limit(parsing.GetIntParamOr(r, "size", 10)). + Offset(params.GetIntOr("offset", 0)). + Limit(params.GetIntOr("size", 10)). Preload("TagArtist"). Find(&albums) sub := spec.NewResponse() @@ -155,7 +157,8 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { } func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { - query := parsing.GetStrParam(r, "query") + params := r.Context().Value(CtxParams).(params.Params) + query := params.Get("query") if query == "" { return spec.NewError(10, "please provide a `query` parameter") } @@ -170,8 +173,8 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { name LIKE ? OR name_u_dec LIKE ? `, query, query). - Offset(parsing.GetIntParamOr(r, "artistOffset", 0)). - Limit(parsing.GetIntParamOr(r, "artistCount", 20)). + Offset(params.GetIntOr("artistOffset", 0)). + Limit(params.GetIntOr("artistCount", 20)). Find(&artists) for _, a := range artists { results.Artists = append(results.Artists, @@ -186,8 +189,8 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { tag_title LIKE ? OR tag_title_u_dec LIKE ? `, query, query). - Offset(parsing.GetIntParamOr(r, "albumOffset", 0)). - Limit(parsing.GetIntParamOr(r, "albumCount", 20)). + Offset(params.GetIntOr("albumOffset", 0)). + Limit(params.GetIntOr("albumCount", 20)). Find(&albums) for _, a := range albums { results.Albums = append(results.Albums, @@ -202,8 +205,8 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { tag_title LIKE ? OR tag_title_u_dec LIKE ? `, query, query). - Offset(parsing.GetIntParamOr(r, "songOffset", 0)). - Limit(parsing.GetIntParamOr(r, "songCount", 20)). + Offset(params.GetIntOr("songOffset", 0)). + Limit(params.GetIntOr("songCount", 20)). Find(&tracks) for _, t := range tracks { results.Tracks = append(results.Tracks, diff --git a/server/ctrlsubsonic/handlers_common.go b/server/ctrlsubsonic/handlers_common.go index 353d60f..7c0484e 100644 --- a/server/ctrlsubsonic/handlers_common.go +++ b/server/ctrlsubsonic/handlers_common.go @@ -11,10 +11,9 @@ import ( "senan.xyz/g/gonic/model" "senan.xyz/g/gonic/scanner" + "senan.xyz/g/gonic/server/ctrlsubsonic/params" "senan.xyz/g/gonic/server/ctrlsubsonic/spec" - "senan.xyz/g/gonic/server/key" "senan.xyz/g/gonic/server/lastfm" - "senan.xyz/g/gonic/server/parsing" ) func lowerUDecOrHash(in string) string { @@ -38,12 +37,13 @@ func (c *Controller) ServePing(r *http.Request) *spec.Response { } func (c *Controller) ServeScrobble(r *http.Request) *spec.Response { - id, err := parsing.GetIntParam(r, "id") + params := r.Context().Value(CtxParams).(params.Params) + id, err := params.GetInt("id") if err != nil { return spec.NewError(10, "please provide an `id` parameter") } // fetch user to get lastfm session - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) if user.LastFMSession == "" { return spec.NewError(0, "you don't have a last.fm session") } @@ -61,8 +61,8 @@ func (c *Controller) ServeScrobble(r *http.Request) *spec.Response { track, // clients will provide time in miliseconds, so use that or // instead convert UnixNano to miliseconds - parsing.GetIntParamOr(r, "time", int(time.Now().UnixNano()/1e6)), - parsing.GetStrParamOr(r, "submission", "true") != "false", + params.GetIntOr("time", int(time.Now().UnixNano()/1e6)), + params.GetOr("submission", "true") != "false", ) if err != nil { return spec.NewError(0, "error when submitting: %v", err) @@ -103,7 +103,7 @@ func (c *Controller) ServeGetScanStatus(r *http.Request) *spec.Response { } func (c *Controller) ServeGetUser(r *http.Request) *spec.Response { - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) sub := spec.NewResponse() sub.User = &spec.User{ Username: user.Name, @@ -119,7 +119,7 @@ func (c *Controller) ServeNotFound(r *http.Request) *spec.Response { } func (c *Controller) ServeGetPlaylists(r *http.Request) *spec.Response { - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) var playlists []*model.Playlist c.DB. Where("user_id = ?", user.ID). @@ -136,7 +136,8 @@ func (c *Controller) ServeGetPlaylists(r *http.Request) *spec.Response { } func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response { - playlistID, err := parsing.GetIntParam(r, "id") + params := r.Context().Value(CtxParams).(params.Params) + playlistID, err := params.GetInt("id") if err != nil { return spec.NewError(10, "please provide an `id` parameter") } @@ -159,7 +160,7 @@ func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response { Order("playlist_items.created_at"). Preload("Album"). Find(&tracks) - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) sub := spec.NewResponse() sub.Playlist = spec.NewPlaylist(&playlist) sub.Playlist.Owner = user.Name @@ -171,23 +172,32 @@ func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response { } func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response { - playlistID, _ := parsing.GetFirstIntParamOf(r, "id", "playlistId") + params := r.Context().Value(CtxParams).(params.Params) + user := r.Context().Value(CtxUser).(*model.User) + var playlistID int + for _, key := range []string{"id", "playlistId"} { + if val, err := params.GetInt(key); err != nil { + playlistID = val + } + } // begin updating meta + // playlist ID may still be 0 here, if so it's okay, + // we get a new playlist playlist := &model.Playlist{} c.DB. Where("id = ?", playlistID). First(playlist) - user := r.Context().Value(key.User).(*model.User) playlist.UserID = user.ID - if name := parsing.GetStrParam(r, "name"); name != "" { - playlist.Name = name + if val := params.Get("name"); val != "" { + playlist.Name = val } - if comment := parsing.GetStrParam(r, "comment"); comment != "" { - playlist.Comment = comment + if val := params.Get("comment"); val != "" { + playlist.Comment = val } c.DB.Save(playlist) // begin delete tracks - if indexes, ok := r.URL.Query()["songIndexToRemove"]; ok { + indexes, ok := params.GetList("songIndexToRemove") + if ok { trackIDs := []int{} c.DB. Order("created_at"). @@ -204,24 +214,30 @@ func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response { } } // begin add tracks - if toAdd := parsing.GetFirstParamOf(r, "songId", "songIdToAdd"); toAdd != nil { - for _, trackIDStr := range toAdd { - trackID, err := strconv.Atoi(trackIDStr) - if err != nil { - continue - } - c.DB.Save(&model.PlaylistItem{ - PlaylistID: playlist.ID, - TrackID: trackID, - }) + var toAdd []string + for _, val := range []string{"songId", "songIdToAdd"} { + toAdd, ok := params.GetList(val) + if ok { + break } } + for _, trackIDStr := range toAdd { + trackID, err := strconv.Atoi(trackIDStr) + if err != nil { + continue + } + c.DB.Save(&model.PlaylistItem{ + PlaylistID: playlist.ID, + TrackID: trackID, + }) + } return spec.NewResponse() } func (c *Controller) ServeDeletePlaylist(r *http.Request) *spec.Response { + params := r.Context().Value(CtxParams).(params.Params) c.DB. - Where("id = ?", parsing.GetIntParamOr(r, "id", 0)). + Where("id = ?", params.GetIntOr("id", 0)). Delete(&model.Playlist{}) return spec.NewResponse() } diff --git a/server/ctrlsubsonic/handlers_raw.go b/server/ctrlsubsonic/handlers_raw.go index 9f0ed30..b097afd 100644 --- a/server/ctrlsubsonic/handlers_raw.go +++ b/server/ctrlsubsonic/handlers_raw.go @@ -8,9 +8,8 @@ import ( "github.com/jinzhu/gorm" "senan.xyz/g/gonic/model" + "senan.xyz/g/gonic/server/ctrlsubsonic/params" "senan.xyz/g/gonic/server/ctrlsubsonic/spec" - "senan.xyz/g/gonic/server/key" - "senan.xyz/g/gonic/server/parsing" ) // "raw" handlers are ones that don't always return a spec response. @@ -20,7 +19,8 @@ import ( // _but not both_ func (c *Controller) ServeGetCoverArt(w http.ResponseWriter, r *http.Request) *spec.Response { - id, err := parsing.GetIntParam(r, "id") + params := r.Context().Value(CtxParams).(params.Params) + id, err := params.GetInt("id") if err != nil { return spec.NewError(10, "please provide an `id` parameter") } @@ -46,7 +46,8 @@ func (c *Controller) ServeGetCoverArt(w http.ResponseWriter, r *http.Request) *s } func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.Response { - id, err := parsing.GetIntParam(r, "id") + params := r.Context().Value(CtxParams).(params.Params) + id, err := params.GetInt("id") if err != nil { return spec.NewError(10, "please provide an `id` parameter") } @@ -67,7 +68,7 @@ func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.R http.ServeFile(w, r, absPath) // // after we've served the file, mark the album as played - user := r.Context().Value(key.User).(*model.User) + user := r.Context().Value(CtxUser).(*model.User) play := model.Play{ AlbumID: track.Album.ID, UserID: user.ID, diff --git a/server/ctrlsubsonic/middleware.go b/server/ctrlsubsonic/middleware.go index 4037f48..b9f68a3 100644 --- a/server/ctrlsubsonic/middleware.go +++ b/server/ctrlsubsonic/middleware.go @@ -6,28 +6,11 @@ import ( "encoding/hex" "fmt" "net/http" - "net/url" + "senan.xyz/g/gonic/server/ctrlsubsonic/params" "senan.xyz/g/gonic/server/ctrlsubsonic/spec" - "senan.xyz/g/gonic/server/key" - "senan.xyz/g/gonic/server/parsing" ) -var requiredParameters = []string{ - "u", "v", "c", -} - -func checkHasAllParams(params url.Values) error { - for _, req := range requiredParameters { - param := params.Get(req) - if param != "" { - continue - } - return fmt.Errorf("please provide a `%s` parameter", req) - } - return nil -} - func checkCredsToken(password, token, salt string) bool { toHash := fmt.Sprintf("%s%s", password, salt) hash := md5.Sum([]byte(toHash)) @@ -43,16 +26,34 @@ func checkCredsBasic(password, given string) bool { return password == given } -func (c *Controller) WithValidSubsonicArgs(next http.Handler) http.Handler { +func (c *Controller) WithParams(next http.Handler) http.Handler { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := checkHasAllParams(r.URL.Query()); err != nil { - writeResp(w, r, spec.NewError(10, err.Error())) + params := params.New(r) + withParams := context.WithValue(r.Context(), CtxParams, params) + next.ServeHTTP(w, r.WithContext(withParams)) + }) +} + +func (c *Controller) WithUser(next http.Handler) http.Handler { + requiredParameters := []string{ + "u", "v", "c", + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + params := r.Context().Value(CtxParams).(params.Params) + for _, req := range requiredParameters { + param := params.Get(req) + if param != "" { + continue + } + writeResp(w, r, spec.NewError(10, "please provide a `%s` parameter", req)) return } - username := parsing.GetStrParam(r, "u") - password := parsing.GetStrParam(r, "p") - token := parsing.GetStrParam(r, "t") - salt := parsing.GetStrParam(r, "s") + // + username := params.Get("u") + password := params.Get("p") + token := params.Get("t") + salt := params.Get("s") + // passwordAuth := token == "" && salt == "" tokenAuth := password == "" if tokenAuth == passwordAuth { @@ -74,7 +75,7 @@ func (c *Controller) WithValidSubsonicArgs(next http.Handler) http.Handler { writeResp(w, r, spec.NewError(40, "invalid password")) return } - withUser := context.WithValue(r.Context(), key.User, user) + withUser := context.WithValue(r.Context(), CtxUser, user) next.ServeHTTP(w, r.WithContext(withUser)) }) } diff --git a/server/ctrlsubsonic/params/params.go b/server/ctrlsubsonic/params/params.go new file mode 100644 index 0000000..66220fd --- /dev/null +++ b/server/ctrlsubsonic/params/params.go @@ -0,0 +1,57 @@ +package params + +import ( + "fmt" + "net/http" + "net/url" + "strconv" +) + +type Params struct { + values url.Values +} + +func New(r *http.Request) Params { + // first load params from the url + params := r.URL.Query() + // also if there's any in the post body, use those too + if err := r.ParseForm(); err != nil { + return Params{params} + } + for k, v := range r.Form { + params[k] = v + } + return Params{params} +} + +func (p Params) Get(key string) string { + return p.values.Get(key) +} + +func (p Params) GetOr(key, or string) string { + val := p.Get(key) + if val == "" { + return or + } + return val +} + +func (p Params) GetInt(key string) (int, error) { + strVal := p.values.Get(key) + if strVal == "" { + return 0, fmt.Errorf("no param with key `%s`", key) + } + val, err := strconv.Atoi(strVal) + if err != nil { + return 0, fmt.Errorf("not an int `%s`", strVal) + } + return val, nil +} + +func (p Params) GetIntOr(key string, or int) int { + val, err := p.GetInt(key) + if err != nil { + return or + } + return val +} diff --git a/server/key/key.go b/server/key/key.go deleted file mode 100644 index 5f2b94f..0000000 --- a/server/key/key.go +++ /dev/null @@ -1,8 +0,0 @@ -package key - -type Key int - -const ( - User Key = iota - Session -) diff --git a/server/parsing/parsing.go b/server/parsing/parsing.go deleted file mode 100644 index 3f2438a..0000000 --- a/server/parsing/parsing.go +++ /dev/null @@ -1,57 +0,0 @@ -package parsing - -import ( - "fmt" - "net/http" - "strconv" -) - -func GetStrParam(r *http.Request, key string) string { - return r.URL.Query().Get(key) -} - -func GetStrParamOr(r *http.Request, key, or string) string { - val := GetStrParam(r, key) - if val == "" { - return or - } - return val -} - -func GetIntParam(r *http.Request, key string) (int, error) { - strVal := r.URL.Query().Get(key) - if strVal == "" { - return 0, fmt.Errorf("no param with key `%s`", key) - } - val, err := strconv.Atoi(strVal) - if err != nil { - return 0, fmt.Errorf("not an int `%s`", strVal) - } - return val, nil -} - -func GetIntParamOr(r *http.Request, key string, or int) int { - val, err := GetIntParam(r, key) - if err != nil { - return or - } - return val -} - -func GetFirstParamOf(r *http.Request, keys ...string) []string { - for _, key := range keys { - if val, ok := r.URL.Query()[key]; ok { - return val - } - } - return nil -} - -func GetFirstIntParamOf(r *http.Request, keys ...string) (int, bool) { - for _, key := range keys { - if v, err := GetIntParam(r, key); err == nil { - return v, true - } - } - return 0, false -}