update deps; experiment: log security

This commit is contained in:
Aine
2022-11-16 23:00:58 +02:00
parent 225ba2ee9b
commit 99a89ef87a
55 changed files with 883 additions and 308 deletions

View File

@@ -15,7 +15,7 @@ import (
// LoggingExecable is a wrapper for anything with database Exec methods (i.e. sql.Conn, sql.DB and sql.Tx)
// that can preprocess queries (e.g. replacing $ with ? on SQLite) and log query durations.
type LoggingExecable struct {
UnderlyingExecable Execable
UnderlyingExecable UnderlyingExecable
db *Database
}
@@ -23,23 +23,30 @@ func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args .
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Exec", query, args, time.Since(start))
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start))
return res, err
}
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Query", query, args, time.Since(start))
return rows, err
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start))
return &LoggingRows{
ctx: ctx,
db: le.db,
query: query,
args: args,
rs: rows,
start: start,
}, err
}
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, time.Since(start))
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start))
return row
}
@@ -47,7 +54,7 @@ func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result,
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (*sql.Rows, error) {
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
@@ -66,7 +73,7 @@ type loggingDB struct {
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
start := time.Now()
tx, err := ld.db.RawDB.BeginTx(ctx, opts)
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, time.Since(start))
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start))
if err != nil {
return nil, err
}
@@ -90,13 +97,76 @@ type LoggingTxn struct {
func (lt *LoggingTxn) Commit() error {
start := time.Now()
err := lt.UnderlyingTx.Commit()
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, time.Since(start))
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start))
return err
}
func (lt *LoggingTxn) Rollback() error {
start := time.Now()
err := lt.UnderlyingTx.Rollback()
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, time.Since(start))
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start))
return err
}
type LoggingRows struct {
ctx context.Context
db *Database
query string
args []interface{}
rs Rows
start time.Time
nrows int
}
func (lrs *LoggingRows) stopTiming() {
if !lrs.start.IsZero() {
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start))
lrs.start = time.Time{}
}
}
func (lrs *LoggingRows) Close() error {
err := lrs.rs.Close()
lrs.stopTiming()
return err
}
func (lrs *LoggingRows) ColumnTypes() ([]*sql.ColumnType, error) {
return lrs.rs.ColumnTypes()
}
func (lrs *LoggingRows) Columns() ([]string, error) {
return lrs.rs.Columns()
}
func (lrs *LoggingRows) Err() error {
return lrs.rs.Err()
}
func (lrs *LoggingRows) Next() bool {
hasNext := lrs.rs.Next()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) NextResultSet() bool {
hasNext := lrs.rs.NextResultSet()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) Scan(dest ...any) error {
return lrs.rs.Scan(dest...)
}

View File

@@ -40,13 +40,23 @@ func ParseDialect(engine string) (Dialect, error) {
switch strings.ToLower(engine) {
case "postgres", "postgresql":
return Postgres, nil
case "sqlite3", "sqlite", "litestream":
case "sqlite3", "sqlite", "litestream", "sqlite3-fk-wal":
return SQLite, nil
default:
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
}
}
type Rows interface {
Close() error
ColumnTypes() ([]*sql.ColumnType, error)
Columns() ([]string, error)
Err() error
Next() bool
NextResultSet() bool
Scan(...any) error
}
type Scannable interface {
Scan(...interface{}) error
}
@@ -54,19 +64,32 @@ type Scannable interface {
// Expected implementations of Scannable
var (
_ Scannable = (*sql.Row)(nil)
_ Scannable = (*sql.Rows)(nil)
_ Scannable = (Rows)(nil)
)
type ContextExecable interface {
type UnderlyingContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type ContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type UnderlyingExecable interface {
UnderlyingContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Execable interface {
ContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
Query(query string, args ...interface{}) (Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
@@ -78,11 +101,11 @@ type Transaction interface {
// Expected implementations of Execable
var (
_ Execable = (*sql.Tx)(nil)
_ Execable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ ContextExecable = (*sql.Conn)(nil)
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ UnderlyingContextExecable = (*sql.Conn)(nil)
)
type Database struct {

View File

@@ -11,10 +11,10 @@ import (
)
type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, duration time.Duration)
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration)
WarnUnsupportedVersion(current, latest int)
PrepareUpgrade(current, latest int)
DoUpgrade(from, to int, message string)
DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{})
}
@@ -25,10 +25,11 @@ var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ time.Duration) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration) {
}
type mauLogger struct {
l maulogger.Logger
@@ -46,11 +47,11 @@ func (m mauLogger) PrepareUpgrade(current, latest int) {
m.l.Infofln("Database currently on v%d, latest: v%d", current, latest)
}
func (m mauLogger) DoUpgrade(from, to int, message string) {
func (m mauLogger) DoUpgrade(from, to int, message string, _ bool) {
m.l.Infofln("Upgrading database from v%d to v%d: %s", from, to, message)
}
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, duration time.Duration) {
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration) {
if duration > 1*time.Second {
m.l.Warnfln("%s(%s) took %.3f seconds", method, query, duration.Seconds())
}
@@ -90,17 +91,18 @@ func (z zeroLogger) PrepareUpgrade(current, latest int) {
}
}
func (z zeroLogger) DoUpgrade(from, to int, message string) {
func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
z.l.Info().
Int("from", from).
Int("to", to).
Bool("single_txn", txn).
Str("description", message).
Msg("Upgrading database")
}
var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, duration time.Duration) {
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration) {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled {
log = z.l
@@ -108,6 +110,10 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
if log.GetLevel() != zerolog.TraceLevel && duration < 1*time.Second {
return
}
if nrows > -1 {
rowLog := log.With().Int("rows", nrows).Logger()
log = &rowLog
}
query = strings.TrimSpace(whitespaceRegex.ReplaceAllLiteralString(query, " "))
log.Trace().
Int64("duration_µs", duration.Microseconds()).

View File

@@ -0,0 +1,4 @@
-- v4: Sample outside transaction
-- transaction: off
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -12,13 +12,14 @@ import (
"fmt"
)
type upgradeFunc func(Transaction, *Database) error
type upgradeFunc func(Execable, *Database) error
type upgrade struct {
message string
fn upgradeFunc
upgradesTo int
upgradesTo int
transaction bool
}
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
@@ -93,7 +94,7 @@ func (db *Database) checkDatabaseOwner() error {
return nil
}
func (db *Database) setVersion(tx Transaction, version int) error {
func (db *Database) setVersion(tx Execable, version int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
@@ -129,25 +130,33 @@ func (db *Database) Upgrade() error {
version++
continue
}
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message)
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
var tx Transaction
tx, err = db.Begin()
if err != nil {
return err
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil {
return err
}
upgradeConn = tx
} else {
upgradeConn = db
}
err = upgradeItem.fn(tx, db)
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(tx, version)
err = db.setVersion(upgradeConn, version)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
if tx != nil {
err = tx.Commit()
if err != nil {
return err
}
}
}
return nil

View File

@@ -29,14 +29,14 @@ func (ut *UpgradeTable) extend(toSize int) {
}
}
func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
func (ut *UpgradeTable) Register(from, to int, message string, txn bool, fn upgradeFunc) {
if from < 0 {
from += to
}
if from < 0 {
panic("invalid from value in UpgradeTable.Register() call")
}
upg := upgrade{message: message, fn: fn, upgradesTo: to}
upg := upgrade{message: message, fn: fn, upgradesTo: to, transaction: txn}
if len(*ut) == from {
*ut = append(*ut, upg)
return
@@ -57,7 +57,14 @@ func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
// -- v1: Message
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte, err error) {
// To disable wrapping the upgrade in a single transaction, put `--transaction: off` on the second line.
//
// -- v5: Upgrade without transaction
// -- transaction: off
// // do dangerous stuff
var transactionDisableRegex = regexp.MustCompile(`^-- transaction: (\w*)`)
func parseFileHeader(file []byte) (from, to int, message string, txn bool, lines [][]byte, err error) {
lines = bytes.Split(file, []byte("\n"))
if len(lines) < 2 {
err = errors.New("upgrade file too short")
@@ -81,6 +88,15 @@ func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte,
from = -1
}
message = string(match[3])
txn = true
match = transactionDisableRegex.FindSubmatch(lines[0])
if match != nil {
lines = lines[1:]
if string(match[1]) != "off" {
err = fmt.Errorf("invalid value %q for transaction flag", match[1])
}
txn = false
}
}
return
}
@@ -163,7 +179,7 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
}
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx Transaction, db *Database) error {
return func(tx Execable, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
return nil
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
@@ -176,7 +192,7 @@ func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
}
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
return func(tx Transaction, database *Database) (err error) {
return func(tx Execable, database *Database) (err error) {
switch database.Dialect {
case SQLite:
_, err = tx.Exec(sqliteData)
@@ -189,7 +205,7 @@ func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
}
}
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, fn upgradeFunc) {
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, txn bool, fn upgradeFunc) {
postgresName := fmt.Sprintf("%s.postgres.sql", name)
sqliteName := fmt.Sprintf("%s.sqlite.sql", name)
skipNames[postgresName] = struct{}{}
@@ -202,11 +218,11 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
if err != nil {
panic(err)
}
from, to, message, _, err = parseFileHeader(postgresData)
from, to, message, txn, _, err = parseFileHeader(postgresData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", postgresName, err))
}
sqliteFrom, sqliteTo, sqliteMessage, _, err := parseFileHeader(sqliteData)
sqliteFrom, sqliteTo, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", sqliteName, err))
}
@@ -214,6 +230,8 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
panic(fmt.Errorf("mismatching versions in postgres and sqlite versions of %s: %d/%d -> %d/%d", name, from, sqliteFrom, to, sqliteTo))
} else if message != sqliteMessage {
panic(fmt.Errorf("mismatching message in postgres and sqlite versions of %s: %q != %q", name, message, sqliteMessage))
} else if txn != sqliteTxn {
panic(fmt.Errorf("mismatching transaction flag in postgres and sqlite versions of %s: %t != %t", name, txn, sqliteTxn))
}
fn = splitSQLUpgradeFunc(string(sqliteData), string(postgresData))
return
@@ -242,14 +260,14 @@ func (ut *UpgradeTable) RegisterFSPath(fs fullFS, dir string) {
} else if _, skip := skipNames[file.Name()]; skip {
// also do nothing
} else if splitName := splitFileNameRegex.FindStringSubmatch(file.Name()); splitName != nil {
from, to, message, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, message, fn)
from, to, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, message, txn, fn)
} else if data, err := fs.ReadFile(filepath.Join(dir, file.Name())); err != nil {
panic(err)
} else if from, to, message, lines, err := parseFileHeader(data); err != nil {
} else if from, to, message, txn, lines, err := parseFileHeader(data); err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
} else {
ut.Register(from, to, message, sqlUpgradeFunc(file.Name(), lines))
ut.Register(from, to, message, txn, sqlUpgradeFunc(file.Name(), lines))
}
}
}