automatically ignore known forwarded addresses, fixes #64
This commit is contained in:
289
vendor/go.mau.fi/util/dbutil/database.go
vendored
Normal file
289
vendor/go.mau.fi/util/dbutil/database.go
vendored
Normal file
@@ -0,0 +1,289 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"fmt"
|
||||
"net/url"
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
)
|
||||
|
||||
type Dialect int
|
||||
|
||||
const (
|
||||
DialectUnknown Dialect = iota
|
||||
Postgres
|
||||
SQLite
|
||||
)
|
||||
|
||||
func (dialect Dialect) String() string {
|
||||
switch dialect {
|
||||
case Postgres:
|
||||
return "postgres"
|
||||
case SQLite:
|
||||
return "sqlite3"
|
||||
default:
|
||||
return ""
|
||||
}
|
||||
}
|
||||
|
||||
func ParseDialect(engine string) (Dialect, error) {
|
||||
engine = strings.ToLower(engine)
|
||||
|
||||
if strings.HasPrefix(engine, "postgres") || engine == "pgx" {
|
||||
return Postgres, nil
|
||||
} else if strings.HasPrefix(engine, "sqlite") || strings.HasPrefix(engine, "litestream") {
|
||||
return SQLite, nil
|
||||
} else {
|
||||
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
|
||||
}
|
||||
}
|
||||
|
||||
type Rows interface {
|
||||
Close() error
|
||||
ColumnTypes() ([]*sql.ColumnType, error)
|
||||
Columns() ([]string, error)
|
||||
Err() error
|
||||
Next() bool
|
||||
NextResultSet() bool
|
||||
Scan(...any) error
|
||||
}
|
||||
|
||||
type Scannable interface {
|
||||
Scan(...interface{}) error
|
||||
}
|
||||
|
||||
// Expected implementations of Scannable
|
||||
var (
|
||||
_ Scannable = (*sql.Row)(nil)
|
||||
_ Scannable = (Rows)(nil)
|
||||
)
|
||||
|
||||
type UnderlyingContextExecable interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type ContextExecable interface {
|
||||
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
|
||||
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
|
||||
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type UnderlyingExecable interface {
|
||||
UnderlyingContextExecable
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (*sql.Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type Execable interface {
|
||||
ContextExecable
|
||||
Exec(query string, args ...interface{}) (sql.Result, error)
|
||||
Query(query string, args ...interface{}) (Rows, error)
|
||||
QueryRow(query string, args ...interface{}) *sql.Row
|
||||
}
|
||||
|
||||
type Transaction interface {
|
||||
Execable
|
||||
Commit() error
|
||||
Rollback() error
|
||||
}
|
||||
|
||||
// Expected implementations of Execable
|
||||
var (
|
||||
_ UnderlyingExecable = (*sql.Tx)(nil)
|
||||
_ UnderlyingExecable = (*sql.DB)(nil)
|
||||
_ Execable = (*LoggingExecable)(nil)
|
||||
_ Transaction = (*LoggingTxn)(nil)
|
||||
_ UnderlyingContextExecable = (*sql.Conn)(nil)
|
||||
)
|
||||
|
||||
type Database struct {
|
||||
loggingDB
|
||||
RawDB *sql.DB
|
||||
ReadOnlyDB *sql.DB
|
||||
Owner string
|
||||
VersionTable string
|
||||
Log DatabaseLogger
|
||||
Dialect Dialect
|
||||
UpgradeTable UpgradeTable
|
||||
|
||||
IgnoreForeignTables bool
|
||||
IgnoreUnsupportedDatabase bool
|
||||
}
|
||||
|
||||
var positionalParamPattern = regexp.MustCompile(`\$(\d+)`)
|
||||
|
||||
func (db *Database) mutateQuery(query string) string {
|
||||
switch db.Dialect {
|
||||
case SQLite:
|
||||
return positionalParamPattern.ReplaceAllString(query, "?$1")
|
||||
default:
|
||||
return query
|
||||
}
|
||||
}
|
||||
|
||||
func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log DatabaseLogger) *Database {
|
||||
if log == nil {
|
||||
log = db.Log
|
||||
}
|
||||
return &Database{
|
||||
RawDB: db.RawDB,
|
||||
loggingDB: db.loggingDB,
|
||||
Owner: "",
|
||||
VersionTable: versionTable,
|
||||
UpgradeTable: upgradeTable,
|
||||
Log: log,
|
||||
Dialect: db.Dialect,
|
||||
|
||||
IgnoreForeignTables: true,
|
||||
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
|
||||
}
|
||||
}
|
||||
|
||||
func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
|
||||
dialect, err := ParseDialect(rawDialect)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
wrappedDB := &Database{
|
||||
RawDB: db,
|
||||
Dialect: dialect,
|
||||
Log: NoopLogger,
|
||||
|
||||
IgnoreForeignTables: true,
|
||||
VersionTable: "version",
|
||||
}
|
||||
wrappedDB.loggingDB.UnderlyingExecable = db
|
||||
wrappedDB.loggingDB.db = wrappedDB
|
||||
return wrappedDB, nil
|
||||
}
|
||||
|
||||
func NewWithDialect(uri, rawDialect string) (*Database, error) {
|
||||
db, err := sql.Open(rawDialect, uri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return NewWithDB(db, rawDialect)
|
||||
}
|
||||
|
||||
type PoolConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
URI string `yaml:"uri"`
|
||||
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
|
||||
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
|
||||
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
type Config struct {
|
||||
PoolConfig `yaml:",inline"`
|
||||
ReadOnlyPool PoolConfig `yaml:"ro_pool"`
|
||||
}
|
||||
|
||||
func (db *Database) Close() error {
|
||||
err := db.RawDB.Close()
|
||||
if db.ReadOnlyDB != nil {
|
||||
err2 := db.ReadOnlyDB.Close()
|
||||
if err == nil {
|
||||
err = fmt.Errorf("closing read-only db failed: %w", err)
|
||||
} else {
|
||||
err = fmt.Errorf("%w (closing read-only db also failed: %v)", err, err2)
|
||||
}
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (db *Database) Configure(cfg Config) error {
|
||||
if err := db.configure(db.ReadOnlyDB, cfg.ReadOnlyPool); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
return db.configure(db.RawDB, cfg.PoolConfig)
|
||||
}
|
||||
|
||||
func (db *Database) configure(rawDB *sql.DB, cfg PoolConfig) error {
|
||||
if rawDB == nil {
|
||||
return nil
|
||||
}
|
||||
|
||||
rawDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
rawDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
if len(cfg.ConnMaxIdleTime) > 0 {
|
||||
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
|
||||
}
|
||||
rawDB.SetConnMaxIdleTime(maxIdleTimeDuration)
|
||||
}
|
||||
if len(cfg.ConnMaxLifetime) > 0 {
|
||||
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
|
||||
}
|
||||
rawDB.SetConnMaxLifetime(maxLifetimeDuration)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database, error) {
|
||||
wrappedDB, err := NewWithDialect(cfg.URI, cfg.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
wrappedDB.Owner = owner
|
||||
if logger != nil {
|
||||
wrappedDB.Log = logger
|
||||
}
|
||||
|
||||
if cfg.ReadOnlyPool.MaxOpenConns > 0 {
|
||||
if cfg.ReadOnlyPool.Type == "" {
|
||||
cfg.ReadOnlyPool.Type = cfg.Type
|
||||
}
|
||||
|
||||
roUri := cfg.ReadOnlyPool.URI
|
||||
if roUri == "" {
|
||||
uriParts := strings.Split(cfg.URI, "?")
|
||||
|
||||
var qs url.Values
|
||||
if len(uriParts) == 2 {
|
||||
var err error
|
||||
qs, err = url.ParseQuery(uriParts[1])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
qs.Del("_txlock")
|
||||
}
|
||||
qs.Set("_query_only", "true")
|
||||
|
||||
roUri = uriParts[0] + "?" + qs.Encode()
|
||||
}
|
||||
|
||||
wrappedDB.ReadOnlyDB, err = sql.Open(cfg.ReadOnlyPool.Type, roUri)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
err = wrappedDB.Configure(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return wrappedDB, nil
|
||||
}
|
||||
Reference in New Issue
Block a user