update deps; experiment: log security
This commit is contained in:
90
vendor/maunium.net/go/mautrix/util/dbutil/connlog.go
generated
vendored
90
vendor/maunium.net/go/mautrix/util/dbutil/connlog.go
generated
vendored
@@ -15,7 +15,7 @@ import (
|
||||
// LoggingExecable is a wrapper for anything with database Exec methods (i.e. sql.Conn, sql.DB and sql.Tx)
|
||||
// that can preprocess queries (e.g. replacing $ with ? on SQLite) and log query durations.
|
||||
type LoggingExecable struct {
|
||||
UnderlyingExecable Execable
|
||||
UnderlyingExecable UnderlyingExecable
|
||||
db *Database
|
||||
}
|
||||
|
||||
@@ -23,23 +23,30 @@ func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args .
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "Exec", query, args, time.Since(start))
|
||||
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start))
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "Query", query, args, time.Since(start))
|
||||
return rows, err
|
||||
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start))
|
||||
return &LoggingRows{
|
||||
ctx: ctx,
|
||||
db: le.db,
|
||||
query: query,
|
||||
args: args,
|
||||
rs: rows,
|
||||
start: start,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, time.Since(start))
|
||||
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start))
|
||||
return row
|
||||
}
|
||||
|
||||
@@ -47,7 +54,7 @@ func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result,
|
||||
return le.ExecContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
|
||||
return le.QueryContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
@@ -66,7 +73,7 @@ type loggingDB struct {
|
||||
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
|
||||
start := time.Now()
|
||||
tx, err := ld.db.RawDB.BeginTx(ctx, opts)
|
||||
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, time.Since(start))
|
||||
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,13 +97,76 @@ type LoggingTxn struct {
|
||||
func (lt *LoggingTxn) Commit() error {
|
||||
start := time.Now()
|
||||
err := lt.UnderlyingTx.Commit()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, time.Since(start))
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start))
|
||||
return err
|
||||
}
|
||||
|
||||
func (lt *LoggingTxn) Rollback() error {
|
||||
start := time.Now()
|
||||
err := lt.UnderlyingTx.Rollback()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, time.Since(start))
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start))
|
||||
return err
|
||||
}
|
||||
|
||||
type LoggingRows struct {
|
||||
ctx context.Context
|
||||
db *Database
|
||||
query string
|
||||
args []interface{}
|
||||
rs Rows
|
||||
start time.Time
|
||||
nrows int
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) stopTiming() {
|
||||
if !lrs.start.IsZero() {
|
||||
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start))
|
||||
lrs.start = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Close() error {
|
||||
err := lrs.rs.Close()
|
||||
lrs.stopTiming()
|
||||
return err
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) ColumnTypes() ([]*sql.ColumnType, error) {
|
||||
return lrs.rs.ColumnTypes()
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Columns() ([]string, error) {
|
||||
return lrs.rs.Columns()
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Err() error {
|
||||
return lrs.rs.Err()
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Next() bool {
|
||||
hasNext := lrs.rs.Next()
|
||||
|
||||
if !hasNext {
|
||||
lrs.stopTiming()
|
||||
} else {
|
||||
lrs.nrows++
|
||||
}
|
||||
|
||||
return hasNext
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) NextResultSet() bool {
|
||||
hasNext := lrs.rs.NextResultSet()
|
||||
|
||||
if !hasNext {
|
||||
lrs.stopTiming()
|
||||
} else {
|
||||
lrs.nrows++
|
||||
}
|
||||
|
||||
return hasNext
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Scan(dest ...any) error {
|
||||
return lrs.rs.Scan(dest...)
|
||||
}
|
||||
|
||||
41
vendor/maunium.net/go/mautrix/util/dbutil/database.go
generated
vendored
41
vendor/maunium.net/go/mautrix/util/dbutil/database.go
generated
vendored
@@ -40,13 +40,23 @@ func ParseDialect(engine string) (Dialect, error) {
|
||||
switch strings.ToLower(engine) {
|
||||
case "postgres", "postgresql":
|
||||
return Postgres, nil
|
||||
case "sqlite3", "sqlite", "litestream":
|
||||
case "sqlite3", "sqlite", "litestream", "sqlite3-fk-wal":
|
||||
return SQLite, nil
|
||||
default:
|
||||
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
|
||||
}
|
||||
}
|
||||
|
||||
type Rows interface {
|
||||
Close() error
|
||||
ColumnTypes() ([]*sql.ColumnType, error)
|
||||
Columns() ([]string, error)
|
||||
Err() error
|
||||
Next() bool
|
||||
NextResultSet() bool
|
||||
Scan(...any) error
|
||||
}
|
||||
|
||||
type Scannable interface {
|
||||
Scan(...interface{}) error
|
||||
}
|
||||
@@ -54,19 +64,32 @@ type Scannable interface {
|
||||
// Expected implementations of Scannable
|
||||
var (
|
||||
_ Scannable = (*sql.Row)(nil)
|
||||
_ Scannable = (*sql.Rows)(nil)
|
||||
_ Scannable = (Rows)(nil)
|
||||
)
|
||||
|
||||
type ContextExecable interface {
|
||||
type UnderlyingContextExecable interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type ContextExecable interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type UnderlyingExecable interface {
|
||||
UnderlyingContextExecable
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type Execable interface {
|
||||
ContextExecable
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
Query(query string, args ...interface{}) (Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
@@ -78,11 +101,11 @@ type Transaction interface {
|
||||
|
||||
// Expected implementations of Execable
|
||||
var (
|
||||
_ Execable = (*sql.Tx)(nil)
|
||||
_ Execable = (*sql.DB)(nil)
|
||||
_ Execable = (*LoggingExecable)(nil)
|
||||
_ Transaction = (*LoggingTxn)(nil)
|
||||
_ ContextExecable = (*sql.Conn)(nil)
|
||||
_ UnderlyingExecable = (*sql.Tx)(nil)
|
||||
_ UnderlyingExecable = (*sql.DB)(nil)
|
||||
_ Execable = (*LoggingExecable)(nil)
|
||||
_ Transaction = (*LoggingTxn)(nil)
|
||||
_ UnderlyingContextExecable = (*sql.Conn)(nil)
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
|
||||
22
vendor/maunium.net/go/mautrix/util/dbutil/log.go
generated
vendored
22
vendor/maunium.net/go/mautrix/util/dbutil/log.go
generated
vendored
@@ -11,10 +11,10 @@ import (
|
||||
)
|
||||
|
||||
type DatabaseLogger interface {
|
||||
QueryTiming(ctx context.Context, method, query string, args []interface{}, duration time.Duration)
|
||||
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration)
|
||||
WarnUnsupportedVersion(current, latest int)
|
||||
PrepareUpgrade(current, latest int)
|
||||
DoUpgrade(from, to int, message string)
|
||||
DoUpgrade(from, to int, message string, txn bool)
|
||||
// Deprecated: legacy warning method, return errors instead
|
||||
Warn(msg string, args ...interface{})
|
||||
}
|
||||
@@ -25,10 +25,11 @@ var NoopLogger DatabaseLogger = &noopLogger{}
|
||||
|
||||
func (n noopLogger) WarnUnsupportedVersion(_, _ int) {}
|
||||
func (n noopLogger) PrepareUpgrade(_, _ int) {}
|
||||
func (n noopLogger) DoUpgrade(_, _ int, _ string) {}
|
||||
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
|
||||
func (n noopLogger) Warn(msg string, args ...interface{}) {}
|
||||
|
||||
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ time.Duration) {}
|
||||
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration) {
|
||||
}
|
||||
|
||||
type mauLogger struct {
|
||||
l maulogger.Logger
|
||||
@@ -46,11 +47,11 @@ func (m mauLogger) PrepareUpgrade(current, latest int) {
|
||||
m.l.Infofln("Database currently on v%d, latest: v%d", current, latest)
|
||||
}
|
||||
|
||||
func (m mauLogger) DoUpgrade(from, to int, message string) {
|
||||
func (m mauLogger) DoUpgrade(from, to int, message string, _ bool) {
|
||||
m.l.Infofln("Upgrading database from v%d to v%d: %s", from, to, message)
|
||||
}
|
||||
|
||||
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, duration time.Duration) {
|
||||
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration) {
|
||||
if duration > 1*time.Second {
|
||||
m.l.Warnfln("%s(%s) took %.3f seconds", method, query, duration.Seconds())
|
||||
}
|
||||
@@ -90,17 +91,18 @@ func (z zeroLogger) PrepareUpgrade(current, latest int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (z zeroLogger) DoUpgrade(from, to int, message string) {
|
||||
func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
|
||||
z.l.Info().
|
||||
Int("from", from).
|
||||
Int("to", to).
|
||||
Bool("single_txn", txn).
|
||||
Str("description", message).
|
||||
Msg("Upgrading database")
|
||||
}
|
||||
|
||||
var whitespaceRegex = regexp.MustCompile(`\s+`)
|
||||
|
||||
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, duration time.Duration) {
|
||||
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
if log.GetLevel() == zerolog.Disabled {
|
||||
log = z.l
|
||||
@@ -108,6 +110,10 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
|
||||
if log.GetLevel() != zerolog.TraceLevel && duration < 1*time.Second {
|
||||
return
|
||||
}
|
||||
if nrows > -1 {
|
||||
rowLog := log.With().Int("rows", nrows).Logger()
|
||||
log = &rowLog
|
||||
}
|
||||
query = strings.TrimSpace(whitespaceRegex.ReplaceAllLiteralString(query, " "))
|
||||
log.Trace().
|
||||
Int64("duration_µs", duration.Microseconds()).
|
||||
|
||||
4
vendor/maunium.net/go/mautrix/util/dbutil/samples/04-notxn.sql
generated
vendored
Normal file
4
vendor/maunium.net/go/mautrix/util/dbutil/samples/04-notxn.sql
generated
vendored
Normal file
@@ -0,0 +1,4 @@
|
||||
-- v4: Sample outside transaction
|
||||
-- transaction: off
|
||||
|
||||
INSERT INTO foo VALUES ('meow', '{}');
|
||||
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/04-postgres.sql
generated
vendored
Normal file
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/04-postgres.sql
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
INSERT INTO foo VALUES ('meow', '{}');
|
||||
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/04-sqlite3.sql
generated
vendored
Normal file
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/04-sqlite3.sql
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
INSERT INTO foo VALUES ('meow', '{}');
|
||||
33
vendor/maunium.net/go/mautrix/util/dbutil/upgrades.go
generated
vendored
33
vendor/maunium.net/go/mautrix/util/dbutil/upgrades.go
generated
vendored
@@ -12,13 +12,14 @@ import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type upgradeFunc func(Transaction, *Database) error
|
||||
type upgradeFunc func(Execable, *Database) error
|
||||
|
||||
type upgrade struct {
|
||||
message string
|
||||
fn upgradeFunc
|
||||
|
||||
upgradesTo int
|
||||
upgradesTo int
|
||||
transaction bool
|
||||
}
|
||||
|
||||
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
|
||||
@@ -93,7 +94,7 @@ func (db *Database) checkDatabaseOwner() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) setVersion(tx Transaction, version int) error {
|
||||
func (db *Database) setVersion(tx Execable, version int) error {
|
||||
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -129,25 +130,33 @@ func (db *Database) Upgrade() error {
|
||||
version++
|
||||
continue
|
||||
}
|
||||
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message)
|
||||
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
|
||||
var tx Transaction
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
var upgradeConn Execable
|
||||
if upgradeItem.transaction {
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upgradeConn = tx
|
||||
} else {
|
||||
upgradeConn = db
|
||||
}
|
||||
err = upgradeItem.fn(tx, db)
|
||||
err = upgradeItem.fn(upgradeConn, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
version = upgradeItem.upgradesTo
|
||||
logVersion = version
|
||||
err = db.setVersion(tx, version)
|
||||
err = db.setVersion(upgradeConn, version)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
if tx != nil {
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
|
||||
42
vendor/maunium.net/go/mautrix/util/dbutil/upgradetable.go
generated
vendored
42
vendor/maunium.net/go/mautrix/util/dbutil/upgradetable.go
generated
vendored
@@ -29,14 +29,14 @@ func (ut *UpgradeTable) extend(toSize int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
|
||||
func (ut *UpgradeTable) Register(from, to int, message string, txn bool, fn upgradeFunc) {
|
||||
if from < 0 {
|
||||
from += to
|
||||
}
|
||||
if from < 0 {
|
||||
panic("invalid from value in UpgradeTable.Register() call")
|
||||
}
|
||||
upg := upgrade{message: message, fn: fn, upgradesTo: to}
|
||||
upg := upgrade{message: message, fn: fn, upgradesTo: to, transaction: txn}
|
||||
if len(*ut) == from {
|
||||
*ut = append(*ut, upg)
|
||||
return
|
||||
@@ -57,7 +57,14 @@ func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
|
||||
// -- v1: Message
|
||||
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
|
||||
|
||||
func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte, err error) {
|
||||
// To disable wrapping the upgrade in a single transaction, put `--transaction: off` on the second line.
|
||||
//
|
||||
// -- v5: Upgrade without transaction
|
||||
// -- transaction: off
|
||||
// // do dangerous stuff
|
||||
var transactionDisableRegex = regexp.MustCompile(`^-- transaction: (\w*)`)
|
||||
|
||||
func parseFileHeader(file []byte) (from, to int, message string, txn bool, lines [][]byte, err error) {
|
||||
lines = bytes.Split(file, []byte("\n"))
|
||||
if len(lines) < 2 {
|
||||
err = errors.New("upgrade file too short")
|
||||
@@ -81,6 +88,15 @@ func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte,
|
||||
from = -1
|
||||
}
|
||||
message = string(match[3])
|
||||
txn = true
|
||||
match = transactionDisableRegex.FindSubmatch(lines[0])
|
||||
if match != nil {
|
||||
lines = lines[1:]
|
||||
if string(match[1]) != "off" {
|
||||
err = fmt.Errorf("invalid value %q for transaction flag", match[1])
|
||||
}
|
||||
txn = false
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -163,7 +179,7 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
|
||||
}
|
||||
|
||||
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
|
||||
return func(tx Transaction, db *Database) error {
|
||||
return func(tx Execable, db *Database) error {
|
||||
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
|
||||
return nil
|
||||
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
|
||||
@@ -176,7 +192,7 @@ func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
|
||||
}
|
||||
|
||||
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
|
||||
return func(tx Transaction, database *Database) (err error) {
|
||||
return func(tx Execable, database *Database) (err error) {
|
||||
switch database.Dialect {
|
||||
case SQLite:
|
||||
_, err = tx.Exec(sqliteData)
|
||||
@@ -189,7 +205,7 @@ func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, fn upgradeFunc) {
|
||||
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, txn bool, fn upgradeFunc) {
|
||||
postgresName := fmt.Sprintf("%s.postgres.sql", name)
|
||||
sqliteName := fmt.Sprintf("%s.sqlite.sql", name)
|
||||
skipNames[postgresName] = struct{}{}
|
||||
@@ -202,11 +218,11 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
from, to, message, _, err = parseFileHeader(postgresData)
|
||||
from, to, message, txn, _, err = parseFileHeader(postgresData)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to parse header in %s: %w", postgresName, err))
|
||||
}
|
||||
sqliteFrom, sqliteTo, sqliteMessage, _, err := parseFileHeader(sqliteData)
|
||||
sqliteFrom, sqliteTo, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to parse header in %s: %w", sqliteName, err))
|
||||
}
|
||||
@@ -214,6 +230,8 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
|
||||
panic(fmt.Errorf("mismatching versions in postgres and sqlite versions of %s: %d/%d -> %d/%d", name, from, sqliteFrom, to, sqliteTo))
|
||||
} else if message != sqliteMessage {
|
||||
panic(fmt.Errorf("mismatching message in postgres and sqlite versions of %s: %q != %q", name, message, sqliteMessage))
|
||||
} else if txn != sqliteTxn {
|
||||
panic(fmt.Errorf("mismatching transaction flag in postgres and sqlite versions of %s: %t != %t", name, txn, sqliteTxn))
|
||||
}
|
||||
fn = splitSQLUpgradeFunc(string(sqliteData), string(postgresData))
|
||||
return
|
||||
@@ -242,14 +260,14 @@ func (ut *UpgradeTable) RegisterFSPath(fs fullFS, dir string) {
|
||||
} else if _, skip := skipNames[file.Name()]; skip {
|
||||
// also do nothing
|
||||
} else if splitName := splitFileNameRegex.FindStringSubmatch(file.Name()); splitName != nil {
|
||||
from, to, message, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
|
||||
ut.Register(from, to, message, fn)
|
||||
from, to, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
|
||||
ut.Register(from, to, message, txn, fn)
|
||||
} else if data, err := fs.ReadFile(filepath.Join(dir, file.Name())); err != nil {
|
||||
panic(err)
|
||||
} else if from, to, message, lines, err := parseFileHeader(data); err != nil {
|
||||
} else if from, to, message, txn, lines, err := parseFileHeader(data); err != nil {
|
||||
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
|
||||
} else {
|
||||
ut.Register(from, to, message, sqlUpgradeFunc(file.Name(), lines))
|
||||
ut.Register(from, to, message, txn, sqlUpgradeFunc(file.Name(), lines))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user