refactor to mautrix 0.17.x; update deps

This commit is contained in:
Aine
2024-02-11 20:47:04 +02:00
parent 0a9701f4c9
commit dd0ad4c245
237 changed files with 9091 additions and 3317 deletions

View File

@@ -9,6 +9,9 @@ package dbutil
import (
"context"
"database/sql"
"errors"
"strconv"
"strings"
"time"
)
@@ -19,18 +22,61 @@ type LoggingExecable struct {
db *Database
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
type pqError interface {
Get(k byte) string
}
type PQErrorWithLine struct {
Underlying error
Line string
}
func (pqe *PQErrorWithLine) Error() string {
return pqe.Underlying.Error()
}
func (pqe *PQErrorWithLine) Unwrap() error {
return pqe.Underlying
}
func addErrorLine(query string, err error) error {
if err == nil {
return err
}
var pqe pqError
if !errors.As(err, &pqe) {
return err
}
pos, _ := strconv.Atoi(pqe.Get('P'))
pos--
if pos <= 0 {
return err
}
lines := strings.Split(query, "\n")
for _, line := range lines {
lineRunes := []rune(line)
if pos < len(lineRunes)+1 {
return &PQErrorWithLine{Underlying: err, Line: line}
}
pos -= len(lineRunes) + 1
}
return err
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
err = addErrorLine(query, err)
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
return res, err
}
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) {
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
err = addErrorLine(query, err)
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
return &LoggingRows{
ctx: ctx,
@@ -42,7 +88,7 @@ func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args
}, err
}
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
@@ -50,18 +96,6 @@ func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, ar
return row
}
func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result, error) {
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
func (le *LoggingExecable) QueryRow(query string, args ...interface{}) *sql.Row {
return le.QueryRowContext(context.Background(), query, args...)
}
// loggingDB is a wrapper for LoggingExecable that allows access to BeginTx.
//
// While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since
@@ -89,10 +123,6 @@ func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Logging
}, nil
}
func (ld *loggingDB) Begin() (*LoggingTxn, error) {
return ld.BeginTx(context.Background(), nil)
}
type LoggingTxn struct {
LoggingExecable
UnderlyingTx *sql.Tx
@@ -129,7 +159,7 @@ type LoggingRows struct {
ctx context.Context
db *Database
query string
args []interface{}
args []any
rs Rows
start time.Time
nrows int

View File

@@ -58,7 +58,7 @@ type Rows interface {
}
type Scannable interface {
Scan(...interface{}) error
Scan(...any) error
}
// Expected implementations of Scannable
@@ -67,30 +67,16 @@ var (
_ Scannable = (Rows)(nil)
)
type UnderlyingContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type ContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type UnderlyingExecable interface {
UnderlyingContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type Execable interface {
ContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type Transaction interface {
@@ -101,15 +87,15 @@ type Transaction interface {
// Expected implementations of Execable
var (
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ UnderlyingContextExecable = (*sql.Conn)(nil)
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ UnderlyingExecable = (*sql.Conn)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
)
type Database struct {
loggingDB
LoggingDB loggingDB
RawDB *sql.DB
ReadOnlyDB *sql.DB
Owner string
@@ -139,7 +125,7 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
}
return &Database{
RawDB: db.RawDB,
loggingDB: db.loggingDB,
LoggingDB: db.LoggingDB,
Owner: "",
VersionTable: versionTable,
UpgradeTable: upgradeTable,
@@ -164,8 +150,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
IgnoreForeignTables: true,
VersionTable: "version",
}
wrappedDB.loggingDB.UnderlyingExecable = db
wrappedDB.loggingDB.db = wrappedDB
wrappedDB.LoggingDB.UnderlyingExecable = db
wrappedDB.LoggingDB.db = wrappedDB
return wrappedDB, nil
}
@@ -259,7 +245,7 @@ func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database,
if roUri == "" {
uriParts := strings.Split(cfg.URI, "?")
var qs url.Values
qs := url.Values{}
if len(uriParts) == 2 {
var err error
qs, err = url.ParseQuery(uriParts[1])

72
vendor/go.mau.fi/util/dbutil/iter.go vendored Normal file
View File

@@ -0,0 +1,72 @@
// Copyright (c) 2023 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
// RowIter is a wrapper for [Rows] that allows conveniently iterating over rows
// with a predefined scanner function.
type RowIter[T any] interface {
// Iter iterates over the rows and calls the given function for each row.
//
// If the function returns false, the iteration is stopped.
// If the function returns an error, the iteration is stopped and the error is
// returned.
Iter(func(T) (bool, error)) error
// AsList collects all rows into a slice.
AsList() ([]T, error)
}
type rowIterImpl[T any] struct {
Rows
ConvertRow func(Scannable) (T, error)
}
// NewRowIter creates a new RowIter from the given Rows and scanner function.
func NewRowIter[T any](rows Rows, convertFn func(Scannable) (T, error)) RowIter[T] {
return &rowIterImpl[T]{Rows: rows, ConvertRow: convertFn}
}
func ScanSingleColumn[T any](rows Scannable) (val T, err error) {
err = rows.Scan(&val)
return
}
type NewableDataStruct[T any] interface {
DataStruct[T]
New() T
}
func ScanDataStruct[T NewableDataStruct[T]](rows Scannable) (T, error) {
var val T
return val.New().Scan(rows)
}
func (i *rowIterImpl[T]) Iter(fn func(T) (bool, error)) error {
if i == nil || i.Rows == nil {
return nil
}
defer i.Rows.Close()
for i.Rows.Next() {
if item, err := i.ConvertRow(i.Rows); err != nil {
return err
} else if cont, err := fn(item); err != nil {
return err
} else if !cont {
break
}
}
return i.Rows.Err()
}
func (i *rowIterImpl[T]) AsList() (list []T, err error) {
err = i.Iter(func(item T) (bool, error) {
list = append(list, item)
return true, nil
})
return
}

View File

@@ -10,12 +10,12 @@ import (
)
type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error)
QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error)
WarnUnsupportedVersion(current, compat, latest int)
PrepareUpgrade(current, compat, latest int)
DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{})
Warn(msg string, args ...any)
}
type noopLogger struct{}
@@ -25,9 +25,9 @@ var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {}
func (n noopLogger) Warn(msg string, args ...any) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) {
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []any, _ int, _ time.Duration, _ error) {
}
type zeroLogger struct {
@@ -92,7 +92,7 @@ func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) {
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error) {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
log = z.l
@@ -124,6 +124,6 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
}
}
func (z zeroLogger) Warn(msg string, args ...interface{}) {
z.l.Warn().Msgf(msg, args...)
func (z zeroLogger) Warn(msg string, args ...any) {
z.l.Warn().Msgf(msg, args...) // zerolog-allow-msgf
}

View File

@@ -0,0 +1,105 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"errors"
"golang.org/x/exp/constraints"
)
// DataStruct is an interface for structs that represent a single database row.
type DataStruct[T any] interface {
Scan(row Scannable) (T, error)
}
// QueryHelper is a generic helper struct for SQL query execution boilerplate.
//
// After implementing the Scan and Init methods in a data struct, the query
// helper allows writing query functions in a single line.
type QueryHelper[T DataStruct[T]] struct {
db *Database
newFunc func(qh *QueryHelper[T]) T
}
func MakeQueryHelper[T DataStruct[T]](db *Database, new func(qh *QueryHelper[T]) T) *QueryHelper[T] {
return &QueryHelper[T]{db: db, newFunc: new}
}
// ValueOrErr is a helper function that returns the value if err is nil, or
// returns nil and the error if err is not nil. It can be used to avoid
// `if err != nil { return nil, err }` boilerplate in certain cases like
// DataStruct.Scan implementations.
func ValueOrErr[T any](val *T, err error) (*T, error) {
if err != nil {
return nil, err
}
return val, nil
}
// StrPtr returns a pointer to the given string, or nil if the string is empty.
func StrPtr[T ~string](val T) *string {
if val == "" {
return nil
}
strVal := string(val)
return &strVal
}
// NumPtr returns a pointer to the given number, or nil if the number is zero.
func NumPtr[T constraints.Integer | constraints.Float](val T) *T {
if val == 0 {
return nil
}
return &val
}
func (qh *QueryHelper[T]) GetDB() *Database {
return qh.db
}
func (qh *QueryHelper[T]) New() T {
return qh.newFunc(qh)
}
// Exec executes a query with ExecContext and returns the error.
//
// It omits the sql.Result return value, as it is rarely used. When the result
// is wanted, use `qh.GetDB().Exec(...)` instead, which is
// otherwise equivalent.
func (qh *QueryHelper[T]) Exec(ctx context.Context, query string, args ...any) error {
_, err := qh.db.Exec(ctx, query, args...)
return err
}
func (qh *QueryHelper[T]) scanNew(row Scannable) (T, error) {
return qh.New().Scan(row)
}
// QueryOne executes a query with QueryRowContext, uses the associated DataStruct
// to scan it, and returns the value. If the query returns no rows, it returns nil
// and no error.
func (qh *QueryHelper[T]) QueryOne(ctx context.Context, query string, args ...any) (val T, err error) {
val, err = qh.scanNew(qh.db.QueryRow(ctx, query, args...))
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return val, err
}
// QueryMany executes a query with QueryContext, uses the associated DataStruct
// to scan each row, and returns the values. If the query returns no rows, it
// returns a non-nil zero-length slice and no error.
func (qh *QueryHelper[T]) QueryMany(ctx context.Context, query string, args ...any) ([]T, error) {
rows, err := qh.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
return NewRowIter(rows, qh.scanNew).AsList()
}

View File

@@ -32,6 +32,22 @@ const (
ContextKeyDoTxnCallerSkip
)
func (db *Database) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return db.Conn(ctx).ExecContext(ctx, query, args...)
}
func (db *Database) Query(ctx context.Context, query string, args ...any) (Rows, error) {
return db.Conn(ctx).QueryContext(ctx, query, args...)
}
func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return db.Conn(ctx).QueryRowContext(ctx, query, args...)
}
func (db *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
return db.LoggingDB.BeginTx(ctx, opts)
}
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
@@ -82,13 +98,13 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
return nil
}
func (db *Database) Conn(ctx context.Context) ContextExecable {
func (db *Database) Conn(ctx context.Context) Execable {
if ctx == nil {
return db
return &db.LoggingDB
}
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
if ok {
return txn
}
return db
return &db.LoggingDB
}

View File

@@ -7,12 +7,13 @@
package dbutil
import (
"context"
"database/sql"
"errors"
"fmt"
)
type upgradeFunc func(Execable, *Database) error
type upgradeFunc func(context.Context, *Database) error
type upgrade struct {
message string
@@ -28,19 +29,19 @@ var ErrForeignTables = errors.New("the database contains foreign tables")
var ErrNotOwned = errors.New("the database is owned by")
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
func (db *Database) upgradeVersionTable() error {
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
func (db *Database) upgradeVersionTable(ctx context.Context) error {
if compatColumnExists, err := db.ColumnExists(ctx, db.VersionTable, "compat"); err != nil {
return fmt.Errorf("failed to check if version table is up to date: %w", err)
} else if !compatColumnExists {
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
if tableExists, err := db.TableExists(ctx, db.VersionTable); err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
} else if !tableExists {
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
_, err = db.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to create version table: %w", err)
}
} else {
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
_, err = db.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to add compat column to version table: %w", err)
}
@@ -49,13 +50,13 @@ func (db *Database) upgradeVersionTable() error {
return nil
}
func (db *Database) getVersion() (version, compat int, err error) {
if err = db.upgradeVersionTable(); err != nil {
func (db *Database) getVersion(ctx context.Context) (version, compat int, err error) {
if err = db.upgradeVersionTable(ctx); err != nil {
return
}
var compatNull sql.NullInt32
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
err = db.QueryRow(ctx, fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
@@ -72,15 +73,12 @@ const (
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
)
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
if tx == nil {
tx = db
}
func (db *Database) TableExists(ctx context.Context, table string) (exists bool, err error) {
switch db.Dialect {
case SQLite:
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
err = db.QueryRow(ctx, tableExistsSQLite, table).Scan(&exists)
case Postgres:
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
err = db.QueryRow(ctx, tableExistsPostgres, table).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
@@ -92,23 +90,20 @@ const (
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
)
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
if tx == nil {
tx = db
}
func (db *Database) ColumnExists(ctx context.Context, table, column string) (exists bool, err error) {
switch db.Dialect {
case SQLite:
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
err = db.QueryRow(ctx, columnExistsSQLite, table, column).Scan(&exists)
case Postgres:
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
err = db.QueryRow(ctx, columnExistsPostgres, table, column).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
func (db *Database) tableExistsNoError(table string) bool {
exists, err := db.TableExists(nil, table)
func (db *Database) tableExistsNoError(ctx context.Context, table string) bool {
exists, err := db.TableExists(ctx, table)
if err != nil {
panic(fmt.Errorf("failed to check if table exists: %w", err))
}
@@ -122,22 +117,22 @@ CREATE TABLE IF NOT EXISTS database_owner (
)
`
func (db *Database) checkDatabaseOwner() error {
func (db *Database) checkDatabaseOwner(ctx context.Context) error {
var owner string
if !db.IgnoreForeignTables {
if db.tableExistsNoError("state_groups_state") {
if db.tableExistsNoError(ctx, "state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if db.tableExistsNoError("roomserver_rooms") {
} else if db.tableExistsNoError(ctx, "roomserver_rooms") {
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
}
}
if db.Owner == "" {
return nil
}
if _, err := db.Exec(createOwnerTable); err != nil {
if _, err := db.Exec(ctx, createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
} else if err = db.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
@@ -149,22 +144,22 @@ func (db *Database) checkDatabaseOwner() error {
return nil
}
func (db *Database) setVersion(tx Execable, version, compat int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
func (db *Database) setVersion(ctx context.Context, version, compat int) error {
_, err := db.Exec(ctx, fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
_, err = db.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
return err
}
func (db *Database) Upgrade() error {
err := db.checkDatabaseOwner()
func (db *Database) Upgrade(ctx context.Context) error {
err := db.checkDatabaseOwner(ctx)
if err != nil {
return err
}
version, compat, err := db.getVersion()
version, compat, err := db.getVersion(ctx)
if err != nil {
return err
}
@@ -185,34 +180,28 @@ func (db *Database) Upgrade() error {
version++
continue
}
doUpgrade := func(ctx context.Context) error {
err = upgradeItem.fn(ctx, db)
if err != nil {
return fmt.Errorf("failed to run upgrade #%d: %w", version, err)
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(ctx, version, upgradeItem.compatVersion)
if err != nil {
return err
}
return nil
}
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
var tx Transaction
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil {
return err
}
upgradeConn = tx
err = db.DoTxn(ctx, nil, doUpgrade)
} else {
upgradeConn = db
err = doUpgrade(ctx)
}
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
if err != nil {
return err
}
if tx != nil {
err = tx.Commit()
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -8,6 +8,7 @@ package dbutil
import (
"bytes"
"context"
"errors"
"fmt"
"io/fs"
@@ -189,27 +190,27 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
}
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx Execable, db *Database) error {
return func(ctx context.Context, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
return nil
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
} else {
_, err = tx.Exec(upgradeSQL)
_, err = db.Exec(ctx, upgradeSQL)
return err
}
}
}
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
return func(tx Execable, database *Database) (err error) {
switch database.Dialect {
return func(ctx context.Context, db *Database) (err error) {
switch db.Dialect {
case SQLite:
_, err = tx.Exec(sqliteData)
_, err = db.Exec(ctx, sqliteData)
case Postgres:
_, err = tx.Exec(postgresData)
_, err = db.Exec(ctx, postgresData)
default:
err = fmt.Errorf("unknown dialect %s", database.Dialect)
err = fmt.Errorf("unknown dialect %s", db.Dialect)
}
return
}

23
vendor/go.mau.fi/util/exerrors/must.go vendored Normal file
View File

@@ -0,0 +1,23 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package exerrors
func Must[T any](val T, err error) T {
PanicIfNotNil(err)
return val
}
func Must2[T any, T2 any](val T, val2 T2, err error) (T, T2) {
PanicIfNotNil(err)
return val, val2
}
func PanicIfNotNil(err error) {
if err != nil {
panic(err)
}
}

View File

@@ -7,10 +7,16 @@
package jsontime
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"time"
)
var ErrNotInteger = errors.New("value is not an integer")
func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) error {
var val int64
err := json.Unmarshal(data, &val)
@@ -25,6 +31,28 @@ func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) err
return nil
}
func anyIntegerToTime(src any, unixConv func(int64) time.Time, into *time.Time) error {
switch v := src.(type) {
case int:
*into = unixConv(int64(v))
case int8:
*into = unixConv(int64(v))
case int16:
*into = unixConv(int64(v))
case int32:
*into = unixConv(int64(v))
case int64:
*into = unixConv(int64(v))
default:
return fmt.Errorf("%w: %T", ErrNotInteger, src)
}
return nil
}
var _ sql.Scanner = &UnixMilli{}
var _ driver.Valuer = UnixMilli{}
type UnixMilli struct {
time.Time
}
@@ -40,6 +68,17 @@ func (um *UnixMilli) UnmarshalJSON(data []byte) error {
return parseTime(data, time.UnixMilli, &um.Time)
}
func (um UnixMilli) Value() (driver.Value, error) {
return um.UnixMilli(), nil
}
func (um *UnixMilli) Scan(src any) error {
return anyIntegerToTime(src, time.UnixMilli, &um.Time)
}
var _ sql.Scanner = &UnixMicro{}
var _ driver.Valuer = UnixMicro{}
type UnixMicro struct {
time.Time
}
@@ -55,6 +94,17 @@ func (um *UnixMicro) UnmarshalJSON(data []byte) error {
return parseTime(data, time.UnixMicro, &um.Time)
}
func (um UnixMicro) Value() (driver.Value, error) {
return um.UnixMicro(), nil
}
func (um *UnixMicro) Scan(src any) error {
return anyIntegerToTime(src, time.UnixMicro, &um.Time)
}
var _ sql.Scanner = &UnixNano{}
var _ driver.Valuer = UnixNano{}
type UnixNano struct {
time.Time
}
@@ -72,6 +122,16 @@ func (un *UnixNano) UnmarshalJSON(data []byte) error {
}, &un.Time)
}
func (un UnixNano) Value() (driver.Value, error) {
return un.UnixNano(), nil
}
func (un *UnixNano) Scan(src any) error {
return anyIntegerToTime(src, func(i int64) time.Time {
return time.Unix(0, i)
}, &un.Time)
}
type Unix struct {
time.Time
}
@@ -83,8 +143,21 @@ func (u Unix) MarshalJSON() ([]byte, error) {
return json.Marshal(u.Unix())
}
var _ sql.Scanner = &Unix{}
var _ driver.Valuer = Unix{}
func (u *Unix) UnmarshalJSON(data []byte) error {
return parseTime(data, func(i int64) time.Time {
return time.Unix(i, 0)
}, &u.Time)
}
func (u Unix) Value() (driver.Value, error) {
return u.Unix(), nil
}
func (u *Unix) Scan(src any) error {
return anyIntegerToTime(src, func(i int64) time.Time {
return time.Unix(i, 0)
}, &u.Time)
}