refactor(scanner): make the watcher a little easier to reason about (#394)

* add a test for symlinks escaping defined music paths

* refactor(scanner): make the watcher a little easier to reason about
This commit is contained in:
Senan Kelly
2023-10-24 23:57:42 +01:00
committed by GitHub
parent cae37255d6
commit c947404923
3 changed files with 76 additions and 72 deletions

View File

@@ -300,12 +300,13 @@ func main() {
defer logJob("scan watcher")() defer logJob("scan watcher")()
done := make(chan struct{})
errgrp.Go(func() error { errgrp.Go(func() error {
<-ctx.Done() <-ctx.Done()
scannr.CancelWatch() done <- struct{}{}
return nil return nil
}) })
return scannr.ExecuteWatch() return scannr.ExecuteWatch(done)
}) })
errgrp.Go(func() error { errgrp.Go(func() error {

View File

@@ -34,9 +34,6 @@ type Scanner struct {
tagReader tagcommon.Reader tagReader tagcommon.Reader
excludePattern *regexp.Regexp excludePattern *regexp.Regexp
scanning *int32 scanning *int32
watcher *fsnotify.Watcher
watchMap map[string]string // maps watched dirs back to root music dir
watchDone chan bool
} }
func New(musicDirs []string, db *db.DB, multiValueSettings map[Tag]MultiValueSetting, tagReader tagcommon.Reader, excludePattern string) *Scanner { func New(musicDirs []string, db *db.DB, multiValueSettings map[Tag]MultiValueSetting, tagReader tagcommon.Reader, excludePattern string) *Scanner {
@@ -52,22 +49,12 @@ func New(musicDirs []string, db *db.DB, multiValueSettings map[Tag]MultiValueSet
tagReader: tagReader, tagReader: tagReader,
excludePattern: excludePatternRegExp, excludePattern: excludePatternRegExp,
scanning: new(int32), scanning: new(int32),
watchMap: make(map[string]string),
watchDone: make(chan bool),
} }
} }
func (s *Scanner) IsScanning() bool { func (s *Scanner) IsScanning() bool { return atomic.LoadInt32(s.scanning) == 1 }
return atomic.LoadInt32(s.scanning) == 1 func (s *Scanner) StartScanning() bool { return atomic.CompareAndSwapInt32(s.scanning, 0, 1) }
} func (s *Scanner) StopScanning() { defer atomic.StoreInt32(s.scanning, 0) }
func (s *Scanner) StartScanning() bool {
return atomic.CompareAndSwapInt32(s.scanning, 0, 1)
}
func (s *Scanner) StopScanning() {
defer atomic.StoreInt32(s.scanning, 0)
}
type ScanOptions struct { type ScanOptions struct {
IsFull bool IsFull bool
@@ -94,7 +81,7 @@ func (s *Scanner) ScanAndClean(opts ScanOptions) (*Context, error) {
for _, dir := range s.musicDirs { for _, dir := range s.musicDirs {
err := filepath.WalkDir(dir, func(absPath string, d fs.DirEntry, err error) error { err := filepath.WalkDir(dir, func(absPath string, d fs.DirEntry, err error) error {
return s.scanCallback(c, dir, absPath, d, err) return s.scanCallback(c, absPath, d, err)
}) })
if err != nil { if err != nil {
return nil, fmt.Errorf("walk: %w", err) return nil, fmt.Errorf("walk: %w", err)
@@ -121,93 +108,82 @@ func (s *Scanner) ScanAndClean(opts ScanOptions) (*Context, error) {
return c, errors.Join(c.errs...) return c, errors.Join(c.errs...)
} }
func (s *Scanner) ExecuteWatch() error { func (s *Scanner) ExecuteWatch(done <-chan struct{}) error {
var err error watcher, err := fsnotify.NewWatcher()
s.watcher, err = fsnotify.NewWatcher()
if err != nil { if err != nil {
log.Printf("error creating watcher: %v\n", err) return fmt.Errorf("creating watcher: %w", err)
return err
} }
defer s.watcher.Close() defer watcher.Close()
t := time.NewTimer(10 * time.Second) const batchInterval = 10 * time.Second
if !t.Stop() { batchT := time.NewTimer(batchInterval)
<-t.C batchT.Stop()
}
for _, dir := range s.musicDirs { for _, dir := range s.musicDirs {
err := filepath.WalkDir(dir, func(absPath string, d fs.DirEntry, err error) error { err := filepath.WalkDir(dir, func(absPath string, d fs.DirEntry, err error) error {
return s.watchCallback(dir, absPath, d, err) return watchCallback(watcher, absPath, d, err)
}) })
if err != nil { if err != nil {
log.Printf("error watching directory tree: %v\n", err) log.Printf("error watching directory tree: %v\n", err)
continue
} }
} }
scanList := map[string]struct{}{} batchSeen := map[string]struct{}{}
for { for {
select { select {
case <-t.C: case <-batchT.C:
if !s.StartScanning() { if !s.StartScanning() {
scanList = map[string]struct{}{}
break break
} }
for dirName := range scanList { for absPath := range batchSeen {
c := &Context{ c := &Context{
seenTracks: map[int]struct{}{}, seenTracks: map[int]struct{}{},
seenAlbums: map[int]struct{}{}, seenAlbums: map[int]struct{}{},
isFull: false,
} }
musicDirName := s.watchMap[dirName] err = filepath.WalkDir(absPath, func(absPath string, d fs.DirEntry, err error) error {
if musicDirName == "" { return watchCallback(watcher, absPath, d, err)
musicDirName = s.watchMap[filepath.Dir(dirName)]
}
err = filepath.WalkDir(dirName, func(absPath string, d fs.DirEntry, err error) error {
return s.watchCallback(musicDirName, absPath, d, err)
}) })
if err != nil { if err != nil {
log.Printf("error watching directory tree: %v\n", err) log.Printf("error watching directory tree: %v\n", err)
continue
} }
err = filepath.WalkDir(dirName, func(absPath string, d fs.DirEntry, err error) error { err = filepath.WalkDir(absPath, func(absPath string, d fs.DirEntry, err error) error {
return s.scanCallback(c, musicDirName, absPath, d, err) return s.scanCallback(c, absPath, d, err)
}) })
if err != nil { if err != nil {
log.Printf("error walking: %v", err) log.Printf("error walking: %v", err)
continue
} }
} }
scanList = map[string]struct{}{}
s.StopScanning() s.StopScanning()
case event := <-s.watcher.Events: clear(batchSeen)
var dirName string
case event := <-watcher.Events:
if event.Op&(fsnotify.Create|fsnotify.Write) == 0 { if event.Op&(fsnotify.Create|fsnotify.Write) == 0 {
break break
} }
if len(scanList) == 0 {
t.Reset(10 * time.Second)
}
fileInfo, err := os.Stat(event.Name) fileInfo, err := os.Stat(event.Name)
if err != nil { if err != nil {
break break
} }
if fileInfo.IsDir() { if fileInfo.IsDir() {
dirName = event.Name batchSeen[event.Name] = struct{}{}
} else { } else {
dirName = filepath.Dir(event.Name) batchSeen[filepath.Dir(event.Name)] = struct{}{}
} }
scanList[dirName] = struct{}{} batchT.Reset(batchInterval)
case err = <-s.watcher.Errors:
case err = <-watcher.Errors:
log.Printf("error from watcher: %v\n", err) log.Printf("error from watcher: %v\n", err)
case <-s.watchDone:
case <-done:
return nil return nil
} }
} }
} }
func (s *Scanner) CancelWatch() { func watchCallback(watcher *fsnotify.Watcher, absPath string, d fs.DirEntry, err error) error {
s.watchDone <- true
}
func (s *Scanner) watchCallback(dir string, absPath string, d fs.DirEntry, err error) error {
if err != nil { if err != nil {
return err return err
} }
@@ -218,25 +194,21 @@ func (s *Scanner) watchCallback(dir string, absPath string, d fs.DirEntry, err e
eval, _ := filepath.EvalSymlinks(absPath) eval, _ := filepath.EvalSymlinks(absPath)
return filepath.WalkDir(eval, func(subAbs string, d fs.DirEntry, err error) error { return filepath.WalkDir(eval, func(subAbs string, d fs.DirEntry, err error) error {
subAbs = strings.Replace(subAbs, eval, absPath, 1) subAbs = strings.Replace(subAbs, eval, absPath, 1)
return s.watchCallback(dir, subAbs, d, err) return watchCallback(watcher, subAbs, d, err)
}) })
default: default:
return nil return nil
} }
if s.watchMap[absPath] == "" { if err := watcher.Add(absPath); err != nil {
s.watchMap[absPath] = dir return fmt.Errorf("add path to watcher: %w", err)
err = s.watcher.Add(absPath)
} }
return err
}
func (s *Scanner) scanCallback(c *Context, dir string, absPath string, d fs.DirEntry, err error) error {
if err != nil {
c.errs = append(c.errs, err)
return nil return nil
} }
if dir == absPath {
func (s *Scanner) scanCallback(c *Context, absPath string, d fs.DirEntry, err error) error {
if err != nil {
c.errs = append(c.errs, err)
return nil return nil
} }
@@ -246,7 +218,7 @@ func (s *Scanner) scanCallback(c *Context, dir string, absPath string, d fs.DirE
eval, _ := filepath.EvalSymlinks(absPath) eval, _ := filepath.EvalSymlinks(absPath)
return filepath.WalkDir(eval, func(subAbs string, d fs.DirEntry, err error) error { return filepath.WalkDir(eval, func(subAbs string, d fs.DirEntry, err error) error {
subAbs = strings.Replace(subAbs, eval, absPath, 1) subAbs = strings.Replace(subAbs, eval, absPath, 1)
return s.scanCallback(c, dir, subAbs, d, err) return s.scanCallback(c, subAbs, d, err)
}) })
default: default:
return nil return nil
@@ -260,7 +232,7 @@ func (s *Scanner) scanCallback(c *Context, dir string, absPath string, d fs.DirE
log.Printf("processing folder %q", absPath) log.Printf("processing folder %q", absPath)
tx := s.db.Begin() tx := s.db.Begin()
if err := s.scanDir(tx, c, dir, absPath); err != nil { if err := s.scanDir(tx, c, absPath); err != nil {
c.errs = append(c.errs, fmt.Errorf("%q: %w", absPath, err)) c.errs = append(c.errs, fmt.Errorf("%q: %w", absPath, err))
tx.Rollback() tx.Rollback()
return nil return nil
@@ -272,7 +244,12 @@ func (s *Scanner) scanCallback(c *Context, dir string, absPath string, d fs.DirE
return nil return nil
} }
func (s *Scanner) scanDir(tx *db.DB, c *Context, musicDir string, absPath string) error { func (s *Scanner) scanDir(tx *db.DB, c *Context, absPath string) error {
musicDir, relPath := musicDirRelative(s.musicDirs, absPath)
if musicDir == absPath {
return nil
}
items, err := os.ReadDir(absPath) items, err := os.ReadDir(absPath)
if err != nil { if err != nil {
return err return err
@@ -300,7 +277,6 @@ func (s *Scanner) scanDir(tx *db.DB, c *Context, musicDir string, absPath string
} }
} }
relPath, _ := filepath.Rel(musicDir, absPath)
pdir, pbasename := filepath.Split(filepath.Dir(relPath)) pdir, pbasename := filepath.Split(filepath.Dir(relPath))
var parent db.Album var parent db.Album
if err := tx.Where("root_dir=? AND left_path=? AND right_path=?", musicDir, pdir, pbasename).Assign(db.Album{RootDir: musicDir, LeftPath: pdir, RightPath: pbasename}).FirstOrCreate(&parent).Error; err != nil { if err := tx.Where("root_dir=? AND left_path=? AND right_path=?", musicDir, pdir, pbasename).Assign(db.Album{RootDir: musicDir, LeftPath: pdir, RightPath: pbasename}).FirstOrCreate(&parent).Error; err != nil {
@@ -701,3 +677,13 @@ func parseMulti(parser tagcommon.Info, setting MultiValueSetting, getMulti func(
} }
return parts return parts
} }
func musicDirRelative(musicDirs []string, absPath string) (musicDir, relPath string) {
for _, musicDir := range musicDirs {
if strings.HasPrefix(absPath, musicDir) {
relPath, _ = filepath.Rel(musicDir, absPath)
return musicDir, relPath
}
}
return
}

View File

@@ -487,6 +487,23 @@ func TestSymlinkedSubdiscs(t *testing.T) {
assert.NotZero(t, info.ModTime()) // track resolves assert.NotZero(t, info.ModTime()) // track resolves
} }
func TestSymlinkEscapesMusicDirs(t *testing.T) {
t.Parallel()
m := mockfs.NewWithDirs(t, []string{"scandir"})
require.NoError(t, os.MkdirAll(filepath.Join(m.TmpDir(), "otherdir", "artist", "album-test"), os.ModePerm))
require.NoError(t, os.Symlink(
filepath.Join(m.TmpDir(), "otherdir", "artist"),
filepath.Join(m.TmpDir(), "scandir", "artist"),
))
m.ScanAndClean()
var albums []*db.Album
require.NoError(t, m.DB().Find(&albums).Error)
require.Len(t, albums, 3)
}
func TestTagErrors(t *testing.T) { func TestTagErrors(t *testing.T) {
t.Parallel() t.Parallel()
m := mockfs.New(t) m := mockfs.New(t)