update deps; experiment: log security
This commit is contained in:
90
vendor/maunium.net/go/mautrix/util/dbutil/connlog.go
generated
vendored
90
vendor/maunium.net/go/mautrix/util/dbutil/connlog.go
generated
vendored
@@ -15,7 +15,7 @@ import (
|
||||
// LoggingExecable is a wrapper for anything with database Exec methods (i.e. sql.Conn, sql.DB and sql.Tx)
|
||||
// that can preprocess queries (e.g. replacing $ with ? on SQLite) and log query durations.
|
||||
type LoggingExecable struct {
|
||||
UnderlyingExecable Execable
|
||||
UnderlyingExecable UnderlyingExecable
|
||||
db *Database
|
||||
}
|
||||
|
||||
@@ -23,23 +23,30 @@ func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args .
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "Exec", query, args, time.Since(start))
|
||||
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start))
|
||||
return res, err
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
|
||||
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "Query", query, args, time.Since(start))
|
||||
return rows, err
|
||||
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start))
|
||||
return &LoggingRows{
|
||||
ctx: ctx,
|
||||
db: le.db,
|
||||
query: query,
|
||||
args: args,
|
||||
rs: rows,
|
||||
start: start,
|
||||
}, err
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, time.Since(start))
|
||||
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start))
|
||||
return row
|
||||
}
|
||||
|
||||
@@ -47,7 +54,7 @@ func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result,
|
||||
return le.ExecContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
func (le *LoggingExecable) Query(query string, args ...interface{}) (*sql.Rows, error) {
|
||||
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
|
||||
return le.QueryContext(context.Background(), query, args...)
|
||||
}
|
||||
|
||||
@@ -66,7 +73,7 @@ type loggingDB struct {
|
||||
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
|
||||
start := time.Now()
|
||||
tx, err := ld.db.RawDB.BeginTx(ctx, opts)
|
||||
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, time.Since(start))
|
||||
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -90,13 +97,76 @@ type LoggingTxn struct {
|
||||
func (lt *LoggingTxn) Commit() error {
|
||||
start := time.Now()
|
||||
err := lt.UnderlyingTx.Commit()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, time.Since(start))
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start))
|
||||
return err
|
||||
}
|
||||
|
||||
func (lt *LoggingTxn) Rollback() error {
|
||||
start := time.Now()
|
||||
err := lt.UnderlyingTx.Rollback()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, time.Since(start))
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start))
|
||||
return err
|
||||
}
|
||||
|
||||
type LoggingRows struct {
|
||||
ctx context.Context
|
||||
db *Database
|
||||
query string
|
||||
args []interface{}
|
||||
rs Rows
|
||||
start time.Time
|
||||
nrows int
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) stopTiming() {
|
||||
if !lrs.start.IsZero() {
|
||||
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start))
|
||||
lrs.start = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Close() error {
|
||||
err := lrs.rs.Close()
|
||||
lrs.stopTiming()
|
||||
return err
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) ColumnTypes() ([]*sql.ColumnType, error) {
|
||||
return lrs.rs.ColumnTypes()
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Columns() ([]string, error) {
|
||||
return lrs.rs.Columns()
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Err() error {
|
||||
return lrs.rs.Err()
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Next() bool {
|
||||
hasNext := lrs.rs.Next()
|
||||
|
||||
if !hasNext {
|
||||
lrs.stopTiming()
|
||||
} else {
|
||||
lrs.nrows++
|
||||
}
|
||||
|
||||
return hasNext
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) NextResultSet() bool {
|
||||
hasNext := lrs.rs.NextResultSet()
|
||||
|
||||
if !hasNext {
|
||||
lrs.stopTiming()
|
||||
} else {
|
||||
lrs.nrows++
|
||||
}
|
||||
|
||||
return hasNext
|
||||
}
|
||||
|
||||
func (lrs *LoggingRows) Scan(dest ...any) error {
|
||||
return lrs.rs.Scan(dest...)
|
||||
}
|
||||
|
||||
Reference in New Issue
Block a user