diff --git a/cmd/gonic/gonic.go b/cmd/gonic/gonic.go index 96addf6..66f17a0 100644 --- a/cmd/gonic/gonic.go +++ b/cmd/gonic/gonic.go @@ -15,8 +15,11 @@ import ( "strings" "time" + // avatar encode/decode + _ "image/gif" + _ "image/png" + "github.com/google/shlex" - "github.com/gorilla/mux" "github.com/gorilla/securecookie" _ "github.com/jinzhu/gorm/dialects/sqlite" "github.com/oklog/run" @@ -25,6 +28,7 @@ import ( "go.senan.xyz/gonic" "go.senan.xyz/gonic/db" + "go.senan.xyz/gonic/handlerutil" "go.senan.xyz/gonic/jukebox" "go.senan.xyz/gonic/lastfm" "go.senan.xyz/gonic/listenbrainz" @@ -34,7 +38,6 @@ import ( "go.senan.xyz/gonic/scanner/tags" "go.senan.xyz/gonic/scrobble" "go.senan.xyz/gonic/server/ctrladmin" - "go.senan.xyz/gonic/server/ctrlbase" "go.senan.xyz/gonic/server/ctrlsubsonic" "go.senan.xyz/gonic/server/ctrlsubsonic/artistinfocache" "go.senan.xyz/gonic/transcode" @@ -166,7 +169,7 @@ func main() { tagger := &tags.TagReader{} scannr := scanner.New( - ctrlsubsonic.PathsOf(musicPaths), + ctrlsubsonic.MusicPaths(musicPaths), dbc, map[scanner.Tag]scanner.MultiValueSetting{ scanner.Genre: scanner.MultiValueSetting(confMultiValueGenre), @@ -218,37 +221,36 @@ func main() { artistInfoCache := artistinfocache.New(dbc, lastfmClient) - ctrlBase := &ctrlbase.Controller{ - DB: dbc, - PlaylistStore: playlistStore, - ProxyPrefix: *confProxyPrefix, - Scanner: scannr, + scrobblers := []scrobble.Scrobbler{lastfmClient, listenbrainzClient} + + resolveProxyPath := func(in string) string { + return path.Join(*confProxyPrefix, in) } - ctrlAdmin, err := ctrladmin.New(ctrlBase, sessDB, podcast, lastfmClient) + + ctrlAdmin, err := ctrladmin.New(dbc, sessDB, scannr, podcast, lastfmClient, resolveProxyPath) if err != nil { log.Panicf("error creating admin controller: %v\n", err) } - ctrlSubsonic := &ctrlsubsonic.Controller{ - Controller: ctrlBase, - MusicPaths: musicPaths, - PodcastsPath: *confPodcastPath, - CacheAudioPath: cacheDirAudio, - CacheCoverPath: cacheDirCovers, - LastFMClient: lastfmClient, - ArtistInfoCache: artistInfoCache, - Scrobblers: []scrobble.Scrobbler{ - lastfmClient, - listenbrainzClient, - }, - Podcasts: podcast, - Transcoder: transcoder, - Jukebox: jukebx, + ctrlSubsonic, err := ctrlsubsonic.New(dbc, scannr, musicPaths, *confPodcastPath, cacheDirAudio, cacheDirCovers, jukebx, playlistStore, scrobblers, podcast, transcoder, lastfmClient, artistInfoCache, resolveProxyPath) + if err != nil { + log.Panicf("error creating subsonic controller: %v\n", err) } - mux := mux.NewRouter() - ctrlbase.AddRoutes(ctrlBase, mux, *confHTTPLog) - ctrladmin.AddRoutes(ctrlAdmin, mux.PathPrefix("/admin").Subrouter()) - ctrlsubsonic.AddRoutes(ctrlSubsonic, mux.PathPrefix("/rest").Subrouter()) + chain := handlerutil.Chain() + if *confHTTPLog { + chain = handlerutil.Chain(handlerutil.Log) + } + chain = handlerutil.Chain( + chain, + handlerutil.BasicCORS, + ) + trim := handlerutil.TrimPathSuffix(".view") // /x.view and /x should match the same + + mux := http.NewServeMux() + mux.Handle("/admin/", http.StripPrefix("/admin", chain(ctrlAdmin))) + mux.Handle("/rest/", http.StripPrefix("/rest", chain(trim(ctrlSubsonic)))) + mux.Handle("/ping", chain(handlerutil.Message("ok"))) + mux.Handle("/", chain(handlerutil.Redirect(resolveProxyPath("/admin/home")))) if *confExpvar { mux.Handle("/debug/vars", expvar.Handler()) diff --git a/go.mod b/go.mod index bf28f87..7395f82 100644 --- a/go.mod +++ b/go.mod @@ -5,7 +5,6 @@ go 1.21 require ( github.com/Masterminds/sprig v2.22.0+incompatible github.com/andybalholm/cascadia v1.3.2 - github.com/davecgh/go-spew v1.1.1 github.com/dexterlb/mpvipc v0.0.0-20230829142118-145d6eabdc37 github.com/disintegration/imaging v1.6.2 github.com/dustin/go-humanize v1.0.1 @@ -13,8 +12,6 @@ require ( github.com/fsnotify/fsnotify v1.6.0 github.com/google/shlex v0.0.0-20191202100458-e7afc7fbc510 github.com/google/uuid v1.3.1 - github.com/gorilla/handlers v1.5.1 - github.com/gorilla/mux v1.8.0 github.com/gorilla/securecookie v1.1.1 github.com/gorilla/sessions v1.2.1 github.com/jinzhu/gorm v1.9.17-0.20211120011537-5c235b72a414 @@ -41,7 +38,7 @@ require ( github.com/Masterminds/goutils v1.1.1 // indirect github.com/Masterminds/semver v1.5.0 // indirect github.com/PuerkitoBio/goquery v1.8.1 // indirect - github.com/felixge/httpsnoop v1.0.3 // indirect + github.com/davecgh/go-spew v1.1.1 // indirect github.com/go-openapi/jsonpointer v0.19.5 // indirect github.com/go-openapi/swag v0.21.1 // indirect github.com/gorilla/context v1.1.1 // indirect diff --git a/go.sum b/go.sum index 234418e..d9026a2 100644 --- a/go.sum +++ b/go.sum @@ -30,9 +30,6 @@ github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5 h1:Yzb9+7DP github.com/erikstmartin/go-testdb v0.0.0-20160219214506-8d10e4a1bae5/go.mod h1:a2zkGnVExMxdzMo3M0Hi/3sEU+cWnZpSni0O6/Yb/P0= github.com/fatih/structs v1.1.0 h1:Q7juDM0QtcnhCpeyLGQKyg4TOIghuNXrkL32pHAUMxo= github.com/fatih/structs v1.1.0/go.mod h1:9NiDSp5zOcgEDl+j00MP/WkGVPOlPRLejGD8Ga6PJ7M= -github.com/felixge/httpsnoop v1.0.1/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= -github.com/felixge/httpsnoop v1.0.3 h1:s/nj+GCswXYzN5v2DpNMuMQYe+0DDwt5WVCU6CWBdXk= -github.com/felixge/httpsnoop v1.0.3/go.mod h1:m8KPJKqk1gH5J9DgRY2ASl2lWCfGKXixSwevea8zH2U= github.com/fsnotify/fsnotify v1.6.0 h1:n+5WquG0fcWoWp6xPWfHdbskMCQaFnG6PfBrh1Ky4HY= github.com/fsnotify/fsnotify v1.6.0/go.mod h1:sl3t1tCWJFWoRz9R8WJCbQihKKwmorjAbSClcnxKAGw= github.com/go-openapi/jsonpointer v0.19.5 h1:gZr+CIYByUqjcgeLXnQu2gHYQC9o73G2XUeOFYEICuY= @@ -55,10 +52,6 @@ github.com/google/uuid v1.3.1 h1:KjJaJ9iWZ3jOFZIf1Lqf4laDRCasjl0BCmnEGxkdLb4= github.com/google/uuid v1.3.1/go.mod h1:TIyPZe4MgqvfeYDBFedMoGGpEw/LqOeaOT+nhxU+yHo= github.com/gorilla/context v1.1.1 h1:AWwleXJkX/nhcU9bZSnZoi3h/qGYqQAGhq6zZe/aQW8= github.com/gorilla/context v1.1.1/go.mod h1:kBGZzfjB9CEq2AlWe17Uuf7NDRt0dE0s8S51q0aT7Yg= -github.com/gorilla/handlers v1.5.1 h1:9lRY6j8DEeeBT10CvO9hGW0gmky0BprnvDI5vfhUHH4= -github.com/gorilla/handlers v1.5.1/go.mod h1:t8XrUpc4KVXb7HGyJ4/cEnwQiaxrX/hz1Zv/4g96P1Q= -github.com/gorilla/mux v1.8.0 h1:i40aqfkR1h2SlN9hojwV5ZA91wcXFOvkdNIeFDP5koI= -github.com/gorilla/mux v1.8.0/go.mod h1:DVbg23sWSpFRCP0SfiEN6jmj59UnW/n46BH5rLB71So= github.com/gorilla/securecookie v1.1.1 h1:miw7JPhV+b/lAHSXz4qd/nN9jRiAFV5FwjeKyCS8BvQ= github.com/gorilla/securecookie v1.1.1/go.mod h1:ra0sb63/xPlUeL+yeDciTfxMRAA+MP+HVt/4epWDjd4= github.com/gorilla/sessions v1.2.1 h1:DHd3rPN5lE3Ts3D8rKkQ8x/0kqfeNmBAaiSi+o7FsgI= diff --git a/server/ctrlbase/ctrl.go b/handlerutil/handlerutil.go similarity index 50% rename from server/ctrlbase/ctrl.go rename to handlerutil/handlerutil.go index 00d854b..860f91d 100644 --- a/server/ctrlbase/ctrl.go +++ b/handlerutil/handlerutil.go @@ -1,17 +1,93 @@ -package ctrlbase +package handlerutil import ( "fmt" "log" "net/http" - "path" "strings" - - "go.senan.xyz/gonic/db" - "go.senan.xyz/gonic/playlist" - "go.senan.xyz/gonic/scanner" ) +type Middleware func(http.Handler) http.Handler + +func Chain(middlewares ...Middleware) Middleware { + return func(final http.Handler) http.Handler { + for i := len(middlewares) - 1; i >= 0; i-- { + final = middlewares[i](final) + } + return final + } +} + +func TrimPathSuffix(suffix string) Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + r.URL.Path = strings.TrimSuffix(r.URL.Path, suffix) + next.ServeHTTP(w, r) + }) + } +} + +func Log(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + sw := &statusWriter{ResponseWriter: w} + next.ServeHTTP(sw, r) + log.Printf("response %s %s %v", statusToBlock(sw.status), r.Method, r.URL) + }) +} + +func BasicCORS(next http.Handler) http.Handler { + allowMethods := strings.Join( + []string{http.MethodPost, http.MethodGet, http.MethodOptions, http.MethodPut, http.MethodDelete}, + ", ", + ) + allowHeaders := strings.Join( + []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"}, + ", ", + ) + + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + w.Header().Set("Access-Control-Allow-Origin", "*") + w.Header().Set("Access-Control-Allow-Methods", allowMethods) + w.Header().Set("Access-Control-Allow-Headers", allowHeaders) + if r.Method == http.MethodOptions { + return + } + next.ServeHTTP(w, r) + }) +} + +func Redirect(to string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + http.Redirect(w, r, to, http.StatusSeeOther) + }) +} + +func Message(message string) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + fmt.Fprintln(w, message) + }) +} + +func BaseURL(r *http.Request) string { + fallbackProtocoll := "http" + if r.TLS != nil { + fallbackProtocoll = "https" + } + fallbackHost := "localhost:4747" + scheme := first( + r.Header.Get("X-Forwarded-Proto"), + r.Header.Get("X-Forwarded-Scheme"), + r.URL.Scheme, + fallbackProtocoll, + ) + host := first( + r.Header.Get("X-Forwarded-Host"), + r.Host, + fallbackHost, + ) + return fmt.Sprintf("%s://%s", scheme, host) +} + type statusWriter struct { http.ResponseWriter status int @@ -32,94 +108,26 @@ func (w *statusWriter) Write(b []byte) (int, error) { func statusToBlock(code int) string { var bg int switch { - case 200 <= code && code <= 299: - bg = 42 // bright green, ok - case 300 <= code && code <= 399: - bg = 46 // bright cyan, redirect - case 400 <= code && code <= 499: - bg = 43 // bright orange, client error - case 500 <= code && code <= 599: - bg = 41 // bright red, server error + case code >= 500: + bg = 41 // bright red + case code >= 400: + bg = 43 // bright orange + case code >= 300: + bg = 46 // bright cyan + case code >= 200: + bg = 42 // bright green default: bg = 47 // bright white (grey) } return fmt.Sprintf("\u001b[%d;1m %d \u001b[0m", bg, code) } -type Controller struct { - DB *db.DB - PlaylistStore *playlist.Store - Scanner *scanner.Scanner - ProxyPrefix string -} - -// Path returns a URL path with the proxy prefix included -func (c *Controller) Path(rel string) string { - return path.Join(c.ProxyPrefix, rel) -} - -func (c *Controller) BaseURL(r *http.Request) string { - protocol := "http" - if r.TLS != nil { - protocol = "https" - } - scheme := firstExisting( - protocol, // fallback - r.Header.Get("X-Forwarded-Proto"), - r.Header.Get("X-Forwarded-Scheme"), - r.URL.Scheme, - ) - host := firstExisting( - "localhost:4747", // fallback - r.Header.Get("X-Forwarded-Host"), - r.Host, - ) - return fmt.Sprintf("%s://%s", scheme, host) -} - -func (c *Controller) WithLogging(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // this is (should be) the first middleware. pass right though it - // by calling `next` first instead of last. when it completes all - // other middlewares and the custom ResponseWriter has been written - sw := &statusWriter{ResponseWriter: w} - next.ServeHTTP(sw, r) - - // sanitise password - if q := r.URL.Query(); q.Get("p") != "" { - q.Set("p", "REDACTED") - r.URL.RawQuery = q.Encode() - } - log.Printf("response %s for `%v`", statusToBlock(sw.status), r.URL) - }) -} - -func (c *Controller) WithCORS(next http.Handler) http.Handler { - allowMethods := strings.Join( - []string{http.MethodPost, http.MethodGet, http.MethodOptions, http.MethodPut, http.MethodDelete}, - ", ", - ) - allowHeaders := strings.Join( - []string{"Accept", "Content-Type", "Content-Length", "Accept-Encoding", "X-CSRF-Token", "Authorization"}, - ", ", - ) - - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - w.Header().Set("Access-Control-Allow-Origin", "*") - w.Header().Set("Access-Control-Allow-Methods", allowMethods) - w.Header().Set("Access-Control-Allow-Headers", allowHeaders) - if r.Method == http.MethodOptions { - return - } - next.ServeHTTP(w, r) - }) -} - -func firstExisting(or string, strings ...string) string { - for _, s := range strings { - if s != "" { +func first[T comparable](vs ...T) T { + var z T + for _, s := range vs { + if s != z { return s } } - return or + return z } diff --git a/server/ctrladmin/ctrl.go b/server/ctrladmin/ctrl.go index 6a22e95..e239de1 100644 --- a/server/ctrladmin/ctrl.go +++ b/server/ctrladmin/ctrl.go @@ -1,6 +1,8 @@ package ctrladmin import ( + "context" + "embed" "encoding/base64" "encoding/gob" "encoding/json" @@ -22,10 +24,11 @@ import ( "go.senan.xyz/gonic" "go.senan.xyz/gonic/db" + "go.senan.xyz/gonic/handlerutil" "go.senan.xyz/gonic/lastfm" "go.senan.xyz/gonic/podcasts" + "go.senan.xyz/gonic/scanner" "go.senan.xyz/gonic/server/ctrladmin/adminui" - "go.senan.xyz/gonic/server/ctrlbase" ) type CtxKey int @@ -35,6 +38,263 @@ const ( CtxSession ) +type Controller struct { + *http.ServeMux + + dbc *db.DB + sessDB *gormstore.Store + scanner *scanner.Scanner + podcasts *podcasts.Podcasts + lastfmClient *lastfm.Client + resolveProxyPath ProxyPathResolver +} + +type ProxyPathResolver func(in string) string + +func New(dbc *db.DB, sessDB *gormstore.Store, scanner *scanner.Scanner, podcasts *podcasts.Podcasts, lastfmClient *lastfm.Client, resolveProxyPath ProxyPathResolver) (*Controller, error) { + c := Controller{ + ServeMux: http.NewServeMux(), + + dbc: dbc, + sessDB: sessDB, + scanner: scanner, + podcasts: podcasts, + lastfmClient: lastfmClient, + resolveProxyPath: resolveProxyPath, + } + + resp := respHandler(adminui.TemplatesFS, resolveProxyPath) + + baseChain := withSession(sessDB) + userChain := handlerutil.Chain( + baseChain, + withUserSession(dbc, resolveProxyPath), + ) + adminChain := handlerutil.Chain( + userChain, + withAdminSession, + ) + + c.Handle("/static/", http.FileServer(http.FS(adminui.StaticFS))) + + // public routes (creates session) + c.Handle("/login", baseChain(resp(c.ServeLogin))) + c.Handle("/login_do", baseChain(respRaw(c.ServeLoginDo))) + + // user routes (if session is valid) + c.Handle("/logout", userChain(respRaw(c.ServeLogout))) + c.Handle("/home", userChain(resp(c.ServeHome))) + c.Handle("/change_username", userChain(resp(c.ServeChangeUsername))) + c.Handle("/change_username_do", userChain(resp(c.ServeChangeUsernameDo))) + c.Handle("/change_password", userChain(resp(c.ServeChangePassword))) + c.Handle("/change_password_do", userChain(resp(c.ServeChangePasswordDo))) + c.Handle("/change_avatar", userChain(resp(c.ServeChangeAvatar))) + c.Handle("/change_avatar_do", userChain(resp(c.ServeChangeAvatarDo))) + c.Handle("/delete_avatar_do", userChain(resp(c.ServeDeleteAvatarDo))) + c.Handle("/delete_user", userChain(resp(c.ServeDeleteUser))) + c.Handle("/delete_user_do", userChain(resp(c.ServeDeleteUserDo))) + c.Handle("/link_lastfm_do", userChain(resp(c.ServeLinkLastFMDo))) + c.Handle("/unlink_lastfm_do", userChain(resp(c.ServeUnlinkLastFMDo))) + c.Handle("/link_listenbrainz_do", userChain(resp(c.ServeLinkListenBrainzDo))) + c.Handle("/unlink_listenbrainz_do", userChain(resp(c.ServeUnlinkListenBrainzDo))) + c.Handle("/create_transcode_pref_do", userChain(resp(c.ServeCreateTranscodePrefDo))) + c.Handle("/delete_transcode_pref_do", userChain(resp(c.ServeDeleteTranscodePrefDo))) + + // admin routes (if session is valid, and is admin) + c.Handle("/create_user", adminChain(resp(c.ServeCreateUser))) + c.Handle("/create_user_do", adminChain(resp(c.ServeCreateUserDo))) + c.Handle("/update_lastfm_api_key", adminChain(resp(c.ServeUpdateLastFMAPIKey))) + c.Handle("/update_lastfm_api_key_do", adminChain(resp(c.ServeUpdateLastFMAPIKeyDo))) + c.Handle("/start_scan_inc_do", adminChain(resp(c.ServeStartScanIncDo))) + c.Handle("/start_scan_full_do", adminChain(resp(c.ServeStartScanFullDo))) + c.Handle("/add_podcast_do", adminChain(resp(c.ServePodcastAddDo))) + c.Handle("/delete_podcast_do", adminChain(resp(c.ServePodcastDeleteDo))) + c.Handle("/download_podcast_do", adminChain(resp(c.ServePodcastDownloadDo))) + c.Handle("/update_podcast_do", adminChain(resp(c.ServePodcastUpdateDo))) + c.Handle("/add_internet_radio_station_do", adminChain(resp(c.ServeInternetRadioStationAddDo))) + c.Handle("/delete_internet_radio_station_do", adminChain(resp(c.ServeInternetRadioStationDeleteDo))) + c.Handle("/update_internet_radio_station_do", adminChain(resp(c.ServeInternetRadioStationUpdateDo))) + + c.Handle("/", baseChain(resp(c.ServeNotFound))) + + return &c, nil +} + +func withSession(sessDB *gormstore.Store) handlerutil.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + session, err := sessDB.Get(r, gonic.Name) + if err != nil { + http.Error(w, fmt.Sprintf("error getting session: %s", err), 500) + return + } + withSession := context.WithValue(r.Context(), CtxSession, session) + next.ServeHTTP(w, r.WithContext(withSession)) + }) + } +} + +func withUserSession(dbc *db.DB, resolvePath func(string) string) handlerutil.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // session exists at this point + session := r.Context().Value(CtxSession).(*sessions.Session) + userID, ok := session.Values["user"].(int) + if !ok { + sessAddFlashW(session, []string{"you are not authenticated"}) + sessLogSave(session, w, r) + http.Redirect(w, r, resolvePath("/admin/login"), http.StatusSeeOther) + return + } + // take username from sesion and add the user row to the context + user := dbc.GetUserByID(userID) + if user == nil { + // the username in the client's session no longer relates to a + // user in the database (maybe the user was deleted) + session.Options.MaxAge = -1 + sessLogSave(session, w, r) + http.Redirect(w, r, "/", http.StatusSeeOther) + return + } + withUser := context.WithValue(r.Context(), CtxUser, user) + next.ServeHTTP(w, r.WithContext(withUser)) + }) + } +} + +func withAdminSession(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + // session and user exist at this point + session := r.Context().Value(CtxSession).(*sessions.Session) + user := r.Context().Value(CtxUser).(*db.User) + if !user.IsAdmin { + sessAddFlashW(session, []string{"you are not an admin"}) + sessLogSave(session, w, r) + http.Redirect(w, r, "/admin/login", http.StatusSeeOther) + return + } + next.ServeHTTP(w, r) + }) +} + +type Response struct { + // code is 200 + template string + data *templateData + // code is 303 + redirect string + flashN []string // normal + flashW []string // warning + // code is >= 400 + code int + err string +} + +type ( + handlerAdmin func(r *http.Request) *Response +) + +func respHandler(templateFS embed.FS, resolvePath func(string) string) func(next handlerAdmin) http.Handler { + tmpl := template.Must(template. + New("layout"). + Funcs(template.FuncMap(sprig.FuncMap())). + Funcs(funcMap()). + Funcs(template.FuncMap{"path": resolvePath}). + ParseFS(templateFS, "*.tmpl", "**/*.tmpl"), + ) + buffPool := bpool.NewBufferPool(64) + + return func(next handlerAdmin) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + resp := next(r) + session, ok := r.Context().Value(CtxSession).(*sessions.Session) + if ok { + sessAddFlashN(session, resp.flashN) + sessAddFlashW(session, resp.flashW) + if err := session.Save(r, w); err != nil { + http.Error(w, fmt.Sprintf("error saving session: %v", err), 500) + return + } + } + if resp.redirect != "" { + http.Redirect(w, r, resolvePath(resp.redirect), http.StatusSeeOther) + return + } + if resp.err != "" { + http.Error(w, resp.err, resp.code) + return + } + if resp.template == "" { + http.Error(w, "useless handler return", 500) + return + } + + if resp.data == nil { + resp.data = &templateData{} + } + resp.data.Version = gonic.Version + if session != nil { + resp.data.Flashes = session.Flashes() + if err := session.Save(r, w); err != nil { + http.Error(w, fmt.Sprintf("error saving session: %v", err), 500) + return + } + } + if user, ok := r.Context().Value(CtxUser).(*db.User); ok { + resp.data.User = user + } + + buff := buffPool.Get() + defer buffPool.Put(buff) + if err := tmpl.ExecuteTemplate(buff, resp.template, resp.data); err != nil { + http.Error(w, fmt.Sprintf("executing template: %v", err), 500) + return + } + + w.Header().Set("Content-Type", "text/html; charset=utf-8") + if resp.code != 0 { + w.WriteHeader(resp.code) + } + if _, err := buff.WriteTo(w); err != nil { + log.Printf("error writing to response buffer: %v\n", err) + } + }) + } +} + +func respRaw(h http.HandlerFunc) http.Handler { + return h // stub +} + +type templateData struct { + // common + Flashes []interface{} + User *db.User + Version string + // home + AlbumCount int + ArtistCount int + TrackCount int + RequestRoot string + RecentFolders []*db.Album + AllUsers []*db.User + LastScanTime time.Time + IsScanning bool + TranscodePreferences []*db.TranscodePreference + TranscodeProfiles []string + + CurrentLastFMAPIKey string + CurrentLastFMAPISecret string + DefaultListenBrainzURL string + SelectedUser *db.User + + Podcasts []*db.Podcast + InternetRadioStations []*db.InternetRadioStation + + // avatar + Avatar []byte +} + func funcMap() template.FuncMap { return template.FuncMap{ "str": func(in any) string { @@ -72,153 +332,7 @@ func funcMap() template.FuncMap { } } -type Controller struct { - *ctrlbase.Controller - buffPool *bpool.BufferPool - template *template.Template - sessDB *gormstore.Store - Podcasts *podcasts.Podcasts - lastfmClient *lastfm.Client -} - -func New(b *ctrlbase.Controller, sessDB *gormstore.Store, podcasts *podcasts.Podcasts, lastfmClient *lastfm.Client) (*Controller, error) { - tmpl, err := template. - New("layout"). - Funcs(template.FuncMap(sprig.FuncMap())). - Funcs(funcMap()). // static - Funcs(template.FuncMap{ // from base - "path": b.Path, - }). - ParseFS(adminui.TemplatesFS, "*.tmpl", "**/*.tmpl") - if err != nil { - return nil, fmt.Errorf("build template: %w", err) - } - return &Controller{ - Controller: b, - buffPool: bpool.NewBufferPool(64), - template: tmpl, - sessDB: sessDB, - Podcasts: podcasts, - lastfmClient: lastfmClient, - }, nil -} - -type templateData struct { - // common - Flashes []interface{} - User *db.User - Version string - // home - AlbumCount int - ArtistCount int - TrackCount int - RequestRoot string - RecentFolders []*db.Album - AllUsers []*db.User - LastScanTime time.Time - IsScanning bool - TranscodePreferences []*db.TranscodePreference - TranscodeProfiles []string - - CurrentLastFMAPIKey string - CurrentLastFMAPISecret string - DefaultListenBrainzURL string - SelectedUser *db.User - - Podcasts []*db.Podcast - InternetRadioStations []*db.InternetRadioStation - - // avatar - Avatar []byte -} - -type Response struct { - // code is 200 - template string - data *templateData - // code is 303 - redirect string - flashN []string // normal - flashW []string // warning - // code is >= 400 - code int - err string -} - -type ( - handlerAdmin func(r *http.Request) *Response - handlerAdminRaw func(w http.ResponseWriter, r *http.Request) -) - -func (c *Controller) H(h handlerAdmin) http.Handler { - // TODO: break this up a bit - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - resp := h(r) - session, ok := r.Context().Value(CtxSession).(*sessions.Session) - if ok { - sessAddFlashN(session, resp.flashN) - sessAddFlashW(session, resp.flashW) - if err := session.Save(r, w); err != nil { - http.Error(w, fmt.Sprintf("error saving session: %v", err), 500) - return - } - } - if resp.redirect != "" { - to := resp.redirect - if strings.HasPrefix(to, "/") { - to = c.Path(to) - } - http.Redirect(w, r, to, http.StatusSeeOther) - return - } - if resp.err != "" { - http.Error(w, resp.err, resp.code) - return - } - if resp.template == "" { - http.Error(w, "useless handler return", 500) - return - } - - if resp.data == nil { - resp.data = &templateData{} - } - resp.data.Version = gonic.Version - if session != nil { - resp.data.Flashes = session.Flashes() - if err := session.Save(r, w); err != nil { - http.Error(w, fmt.Sprintf("error saving session: %v", err), 500) - return - } - } - if user, ok := r.Context().Value(CtxUser).(*db.User); ok { - resp.data.User = user - } - - buff := c.buffPool.Get() - defer c.buffPool.Put(buff) - if err := c.template.ExecuteTemplate(buff, resp.template, resp.data); err != nil { - http.Error(w, fmt.Sprintf("executing template: %v", err), 500) - return - } - - w.Header().Set("Content-Type", "text/html; charset=utf-8") - if resp.code != 0 { - w.WriteHeader(resp.code) - } - if _, err := buff.WriteTo(w); err != nil { - log.Printf("error writing to response buffer: %v\n", err) - } - }) -} - -func (c *Controller) HR(h handlerAdminRaw) http.Handler { - return http.HandlerFunc(h) -} - -// ## begin utilities -// ## begin utilities -// ## begin utilities +// utilities type FlashType string @@ -268,9 +382,7 @@ func sessLogSave(s *sessions.Session, w http.ResponseWriter, r *http.Request) { } } -// ## begin validation -// ## begin validation -// ## begin validation +// validation var ( errValiNoUsername = errors.New("please enter a username") diff --git a/server/ctrladmin/handlers.go b/server/ctrladmin/handlers.go index 4944c50..6575d25 100644 --- a/server/ctrladmin/handlers.go +++ b/server/ctrladmin/handlers.go @@ -5,9 +5,7 @@ import ( "bytes" "fmt" "image" - _ "image/gif" // to decode uploaded GIF avatars "image/jpeg" - _ "image/png" // to decode uploaded PNG avatars "log" "net/http" "net/url" @@ -19,19 +17,12 @@ import ( "github.com/nfnt/resize" "go.senan.xyz/gonic/db" + "go.senan.xyz/gonic/handlerutil" "go.senan.xyz/gonic/listenbrainz" "go.senan.xyz/gonic/scanner" "go.senan.xyz/gonic/transcode" ) -func doScan(scanner *scanner.Scanner, opts scanner.ScanOptions) { - go func() { - if _, err := scanner.ScanAndClean(opts); err != nil { - log.Printf("error while scanning: %v\n", err) - } - }() -} - func (c *Controller) ServeNotFound(_ *http.Request) *Response { return &Response{template: "not_found.tmpl", code: 404} } @@ -45,35 +36,35 @@ func (c *Controller) ServeHome(r *http.Request) *Response { data := &templateData{} // stats box - c.DB.Model(&db.Artist{}).Count(&data.ArtistCount) - c.DB.Model(&db.Album{}).Count(&data.AlbumCount) - c.DB.Table("tracks").Count(&data.TrackCount) + c.dbc.Model(&db.Artist{}).Count(&data.ArtistCount) + c.dbc.Model(&db.Album{}).Count(&data.AlbumCount) + c.dbc.Table("tracks").Count(&data.TrackCount) // lastfm box - data.RequestRoot = c.BaseURL(r) - data.CurrentLastFMAPIKey, _ = c.DB.GetSetting(db.LastFMAPIKey) + data.RequestRoot = handlerutil.BaseURL(r) + data.CurrentLastFMAPIKey, _ = c.dbc.GetSetting(db.LastFMAPIKey) data.DefaultListenBrainzURL = listenbrainz.BaseURL // users box - allUsersQ := c.DB.DB + allUsersQ := c.dbc.DB if !user.IsAdmin { allUsersQ = allUsersQ.Where("name=?", user.Name) } allUsersQ.Find(&data.AllUsers) // recent folders box - c.DB. + c.dbc. Order("created_at DESC"). Limit(10). Find(&data.RecentFolders) - data.IsScanning = c.Scanner.IsScanning() - if tStr, _ := c.DB.GetSetting(db.LastScanTime); tStr != "" { + data.IsScanning = c.scanner.IsScanning() + if tStr, _ := c.dbc.GetSetting(db.LastScanTime); tStr != "" { i, _ := strconv.ParseInt(tStr, 10, 64) data.LastScanTime = time.Unix(i, 0) } // transcoding box - c.DB. + c.dbc. Where("user_id=?", user.ID). Find(&data.TranscodePreferences) for profile := range transcode.UserProfiles { @@ -81,10 +72,10 @@ func (c *Controller) ServeHome(r *http.Request) *Response { } sort.Strings(data.TranscodeProfiles) // podcasts box - c.DB.Find(&data.Podcasts) + c.dbc.Find(&data.Podcasts) // internet radio box - c.DB.Find(&data.InternetRadioStations) + c.dbc.Find(&data.InternetRadioStations) return &Response{ template: "home.tmpl", @@ -106,7 +97,7 @@ func (c *Controller) ServeLinkLastFMDo(r *http.Request) *Response { } user := r.Context().Value(CtxUser).(*db.User) user.LastFMSession = sessionKey - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save user: %v", err)}} } return &Response{redirect: "/admin/home"} @@ -115,7 +106,7 @@ func (c *Controller) ServeLinkLastFMDo(r *http.Request) *Response { func (c *Controller) ServeUnlinkLastFMDo(r *http.Request) *Response { user := r.Context().Value(CtxUser).(*db.User) user.LastFMSession = "" - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save user: %v", err)}} } return &Response{redirect: "/admin/home"} @@ -133,7 +124,7 @@ func (c *Controller) ServeLinkListenBrainzDo(r *http.Request) *Response { user := r.Context().Value(CtxUser).(*db.User) user.ListenBrainzURL = url user.ListenBrainzToken = token - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save user: %v", err)}} } return &Response{redirect: "/admin/home"} @@ -143,7 +134,7 @@ func (c *Controller) ServeUnlinkListenBrainzDo(r *http.Request) *Response { user := r.Context().Value(CtxUser).(*db.User) user.ListenBrainzURL = "" user.ListenBrainzToken = "" - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save user: %v", err)}} } return &Response{redirect: "/admin/home"} @@ -175,7 +166,7 @@ func (c *Controller) ServeChangeUsernameDo(r *http.Request) *Response { } } user.Name = usernameNew - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save username: %v", err)}} } return &Response{redirect: "/admin/home"} @@ -208,7 +199,7 @@ func (c *Controller) ServeChangePasswordDo(r *http.Request) *Response { } } user.Password = passwordOne - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save user: %v", err)}} } return &Response{redirect: "/admin/home"} @@ -240,7 +231,7 @@ func (c *Controller) ServeChangeAvatarDo(r *http.Request) *Response { } } user.Avatar = avatar - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save user: %v", err)}} } return &Response{ @@ -255,7 +246,7 @@ func (c *Controller) ServeDeleteAvatarDo(r *http.Request) *Response { return &Response{code: 400, err: err.Error()} } user.Avatar = nil - if err := c.DB.Save(user).Error; err != nil { + if err := c.dbc.Save(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("save user: %v", err)}} } return &Response{ @@ -288,7 +279,7 @@ func (c *Controller) ServeDeleteUserDo(r *http.Request) *Response { flashW: []string{"can't delete the admin user"}, } } - if err := c.DB.Delete(user).Error; err != nil { + if err := c.dbc.Delete(user).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("delete user: %v", err)}} } return &Response{redirect: "/admin/home"} @@ -318,7 +309,7 @@ func (c *Controller) ServeCreateUserDo(r *http.Request) *Response { Name: username, Password: passwordOne, } - if err := c.DB.Create(&user).Error; err != nil { + if err := c.dbc.Create(&user).Error; err != nil { return &Response{ redirect: r.Referer(), flashW: []string{fmt.Sprintf("could not create user `%s`: %v", username, err)}, @@ -330,10 +321,10 @@ func (c *Controller) ServeCreateUserDo(r *http.Request) *Response { func (c *Controller) ServeUpdateLastFMAPIKey(r *http.Request) *Response { data := &templateData{} var err error - if data.CurrentLastFMAPIKey, err = c.DB.GetSetting(db.LastFMAPIKey); err != nil { + if data.CurrentLastFMAPIKey, err = c.dbc.GetSetting(db.LastFMAPIKey); err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("couldn't get api key: %v", err)}} } - if data.CurrentLastFMAPISecret, err = c.DB.GetSetting(db.LastFMSecret); err != nil { + if data.CurrentLastFMAPISecret, err = c.dbc.GetSetting(db.LastFMSecret); err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("couldn't get secret: %v", err)}} } return &Response{ @@ -351,17 +342,17 @@ func (c *Controller) ServeUpdateLastFMAPIKeyDo(r *http.Request) *Response { flashW: []string{err.Error()}, } } - if err := c.DB.SetSetting(db.LastFMAPIKey, apiKey); err != nil { + if err := c.dbc.SetSetting(db.LastFMAPIKey, apiKey); err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("couldn't set api key: %v", err)}} } - if err := c.DB.SetSetting(db.LastFMSecret, secret); err != nil { + if err := c.dbc.SetSetting(db.LastFMSecret, secret); err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("couldn't set secret: %v", err)}} } return &Response{redirect: "/admin/home"} } func (c *Controller) ServeStartScanIncDo(_ *http.Request) *Response { - defer doScan(c.Scanner, scanner.ScanOptions{}) + defer doScan(c.scanner, scanner.ScanOptions{}) return &Response{ redirect: "/admin/home", flashN: []string{"incremental scan started. refresh for results"}, @@ -369,7 +360,7 @@ func (c *Controller) ServeStartScanIncDo(_ *http.Request) *Response { } func (c *Controller) ServeStartScanFullDo(_ *http.Request) *Response { - defer doScan(c.Scanner, scanner.ScanOptions{IsFull: true}) + defer doScan(c.scanner, scanner.ScanOptions{IsFull: true}) return &Response{ redirect: "/admin/home", flashN: []string{"full scan started. refresh for results"}, @@ -391,7 +382,7 @@ func (c *Controller) ServeCreateTranscodePrefDo(r *http.Request) *Response { Client: client, Profile: profile, } - if err := c.DB.Create(&pref).Error; err != nil { + if err := c.dbc.Create(&pref).Error; err != nil { return &Response{ redirect: "/admin/home", flashW: []string{fmt.Sprintf("could not create preference: %v", err)}, @@ -406,7 +397,7 @@ func (c *Controller) ServeDeleteTranscodePrefDo(r *http.Request) *Response { if client == "" { return &Response{code: 400, err: "please provide a client"} } - c.DB. + c.dbc. Where("user_id=? AND client=?", user.ID, client). Delete(db.TranscodePreference{}) return &Response{ @@ -424,7 +415,7 @@ func (c *Controller) ServePodcastAddDo(r *http.Request) *Response { flashW: []string{fmt.Sprintf("could not create feed: %v", err)}, } } - if _, err := c.Podcasts.AddNewPodcast(rssURL, feed); err != nil { + if _, err := c.podcasts.AddNewPodcast(rssURL, feed); err != nil { return &Response{ redirect: "/admin/home", flashW: []string{fmt.Sprintf("could not create feed: %v", err)}, @@ -440,7 +431,7 @@ func (c *Controller) ServePodcastDownloadDo(r *http.Request) *Response { if err != nil { return &Response{code: 400, err: "please provide a valid podcast id"} } - if err := c.Podcasts.DownloadPodcastAll(id); err != nil { + if err := c.podcasts.DownloadPodcastAll(id); err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("error downloading: %v", err)}} } return &Response{ @@ -464,7 +455,7 @@ func (c *Controller) ServePodcastUpdateDo(r *http.Request) *Response { default: return &Response{code: 400, err: "please provide a valid podcast download type"} } - if err := c.Podcasts.SetAutoDownload(id, setting); err != nil { + if err := c.podcasts.SetAutoDownload(id, setting); err != nil { return &Response{ flashW: []string{fmt.Sprintf("could not update auto download setting: %v", err)}, code: 400, @@ -481,7 +472,7 @@ func (c *Controller) ServePodcastDeleteDo(r *http.Request) *Response { if err != nil { return &Response{code: 400, err: "please provide a valid podcast id"} } - if err := c.Podcasts.DeletePodcast(id); err != nil { + if err := c.podcasts.DeletePodcast(id); err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("error deleting: %v", err)}} } return &Response{ @@ -512,7 +503,7 @@ func (c *Controller) ServeInternetRadioStationAddDo(r *http.Request) *Response { station.StreamURL = streamURL station.Name = name station.HomepageURL = homepageURL - if err := c.DB.Save(&station).Error; err != nil { + if err := c.dbc.Save(&station).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("error saving station: %v", err)}} } @@ -555,14 +546,14 @@ func (c *Controller) ServeInternetRadioStationUpdateDo(r *http.Request) *Respons } var station db.InternetRadioStation - if err := c.DB.Where("id=?", stationID).First(&station).Error; err != nil { + if err := c.dbc.Where("id=?", stationID).First(&station).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("find station by id: %v", err)}} } station.StreamURL = streamURL station.Name = name station.HomepageURL = homepageURL - if err := c.DB.Save(&station).Error; err != nil { + if err := c.dbc.Save(&station).Error; err != nil { return &Response{code: 500, err: "please provide a valid internet radio station id"} } @@ -578,11 +569,11 @@ func (c *Controller) ServeInternetRadioStationDeleteDo(r *http.Request) *Respons } var station db.InternetRadioStation - if err := c.DB.Where("id=?", stationID).First(&station).Error; err != nil { + if err := c.dbc.Where("id=?", stationID).First(&station).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("find station by id: %v", err)}} } - if err := c.DB.Where("id=?", stationID).Delete(&db.InternetRadioStation{}).Error; err != nil { + if err := c.dbc.Where("id=?", stationID).Delete(&db.InternetRadioStation{}).Error; err != nil { return &Response{redirect: r.Referer(), flashW: []string{fmt.Sprintf("deleting radio station: %v", err)}} } @@ -621,6 +612,14 @@ func selectedUserIfAdmin(c *Controller, r *http.Request) (*db.User, error) { if !user.IsAdmin && user.Name != selectedUsername { return nil, fmt.Errorf("must be admin to perform actions for other users") } - selectedUser := c.DB.GetUserByName(selectedUsername) + selectedUser := c.dbc.GetUserByName(selectedUsername) return selectedUser, nil } + +func doScan(scanner *scanner.Scanner, opts scanner.ScanOptions) { + go func() { + if _, err := scanner.ScanAndClean(opts); err != nil { + log.Printf("error while scanning: %v\n", err) + } + }() +} diff --git a/server/ctrladmin/handlers_raw.go b/server/ctrladmin/handlers_raw.go index adbf1eb..49d9e2a 100644 --- a/server/ctrladmin/handlers_raw.go +++ b/server/ctrladmin/handlers_raw.go @@ -16,7 +16,7 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, r.Referer(), http.StatusSeeOther) return } - user := c.DB.GetUserByName(username) + user := c.dbc.GetUserByName(username) if user == nil || password != user.Password { sessAddFlashW(session, []string{"invalid username / password"}) sessLogSave(session, w, r) @@ -28,12 +28,12 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { // session and put the row into the request context session.Values["user"] = user.ID sessLogSave(session, w, r) - http.Redirect(w, r, c.Path("/admin/home"), http.StatusSeeOther) + http.Redirect(w, r, c.resolveProxyPath("/admin/home"), http.StatusSeeOther) } func (c *Controller) ServeLogout(w http.ResponseWriter, r *http.Request) { session := r.Context().Value(CtxSession).(*sessions.Session) session.Options.MaxAge = -1 sessLogSave(session, w, r) - http.Redirect(w, r, c.Path("/admin/login"), http.StatusSeeOther) + http.Redirect(w, r, c.resolveProxyPath("/admin/login"), http.StatusSeeOther) } diff --git a/server/ctrladmin/middleware.go b/server/ctrladmin/middleware.go deleted file mode 100644 index db50f64..0000000 --- a/server/ctrladmin/middleware.go +++ /dev/null @@ -1,65 +0,0 @@ -package ctrladmin - -import ( - "context" - "fmt" - "net/http" - - "github.com/gorilla/sessions" - - "go.senan.xyz/gonic" - "go.senan.xyz/gonic/db" -) - -func (c *Controller) WithSession(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - session, err := c.sessDB.Get(r, gonic.Name) - if err != nil { - http.Error(w, fmt.Sprintf("error getting session: %s", err), 500) - return - } - withSession := context.WithValue(r.Context(), CtxSession, session) - next.ServeHTTP(w, r.WithContext(withSession)) - }) -} - -func (c *Controller) WithUserSession(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // session exists at this point - session := r.Context().Value(CtxSession).(*sessions.Session) - userID, ok := session.Values["user"].(int) - if !ok { - sessAddFlashW(session, []string{"you are not authenticated"}) - sessLogSave(session, w, r) - http.Redirect(w, r, c.Path("/admin/login"), http.StatusSeeOther) - return - } - // take username from sesion and add the user row to the context - user := c.DB.GetUserByID(userID) - if user == nil { - // the username in the client's session no longer relates to a - // user in the database (maybe the user was deleted) - session.Options.MaxAge = -1 - sessLogSave(session, w, r) - http.Redirect(w, r, c.Path("/admin/login"), http.StatusSeeOther) - return - } - withUser := context.WithValue(r.Context(), CtxUser, user) - next.ServeHTTP(w, r.WithContext(withUser)) - }) -} - -func (c *Controller) WithAdminSession(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - // session and user exist at this point - session := r.Context().Value(CtxSession).(*sessions.Session) - user := r.Context().Value(CtxUser).(*db.User) - if !user.IsAdmin { - sessAddFlashW(session, []string{"you are not an admin"}) - sessLogSave(session, w, r) - http.Redirect(w, r, c.Path("/admin/login"), http.StatusSeeOther) - return - } - next.ServeHTTP(w, r) - }) -} diff --git a/server/ctrladmin/routes.go b/server/ctrladmin/routes.go deleted file mode 100644 index 12bda3f..0000000 --- a/server/ctrladmin/routes.go +++ /dev/null @@ -1,62 +0,0 @@ -package ctrladmin - -import ( - "net/http" - - "github.com/gorilla/mux" - "go.senan.xyz/gonic/server/ctrladmin/adminui" -) - -func AddRoutes(c *Controller, r *mux.Router) { - // public routes (creates session) - r.Use(c.WithSession) - r.Handle("/login", c.H(c.ServeLogin)) - r.Handle("/login_do", c.HR(c.ServeLoginDo)) // "raw" handler, updates session - - staticHandler := http.StripPrefix("/admin", http.FileServer(http.FS(adminui.StaticFS))) - r.PathPrefix("/static").Handler(staticHandler) - - // user routes (if session is valid) - routUser := r.NewRoute().Subrouter() - routUser.Use(c.WithUserSession) - routUser.Handle("/logout", c.HR(c.ServeLogout)) // "raw" handler, updates session - routUser.Handle("/home", c.H(c.ServeHome)) - routUser.Handle("/change_username", c.H(c.ServeChangeUsername)) - routUser.Handle("/change_username_do", c.H(c.ServeChangeUsernameDo)) - routUser.Handle("/change_password", c.H(c.ServeChangePassword)) - routUser.Handle("/change_password_do", c.H(c.ServeChangePasswordDo)) - routUser.Handle("/change_avatar", c.H(c.ServeChangeAvatar)) - routUser.Handle("/change_avatar_do", c.H(c.ServeChangeAvatarDo)) - routUser.Handle("/delete_avatar_do", c.H(c.ServeDeleteAvatarDo)) - routUser.Handle("/delete_user", c.H(c.ServeDeleteUser)) - routUser.Handle("/delete_user_do", c.H(c.ServeDeleteUserDo)) - routUser.Handle("/link_lastfm_do", c.H(c.ServeLinkLastFMDo)) - routUser.Handle("/unlink_lastfm_do", c.H(c.ServeUnlinkLastFMDo)) - routUser.Handle("/link_listenbrainz_do", c.H(c.ServeLinkListenBrainzDo)) - routUser.Handle("/unlink_listenbrainz_do", c.H(c.ServeUnlinkListenBrainzDo)) - routUser.Handle("/create_transcode_pref_do", c.H(c.ServeCreateTranscodePrefDo)) - routUser.Handle("/delete_transcode_pref_do", c.H(c.ServeDeleteTranscodePrefDo)) - - // admin routes (if session is valid, and is admin) - routAdmin := routUser.NewRoute().Subrouter() - routAdmin.Use(c.WithAdminSession) - routAdmin.Handle("/create_user", c.H(c.ServeCreateUser)) - routAdmin.Handle("/create_user_do", c.H(c.ServeCreateUserDo)) - routAdmin.Handle("/update_lastfm_api_key", c.H(c.ServeUpdateLastFMAPIKey)) - routAdmin.Handle("/update_lastfm_api_key_do", c.H(c.ServeUpdateLastFMAPIKeyDo)) - routAdmin.Handle("/start_scan_inc_do", c.H(c.ServeStartScanIncDo)) - routAdmin.Handle("/start_scan_full_do", c.H(c.ServeStartScanFullDo)) - routAdmin.Handle("/add_podcast_do", c.H(c.ServePodcastAddDo)) - routAdmin.Handle("/delete_podcast_do", c.H(c.ServePodcastDeleteDo)) - routAdmin.Handle("/download_podcast_do", c.H(c.ServePodcastDownloadDo)) - routAdmin.Handle("/update_podcast_do", c.H(c.ServePodcastUpdateDo)) - routAdmin.Handle("/add_internet_radio_station_do", c.H(c.ServeInternetRadioStationAddDo)) - routAdmin.Handle("/delete_internet_radio_station_do", c.H(c.ServeInternetRadioStationDeleteDo)) - routAdmin.Handle("/update_internet_radio_station_do", c.H(c.ServeInternetRadioStationUpdateDo)) - - // middlewares should be run for not found handler - // https://github.com/gorilla/mux/issues/416 - notFoundHandler := c.H(c.ServeNotFound) - notFoundRoute := r.NewRoute().Handler(notFoundHandler) - r.NotFoundHandler = notFoundRoute.GetHandler() -} diff --git a/server/ctrlbase/routes.go b/server/ctrlbase/routes.go deleted file mode 100644 index 7ade2f1..0000000 --- a/server/ctrlbase/routes.go +++ /dev/null @@ -1,34 +0,0 @@ -package ctrlbase - -import ( - "fmt" - "net/http" - - "github.com/gorilla/handlers" - "github.com/gorilla/mux" -) - -func AddRoutes(c *Controller, r *mux.Router, logHTTP bool) { - if logHTTP { - r.Use(c.WithLogging) - } - r.Use(c.WithCORS) - r.Use(handlers.RecoveryHandler(handlers.PrintRecoveryStack(true))) - - r.HandleFunc("/", func(w http.ResponseWriter, r *http.Request) { - adminHome := c.Path("/admin/home") - http.Redirect(w, r, adminHome, http.StatusSeeOther) - }) - // misc subsonic routes without /rest prefix - r.HandleFunc("/settings.view", func(w http.ResponseWriter, r *http.Request) { - adminHome := c.Path("/admin/home") - http.Redirect(w, r, adminHome, http.StatusSeeOther) - }) - r.HandleFunc("/musicFolderSettings.view", func(w http.ResponseWriter, r *http.Request) { - restScan := c.Path(fmt.Sprintf("/rest/startScan.view?%s", r.URL.Query().Encode())) - http.Redirect(w, r, restScan, http.StatusSeeOther) - }) - r.HandleFunc("/ping", func(w http.ResponseWriter, r *http.Request) { - fmt.Fprint(w, "OK") - }) -} diff --git a/server/ctrlsubsonic/ctrl.go b/server/ctrlsubsonic/ctrl.go index de58d51..f3e409c 100644 --- a/server/ctrlsubsonic/ctrl.go +++ b/server/ctrlsubsonic/ctrl.go @@ -1,6 +1,9 @@ package ctrlsubsonic import ( + "context" + "crypto/md5" + "encoding/hex" "encoding/json" "encoding/xml" "fmt" @@ -8,11 +11,14 @@ import ( "log" "net/http" + "go.senan.xyz/gonic/db" + "go.senan.xyz/gonic/handlerutil" "go.senan.xyz/gonic/jukebox" "go.senan.xyz/gonic/lastfm" + "go.senan.xyz/gonic/playlist" "go.senan.xyz/gonic/podcasts" + "go.senan.xyz/gonic/scanner" "go.senan.xyz/gonic/scrobble" - "go.senan.xyz/gonic/server/ctrlbase" "go.senan.xyz/gonic/server/ctrlsubsonic/artistinfocache" "go.senan.xyz/gonic/server/ctrlsubsonic/params" "go.senan.xyz/gonic/server/ctrlsubsonic/spec" @@ -31,7 +37,7 @@ type MusicPath struct { Alias, Path string } -func PathsOf(paths []MusicPath) []string { +func MusicPaths(paths []MusicPath) []string { var r []string for _, p := range paths { r = append(r, p.Path) @@ -39,23 +45,227 @@ func PathsOf(paths []MusicPath) []string { return r } +type ProxyPathResolver func(in string) string + type Controller struct { - *ctrlbase.Controller - MusicPaths []MusicPath - PodcastsPath string - CacheAudioPath string - CacheCoverPath string - Jukebox *jukebox.Jukebox - Scrobblers []scrobble.Scrobbler - Podcasts *podcasts.Podcasts - Transcoder transcode.Transcoder - LastFMClient *lastfm.Client - ArtistInfoCache *artistinfocache.ArtistInfoCache + *http.ServeMux + + dbc *db.DB + scanner *scanner.Scanner + musicPaths []MusicPath + podcastsPath string + cacheAudioPath string + cacheCoverPath string + jukebox *jukebox.Jukebox + playlistStore *playlist.Store + scrobblers []scrobble.Scrobbler + podcasts *podcasts.Podcasts + transcoder transcode.Transcoder + lastFMClient *lastfm.Client + artistInfoCache *artistinfocache.ArtistInfoCache + resolveProxyPath ProxyPathResolver } -type metaResponse struct { - XMLName xml.Name `xml:"subsonic-response" json:"-"` - *spec.Response `json:"subsonic-response"` +func New(dbc *db.DB, scannr *scanner.Scanner, musicPaths []MusicPath, podcastsPath string, cacheAudioPath string, cacheCoverPath string, jukebox *jukebox.Jukebox, playlistStore *playlist.Store, scrobblers []scrobble.Scrobbler, podcasts *podcasts.Podcasts, transcoder transcode.Transcoder, lastFMClient *lastfm.Client, artistInfoCache *artistinfocache.ArtistInfoCache, resolveProxyPath ProxyPathResolver) (*Controller, error) { + c := Controller{ + ServeMux: http.NewServeMux(), + + dbc: dbc, + scanner: scannr, + musicPaths: musicPaths, + podcastsPath: podcastsPath, + cacheAudioPath: cacheAudioPath, + cacheCoverPath: cacheCoverPath, + jukebox: jukebox, + playlistStore: playlistStore, + scrobblers: scrobblers, + podcasts: podcasts, + transcoder: transcoder, + lastFMClient: lastFMClient, + artistInfoCache: artistInfoCache, + resolveProxyPath: resolveProxyPath, + } + + chain := handlerutil.Chain( + withParams, + withRequiredParams, + withUser(dbc), + ) + + c.Handle("/getLicense", chain(resp(c.ServeGetLicence))) + c.Handle("/ping", chain(resp(c.ServePing))) + c.Handle("/getOpenSubsonicExtensions", chain(resp(c.ServeGetOpenSubsonicExtensions))) + + c.Handle("/getMusicFolders", chain(resp(c.ServeGetMusicFolders))) + c.Handle("/getScanStatus", chain(resp(c.ServeGetScanStatus))) + c.Handle("/scrobble", chain(resp(c.ServeScrobble))) + c.Handle("/startScan", chain(resp(c.ServeStartScan))) + c.Handle("/getUser", chain(resp(c.ServeGetUser))) + c.Handle("/getPlaylists", chain(resp(c.ServeGetPlaylists))) + c.Handle("/getPlaylist", chain(resp(c.ServeGetPlaylist))) + c.Handle("/createPlaylist", chain(resp(c.ServeCreatePlaylist))) + c.Handle("/updatePlaylist", chain(resp(c.ServeUpdatePlaylist))) + c.Handle("/deletePlaylist", chain(resp(c.ServeDeletePlaylist))) + c.Handle("/savePlayQueue", chain(resp(c.ServeSavePlayQueue))) + c.Handle("/getPlayQueue", chain(resp(c.ServeGetPlayQueue))) + c.Handle("/getSong", chain(resp(c.ServeGetSong))) + c.Handle("/getRandomSongs", chain(resp(c.ServeGetRandomSongs))) + c.Handle("/getSongsByGenre", chain(resp(c.ServeGetSongsByGenre))) + c.Handle("/jukeboxControl", chain(resp(c.ServeJukebox))) + c.Handle("/getBookmarks", chain(resp(c.ServeGetBookmarks))) + c.Handle("/createBookmark", chain(resp(c.ServeCreateBookmark))) + c.Handle("/deleteBookmark", chain(resp(c.ServeDeleteBookmark))) + c.Handle("/getTopSongs", chain(resp(c.ServeGetTopSongs))) + c.Handle("/getSimilarSongs", chain(resp(c.ServeGetSimilarSongs))) + c.Handle("/getSimilarSongs2", chain(resp(c.ServeGetSimilarSongsTwo))) + c.Handle("/getLyrics", chain(resp(c.ServeGetLyrics))) + + // raw + c.Handle("/getCoverArt", chain(respRaw(c.ServeGetCoverArt))) + c.Handle("/stream", chain(respRaw(c.ServeStream))) + c.Handle("/download", chain(respRaw(c.ServeStream))) + c.Handle("/getAvatar", chain(respRaw(c.ServeGetAvatar))) + + // browse by tag + c.Handle("/getAlbum", chain(resp(c.ServeGetAlbum))) + c.Handle("/getAlbumList2", chain(resp(c.ServeGetAlbumListTwo))) + c.Handle("/getArtist", chain(resp(c.ServeGetArtist))) + c.Handle("/getArtists", chain(resp(c.ServeGetArtists))) + c.Handle("/search3", chain(resp(c.ServeSearchThree))) + c.Handle("/getArtistInfo2", chain(resp(c.ServeGetArtistInfoTwo))) + c.Handle("/getStarred2", chain(resp(c.ServeGetStarredTwo))) + + // browse by folder + c.Handle("/getIndexes", chain(resp(c.ServeGetIndexes))) + c.Handle("/getMusicDirectory", chain(resp(c.ServeGetMusicDirectory))) + c.Handle("/getAlbumList", chain(resp(c.ServeGetAlbumList))) + c.Handle("/search2", chain(resp(c.ServeSearchTwo))) + c.Handle("/getGenres", chain(resp(c.ServeGetGenres))) + c.Handle("/getArtistInfo", chain(resp(c.ServeGetArtistInfo))) + c.Handle("/getStarred", chain(resp(c.ServeGetStarred))) + + // star / rating + c.Handle("/star", chain(resp(c.ServeStar))) + c.Handle("/unstar", chain(resp(c.ServeUnstar))) + c.Handle("/setRating", chain(resp(c.ServeSetRating))) + + // podcasts + c.Handle("/getPodcasts", chain(resp(c.ServeGetPodcasts))) + c.Handle("/getNewestPodcasts", chain(resp(c.ServeGetNewestPodcasts))) + c.Handle("/downloadPodcastEpisode", chain(resp(c.ServeDownloadPodcastEpisode))) + c.Handle("/createPodcastChannel", chain(resp(c.ServeCreatePodcastChannel))) + c.Handle("/refreshPodcasts", chain(resp(c.ServeRefreshPodcasts))) + c.Handle("/deletePodcastChannel", chain(resp(c.ServeDeletePodcastChannel))) + c.Handle("/deletePodcastEpisode", chain(resp(c.ServeDeletePodcastEpisode))) + + // internet radio + c.Handle("/getInternetRadioStations", chain(resp(c.ServeGetInternetRadioStations))) + c.Handle("/createInternetRadioStation", chain(resp(c.ServeCreateInternetRadioStation))) + c.Handle("/updateInternetRadioStation", chain(resp(c.ServeUpdateInternetRadioStation))) + c.Handle("/deleteInternetRadioStation", chain(resp(c.ServeDeleteInternetRadioStation))) + + c.Handle("/", chain(resp(c.ServeNotFound))) + + return &c, nil +} + +type ( + handlerSubsonic func(r *http.Request) *spec.Response + handlerSubsonicRaw func(w http.ResponseWriter, r *http.Request) *spec.Response +) + +func resp(h handlerSubsonic) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := writeResp(w, r, h(r)); err != nil { + log.Printf("error writing subsonic response: %v\n", err) + } + }) +} + +func respRaw(h handlerSubsonicRaw) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + if err := writeResp(w, r, h(w, r)); err != nil { + log.Printf("error writing raw subsonic response: %v\n", err) + } + }) +} + +func withParams(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + params := params.New(r) + withParams := context.WithValue(r.Context(), CtxParams, params) + next.ServeHTTP(w, r.WithContext(withParams)) + }) +} + +func withRequiredParams(next http.Handler) http.Handler { + requiredParameters := []string{ + "u", "c", + } + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + params := r.Context().Value(CtxParams).(params.Params) + for _, req := range requiredParameters { + if _, err := params.Get(req); err != nil { + _ = writeResp(w, r, spec.NewError(10, "please provide a `%s` parameter", req)) + return + } + } + next.ServeHTTP(w, r) + }) +} + +func withUser(dbc *db.DB) handlerutil.Middleware { + return func(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + params := r.Context().Value(CtxParams).(params.Params) + // ignoring errors here, a middleware has already ensured they exist + username, _ := params.Get("u") + password, _ := params.Get("p") + token, _ := params.Get("t") + salt, _ := params.Get("s") + + passwordAuth := token == "" && salt == "" + tokenAuth := password == "" + if tokenAuth == passwordAuth { + _ = writeResp(w, r, spec.NewError(10, + "please provide `t` and `s`, or just `p`")) + return + } + user := dbc.GetUserByName(username) + if user == nil { + _ = writeResp(w, r, spec.NewError(40, + "invalid username `%s`", username)) + return + } + var credsOk bool + if tokenAuth { + credsOk = checkCredsToken(user.Password, token, salt) + } else { + credsOk = checkCredsBasic(user.Password, password) + } + if !credsOk { + _ = writeResp(w, r, spec.NewError(40, "invalid password")) + return + } + withUser := context.WithValue(r.Context(), CtxUser, user) + next.ServeHTTP(w, r.WithContext(withUser)) + }) + } +} + +func checkCredsToken(password, token, salt string) bool { + toHash := fmt.Sprintf("%s%s", password, salt) + hash := md5.Sum([]byte(toHash)) + expToken := hex.EncodeToString(hash[:]) + return token == expToken +} + +func checkCredsBasic(password, given string) bool { + if len(given) >= 4 && given[:4] == "enc:" { + bytes, _ := hex.DecodeString(given[4:]) + given = string(bytes) + } + return password == given } type errWriter struct { @@ -78,8 +288,14 @@ func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) erro log.Printf("subsonic error code %d: %s", resp.Error.Code, resp.Error.Message) } - res := metaResponse{Response: resp} + var res struct { + XMLName xml.Name `xml:"subsonic-response" json:"-"` + *spec.Response `json:"subsonic-response"` + } + res.Response = resp + params := r.Context().Value(CtxParams).(params.Params) + ew := &errWriter{w: w} switch v, _ := params.Get("f"); v { case "json": @@ -89,6 +305,7 @@ func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) erro return fmt.Errorf("marshal to json: %w", err) } ew.write(data) + case "jsonp": w.Header().Set("Content-Type", "application/javascript") data, err := json.Marshal(res) @@ -101,34 +318,16 @@ func writeResp(w http.ResponseWriter, r *http.Request, resp *spec.Response) erro ew.write([]byte("(")) ew.write(data) ew.write([]byte(");")) + default: w.Header().Set("Content-Type", "application/xml") data, err := xml.MarshalIndent(res, "", " ") if err != nil { return fmt.Errorf("marshal to xml: %w", err) } + ew.write(data) } + return ew.err } - -type ( - handlerSubsonic func(r *http.Request) *spec.Response - handlerSubsonicRaw func(w http.ResponseWriter, r *http.Request) *spec.Response -) - -func (c *Controller) H(h handlerSubsonic) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := writeResp(w, r, h(r)); err != nil { - log.Printf("error writing subsonic response: %v\n", err) - } - }) -} - -func (c *Controller) HR(h handlerSubsonicRaw) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - if err := writeResp(w, r, h(w, r)); err != nil { - log.Printf("error writing raw subsonic response: %v\n", err) - } - }) -} diff --git a/server/ctrlsubsonic/ctrl_test.go b/server/ctrlsubsonic/ctrl_test.go index 05f5f90..1d1c28c 100644 --- a/server/ctrlsubsonic/ctrl_test.go +++ b/server/ctrlsubsonic/ctrl_test.go @@ -19,7 +19,6 @@ import ( "go.senan.xyz/gonic" "go.senan.xyz/gonic/db" "go.senan.xyz/gonic/mockfs" - "go.senan.xyz/gonic/server/ctrlbase" "go.senan.xyz/gonic/server/ctrlsubsonic/params" "go.senan.xyz/gonic/transcode" ) @@ -78,7 +77,7 @@ func makeHTTPMockWithAdmin(query url.Values) (*httptest.ResponseRecorder, *http. return rr, req } -func runQueryCases(t *testing.T, contr *Controller, h handlerSubsonic, cases []*queryCase) { +func runQueryCases(t *testing.T, h handlerSubsonic, cases []*queryCase) { t.Helper() for _, qc := range cases { qc := qc @@ -86,7 +85,7 @@ func runQueryCases(t *testing.T, contr *Controller, h handlerSubsonic, cases []* t.Parallel() rr, req := makeHTTPMock(qc.params) - contr.H(h).ServeHTTP(rr, req) + resp(h).ServeHTTP(rr, req) body := rr.Body.String() if status := rr.Code; status != http.StatusOK { t.Fatalf("didn't give a 200\n%s", body) @@ -149,11 +148,10 @@ func makec(tb testing.TB, roots []string, audio bool) *Controller { absRoots = append(absRoots, MusicPath{Path: filepath.Join(m.TmpDir(), root)}) } - base := &ctrlbase.Controller{DB: m.DB()} contr := &Controller{ - Controller: base, - MusicPaths: absRoots, - Transcoder: transcode.NewFFmpegTranscoder(), + dbc: m.DB(), + musicPaths: absRoots, + transcoder: transcode.NewFFmpegTranscoder(), } return contr diff --git a/server/ctrlsubsonic/handlers_bookmark.go b/server/ctrlsubsonic/handlers_bookmark.go index 037a87f..7c6822b 100644 --- a/server/ctrlsubsonic/handlers_bookmark.go +++ b/server/ctrlsubsonic/handlers_bookmark.go @@ -15,7 +15,7 @@ import ( func (c *Controller) ServeGetBookmarks(r *http.Request) *spec.Response { user := r.Context().Value(CtxUser).(*db.User) bookmarks := []*db.Bookmark{} - err := c.DB. + err := c.dbc. Where("user_id=?", user.ID). Find(&bookmarks). Error @@ -40,7 +40,7 @@ func (c *Controller) ServeGetBookmarks(r *http.Request) *spec.Response { switch specid.IDT(bookmark.EntryIDType) { case specid.Track: var track db.Track - err := c.DB. + err := c.dbc. Preload("Album"). Find(&track, "id=?", bookmark.EntryID). Error @@ -64,14 +64,14 @@ func (c *Controller) ServeCreateBookmark(r *http.Request) *spec.Response { return spec.NewError(10, "please provide an `id` parameter") } bookmark := &db.Bookmark{} - c.DB.FirstOrCreate(bookmark, db.Bookmark{ + c.dbc.FirstOrCreate(bookmark, db.Bookmark{ UserID: user.ID, EntryIDType: string(id.Type), EntryID: id.Value, }) bookmark.Comment = params.GetOr("comment", "") bookmark.Position = params.GetOrInt("position", 0) - c.DB.Save(bookmark) + c.dbc.Save(bookmark) return spec.NewResponse() } @@ -82,7 +82,7 @@ func (c *Controller) ServeDeleteBookmark(r *http.Request) *spec.Response { if err != nil { return spec.NewError(10, "please provide an `id` parameter") } - c.DB. + c.dbc. Where("user_id=? AND entry_id_type=? AND entry_id=?", user.ID, id.Type, id.Value). Delete(&db.Bookmark{}) return spec.NewResponse() diff --git a/server/ctrlsubsonic/handlers_by_folder.go b/server/ctrlsubsonic/handlers_by_folder.go index ed7cb2d..71e8933 100644 --- a/server/ctrlsubsonic/handlers_by_folder.go +++ b/server/ctrlsubsonic/handlers_by_folder.go @@ -21,16 +21,16 @@ import ( func (c *Controller) ServeGetIndexes(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) user := r.Context().Value(CtxUser).(*db.User) - rootQ := c.DB. + rootQ := c.dbc. Select("id"). Model(&db.Album{}). Where("parent_id IS NULL") - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { rootQ = rootQ. Where("root_dir=?", m) } var folders []*db.Album - c.DB. + c.dbc. Select("*, count(sub.id) child_count"). Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID). @@ -70,13 +70,13 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { user := r.Context().Value(CtxUser).(*db.User) childrenObj := []*spec.TrackChild{} folder := &db.Album{} - c.DB. + c.dbc. Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID). First(folder, id.Value) // start looking for child childFolders in the current dir var childFolders []*db.Album - c.DB. + c.dbc. Where("parent_id=?", id.Value). Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID). @@ -87,7 +87,7 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { } // start looking for child childTracks in the current dir var childTracks []*db.Track - c.DB. + c.dbc. Where("album_id=?", id.Value). Preload("Album"). Preload("Album.Artists"). @@ -96,7 +96,7 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { Order("filename"). Find(&childTracks) - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for _, ch := range childTracks { toAppend := spec.NewTCTrackByFolder(ch, folder) @@ -120,7 +120,7 @@ func (c *Controller) ServeGetMusicDirectory(r *http.Request) *spec.Response { func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) user := r.Context().Value(CtxUser).(*db.User) - q := c.DB.DB + q := c.dbc.DB switch v, _ := params.Get("type"); v { case "alphabeticalByArtist": q = q.Joins(` @@ -163,7 +163,7 @@ func (c *Controller) ServeGetAlbumList(r *http.Request) *spec.Response { return spec.NewError(10, "unknown value `%s` for parameter 'type'", v) } - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("root_dir=?", m) } var folders []*db.Album @@ -205,16 +205,16 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { results := &spec.SearchResultTwo{} // search "artists" - rootQ := c.DB. + rootQ := c.dbc. Select("id"). Model(&db.Album{}). Where("parent_id IS NULL") - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { rootQ = rootQ.Where("root_dir=?", m) } var artists []*db.Album - q := c.DB.Where(`parent_id IN ?`, rootQ.SubQuery()) + q := c.dbc.Where(`parent_id IN ?`, rootQ.SubQuery()) for _, s := range queries { q = q.Where(`right_path LIKE ? OR right_path_u_dec LIKE ?`, s, s) } @@ -231,7 +231,7 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { // search "albums" var albums []*db.Album - q = c.DB.Joins("JOIN album_artists ON album_artists.album_id=albums.id") + q = c.dbc.Joins("JOIN album_artists ON album_artists.album_id=albums.id") for _, s := range queries { q = q.Where(`right_path LIKE ? OR right_path_u_dec LIKE ?`, s, s) } @@ -239,7 +239,7 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { Preload("AlbumRating", "user_id=?", user.ID). Offset(params.GetOrInt("albumOffset", 0)). Limit(params.GetOrInt("albumCount", 20)) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("root_dir=?", m) } if err := q.Find(&albums).Error; err != nil { @@ -251,7 +251,7 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { // search tracks var tracks []*db.Track - q = c.DB.Preload("Album") + q = c.dbc.Preload("Album") for _, s := range queries { q = q.Where(`filename LIKE ? OR filename LIKE ?`, s, s) } @@ -259,7 +259,7 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { Preload("TrackRating", "user_id=?", user.ID). Offset(params.GetOrInt("songOffset", 0)). Limit(params.GetOrInt("songCount", 20)) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q. Joins("JOIN albums ON albums.id=tracks.album_id"). Where("albums.root_dir=?", m) @@ -268,7 +268,7 @@ func (c *Controller) ServeSearchTwo(r *http.Request) *spec.Response { return spec.NewError(0, "find tracks: %v", err) } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for _, t := range tracks { track := spec.NewTCTrackByFolder(t, t.Album) @@ -292,16 +292,16 @@ func (c *Controller) ServeGetStarred(r *http.Request) *spec.Response { results := &spec.Starred{} // "artists" - rootQ := c.DB. + rootQ := c.dbc. Select("id"). Model(&db.Album{}). Where("parent_id IS NULL") - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { rootQ = rootQ.Where("root_dir=?", m) } var artists []*db.Album - q := c.DB. + q := c.dbc. Where(`parent_id IN ?`, rootQ.SubQuery()). Joins("JOIN album_stars ON albums.id=album_stars.album_id"). Where("album_stars.user_id=?", user.ID). @@ -316,13 +316,13 @@ func (c *Controller) ServeGetStarred(r *http.Request) *spec.Response { // "albums" var albums []*db.Album - q = c.DB. + q = c.dbc. Joins("JOIN album_artists ON album_artists.album_id=albums.id"). Joins("JOIN album_stars ON albums.id=album_stars.album_id"). Where("album_stars.user_id=?", user.ID). Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("root_dir=?", m) } if err := q.Find(&albums).Error; err != nil { @@ -334,13 +334,13 @@ func (c *Controller) ServeGetStarred(r *http.Request) *spec.Response { // tracks var tracks []*db.Track - q = c.DB. + q = c.dbc. Preload("Album"). Joins("JOIN track_stars ON tracks.id=track_stars.track_id"). Where("track_stars.user_id=?", user.ID). Preload("TrackStar", "user_id=?", user.ID). Preload("TrackRating", "user_id=?", user.ID) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q. Joins("JOIN albums ON albums.id=tracks.album_id"). Where("albums.root_dir=?", m) @@ -349,7 +349,7 @@ func (c *Controller) ServeGetStarred(r *http.Request) *spec.Response { return spec.NewError(0, "find tracks: %v", err) } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for _, t := range tracks { track := spec.NewTCTrackByFolder(t, t.Album) diff --git a/server/ctrlsubsonic/handlers_by_folder_test.go b/server/ctrlsubsonic/handlers_by_folder_test.go index 27de6ad..0efc4e2 100644 --- a/server/ctrlsubsonic/handlers_by_folder_test.go +++ b/server/ctrlsubsonic/handlers_by_folder_test.go @@ -9,10 +9,8 @@ import ( func TestGetIndexes(t *testing.T) { t.Parallel() - contr := makeControllerRoots(t, []string{"m-0", "m-1"}) - - runQueryCases(t, contr, contr.ServeGetIndexes, []*queryCase{ + runQueryCases(t, contr.ServeGetIndexes, []*queryCase{ {url.Values{}, "no_args", false}, {url.Values{"musicFolderId": {"0"}}, "with_music_folder_1", false}, {url.Values{"musicFolderId": {"1"}}, "with_music_folder_2", false}, @@ -21,10 +19,8 @@ func TestGetIndexes(t *testing.T) { func TestGetMusicDirectory(t *testing.T) { t.Parallel() - contr := makeController(t) - - runQueryCases(t, contr, contr.ServeGetMusicDirectory, []*queryCase{ + runQueryCases(t, contr.ServeGetMusicDirectory, []*queryCase{ {url.Values{"id": {"al-2"}}, "without_tracks", false}, {url.Values{"id": {"al-3"}}, "with_tracks", false}, }) @@ -33,8 +29,7 @@ func TestGetMusicDirectory(t *testing.T) { func TestGetAlbumList(t *testing.T) { t.Parallel() contr := makeController(t) - - runQueryCases(t, contr, contr.ServeGetAlbumList, []*queryCase{ + runQueryCases(t, contr.ServeGetAlbumList, []*queryCase{ {url.Values{"type": {"alphabeticalByArtist"}}, "alpha_artist", false}, {url.Values{"type": {"alphabeticalByName"}}, "alpha_name", false}, {url.Values{"type": {"newest"}}, "newest", false}, @@ -45,8 +40,7 @@ func TestGetAlbumList(t *testing.T) { func TestSearchTwo(t *testing.T) { t.Parallel() contr := makeController(t) - - runQueryCases(t, contr, contr.ServeSearchTwo, []*queryCase{ + runQueryCases(t, contr.ServeSearchTwo, []*queryCase{ {url.Values{"query": {"art"}}, "q_art", false}, {url.Values{"query": {"alb"}}, "q_alb", false}, {url.Values{"query": {"tra"}}, "q_tra", false}, diff --git a/server/ctrlsubsonic/handlers_by_tags.go b/server/ctrlsubsonic/handlers_by_tags.go index 2fd4fc5..cb3a2df 100644 --- a/server/ctrlsubsonic/handlers_by_tags.go +++ b/server/ctrlsubsonic/handlers_by_tags.go @@ -14,6 +14,7 @@ import ( "github.com/jinzhu/gorm" "go.senan.xyz/gonic/db" + "go.senan.xyz/gonic/handlerutil" "go.senan.xyz/gonic/server/ctrlsubsonic/params" "go.senan.xyz/gonic/server/ctrlsubsonic/spec" "go.senan.xyz/gonic/server/ctrlsubsonic/specid" @@ -23,7 +24,7 @@ func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) user := r.Context().Value(CtxUser).(*db.User) var artists []*db.Artist - q := c.DB. + q := c.dbc. Select("*, count(sub.id) album_count"). Joins("JOIN album_artists ON album_artists.artist_id=artists.id"). Joins("JOIN albums sub ON sub.id=album_artists.album_id"). @@ -32,7 +33,7 @@ func (c *Controller) ServeGetArtists(r *http.Request) *spec.Response { Preload("Info"). Group("artists.id"). Order("artists.name COLLATE NOCASE") - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("sub.root_dir=?", m) } if err := q.Find(&artists).Error; err != nil { @@ -67,7 +68,7 @@ func (c *Controller) ServeGetArtist(r *http.Request) *spec.Response { return spec.NewError(10, "please provide an `id` parameter") } artist := &db.Artist{} - c.DB. + c.dbc. Preload("Albums", func(db *gorm.DB) *gorm.DB { return db. Select("*, count(sub.id) child_count, sum(sub.length) duration"). @@ -99,7 +100,7 @@ func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response { return spec.NewError(10, "please provide an `id` parameter") } album := &db.Album{} - err = c.DB. + err = c.dbc. Select("albums.*, count(tracks.id) child_count, sum(tracks.length) duration"). Joins("LEFT JOIN tracks ON tracks.album_id=albums.id"). Preload("Artists"). @@ -121,7 +122,7 @@ func (c *Controller) ServeGetAlbum(r *http.Request) *spec.Response { sub.Album = spec.NewAlbumByTags(album, album.Artists) sub.Album.Tracks = make([]*spec.TrackChild, len(album.Tracks)) - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for i, track := range album.Tracks { sub.Album.Tracks[i] = spec.NewTrackByTags(track, album) @@ -140,7 +141,7 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { if err != nil { return spec.NewError(10, "please provide a `type` parameter") } - q := c.DB.DB + q := c.dbc.DB switch listType { case "alphabeticalByArtist": q = q.Joins("JOIN artists ON artists.id=album_artists.artist_id") @@ -175,7 +176,7 @@ func (c *Controller) ServeGetAlbumListTwo(r *http.Request) *spec.Response { default: return spec.NewError(10, "unknown value `%s` for parameter 'type'", listType) } - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("root_dir=?", m) } var albums []*db.Album @@ -218,7 +219,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { // search artists var artists []*db.Artist - q := c.DB. + q := c.dbc. Select("*, count(albums.id) album_count"). Group("artists.id") for _, s := range queries { @@ -232,7 +233,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { Preload("Info"). Offset(params.GetOrInt("artistOffset", 0)). Limit(params.GetOrInt("artistCount", 20)) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("albums.root_dir=?", m) } if err := q.Find(&artists).Error; err != nil { @@ -244,7 +245,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { // search albums var albums []*db.Album - q = c.DB. + q = c.dbc. Preload("Artists"). Preload("Genres"). Preload("AlbumStar", "user_id=?", user.ID). @@ -255,7 +256,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { q = q. Offset(params.GetOrInt("albumOffset", 0)). Limit(params.GetOrInt("albumCount", 20)) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("root_dir=?", m) } if err := q.Find(&albums).Error; err != nil { @@ -267,7 +268,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { // search tracks var tracks []*db.Track - q = c.DB. + q = c.dbc. Preload("Album"). Preload("Album.Artists"). Preload("Genres"). @@ -278,7 +279,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { } q = q.Offset(params.GetOrInt("songOffset", 0)). Limit(params.GetOrInt("songCount", 20)) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q. Joins("JOIN albums ON albums.id=tracks.album_id"). Where("albums.root_dir=?", m) @@ -287,7 +288,7 @@ func (c *Controller) ServeSearchThree(r *http.Request) *spec.Response { return spec.NewError(0, "find tracks: %v", err) } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for _, t := range tracks { track := spec.NewTrackByTags(t, t.Album) @@ -308,7 +309,7 @@ func (c *Controller) ServeGetArtistInfoTwo(r *http.Request) *spec.Response { } var artist db.Artist - err = c.DB. + err = c.dbc. Where("id=?", id.Value). Find(&artist). Error @@ -319,7 +320,7 @@ func (c *Controller) ServeGetArtistInfoTwo(r *http.Request) *spec.Response { sub := spec.NewResponse() sub.ArtistInfoTwo = &spec.ArtistInfo{} - info, err := c.ArtistInfoCache.GetOrLookup(r.Context(), artist.ID) + info, err := c.artistInfoCache.GetOrLookup(r.Context(), artist.ID) if err != nil { log.Printf("error fetching artist info from lastfm: %v", err) return sub @@ -348,7 +349,7 @@ func (c *Controller) ServeGetArtistInfoTwo(r *http.Request) *spec.Response { break } var artist db.Artist - err = c.DB. + err = c.dbc. Select("artists.*, count(albums.id) album_count"). Where("name=?", similarName). Joins("LEFT JOIN album_artists ON album_artists.artist_id=artists.id"). @@ -378,7 +379,7 @@ func (c *Controller) ServeGetArtistInfoTwo(r *http.Request) *spec.Response { func (c *Controller) ServeGetGenres(_ *http.Request) *spec.Response { var genres []*db.Genre - c.DB. + c.dbc. Select(`*, (SELECT count(1) FROM album_genres WHERE genre_id=genres.id) album_count, (SELECT count(1) FROM track_genres WHERE genre_id=genres.id) track_count`). @@ -402,7 +403,7 @@ func (c *Controller) ServeGetSongsByGenre(r *http.Request) *spec.Response { return spec.NewError(10, "please provide an `genre` parameter") } var tracks []*db.Track - q := c.DB. + q := c.dbc. Joins("JOIN albums ON tracks.album_id=albums.id"). Joins("JOIN track_genres ON track_genres.track_id=tracks.id"). Joins("JOIN genres ON track_genres.genre_id=genres.id AND genres.name=?", genre). @@ -412,7 +413,7 @@ func (c *Controller) ServeGetSongsByGenre(r *http.Request) *spec.Response { Preload("TrackRating", "user_id=?", user.ID). Offset(params.GetOrInt("offset", 0)). Limit(params.GetOrInt("count", 10)) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("albums.root_dir=?", m) } q = q.Group("tracks.id") @@ -424,7 +425,7 @@ func (c *Controller) ServeGetSongsByGenre(r *http.Request) *spec.Response { List: make([]*spec.TrackChild, len(tracks)), } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for i, t := range tracks { sub.TracksByGenre.List[i] = spec.NewTrackByTags(t, t.Album) @@ -442,7 +443,7 @@ func (c *Controller) ServeGetStarredTwo(r *http.Request) *spec.Response { // artists var artists []*db.Artist - q := c.DB. + q := c.dbc. Joins("JOIN artist_stars ON artist_stars.artist_id=artists.id"). Where("artist_stars.user_id=?", user.ID). Joins("JOIN album_artists ON album_artists.artist_id=artists.id"). @@ -452,7 +453,7 @@ func (c *Controller) ServeGetStarredTwo(r *http.Request) *spec.Response { Preload("ArtistRating", "user_id=?", user.ID). Preload("Info"). Group("artists.id") - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("albums.root_dir=?", m) } if err := q.Find(&artists).Error; err != nil { @@ -464,14 +465,14 @@ func (c *Controller) ServeGetStarredTwo(r *http.Request) *spec.Response { // albums var albums []*db.Album - q = c.DB. + q = c.dbc. Joins("JOIN album_stars ON album_stars.album_id=albums.id"). Where("album_stars.user_id=?", user.ID). Order("album_stars.star_date DESC"). Preload("Artists"). Preload("AlbumStar", "user_id=?", user.ID). Preload("AlbumRating", "user_id=?", user.ID) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("albums.root_dir=?", m) } if err := q.Find(&albums).Error; err != nil { @@ -483,7 +484,7 @@ func (c *Controller) ServeGetStarredTwo(r *http.Request) *spec.Response { // tracks var tracks []*db.Track - q = c.DB. + q = c.dbc. Joins("JOIN track_stars ON tracks.id=track_stars.track_id"). Where("track_stars.user_id=?", user.ID). Order("track_stars.star_date DESC"). @@ -491,7 +492,7 @@ func (c *Controller) ServeGetStarredTwo(r *http.Request) *spec.Response { Preload("Album.Artists"). Preload("TrackStar", "user_id=?", user.ID). Preload("TrackRating", "user_id=?", user.ID) - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q. Joins("JOIN albums ON albums.id=tracks.album_id"). Where("albums.root_dir=?", m) @@ -500,7 +501,7 @@ func (c *Controller) ServeGetStarredTwo(r *http.Request) *spec.Response { return spec.NewError(0, "find tracks: %v", err) } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for _, t := range tracks { track := spec.NewTrackByTags(t, t.Album) @@ -514,8 +515,8 @@ func (c *Controller) ServeGetStarredTwo(r *http.Request) *spec.Response { } func (c *Controller) genArtistCoverURL(r *http.Request, artist *db.Artist, size int) string { - coverURL, _ := url.Parse(c.BaseURL(r)) - coverURL.Path = c.Path("/rest/getCoverArt") + coverURL, _ := url.Parse(handlerutil.BaseURL(r)) + coverURL.Path = c.resolveProxyPath("/rest/getCoverArt") query := r.URL.Query() query.Set("id", artist.SID().String()) @@ -534,11 +535,11 @@ func (c *Controller) ServeGetTopSongs(r *http.Request) *spec.Response { return spec.NewError(10, "please provide an `artist` parameter") } var artist db.Artist - if err := c.DB.Where("name=?", artistName).Find(&artist).Error; err != nil { + if err := c.dbc.Where("name=?", artistName).Find(&artist).Error; err != nil { return spec.NewError(0, "finding artist by name: %v", err) } - info, err := c.ArtistInfoCache.GetOrLookup(r.Context(), artist.ID) + info, err := c.artistInfoCache.GetOrLookup(r.Context(), artist.ID) if err != nil { log.Printf("error fetching artist info from lastfm: %v", err) return spec.NewResponse() @@ -555,7 +556,7 @@ func (c *Controller) ServeGetTopSongs(r *http.Request) *spec.Response { } var tracks []*db.Track - err = c.DB. + err = c.dbc. Preload("Album"). Joins("JOIN albums ON albums.id=tracks.album_id"). Joins("JOIN album_artists ON album_artists.album_id=albums.id"). @@ -573,7 +574,7 @@ func (c *Controller) ServeGetTopSongs(r *http.Request) *spec.Response { return sub } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for _, track := range tracks { tc := spec.NewTrackByTags(track, track.Album) @@ -593,7 +594,7 @@ func (c *Controller) ServeGetSimilarSongs(r *http.Request) *spec.Response { } var track db.Track - err = c.DB. + err = c.dbc. Preload("Album"). Where("id=?", id.Value). First(&track). @@ -602,7 +603,7 @@ func (c *Controller) ServeGetSimilarSongs(r *http.Request) *spec.Response { return spec.NewError(10, "couldn't find a track with that id") } - similarTracks, err := c.LastFMClient.TrackGetSimilarTracks(track.TagTrackArtist, track.TagTitle) + similarTracks, err := c.lastFMClient.TrackGetSimilarTracks(track.TagTrackArtist, track.TagTitle) if err != nil { log.Printf("error fetching similar songs from lastfm: %v", err) return spec.NewResponse() @@ -618,7 +619,7 @@ func (c *Controller) ServeGetSimilarSongs(r *http.Request) *spec.Response { } var tracks []*db.Track - err = c.DB. + err = c.dbc. Select("tracks.*"). Preload("Album"). Preload("TrackStar", "user_id=?", user.ID). @@ -640,7 +641,7 @@ func (c *Controller) ServeGetSimilarSongs(r *http.Request) *spec.Response { Tracks: make([]*spec.TrackChild, len(tracks)), } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for i, track := range tracks { sub.SimilarSongs.Tracks[i] = spec.NewTrackByTags(track, track.Album) @@ -659,7 +660,7 @@ func (c *Controller) ServeGetSimilarSongsTwo(r *http.Request) *spec.Response { } var artist db.Artist - err = c.DB. + err = c.dbc. Where("id=?", id.Value). First(&artist). Error @@ -667,7 +668,7 @@ func (c *Controller) ServeGetSimilarSongsTwo(r *http.Request) *spec.Response { return spec.NewError(0, "artist with id `%s` not found", id) } - similarArtists, err := c.LastFMClient.ArtistGetSimilar(artist.Name) + similarArtists, err := c.lastFMClient.ArtistGetSimilar(artist.Name) if err != nil { log.Printf("error fetching artist info from lastfm: %v", err) return spec.NewResponse() @@ -682,7 +683,7 @@ func (c *Controller) ServeGetSimilarSongsTwo(r *http.Request) *spec.Response { } var tracks []*db.Track - err = c.DB. + err = c.dbc. Preload("Album"). Preload("TrackStar", "user_id=?", user.ID). Preload("TrackRating", "user_id=?", user.ID). @@ -706,7 +707,7 @@ func (c *Controller) ServeGetSimilarSongsTwo(r *http.Request) *spec.Response { Tracks: make([]*spec.TrackChild, len(tracks)), } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for i, track := range tracks { sub.SimilarSongsTwo.Tracks[i] = spec.NewTrackByTags(track, track.Album) sub.SimilarSongsTwo.Tracks[i].TranscodeMeta = transcodeMeta @@ -737,33 +738,33 @@ func (c *Controller) ServeStar(r *http.Request) *spec.Response { stardate := time.Now() for _, id := range starIDsOfType(params, specid.Album) { var albumstar db.AlbumStar - _ = c.DB.Where("user_id=? AND album_id=?", user.ID, id).First(&albumstar).Error + _ = c.dbc.Where("user_id=? AND album_id=?", user.ID, id).First(&albumstar).Error albumstar.UserID = user.ID albumstar.AlbumID = id albumstar.StarDate = stardate - if err := c.DB.Save(&albumstar).Error; err != nil { + if err := c.dbc.Save(&albumstar).Error; err != nil { return spec.NewError(0, "save album star: %v", err) } } for _, id := range starIDsOfType(params, specid.Artist) { var artiststar db.ArtistStar - _ = c.DB.Where("user_id=? AND artist_id=?", user.ID, id).First(&artiststar).Error + _ = c.dbc.Where("user_id=? AND artist_id=?", user.ID, id).First(&artiststar).Error artiststar.UserID = user.ID artiststar.ArtistID = id artiststar.StarDate = stardate - if err := c.DB.Save(&artiststar).Error; err != nil { + if err := c.dbc.Save(&artiststar).Error; err != nil { return spec.NewError(0, "save artist star: %v", err) } } for _, id := range starIDsOfType(params, specid.Track) { var trackstar db.TrackStar - _ = c.DB.Where("user_id=? AND track_id=?", user.ID, id).First(&trackstar).Error + _ = c.dbc.Where("user_id=? AND track_id=?", user.ID, id).First(&trackstar).Error trackstar.UserID = user.ID trackstar.TrackID = id trackstar.StarDate = stardate - if err := c.DB.Save(&trackstar).Error; err != nil { + if err := c.dbc.Save(&trackstar).Error; err != nil { return spec.NewError(0, "save track star: %v", err) } } @@ -776,19 +777,19 @@ func (c *Controller) ServeUnstar(r *http.Request) *spec.Response { user := r.Context().Value(CtxUser).(*db.User) for _, id := range starIDsOfType(params, specid.Album) { - if err := c.DB.Where("user_id=? AND album_id=?", user.ID, id).Delete(db.AlbumStar{}).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Where("user_id=? AND album_id=?", user.ID, id).Delete(db.AlbumStar{}).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return spec.NewError(0, "delete album star: %v", err) } } for _, id := range starIDsOfType(params, specid.Artist) { - if err := c.DB.Where("user_id=? AND artist_id=?", user.ID, id).Delete(db.ArtistStar{}).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Where("user_id=? AND artist_id=?", user.ID, id).Delete(db.ArtistStar{}).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return spec.NewError(0, "delete artist star: %v", err) } } for _, id := range starIDsOfType(params, specid.Track) { - if err := c.DB.Where("user_id=? AND track_id=?", user.ID, id).Delete(db.TrackStar{}).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Where("user_id=? AND track_id=?", user.ID, id).Delete(db.TrackStar{}).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return spec.NewError(0, "delete track star: %v", err) } } @@ -813,95 +814,95 @@ func (c *Controller) ServeSetRating(r *http.Request) *spec.Response { switch id.Type { case specid.Album: var album db.Album - err := c.DB.Where("id=?", id.Value).First(&album).Error + err := c.dbc.Where("id=?", id.Value).First(&album).Error if err != nil { return spec.NewError(0, "fetch album: %v", err) } var albumRating db.AlbumRating - if err := c.DB.Where("user_id=? AND album_id=?", user.ID, id.Value).First(&albumRating).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Where("user_id=? AND album_id=?", user.ID, id.Value).First(&albumRating).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return spec.NewError(0, "fetch album rating: %v", err) } switch { case rating == 0 && albumRating.AlbumID == album.ID: - if err := c.DB.Delete(&albumRating).Error; err != nil { + if err := c.dbc.Delete(&albumRating).Error; err != nil { return spec.NewError(0, "delete album rating: %v", err) } case rating > 0: albumRating.UserID = user.ID albumRating.AlbumID = id.Value albumRating.Rating = rating - if err := c.DB.Save(&albumRating).Error; err != nil { + if err := c.dbc.Save(&albumRating).Error; err != nil { return spec.NewError(0, "save album rating: %v", err) } } var averageRating float64 - if err := c.DB.Model(db.AlbumRating{}).Select("coalesce(avg(rating), 0)").Where("album_id=?", id.Value).Row().Scan(&averageRating); err != nil { + if err := c.dbc.Model(db.AlbumRating{}).Select("coalesce(avg(rating), 0)").Where("album_id=?", id.Value).Row().Scan(&averageRating); err != nil { return spec.NewError(0, "find average album rating: %v", err) } album.AverageRating = math.Trunc(averageRating*100) / 100 - if err := c.DB.Save(&album).Error; err != nil { + if err := c.dbc.Save(&album).Error; err != nil { return spec.NewError(0, "save album: %v", err) } case specid.Artist: var artist db.Artist - err := c.DB.Where("id=?", id.Value).First(&artist).Error + err := c.dbc.Where("id=?", id.Value).First(&artist).Error if err != nil { return spec.NewError(0, "fetch artist: %v", err) } var artistRating db.ArtistRating - if err := c.DB.Where("user_id=? AND artist_id=?", user.ID, id.Value).First(&artistRating).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Where("user_id=? AND artist_id=?", user.ID, id.Value).First(&artistRating).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return spec.NewError(0, "fetch artist rating: %v", err) } switch { case rating == 0 && artistRating.ArtistID == artist.ID: - if err := c.DB.Delete(&artistRating).Error; err != nil { + if err := c.dbc.Delete(&artistRating).Error; err != nil { return spec.NewError(0, "delete artist rating: %v", err) } case rating > 0: artistRating.UserID = user.ID artistRating.ArtistID = id.Value artistRating.Rating = rating - if err := c.DB.Save(&artistRating).Error; err != nil { + if err := c.dbc.Save(&artistRating).Error; err != nil { return spec.NewError(0, "save artist rating: %v", err) } } var averageRating float64 - if err := c.DB.Model(db.ArtistRating{}).Select("coalesce(avg(rating), 0)").Where("artist_id=?", id.Value).Row().Scan(&averageRating); err != nil { + if err := c.dbc.Model(db.ArtistRating{}).Select("coalesce(avg(rating), 0)").Where("artist_id=?", id.Value).Row().Scan(&averageRating); err != nil { return spec.NewError(0, "find average artist rating: %v", err) } artist.AverageRating = math.Trunc(averageRating*100) / 100 - if err := c.DB.Save(&artist).Error; err != nil { + if err := c.dbc.Save(&artist).Error; err != nil { return spec.NewError(0, "save artist: %v", err) } case specid.Track: var track db.Track - err := c.DB.Where("id=?", id.Value).First(&track).Error + err := c.dbc.Where("id=?", id.Value).First(&track).Error if err != nil { return spec.NewError(0, "fetch track: %v", err) } var trackRating db.TrackRating - if err := c.DB.Where("user_id=? AND track_id=?", user.ID, id.Value).First(&trackRating).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Where("user_id=? AND track_id=?", user.ID, id.Value).First(&trackRating).Error; err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return spec.NewError(0, "fetch track rating: %v", err) } switch { case rating == 0 && trackRating.TrackID == track.ID: - if err := c.DB.Delete(&trackRating).Error; err != nil { + if err := c.dbc.Delete(&trackRating).Error; err != nil { return spec.NewError(0, "delete track rating: %v", err) } case rating > 0: trackRating.UserID = user.ID trackRating.TrackID = id.Value trackRating.Rating = rating - if err := c.DB.Save(&trackRating).Error; err != nil { + if err := c.dbc.Save(&trackRating).Error; err != nil { return spec.NewError(0, "save track rating: %v", err) } } var averageRating float64 - if err := c.DB.Model(db.TrackRating{}).Select("coalesce(avg(rating), 0)").Where("track_id=?", id.Value).Row().Scan(&averageRating); err != nil { + if err := c.dbc.Model(db.TrackRating{}).Select("coalesce(avg(rating), 0)").Where("track_id=?", id.Value).Row().Scan(&averageRating); err != nil { return spec.NewError(0, "find average track rating: %v", err) } track.AverageRating = math.Trunc(averageRating*100) / 100 - if err := c.DB.Save(&track).Error; err != nil { + if err := c.dbc.Save(&track).Error; err != nil { return spec.NewError(0, "save track: %v", err) } default: diff --git a/server/ctrlsubsonic/handlers_by_tags_test.go b/server/ctrlsubsonic/handlers_by_tags_test.go index 7ea1bf8..cc71485 100644 --- a/server/ctrlsubsonic/handlers_by_tags_test.go +++ b/server/ctrlsubsonic/handlers_by_tags_test.go @@ -8,8 +8,7 @@ import ( func TestGetArtists(t *testing.T) { t.Parallel() contr := makeControllerRoots(t, []string{"m-0", "m-1"}) - - runQueryCases(t, contr, contr.ServeGetArtists, []*queryCase{ + runQueryCases(t, contr.ServeGetArtists, []*queryCase{ {url.Values{}, "no_args", false}, {url.Values{"musicFolderId": {"0"}}, "with_music_folder_1", false}, {url.Values{"musicFolderId": {"1"}}, "with_music_folder_2", false}, @@ -19,8 +18,7 @@ func TestGetArtists(t *testing.T) { func TestGetArtist(t *testing.T) { t.Parallel() contr := makeController(t) - - runQueryCases(t, contr, contr.ServeGetArtist, []*queryCase{ + runQueryCases(t, contr.ServeGetArtist, []*queryCase{ {url.Values{"id": {"ar-1"}}, "id_one", false}, {url.Values{"id": {"ar-2"}}, "id_two", false}, {url.Values{"id": {"ar-3"}}, "id_three", false}, @@ -30,8 +28,7 @@ func TestGetArtist(t *testing.T) { func TestGetAlbum(t *testing.T) { t.Parallel() contr := makeController(t) - - runQueryCases(t, contr, contr.ServeGetAlbum, []*queryCase{ + runQueryCases(t, contr.ServeGetAlbum, []*queryCase{ {url.Values{"id": {"al-2"}}, "without_cover", false}, {url.Values{"id": {"al-3"}}, "with_cover", false}, }) @@ -40,8 +37,7 @@ func TestGetAlbum(t *testing.T) { func TestGetAlbumListTwo(t *testing.T) { t.Parallel() contr := makeController(t) - - runQueryCases(t, contr, contr.ServeGetAlbumListTwo, []*queryCase{ + runQueryCases(t, contr.ServeGetAlbumListTwo, []*queryCase{ {url.Values{"type": {"alphabeticalByArtist"}}, "alpha_artist", false}, {url.Values{"type": {"alphabeticalByName"}}, "alpha_name", false}, {url.Values{"type": {"newest"}}, "newest", false}, @@ -52,8 +48,7 @@ func TestGetAlbumListTwo(t *testing.T) { func TestSearchThree(t *testing.T) { t.Parallel() contr := makeController(t) - - runQueryCases(t, contr, contr.ServeSearchThree, []*queryCase{ + runQueryCases(t, contr.ServeSearchThree, []*queryCase{ {url.Values{"query": {"art"}}, "q_art", false}, {url.Values{"query": {"alb"}}, "q_alb", false}, {url.Values{"query": {"tit"}}, "q_tra", false}, diff --git a/server/ctrlsubsonic/handlers_common.go b/server/ctrlsubsonic/handlers_common.go index f8e9cb1..9bad487 100644 --- a/server/ctrlsubsonic/handlers_common.go +++ b/server/ctrlsubsonic/handlers_common.go @@ -58,7 +58,7 @@ func (c *Controller) ServeScrobble(r *http.Request) *spec.Response { switch id.Type { case specid.Track: var track db.Track - if err := c.DB.Preload("Album").Preload("Album.Artists").First(&track, id.Value).Error; err != nil { + if err := c.dbc.Preload("Album").Preload("Album.Artists").First(&track, id.Value).Error; err != nil { return spec.NewError(0, "error finding track: %v", err) } if track.Album == nil { @@ -75,13 +75,13 @@ func (c *Controller) ServeScrobble(r *http.Request) *spec.Response { scrobbleTrack.MusicBrainzID = track.TagBrainzID } - if err := scrobbleStatsUpdateTrack(c.DB, &track, user.ID, optStamp); err != nil { + if err := scrobbleStatsUpdateTrack(c.dbc, &track, user.ID, optStamp); err != nil { return spec.NewError(0, "error updating stats: %v", err) } case specid.PodcastEpisode: var podcastEpisode db.PodcastEpisode - if err := c.DB.Preload("Podcast").First(&podcastEpisode, id.Value).Error; err != nil { + if err := c.dbc.Preload("Podcast").First(&podcastEpisode, id.Value).Error; err != nil { return spec.NewError(0, "error finding podcast episode: %v", err) } @@ -89,13 +89,13 @@ func (c *Controller) ServeScrobble(r *http.Request) *spec.Response { scrobbleTrack.Artist = podcastEpisode.Podcast.Title scrobbleTrack.Duration = time.Second * time.Duration(podcastEpisode.Length) - if err := scrobbleStatsUpdatePodcastEpisode(c.DB, id.Value); err != nil { + if err := scrobbleStatsUpdatePodcastEpisode(c.dbc, id.Value); err != nil { return spec.NewError(0, "error updating stats: %v", err) } } var scrobbleErrs []error - for _, scrobbler := range c.Scrobblers { + for _, scrobbler := range c.scrobblers { if !scrobbler.IsUserAuthenticated(*user) { continue } @@ -113,7 +113,7 @@ func (c *Controller) ServeScrobble(r *http.Request) *spec.Response { func (c *Controller) ServeGetMusicFolders(_ *http.Request) *spec.Response { sub := spec.NewResponse() sub.MusicFolders = &spec.MusicFolders{} - for i, mp := range c.MusicPaths { + for i, mp := range c.musicPaths { alias := mp.Alias if alias == "" { alias = filepath.Base(mp.Path) @@ -125,7 +125,7 @@ func (c *Controller) ServeGetMusicFolders(_ *http.Request) *spec.Response { func (c *Controller) ServeStartScan(r *http.Request) *spec.Response { go func() { - if _, err := c.Scanner.ScanAndClean(scanner.ScanOptions{}); err != nil { + if _, err := c.scanner.ScanAndClean(scanner.ScanOptions{}); err != nil { log.Printf("error while scanning: %v\n", err) } }() @@ -134,13 +134,13 @@ func (c *Controller) ServeStartScan(r *http.Request) *spec.Response { func (c *Controller) ServeGetScanStatus(_ *http.Request) *spec.Response { var trackCount int - if err := c.DB.Model(db.Track{}).Count(&trackCount).Error; err != nil { + if err := c.dbc.Model(db.Track{}).Count(&trackCount).Error; err != nil { return spec.NewError(0, "error finding track count: %v", err) } sub := spec.NewResponse() sub.ScanStatus = &spec.ScanStatus{ - Scanning: c.Scanner.IsScanning(), + Scanning: c.scanner.IsScanning(), Count: trackCount, } return sub @@ -155,8 +155,8 @@ func (c *Controller) ServeGetUser(r *http.Request) *spec.Response { sub.User = &spec.User{ Username: user.Name, AdminRole: user.IsAdmin, - JukeboxRole: c.Jukebox != nil, - PodcastRole: c.Podcasts != nil, + JukeboxRole: c.jukebox != nil, + PodcastRole: c.podcasts != nil, DownloadRole: true, ScrobblingEnabled: hasLastFM || hasListenBrainz, Folder: []int{1}, @@ -172,7 +172,7 @@ func (c *Controller) ServeGetPlayQueue(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) user := r.Context().Value(CtxUser).(*db.User) var queue db.PlayQueue - err := c.DB. + err := c.dbc. Where("user_id=?", user.ID). Find(&queue). Error @@ -190,13 +190,13 @@ func (c *Controller) ServeGetPlayQueue(r *http.Request) *spec.Response { trackIDs := queue.GetItems() sub.PlayQueue.List = make([]*spec.TrackChild, len(trackIDs)) - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for i, id := range trackIDs { switch id.Type { case specid.Track: track := db.Track{} - c.DB. + c.dbc. Where("id=?", id.Value). Preload("Album"). Preload("TrackStar", "user_id=?", user.ID). @@ -206,7 +206,7 @@ func (c *Controller) ServeGetPlayQueue(r *http.Request) *spec.Response { sub.PlayQueue.List[i].TranscodeMeta = transcodeMeta case specid.PodcastEpisode: pe := db.PodcastEpisode{} - c.DB. + c.dbc. Where("id=?", id.Value). Find(&pe) sub.PlayQueue.List[i] = spec.NewTCPodcastEpisode(&pe) @@ -233,13 +233,13 @@ func (c *Controller) ServeSavePlayQueue(r *http.Request) *spec.Response { } user := r.Context().Value(CtxUser).(*db.User) var queue db.PlayQueue - c.DB.Where("user_id=?", user.ID).First(&queue) + c.dbc.Where("user_id=?", user.ID).First(&queue) queue.UserID = user.ID queue.Current = params.GetOrID("current", specid.ID{}).String() queue.Position = params.GetOrInt("position", 0) queue.ChangedBy = params.GetOr("c", "") // must exist, middleware checks queue.SetItems(trackIDs) - c.DB.Save(&queue) + c.dbc.Save(&queue) return spec.NewResponse() } @@ -251,7 +251,7 @@ func (c *Controller) ServeGetSong(r *http.Request) *spec.Response { return spec.NewError(10, "provide an `id` parameter") } var track db.Track - err = c.DB. + err = c.dbc. Where("id=?", id.Value). Preload("Album"). Preload("Album.Artists"). @@ -263,7 +263,7 @@ func (c *Controller) ServeGetSong(r *http.Request) *spec.Response { return spec.NewError(10, "couldn't find a track with that id") } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) sub := spec.NewResponse() sub.Track = spec.NewTrackByTags(&track, track.Album) @@ -277,7 +277,7 @@ func (c *Controller) ServeGetRandomSongs(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) user := r.Context().Value(CtxUser).(*db.User) var tracks []*db.Track - q := c.DB.DB. + q := c.dbc.DB. Limit(params.GetOrInt("size", 10)). Preload("Album"). Preload("Album.Artists"). @@ -295,7 +295,7 @@ func (c *Controller) ServeGetRandomSongs(r *http.Request) *spec.Response { q = q.Joins("JOIN track_genres ON track_genres.track_id=tracks.id") q = q.Joins("JOIN genres ON genres.id=track_genres.genre_id AND genres.name=?", genre) } - if m := getMusicFolder(c.MusicPaths, params); m != "" { + if m := getMusicFolder(c.musicPaths, params); m != "" { q = q.Where("albums.root_dir=?", m) } if err := q.Find(&tracks).Error; err != nil { @@ -305,7 +305,7 @@ func (c *Controller) ServeGetRandomSongs(r *http.Request) *spec.Response { sub.RandomTracks = &spec.RandomTracks{} sub.RandomTracks.List = make([]*spec.TrackChild, len(tracks)) - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for i, track := range tracks { sub.RandomTracks.List[i] = spec.NewTrackByTags(track, track.Album) @@ -321,7 +321,7 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go trackPaths := func(ids []specid.ID) ([]string, error) { var paths []string for _, id := range ids { - r, err := specidpaths.Locate(c.DB, id) + r, err := specidpaths.Locate(c.dbc, id) if err != nil { return nil, fmt.Errorf("find track by id: %w", err) } @@ -330,7 +330,7 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go return paths, nil } getSpecStatus := func() (*spec.JukeboxStatus, error) { - status, err := c.Jukebox.GetStatus() + status, err := c.jukebox.GetStatus() if err != nil { return nil, fmt.Errorf("get status: %w", err) } @@ -343,12 +343,12 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go } getSpecPlaylist := func() ([]*spec.TrackChild, error) { var ret []*spec.TrackChild - playlist, err := c.Jukebox.GetPlaylist() + playlist, err := c.jukebox.GetPlaylist() if err != nil { return nil, fmt.Errorf("get playlist: %w", err) } for _, path := range playlist { - file, err := specidpaths.Lookup(c.DB, PathsOf(c.MusicPaths), c.PodcastsPath, path) + file, err := specidpaths.Lookup(c.dbc, MusicPaths(c.musicPaths), c.podcastsPath, path) if err != nil { return nil, fmt.Errorf("fetch track: %w", err) } @@ -368,7 +368,7 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go if err != nil { return spec.NewError(0, "error creating playlist items: %v", err) } - if err := c.Jukebox.SetPlaylist(paths); err != nil { + if err := c.jukebox.SetPlaylist(paths); err != nil { return spec.NewError(0, "error setting playlist: %v", err) } case "add": @@ -377,11 +377,11 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go if err != nil { return spec.NewError(10, "error creating playlist items: %v", err) } - if err := c.Jukebox.AppendToPlaylist(paths); err != nil { + if err := c.jukebox.AppendToPlaylist(paths); err != nil { return spec.NewError(0, "error appending to playlist: %v", err) } case "clear": - if err := c.Jukebox.ClearPlaylist(); err != nil { + if err := c.jukebox.ClearPlaylist(); err != nil { return spec.NewError(0, "error clearing playlist: %v", err) } case "remove": @@ -389,15 +389,15 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go if err != nil { return spec.NewError(10, "please provide an id for remove actions") } - if err := c.Jukebox.RemovePlaylistIndex(index); err != nil { + if err := c.jukebox.RemovePlaylistIndex(index); err != nil { return spec.NewError(0, "error removing: %v", err) } case "stop": - if err := c.Jukebox.Pause(); err != nil { + if err := c.jukebox.Pause(); err != nil { return spec.NewError(0, "error stopping: %v", err) } case "start": - if err := c.Jukebox.Play(); err != nil { + if err := c.jukebox.Play(); err != nil { return spec.NewError(0, "error starting: %v", err) } case "skip": @@ -406,7 +406,7 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go return spec.NewError(10, "please provide an index for skip actions") } offset, _ := params.GetInt("offset") - if err := c.Jukebox.SkipToPlaylistIndex(index, offset); err != nil { + if err := c.jukebox.SkipToPlaylistIndex(index, offset); err != nil { return spec.NewError(0, "error skipping: %v", err) } case "get": @@ -429,7 +429,7 @@ func (c *Controller) ServeJukebox(r *http.Request) *spec.Response { // nolint:go if err != nil { return spec.NewError(10, "please provide a valid gain param") } - if err := c.Jukebox.SetVolumePct(int(math.Min(gain, 1) * 100)); err != nil { + if err := c.jukebox.SetVolumePct(int(math.Min(gain, 1) * 100)); err != nil { return spec.NewError(0, "error setting gain: %v", err) } } diff --git a/server/ctrlsubsonic/handlers_internet_radio.go b/server/ctrlsubsonic/handlers_internet_radio.go index f2fc21f..0d03eec 100644 --- a/server/ctrlsubsonic/handlers_internet_radio.go +++ b/server/ctrlsubsonic/handlers_internet_radio.go @@ -11,7 +11,7 @@ import ( func (c *Controller) ServeGetInternetRadioStations(_ *http.Request) *spec.Response { var stations []*db.InternetRadioStation - if err := c.DB.Find(&stations).Error; err != nil { + if err := c.dbc.Find(&stations).Error; err != nil { return spec.NewError(0, "find stations: %v", err) } sub := spec.NewResponse() @@ -55,7 +55,7 @@ func (c *Controller) ServeCreateInternetRadioStation(r *http.Request) *spec.Resp station.Name = name station.HomepageURL = homepageURL - if err := c.DB.Save(&station).Error; err != nil { + if err := c.dbc.Save(&station).Error; err != nil { return spec.NewError(0, "save station: %v", err) } @@ -92,7 +92,7 @@ func (c *Controller) ServeUpdateInternetRadioStation(r *http.Request) *spec.Resp } var station db.InternetRadioStation - if err := c.DB.Where("id=?", stationID.Value).First(&station).Error; err != nil { + if err := c.dbc.Where("id=?", stationID.Value).First(&station).Error; err != nil { return spec.NewError(70, "id not found: %v", err) } @@ -100,7 +100,7 @@ func (c *Controller) ServeUpdateInternetRadioStation(r *http.Request) *spec.Resp station.Name = name station.HomepageURL = homepageURL - if err := c.DB.Save(&station).Error; err != nil { + if err := c.dbc.Save(&station).Error; err != nil { return spec.NewError(0, "save station: %v", err) } return spec.NewResponse() @@ -119,11 +119,11 @@ func (c *Controller) ServeDeleteInternetRadioStation(r *http.Request) *spec.Resp } var station db.InternetRadioStation - if err := c.DB.Where("id=?", stationID.Value).First(&station).Error; err != nil { + if err := c.dbc.Where("id=?", stationID.Value).First(&station).Error; err != nil { return spec.NewError(70, "id not found: %v", err) } - if err := c.DB.Delete(&station).Error; err != nil { + if err := c.dbc.Delete(&station).Error; err != nil { return spec.NewError(70, "id not found: %v", err) } diff --git a/server/ctrlsubsonic/handlers_internet_radio_test.go b/server/ctrlsubsonic/handlers_internet_radio_test.go index cd8c028..c266ad5 100644 --- a/server/ctrlsubsonic/handlers_internet_radio_test.go +++ b/server/ctrlsubsonic/handlers_internet_radio_test.go @@ -61,7 +61,7 @@ func TestInternetRadio(t *testing.T) { t.Run("deletes", func(t *testing.T) { testInternetRadioDeletes(t, contr) }) } -func runTestCase(t *testing.T, contr *Controller, h handlerSubsonic, q url.Values, admin bool) *spec.SubsonicResponse { +func runTestCase(t *testing.T, h handlerSubsonic, q url.Values, admin bool) *spec.SubsonicResponse { t.Helper() var rr *httptest.ResponseRecorder @@ -72,7 +72,7 @@ func runTestCase(t *testing.T, contr *Controller, h handlerSubsonic, q url.Value } else { rr, req = makeHTTPMock(q) } - contr.H(h).ServeHTTP(rr, req) + resp(h).ServeHTTP(rr, req) body := rr.Body.String() if status := rr.Code; status != http.StatusOK { t.Fatalf("didn't give a 200\n%s", body) @@ -134,29 +134,29 @@ func testInternetRadioBadCreates(t *testing.T, contr *Controller) { var response *spec.SubsonicResponse // no parameters - response = runTestCase(t, contr, contr.ServeCreateInternetRadioStation, url.Values{}, true) + response = runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{}, true) checkMissingParameter(t, response) // just one required parameter - response = runTestCase(t, contr, contr.ServeCreateInternetRadioStation, + response = runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{"streamUrl": {station1StreamURL}}, true) checkMissingParameter(t, response) - response = runTestCase(t, contr, contr.ServeCreateInternetRadioStation, + response = runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{"name": {station1Name}}, true) checkMissingParameter(t, response) // bad URLs - response = runTestCase(t, contr, contr.ServeCreateInternetRadioStation, + response = runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{"streamUrl": {station1StreamURL}, "name": {station1Name}, "homepageUrl": {notAURL}}, true) checkBadParameter(t, response) - response = runTestCase(t, contr, contr.ServeCreateInternetRadioStation, + response = runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{"streamUrl": {notAURL}, "name": {station1Name}, "homepageUrl": {station1HomepageURL}}, true) checkBadParameter(t, response) // check for empty get after - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if (response.Response.InternetRadioStations == nil) || (len(response.Response.InternetRadioStations.List) != 0) { @@ -166,7 +166,7 @@ func testInternetRadioBadCreates(t *testing.T, contr *Controller) { func testInternetRadioInitialEmpty(t *testing.T, contr *Controller) { // check for empty get on new DB - response := runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response := runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if (response.Response.InternetRadioStations == nil) || (len(response.Response.InternetRadioStations.List) != 0) { @@ -176,15 +176,15 @@ func testInternetRadioInitialEmpty(t *testing.T, contr *Controller) { func testInternetRadioInitialAdds(t *testing.T, contr *Controller) { // successful adds and read back - response := runTestCase(t, contr, contr.ServeCreateInternetRadioStation, + response := runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{"streamUrl": {station1StreamURL}, "name": {station1Name}, "homepageUrl": {station1HomepageURL}}, true) checkSuccess(t, response) - response = runTestCase(t, contr, contr.ServeCreateInternetRadioStation, + response = runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{"streamUrl": {station2StreamURL}, "name": {station2Name}}, true) // NOTE: no homepage Url checkSuccess(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if response.Response.InternetRadioStations == nil { @@ -207,16 +207,16 @@ func testInternetRadioInitialAdds(t *testing.T, contr *Controller) { func testInternetRadioUpdateHomepage(t *testing.T, contr *Controller) { // update empty homepage URL without other parameters (fails) - response := runTestCase(t, contr, contr.ServeUpdateInternetRadioStation, + response := runTestCase(t, contr.ServeUpdateInternetRadioStation, url.Values{"id": {station2ID}, "homepageUrl": {station2HomepageURL}}, true) checkMissingParameter(t, response) // update empty homepage URL properly and read back - response = runTestCase(t, contr, contr.ServeUpdateInternetRadioStation, + response = runTestCase(t, contr.ServeUpdateInternetRadioStation, url.Values{"id": {station2ID}, "streamUrl": {station2StreamURL}, "name": {station2Name}, "homepageUrl": {station2HomepageURL}}, true) checkSuccess(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if response.Response.InternetRadioStations == nil { @@ -239,19 +239,19 @@ func testInternetRadioUpdateHomepage(t *testing.T, contr *Controller) { func testInternetRadioNotAdmin(t *testing.T, contr *Controller) { // create, update, delete w/o admin privileges (fails and does not modify data) - response := runTestCase(t, contr, contr.ServeCreateInternetRadioStation, + response := runTestCase(t, contr.ServeCreateInternetRadioStation, url.Values{"streamUrl": {station1StreamURL}, "name": {station1Name}, "homepageUrl": {station1HomepageURL}}, false) checkNotAdmin(t, response) - response = runTestCase(t, contr, contr.ServeUpdateInternetRadioStation, + response = runTestCase(t, contr.ServeUpdateInternetRadioStation, url.Values{"id": {station1ID}, "streamUrl": {newstation1StreamURL}, "name": {newstation1Name}, "homepageUrl": {newstation1HomepageURL}}, false) checkNotAdmin(t, response) - response = runTestCase(t, contr, contr.ServeDeleteInternetRadioStation, + response = runTestCase(t, contr.ServeDeleteInternetRadioStation, url.Values{"id": {station1ID}}, false) checkNotAdmin(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if response.Response.InternetRadioStations == nil { @@ -274,11 +274,11 @@ func testInternetRadioNotAdmin(t *testing.T, contr *Controller) { func testInternetRadioUpdates(t *testing.T, contr *Controller) { // replace station 1 and read back - response := runTestCase(t, contr, contr.ServeUpdateInternetRadioStation, + response := runTestCase(t, contr.ServeUpdateInternetRadioStation, url.Values{"id": {station1ID}, "streamUrl": {newstation1StreamURL}, "name": {newstation1Name}, "homepageUrl": {newstation1HomepageURL}}, true) checkSuccess(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if response.Response.InternetRadioStations == nil { @@ -299,11 +299,11 @@ func testInternetRadioUpdates(t *testing.T, contr *Controller) { } // update station 2 but without homepage URL and read back - response = runTestCase(t, contr, contr.ServeUpdateInternetRadioStation, + response = runTestCase(t, contr.ServeUpdateInternetRadioStation, url.Values{"id": {station2ID}, "streamUrl": {newstation2StreamURL}, "name": {newstation2Name}}, true) checkSuccess(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if response.Response.InternetRadioStations == nil { @@ -326,11 +326,11 @@ func testInternetRadioUpdates(t *testing.T, contr *Controller) { func testInternetRadioDeletes(t *testing.T, contr *Controller) { // delete non-existent station 3 (fails and does not modify data) - response := runTestCase(t, contr, contr.ServeDeleteInternetRadioStation, + response := runTestCase(t, contr.ServeDeleteInternetRadioStation, url.Values{"id": {station3ID}}, true) checkBadParameter(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if response.Response.InternetRadioStations == nil { @@ -351,11 +351,11 @@ func testInternetRadioDeletes(t *testing.T, contr *Controller) { } // delete station 1 and recheck - response = runTestCase(t, contr, contr.ServeDeleteInternetRadioStation, + response = runTestCase(t, contr.ServeDeleteInternetRadioStation, url.Values{"id": {station1ID}}, true) checkSuccess(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if response.Response.InternetRadioStations == nil { @@ -372,11 +372,11 @@ func testInternetRadioDeletes(t *testing.T, contr *Controller) { } // delete station 2 and check that they're all gone - response = runTestCase(t, contr, contr.ServeDeleteInternetRadioStation, + response = runTestCase(t, contr.ServeDeleteInternetRadioStation, url.Values{"id": {station2ID}}, true) checkSuccess(t, response) - response = runTestCase(t, contr, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin + response = runTestCase(t, contr.ServeGetInternetRadioStations, url.Values{}, false) // no need to be admin checkSuccess(t, response) if (response.Response.InternetRadioStations == nil) || (len(response.Response.InternetRadioStations.List) != 0) { diff --git a/server/ctrlsubsonic/handlers_playlist.go b/server/ctrlsubsonic/handlers_playlist.go index 1b240ce..d196b70 100644 --- a/server/ctrlsubsonic/handlers_playlist.go +++ b/server/ctrlsubsonic/handlers_playlist.go @@ -22,7 +22,7 @@ import ( func (c *Controller) ServeGetPlaylists(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) user := r.Context().Value(CtxUser).(*db.User) - paths, err := c.PlaylistStore.List() + paths, err := c.playlistStore.List() if err != nil { return spec.NewError(0, "error listing playlists: %v", err) } @@ -31,7 +31,7 @@ func (c *Controller) ServeGetPlaylists(r *http.Request) *spec.Response { List: []*spec.Playlist{}, } for _, path := range paths { - playlist, err := c.PlaylistStore.Read(path) + playlist, err := c.playlistStore.Read(path) if err != nil { return spec.NewError(0, "error reading playlist %q: %v", path, err) } @@ -54,7 +54,7 @@ func (c *Controller) ServeGetPlaylist(r *http.Request) *spec.Response { if err != nil { return spec.NewError(10, "please provide an `id` parameter") } - playlist, err := c.PlaylistStore.Read(playlistIDDecode(playlistID)) + playlist, err := c.playlistStore.Read(playlistIDDecode(playlistID)) if err != nil { return spec.NewError(70, "playlist with id %s not found", playlistID) } @@ -75,7 +75,7 @@ func (c *Controller) ServeCreatePlaylist(r *http.Request) *spec.Response { playlistPath := playlistIDDecode(playlistID) var playlist playlistp.Playlist - if pl, _ := c.PlaylistStore.Read(playlistPath); pl != nil { + if pl, _ := c.playlistStore.Read(playlistPath); pl != nil { playlist = *pl } @@ -94,7 +94,7 @@ func (c *Controller) ServeCreatePlaylist(r *http.Request) *spec.Response { playlist.Items = nil ids := params.GetOrIDList("songId", nil) for _, id := range ids { - r, err := specidpaths.Locate(c.DB, id) + r, err := specidpaths.Locate(c.dbc, id) if err != nil { return spec.NewError(0, "lookup id %v: %v", id, err) } @@ -104,7 +104,7 @@ func (c *Controller) ServeCreatePlaylist(r *http.Request) *spec.Response { if playlistPath == "" { playlistPath = playlistp.NewPath(user.ID, fmt.Sprint(time.Now().UnixMilli())) } - if err := c.PlaylistStore.Write(playlistPath, &playlist); err != nil { + if err := c.playlistStore.Write(playlistPath, &playlist); err != nil { return spec.NewError(0, "save playlist: %v", err) } @@ -123,7 +123,7 @@ func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response { playlistID := params.GetFirstOr( /* default */ "", "id", "playlistId") playlistPath := playlistIDDecode(playlistID) - playlist, err := c.PlaylistStore.Read(playlistPath) + playlist, err := c.playlistStore.Read(playlistPath) if err != nil { return spec.NewError(0, "find playlist: %v", err) } @@ -154,7 +154,7 @@ func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response { // add items if ids, err := params.GetIDList("songIdToAdd"); err == nil { for _, id := range ids { - item, err := specidpaths.Locate(c.DB, id) + item, err := specidpaths.Locate(c.dbc, id) if err != nil { return spec.NewError(0, "locate id %q: %v", id, err) } @@ -162,7 +162,7 @@ func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response { } } - if err := c.PlaylistStore.Write(playlistPath, playlist); err != nil { + if err := c.playlistStore.Write(playlistPath, playlist); err != nil { return spec.NewError(0, "save playlist: %v", err) } return spec.NewResponse() @@ -171,7 +171,7 @@ func (c *Controller) ServeUpdatePlaylist(r *http.Request) *spec.Response { func (c *Controller) ServeDeletePlaylist(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) playlistID := params.GetFirstOr( /* default */ "", "id", "playlistId") - if err := c.PlaylistStore.Delete(playlistIDDecode(playlistID)); err != nil { + if err := c.playlistStore.Delete(playlistIDDecode(playlistID)); err != nil { return spec.NewError(0, "delete playlist: %v", err) } return spec.NewResponse() @@ -188,7 +188,7 @@ func playlistIDDecode(id string) string { func playlistRender(c *Controller, params params.Params, playlistID string, playlist *playlistp.Playlist, withItems bool) (*spec.Playlist, error) { user := &db.User{} - if err := c.DB.Where("id=?", playlist.UserID).Find(user).Error; err != nil { + if err := c.dbc.Where("id=?", playlist.UserID).Find(user).Error; err != nil { return nil, fmt.Errorf("find user by id: %w", err) } @@ -205,10 +205,10 @@ func playlistRender(c *Controller, params params.Params, playlistID string, play return resp, nil } - transcodeMeta := streamGetTranscodeMeta(c.DB, user.ID, params.GetOr("c", "")) + transcodeMeta := streamGetTranscodeMeta(c.dbc, user.ID, params.GetOr("c", "")) for _, path := range playlist.Items { - file, err := specidpaths.Lookup(c.DB, PathsOf(c.MusicPaths), c.PodcastsPath, path) + file, err := specidpaths.Lookup(c.dbc, MusicPaths(c.musicPaths), c.podcastsPath, path) if err != nil { log.Printf("error looking up path %q: %s", path, err) continue @@ -218,14 +218,14 @@ func playlistRender(c *Controller, params params.Params, playlistID string, play switch id := file.SID(); id.Type { case specid.Track: var track db.Track - if err := c.DB.Where("id=?", id.Value).Preload("Album").Preload("Album.Artists").Preload("TrackStar", "user_id=?", user.ID).Preload("TrackRating", "user_id=?", user.ID).Find(&track).Error; errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Where("id=?", id.Value).Preload("Album").Preload("Album.Artists").Preload("TrackStar", "user_id=?", user.ID).Preload("TrackRating", "user_id=?", user.ID).Find(&track).Error; errors.Is(err, gorm.ErrRecordNotFound) { return nil, fmt.Errorf("load track by id: %w", err) } trch = spec.NewTCTrackByFolder(&track, track.Album) resp.Duration += track.Length case specid.PodcastEpisode: var pe db.PodcastEpisode - if err := c.DB.Preload("Podcast").Where("id=?", id.Value).Find(&pe).Error; errors.Is(err, gorm.ErrRecordNotFound) { + if err := c.dbc.Preload("Podcast").Where("id=?", id.Value).Find(&pe).Error; errors.Is(err, gorm.ErrRecordNotFound) { return nil, fmt.Errorf("load podcast episode by id: %w", err) } trch = spec.NewTCPodcastEpisode(&pe) diff --git a/server/ctrlsubsonic/handlers_podcast.go b/server/ctrlsubsonic/handlers_podcast.go index 90f827a..3af4852 100644 --- a/server/ctrlsubsonic/handlers_podcast.go +++ b/server/ctrlsubsonic/handlers_podcast.go @@ -15,7 +15,7 @@ func (c *Controller) ServeGetPodcasts(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) isIncludeEpisodes := params.GetOrBool("includeEpisodes", true) id, _ := params.GetID("id") - podcasts, err := c.Podcasts.GetPodcastOrAll(id.Value, isIncludeEpisodes) + podcasts, err := c.podcasts.GetPodcastOrAll(id.Value, isIncludeEpisodes) if err != nil { return spec.NewError(10, "failed get podcast(s): %s", err) } @@ -31,7 +31,7 @@ func (c *Controller) ServeGetPodcasts(r *http.Request) *spec.Response { func (c *Controller) ServeGetNewestPodcasts(r *http.Request) *spec.Response { params := r.Context().Value(CtxParams).(params.Params) count := params.GetOrInt("count", 10) - episodes, err := c.Podcasts.GetNewestPodcastEpisodes(count) + episodes, err := c.podcasts.GetNewestPodcastEpisodes(count) if err != nil { return spec.NewError(10, "failed get podcast(s): %s", err) } @@ -53,7 +53,7 @@ func (c *Controller) ServeDownloadPodcastEpisode(r *http.Request) *spec.Response if err != nil || id.Type != specid.PodcastEpisode { return spec.NewError(10, "please provide a valid podcast episode id") } - if err := c.Podcasts.DownloadEpisode(id.Value); err != nil { + if err := c.podcasts.DownloadEpisode(id.Value); err != nil { return spec.NewError(10, "failed to download episode: %s", err) } return spec.NewResponse() @@ -71,7 +71,7 @@ func (c *Controller) ServeCreatePodcastChannel(r *http.Request) *spec.Response { if err != nil { return spec.NewError(10, "failed to parse feed: %s", err) } - if _, err = c.Podcasts.AddNewPodcast(rssURL, feed); err != nil { + if _, err = c.podcasts.AddNewPodcast(rssURL, feed); err != nil { return spec.NewError(10, "failed to add feed: %s", err) } return spec.NewResponse() @@ -82,7 +82,7 @@ func (c *Controller) ServeRefreshPodcasts(r *http.Request) *spec.Response { if !user.IsAdmin { return spec.NewError(50, "user not admin") } - if err := c.Podcasts.RefreshPodcasts(); err != nil { + if err := c.podcasts.RefreshPodcasts(); err != nil { return spec.NewError(10, "failed to refresh feeds: %s", err) } return spec.NewResponse() @@ -98,7 +98,7 @@ func (c *Controller) ServeDeletePodcastChannel(r *http.Request) *spec.Response { if err != nil || id.Type != specid.Podcast { return spec.NewError(10, "please provide a valid podcast id") } - if err := c.Podcasts.DeletePodcast(id.Value); err != nil { + if err := c.podcasts.DeletePodcast(id.Value); err != nil { return spec.NewError(10, "failed to delete podcast: %s", err) } return spec.NewResponse() @@ -114,7 +114,7 @@ func (c *Controller) ServeDeletePodcastEpisode(r *http.Request) *spec.Response { if err != nil || id.Type != specid.PodcastEpisode { return spec.NewError(10, "please provide a valid podcast episode id") } - if err := c.Podcasts.DeletePodcastEpisode(id.Value); err != nil { + if err := c.podcasts.DeletePodcastEpisode(id.Value); err != nil { return spec.NewError(10, "failed to delete podcast: %s", err) } return spec.NewResponse() diff --git a/server/ctrlsubsonic/handlers_raw.go b/server/ctrlsubsonic/handlers_raw.go index 650572b..7146404 100644 --- a/server/ctrlsubsonic/handlers_raw.go +++ b/server/ctrlsubsonic/handlers_raw.go @@ -171,13 +171,13 @@ func (c *Controller) ServeGetCoverArt(w http.ResponseWriter, r *http.Request) *s } size := params.GetOrInt("size", coverDefaultSize) cachePath := filepath.Join( - c.CacheCoverPath, + c.cacheCoverPath, fmt.Sprintf("%s-%d.%s", id.String(), size, coverCacheFormat), ) _, err = os.Stat(cachePath) switch { case os.IsNotExist(err): - reader, err := coverFor(c.DB, c.ArtistInfoCache, id) + reader, err := coverFor(c.dbc, c.artistInfoCache, id) if err != nil { return spec.NewError(10, "couldn't find cover `%s`: %v", id, err) } @@ -206,7 +206,7 @@ func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.R return spec.NewError(10, "please provide an `id` parameter") } - file, err := specidpaths.Locate(c.DB, id) + file, err := specidpaths.Locate(c.dbc, id) if err != nil { return spec.NewError(0, "error looking up id %s: %v", id, err) } @@ -224,7 +224,7 @@ func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.R return nil } - pref, err := streamGetTransodePreference(c.DB, user.ID, params.GetOr("c", "")) + pref, err := streamGetTransodePreference(c.dbc, user.ID, params.GetOr("c", "")) if err != nil && !errors.Is(err, gorm.ErrRecordNotFound) { return spec.NewError(0, "couldn't find transcode preference: %v", err) } @@ -244,7 +244,7 @@ func (c *Controller) ServeStream(w http.ResponseWriter, r *http.Request) *spec.R log.Printf("trancoding to %q with max bitrate %dk", profile.MIME(), profile.BitRate()) w.Header().Set("Content-Type", profile.MIME()) - if err := c.Transcoder.Transcode(r.Context(), profile, file.AbsPath(), w); err != nil && !errors.Is(err, transcode.ErrFFmpegKilled) { + if err := c.transcoder.Transcode(r.Context(), profile, file.AbsPath(), w); err != nil && !errors.Is(err, transcode.ErrFFmpegKilled) { return spec.NewError(0, "error transcoding: %v", err) } @@ -261,7 +261,7 @@ func (c *Controller) ServeGetAvatar(w http.ResponseWriter, r *http.Request) *spe if err != nil { return spec.NewError(10, "please provide an `username` parameter") } - reqUser := c.DB.GetUserByName(username) + reqUser := c.dbc.GetUserByName(username) if (user != reqUser) && !user.IsAdmin { return spec.NewError(50, "user not admin") } diff --git a/server/ctrlsubsonic/middleware.go b/server/ctrlsubsonic/middleware.go deleted file mode 100644 index 11ab285..0000000 --- a/server/ctrlsubsonic/middleware.go +++ /dev/null @@ -1,89 +0,0 @@ -package ctrlsubsonic - -import ( - "context" - "crypto/md5" - "encoding/hex" - "fmt" - "net/http" - - "go.senan.xyz/gonic/server/ctrlsubsonic/params" - "go.senan.xyz/gonic/server/ctrlsubsonic/spec" -) - -func checkCredsToken(password, token, salt string) bool { - toHash := fmt.Sprintf("%s%s", password, salt) - hash := md5.Sum([]byte(toHash)) - expToken := hex.EncodeToString(hash[:]) - return token == expToken -} - -func checkCredsBasic(password, given string) bool { - if len(given) >= 4 && given[:4] == "enc:" { - bytes, _ := hex.DecodeString(given[4:]) - given = string(bytes) - } - return password == given -} - -func (c *Controller) WithParams(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := params.New(r) - withParams := context.WithValue(r.Context(), CtxParams, params) - next.ServeHTTP(w, r.WithContext(withParams)) - }) -} - -func (c *Controller) WithRequiredParams(next http.Handler) http.Handler { - requiredParameters := []string{ - "u", "c", - } - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := r.Context().Value(CtxParams).(params.Params) - for _, req := range requiredParameters { - if _, err := params.Get(req); err != nil { - _ = writeResp(w, r, spec.NewError(10, - "please provide a `%s` parameter", req)) - return - } - } - next.ServeHTTP(w, r) - }) -} - -func (c *Controller) WithUser(next http.Handler) http.Handler { - return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { - params := r.Context().Value(CtxParams).(params.Params) - // ignoring errors here, a middleware has already ensured they exist - username, _ := params.Get("u") - password, _ := params.Get("p") - token, _ := params.Get("t") - salt, _ := params.Get("s") - - passwordAuth := token == "" && salt == "" - tokenAuth := password == "" - if tokenAuth == passwordAuth { - _ = writeResp(w, r, spec.NewError(10, - "please provide `t` and `s`, or just `p`")) - return - } - user := c.DB.GetUserByName(username) - if user == nil { - _ = writeResp(w, r, spec.NewError(40, - "invalid username `%s`", username)) - return - } - var credsOk bool - if tokenAuth { - credsOk = checkCredsToken(user.Password, token, salt) - } else { - credsOk = checkCredsBasic(user.Password, password) - } - if !credsOk { - _ = writeResp(w, r, spec.NewError(40, "invalid password")) - return - } - withUser := context.WithValue(r.Context(), CtxUser, user) - next.ServeHTTP(w, r.WithContext(withUser)) - }) -} diff --git a/server/ctrlsubsonic/routes.go b/server/ctrlsubsonic/routes.go deleted file mode 100644 index 91872b8..0000000 --- a/server/ctrlsubsonic/routes.go +++ /dev/null @@ -1,88 +0,0 @@ -package ctrlsubsonic - -import "github.com/gorilla/mux" - -func AddRoutes(c *Controller, r *mux.Router) { - r.Use(c.WithParams) - r.Use(c.WithRequiredParams) - r.Use(c.WithUser) - - // common - r.Handle("/getLicense{_:(?:\\.view)?}", c.H(c.ServeGetLicence)) - r.Handle("/ping{_:(?:\\.view)?}", c.H(c.ServePing)) - r.Handle("/getOpenSubsonicExtensions{_:(?:\\.view)?}", c.H(c.ServeGetOpenSubsonicExtensions)) - - r.Handle("/getMusicFolders{_:(?:\\.view)?}", c.H(c.ServeGetMusicFolders)) - r.Handle("/getScanStatus{_:(?:\\.view)?}", c.H(c.ServeGetScanStatus)) - r.Handle("/scrobble{_:(?:\\.view)?}", c.H(c.ServeScrobble)) - r.Handle("/startScan{_:(?:\\.view)?}", c.H(c.ServeStartScan)) - r.Handle("/getUser{_:(?:\\.view)?}", c.H(c.ServeGetUser)) - r.Handle("/getPlaylists{_:(?:\\.view)?}", c.H(c.ServeGetPlaylists)) - r.Handle("/getPlaylist{_:(?:\\.view)?}", c.H(c.ServeGetPlaylist)) - r.Handle("/createPlaylist{_:(?:\\.view)?}", c.H(c.ServeCreatePlaylist)) - r.Handle("/updatePlaylist{_:(?:\\.view)?}", c.H(c.ServeUpdatePlaylist)) - r.Handle("/deletePlaylist{_:(?:\\.view)?}", c.H(c.ServeDeletePlaylist)) - r.Handle("/savePlayQueue{_:(?:\\.view)?}", c.H(c.ServeSavePlayQueue)) - r.Handle("/getPlayQueue{_:(?:\\.view)?}", c.H(c.ServeGetPlayQueue)) - r.Handle("/getSong{_:(?:\\.view)?}", c.H(c.ServeGetSong)) - r.Handle("/getRandomSongs{_:(?:\\.view)?}", c.H(c.ServeGetRandomSongs)) - r.Handle("/getSongsByGenre{_:(?:\\.view)?}", c.H(c.ServeGetSongsByGenre)) - r.Handle("/jukeboxControl{_:(?:\\.view)?}", c.H(c.ServeJukebox)) - r.Handle("/getBookmarks{_:(?:\\.view)?}", c.H(c.ServeGetBookmarks)) - r.Handle("/createBookmark{_:(?:\\.view)?}", c.H(c.ServeCreateBookmark)) - r.Handle("/deleteBookmark{_:(?:\\.view)?}", c.H(c.ServeDeleteBookmark)) - r.Handle("/getTopSongs{_:(?:\\.view)?}", c.H(c.ServeGetTopSongs)) - r.Handle("/getSimilarSongs{_:(?:\\.view)?}", c.H(c.ServeGetSimilarSongs)) - r.Handle("/getSimilarSongs2{_:(?:\\.view)?}", c.H(c.ServeGetSimilarSongsTwo)) - r.Handle("/getLyrics{_:(?:\\.view)?}", c.H(c.ServeGetLyrics)) - - // raw - r.Handle("/getCoverArt{_:(?:\\.view)?}", c.HR(c.ServeGetCoverArt)) - r.Handle("/stream{_:(?:\\.view)?}", c.HR(c.ServeStream)) - r.Handle("/download{_:(?:\\.view)?}", c.HR(c.ServeStream)) - r.Handle("/getAvatar{_:(?:\\.view)?}", c.HR(c.ServeGetAvatar)) - - // browse by tag - r.Handle("/getAlbum{_:(?:\\.view)?}", c.H(c.ServeGetAlbum)) - r.Handle("/getAlbumList2{_:(?:\\.view)?}", c.H(c.ServeGetAlbumListTwo)) - r.Handle("/getArtist{_:(?:\\.view)?}", c.H(c.ServeGetArtist)) - r.Handle("/getArtists{_:(?:\\.view)?}", c.H(c.ServeGetArtists)) - r.Handle("/search3{_:(?:\\.view)?}", c.H(c.ServeSearchThree)) - r.Handle("/getArtistInfo2{_:(?:\\.view)?}", c.H(c.ServeGetArtistInfoTwo)) - r.Handle("/getStarred2{_:(?:\\.view)?}", c.H(c.ServeGetStarredTwo)) - - // browse by folder - r.Handle("/getIndexes{_:(?:\\.view)?}", c.H(c.ServeGetIndexes)) - r.Handle("/getMusicDirectory{_:(?:\\.view)?}", c.H(c.ServeGetMusicDirectory)) - r.Handle("/getAlbumList{_:(?:\\.view)?}", c.H(c.ServeGetAlbumList)) - r.Handle("/search2{_:(?:\\.view)?}", c.H(c.ServeSearchTwo)) - r.Handle("/getGenres{_:(?:\\.view)?}", c.H(c.ServeGetGenres)) - r.Handle("/getArtistInfo{_:(?:\\.view)?}", c.H(c.ServeGetArtistInfo)) - r.Handle("/getStarred{_:(?:\\.view)?}", c.H(c.ServeGetStarred)) - - // star / rating - r.Handle("/star{_:(?:\\.view)?}", c.H(c.ServeStar)) - r.Handle("/unstar{_:(?:\\.view)?}", c.H(c.ServeUnstar)) - r.Handle("/setRating{_:(?:\\.view)?}", c.H(c.ServeSetRating)) - - // podcasts - r.Handle("/getPodcasts{_:(?:\\.view)?}", c.H(c.ServeGetPodcasts)) - r.Handle("/getNewestPodcasts{_:(?:\\.view)?}", c.H(c.ServeGetNewestPodcasts)) - r.Handle("/downloadPodcastEpisode{_:(?:\\.view)?}", c.H(c.ServeDownloadPodcastEpisode)) - r.Handle("/createPodcastChannel{_:(?:\\.view)?}", c.H(c.ServeCreatePodcastChannel)) - r.Handle("/refreshPodcasts{_:(?:\\.view)?}", c.H(c.ServeRefreshPodcasts)) - r.Handle("/deletePodcastChannel{_:(?:\\.view)?}", c.H(c.ServeDeletePodcastChannel)) - r.Handle("/deletePodcastEpisode{_:(?:\\.view)?}", c.H(c.ServeDeletePodcastEpisode)) - - // internet radio - r.Handle("/getInternetRadioStations{_:(?:\\.view)?}", c.H(c.ServeGetInternetRadioStations)) - r.Handle("/createInternetRadioStation{_:(?:\\.view)?}", c.H(c.ServeCreateInternetRadioStation)) - r.Handle("/updateInternetRadioStation{_:(?:\\.view)?}", c.H(c.ServeUpdateInternetRadioStation)) - r.Handle("/deleteInternetRadioStation{_:(?:\\.view)?}", c.H(c.ServeDeleteInternetRadioStation)) - - // middlewares should be run for not found handler - // https://github.com/gorilla/mux/issues/416 - notFoundHandler := c.H(c.ServeNotFound) - notFoundRoute := r.NewRoute().Handler(notFoundHandler) - r.NotFoundHandler = notFoundRoute.GetHandler() -}