fix auth; update deps
This commit is contained in:
6
vendor/go.mau.fi/util/dbutil/database.go
vendored
6
vendor/go.mau.fi/util/dbutil/database.go
vendored
@@ -104,6 +104,8 @@ type Database struct {
|
||||
Dialect Dialect
|
||||
UpgradeTable UpgradeTable
|
||||
|
||||
txnCtxKey contextKey
|
||||
|
||||
IgnoreForeignTables bool
|
||||
IgnoreUnsupportedDatabase bool
|
||||
}
|
||||
@@ -132,6 +134,8 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
|
||||
Log: log,
|
||||
Dialect: db.Dialect,
|
||||
|
||||
txnCtxKey: db.txnCtxKey,
|
||||
|
||||
IgnoreForeignTables: true,
|
||||
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
|
||||
}
|
||||
@@ -149,6 +153,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
|
||||
|
||||
IgnoreForeignTables: true,
|
||||
VersionTable: "version",
|
||||
|
||||
txnCtxKey: contextKey(nextContextKeyDatabaseTransaction.Add(1)),
|
||||
}
|
||||
wrappedDB.LoggingDB.UnderlyingExecable = db
|
||||
wrappedDB.LoggingDB.db = wrappedDB
|
||||
|
||||
3
vendor/go.mau.fi/util/dbutil/json.go
vendored
3
vendor/go.mau.fi/util/dbutil/json.go
vendored
@@ -4,6 +4,7 @@ import (
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"unsafe"
|
||||
)
|
||||
|
||||
// JSON is a utility type for using arbitrary JSON data as values in database Exec and Scan calls.
|
||||
@@ -29,7 +30,7 @@ func (j JSON) Value() (driver.Value, error) {
|
||||
return nil, nil
|
||||
}
|
||||
v, err := json.Marshal(j.Data)
|
||||
return string(v), err
|
||||
return unsafe.String(unsafe.SliceData(v), len(v)), err
|
||||
}
|
||||
|
||||
// JSONPtr is a convenience function for wrapping a pointer to a value in the JSON utility, but removing typed nils
|
||||
|
||||
@@ -2,8 +2,8 @@
|
||||
CREATE TABLE foo (
|
||||
-- only: postgres
|
||||
key BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
|
||||
-- only: sqlite
|
||||
key INTEGER PRIMARY KEY,
|
||||
-- only: sqlite (line commented)
|
||||
-- key INTEGER PRIMARY KEY,
|
||||
|
||||
data JSONB NOT NULL
|
||||
);
|
||||
|
||||
@@ -8,4 +8,3 @@ CREATE FUNCTION delete_data() RETURNS TRIGGER LANGUAGE plpgsql AS $$ BEGIN
|
||||
DELETE FROM test WHERE key <= NEW.data->>'index';
|
||||
RETURN NEW;
|
||||
END $$;
|
||||
-- end only postgres
|
||||
|
||||
@@ -7,4 +7,3 @@ CREATE TABLE foo (
|
||||
CREATE TRIGGER test AFTER INSERT ON foo WHEN NEW.data->>'action' = 'delete' BEGIN
|
||||
DELETE FROM test WHERE key <= NEW.data->>'index';
|
||||
END;
|
||||
-- end only sqlite
|
||||
|
||||
20
vendor/go.mau.fi/util/dbutil/transaction.go
vendored
20
vendor/go.mau.fi/util/dbutil/transaction.go
vendored
@@ -12,6 +12,7 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
@@ -26,13 +27,18 @@ var (
|
||||
ErrTxnCommit = fmt.Errorf("%w: commit", ErrTxn)
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
type contextKey int64
|
||||
|
||||
const (
|
||||
ContextKeyDatabaseTransaction contextKey = iota
|
||||
ContextKeyDoTxnCallerSkip
|
||||
ContextKeyDoTxnCallerSkip contextKey = 1
|
||||
)
|
||||
|
||||
var nextContextKeyDatabaseTransaction atomic.Uint64
|
||||
|
||||
func init() {
|
||||
nextContextKeyDatabaseTransaction.Store(1 << 61)
|
||||
}
|
||||
|
||||
func (db *Database) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
return db.Conn(ctx).ExecContext(ctx, query, args...)
|
||||
}
|
||||
@@ -56,7 +62,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
|
||||
if ctx == nil {
|
||||
panic("DoTxn() called with nil ctx")
|
||||
}
|
||||
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
|
||||
if ctx.Value(db.txnCtxKey) != nil {
|
||||
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
|
||||
return fn(ctx)
|
||||
}
|
||||
@@ -82,7 +88,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
|
||||
select {
|
||||
case <-ticker.C:
|
||||
slowLog.Warn().
|
||||
Dur("duration_seconds", time.Since(start)).
|
||||
Float64("duration_seconds", time.Since(start).Seconds()).
|
||||
Msg("Transaction still running")
|
||||
case <-deadlockCh:
|
||||
return
|
||||
@@ -106,7 +112,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
|
||||
log.Trace().Msg("Transaction started")
|
||||
tx.noTotalLog = true
|
||||
ctx = log.WithContext(ctx)
|
||||
ctx = context.WithValue(ctx, ContextKeyDatabaseTransaction, tx)
|
||||
ctx = context.WithValue(ctx, db.txnCtxKey, tx)
|
||||
err = fn(ctx)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msg("Database transaction failed, rolling back")
|
||||
@@ -131,7 +137,7 @@ func (db *Database) Conn(ctx context.Context) Execable {
|
||||
if ctx == nil {
|
||||
panic("Conn() called with nil ctx")
|
||||
}
|
||||
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
|
||||
txn, ok := ctx.Value(db.txnCtxKey).(Transaction)
|
||||
if ok {
|
||||
return txn
|
||||
}
|
||||
|
||||
77
vendor/go.mau.fi/util/dbutil/upgradetable.go
vendored
77
vendor/go.mau.fi/util/dbutil/upgradetable.go
vendored
@@ -121,40 +121,38 @@ func parseFileHeader(file []byte) (from, to, compat int, message string, txn boo
|
||||
// -- only: sqlite for next 123 lines
|
||||
//
|
||||
// If the single-line limit is on the second line of the file, the whole file is limited to that dialect.
|
||||
var dialectLineFilter = regexp.MustCompile(`^\s*-- only: (postgres|sqlite)(?: for next (\d+) lines| until "(end) only")?`)
|
||||
//
|
||||
// If the filter ends with `(lines commented)`, then ALL lines chosen by the filter will be uncommented.
|
||||
var dialectLineFilter = regexp.MustCompile(`^\s*-- only: (postgres|sqlite)(?: for next (\d+) lines| until "(end) only")?(?: \(lines? (commented)\))?`)
|
||||
|
||||
// Constants used to make parseDialectFilter clearer
|
||||
const (
|
||||
skipUntilEndTag = -1
|
||||
skipNothing = 0
|
||||
skipCurrentLine = 1
|
||||
skipNextLine = 2
|
||||
skipNextLine = 1
|
||||
)
|
||||
|
||||
func (db *Database) parseDialectFilter(line []byte) (int, error) {
|
||||
func (db *Database) parseDialectFilter(line []byte) (dialect Dialect, lineCount int, uncomment bool, err error) {
|
||||
match := dialectLineFilter.FindSubmatch(line)
|
||||
if match == nil {
|
||||
return skipNothing, nil
|
||||
return
|
||||
}
|
||||
dialect, err := ParseDialect(string(match[1]))
|
||||
dialect, err = ParseDialect(string(match[1]))
|
||||
if err != nil {
|
||||
return skipNothing, err
|
||||
} else if dialect == db.Dialect {
|
||||
// Skip the dialect filter line
|
||||
return skipCurrentLine, nil
|
||||
} else if bytes.Equal(match[3], []byte("end")) {
|
||||
return skipUntilEndTag, nil
|
||||
} else if len(match[2]) == 0 {
|
||||
// Skip the dialect filter and the next line
|
||||
return skipNextLine, nil
|
||||
} else {
|
||||
// Parse number of lines to skip, add 1 for current line
|
||||
lineCount, err := strconv.Atoi(string(match[2]))
|
||||
if err != nil {
|
||||
return skipNothing, fmt.Errorf("invalid line count '%s': %w", match[2], err)
|
||||
}
|
||||
return skipCurrentLine + lineCount, nil
|
||||
return
|
||||
}
|
||||
uncomment = bytes.Equal(match[4], []byte("commented"))
|
||||
if bytes.Equal(match[3], []byte("end")) {
|
||||
lineCount = skipUntilEndTag
|
||||
} else if len(match[2]) == 0 {
|
||||
lineCount = skipNextLine
|
||||
} else {
|
||||
lineCount, err = strconv.Atoi(string(match[2]))
|
||||
if err != nil {
|
||||
err = fmt.Errorf("invalid line count %q: %w", match[2], err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
var endLineFilter = regexp.MustCompile(`^\s*-- end only (postgres|sqlite)$`)
|
||||
@@ -162,15 +160,16 @@ var endLineFilter = regexp.MustCompile(`^\s*-- end only (postgres|sqlite)$`)
|
||||
func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
|
||||
output := make([][]byte, 0, len(lines))
|
||||
for i := 0; i < len(lines); i++ {
|
||||
skipLines, err := db.parseDialectFilter(lines[i])
|
||||
dialect, lineCount, uncomment, err := db.parseDialectFilter(lines[i])
|
||||
if err != nil {
|
||||
return "", err
|
||||
} else if skipLines > 0 {
|
||||
// Current line is implicitly skipped, so reduce one here
|
||||
i += skipLines - 1
|
||||
} else if skipLines == skipUntilEndTag {
|
||||
} else if lineCount == skipNothing {
|
||||
output = append(output, lines[i])
|
||||
} else if lineCount == skipUntilEndTag {
|
||||
startedAt := i
|
||||
startedAtMatch := dialectLineFilter.FindSubmatch(lines[startedAt])
|
||||
// Skip filter start line
|
||||
i++
|
||||
for ; i < len(lines); i++ {
|
||||
if match := endLineFilter.FindSubmatch(lines[i]); match != nil {
|
||||
if !bytes.Equal(match[1], startedAtMatch[1]) {
|
||||
@@ -178,12 +177,32 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
|
||||
}
|
||||
break
|
||||
}
|
||||
if dialect == db.Dialect {
|
||||
if uncomment {
|
||||
output = append(output, bytes.TrimPrefix(lines[i], []byte("--")))
|
||||
} else {
|
||||
output = append(output, lines[i])
|
||||
}
|
||||
}
|
||||
}
|
||||
if i == len(lines) {
|
||||
return "", fmt.Errorf(`didn't get end tag matching start %q at line %d`, string(startedAtMatch[1]), startedAt)
|
||||
}
|
||||
} else if dialect != db.Dialect {
|
||||
i += lineCount
|
||||
} else {
|
||||
output = append(output, lines[i])
|
||||
// Skip current line, uncomment the specified number of lines
|
||||
i++
|
||||
targetI := i + lineCount
|
||||
for ; i < targetI; i++ {
|
||||
if uncomment {
|
||||
output = append(output, bytes.TrimPrefix(lines[i], []byte("--")))
|
||||
} else {
|
||||
output = append(output, lines[i])
|
||||
}
|
||||
}
|
||||
// Decrement counter to avoid skipping the next line
|
||||
i--
|
||||
}
|
||||
}
|
||||
return string(bytes.Join(output, []byte("\n"))), nil
|
||||
@@ -191,7 +210,7 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
|
||||
|
||||
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
|
||||
return func(ctx context.Context, db *Database) error {
|
||||
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
|
||||
if dialect, skip, _, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine && dialect != db.Dialect {
|
||||
return nil
|
||||
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
|
||||
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
|
||||
|
||||
Reference in New Issue
Block a user