refactor to mautrix 0.17.x; update deps

This commit is contained in:
Aine
2024-02-11 20:47:04 +02:00
parent 0a9701f4c9
commit dd0ad4c245
237 changed files with 9091 additions and 3317 deletions

View File

@@ -7,12 +7,13 @@
package dbutil
import (
"context"
"database/sql"
"errors"
"fmt"
)
type upgradeFunc func(Execable, *Database) error
type upgradeFunc func(context.Context, *Database) error
type upgrade struct {
message string
@@ -28,19 +29,19 @@ var ErrForeignTables = errors.New("the database contains foreign tables")
var ErrNotOwned = errors.New("the database is owned by")
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
func (db *Database) upgradeVersionTable() error {
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
func (db *Database) upgradeVersionTable(ctx context.Context) error {
if compatColumnExists, err := db.ColumnExists(ctx, db.VersionTable, "compat"); err != nil {
return fmt.Errorf("failed to check if version table is up to date: %w", err)
} else if !compatColumnExists {
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
if tableExists, err := db.TableExists(ctx, db.VersionTable); err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
} else if !tableExists {
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
_, err = db.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to create version table: %w", err)
}
} else {
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
_, err = db.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to add compat column to version table: %w", err)
}
@@ -49,13 +50,13 @@ func (db *Database) upgradeVersionTable() error {
return nil
}
func (db *Database) getVersion() (version, compat int, err error) {
if err = db.upgradeVersionTable(); err != nil {
func (db *Database) getVersion(ctx context.Context) (version, compat int, err error) {
if err = db.upgradeVersionTable(ctx); err != nil {
return
}
var compatNull sql.NullInt32
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
err = db.QueryRow(ctx, fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
@@ -72,15 +73,12 @@ const (
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
)
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
if tx == nil {
tx = db
}
func (db *Database) TableExists(ctx context.Context, table string) (exists bool, err error) {
switch db.Dialect {
case SQLite:
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
err = db.QueryRow(ctx, tableExistsSQLite, table).Scan(&exists)
case Postgres:
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
err = db.QueryRow(ctx, tableExistsPostgres, table).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
@@ -92,23 +90,20 @@ const (
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
)
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
if tx == nil {
tx = db
}
func (db *Database) ColumnExists(ctx context.Context, table, column string) (exists bool, err error) {
switch db.Dialect {
case SQLite:
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
err = db.QueryRow(ctx, columnExistsSQLite, table, column).Scan(&exists)
case Postgres:
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
err = db.QueryRow(ctx, columnExistsPostgres, table, column).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
func (db *Database) tableExistsNoError(table string) bool {
exists, err := db.TableExists(nil, table)
func (db *Database) tableExistsNoError(ctx context.Context, table string) bool {
exists, err := db.TableExists(ctx, table)
if err != nil {
panic(fmt.Errorf("failed to check if table exists: %w", err))
}
@@ -122,22 +117,22 @@ CREATE TABLE IF NOT EXISTS database_owner (
)
`
func (db *Database) checkDatabaseOwner() error {
func (db *Database) checkDatabaseOwner(ctx context.Context) error {
var owner string
if !db.IgnoreForeignTables {
if db.tableExistsNoError("state_groups_state") {
if db.tableExistsNoError(ctx, "state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if db.tableExistsNoError("roomserver_rooms") {
} else if db.tableExistsNoError(ctx, "roomserver_rooms") {
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
}
}
if db.Owner == "" {
return nil
}
if _, err := db.Exec(createOwnerTable); err != nil {
if _, err := db.Exec(ctx, createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
} else if err = db.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
@@ -149,22 +144,22 @@ func (db *Database) checkDatabaseOwner() error {
return nil
}
func (db *Database) setVersion(tx Execable, version, compat int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
func (db *Database) setVersion(ctx context.Context, version, compat int) error {
_, err := db.Exec(ctx, fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
_, err = db.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
return err
}
func (db *Database) Upgrade() error {
err := db.checkDatabaseOwner()
func (db *Database) Upgrade(ctx context.Context) error {
err := db.checkDatabaseOwner(ctx)
if err != nil {
return err
}
version, compat, err := db.getVersion()
version, compat, err := db.getVersion(ctx)
if err != nil {
return err
}
@@ -185,34 +180,28 @@ func (db *Database) Upgrade() error {
version++
continue
}
doUpgrade := func(ctx context.Context) error {
err = upgradeItem.fn(ctx, db)
if err != nil {
return fmt.Errorf("failed to run upgrade #%d: %w", version, err)
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(ctx, version, upgradeItem.compatVersion)
if err != nil {
return err
}
return nil
}
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
var tx Transaction
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil {
return err
}
upgradeConn = tx
err = db.DoTxn(ctx, nil, doUpgrade)
} else {
upgradeConn = db
err = doUpgrade(ctx)
}
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
if err != nil {
return err
}
if tx != nil {
err = tx.Commit()
if err != nil {
return err
}
}
}
return nil
}