BREAKING: update mautrix to 0.15.x

This commit is contained in:
Aine
2023-06-01 14:32:20 +00:00
parent a6b20a75ab
commit 2bdb8ca635
222 changed files with 7851 additions and 23986 deletions

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,10 +7,15 @@
package crypto
import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -33,6 +38,18 @@ func getRelatesTo(content interface{}) *event.RelatesTo {
return nil
}
func getMentions(content interface{}) *event.Mentions {
contentStruct, ok := content.(*event.Content)
if ok {
content = contentStruct.Parsed
}
message, ok := content.(*event.MessageEventContent)
if ok {
return message.Mentions
}
return nil
}
type rawMegolmEvent struct {
RoomID id.RoomID `json:"room_id"`
Type event.Type `json:"type"`
@@ -44,12 +61,29 @@ func IsShareError(err error) bool {
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
}
func parseMessageIndex(ciphertext []byte) (uint64, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
var err error
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
if err != nil {
return 0, err
} else if decoded[0] != 3 || decoded[1] != 8 {
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
}
index, read := binary.Uvarint(decoded[2 : 2+binary.MaxVarintLen64])
if read <= 0 {
return 0, fmt.Errorf("failed to decode varint, read value %d", read)
}
return index, nil
}
// EncryptMegolmEvent encrypts data with the m.megolm.v1.aes-sha2 algorithm.
//
// If you use the event.Content struct, make sure you pass a pointer to the struct,
// as JSON serialization will not work correctly otherwise.
func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.Log.Trace("Encrypting event of type %s for %s", evtType.Type, roomID)
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
@@ -64,15 +98,28 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
if err != nil {
return nil, err
}
log := mach.machOrContextLog(ctx).With().
Str("event_type", evtType.Type).
Str("room_id", roomID.String()).
Str("session_id", session.ID().String()).
Logger()
log.Trace().Msg("Encrypting event...")
ciphertext, err := session.Encrypt(plaintext)
if err != nil {
return nil, err
}
idx, err := parseMessageIndex(ciphertext)
if err != nil {
log.Warn().Err(err).Msg("Failed to get megolm message index of encrypted event")
} else {
log = log.With().Uint64("message_index", idx).Logger()
}
log.Debug().Msg("Encrypted event successfully")
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
if err != nil {
mach.Log.Warn("Failed to update megolm session in crypto store after encrypting: %v", err)
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
}
return &event.EncryptedEventContent{
encrypted := &event.EncryptedEventContent{
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
MegolmCiphertext: ciphertext,
@@ -81,13 +128,19 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
// These are deprecated
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
}, nil
}
if mach.PlaintextMentions {
encrypted.Mentions = getMentions(content)
}
return encrypted, nil
}
func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroupSession {
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(idKey, signingKey, roomID, session.ID(), session.Internal.Key(), "create")
if !mach.DontStoreOutboundKeys {
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
}
return session
}
@@ -96,21 +149,38 @@ type deviceSessionWrapper struct {
identity *id.Device
}
func strishArray[T ~string](arr []T) []string {
out := make([]string, len(arr))
for i, item := range arr {
out[i] = string(item)
}
return out
}
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
//
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) error {
mach.Log.Debug("Sharing group session for room %s to %v", roomID, users)
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
return AlreadyShared
}
log := mach.machOrContextLog(ctx).With().
Str("room_id", roomID.String()).
Str("action", "share megolm session").
Logger()
ctx = log.WithContext(ctx)
if session == nil || session.Expired() {
session = mach.newOutboundGroupSession(roomID)
session = mach.newOutboundGroupSession(ctx, roomID)
}
log = log.With().Str("session_id", session.ID().String()).Logger()
ctx = log.WithContext(ctx)
log.Debug().Strs("users", strishArray(users)).Msg("Sharing group session for room")
withheldCount := 0
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
@@ -120,20 +190,25 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
var fetchKeys []id.UserID
for _, userID := range users {
log := log.With().Str("target_user_id", userID.String()).Logger()
devices, err := mach.CryptoStore.GetDevices(userID)
if err != nil {
mach.Log.Error("Failed to get devices of %s", userID)
log.Error().Err(err).Msg("Failed to get devices of user")
} else if devices == nil {
mach.Log.Trace("GetDevices returned nil for %s, will fetch keys and retry", userID)
log.Debug().Msg("GetDevices returned nil, will fetch keys and retry")
fetchKeys = append(fetchKeys, userID)
} else if len(devices) == 0 {
mach.Log.Trace("%s has no devices, skipping", userID)
log.Trace().Msg("User has no devices, skipping")
} else {
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s", session.ID(), userID)
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user")
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
mach.Log.Trace("Found %d sessions, withholding from %d sessions and missing %d sessions to encrypt %s for for %s", len(olmSessions[userID]), len(toDeviceWithheld.Messages[userID]), len(missingUserSessions), session.ID(), userID)
mach.findOlmSessionsForUser(ctx, session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
log.Debug().
Int("olm_session_count", len(olmSessions[userID])).
Int("withheld_count", len(toDeviceWithheld.Messages[userID])).
Int("missing_count", len(missingUserSessions)).
Msg("Completed first pass of finding olm sessions")
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(missingUserSessions) > 0 {
missingSessions[userID] = missingUserSessions
@@ -146,18 +221,21 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
}
if len(fetchKeys) > 0 {
mach.Log.Trace("Fetching missing keys for %v", fetchKeys)
for userID, devices := range mach.fetchKeys(fetchKeys, "", true) {
mach.Log.Trace("Got %d device keys for %s", len(devices), userID)
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
log.Debug().
Int("device_count", len(devices)).
Str("target_user_id", userID.String()).
Msg("Got device keys for user")
missingSessions[userID] = devices
}
}
if len(missingSessions) > 0 {
mach.Log.Trace("Creating missing outbound sessions")
err = mach.createOutboundSessions(missingSessions)
log.Debug().Msg("Creating missing olm sessions")
err = mach.createOutboundSessions(ctx, missingSessions)
if err != nil {
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
log.Error().Err(err).Msg("Failed to create missing olm sessions")
}
}
@@ -176,42 +254,51 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
withheld = make(map[id.DeviceID]*event.Content)
toDeviceWithheld.Messages[userID] = withheld
}
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s (post-fetch retry)", session.ID(), userID)
mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil)
mach.Log.Trace("Found %d sessions and withholding from %d sessions to encrypt %s for for %s (post-fetch retry)", len(output), len(withheld), session.ID(), userID)
log := log.With().Str("target_user_id", userID.String()).Logger()
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
log.Debug().
Int("olm_session_count", len(output)).
Int("withheld_count", len(withheld)).
Msg("Completed post-fetch retry of finding olm sessions")
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(toDeviceWithheld.Messages[userID]) == 0 {
delete(toDeviceWithheld.Messages, userID)
}
}
err = mach.encryptAndSendGroupSession(session, olmSessions)
err = mach.encryptAndSendGroupSession(ctx, session, olmSessions)
if err != nil {
return fmt.Errorf("failed to share group session: %w", err)
}
if len(toDeviceWithheld.Messages) > 0 {
mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID)
log.Debug().
Int("device_count", withheldCount).
Int("user_count", len(toDeviceWithheld.Messages)).
Msg("Sending to-device messages to report withheld key")
// TODO remove the next 4 lines once clients support m.room_key.withheld
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err)
log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)")
}
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
log.Warn().Err(err).Msg("Failed to report withheld keys")
}
}
mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID)
log.Debug().Msg("Group session successfully shared")
session.Shared = true
return mach.CryptoStore.AddOutboundGroupSession(session)
}
func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
mach.Log.Trace("Encrypting group session %s for all found devices", session.ID())
log := zerolog.Ctx(ctx)
log.Trace().Msg("Encrypting group session for all found devices")
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
for userID, sessions := range olmSessions {
@@ -221,31 +308,41 @@ func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
mach.Log.Trace("Encrypting group session %s for %s of %s", session.ID(), deviceID, userID)
content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
log.Trace().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Msg("Encrypting group session for device")
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
deviceCount++
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
log.Debug().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Msg("Encrypted group session for device")
}
}
mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID())
log.Debug().
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
return err
}
func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
for deviceID, device := range devices {
log := zerolog.Ctx(ctx).With().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Logger()
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
if state := session.Users[userKey]; state != OGSNotShared {
continue
} else if userID == mach.Client.UserID && deviceID == mach.Client.DeviceID {
session.Users[userKey] = OGSIgnored
} else if device.Trust == id.TrustStateBlacklisted {
mach.Log.Debug(
"Not encrypting group session %s for %s of %s: device is blacklisted",
session.ID(), deviceID, userID,
)
log.Debug().Msg("Not encrypting group session for device: device is blacklisted")
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
@@ -256,10 +353,10 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
}}
session.Users[userKey] = OGSIgnored
} else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust {
mach.Log.Debug(
"Not encrypting group session %s for %s of %s: device is not verified (minimum: %s, device: %s)",
session.ID(), deviceID, userID, mach.SendKeysMinTrust, trustState,
)
log.Debug().
Str("min_trust", mach.SendKeysMinTrust.String()).
Str("device_trust", trustState.String()).
Msg("Not encrypting group session for device: device is not trusted")
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
@@ -270,9 +367,9 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
}}
session.Users[userKey] = OGSIgnored
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
mach.Log.Error("Failed to get session for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
} else if deviceSession == nil {
mach.Log.Warn("Didn't find a session for %s of %s", deviceID, userID)
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")
if missingOutput != nil {
missingOutput[deviceID] = device
}