automatically ignore known forwarded addresses, fixes #64

This commit is contained in:
Aine
2023-09-18 12:35:37 +03:00
parent e90925eceb
commit 60b4386dd8
187 changed files with 4070 additions and 2667 deletions

189
vendor/go.mau.fi/util/dbutil/connlog.go vendored Normal file
View File

@@ -0,0 +1,189 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"time"
)
// LoggingExecable is a wrapper for anything with database Exec methods (i.e. sql.Conn, sql.DB and sql.Tx)
// that can preprocess queries (e.g. replacing $ with ? on SQLite) and log query durations.
type LoggingExecable struct {
UnderlyingExecable UnderlyingExecable
db *Database
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
return res, err
}
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
return &LoggingRows{
ctx: ctx,
db: le.db,
query: query,
args: args,
rs: rows,
start: start,
}, err
}
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start), nil)
return row
}
func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result, error) {
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
func (le *LoggingExecable) QueryRow(query string, args ...interface{}) *sql.Row {
return le.QueryRowContext(context.Background(), query, args...)
}
// loggingDB is a wrapper for LoggingExecable that allows access to BeginTx.
//
// While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since
// the LoggingExecable could be for a transaction (where BeginTx wouldn't make sense).
type loggingDB struct {
LoggingExecable
}
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
targetDB := ld.db.RawDB
if opts != nil && opts.ReadOnly && ld.db.ReadOnlyDB != nil {
targetDB = ld.db.ReadOnlyDB
}
start := time.Now()
tx, err := targetDB.BeginTx(ctx, opts)
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start), err)
if err != nil {
return nil, err
}
return &LoggingTxn{
LoggingExecable: LoggingExecable{UnderlyingExecable: tx, db: ld.db},
UnderlyingTx: tx,
ctx: ctx,
StartTime: start,
}, nil
}
func (ld *loggingDB) Begin() (*LoggingTxn, error) {
return ld.BeginTx(context.Background(), nil)
}
type LoggingTxn struct {
LoggingExecable
UnderlyingTx *sql.Tx
ctx context.Context
StartTime time.Time
EndTime time.Time
noTotalLog bool
}
func (lt *LoggingTxn) Commit() error {
start := time.Now()
err := lt.UnderlyingTx.Commit()
lt.EndTime = time.Now()
if !lt.noTotalLog {
lt.db.Log.QueryTiming(lt.ctx, "<Transaction>", "", nil, -1, lt.EndTime.Sub(lt.StartTime), nil)
}
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start), err)
return err
}
func (lt *LoggingTxn) Rollback() error {
start := time.Now()
err := lt.UnderlyingTx.Rollback()
lt.EndTime = time.Now()
if !lt.noTotalLog {
lt.db.Log.QueryTiming(lt.ctx, "<Transaction>", "", nil, -1, lt.EndTime.Sub(lt.StartTime), nil)
}
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start), err)
return err
}
type LoggingRows struct {
ctx context.Context
db *Database
query string
args []interface{}
rs Rows
start time.Time
nrows int
}
func (lrs *LoggingRows) stopTiming() {
if !lrs.start.IsZero() {
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start), lrs.rs.Err())
lrs.start = time.Time{}
}
}
func (lrs *LoggingRows) Close() error {
err := lrs.rs.Close()
lrs.stopTiming()
return err
}
func (lrs *LoggingRows) ColumnTypes() ([]*sql.ColumnType, error) {
return lrs.rs.ColumnTypes()
}
func (lrs *LoggingRows) Columns() ([]string, error) {
return lrs.rs.Columns()
}
func (lrs *LoggingRows) Err() error {
return lrs.rs.Err()
}
func (lrs *LoggingRows) Next() bool {
hasNext := lrs.rs.Next()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) NextResultSet() bool {
hasNext := lrs.rs.NextResultSet()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) Scan(dest ...any) error {
return lrs.rs.Scan(dest...)
}

289
vendor/go.mau.fi/util/dbutil/database.go vendored Normal file
View File

@@ -0,0 +1,289 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"fmt"
"net/url"
"regexp"
"strings"
"time"
)
type Dialect int
const (
DialectUnknown Dialect = iota
Postgres
SQLite
)
func (dialect Dialect) String() string {
switch dialect {
case Postgres:
return "postgres"
case SQLite:
return "sqlite3"
default:
return ""
}
}
func ParseDialect(engine string) (Dialect, error) {
engine = strings.ToLower(engine)
if strings.HasPrefix(engine, "postgres") || engine == "pgx" {
return Postgres, nil
} else if strings.HasPrefix(engine, "sqlite") || strings.HasPrefix(engine, "litestream") {
return SQLite, nil
} else {
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
}
}
type Rows interface {
Close() error
ColumnTypes() ([]*sql.ColumnType, error)
Columns() ([]string, error)
Err() error
Next() bool
NextResultSet() bool
Scan(...any) error
}
type Scannable interface {
Scan(...interface{}) error
}
// Expected implementations of Scannable
var (
_ Scannable = (*sql.Row)(nil)
_ Scannable = (Rows)(nil)
)
type UnderlyingContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type ContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type UnderlyingExecable interface {
UnderlyingContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Execable interface {
ContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Transaction interface {
Execable
Commit() error
Rollback() error
}
// Expected implementations of Execable
var (
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ UnderlyingContextExecable = (*sql.Conn)(nil)
)
type Database struct {
loggingDB
RawDB *sql.DB
ReadOnlyDB *sql.DB
Owner string
VersionTable string
Log DatabaseLogger
Dialect Dialect
UpgradeTable UpgradeTable
IgnoreForeignTables bool
IgnoreUnsupportedDatabase bool
}
var positionalParamPattern = regexp.MustCompile(`\$(\d+)`)
func (db *Database) mutateQuery(query string) string {
switch db.Dialect {
case SQLite:
return positionalParamPattern.ReplaceAllString(query, "?$1")
default:
return query
}
}
func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log DatabaseLogger) *Database {
if log == nil {
log = db.Log
}
return &Database{
RawDB: db.RawDB,
loggingDB: db.loggingDB,
Owner: "",
VersionTable: versionTable,
UpgradeTable: upgradeTable,
Log: log,
Dialect: db.Dialect,
IgnoreForeignTables: true,
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
}
}
func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
dialect, err := ParseDialect(rawDialect)
if err != nil {
return nil, err
}
wrappedDB := &Database{
RawDB: db,
Dialect: dialect,
Log: NoopLogger,
IgnoreForeignTables: true,
VersionTable: "version",
}
wrappedDB.loggingDB.UnderlyingExecable = db
wrappedDB.loggingDB.db = wrappedDB
return wrappedDB, nil
}
func NewWithDialect(uri, rawDialect string) (*Database, error) {
db, err := sql.Open(rawDialect, uri)
if err != nil {
return nil, err
}
return NewWithDB(db, rawDialect)
}
type PoolConfig struct {
Type string `yaml:"type"`
URI string `yaml:"uri"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
}
type Config struct {
PoolConfig `yaml:",inline"`
ReadOnlyPool PoolConfig `yaml:"ro_pool"`
}
func (db *Database) Close() error {
err := db.RawDB.Close()
if db.ReadOnlyDB != nil {
err2 := db.ReadOnlyDB.Close()
if err == nil {
err = fmt.Errorf("closing read-only db failed: %w", err)
} else {
err = fmt.Errorf("%w (closing read-only db also failed: %v)", err, err2)
}
}
return err
}
func (db *Database) Configure(cfg Config) error {
if err := db.configure(db.ReadOnlyDB, cfg.ReadOnlyPool); err != nil {
return err
}
return db.configure(db.RawDB, cfg.PoolConfig)
}
func (db *Database) configure(rawDB *sql.DB, cfg PoolConfig) error {
if rawDB == nil {
return nil
}
rawDB.SetMaxOpenConns(cfg.MaxOpenConns)
rawDB.SetMaxIdleConns(cfg.MaxIdleConns)
if len(cfg.ConnMaxIdleTime) > 0 {
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
if err != nil {
return fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
rawDB.SetConnMaxIdleTime(maxIdleTimeDuration)
}
if len(cfg.ConnMaxLifetime) > 0 {
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
if err != nil {
return fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
rawDB.SetConnMaxLifetime(maxLifetimeDuration)
}
return nil
}
func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database, error) {
wrappedDB, err := NewWithDialect(cfg.URI, cfg.Type)
if err != nil {
return nil, err
}
wrappedDB.Owner = owner
if logger != nil {
wrappedDB.Log = logger
}
if cfg.ReadOnlyPool.MaxOpenConns > 0 {
if cfg.ReadOnlyPool.Type == "" {
cfg.ReadOnlyPool.Type = cfg.Type
}
roUri := cfg.ReadOnlyPool.URI
if roUri == "" {
uriParts := strings.Split(cfg.URI, "?")
var qs url.Values
if len(uriParts) == 2 {
var err error
qs, err = url.ParseQuery(uriParts[1])
if err != nil {
return nil, err
}
qs.Del("_txlock")
}
qs.Set("_query_only", "true")
roUri = uriParts[0] + "?" + qs.Encode()
}
wrappedDB.ReadOnlyDB, err = sql.Open(cfg.ReadOnlyPool.Type, roUri)
if err != nil {
return nil, err
}
}
err = wrappedDB.Configure(cfg)
if err != nil {
return nil, err
}
return wrappedDB, nil
}

33
vendor/go.mau.fi/util/dbutil/json.go vendored Normal file
View File

@@ -0,0 +1,33 @@
package dbutil
import (
"database/sql/driver"
"encoding/json"
"fmt"
)
// JSON is a utility type for using arbitrary JSON data as values in database Exec and Scan calls.
type JSON struct {
Data any
}
func (j JSON) Scan(i any) error {
switch value := i.(type) {
case nil:
return nil
case string:
return json.Unmarshal([]byte(value), j.Data)
case []byte:
return json.Unmarshal(value, j.Data)
default:
return fmt.Errorf("invalid type %T for dbutil.JSON.Scan", i)
}
}
func (j JSON) Value() (driver.Value, error) {
if j.Data == nil {
return nil, nil
}
v, err := json.Marshal(j.Data)
return string(v), err
}

129
vendor/go.mau.fi/util/dbutil/log.go vendored Normal file
View File

@@ -0,0 +1,129 @@
package dbutil
import (
"context"
"regexp"
"strings"
"time"
"github.com/rs/zerolog"
)
type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error)
WarnUnsupportedVersion(current, compat, latest int)
PrepareUpgrade(current, compat, latest int)
DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{})
}
type noopLogger struct{}
var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) {
}
type zeroLogger struct {
l *zerolog.Logger
ZeroLogSettings
}
type ZeroLogSettings struct {
CallerSkipFrame int
Caller bool
// TraceLogAllQueries specifies whether or not all queries should be logged
// at the TRACE level.
TraceLogAllQueries bool
}
func ZeroLogger(log zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
return ZeroLoggerPtr(&log, cfg...)
}
func ZeroLoggerPtr(log *zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
wrapped := &zeroLogger{l: log}
if len(cfg) > 0 {
wrapped.ZeroLogSettings = cfg[0]
} else {
wrapped.ZeroLogSettings = ZeroLogSettings{
CallerSkipFrame: 2, // Skip LoggingExecable.ExecContext and zeroLogger.QueryTiming
Caller: true,
}
}
return wrapped
}
func (z zeroLogger) WarnUnsupportedVersion(current, compat, latest int) {
z.l.Warn().
Int("current_version", current).
Int("oldest_compatible_version", compat).
Int("latest_known_version", latest).
Msg("Unsupported database schema version, continuing anyway")
}
func (z zeroLogger) PrepareUpgrade(current, compat, latest int) {
evt := z.l.Info().
Int("current_version", current).
Int("oldest_compatible_version", compat).
Int("latest_known_version", latest)
if current >= latest {
evt.Msg("Database is up to date")
} else {
evt.Msg("Preparing to update database schema")
}
}
func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
z.l.Info().
Int("from", from).
Int("to", to).
Bool("single_txn", txn).
Str("description", message).
Msg("Upgrading database")
}
var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
log = z.l
}
if (!z.TraceLogAllQueries || log.GetLevel() != zerolog.TraceLevel) && duration < 1*time.Second {
return
}
if nrows > -1 {
rowLog := log.With().Int("rows", nrows).Logger()
log = &rowLog
}
query = strings.TrimSpace(whitespaceRegex.ReplaceAllLiteralString(query, " "))
log.Trace().
Err(err).
Int64("duration_µs", duration.Microseconds()).
Str("method", method).
Str("query", query).
Interface("query_args", args).
Msg("Query")
if duration >= 1*time.Second {
evt := log.Warn().
Float64("duration_seconds", duration.Seconds()).
Str("method", method).
Str("query", query)
if z.Caller {
evt = evt.Caller(z.CallerSkipFrame)
}
evt.Msg("Query took long")
}
}
func (z zeroLogger) Warn(msg string, args ...interface{}) {
z.l.Warn().Msgf(msg, args...)
}

View File

@@ -0,0 +1,21 @@
-- v0 -> v3: Sample revision jump
CREATE TABLE foo (
-- only: postgres
key BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
-- only: sqlite
key INTEGER PRIMARY KEY,
data JSONB NOT NULL
);
-- only: sqlite until "end only"
CREATE TRIGGER test AFTER INSERT ON foo WHEN NEW.data->>'action' = 'delete' BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
END;
-- end only sqlite
-- only: postgres until "end only"
CREATE FUNCTION delete_data() RETURNS TRIGGER LANGUAGE plpgsql AS $$ BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
RETURN NEW;
END $$;
-- end only postgres

View File

@@ -0,0 +1,4 @@
-- v4: Sample outside transaction
-- transaction: off
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -0,0 +1,3 @@
-- v5 (compatible with v3+): Sample backwards-compatible upgrade
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -0,0 +1,11 @@
CREATE TABLE foo (
key BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
data JSONB NOT NULL
);
CREATE FUNCTION delete_data() RETURNS TRIGGER LANGUAGE plpgsql AS $$ BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
RETURN NEW;
END $$;
-- end only postgres

View File

@@ -0,0 +1,10 @@
CREATE TABLE foo (
key INTEGER PRIMARY KEY,
data JSONB NOT NULL
);
CREATE TRIGGER test AFTER INSERT ON foo WHEN NEW.data->>'action' = 'delete' BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
END;
-- end only sqlite

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -0,0 +1,94 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"go.mau.fi/util/random"
)
var (
ErrTxn = errors.New("transaction")
ErrTxnBegin = fmt.Errorf("%w: begin", ErrTxn)
ErrTxnCommit = fmt.Errorf("%w: commit", ErrTxn)
)
type contextKey int
const (
ContextKeyDatabaseTransaction contextKey = iota
ContextKeyDoTxnCallerSkip
)
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
return fn(ctx)
}
log := zerolog.Ctx(ctx).With().Str("db_txn_id", random.String(12)).Logger()
start := time.Now()
defer func() {
dur := time.Since(start)
if dur > time.Second {
val := ctx.Value(ContextKeyDoTxnCallerSkip)
callerSkip := 2
if val != nil {
callerSkip += val.(int)
}
log.Warn().
Float64("duration_seconds", dur.Seconds()).
Caller(callerSkip).
Msg("Transaction took long")
}
}()
tx, err := db.BeginTx(ctx, opts)
if err != nil {
log.Trace().Err(err).Msg("Failed to begin transaction")
return exerrors.NewDualError(ErrTxnBegin, err)
}
log.Trace().Msg("Transaction started")
tx.noTotalLog = true
ctx = log.WithContext(ctx)
ctx = context.WithValue(ctx, ContextKeyDatabaseTransaction, tx)
err = fn(ctx)
if err != nil {
log.Trace().Err(err).Msg("Database transaction failed, rolling back")
rollbackErr := tx.Rollback()
if rollbackErr != nil {
log.Warn().Err(rollbackErr).Msg("Rollback after transaction error failed")
} else {
log.Trace().Msg("Rollback successful")
}
return err
}
err = tx.Commit()
if err != nil {
log.Trace().Err(err).Msg("Commit failed")
return exerrors.NewDualError(ErrTxnCommit, err)
}
log.Trace().Msg("Commit successful")
return nil
}
func (db *Database) Conn(ctx context.Context) ContextExecable {
if ctx == nil {
return db
}
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
if ok {
return txn
}
return db
}

218
vendor/go.mau.fi/util/dbutil/upgrades.go vendored Normal file
View File

@@ -0,0 +1,218 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"database/sql"
"errors"
"fmt"
)
type upgradeFunc func(Execable, *Database) error
type upgrade struct {
message string
fn upgradeFunc
upgradesTo int
compatVersion int
transaction bool
}
var ErrUnsupportedDatabaseVersion = errors.New("unsupported database schema version")
var ErrForeignTables = errors.New("the database contains foreign tables")
var ErrNotOwned = errors.New("the database is owned by")
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
func (db *Database) upgradeVersionTable() error {
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
return fmt.Errorf("failed to check if version table is up to date: %w", err)
} else if !compatColumnExists {
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
} else if !tableExists {
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to create version table: %w", err)
}
} else {
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to add compat column to version table: %w", err)
}
}
}
return nil
}
func (db *Database) getVersion() (version, compat int, err error) {
if err = db.upgradeVersionTable(); err != nil {
return
}
var compatNull sql.NullInt32
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if compatNull.Valid && compatNull.Int32 != 0 {
compat = int(compatNull.Int32)
} else {
compat = version
}
return
}
const (
tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
)
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect {
case SQLite:
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
case Postgres:
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
const (
columnExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name=$2)"
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
)
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect {
case SQLite:
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
case Postgres:
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
func (db *Database) tableExistsNoError(table string) bool {
exists, err := db.TableExists(nil, table)
if err != nil {
panic(fmt.Errorf("failed to check if table exists: %w", err))
}
return exists
}
const createOwnerTable = `
CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
)
`
func (db *Database) checkDatabaseOwner() error {
var owner string
if !db.IgnoreForeignTables {
if db.tableExistsNoError("state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if db.tableExistsNoError("roomserver_rooms") {
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
}
}
if db.Owner == "" {
return nil
}
if _, err := db.Exec(createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to check database owner: %w", err)
} else if owner != db.Owner {
return fmt.Errorf("%w %s", ErrNotOwned, owner)
}
return nil
}
func (db *Database) setVersion(tx Execable, version, compat int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
return err
}
func (db *Database) Upgrade() error {
err := db.checkDatabaseOwner()
if err != nil {
return err
}
version, compat, err := db.getVersion()
if err != nil {
return err
}
if compat > len(db.UpgradeTable) {
if db.IgnoreUnsupportedDatabase {
db.Log.WarnUnsupportedVersion(version, compat, len(db.UpgradeTable))
return nil
}
return fmt.Errorf("%w: currently on v%d (compatible down to v%d), latest known: v%d", ErrUnsupportedDatabaseVersion, version, compat, len(db.UpgradeTable))
}
db.Log.PrepareUpgrade(version, compat, len(db.UpgradeTable))
logVersion := version
for version < len(db.UpgradeTable) {
upgradeItem := db.UpgradeTable[version]
if upgradeItem.fn == nil {
version++
continue
}
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
var tx Transaction
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil {
return err
}
upgradeConn = tx
} else {
upgradeConn = db
}
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
if err != nil {
return err
}
if tx != nil {
err = tx.Commit()
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -0,0 +1,283 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"bytes"
"errors"
"fmt"
"io/fs"
"path/filepath"
"regexp"
"strconv"
"strings"
)
type UpgradeTable []upgrade
func (ut *UpgradeTable) extend(toSize int) {
if cap(*ut) >= toSize {
*ut = (*ut)[:toSize]
} else {
resized := make([]upgrade, toSize)
copy(resized, *ut)
*ut = resized
}
}
func (ut *UpgradeTable) Register(from, to, compat int, message string, txn bool, fn upgradeFunc) {
if from < 0 {
from += to
}
if from < 0 {
panic("invalid from value in UpgradeTable.Register() call")
}
if compat <= 0 {
compat = to
}
upg := upgrade{message: message, fn: fn, upgradesTo: to, compatVersion: compat, transaction: txn}
if len(*ut) == from {
*ut = append(*ut, upg)
return
} else if len(*ut) < from {
ut.extend(from + 1)
} else if (*ut)[from].fn != nil {
panic(fmt.Errorf("tried to override upgrade at %d ('%s') with '%s'", from, (*ut)[from].message, upg.message))
}
(*ut)[from] = upg
}
// Syntax is either
//
// -- v0 -> v1: Message
//
// or
//
// -- v1: Message
//
// Both syntaxes may also have a compatibility notice before the colon:
//
// -- v5 (compatible with v3+): Upgrade with backwards compatibility
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+)(?: \(compatible with v(\d+)\+\))?: (.+)$`)
// To disable wrapping the upgrade in a single transaction, put `--transaction: off` on the second line.
//
// -- v5: Upgrade without transaction
// -- transaction: off
// // do dangerous stuff
var transactionDisableRegex = regexp.MustCompile(`^-- transaction: (\w*)`)
func parseFileHeader(file []byte) (from, to, compat int, message string, txn bool, lines [][]byte, err error) {
lines = bytes.Split(file, []byte("\n"))
if len(lines) < 2 {
err = errors.New("upgrade file too short")
return
}
var maybeFrom int
match := upgradeHeaderRegex.FindSubmatch(lines[0])
lines = lines[1:]
if match == nil {
err = errors.New("header not found")
} else if len(match) != 5 {
err = errors.New("unexpected number of items in regex match")
} else if maybeFrom, err = strconv.Atoi(string(match[1])); len(match[1]) > 0 && err != nil {
err = fmt.Errorf("invalid source version: %w", err)
} else if to, err = strconv.Atoi(string(match[2])); err != nil {
err = fmt.Errorf("invalid target version: %w", err)
} else if compat, err = strconv.Atoi(string(match[3])); len(match[3]) > 0 && err != nil {
err = fmt.Errorf("invalid compatible version: %w", err)
} else {
err = nil
if len(match[1]) > 0 {
from = maybeFrom
} else {
from = -1
}
message = string(match[4])
txn = true
match = transactionDisableRegex.FindSubmatch(lines[0])
if match != nil {
lines = lines[1:]
if string(match[1]) != "off" {
err = fmt.Errorf("invalid value %q for transaction flag", match[1])
}
txn = false
}
}
return
}
// To limit the next line to one dialect:
//
// -- only: postgres
//
// To limit the next N lines:
//
// -- only: sqlite for next 123 lines
//
// If the single-line limit is on the second line of the file, the whole file is limited to that dialect.
var dialectLineFilter = regexp.MustCompile(`^\s*-- only: (postgres|sqlite)(?: for next (\d+) lines| until "(end) only")?`)
// Constants used to make parseDialectFilter clearer
const (
skipUntilEndTag = -1
skipNothing = 0
skipCurrentLine = 1
skipNextLine = 2
)
func (db *Database) parseDialectFilter(line []byte) (int, error) {
match := dialectLineFilter.FindSubmatch(line)
if match == nil {
return skipNothing, nil
}
dialect, err := ParseDialect(string(match[1]))
if err != nil {
return skipNothing, err
} else if dialect == db.Dialect {
// Skip the dialect filter line
return skipCurrentLine, nil
} else if bytes.Equal(match[3], []byte("end")) {
return skipUntilEndTag, nil
} else if len(match[2]) == 0 {
// Skip the dialect filter and the next line
return skipNextLine, nil
} else {
// Parse number of lines to skip, add 1 for current line
lineCount, err := strconv.Atoi(string(match[2]))
if err != nil {
return skipNothing, fmt.Errorf("invalid line count '%s': %w", match[2], err)
}
return skipCurrentLine + lineCount, nil
}
}
var endLineFilter = regexp.MustCompile(`^\s*-- end only (postgres|sqlite)$`)
func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
output := make([][]byte, 0, len(lines))
for i := 0; i < len(lines); i++ {
skipLines, err := db.parseDialectFilter(lines[i])
if err != nil {
return "", err
} else if skipLines > 0 {
// Current line is implicitly skipped, so reduce one here
i += skipLines - 1
} else if skipLines == skipUntilEndTag {
startedAt := i
startedAtMatch := dialectLineFilter.FindSubmatch(lines[startedAt])
for ; i < len(lines); i++ {
if match := endLineFilter.FindSubmatch(lines[i]); match != nil {
if !bytes.Equal(match[1], startedAtMatch[1]) {
return "", fmt.Errorf(`unexpected end tag %q for %q start at line %d`, string(match[0]), string(startedAtMatch[1]), startedAt)
}
break
}
}
if i == len(lines) {
return "", fmt.Errorf(`didn't get end tag matching start %q at line %d`, string(startedAtMatch[1]), startedAt)
}
} else {
output = append(output, lines[i])
}
}
return string(bytes.Join(output, []byte("\n"))), nil
}
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx Execable, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
return nil
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
} else {
_, err = tx.Exec(upgradeSQL)
return err
}
}
}
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
return func(tx Execable, database *Database) (err error) {
switch database.Dialect {
case SQLite:
_, err = tx.Exec(sqliteData)
case Postgres:
_, err = tx.Exec(postgresData)
default:
err = fmt.Errorf("unknown dialect %s", database.Dialect)
}
return
}
}
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to, compat int, message string, txn bool, fn upgradeFunc) {
postgresName := fmt.Sprintf("%s.postgres.sql", name)
sqliteName := fmt.Sprintf("%s.sqlite.sql", name)
skipNames[postgresName] = struct{}{}
skipNames[sqliteName] = struct{}{}
postgresData, err := fs.ReadFile(postgresName)
if err != nil {
panic(err)
}
sqliteData, err := fs.ReadFile(sqliteName)
if err != nil {
panic(err)
}
from, to, compat, message, txn, _, err = parseFileHeader(postgresData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", postgresName, err))
}
sqliteFrom, sqliteTo, sqliteCompat, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", sqliteName, err))
}
if from != sqliteFrom || to != sqliteTo || compat != sqliteCompat {
panic(fmt.Errorf("mismatching versions in postgres and sqlite versions of %s: %d/%d -> %d/%d", name, from, sqliteFrom, to, sqliteTo))
} else if message != sqliteMessage {
panic(fmt.Errorf("mismatching message in postgres and sqlite versions of %s: %q != %q", name, message, sqliteMessage))
} else if txn != sqliteTxn {
panic(fmt.Errorf("mismatching transaction flag in postgres and sqlite versions of %s: %t != %t", name, txn, sqliteTxn))
}
fn = splitSQLUpgradeFunc(string(sqliteData), string(postgresData))
return
}
type fullFS interface {
fs.ReadFileFS
fs.ReadDirFS
}
var splitFileNameRegex = regexp.MustCompile(`^(.+)\.(postgres|sqlite)\.sql$`)
func (ut *UpgradeTable) RegisterFS(fs fullFS) {
ut.RegisterFSPath(fs, ".")
}
func (ut *UpgradeTable) RegisterFSPath(fs fullFS, dir string) {
files, err := fs.ReadDir(dir)
if err != nil {
panic(err)
}
skipNames := map[string]struct{}{}
for _, file := range files {
if file.IsDir() || !strings.HasSuffix(file.Name(), ".sql") {
// do nothing
} else if _, skip := skipNames[file.Name()]; skip {
// also do nothing
} else if splitName := splitFileNameRegex.FindStringSubmatch(file.Name()); splitName != nil {
from, to, compat, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, compat, message, txn, fn)
} else if data, err := fs.ReadFile(filepath.Join(dir, file.Name())); err != nil {
panic(err)
} else if from, to, compat, message, txn, lines, err := parseFileHeader(data); err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
} else {
ut.Register(from, to, compat, message, txn, sqlUpgradeFunc(file.Name(), lines))
}
}
}