refactor to mautrix 0.17.x; update deps
This commit is contained in:
70
vendor/go.mau.fi/util/dbutil/connlog.go
vendored
70
vendor/go.mau.fi/util/dbutil/connlog.go
vendored
@@ -9,6 +9,9 @@ package dbutil
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"strconv"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
@@ -19,18 +22,61 @@ type LoggingExecable struct {
|
||||
db *Database
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
|
||||
type pqError interface {
|
||||
Get(k byte) string
|
||||
}
|
||||
|
||||
type PQErrorWithLine struct {
|
||||
Underlying error
|
||||
Line string
|
||||
}
|
||||
|
||||
func (pqe *PQErrorWithLine) Error() string {
|
||||
return pqe.Underlying.Error()
|
||||
}
|
||||
|
||||
func (pqe *PQErrorWithLine) Unwrap() error {
|
||||
return pqe.Underlying
|
||||
}
|
||||
|
||||
func addErrorLine(query string, err error) error {
|
||||
if err == nil {
|
||||
return err
|
||||
}
|
||||
var pqe pqError
|
||||
if !errors.As(err, &pqe) {
|
||||
return err
|
||||
}
|
||||
pos, _ := strconv.Atoi(pqe.Get('P'))
|
||||
pos--
|
||||
if pos <= 0 {
|
||||
return err
|
||||
}
|
||||
lines := strings.Split(query, "\n")
|
||||
for _, line := range lines {
|
||||
lineRunes := []rune(line)
|
||||
if pos < len(lineRunes)+1 {
|
||||
return &PQErrorWithLine{Underlying: err, Line: line}
|
||||
}
|
||||
pos -= len(lineRunes) + 1
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
|
||||
err = addErrorLine(query, err)
|
||||
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
|
||||
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
|
||||
err = addErrorLine(query, err)
|
||||
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
|
||||
return &LoggingRows{
|
||||
ctx: ctx,
|
||||
@@ -42,7 +88,7 @@ func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args
|
||||
}, err
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
|
||||
@@ -50,18 +96,6 @@ func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, ar
|
||||
return row
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result, error) {
|
||||
return le.ExecContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
|
||||
return le.QueryContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) QueryRow(query string, args ...interface{}) *sql.Row {
|
||||
return le.QueryRowContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
// loggingDB is a wrapper for LoggingExecable that allows access to BeginTx.
|
||||
//
|
||||
// While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since
|
||||
@@ -89,10 +123,6 @@ func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Logging
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (ld *loggingDB) Begin() (*LoggingTxn, error) {
|
||||
return ld.BeginTx(context.Background(), nil)
|
||||
}
|
||||
|
||||
type LoggingTxn struct {
|
||||
LoggingExecable
|
||||
UnderlyingTx *sql.Tx
|
||||
@@ -129,7 +159,7 @@ type LoggingRows struct {
|
||||
ctx context.Context
|
||||
db *Database
|
||||
query string
|
||||
args []interface{}
|
||||
args []any
|
||||
rs Rows
|
||||
start time.Time
|
||||
nrows int
|
||||
|
||||
48
vendor/go.mau.fi/util/dbutil/database.go
vendored
48
vendor/go.mau.fi/util/dbutil/database.go
vendored
@@ -58,7 +58,7 @@ type Rows interface {
|
||||
}
|
||||
|
||||
type Scannable interface {
|
||||
Scan(...interface{}) error
|
||||
Scan(...any) error
|
||||
}
|
||||
|
||||
// Expected implementations of Scannable
|
||||
@@ -67,30 +67,16 @@ var (
|
||||
_ Scannable = (Rows)(nil)
|
||||
)
|
||||
|
||||
type UnderlyingContextExecable interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type ContextExecable interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type UnderlyingExecable interface {
|
||||
UnderlyingContextExecable
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type Execable interface {
|
||||
ContextExecable
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...any) (Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
|
||||
}
|
||||
|
||||
type Transaction interface {
|
||||
@@ -101,15 +87,15 @@ type Transaction interface {
|
||||
|
||||
// Expected implementations of Execable
|
||||
var (
|
||||
_ UnderlyingExecable = (*sql.Tx)(nil)
|
||||
_ UnderlyingExecable = (*sql.DB)(nil)
|
||||
_ Execable = (*LoggingExecable)(nil)
|
||||
_ Transaction = (*LoggingTxn)(nil)
|
||||
_ UnderlyingContextExecable = (*sql.Conn)(nil)
|
||||
_ UnderlyingExecable = (*sql.Tx)(nil)
|
||||
_ UnderlyingExecable = (*sql.DB)(nil)
|
||||
_ UnderlyingExecable = (*sql.Conn)(nil)
|
||||
_ Execable = (*LoggingExecable)(nil)
|
||||
_ Transaction = (*LoggingTxn)(nil)
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
loggingDB
|
||||
LoggingDB loggingDB
|
||||
RawDB *sql.DB
|
||||
ReadOnlyDB *sql.DB
|
||||
Owner string
|
||||
@@ -139,7 +125,7 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
|
||||
}
|
||||
return &Database{
|
||||
RawDB: db.RawDB,
|
||||
loggingDB: db.loggingDB,
|
||||
LoggingDB: db.LoggingDB,
|
||||
Owner: "",
|
||||
VersionTable: versionTable,
|
||||
UpgradeTable: upgradeTable,
|
||||
@@ -164,8 +150,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
|
||||
IgnoreForeignTables: true,
|
||||
VersionTable: "version",
|
||||
}
|
||||
wrappedDB.loggingDB.UnderlyingExecable = db
|
||||
wrappedDB.loggingDB.db = wrappedDB
|
||||
wrappedDB.LoggingDB.UnderlyingExecable = db
|
||||
wrappedDB.LoggingDB.db = wrappedDB
|
||||
return wrappedDB, nil
|
||||
}
|
||||
|
||||
@@ -259,7 +245,7 @@ func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database,
|
||||
if roUri == "" {
|
||||
uriParts := strings.Split(cfg.URI, "?")
|
||||
|
||||
var qs url.Values
|
||||
qs := url.Values{}
|
||||
if len(uriParts) == 2 {
|
||||
var err error
|
||||
qs, err = url.ParseQuery(uriParts[1])
|
||||
|
||||
72
vendor/go.mau.fi/util/dbutil/iter.go
vendored
Normal file
72
vendor/go.mau.fi/util/dbutil/iter.go
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2023 Sumner Evans
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package dbutil
|
||||
|
||||
// RowIter is a wrapper for [Rows] that allows conveniently iterating over rows
|
||||
// with a predefined scanner function.
|
||||
type RowIter[T any] interface {
|
||||
// Iter iterates over the rows and calls the given function for each row.
|
||||
//
|
||||
// If the function returns false, the iteration is stopped.
|
||||
// If the function returns an error, the iteration is stopped and the error is
|
||||
// returned.
|
||||
Iter(func(T) (bool, error)) error
|
||||
|
||||
// AsList collects all rows into a slice.
|
||||
AsList() ([]T, error)
|
||||
}
|
||||
|
||||
type rowIterImpl[T any] struct {
|
||||
Rows
|
||||
ConvertRow func(Scannable) (T, error)
|
||||
}
|
||||
|
||||
// NewRowIter creates a new RowIter from the given Rows and scanner function.
|
||||
func NewRowIter[T any](rows Rows, convertFn func(Scannable) (T, error)) RowIter[T] {
|
||||
return &rowIterImpl[T]{Rows: rows, ConvertRow: convertFn}
|
||||
}
|
||||
|
||||
func ScanSingleColumn[T any](rows Scannable) (val T, err error) {
|
||||
err = rows.Scan(&val)
|
||||
return
|
||||
}
|
||||
|
||||
type NewableDataStruct[T any] interface {
|
||||
DataStruct[T]
|
||||
New() T
|
||||
}
|
||||
|
||||
func ScanDataStruct[T NewableDataStruct[T]](rows Scannable) (T, error) {
|
||||
var val T
|
||||
return val.New().Scan(rows)
|
||||
}
|
||||
|
||||
func (i *rowIterImpl[T]) Iter(fn func(T) (bool, error)) error {
|
||||
if i == nil || i.Rows == nil {
|
||||
return nil
|
||||
}
|
||||
defer i.Rows.Close()
|
||||
|
||||
for i.Rows.Next() {
|
||||
if item, err := i.ConvertRow(i.Rows); err != nil {
|
||||
return err
|
||||
} else if cont, err := fn(item); err != nil {
|
||||
return err
|
||||
} else if !cont {
|
||||
break
|
||||
}
|
||||
}
|
||||
return i.Rows.Err()
|
||||
}
|
||||
|
||||
func (i *rowIterImpl[T]) AsList() (list []T, err error) {
|
||||
err = i.Iter(func(item T) (bool, error) {
|
||||
list = append(list, item)
|
||||
return true, nil
|
||||
})
|
||||
return
|
||||
}
|
||||
14
vendor/go.mau.fi/util/dbutil/log.go
vendored
14
vendor/go.mau.fi/util/dbutil/log.go
vendored
@@ -10,12 +10,12 @@ import (
|
||||
)
|
||||
|
||||
type DatabaseLogger interface {
|
||||
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error)
|
||||
QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error)
|
||||
WarnUnsupportedVersion(current, compat, latest int)
|
||||
PrepareUpgrade(current, compat, latest int)
|
||||
DoUpgrade(from, to int, message string, txn bool)
|
||||
// Deprecated: legacy warning method, return errors instead
|
||||
Warn(msg string, args ...interface{})
|
||||
Warn(msg string, args ...any)
|
||||
}
|
||||
|
||||
type noopLogger struct{}
|
||||
@@ -25,9 +25,9 @@ var NoopLogger DatabaseLogger = &noopLogger{}
|
||||
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
|
||||
func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
|
||||
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
|
||||
func (n noopLogger) Warn(msg string, args ...interface{}) {}
|
||||
func (n noopLogger) Warn(msg string, args ...any) {}
|
||||
|
||||
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) {
|
||||
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []any, _ int, _ time.Duration, _ error) {
|
||||
}
|
||||
|
||||
type zeroLogger struct {
|
||||
@@ -92,7 +92,7 @@ func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
|
||||
|
||||
var whitespaceRegex = regexp.MustCompile(`\s+`)
|
||||
|
||||
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) {
|
||||
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
|
||||
log = z.l
|
||||
@@ -124,6 +124,6 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
|
||||
}
|
||||
}
|
||||
|
||||
func (z zeroLogger) Warn(msg string, args ...interface{}) {
|
||||
z.l.Warn().Msgf(msg, args...)
|
||||
func (z zeroLogger) Warn(msg string, args ...any) {
|
||||
z.l.Warn().Msgf(msg, args...) // zerolog-allow-msgf
|
||||
}
|
||||
|
||||
105
vendor/go.mau.fi/util/dbutil/queryhelper.go
vendored
Normal file
105
vendor/go.mau.fi/util/dbutil/queryhelper.go
vendored
Normal file
@@ -0,0 +1,105 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
|
||||
"golang.org/x/exp/constraints"
|
||||
)
|
||||
|
||||
// DataStruct is an interface for structs that represent a single database row.
|
||||
type DataStruct[T any] interface {
|
||||
Scan(row Scannable) (T, error)
|
||||
}
|
||||
|
||||
// QueryHelper is a generic helper struct for SQL query execution boilerplate.
|
||||
//
|
||||
// After implementing the Scan and Init methods in a data struct, the query
|
||||
// helper allows writing query functions in a single line.
|
||||
type QueryHelper[T DataStruct[T]] struct {
|
||||
db *Database
|
||||
newFunc func(qh *QueryHelper[T]) T
|
||||
}
|
||||
|
||||
func MakeQueryHelper[T DataStruct[T]](db *Database, new func(qh *QueryHelper[T]) T) *QueryHelper[T] {
|
||||
return &QueryHelper[T]{db: db, newFunc: new}
|
||||
}
|
||||
|
||||
// ValueOrErr is a helper function that returns the value if err is nil, or
|
||||
// returns nil and the error if err is not nil. It can be used to avoid
|
||||
// `if err != nil { return nil, err }` boilerplate in certain cases like
|
||||
// DataStruct.Scan implementations.
|
||||
func ValueOrErr[T any](val *T, err error) (*T, error) {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return val, nil
|
||||
}
|
||||
|
||||
// StrPtr returns a pointer to the given string, or nil if the string is empty.
|
||||
func StrPtr[T ~string](val T) *string {
|
||||
if val == "" {
|
||||
return nil
|
||||
}
|
||||
strVal := string(val)
|
||||
return &strVal
|
||||
}
|
||||
|
||||
// NumPtr returns a pointer to the given number, or nil if the number is zero.
|
||||
func NumPtr[T constraints.Integer | constraints.Float](val T) *T {
|
||||
if val == 0 {
|
||||
return nil
|
||||
}
|
||||
return &val
|
||||
}
|
||||
|
||||
func (qh *QueryHelper[T]) GetDB() *Database {
|
||||
return qh.db
|
||||
}
|
||||
|
||||
func (qh *QueryHelper[T]) New() T {
|
||||
return qh.newFunc(qh)
|
||||
}
|
||||
|
||||
// Exec executes a query with ExecContext and returns the error.
|
||||
//
|
||||
// It omits the sql.Result return value, as it is rarely used. When the result
|
||||
// is wanted, use `qh.GetDB().Exec(...)` instead, which is
|
||||
// otherwise equivalent.
|
||||
func (qh *QueryHelper[T]) Exec(ctx context.Context, query string, args ...any) error {
|
||||
_, err := qh.db.Exec(ctx, query, args...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (qh *QueryHelper[T]) scanNew(row Scannable) (T, error) {
|
||||
return qh.New().Scan(row)
|
||||
}
|
||||
|
||||
// QueryOne executes a query with QueryRowContext, uses the associated DataStruct
|
||||
// to scan it, and returns the value. If the query returns no rows, it returns nil
|
||||
// and no error.
|
||||
func (qh *QueryHelper[T]) QueryOne(ctx context.Context, query string, args ...any) (val T, err error) {
|
||||
val, err = qh.scanNew(qh.db.QueryRow(ctx, query, args...))
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return val, err
|
||||
}
|
||||
|
||||
// QueryMany executes a query with QueryContext, uses the associated DataStruct
|
||||
// to scan each row, and returns the values. If the query returns no rows, it
|
||||
// returns a non-nil zero-length slice and no error.
|
||||
func (qh *QueryHelper[T]) QueryMany(ctx context.Context, query string, args ...any) ([]T, error) {
|
||||
rows, err := qh.db.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewRowIter(rows, qh.scanNew).AsList()
|
||||
}
|
||||
22
vendor/go.mau.fi/util/dbutil/transaction.go
vendored
22
vendor/go.mau.fi/util/dbutil/transaction.go
vendored
@@ -32,6 +32,22 @@ const (
|
||||
ContextKeyDoTxnCallerSkip
|
||||
)
|
||||
|
||||
func (db *Database) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
|
||||
return db.Conn(ctx).ExecContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (db *Database) Query(ctx context.Context, query string, args ...any) (Rows, error) {
|
||||
return db.Conn(ctx).QueryContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
|
||||
return db.Conn(ctx).QueryRowContext(ctx, query, args...)
|
||||
}
|
||||
|
||||
func (db *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
|
||||
return db.LoggingDB.BeginTx(ctx, opts)
|
||||
}
|
||||
|
||||
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
|
||||
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
|
||||
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
|
||||
@@ -82,13 +98,13 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) Conn(ctx context.Context) ContextExecable {
|
||||
func (db *Database) Conn(ctx context.Context) Execable {
|
||||
if ctx == nil {
|
||||
return db
|
||||
return &db.LoggingDB
|
||||
}
|
||||
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
|
||||
if ok {
|
||||
return txn
|
||||
}
|
||||
return db
|
||||
return &db.LoggingDB
|
||||
}
|
||||
|
||||
101
vendor/go.mau.fi/util/dbutil/upgrades.go
vendored
101
vendor/go.mau.fi/util/dbutil/upgrades.go
vendored
@@ -7,12 +7,13 @@
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
)
|
||||
|
||||
type upgradeFunc func(Execable, *Database) error
|
||||
type upgradeFunc func(context.Context, *Database) error
|
||||
|
||||
type upgrade struct {
|
||||
message string
|
||||
@@ -28,19 +29,19 @@ var ErrForeignTables = errors.New("the database contains foreign tables")
|
||||
var ErrNotOwned = errors.New("the database is owned by")
|
||||
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
|
||||
|
||||
func (db *Database) upgradeVersionTable() error {
|
||||
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
|
||||
func (db *Database) upgradeVersionTable(ctx context.Context) error {
|
||||
if compatColumnExists, err := db.ColumnExists(ctx, db.VersionTable, "compat"); err != nil {
|
||||
return fmt.Errorf("failed to check if version table is up to date: %w", err)
|
||||
} else if !compatColumnExists {
|
||||
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
|
||||
if tableExists, err := db.TableExists(ctx, db.VersionTable); err != nil {
|
||||
return fmt.Errorf("failed to check if version table exists: %w", err)
|
||||
} else if !tableExists {
|
||||
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
|
||||
_, err = db.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create version table: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
|
||||
_, err = db.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add compat column to version table: %w", err)
|
||||
}
|
||||
@@ -49,13 +50,13 @@ func (db *Database) upgradeVersionTable() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) getVersion() (version, compat int, err error) {
|
||||
if err = db.upgradeVersionTable(); err != nil {
|
||||
func (db *Database) getVersion(ctx context.Context) (version, compat int, err error) {
|
||||
if err = db.upgradeVersionTable(ctx); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
var compatNull sql.NullInt32
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
|
||||
err = db.QueryRow(ctx, fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
@@ -72,15 +73,12 @@ const (
|
||||
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
|
||||
)
|
||||
|
||||
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
|
||||
if tx == nil {
|
||||
tx = db
|
||||
}
|
||||
func (db *Database) TableExists(ctx context.Context, table string) (exists bool, err error) {
|
||||
switch db.Dialect {
|
||||
case SQLite:
|
||||
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
|
||||
err = db.QueryRow(ctx, tableExistsSQLite, table).Scan(&exists)
|
||||
case Postgres:
|
||||
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
|
||||
err = db.QueryRow(ctx, tableExistsPostgres, table).Scan(&exists)
|
||||
default:
|
||||
err = ErrUnsupportedDialect
|
||||
}
|
||||
@@ -92,23 +90,20 @@ const (
|
||||
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
|
||||
)
|
||||
|
||||
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
|
||||
if tx == nil {
|
||||
tx = db
|
||||
}
|
||||
func (db *Database) ColumnExists(ctx context.Context, table, column string) (exists bool, err error) {
|
||||
switch db.Dialect {
|
||||
case SQLite:
|
||||
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
|
||||
err = db.QueryRow(ctx, columnExistsSQLite, table, column).Scan(&exists)
|
||||
case Postgres:
|
||||
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
|
||||
err = db.QueryRow(ctx, columnExistsPostgres, table, column).Scan(&exists)
|
||||
default:
|
||||
err = ErrUnsupportedDialect
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (db *Database) tableExistsNoError(table string) bool {
|
||||
exists, err := db.TableExists(nil, table)
|
||||
func (db *Database) tableExistsNoError(ctx context.Context, table string) bool {
|
||||
exists, err := db.TableExists(ctx, table)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to check if table exists: %w", err))
|
||||
}
|
||||
@@ -122,22 +117,22 @@ CREATE TABLE IF NOT EXISTS database_owner (
|
||||
)
|
||||
`
|
||||
|
||||
func (db *Database) checkDatabaseOwner() error {
|
||||
func (db *Database) checkDatabaseOwner(ctx context.Context) error {
|
||||
var owner string
|
||||
if !db.IgnoreForeignTables {
|
||||
if db.tableExistsNoError("state_groups_state") {
|
||||
if db.tableExistsNoError(ctx, "state_groups_state") {
|
||||
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
|
||||
} else if db.tableExistsNoError("roomserver_rooms") {
|
||||
} else if db.tableExistsNoError(ctx, "roomserver_rooms") {
|
||||
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
|
||||
}
|
||||
}
|
||||
if db.Owner == "" {
|
||||
return nil
|
||||
}
|
||||
if _, err := db.Exec(createOwnerTable); err != nil {
|
||||
if _, err := db.Exec(ctx, createOwnerTable); err != nil {
|
||||
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
|
||||
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
|
||||
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
|
||||
} else if err = db.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
|
||||
_, err = db.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert database owner: %w", err)
|
||||
}
|
||||
@@ -149,22 +144,22 @@ func (db *Database) checkDatabaseOwner() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) setVersion(tx Execable, version, compat int) error {
|
||||
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
|
||||
func (db *Database) setVersion(ctx context.Context, version, compat int) error {
|
||||
_, err := db.Exec(ctx, fmt.Sprintf("DELETE FROM %s", db.VersionTable))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
|
||||
_, err = db.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) Upgrade() error {
|
||||
err := db.checkDatabaseOwner()
|
||||
func (db *Database) Upgrade(ctx context.Context) error {
|
||||
err := db.checkDatabaseOwner(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
version, compat, err := db.getVersion()
|
||||
version, compat, err := db.getVersion(ctx)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -185,34 +180,28 @@ func (db *Database) Upgrade() error {
|
||||
version++
|
||||
continue
|
||||
}
|
||||
doUpgrade := func(ctx context.Context) error {
|
||||
err = upgradeItem.fn(ctx, db)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to run upgrade #%d: %w", version, err)
|
||||
}
|
||||
version = upgradeItem.upgradesTo
|
||||
logVersion = version
|
||||
err = db.setVersion(ctx, version, upgradeItem.compatVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
|
||||
var tx Transaction
|
||||
var upgradeConn Execable
|
||||
if upgradeItem.transaction {
|
||||
tx, err = db.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
upgradeConn = tx
|
||||
err = db.DoTxn(ctx, nil, doUpgrade)
|
||||
} else {
|
||||
upgradeConn = db
|
||||
err = doUpgrade(ctx)
|
||||
}
|
||||
err = upgradeItem.fn(upgradeConn, db)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
version = upgradeItem.upgradesTo
|
||||
logVersion = version
|
||||
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if tx != nil {
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
15
vendor/go.mau.fi/util/dbutil/upgradetable.go
vendored
15
vendor/go.mau.fi/util/dbutil/upgradetable.go
vendored
@@ -8,6 +8,7 @@ package dbutil
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io/fs"
|
||||
@@ -189,27 +190,27 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
|
||||
}
|
||||
|
||||
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
|
||||
return func(tx Execable, db *Database) error {
|
||||
return func(ctx context.Context, db *Database) error {
|
||||
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
|
||||
return nil
|
||||
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
|
||||
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
|
||||
} else {
|
||||
_, err = tx.Exec(upgradeSQL)
|
||||
_, err = db.Exec(ctx, upgradeSQL)
|
||||
return err
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
|
||||
return func(tx Execable, database *Database) (err error) {
|
||||
switch database.Dialect {
|
||||
return func(ctx context.Context, db *Database) (err error) {
|
||||
switch db.Dialect {
|
||||
case SQLite:
|
||||
_, err = tx.Exec(sqliteData)
|
||||
_, err = db.Exec(ctx, sqliteData)
|
||||
case Postgres:
|
||||
_, err = tx.Exec(postgresData)
|
||||
_, err = db.Exec(ctx, postgresData)
|
||||
default:
|
||||
err = fmt.Errorf("unknown dialect %s", database.Dialect)
|
||||
err = fmt.Errorf("unknown dialect %s", db.Dialect)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
23
vendor/go.mau.fi/util/exerrors/must.go
vendored
Normal file
23
vendor/go.mau.fi/util/exerrors/must.go
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package exerrors
|
||||
|
||||
func Must[T any](val T, err error) T {
|
||||
PanicIfNotNil(err)
|
||||
return val
|
||||
}
|
||||
|
||||
func Must2[T any, T2 any](val T, val2 T2, err error) (T, T2) {
|
||||
PanicIfNotNil(err)
|
||||
return val, val2
|
||||
}
|
||||
|
||||
func PanicIfNotNil(err error) {
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
73
vendor/go.mau.fi/util/jsontime/integer.go
vendored
73
vendor/go.mau.fi/util/jsontime/integer.go
vendored
@@ -7,10 +7,16 @@
|
||||
package jsontime
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
)
|
||||
|
||||
var ErrNotInteger = errors.New("value is not an integer")
|
||||
|
||||
func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) error {
|
||||
var val int64
|
||||
err := json.Unmarshal(data, &val)
|
||||
@@ -25,6 +31,28 @@ func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) err
|
||||
return nil
|
||||
}
|
||||
|
||||
func anyIntegerToTime(src any, unixConv func(int64) time.Time, into *time.Time) error {
|
||||
switch v := src.(type) {
|
||||
case int:
|
||||
*into = unixConv(int64(v))
|
||||
case int8:
|
||||
*into = unixConv(int64(v))
|
||||
case int16:
|
||||
*into = unixConv(int64(v))
|
||||
case int32:
|
||||
*into = unixConv(int64(v))
|
||||
case int64:
|
||||
*into = unixConv(int64(v))
|
||||
default:
|
||||
return fmt.Errorf("%w: %T", ErrNotInteger, src)
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
var _ sql.Scanner = &UnixMilli{}
|
||||
var _ driver.Valuer = UnixMilli{}
|
||||
|
||||
type UnixMilli struct {
|
||||
time.Time
|
||||
}
|
||||
@@ -40,6 +68,17 @@ func (um *UnixMilli) UnmarshalJSON(data []byte) error {
|
||||
return parseTime(data, time.UnixMilli, &um.Time)
|
||||
}
|
||||
|
||||
func (um UnixMilli) Value() (driver.Value, error) {
|
||||
return um.UnixMilli(), nil
|
||||
}
|
||||
|
||||
func (um *UnixMilli) Scan(src any) error {
|
||||
return anyIntegerToTime(src, time.UnixMilli, &um.Time)
|
||||
}
|
||||
|
||||
var _ sql.Scanner = &UnixMicro{}
|
||||
var _ driver.Valuer = UnixMicro{}
|
||||
|
||||
type UnixMicro struct {
|
||||
time.Time
|
||||
}
|
||||
@@ -55,6 +94,17 @@ func (um *UnixMicro) UnmarshalJSON(data []byte) error {
|
||||
return parseTime(data, time.UnixMicro, &um.Time)
|
||||
}
|
||||
|
||||
func (um UnixMicro) Value() (driver.Value, error) {
|
||||
return um.UnixMicro(), nil
|
||||
}
|
||||
|
||||
func (um *UnixMicro) Scan(src any) error {
|
||||
return anyIntegerToTime(src, time.UnixMicro, &um.Time)
|
||||
}
|
||||
|
||||
var _ sql.Scanner = &UnixNano{}
|
||||
var _ driver.Valuer = UnixNano{}
|
||||
|
||||
type UnixNano struct {
|
||||
time.Time
|
||||
}
|
||||
@@ -72,6 +122,16 @@ func (un *UnixNano) UnmarshalJSON(data []byte) error {
|
||||
}, &un.Time)
|
||||
}
|
||||
|
||||
func (un UnixNano) Value() (driver.Value, error) {
|
||||
return un.UnixNano(), nil
|
||||
}
|
||||
|
||||
func (un *UnixNano) Scan(src any) error {
|
||||
return anyIntegerToTime(src, func(i int64) time.Time {
|
||||
return time.Unix(0, i)
|
||||
}, &un.Time)
|
||||
}
|
||||
|
||||
type Unix struct {
|
||||
time.Time
|
||||
}
|
||||
@@ -83,8 +143,21 @@ func (u Unix) MarshalJSON() ([]byte, error) {
|
||||
return json.Marshal(u.Unix())
|
||||
}
|
||||
|
||||
var _ sql.Scanner = &Unix{}
|
||||
var _ driver.Valuer = Unix{}
|
||||
|
||||
func (u *Unix) UnmarshalJSON(data []byte) error {
|
||||
return parseTime(data, func(i int64) time.Time {
|
||||
return time.Unix(i, 0)
|
||||
}, &u.Time)
|
||||
}
|
||||
|
||||
func (u Unix) Value() (driver.Value, error) {
|
||||
return u.Unix(), nil
|
||||
}
|
||||
|
||||
func (u *Unix) Scan(src any) error {
|
||||
return anyIntegerToTime(src, func(i int64) time.Time {
|
||||
return time.Unix(i, 0)
|
||||
}, &u.Time)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user