diff --git a/db/db.go b/db/db.go index 9576e65..c2afe2e 100644 --- a/db/db.go +++ b/db/db.go @@ -24,7 +24,38 @@ var ( } ) -func New(path string) (*gorm.DB, error) { +type DB struct { + *gorm.DB +} + +func (db *DB) GetSetting(key string) string { + setting := &model.Setting{} + db. + Where("key = ?", key). + First(setting) + return setting.Value +} + +func (db *DB) SetSetting(key, value string) { + db. + Where(model.Setting{Key: key}). + Assign(model.Setting{Value: value}). + FirstOrCreate(&model.Setting{}) +} + +func (db *DB) GetUserFromName(name string) *model.User { + user := &model.User{} + err := db. + Where("name = ?", name). + First(user). + Error + if gorm.IsRecordNotFoundError(err) { + return nil + } + return user +} + +func New(path string) (*DB, error) { pathAndArgs := fmt.Sprintf("%s?%s", path, dbOptions.Encode()) db, err := gorm.Open("sqlite3", pathAndArgs) if err != nil { @@ -45,5 +76,5 @@ func New(path string) (*gorm.DB, error) { Password: "admin", IsAdmin: true, }) - return db, nil + return &DB{DB: db}, nil } diff --git a/scanner/scanner.go b/scanner/scanner.go index 027a577..5478185 100644 --- a/scanner/scanner.go +++ b/scanner/scanner.go @@ -14,6 +14,7 @@ import ( "github.com/pkg/errors" "github.com/rainycape/unidecode" + "github.com/sentriz/gonic/db" "github.com/sentriz/gonic/mime" "github.com/sentriz/gonic/model" "github.com/sentriz/gonic/scanner/stack" @@ -51,14 +52,14 @@ func decoded(in string) string { return result } -func withTx(db *gorm.DB, cb func(tx *gorm.DB)) { +func withTx(db *db.DB, cb func(tx *gorm.DB)) { tx := db.Begin() defer tx.Commit() cb(tx) } type Scanner struct { - db *gorm.DB + db *db.DB musicPath string // these two are for the transaction we do for every folder. // the boolean is there so we dont begin or commit multiple @@ -78,7 +79,7 @@ type Scanner struct { seenTracksErr int // n tracks we we couldn't scan } -func New(db *gorm.DB, musicPath string) *Scanner { +func New(db *db.DB, musicPath string) *Scanner { return &Scanner{ db: db, musicPath: musicPath, diff --git a/scanner/scanner_test.go b/scanner/scanner_test.go index be9d653..dbea10d 100644 --- a/scanner/scanner_test.go +++ b/scanner/scanner_test.go @@ -5,7 +5,6 @@ import ( "log" "testing" - "github.com/jinzhu/gorm" _ "github.com/jinzhu/gorm/dialects/sqlite" "github.com/sentriz/gonic/db" @@ -24,7 +23,7 @@ func init() { log.SetOutput(ioutil.Discard) } -func resetTables(db *gorm.DB) { +func resetTables(db *db.DB) { tx := db.Begin() defer tx.Commit() tx.Exec("delete from tracks") @@ -32,7 +31,7 @@ func resetTables(db *gorm.DB) { tx.Exec("delete from albums") } -func resetTablesPause(db *gorm.DB, b *testing.B) { +func resetTablesPause(db *db.DB, b *testing.B) { b.StopTimer() defer b.StartTimer() resetTables(db) diff --git a/server/handler/handler.go b/server/handler/handler.go index 3877293..27e5bc2 100644 --- a/server/handler/handler.go +++ b/server/handler/handler.go @@ -3,10 +3,9 @@ package handler import ( "html/template" - "github.com/jinzhu/gorm" "github.com/wader/gormstore" - "github.com/sentriz/gonic/model" + "github.com/sentriz/gonic/db" ) type contextKey int @@ -17,35 +16,8 @@ const ( ) type Controller struct { - DB *gorm.DB + DB *db.DB SessDB *gormstore.Store Templates map[string]*template.Template MusicPath string } - -func (c *Controller) GetSetting(key string) string { - setting := &model.Setting{} - c.DB. - Where("key = ?", key). - First(setting) - return setting.Value -} - -func (c *Controller) SetSetting(key, value string) { - c.DB. - Where(model.Setting{Key: key}). - Assign(model.Setting{Value: value}). - FirstOrCreate(&model.Setting{}) -} - -func (c *Controller) GetUserFromName(name string) *model.User { - user := &model.User{} - err := c.DB. - Where("name = ?", name). - First(user). - Error - if gorm.IsRecordNotFoundError(err) { - return nil - } - return user -} diff --git a/server/handler/handler_admin.go b/server/handler/handler_admin.go index 8ded909..09c760c 100644 --- a/server/handler/handler_admin.go +++ b/server/handler/handler_admin.go @@ -27,7 +27,7 @@ func (c *Controller) ServeLoginDo(w http.ResponseWriter, r *http.Request) { http.Redirect(w, r, r.Header.Get("Referer"), http.StatusSeeOther) return } - user := c.GetUserFromName(username) + user := c.DB.GetUserFromName(username) if user == nil || password != user.Password { session.AddFlash("invalid username / password") sessionLogSave(w, r, session) @@ -60,7 +60,7 @@ func (c *Controller) ServeHome(w http.ResponseWriter, r *http.Request) { Order("updated_at DESC"). Limit(8). Find(&data.RecentFolders) - data.CurrentLastFMAPIKey = c.GetSetting("lastfm_api_key") + data.CurrentLastFMAPIKey = c.DB.GetSetting("lastfm_api_key") scheme := firstExisting( "http", // fallback r.Header.Get("X-Forwarded-Proto"), @@ -104,8 +104,8 @@ func (c *Controller) ServeLinkLastFMDo(w http.ResponseWriter, r *http.Request) { return } sessionKey, err := lastfm.GetSession( - c.GetSetting("lastfm_api_key"), - c.GetSetting("lastfm_secret"), + c.DB.GetSetting("lastfm_api_key"), + c.DB.GetSetting("lastfm_secret"), token, ) session := r.Context().Value(contextSessionKey).(*sessions.Session) @@ -240,8 +240,8 @@ func (c *Controller) ServeCreateUserDo(w http.ResponseWriter, r *http.Request) { func (c *Controller) ServeUpdateLastFMAPIKey(w http.ResponseWriter, r *http.Request) { data := &templateData{} - data.CurrentLastFMAPIKey = c.GetSetting("lastfm_api_key") - data.CurrentLastFMAPISecret = c.GetSetting("lastfm_secret") + data.CurrentLastFMAPIKey = c.DB.GetSetting("lastfm_api_key") + data.CurrentLastFMAPISecret = c.DB.GetSetting("lastfm_secret") renderTemplate(w, r, c.Templates["update_lastfm_api_key.tmpl"], data) } @@ -256,8 +256,8 @@ func (c *Controller) ServeUpdateLastFMAPIKeyDo(w http.ResponseWriter, r *http.Re http.Redirect(w, r, r.Header.Get("Referer"), http.StatusSeeOther) return } - c.SetSetting("lastfm_api_key", apiKey) - c.SetSetting("lastfm_secret", secret) + c.DB.SetSetting("lastfm_api_key", apiKey) + c.DB.SetSetting("lastfm_secret", secret) http.Redirect(w, r, "/admin/home", http.StatusSeeOther) } diff --git a/server/handler/handler_sub_by_folder.go b/server/handler/handler_sub_by_folder.go index cdf0823..23fd0c6 100644 --- a/server/handler/handler_sub_by_folder.go +++ b/server/handler/handler_sub_by_folder.go @@ -107,7 +107,7 @@ func (c *Controller) GetAlbumList(w http.ResponseWriter, r *http.Request) { respondError(w, r, 10, "please provide a `type` parameter") return } - q := c.DB + q := c.DB.DB switch listType { case "alphabeticalByArtist": q = q.Joins(` diff --git a/server/handler/handler_sub_by_tags.go b/server/handler/handler_sub_by_tags.go index 79bfc2d..c21040e 100644 --- a/server/handler/handler_sub_by_tags.go +++ b/server/handler/handler_sub_by_tags.go @@ -103,7 +103,7 @@ func (c *Controller) GetAlbumListTwo(w http.ResponseWriter, r *http.Request) { respondError(w, r, 10, "please provide a `type` parameter") return } - q := c.DB + q := c.DB.DB switch listType { case "alphabeticalByArtist": q = q.Joins(` diff --git a/server/handler/handler_sub_common.go b/server/handler/handler_sub_common.go index 9868f6e..2812d84 100644 --- a/server/handler/handler_sub_common.go +++ b/server/handler/handler_sub_common.go @@ -129,8 +129,8 @@ func (c *Controller) Scrobble(w http.ResponseWriter, r *http.Request) { First(track, id) // scrobble with above info err = lastfm.Scrobble( - c.GetSetting("lastfm_api_key"), - c.GetSetting("lastfm_secret"), + c.DB.GetSetting("lastfm_api_key"), + c.DB.GetSetting("lastfm_secret"), user.LastFMSession, track, // clients will provide time in miliseconds, so use that or diff --git a/server/handler/middleware_admin.go b/server/handler/middleware_admin.go index 8558edb..964f3b6 100644 --- a/server/handler/middleware_admin.go +++ b/server/handler/middleware_admin.go @@ -31,7 +31,7 @@ func (c *Controller) WithUserSession(next http.HandlerFunc) http.HandlerFunc { return } // take username from sesion and add the user row to the context - user := c.GetUserFromName(username) + user := c.DB.GetUserFromName(username) if user == nil { // the username in the client's session no longer relates to a // user in the database (maybe the user was deleted) diff --git a/server/handler/middleware_sub.go b/server/handler/middleware_sub.go index 7b49e58..c11a145 100644 --- a/server/handler/middleware_sub.go +++ b/server/handler/middleware_sub.go @@ -58,7 +58,7 @@ func (c *Controller) WithValidSubsonicArgs(next http.HandlerFunc) http.HandlerFu "please provide parameters `t` and `s`, or just `p`") return } - user := c.GetUserFromName(username) + user := c.DB.GetUserFromName(username) if user == nil { respondError(w, r, 40, "invalid username `%s`", username) return diff --git a/server/handler/respond_admin.go b/server/handler/respond_admin.go index d109c1f..3a39dbb 100644 --- a/server/handler/respond_admin.go +++ b/server/handler/respond_admin.go @@ -11,31 +11,35 @@ import ( ) type templateData struct { - AlbumCount int - AllUsers []*model.User - ArtistCount int + // common + Flashes []interface{} + User *model.User + // home + AlbumCount int + ArtistCount int + TrackCount int + RequestRoot string + RecentFolders []*model.Album + AllUsers []*model.User + // CurrentLastFMAPIKey string CurrentLastFMAPISecret string - Flashes []interface{} - RecentFolders []*model.Album - RequestRoot string SelectedUser *model.User - TrackCount int - User *model.User } -func renderTemplate(w http.ResponseWriter, r *http.Request, - tmpl *template.Template, data *templateData) { - session := r.Context().Value(contextSessionKey).(*sessions.Session) +func renderTemplate( + w http.ResponseWriter, + r *http.Request, + tmpl *template.Template, + data *templateData, +) { if data == nil { data = &templateData{} } + session := r.Context().Value(contextSessionKey).(*sessions.Session) data.Flashes = session.Flashes() sessionLogSave(w, r, session) - user, ok := r.Context().Value(contextUserKey).(*model.User) - if ok { - data.User = user - } + data.User = r.Context().Value(contextUserKey).(*model.User) err := tmpl.Execute(w, data) if err != nil { log.Printf("error executing template: %v\n", err) diff --git a/server/server.go b/server/server.go index 29afe06..b914c31 100644 --- a/server/server.go +++ b/server/server.go @@ -5,8 +5,7 @@ import ( "net/http" "time" - "github.com/jinzhu/gorm" - + "github.com/sentriz/gonic/db" "github.com/sentriz/gonic/server/handler" ) @@ -18,7 +17,7 @@ type Server struct { } func New( - db *gorm.DB, + db *db.DB, musicPath string, listenAddr string, assetPath string, diff --git a/server/server_admin.go b/server/server_admin.go index 3b8cbc8..24ea248 100644 --- a/server/server_admin.go +++ b/server/server_admin.go @@ -80,12 +80,12 @@ func staticHandler(assets *Assets, path string) http.HandlerFunc { } func (s *Server) SetupAdmin() error { - sessionKey := []byte(s.GetSetting("session_key")) + sessionKey := []byte(s.DB.GetSetting("session_key")) if len(sessionKey) == 0 { sessionKey = securecookie.GenerateRandomKey(32) - s.SetSetting("session_key", string(sessionKey)) + s.DB.SetSetting("session_key", string(sessionKey)) } - s.SessDB = gormstore.New(s.DB, sessionKey) + s.SessDB = gormstore.New(s.DB.DB, sessionKey) go s.SessDB.PeriodicCleanup(time.Hour, nil) // tmplBase := template.