diff --git a/db/db.go b/db/db.go index cb2b6e3..26c0f16 100644 --- a/db/db.go +++ b/db/db.go @@ -28,6 +28,34 @@ type DB struct { *gorm.DB } +func New(path string) (*DB, error) { + pathAndArgs := fmt.Sprintf("%s?%s", path, dbOptions.Encode()) + db, err := gorm.Open("sqlite3", pathAndArgs) + if err != nil { + return nil, errors.Wrap(err, "with gorm") + } + db.SetLogger(log.New(os.Stdout, "gorm ", 0)) + db.DB().SetMaxOpenConns(dbMaxOpenConns) + db.AutoMigrate( + model.Artist{}, + model.Track{}, + model.User{}, + model.Setting{}, + model.Play{}, + model.Album{}, + ) + db.FirstOrCreate(&model.User{}, model.User{ + Name: "admin", + Password: "admin", + IsAdmin: true, + }) + return &DB{DB: db}, nil +} + +func NewMock() (*DB, error) { + return New(":memory:") +} + func (db *DB) GetSetting(key string) string { setting := &model.Setting{} db. @@ -60,27 +88,3 @@ func (db *DB) WithTx(cb func(tx *gorm.DB)) { defer tx.Commit() cb(tx) } - -func New(path string) (*DB, error) { - pathAndArgs := fmt.Sprintf("%s?%s", path, dbOptions.Encode()) - db, err := gorm.Open("sqlite3", pathAndArgs) - if err != nil { - return nil, errors.Wrap(err, "with gorm") - } - db.SetLogger(log.New(os.Stdout, "gorm ", 0)) - db.DB().SetMaxOpenConns(dbMaxOpenConns) - db.AutoMigrate( - model.Artist{}, - model.Track{}, - model.User{}, - model.Setting{}, - model.Play{}, - model.Album{}, - ) - db.FirstOrCreate(&model.User{}, model.User{ - Name: "admin", - Password: "admin", - IsAdmin: true, - }) - return &DB{DB: db}, nil -} diff --git a/db/db_test.go b/db/db_test.go new file mode 100644 index 0000000..d2c4262 --- /dev/null +++ b/db/db_test.go @@ -0,0 +1,46 @@ +package db + +import ( + "log" + "math/rand" + "testing" + + _ "github.com/jinzhu/gorm/dialects/sqlite" +) + +var testDB *DB + +func init() { + var err error + testDB, err = NewMock() + if err != nil { + log.Fatalf("error opening database: %v\n", err) + } +} + +func randKey() string { + letters := []rune("abcdef0123456789") + b := make([]rune, 16) + for i := range b { + b[i] = letters[rand.Intn(len(letters))] + } + return string(b) +} + +func TestGetSetting(t *testing.T) { + key := randKey() + // new key + expected := "hello" + testDB.SetSetting(key, expected) + actual := testDB.GetSetting(key) + if actual != expected { + t.Errorf("expected %q, got %q", expected, actual) + } + // existing key + expected = "howdy" + testDB.SetSetting(key, expected) + actual = testDB.GetSetting(key) + if actual != expected { + t.Errorf("expected %q, got %q", expected, actual) + } +}