add param abstraction to request context

This commit is contained in:
sentriz
2020-01-17 20:09:08 +00:00
parent 4af1e43389
commit 8e5d397082
13 changed files with 207 additions and 184 deletions

View File

@@ -19,7 +19,6 @@ import (
"senan.xyz/g/gonic/assets" "senan.xyz/g/gonic/assets"
"senan.xyz/g/gonic/model" "senan.xyz/g/gonic/model"
"senan.xyz/g/gonic/server/ctrlbase" "senan.xyz/g/gonic/server/ctrlbase"
"senan.xyz/g/gonic/server/key"
"senan.xyz/g/gonic/version" "senan.xyz/g/gonic/version"
) )
@@ -27,6 +26,13 @@ func init() {
gob.Register(&Flash{}) gob.Register(&Flash{})
} }
type CtxKey int
const (
CtxUser CtxKey = iota
CtxSession
)
// extendFromPaths /extends/ the given template for every asset // extendFromPaths /extends/ the given template for every asset
// with given prefix // with given prefix
func extendFromPaths(b *template.Template, p string) *template.Template { func extendFromPaths(b *template.Template, p string) *template.Template {
@@ -124,7 +130,7 @@ type Response struct {
func (c *Controller) H(h adminHandler) http.Handler { func (c *Controller) H(h adminHandler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
resp := h(r) resp := h(r)
session, ok := r.Context().Value(key.Session).(*sessions.Session) session, ok := r.Context().Value(CtxSession).(*sessions.Session)
if ok { if ok {
sessAddFlashN(session, resp.flashN) sessAddFlashN(session, resp.flashN)
sessAddFlashW(session, resp.flashW) sessAddFlashW(session, resp.flashW)
@@ -156,7 +162,7 @@ func (c *Controller) H(h adminHandler) http.Handler {
return return
} }
} }
if user, ok := r.Context().Value(key.User).(*model.User); ok { if user, ok := r.Context().Value(CtxUser).(*model.User); ok {
resp.data.User = user resp.data.User = user
} }
buff := c.buffPool.Get() buff := c.buffPool.Get()

View File

@@ -9,7 +9,6 @@ import (
"senan.xyz/g/gonic/model" "senan.xyz/g/gonic/model"
"senan.xyz/g/gonic/scanner" "senan.xyz/g/gonic/scanner"
"senan.xyz/g/gonic/server/key"
"senan.xyz/g/gonic/server/lastfm" "senan.xyz/g/gonic/server/lastfm"
) )
@@ -60,7 +59,7 @@ func (c *Controller) ServeHome(r *http.Request) *Response {
} }
// //
// playlists box // playlists box
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
c.DB. c.DB.
Select("*, count(items.id) as track_count"). Select("*, count(items.id) as track_count").
Joins(` Joins(`
@@ -92,7 +91,7 @@ func (c *Controller) ServeChangeOwnPasswordDo(r *http.Request) *Response {
flashW: []string{err.Error()}, flashW: []string{err.Error()},
} }
} }
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
user.Password = passwordOne user.Password = passwordOne
c.DB.Save(user) c.DB.Save(user)
return &Response{redirect: "/admin/home"} return &Response{redirect: "/admin/home"}
@@ -117,14 +116,14 @@ func (c *Controller) ServeLinkLastFMDo(r *http.Request) *Response {
flashW: []string{err.Error()}, flashW: []string{err.Error()},
} }
} }
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
user.LastFMSession = sessionKey user.LastFMSession = sessionKey
c.DB.Save(&user) c.DB.Save(&user)
return &Response{redirect: "/admin/home"} return &Response{redirect: "/admin/home"}
} }
func (c *Controller) ServeUnlinkLastFMDo(r *http.Request) *Response { func (c *Controller) ServeUnlinkLastFMDo(r *http.Request) *Response {
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
user.LastFMSession = "" user.LastFMSession = ""
c.DB.Save(&user) c.DB.Save(&user)
return &Response{redirect: "/admin/home"} return &Response{redirect: "/admin/home"}
@@ -241,7 +240,7 @@ func (c *Controller) ServeUpdateLastFMAPIKey(r *http.Request) *Response {
data.CurrentLastFMAPIKey = c.DB.GetSetting("lastfm_api_key") data.CurrentLastFMAPIKey = c.DB.GetSetting("lastfm_api_key")
data.CurrentLastFMAPISecret = c.DB.GetSetting("lastfm_secret") data.CurrentLastFMAPISecret = c.DB.GetSetting("lastfm_secret")
return &Response{ return &Response{
template: "update_lastfm_api_key.tmpl", template: "update_lastfm_api_key",
data: data, data: data,
} }
} }
@@ -285,7 +284,7 @@ func (c *Controller) ServeUploadPlaylistDo(r *http.Request) *Response {
code: 500, code: 500,
} }
} }
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
var playlistCount int var playlistCount int
var errors []string var errors []string
for _, headers := range r.MultipartForm.File { for _, headers := range r.MultipartForm.File {

View File

@@ -4,12 +4,10 @@ import (
"net/http" "net/http"
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"senan.xyz/g/gonic/server/key"
) )
func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value(key.Session).(*sessions.Session) session := r.Context().Value(CtxSession).(*sessions.Session)
username := r.FormValue("username") username := r.FormValue("username")
password := r.FormValue("password") password := r.FormValue("password")
if username == "" || password == "" { if username == "" || password == "" {
@@ -34,7 +32,7 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) {
} }
func (c *Controller) ServeLogout(w http.ResponseWriter, r *http.Request) { func (c *Controller) ServeLogout(w http.ResponseWriter, r *http.Request) {
session := r.Context().Value(key.Session).(*sessions.Session) session := r.Context().Value(CtxSession).(*sessions.Session)
session.Options.MaxAge = -1 session.Options.MaxAge = -1
sessLogSave(session, w, r) sessLogSave(session, w, r)
http.Redirect(w, r, "/admin/login", http.StatusSeeOther) http.Redirect(w, r, "/admin/login", http.StatusSeeOther)

View File

@@ -7,13 +7,12 @@ import (
"github.com/gorilla/sessions" "github.com/gorilla/sessions"
"senan.xyz/g/gonic/model" "senan.xyz/g/gonic/model"
"senan.xyz/g/gonic/server/key"
) )
func (c *Controller) WithSession(next http.Handler) http.Handler { func (c *Controller) WithSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
session, _ := c.sessDB.Get(r, "gonic") session, _ := c.sessDB.Get(r, "gonic")
withSession := context.WithValue(r.Context(), key.Session, session) withSession := context.WithValue(r.Context(), CtxSession, session)
next.ServeHTTP(w, r.WithContext(withSession)) next.ServeHTTP(w, r.WithContext(withSession))
}) })
} }
@@ -21,7 +20,7 @@ func (c *Controller) WithSession(next http.Handler) http.Handler {
func (c *Controller) WithUserSession(next http.Handler) http.Handler { func (c *Controller) WithUserSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// session exists at this point // session exists at this point
session := r.Context().Value(key.Session).(*sessions.Session) session := r.Context().Value(CtxSession).(*sessions.Session)
username, ok := session.Values["user"].(string) username, ok := session.Values["user"].(string)
if !ok { if !ok {
sessAddFlashW(session, []string{"you are not authenticated"}) sessAddFlashW(session, []string{"you are not authenticated"})
@@ -39,7 +38,7 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler {
http.Redirect(w, r, "/admin/login", http.StatusSeeOther) http.Redirect(w, r, "/admin/login", http.StatusSeeOther)
return return
} }
withUser := context.WithValue(r.Context(), key.User, user) withUser := context.WithValue(r.Context(), CtxUser, user)
next.ServeHTTP(w, r.WithContext(withUser)) next.ServeHTTP(w, r.WithContext(withUser))
}) })
} }
@@ -47,8 +46,8 @@ func (c *Controller) WithUserSession(next http.Handler) http.Handler {
func (c *Controller) WithAdminSession(next http.Handler) http.Handler { func (c *Controller) WithAdminSession(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
// session and user exist at this point // session and user exist at this point
session := r.Context().Value(key.Session).(*sessions.Session) session := r.Context().Value(CtxSession).(*sessions.Session)
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
if !user.IsAdmin { if !user.IsAdmin {
sessAddFlashW(session, []string{"you are not an admin"}) sessAddFlashW(session, []string{"you are not an admin"})
sessLogSave(session, w, r) sessLogSave(session, w, r)

View File

@@ -10,8 +10,16 @@ import (
"github.com/pkg/errors" "github.com/pkg/errors"
"senan.xyz/g/gonic/server/ctrlbase" "senan.xyz/g/gonic/server/ctrlbase"
"senan.xyz/g/gonic/server/ctrlsubsonic/params"
"senan.xyz/g/gonic/server/ctrlsubsonic/spec" "senan.xyz/g/gonic/server/ctrlsubsonic/spec"
"senan.xyz/g/gonic/server/parsing" )
type CtxKey int
const (
CtxUser CtxKey = iota
CtxSession
CtxParams
) )
type Controller struct { type Controller struct {
@@ -42,12 +50,10 @@ func (ew *errWriter) write(buf []byte) {
} }
func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) error { func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) error {
if resp.Error != nil {
w.WriteHeader(http.StatusBadRequest)
}
res := metaResponse{Response: resp} res := metaResponse{Response: resp}
params := r.Context().Value(CtxParams).(params.Params)
ew := &errWriter{w: w} ew := &errWriter{w: w}
switch parsing.GetStrParam(r, "f") { switch params.Get("f") {
case "json": case "json":
w.Header().Set("Content-Type", "application/json") w.Header().Set("Content-Type", "application/json")
data, err := json.Marshal(res) data, err := json.Marshal(res)
@@ -62,7 +68,7 @@ func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) erro
return errors.Wrap(err, "marshal to jsonp") return errors.Wrap(err, "marshal to jsonp")
} }
// TODO: error if no callback provided instead of using a default // TODO: error if no callback provided instead of using a default
pCall := parsing.GetStrParamOr(r, "callback", "cb") pCall := params.GetOr("callback", "cb")
ew.write([]byte(pCall)) ew.write([]byte(pCall))
ew.write([]byte("(")) ew.write([]byte("("))
ew.write(data) ew.write(data)

View File

@@ -9,9 +9,8 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"senan.xyz/g/gonic/model" "senan.xyz/g/gonic/model"
"senan.xyz/g/gonic/server/ctrlsubsonic/params"
"senan.xyz/g/gonic/server/ctrlsubsonic/spec" "senan.xyz/g/gonic/server/ctrlsubsonic/spec"
"senan.xyz/g/gonic/server/key"
"senan.xyz/g/gonic/server/parsing"
) )
// the subsonic spec metions "artist" a lot when talking about the // the subsonic spec metions "artist" a lot when talking about the
@@ -60,7 +59,8 @@ func (c *Controller) ServeGetIndexes(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response {
id, err := parsing.GetIntParam(r, "id") params := r.Context().Value(CtxParams).(params.Params)
id, err := params.GetInt("id")
if err != nil { if err != nil {
return spec.NewError(10, "please provide an `id` parameter") return spec.NewError(10, "please provide an `id` parameter")
} }
@@ -86,7 +86,7 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response {
Find(&childTracks) Find(&childTracks)
for _, c := range childTracks { for _, c := range childTracks {
toAppend := spec.NewTCTrackByFolder(c, folder) toAppend := spec.NewTCTrackByFolder(c, folder)
if parsing.GetStrParam(r, "c") == "Jamstash" { if params.Get("c") == "Jamstash" {
// jamstash thinks it can't play flacs // jamstash thinks it can't play flacs
toAppend.ContentType = "audio/mpeg" toAppend.ContentType = "audio/mpeg"
toAppend.Suffix = "mp3" toAppend.Suffix = "mp3"
@@ -103,7 +103,8 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response {
// changes to this function should be reflected in in _by_tags.go's // changes to this function should be reflected in in _by_tags.go's
// getAlbumListTwo() function // getAlbumListTwo() function
func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response {
listType := parsing.GetStrParam(r, "type") params := r.Context().Value(CtxParams).(params.Params)
listType := params.Get("type")
if listType == "" { if listType == "" {
return spec.NewError(10, "please provide a `type` parameter") return spec.NewError(10, "please provide a `type` parameter")
} }
@@ -117,7 +118,7 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response {
case "alphabeticalByName": case "alphabeticalByName":
q = q.Order("right_path") q = q.Order("right_path")
case "frequent": case "frequent":
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
q = q.Joins(` q = q.Joins(`
JOIN plays JOIN plays
ON albums.id = plays.album_id AND plays.user_id = ?`, ON albums.id = plays.album_id AND plays.user_id = ?`,
@@ -128,7 +129,7 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response {
case "random": case "random":
q = q.Order(gorm.Expr("random()")) q = q.Order(gorm.Expr("random()"))
case "recent": case "recent":
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
q = q.Joins(` q = q.Joins(`
JOIN plays JOIN plays
ON albums.id = plays.album_id AND plays.user_id = ?`, ON albums.id = plays.album_id AND plays.user_id = ?`,
@@ -140,8 +141,8 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response {
var folders []*model.Album var folders []*model.Album
q. q.
Where("albums.tag_artist_id IS NOT NULL"). Where("albums.tag_artist_id IS NOT NULL").
Offset(parsing.GetIntParamOr(r, "offset", 0)). Offset(params.GetIntOr("offset", 0)).
Limit(parsing.GetIntParamOr(r, "size", 10)). Limit(params.GetIntOr("size", 10)).
Preload("Parent"). Preload("Parent").
Find(&folders) Find(&folders)
sub := spec.NewResponse() sub := spec.NewResponse()
@@ -155,7 +156,8 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response {
query := parsing.GetStrParam(r, "query") params := r.Context().Value(CtxParams).(params.Params)
query := params.Get("query")
if query == "" { if query == "" {
return spec.NewError(10, "please provide a `query` parameter") return spec.NewError(10, "please provide a `query` parameter")
} }
@@ -170,8 +172,8 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response {
AND (right_path LIKE ? OR AND (right_path LIKE ? OR
right_path_u_dec LIKE ?) right_path_u_dec LIKE ?)
`, query, query). `, query, query).
Offset(parsing.GetIntParamOr(r, "artistOffset", 0)). Offset(params.GetIntOr("artistOffset", 0)).
Limit(parsing.GetIntParamOr(r, "artistCount", 20)). Limit(params.GetIntOr("artistCount", 20)).
Find(&artists) Find(&artists)
for _, a := range artists { for _, a := range artists {
results.Artists = append(results.Artists, results.Artists = append(results.Artists,
@@ -186,8 +188,8 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response {
AND (right_path LIKE ? OR AND (right_path LIKE ? OR
right_path_u_dec LIKE ?) right_path_u_dec LIKE ?)
`, query, query). `, query, query).
Offset(parsing.GetIntParamOr(r, "albumOffset", 0)). Offset(params.GetIntOr("albumOffset", 0)).
Limit(parsing.GetIntParamOr(r, "albumCount", 20)). Limit(params.GetIntOr("albumCount", 20)).
Find(&albums) Find(&albums)
for _, a := range albums { for _, a := range albums {
results.Albums = append(results.Albums, spec.NewTCAlbumByFolder(a)) results.Albums = append(results.Albums, spec.NewTCAlbumByFolder(a))
@@ -201,8 +203,8 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response {
filename LIKE ? OR filename LIKE ? OR
filename_u_dec LIKE ? filename_u_dec LIKE ?
`, query, query). `, query, query).
Offset(parsing.GetIntParamOr(r, "songOffset", 0)). Offset(params.GetIntOr("songOffset", 0)).
Limit(parsing.GetIntParamOr(r, "songCount", 20)). Limit(params.GetIntOr("songCount", 20)).
Find(&tracks) Find(&tracks)
for _, t := range tracks { for _, t := range tracks {
results.Tracks = append(results.Tracks, results.Tracks = append(results.Tracks,

View File

@@ -9,9 +9,8 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"senan.xyz/g/gonic/model" "senan.xyz/g/gonic/model"
"senan.xyz/g/gonic/server/ctrlsubsonic/params"
"senan.xyz/g/gonic/server/ctrlsubsonic/spec" "senan.xyz/g/gonic/server/ctrlsubsonic/spec"
"senan.xyz/g/gonic/server/key"
"senan.xyz/g/gonic/server/parsing"
) )
func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response { func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response {
@@ -52,7 +51,8 @@ func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response { func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response {
id, err := parsing.GetIntParam(r, "id") params := r.Context().Value(CtxParams).(params.Params)
id, err := params.GetInt("id")
if err != nil { if err != nil {
return spec.NewError(10, "please provide an `id` parameter") return spec.NewError(10, "please provide an `id` parameter")
} }
@@ -70,7 +70,8 @@ func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response { func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response {
id, err := parsing.GetIntParam(r, "id") params := r.Context().Value(CtxParams).(params.Params)
id, err := params.GetInt("id")
if err != nil { if err != nil {
return spec.NewError(10, "please provide an `id` parameter") return spec.NewError(10, "please provide an `id` parameter")
} }
@@ -97,7 +98,8 @@ func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response {
// changes to this function should be reflected in in _by_folder.go's // changes to this function should be reflected in in _by_folder.go's
// getAlbumList() function // getAlbumList() function
func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response {
listType := parsing.GetStrParam(r, "type") params := r.Context().Value(CtxParams).(params.Params)
listType := params.Get("type")
if listType == "" { if listType == "" {
return spec.NewError(10, "please provide a `type` parameter") return spec.NewError(10, "please provide a `type` parameter")
} }
@@ -113,11 +115,11 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response {
case "byYear": case "byYear":
q = q.Where( q = q.Where(
"tag_year BETWEEN ? AND ?", "tag_year BETWEEN ? AND ?",
parsing.GetIntParamOr(r, "fromYear", 1800), params.GetIntOr("fromYear", 1800),
parsing.GetIntParamOr(r, "toYear", 2200)) params.GetIntOr("toYear", 2200))
q = q.Order("tag_year") q = q.Order("tag_year")
case "frequent": case "frequent":
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
q = q.Joins(` q = q.Joins(`
JOIN plays JOIN plays
ON albums.id = plays.album_id AND plays.user_id = ?`, ON albums.id = plays.album_id AND plays.user_id = ?`,
@@ -128,7 +130,7 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response {
case "random": case "random":
q = q.Order(gorm.Expr("random()")) q = q.Order(gorm.Expr("random()"))
case "recent": case "recent":
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
q = q.Joins(` q = q.Joins(`
JOIN plays JOIN plays
ON albums.id = plays.album_id AND plays.user_id = ?`, ON albums.id = plays.album_id AND plays.user_id = ?`,
@@ -140,8 +142,8 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response {
var albums []*model.Album var albums []*model.Album
q. q.
Where("albums.tag_artist_id IS NOT NULL"). Where("albums.tag_artist_id IS NOT NULL").
Offset(parsing.GetIntParamOr(r, "offset", 0)). Offset(params.GetIntOr("offset", 0)).
Limit(parsing.GetIntParamOr(r, "size", 10)). Limit(params.GetIntOr("size", 10)).
Preload("TagArtist"). Preload("TagArtist").
Find(&albums) Find(&albums)
sub := spec.NewResponse() sub := spec.NewResponse()
@@ -155,7 +157,8 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response {
query := parsing.GetStrParam(r, "query") params := r.Context().Value(CtxParams).(params.Params)
query := params.Get("query")
if query == "" { if query == "" {
return spec.NewError(10, "please provide a `query` parameter") return spec.NewError(10, "please provide a `query` parameter")
} }
@@ -170,8 +173,8 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response {
name LIKE ? OR name LIKE ? OR
name_u_dec LIKE ? name_u_dec LIKE ?
`, query, query). `, query, query).
Offset(parsing.GetIntParamOr(r, "artistOffset", 0)). Offset(params.GetIntOr("artistOffset", 0)).
Limit(parsing.GetIntParamOr(r, "artistCount", 20)). Limit(params.GetIntOr("artistCount", 20)).
Find(&artists) Find(&artists)
for _, a := range artists { for _, a := range artists {
results.Artists = append(results.Artists, results.Artists = append(results.Artists,
@@ -186,8 +189,8 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response {
tag_title LIKE ? OR tag_title LIKE ? OR
tag_title_u_dec LIKE ? tag_title_u_dec LIKE ?
`, query, query). `, query, query).
Offset(parsing.GetIntParamOr(r, "albumOffset", 0)). Offset(params.GetIntOr("albumOffset", 0)).
Limit(parsing.GetIntParamOr(r, "albumCount", 20)). Limit(params.GetIntOr("albumCount", 20)).
Find(&albums) Find(&albums)
for _, a := range albums { for _, a := range albums {
results.Albums = append(results.Albums, results.Albums = append(results.Albums,
@@ -202,8 +205,8 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response {
tag_title LIKE ? OR tag_title LIKE ? OR
tag_title_u_dec LIKE ? tag_title_u_dec LIKE ?
`, query, query). `, query, query).
Offset(parsing.GetIntParamOr(r, "songOffset", 0)). Offset(params.GetIntOr("songOffset", 0)).
Limit(parsing.GetIntParamOr(r, "songCount", 20)). Limit(params.GetIntOr("songCount", 20)).
Find(&tracks) Find(&tracks)
for _, t := range tracks { for _, t := range tracks {
results.Tracks = append(results.Tracks, results.Tracks = append(results.Tracks,

View File

@@ -11,10 +11,9 @@ import (
"senan.xyz/g/gonic/model" "senan.xyz/g/gonic/model"
"senan.xyz/g/gonic/scanner" "senan.xyz/g/gonic/scanner"
"senan.xyz/g/gonic/server/ctrlsubsonic/params"
"senan.xyz/g/gonic/server/ctrlsubsonic/spec" "senan.xyz/g/gonic/server/ctrlsubsonic/spec"
"senan.xyz/g/gonic/server/key"
"senan.xyz/g/gonic/server/lastfm" "senan.xyz/g/gonic/server/lastfm"
"senan.xyz/g/gonic/server/parsing"
) )
func lowerUDecOrHash(in string) string { func lowerUDecOrHash(in string) string {
@@ -38,12 +37,13 @@ func (c *Controller) ServePing(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeScrobble(r *http.Request) *spec.Response { func (c *Controller) ServeScrobble(r *http.Request) *spec.Response {
id, err := parsing.GetIntParam(r, "id") params := r.Context().Value(CtxParams).(params.Params)
id, err := params.GetInt("id")
if err != nil { if err != nil {
return spec.NewError(10, "please provide an `id` parameter") return spec.NewError(10, "please provide an `id` parameter")
} }
// fetch user to get lastfm session // fetch user to get lastfm session
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
if user.LastFMSession == "" { if user.LastFMSession == "" {
return spec.NewError(0, "you don't have a last.fm session") return spec.NewError(0, "you don't have a last.fm session")
} }
@@ -61,8 +61,8 @@ func (c *Controller) ServeScrobble(r *http.Request) *spec.Response {
track, track,
// clients will provide time in miliseconds, so use that or // clients will provide time in miliseconds, so use that or
// instead convert UnixNano to miliseconds // instead convert UnixNano to miliseconds
parsing.GetIntParamOr(r, "time", int(time.Now().UnixNano()/1e6)), params.GetIntOr("time", int(time.Now().UnixNano()/1e6)),
parsing.GetStrParamOr(r, "submission", "true") != "false", params.GetOr("submission", "true") != "false",
) )
if err != nil { if err != nil {
return spec.NewError(0, "error when submitting: %v", err) return spec.NewError(0, "error when submitting: %v", err)
@@ -103,7 +103,7 @@ func (c *Controller) ServeGetScanStatus(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeGetUser(r *http.Request) *spec.Response { func (c *Controller) ServeGetUser(r *http.Request) *spec.Response {
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
sub := spec.NewResponse() sub := spec.NewResponse()
sub.User = &spec.User{ sub.User = &spec.User{
Username: user.Name, Username: user.Name,
@@ -119,7 +119,7 @@ func (c *Controller) ServeNotFound(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeGetPlaylists(r *http.Request) *spec.Response { func (c *Controller) ServeGetPlaylists(r *http.Request) *spec.Response {
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
var playlists []*model.Playlist var playlists []*model.Playlist
c.DB. c.DB.
Where("user_id = ?", user.ID). Where("user_id = ?", user.ID).
@@ -136,7 +136,8 @@ func (c *Controller) ServeGetPlaylists(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response { func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response {
playlistID, err := parsing.GetIntParam(r, "id") params := r.Context().Value(CtxParams).(params.Params)
playlistID, err := params.GetInt("id")
if err != nil { if err != nil {
return spec.NewError(10, "please provide an `id` parameter") return spec.NewError(10, "please provide an `id` parameter")
} }
@@ -159,7 +160,7 @@ func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response {
Order("playlist_items.created_at"). Order("playlist_items.created_at").
Preload("Album"). Preload("Album").
Find(&tracks) Find(&tracks)
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
sub := spec.NewResponse() sub := spec.NewResponse()
sub.Playlist = spec.NewPlaylist(&playlist) sub.Playlist = spec.NewPlaylist(&playlist)
sub.Playlist.Owner = user.Name sub.Playlist.Owner = user.Name
@@ -171,23 +172,32 @@ func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response {
} }
func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response { func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response {
playlistID, _ := parsing.GetFirstIntParamOf(r, "id", "playlistId") params := r.Context().Value(CtxParams).(params.Params)
user := r.Context().Value(CtxUser).(*model.User)
var playlistID int
for _, key := range []string{"id", "playlistId"} {
if val, err := params.GetInt(key); err != nil {
playlistID = val
}
}
// begin updating meta // begin updating meta
// playlist ID may still be 0 here, if so it's okay,
// we get a new playlist
playlist := &model.Playlist{} playlist := &model.Playlist{}
c.DB. c.DB.
Where("id = ?", playlistID). Where("id = ?", playlistID).
First(playlist) First(playlist)
user := r.Context().Value(key.User).(*model.User)
playlist.UserID = user.ID playlist.UserID = user.ID
if name := parsing.GetStrParam(r, "name"); name != "" { if val := params.Get("name"); val != "" {
playlist.Name = name playlist.Name = val
} }
if comment := parsing.GetStrParam(r, "comment"); comment != "" { if val := params.Get("comment"); val != "" {
playlist.Comment = comment playlist.Comment = val
} }
c.DB.Save(playlist) c.DB.Save(playlist)
// begin delete tracks // begin delete tracks
if indexes, ok := r.URL.Query()["songIndexToRemove"]; ok { indexes, ok := params.GetList("songIndexToRemove")
if ok {
trackIDs := []int{} trackIDs := []int{}
c.DB. c.DB.
Order("created_at"). Order("created_at").
@@ -204,24 +214,30 @@ func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response {
} }
} }
// begin add tracks // begin add tracks
if toAdd := parsing.GetFirstParamOf(r, "songId", "songIdToAdd"); toAdd != nil { var toAdd []string
for _, trackIDStr := range toAdd { for _, val := range []string{"songId", "songIdToAdd"} {
trackID, err := strconv.Atoi(trackIDStr) toAdd, ok := params.GetList(val)
if err != nil { if ok {
continue break
}
c.DB.Save(&model.PlaylistItem{
PlaylistID: playlist.ID,
TrackID: trackID,
})
} }
} }
for _, trackIDStr := range toAdd {
trackID, err := strconv.Atoi(trackIDStr)
if err != nil {
continue
}
c.DB.Save(&model.PlaylistItem{
PlaylistID: playlist.ID,
TrackID: trackID,
})
}
return spec.NewResponse() return spec.NewResponse()
} }
func (c *Controller) ServeDeletePlaylist(r *http.Request) *spec.Response { func (c *Controller) ServeDeletePlaylist(r *http.Request) *spec.Response {
params := r.Context().Value(CtxParams).(params.Params)
c.DB. c.DB.
Where("id = ?", parsing.GetIntParamOr(r, "id", 0)). Where("id = ?", params.GetIntOr("id", 0)).
Delete(&model.Playlist{}) Delete(&model.Playlist{})
return spec.NewResponse() return spec.NewResponse()
} }

View File

@@ -8,9 +8,8 @@ import (
"github.com/jinzhu/gorm" "github.com/jinzhu/gorm"
"senan.xyz/g/gonic/model" "senan.xyz/g/gonic/model"
"senan.xyz/g/gonic/server/ctrlsubsonic/params"
"senan.xyz/g/gonic/server/ctrlsubsonic/spec" "senan.xyz/g/gonic/server/ctrlsubsonic/spec"
"senan.xyz/g/gonic/server/key"
"senan.xyz/g/gonic/server/parsing"
) )
// "raw" handlers are ones that don't always return a spec response. // "raw" handlers are ones that don't always return a spec response.
@@ -20,7 +19,8 @@ import (
// _but not both_ // _but not both_
func (c *Controller) ServeGetCoverArt(w http.ResponseWriter, r *http.Request) *spec.Response { func (c *Controller) ServeGetCoverArt(w http.ResponseWriter, r *http.Request) *spec.Response {
id, err := parsing.GetIntParam(r, "id") params := r.Context().Value(CtxParams).(params.Params)
id, err := params.GetInt("id")
if err != nil { if err != nil {
return spec.NewError(10, "please provide an `id` parameter") return spec.NewError(10, "please provide an `id` parameter")
} }
@@ -46,7 +46,8 @@ func (c *Controller) ServeGetCoverArt(w http.ResponseWriter, r *http.Request) *s
} }
func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.Response { func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.Response {
id, err := parsing.GetIntParam(r, "id") params := r.Context().Value(CtxParams).(params.Params)
id, err := params.GetInt("id")
if err != nil { if err != nil {
return spec.NewError(10, "please provide an `id` parameter") return spec.NewError(10, "please provide an `id` parameter")
} }
@@ -67,7 +68,7 @@ func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.R
http.ServeFile(w, r, absPath) http.ServeFile(w, r, absPath)
// //
// after we've served the file, mark the album as played // after we've served the file, mark the album as played
user := r.Context().Value(key.User).(*model.User) user := r.Context().Value(CtxUser).(*model.User)
play := model.Play{ play := model.Play{
AlbumID: track.Album.ID, AlbumID: track.Album.ID,
UserID: user.ID, UserID: user.ID,

View File

@@ -6,28 +6,11 @@ import (
"encoding/hex" "encoding/hex"
"fmt" "fmt"
"net/http" "net/http"
"net/url"
"senan.xyz/g/gonic/server/ctrlsubsonic/params"
"senan.xyz/g/gonic/server/ctrlsubsonic/spec" "senan.xyz/g/gonic/server/ctrlsubsonic/spec"
"senan.xyz/g/gonic/server/key"
"senan.xyz/g/gonic/server/parsing"
) )
var requiredParameters = []string{
"u", "v", "c",
}
func checkHasAllParams(params url.Values) error {
for _, req := range requiredParameters {
param := params.Get(req)
if param != "" {
continue
}
return fmt.Errorf("please provide a `%s` parameter", req)
}
return nil
}
func checkCredsToken(password, token, salt string) bool { func checkCredsToken(password, token, salt string) bool {
toHash := fmt.Sprintf("%s%s", password, salt) toHash := fmt.Sprintf("%s%s", password, salt)
hash := md5.Sum([]byte(toHash)) hash := md5.Sum([]byte(toHash))
@@ -43,16 +26,34 @@ func checkCredsBasic(password, given string) bool {
return password == given return password == given
} }
func (c *Controller) WithValidSubsonicArgs(next http.Handler) http.Handler { func (c *Controller) WithParams(next http.Handler) http.Handler {
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
if err := checkHasAllParams(r.URL.Query()); err != nil { params := params.New(r)
writeResp(w, r, spec.NewError(10, err.Error())) withParams := context.WithValue(r.Context(), CtxParams, params)
next.ServeHTTP(w, r.WithContext(withParams))
})
}
func (c *Controller) WithUser(next http.Handler) http.Handler {
requiredParameters := []string{
"u", "v", "c",
}
return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) {
params := r.Context().Value(CtxParams).(params.Params)
for _, req := range requiredParameters {
param := params.Get(req)
if param != "" {
continue
}
writeResp(w, r, spec.NewError(10, "please provide a `%s` parameter", req))
return return
} }
username := parsing.GetStrParam(r, "u") //
password := parsing.GetStrParam(r, "p") username := params.Get("u")
token := parsing.GetStrParam(r, "t") password := params.Get("p")
salt := parsing.GetStrParam(r, "s") token := params.Get("t")
salt := params.Get("s")
//
passwordAuth := token == "" && salt == "" passwordAuth := token == "" && salt == ""
tokenAuth := password == "" tokenAuth := password == ""
if tokenAuth == passwordAuth { if tokenAuth == passwordAuth {
@@ -74,7 +75,7 @@ func (c *Controller) WithValidSubsonicArgs(next http.Handler) http.Handler {
writeResp(w, r, spec.NewError(40, "invalid password")) writeResp(w, r, spec.NewError(40, "invalid password"))
return return
} }
withUser := context.WithValue(r.Context(), key.User, user) withUser := context.WithValue(r.Context(), CtxUser, user)
next.ServeHTTP(w, r.WithContext(withUser)) next.ServeHTTP(w, r.WithContext(withUser))
}) })
} }

View File

@@ -0,0 +1,57 @@
package params
import (
"fmt"
"net/http"
"net/url"
"strconv"
)
type Params struct {
values url.Values
}
func New(r *http.Request) Params {
// first load params from the url
params := r.URL.Query()
// also if there's any in the post body, use those too
if err := r.ParseForm(); err != nil {
return Params{params}
}
for k, v := range r.Form {
params[k] = v
}
return Params{params}
}
func (p Params) Get(key string) string {
return p.values.Get(key)
}
func (p Params) GetOr(key, or string) string {
val := p.Get(key)
if val == "" {
return or
}
return val
}
func (p Params) GetInt(key string) (int, error) {
strVal := p.values.Get(key)
if strVal == "" {
return 0, fmt.Errorf("no param with key `%s`", key)
}
val, err := strconv.Atoi(strVal)
if err != nil {
return 0, fmt.Errorf("not an int `%s`", strVal)
}
return val, nil
}
func (p Params) GetIntOr(key string, or int) int {
val, err := p.GetInt(key)
if err != nil {
return or
}
return val
}

View File

@@ -1,8 +0,0 @@
package key
type Key int
const (
User Key = iota
Session
)

View File

@@ -1,57 +0,0 @@
package parsing
import (
"fmt"
"net/http"
"strconv"
)
func GetStrParam(r *http.Request, key string) string {
return r.URL.Query().Get(key)
}
func GetStrParamOr(r *http.Request, key, or string) string {
val := GetStrParam(r, key)
if val == "" {
return or
}
return val
}
func GetIntParam(r *http.Request, key string) (int, error) {
strVal := r.URL.Query().Get(key)
if strVal == "" {
return 0, fmt.Errorf("no param with key `%s`", key)
}
val, err := strconv.Atoi(strVal)
if err != nil {
return 0, fmt.Errorf("not an int `%s`", strVal)
}
return val, nil
}
func GetIntParamOr(r *http.Request, key string, or int) int {
val, err := GetIntParam(r, key)
if err != nil {
return or
}
return val
}
func GetFirstParamOf(r *http.Request, keys ...string) []string {
for _, key := range keys {
if val, ok := r.URL.Query()[key]; ok {
return val
}
}
return nil
}
func GetFirstIntParamOf(r *http.Request, keys ...string) (int, bool) {
for _, key := range keys {
if v, err := GetIntParam(r, key); err == nil {
return v, true
}
}
return 0, false
}