update deps; experiment: log security

This commit is contained in:
Aine
2022-11-16 23:00:58 +02:00
parent 225ba2ee9b
commit 99a89ef87a
55 changed files with 883 additions and 308 deletions

View File

@@ -1,3 +1,30 @@
## v0.12.3 (2022-11-16)
* **Breaking change:** Added logging for row iteration in the dbutil package.
This changes the return type of `Query` methods from `*sql.Rows` to a new
`dbutil.Rows` interface.
* Added flag to disable wrapping database upgrades in a transaction (e.g. to
allow setting `PRAGMA`s for advanced table mutations on SQLite).
* Deprecated `MessageEventContent.GetReplyTo` in favor of directly using
`RelatesTo.GetReplyTo`. RelatesTo methods are nil-safe, so checking if
RelatesTo is nil is not necessary for using those methods.
* Added wrapper for space hierarchyendpoint (thanks to [@mgcm] in [#100]).
* Added bridge config option to handle transactions asynchronously.
* Added separate channels for to-device events in appservice transaction
handler to avoid blocking to-device events behind normal events.
* Added `RelatesTo.GetNonFallbackReplyTo` utility method to get the reply event
ID, unless the reply is a thread fallback.
* Added `event.TextToHTML` as an utility method to HTML-escape a string and
replace newlines with `<br/>`.
* Added check to bridge encryption helper to make sure the e2ee keys are still
on the server. Synapse is known to sometimes lose keys randomly.
* Changed bridge crypto syncer to crash on `M_UNKNOWN_TOKEN` errors instead of
retrying forever pointlessly.
* Fixed verifying signatures of fallback one-time keys.
[@mgcm]: https://github.com/mgcm
[#100]: https://github.com/mautrix/go/pull/100
## v0.12.2 (2022-10-16)
* Added utility method to redact bridge commands.

View File

@@ -96,11 +96,12 @@ type AppService struct {
txnIDC *TransactionIDCache
Events chan *event.Event `yaml:"-"`
DeviceLists chan *mautrix.DeviceLists `yaml:"-"`
OTKCounts chan *mautrix.OTKCount `yaml:"-"`
QueryHandler QueryHandler `yaml:"-"`
StateStore StateStore `yaml:"-"`
Events chan *event.Event `yaml:"-"`
ToDeviceEvents chan *event.Event `yaml:"-"`
DeviceLists chan *mautrix.DeviceLists `yaml:"-"`
OTKCounts chan *mautrix.OTKCount `yaml:"-"`
QueryHandler QueryHandler `yaml:"-"`
StateStore StateStore `yaml:"-"`
Router *mux.Router `yaml:"-"`
UserAgent string `yaml:"-"`
@@ -275,6 +276,7 @@ func (as *AppService) BotClient() *mautrix.Client {
// Init initializes the logger and loads the registration of this appservice.
func (as *AppService) Init() (bool, error) {
as.Events = make(chan *event.Event, EventChannelSize)
as.ToDeviceEvents = make(chan *event.Event, EventChannelSize)
as.OTKCounts = make(chan *mautrix.OTKCount, OTKChannelSize)
as.DeviceLists = make(chan *mautrix.DeviceLists, EventChannelSize)
as.QueryHandler = &QueryHandlerStub{}

View File

@@ -137,12 +137,22 @@ func (ep *EventProcessor) Dispatch(evt *event.Event) {
}
}
}
func (ep *EventProcessor) Start() {
func (ep *EventProcessor) startEvents() {
for {
select {
case evt := <-ep.as.Events:
ep.Dispatch(evt)
case <-ep.stop:
return
}
}
}
func (ep *EventProcessor) startEncryption() {
for {
select {
case evt := <-ep.as.ToDeviceEvents:
ep.Dispatch(evt)
case otk := <-ep.as.OTKCounts:
ep.DispatchOTK(otk)
case dl := <-ep.as.DeviceLists:
@@ -153,6 +163,11 @@ func (ep *EventProcessor) Start() {
}
}
func (ep *EventProcessor) Stop() {
ep.stop <- struct{}{}
func (ep *EventProcessor) Start() {
go ep.startEvents()
go ep.startEncryption()
}
func (ep *EventProcessor) Stop() {
close(ep.stop)
}

View File

@@ -206,13 +206,19 @@ func (as *AppService) handleEvents(evts []*event.Event, defaultTypeClass event.T
}
if evt.Type.IsState() {
// TODO remove this check after https://github.com/matrix-org/synapse/pull/11265
// TODO remove this check after making sure the log doesn't happen
historical, ok := evt.Content.Raw["org.matrix.msc2716.historical"].(bool)
if !ok || !historical {
if ok && historical {
as.Log.Warnfln("Received historical state event %s (%s/%s)", evt.ID, evt.Type.Type, evt.GetStateKey())
} else {
as.UpdateState(evt)
}
}
as.Events <- evt
if evt.Type.Class == event.ToDeviceEventType {
as.ToDeviceEvents <- evt
} else {
as.Events <- evt
}
}
}

View File

@@ -58,16 +58,17 @@ type AppserviceConfig struct {
ASToken string `yaml:"as_token"`
HSToken string `yaml:"hs_token"`
EphemeralEvents bool `yaml:"ephemeral_events"`
EphemeralEvents bool `yaml:"ephemeral_events"`
AsyncTransactions bool `yaml:"async_transactions"`
}
func (config *BaseConfig) MakeUserIDRegex() *regexp.Regexp {
usernamePlaceholder := util.RandomString(16)
func (config *BaseConfig) MakeUserIDRegex(matcher string) *regexp.Regexp {
usernamePlaceholder := strings.ToLower(util.RandomString(16))
usernameTemplate := fmt.Sprintf("@%s:%s",
config.Bridge.FormatUsername(usernamePlaceholder),
config.Homeserver.Domain)
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, ".+", 1)
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1)
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
return regexp.MustCompile(usernameTemplate)
}
@@ -84,7 +85,7 @@ func (config *BaseConfig) GenerateRegistration() *appservice.Registration {
regexp.QuoteMeta(config.AppService.Bot.Username),
regexp.QuoteMeta(config.Homeserver.Domain)))
registration.Namespaces.UserIDs.Register(botRegex, true)
registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(), true)
registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true)
return registration
}
@@ -230,6 +231,7 @@ func doUpgrade(helper *up.Helper) {
helper.Copy(up.Str, "appservice", "bot", "displayname")
helper.Copy(up.Str, "appservice", "bot", "avatar")
helper.Copy(up.Bool, "appservice", "ephemeral_events")
helper.Copy(up.Bool, "appservice", "async_transactions")
helper.Copy(up.Str, "appservice", "as_token")
helper.Copy(up.Str, "appservice", "hs_token")

View File

@@ -1468,6 +1468,18 @@ func (cli *Client) JoinedRooms() (resp *RespJoinedRooms, err error) {
return
}
// Hierarchy returns a list of rooms that are in the room's hierarchy. See https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
//
// The hierarchy API is provided to walk the space tree and discover the rooms with their aesthetic details. works in a depth-first manner:
// when it encounters another space as a child it recurses into that space before returning non-space children.
//
// The second function parameter specifies query parameters to limit the response. No query parameters will be added if it's nil.
func (cli *Client) Hierarchy(roomID id.RoomID, req *ReqHierarchy) (resp *RespHierarchy, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v1", "rooms", roomID, "hierarchy"}, req.Query())
_, err = cli.MakeRequest(http.MethodGet, urlPath, nil, &resp)
return
}
// Messages returns a list of message and state events for a room. It uses
// pagination query parameters to paginate history in the room.
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidmessages
@@ -1760,6 +1772,9 @@ func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBat
if req.BeeperNewMessages {
query["com.beeper.new_messages"] = "true"
}
if req.BeeperMarkReadBy != "" {
query["com.beeper.mark_read_by"] = req.BeeperMarkReadBy.String()
}
if len(req.BatchID) > 0 {
query["batch_id"] = req.BatchID.String()
}

View File

@@ -97,7 +97,7 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device
continue
}
identity := input[userID][deviceID]
if ok, err := olm.VerifySignatureJSON(oneTimeKey, userID, deviceID.String(), identity.SigningKey); err != nil {
if ok, err := olm.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil {
mach.Log.Error("Failed to verify signature for %s of %s: %v", deviceID, userID, err)
} else if !ok {
mach.Log.Warn("Invalid signature for %s of %s", deviceID, userID)

View File

@@ -445,15 +445,20 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
mach.keyWaitersLock.Lock()
ch, ok := mach.keyWaiters[sessionID]
if !ok {
ch := make(chan struct{})
ch = make(chan struct{})
mach.keyWaiters[sessionID] = ch
}
mach.keyWaitersLock.Unlock()
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
return true
}
select {
case <-ch:
return true
case <-time.After(timeout):
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
// Check if the session somehow appeared in the store without telling us
// We accept withheld sessions as received, as then the decryption attempt will show the error.
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)

View File

@@ -107,9 +107,13 @@ func (u *Utility) VerifySignature(message string, key id.Ed25519, signature stri
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
// If the _obj is a struct, the `json` tags will be honored.
func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return false, err
var err error
objJSON, ok := obj.(json.RawMessage)
if !ok {
objJSON, err = json.Marshal(obj)
if err != nil {
return false, err
}
}
sig := gjson.GetBytes(objJSON, util.GJSONPath("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
if !sig.Exists() || sig.Type != gjson.String {

View File

@@ -302,7 +302,7 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey
}, nil
}
func (store *SQLCryptoStore) scanGroupSessionList(rows *sql.Rows) (result []*InboundGroupSession, err error) {
func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*InboundGroupSession, err error) {
for rows.Next() {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
@@ -577,7 +577,7 @@ func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceI
// FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information.
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
var rows *sql.Rows
var rows dbutil.Rows
var err error
if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))

View File

@@ -22,7 +22,7 @@ const VersionTableName = "crypto_version"
var fs embed.FS
func init() {
Table.Register(-1, 3, "Unsupported version", func(tx dbutil.Transaction, database *dbutil.Database) error {
Table.Register(-1, 3, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
})
Table.RegisterFS(fs)

View File

@@ -58,6 +58,8 @@ var (
// The client attempted to join a room that has a version the server does not support.
// Inspect the room_version property of the error response for the room's version.
MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"}
// The client specified a parameter that has the wrong value.
MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM"}
)
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.

View File

@@ -124,9 +124,10 @@ func (evt *Event) GetStateKey() string {
}
type StrippedState struct {
Content Content `json:"content"`
Type Type `json:"type"`
StateKey string `json:"state_key"`
Content Content `json:"content"`
Type Type `json:"type"`
StateKey string `json:"state_key"`
Sender id.UserID `json:"sender"`
}
type Unsigned struct {

View File

@@ -138,9 +138,13 @@ func (content *MessageEventContent) SetEdit(original id.EventID) {
}
}
func TextToHTML(text string) string {
return strings.ReplaceAll(html.EscapeString(text), "\n", "<br/>")
}
func (content *MessageEventContent) EnsureHasHTML() {
if len(content.FormattedBody) == 0 || content.Format != FormatHTML {
content.FormattedBody = strings.ReplaceAll(html.EscapeString(content.Body), "\n", "<br/>")
content.FormattedBody = TextToHTML(content.Body)
content.Format = FormatHTML
}
}

View File

@@ -70,6 +70,13 @@ func (rel *RelatesTo) GetReplyTo() id.EventID {
return ""
}
func (rel *RelatesTo) GetNonFallbackReplyTo() id.EventID {
if rel != nil && rel.InReplyTo != nil && !rel.IsFallingBack {
return rel.InReplyTo.EventID
}
return ""
}
func (rel *RelatesTo) GetAnnotationID() id.EventID {
if rel != nil && rel.Type == RelAnnotation {
return rel.EventID

View File

@@ -35,7 +35,7 @@ func TrimReplyFallbackText(text string) string {
}
func (content *MessageEventContent) RemoveReplyFallback() {
if len(content.GetReplyTo()) > 0 && !content.replyFallbackRemoved {
if len(content.RelatesTo.GetReplyTo()) > 0 && !content.replyFallbackRemoved {
if content.Format == FormatHTML {
content.FormattedBody = TrimReplyFallbackHTML(content.FormattedBody)
}
@@ -44,11 +44,9 @@ func (content *MessageEventContent) RemoveReplyFallback() {
}
}
// Deprecated: RelatesTo methods are nil-safe, so RelatesTo.GetReplyTo can be used directly
func (content *MessageEventContent) GetReplyTo() id.EventID {
if content.RelatesTo != nil {
return content.RelatesTo.GetReplyTo()
}
return ""
return content.RelatesTo.GetReplyTo()
}
const ReplyFormat = `<mx-reply><blockquote><a href="https://matrix.to/#/%s/%s">In reply to</a> <a href="https://matrix.to/#/%s">%s</a><br>%s</blockquote></mx-reply>`

View File

@@ -2,6 +2,7 @@ package mautrix
import (
"encoding/json"
"strconv"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -172,10 +173,15 @@ type ReqAliasCreate struct {
}
type OneTimeKey struct {
Key id.Curve25519 `json:"key"`
IsSigned bool `json:"-"`
Signatures Signatures `json:"signatures,omitempty"`
Unsigned map[string]interface{} `json:"unsigned,omitempty"`
Key id.Curve25519 `json:"key"`
Fallback bool `json:"fallback,omitempty"`
Signatures Signatures `json:"signatures,omitempty"`
Unsigned map[string]any `json:"unsigned,omitempty"`
IsSigned bool `json:"-"`
// Raw data in the one-time key. This must be used for signature verification to ensure unrecognized fields
// aren't thrown away (because that would invalidate the signature).
RawData json.RawMessage `json:"-"`
}
type serializableOTK OneTimeKey
@@ -188,6 +194,7 @@ func (otk *OneTimeKey) UnmarshalJSON(data []byte) (err error) {
otk.IsSigned = false
} else {
err = json.Unmarshal(data, (*serializableOTK)(otk))
otk.RawData = data
otk.IsSigned = true
}
return err
@@ -319,7 +326,8 @@ type ReqBatchSend struct {
PrevEventID id.EventID `json:"-"`
BatchID id.BatchID `json:"-"`
BeeperNewMessages bool `json:"-"`
BeeperNewMessages bool `json:"-"`
BeeperMarkReadBy id.UserID `json:"-"`
StateEventsAtStart []*event.Event `json:"state_events_at_start"`
Events []*event.Event `json:"events"`
@@ -334,3 +342,41 @@ type ReqSetReadMarkers struct {
BeeperReadPrivateExtra interface{} `json:"com.beeper.read.private.extra"`
BeeperFullyReadExtra interface{} `json:"com.beeper.fully_read.extra"`
}
// ReqHierarchy contains the parameters for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
//
// As it's a GET method, there is no JSON body, so this is only query parameters.
type ReqHierarchy struct {
// A pagination token from a previous Hierarchy call.
// If specified, max_depth and suggested_only cannot be changed from the first request.
From string
// Limit for the maximum number of rooms to include per response.
// The server will apply a default value if a limit isn't provided.
Limit int
// Limit for how far to go into the space. When reached, no further child rooms will be returned.
// The server will apply a default value if a max depth isn't provided.
MaxDepth *int
// Flag to indicate whether the server should only consider suggested rooms.
// Suggested rooms are annotated in their m.space.child event contents.
SuggestedOnly bool
}
func (req *ReqHierarchy) Query() map[string]string {
query := map[string]string{}
if req == nil {
return query
}
if req.From != "" {
query["from"] = req.From
}
if req.Limit > 0 {
query["limit"] = strconv.Itoa(req.Limit)
}
if req.MaxDepth != nil {
query["max_depth"] = strconv.Itoa(*req.MaxDepth)
}
if req.SuggestedOnly {
query["suggested_only"] = "true"
}
return query
}

View File

@@ -12,6 +12,7 @@ import (
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util"
"maunium.net/go/mautrix/util/jsontime"
)
// RespWhoami is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami
@@ -514,3 +515,28 @@ func (vers *CapRoomVersions) IsAvailable(version string) bool {
_, available := vers.Available[version]
return available
}
// RespHierarchy is the JSON response for https://spec.matrix.org/v1.4/client-server-api/#get_matrixclientv1roomsroomidhierarchy
type RespHierarchy struct {
NextBatch string `json:"next_batch,omitempty"`
Rooms []ChildRoomsChunk `json:"rooms"`
}
type ChildRoomsChunk struct {
AvatarURL id.ContentURI `json:"avatar_url,omitempty"`
CanonicalAlias id.RoomAlias `json:"canonical_alias,omitempty"`
ChildrenState []StrippedStateWithTime `json:"children_state"`
GuestCanJoin bool `json:"guest_can_join"`
JoinRule event.JoinRule `json:"join_rule,omitempty"`
Name string `json:"name,omitempty"`
NumJoinedMembers int `json:"num_joined_members"`
RoomID id.RoomID `json:"room_id"`
RoomType event.RoomType `json:"room_type"`
Topic string `json:"topic,omitempty"`
WorldReadble bool `json:"world_readable"`
}
type StrippedStateWithTime struct {
event.StrippedState
Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
}

View File

@@ -15,7 +15,7 @@ import (
// LoggingExecable is a wrapper for anything with database Exec methods (i.e. sql.Conn, sql.DB and sql.Tx)
// that can preprocess queries (e.g. replacing $ with ? on SQLite) and log query durations.
type LoggingExecable struct {
UnderlyingExecable Execable
UnderlyingExecable UnderlyingExecable
db *Database
}
@@ -23,23 +23,30 @@ func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args .
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Exec", query, args, time.Since(start))
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start))
return res, err
}
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error) {
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Query", query, args, time.Since(start))
return rows, err
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start))
return &LoggingRows{
ctx: ctx,
db: le.db,
query: query,
args: args,
rs: rows,
start: start,
}, err
}
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, time.Since(start))
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start))
return row
}
@@ -47,7 +54,7 @@ func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result,
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (*sql.Rows, error) {
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
@@ -66,7 +73,7 @@ type loggingDB struct {
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
start := time.Now()
tx, err := ld.db.RawDB.BeginTx(ctx, opts)
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, time.Since(start))
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start))
if err != nil {
return nil, err
}
@@ -90,13 +97,76 @@ type LoggingTxn struct {
func (lt *LoggingTxn) Commit() error {
start := time.Now()
err := lt.UnderlyingTx.Commit()
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, time.Since(start))
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start))
return err
}
func (lt *LoggingTxn) Rollback() error {
start := time.Now()
err := lt.UnderlyingTx.Rollback()
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, time.Since(start))
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start))
return err
}
type LoggingRows struct {
ctx context.Context
db *Database
query string
args []interface{}
rs Rows
start time.Time
nrows int
}
func (lrs *LoggingRows) stopTiming() {
if !lrs.start.IsZero() {
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start))
lrs.start = time.Time{}
}
}
func (lrs *LoggingRows) Close() error {
err := lrs.rs.Close()
lrs.stopTiming()
return err
}
func (lrs *LoggingRows) ColumnTypes() ([]*sql.ColumnType, error) {
return lrs.rs.ColumnTypes()
}
func (lrs *LoggingRows) Columns() ([]string, error) {
return lrs.rs.Columns()
}
func (lrs *LoggingRows) Err() error {
return lrs.rs.Err()
}
func (lrs *LoggingRows) Next() bool {
hasNext := lrs.rs.Next()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) NextResultSet() bool {
hasNext := lrs.rs.NextResultSet()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) Scan(dest ...any) error {
return lrs.rs.Scan(dest...)
}

View File

@@ -40,13 +40,23 @@ func ParseDialect(engine string) (Dialect, error) {
switch strings.ToLower(engine) {
case "postgres", "postgresql":
return Postgres, nil
case "sqlite3", "sqlite", "litestream":
case "sqlite3", "sqlite", "litestream", "sqlite3-fk-wal":
return SQLite, nil
default:
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
}
}
type Rows interface {
Close() error
ColumnTypes() ([]*sql.ColumnType, error)
Columns() ([]string, error)
Err() error
Next() bool
NextResultSet() bool
Scan(...any) error
}
type Scannable interface {
Scan(...interface{}) error
}
@@ -54,19 +64,32 @@ type Scannable interface {
// Expected implementations of Scannable
var (
_ Scannable = (*sql.Row)(nil)
_ Scannable = (*sql.Rows)(nil)
_ Scannable = (Rows)(nil)
)
type ContextExecable interface {
type UnderlyingContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type ContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type UnderlyingExecable interface {
UnderlyingContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Execable interface {
ContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
Query(query string, args ...interface{}) (Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
@@ -78,11 +101,11 @@ type Transaction interface {
// Expected implementations of Execable
var (
_ Execable = (*sql.Tx)(nil)
_ Execable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ ContextExecable = (*sql.Conn)(nil)
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ UnderlyingContextExecable = (*sql.Conn)(nil)
)
type Database struct {

View File

@@ -11,10 +11,10 @@ import (
)
type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, duration time.Duration)
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration)
WarnUnsupportedVersion(current, latest int)
PrepareUpgrade(current, latest int)
DoUpgrade(from, to int, message string)
DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{})
}
@@ -25,10 +25,11 @@ var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ time.Duration) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration) {
}
type mauLogger struct {
l maulogger.Logger
@@ -46,11 +47,11 @@ func (m mauLogger) PrepareUpgrade(current, latest int) {
m.l.Infofln("Database currently on v%d, latest: v%d", current, latest)
}
func (m mauLogger) DoUpgrade(from, to int, message string) {
func (m mauLogger) DoUpgrade(from, to int, message string, _ bool) {
m.l.Infofln("Upgrading database from v%d to v%d: %s", from, to, message)
}
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, duration time.Duration) {
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration) {
if duration > 1*time.Second {
m.l.Warnfln("%s(%s) took %.3f seconds", method, query, duration.Seconds())
}
@@ -90,17 +91,18 @@ func (z zeroLogger) PrepareUpgrade(current, latest int) {
}
}
func (z zeroLogger) DoUpgrade(from, to int, message string) {
func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
z.l.Info().
Int("from", from).
Int("to", to).
Bool("single_txn", txn).
Str("description", message).
Msg("Upgrading database")
}
var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, duration time.Duration) {
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration) {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled {
log = z.l
@@ -108,6 +110,10 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
if log.GetLevel() != zerolog.TraceLevel && duration < 1*time.Second {
return
}
if nrows > -1 {
rowLog := log.With().Int("rows", nrows).Logger()
log = &rowLog
}
query = strings.TrimSpace(whitespaceRegex.ReplaceAllLiteralString(query, " "))
log.Trace().
Int64("duration_µs", duration.Microseconds()).

View File

@@ -0,0 +1,4 @@
-- v4: Sample outside transaction
-- transaction: off
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -12,13 +12,14 @@ import (
"fmt"
)
type upgradeFunc func(Transaction, *Database) error
type upgradeFunc func(Execable, *Database) error
type upgrade struct {
message string
fn upgradeFunc
upgradesTo int
upgradesTo int
transaction bool
}
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
@@ -93,7 +94,7 @@ func (db *Database) checkDatabaseOwner() error {
return nil
}
func (db *Database) setVersion(tx Transaction, version int) error {
func (db *Database) setVersion(tx Execable, version int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
@@ -129,25 +130,33 @@ func (db *Database) Upgrade() error {
version++
continue
}
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message)
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
var tx Transaction
tx, err = db.Begin()
if err != nil {
return err
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil {
return err
}
upgradeConn = tx
} else {
upgradeConn = db
}
err = upgradeItem.fn(tx, db)
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(tx, version)
err = db.setVersion(upgradeConn, version)
if err != nil {
return err
}
err = tx.Commit()
if err != nil {
return err
if tx != nil {
err = tx.Commit()
if err != nil {
return err
}
}
}
return nil

View File

@@ -29,14 +29,14 @@ func (ut *UpgradeTable) extend(toSize int) {
}
}
func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
func (ut *UpgradeTable) Register(from, to int, message string, txn bool, fn upgradeFunc) {
if from < 0 {
from += to
}
if from < 0 {
panic("invalid from value in UpgradeTable.Register() call")
}
upg := upgrade{message: message, fn: fn, upgradesTo: to}
upg := upgrade{message: message, fn: fn, upgradesTo: to, transaction: txn}
if len(*ut) == from {
*ut = append(*ut, upg)
return
@@ -57,7 +57,14 @@ func (ut *UpgradeTable) Register(from, to int, message string, fn upgradeFunc) {
// -- v1: Message
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte, err error) {
// To disable wrapping the upgrade in a single transaction, put `--transaction: off` on the second line.
//
// -- v5: Upgrade without transaction
// -- transaction: off
// // do dangerous stuff
var transactionDisableRegex = regexp.MustCompile(`^-- transaction: (\w*)`)
func parseFileHeader(file []byte) (from, to int, message string, txn bool, lines [][]byte, err error) {
lines = bytes.Split(file, []byte("\n"))
if len(lines) < 2 {
err = errors.New("upgrade file too short")
@@ -81,6 +88,15 @@ func parseFileHeader(file []byte) (from, to int, message string, lines [][]byte,
from = -1
}
message = string(match[3])
txn = true
match = transactionDisableRegex.FindSubmatch(lines[0])
if match != nil {
lines = lines[1:]
if string(match[1]) != "off" {
err = fmt.Errorf("invalid value %q for transaction flag", match[1])
}
txn = false
}
}
return
}
@@ -163,7 +179,7 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
}
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx Transaction, db *Database) error {
return func(tx Execable, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
return nil
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
@@ -176,7 +192,7 @@ func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
}
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
return func(tx Transaction, database *Database) (err error) {
return func(tx Execable, database *Database) (err error) {
switch database.Dialect {
case SQLite:
_, err = tx.Exec(sqliteData)
@@ -189,7 +205,7 @@ func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
}
}
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, fn upgradeFunc) {
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, txn bool, fn upgradeFunc) {
postgresName := fmt.Sprintf("%s.postgres.sql", name)
sqliteName := fmt.Sprintf("%s.sqlite.sql", name)
skipNames[postgresName] = struct{}{}
@@ -202,11 +218,11 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
if err != nil {
panic(err)
}
from, to, message, _, err = parseFileHeader(postgresData)
from, to, message, txn, _, err = parseFileHeader(postgresData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", postgresName, err))
}
sqliteFrom, sqliteTo, sqliteMessage, _, err := parseFileHeader(sqliteData)
sqliteFrom, sqliteTo, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", sqliteName, err))
}
@@ -214,6 +230,8 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
panic(fmt.Errorf("mismatching versions in postgres and sqlite versions of %s: %d/%d -> %d/%d", name, from, sqliteFrom, to, sqliteTo))
} else if message != sqliteMessage {
panic(fmt.Errorf("mismatching message in postgres and sqlite versions of %s: %q != %q", name, message, sqliteMessage))
} else if txn != sqliteTxn {
panic(fmt.Errorf("mismatching transaction flag in postgres and sqlite versions of %s: %t != %t", name, txn, sqliteTxn))
}
fn = splitSQLUpgradeFunc(string(sqliteData), string(postgresData))
return
@@ -242,14 +260,14 @@ func (ut *UpgradeTable) RegisterFSPath(fs fullFS, dir string) {
} else if _, skip := skipNames[file.Name()]; skip {
// also do nothing
} else if splitName := splitFileNameRegex.FindStringSubmatch(file.Name()); splitName != nil {
from, to, message, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, message, fn)
from, to, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, message, txn, fn)
} else if data, err := fs.ReadFile(filepath.Join(dir, file.Name())); err != nil {
panic(err)
} else if from, to, message, lines, err := parseFileHeader(data); err != nil {
} else if from, to, message, txn, lines, err := parseFileHeader(data); err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
} else {
ut.Register(from, to, message, sqlUpgradeFunc(file.Name(), lines))
ut.Register(from, to, message, txn, sqlUpgradeFunc(file.Name(), lines))
}
}
}

View File

@@ -0,0 +1,86 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package jsontime
import (
"encoding/json"
"time"
)
func UM(time time.Time) UnixMilli {
return UnixMilli{Time: time}
}
func UMInt(ts int64) UnixMilli {
return UM(time.UnixMilli(ts))
}
func UnixMilliNow() UnixMilli {
return UM(time.Now())
}
type UnixMilli struct {
time.Time
}
func (um UnixMilli) MarshalJSON() ([]byte, error) {
if um.IsZero() {
return []byte{'0'}, nil
}
return json.Marshal(um.UnixMilli())
}
func (um *UnixMilli) UnmarshalJSON(data []byte) error {
var val int64
err := json.Unmarshal(data, &val)
if err != nil {
return err
}
if val == 0 {
um.Time = time.Time{}
} else {
um.Time = time.UnixMilli(val)
}
return nil
}
func U(time time.Time) Unix {
return Unix{Time: time}
}
func UInt(ts int64) Unix {
return U(time.Unix(ts, 0))
}
func UnixNow() Unix {
return U(time.Now())
}
type Unix struct {
time.Time
}
func (u Unix) MarshalJSON() ([]byte, error) {
if u.IsZero() {
return []byte{'0'}, nil
}
return json.Marshal(u.Unix())
}
func (u *Unix) UnmarshalJSON(data []byte) error {
var val int64
err := json.Unmarshal(data, &val)
if err != nil {
return err
}
if val == 0 {
u.Time = time.Time{}
} else {
u.Time = time.Unix(val, 0)
}
return nil
}

View File

@@ -1,5 +1,5 @@
package mautrix
const Version = "v0.12.2"
const Version = "v0.12.3"
var DefaultUserAgent = "mautrix-go/" + Version