BREAKING: update mautrix to 0.15.x

This commit is contained in:
Aine
2023-06-01 14:32:20 +00:00
parent a6b20a75ab
commit 2bdb8ca635
222 changed files with 7851 additions and 23986 deletions

View File

@@ -1,4 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -55,8 +56,11 @@ func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err erro
return
}
mach.Log.Trace("Got cross-signing keys: Master `%v` Self-signing `%v` User-signing `%v`",
keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey)
mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
Msg("Imported own cross-signing keys")
mach.CrossSigningKeys = &keysCache
mach.crossSigningPubkeys = keysCache.PublicKeys()
@@ -76,8 +80,11 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil {
return nil, fmt.Errorf("failed to generate user-signing key: %w", err)
}
mach.Log.Debug("Generated cross-signing keys: Master: `%v` Self-signing: `%v` User-signing: `%v`",
keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey)
mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
Msg("Generated cross-signing keys")
return &keysCache, nil
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -32,7 +32,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa
}
cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID)
if err != nil {
mach.Log.Error("Failed to get own cross-signing public keys: %v", err)
mach.Log.Error().Err(err).Msg("Failed to get own cross-signing public keys")
return nil
}
mach.crossSigningPubkeys = cspk

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -79,7 +79,10 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
return err
}
mach.Log.Trace("Signed master key of %s with user-signing key: `%v`", userID, signature)
mach.Log.Debug().
Str("user_id", userID.String()).
Str("signature", signature).
Msg("Signed master key of user with our user-signing key")
if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
@@ -116,7 +119,10 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature,
},
}
mach.Log.Trace("Signed own master key with device %v: `%v`", deviceID, signature)
mach.Log.Debug().
Str("device_id", deviceID.String()).
Str("signature", signature).
Msg("Signed own master key with own device key")
resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{
userID: map[string]mautrix.ReqKeysSignatures{
@@ -165,7 +171,11 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
return err
}
mach.Log.Trace("Signed own device %s with self-signing key: `%v`", device.UserID, device.DeviceID, signature)
mach.Log.Debug().
Str("user_id", device.UserID.String()).
Str("device_id", device.DeviceID.String()).
Str("signature", signature).
Msg("Signed own device key with self-signing key")
if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,28 +8,36 @@
package crypto
import (
"context"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
log := mach.machOrContextLog(ctx)
for userID, userKeys := range crossSigningKeys {
log := log.With().Str("user_id", userID.String()).Logger()
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
if err != nil {
mach.Log.Error("Error fetching current cross-signing keys of user %v: %v", userID, err)
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
}
if currentKeys != nil {
for curKeyUsage, curKey := range currentKeys {
log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger()
// got a new key with the same usage as an existing key
for _, newKeyUsage := range userKeys.Usage {
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil {
mach.Log.Error("Error deleting old signatures made by %s (%s): %v", curKey, curKeyUsage, err)
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
mach.Log.Debug("Dropped %d signatures made by key %s (%s) as it has been replaced", count, curKey, curKeyUsage)
log.Debug().
Int64("signature_count", count).
Msg("Dropped signatures made by old key as it has been replaced")
}
}
break
@@ -39,10 +47,11 @@ func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mau
}
for _, key := range userKeys.Keys {
log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger()
for _, usage := range userKeys.Usage {
mach.Log.Debug("Storing cross-signing key for %s: %s (type %s)", userID, key, usage)
log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key")
if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil {
mach.Log.Error("Error storing cross-signing key: %v", err)
log.Error().Err(err).Msg("Error storing cross-signing key")
}
}
@@ -50,31 +59,38 @@ func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mau
for signKeyID, signature := range keySigs {
_, signKeyName := signKeyID.Parse()
signingKey := id.Ed25519(signKeyName)
log := log.With().
Str("sign_key_id", signKeyID.String()).
Str("signer_user_id", signUserID.String()).
Str("signing_key", signingKey.String()).
Logger()
// if the signer is one of this user's own devices, find the key from the key ID
if signUserID == userID {
ownDeviceID := id.DeviceID(signKeyName)
if ownDeviceKeys, ok := deviceKeys[userID][ownDeviceID]; ok {
signingKey = ownDeviceKeys.Keys.GetEd25519(ownDeviceID)
mach.Log.Trace("Treating %s as the device ID -> signing key %s", signKeyName, signingKey)
log.Trace().
Str("device_id", signKeyName).
Msg("Treating key name as device ID")
}
}
if len(signingKey) != 43 {
mach.Log.Trace("Cross-signing key %s/%s/%v has a signature from an unknown key %s", userID, key, userKeys.Usage, signKeyID)
log.Debug().Msg("Cross-signing key has a signature from an unknown key")
continue
}
mach.Log.Debug("Verifying cross-signing key %s/%s/%v with key %s/%s", userID, key, userKeys.Usage, signUserID, signingKey)
log.Debug().Msg("Verifying cross-signing key signature")
if verified, err := olm.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil {
mach.Log.Warn("Error while verifying signature from %s for %s: %v", signingKey, key, err)
log.Warn().Err(err).Msg("Error verifying cross-signing key signature")
} else {
if verified {
mach.Log.Debug("Signature from %s for %s verified", signingKey, key)
log.Debug().Err(err).Msg("Cross-signing key signature verified")
err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature)
if err != nil {
mach.Log.Warn("Failed to store signature from %s for %s: %v", signingKey, key, err)
log.Error().Err(err).Msg("Error storing cross-signing key signature")
}
} else {
mach.Log.Error("Invalid signature from %s for %s", signingKey, key)
log.Warn().Err(err).Msg("Cross-siging key signature is invalid")
}
}
}

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,52 +8,72 @@
package crypto
import (
"context"
"maunium.net/go/mautrix/id"
)
// ResolveTrust resolves the trust state of the device from cross-signing.
func (mach *OlmMachine) ResolveTrust(device *id.Device) id.TrustState {
state, _ := mach.ResolveTrustContext(context.Background(), device)
return state
}
// ResolveTrustContext resolves the trust state of the device from cross-signing.
func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Device) (id.TrustState, error) {
if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted {
return device.Trust
return device.Trust, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
if err != nil {
mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", device.UserID, err)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
Msg("Error retrieving cross-signing key of user from database")
return id.TrustStateUnset, err
}
theirMSK, ok := theirKeys[id.XSUsageMaster]
if !ok {
mach.Log.Error("Master key of user %v not found", device.UserID)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().
Str("user_id", device.UserID.String()).
Msg("Master key of user not found")
return id.TrustStateUnset, nil
}
theirSSK, ok := theirKeys[id.XSUsageSelfSigning]
if !ok {
mach.Log.Error("Self-signing key of user %v not found", device.UserID)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().
Str("user_id", device.UserID.String()).
Msg("Self-signing key of user not found")
return id.TrustStateUnset, nil
}
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
if err != nil {
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", device.UserID, err)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
Msg("Error retrieving cross-signing signatures for master key of user from database")
return id.TrustStateUnset, err
}
if !sskSigExists {
mach.Log.Warn("Self-signing key of user %v is not signed by their master key", device.UserID)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().
Str("user_id", device.UserID.String()).
Msg("Self-signing key of user is not signed by their master key")
return id.TrustStateUnset, nil
}
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
if err != nil {
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", device.UserID, err)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
Str("device_key", device.SigningKey.String()).
Msg("Error retrieving cross-signing signatures for device from database")
return id.TrustStateUnset, err
}
if deviceSigExists {
if mach.IsUserTrusted(device.UserID) {
return id.TrustStateCrossSignedVerified
if trusted, err := mach.IsUserTrusted(ctx, device.UserID); !trusted {
return id.TrustStateCrossSignedVerified, err
} else if theirMSK.Key == theirMSK.First {
return id.TrustStateCrossSignedTOFU
return id.TrustStateCrossSignedTOFU, nil
}
return id.TrustStateCrossSignedUntrusted
return id.TrustStateCrossSignedUntrusted, nil
}
return id.TrustStateUnset
return id.TrustStateUnset, nil
}
// IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing.
@@ -68,36 +88,42 @@ func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool {
// IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key.
// In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user.
func (mach *OlmMachine) IsUserTrusted(userID id.UserID) bool {
func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bool, error) {
csPubkeys := mach.GetOwnCrossSigningPublicKeys()
if csPubkeys == nil {
return false
return false, nil
}
if userID == mach.Client.UserID {
return true
return true, nil
}
// first we verify our user-signing key
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
if err != nil {
mach.Log.Error("Error retrieving our self-singing key signatures: %v", err)
return false
mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database")
return false, err
} else if !ourUserSigningKeyTrusted {
return false
return false, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
if err != nil {
mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", userID, err)
return false
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).
Msg("Error retrieving cross-signing key of user from database")
return false, err
}
theirMskKey, ok := theirKeys[id.XSUsageMaster]
if !ok {
mach.Log.Error("Master key of user %v not found", userID)
return false
mach.machOrContextLog(ctx).Error().
Str("user_id", userID.String()).
Msg("Master key of user not found")
return false, nil
}
sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
if err != nil {
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", userID, err)
return false
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).
Msg("Error retrieving cross-signing signatures for master key of user from database")
return false, err
}
return sigExists
return sigExists, nil
}

View File

@@ -0,0 +1,374 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package cryptohelper
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/sqlstatestore"
"maunium.net/go/mautrix/util/dbutil"
)
type CryptoHelper struct {
client *mautrix.Client
mach *crypto.OlmMachine
log zerolog.Logger
lock sync.RWMutex
pickleKey []byte
managedStateStore *sqlstatestore.SQLStateStore
unmanagedCryptoStore crypto.Store
dbForManagedStores *dbutil.Database
DecryptErrorCallback func(*event.Event, error)
LoginAs *mautrix.ReqLogin
DBAccountID string
}
var _ mautrix.CryptoHelper = (*CryptoHelper)(nil)
// NewCryptoHelper creates a struct that helps a mautrix client struct with Matrix e2ee operations.
//
// The client and pickle key are always required. Additionally, you must either:
// - Provide a crypto.Store here and set a StateStore in the client, or
// - Provide a dbutil.Database here to automatically create missing stores.
// - Provide a string here to use it as a path to a SQLite database, and then automatically create missing stores.
//
// The same database may be shared across multiple clients, but note that doing that will allow all clients access to
// decryption keys received by any one of the clients. For that reason, the pickle key must also be same for all clients
// using the same database.
func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoHelper, error) {
if len(pickleKey) == 0 {
return nil, fmt.Errorf("pickle key must be provided")
}
_, isExtensible := cli.Syncer.(mautrix.ExtensibleSyncer)
if !isExtensible {
return nil, fmt.Errorf("the client syncer must implement ExtensibleSyncer")
}
var managedStateStore *sqlstatestore.SQLStateStore
var dbForManagedStores *dbutil.Database
var unmanagedCryptoStore crypto.Store
switch typedStore := store.(type) {
case crypto.Store:
if cli.StateStore == nil {
return nil, fmt.Errorf("when passing a crypto.Store to NewCryptoHelper, the client must have a state store set beforehand")
} else if _, isCryptoCompatible := cli.StateStore.(crypto.StateStore); !isCryptoCompatible {
return nil, fmt.Errorf("the client state store must implement crypto.StateStore")
}
unmanagedCryptoStore = typedStore
case string:
db, err := dbutil.NewWithDialect(typedStore, "sqlite3")
if err != nil {
return nil, err
}
dbForManagedStores = db
case *dbutil.Database:
dbForManagedStores = typedStore
default:
return nil, fmt.Errorf("you must pass a *dbutil.Database or *crypto.StateStore to NewCryptoHelper")
}
log := cli.Log.With().Str("component", "crypto").Logger()
if cli.StateStore == nil && dbForManagedStores != nil {
managedStateStore = sqlstatestore.NewSQLStateStore(dbForManagedStores, dbutil.ZeroLogger(log.With().Str("db_section", "matrix_state").Logger()), false)
cli.StateStore = managedStateStore
} else if _, isCryptoCompatible := cli.StateStore.(crypto.StateStore); !isCryptoCompatible {
return nil, fmt.Errorf("the client state store must implement crypto.StateStore")
}
return &CryptoHelper{
client: cli,
log: log,
pickleKey: pickleKey,
unmanagedCryptoStore: unmanagedCryptoStore,
managedStateStore: managedStateStore,
dbForManagedStores: dbForManagedStores,
DecryptErrorCallback: func(_ *event.Event, _ error) {},
}, nil
}
func (helper *CryptoHelper) Init() error {
if helper == nil {
return fmt.Errorf("crypto helper is nil")
}
syncer, ok := helper.client.Syncer.(mautrix.ExtensibleSyncer)
if !ok {
return fmt.Errorf("the client syncer must implement ExtensibleSyncer")
}
var stateStore crypto.StateStore
if helper.managedStateStore != nil {
err := helper.managedStateStore.Upgrade()
if err != nil {
return fmt.Errorf("failed to upgrade client state store: %w", err)
}
stateStore = helper.managedStateStore
} else {
stateStore = helper.client.StateStore.(crypto.StateStore)
}
var cryptoStore crypto.Store
if helper.unmanagedCryptoStore == nil {
managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey)
if helper.client.Store == nil {
helper.client.Store = managedCryptoStore
} else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory {
helper.client.Store = managedCryptoStore
}
err := managedCryptoStore.DB.Upgrade()
if err != nil {
return fmt.Errorf("failed to upgrade crypto state store: %w", err)
}
storedDeviceID := managedCryptoStore.FindDeviceID()
if helper.LoginAs != nil {
if storedDeviceID != "" {
helper.LoginAs.DeviceID = storedDeviceID
}
helper.LoginAs.StoreCredentials = true
helper.log.Debug().
Str("username", helper.LoginAs.Identifier.User).
Str("device_id", helper.LoginAs.DeviceID.String()).
Msg("Logging in")
_, err = helper.client.Login(helper.LoginAs)
if err != nil {
return err
}
if storedDeviceID == "" {
managedCryptoStore.DeviceID = helper.client.DeviceID
}
} else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID {
return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID)
}
cryptoStore = managedCryptoStore
} else {
if helper.LoginAs != nil {
return fmt.Errorf("LoginAs can only be used with a managed crypto store")
}
cryptoStore = helper.unmanagedCryptoStore
}
if helper.client.DeviceID == "" || helper.client.UserID == "" {
return fmt.Errorf("the client must be logged in")
}
helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore)
err := helper.mach.Load()
if err != nil {
return fmt.Errorf("failed to load olm account: %w", err)
} else if err = helper.verifyDeviceKeysOnServer(); err != nil {
return err
}
syncer.OnSync(helper.mach.ProcessSyncResponse)
syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent)
if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok {
syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted)
} else {
helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.")
}
if helper.managedStateStore != nil {
syncer.OnEvent(helper.client.StateStoreSyncHandler)
}
return nil
}
func (helper *CryptoHelper) Close() error {
if helper != nil && helper.dbForManagedStores != nil {
err := helper.dbForManagedStores.RawDB.Close()
if err != nil {
return err
}
}
return nil
}
func (helper *CryptoHelper) Machine() *crypto.OlmMachine {
if helper == nil || helper.mach == nil {
panic("Machine() called before initing CryptoHelper")
}
return helper.mach
}
func (helper *CryptoHelper) verifyDeviceKeysOnServer() error {
helper.log.Debug().Msg("Making sure our device has the expected keys on the server")
resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
helper.client.UserID: {helper.client.DeviceID},
},
})
if err != nil {
return fmt.Errorf("failed to query own keys to make sure device is properly configured: %w", err)
}
ownID := helper.mach.OwnIdentity()
isShared := helper.mach.GetAccount().Shared
device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID]
if !ok || len(device.Keys) == 0 {
if isShared {
return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server")
} else {
helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine")
return nil
}
} else if !isShared {
return fmt.Errorf("olm account is not marked as shared, but there are keys on the server")
} else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed {
return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed)
}
if !isShared {
helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?")
} else {
helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine")
}
return nil
}
var NoSessionFound = crypto.NoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) {
if helper == nil {
return
}
content := evt.Content.AsEncrypted()
log := helper.log.With().
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
log.Debug().Msg("Decrypting received event")
decrypted, err := helper.Decrypt(evt)
if errors.Is(err, NoSessionFound) {
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = helper.Decrypt(evt)
} else {
go helper.waitLongerForSession(log, src, evt)
return
}
}
if err != nil {
log.Warn().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
return
}
helper.postDecrypt(src, decrypted)
}
func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) {
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted)
}
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
if helper == nil {
return
}
helper.lock.RLock()
defer helper.lock.RUnlock()
if deviceID == "" {
deviceID = "*"
}
// TODO get log from context
log := helper.log.With().
Str("session_id", sessionID.String()).
Str("user_id", userID.String()).
Str("device_id", deviceID.String()).
Str("room_id", roomID.String()).
Logger()
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
userID: {deviceID},
helper.client.UserID: {"*"},
})
if err != nil {
log.Warn().Err(err).Msg("Failed to send key request")
} else {
log.Debug().Msg("Sent key request")
}
}
func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
go helper.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
log.Debug().Msg("Didn't get session, giving up")
helper.DecryptErrorCallback(evt, NoSessionFound)
return
}
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
decrypted, err := helper.Decrypt(evt)
if err != nil {
log.Error().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
return
}
helper.postDecrypt(src, decrypted)
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
if helper == nil {
return false
}
helper.lock.RLock()
defer helper.lock.RUnlock()
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
helper.lock.RLock()
defer helper.lock.RUnlock()
ctx := context.TODO()
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
return
}
helper.log.Debug().
Err(err).
Str("room_id", roomID.String()).
Msg("Got session error while encrypting event, sharing group session and trying again")
var users []id.UserID
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID)
if err != nil {
err = fmt.Errorf("failed to get room member list: %w", err)
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {
err = fmt.Errorf("failed to share group session: %w", err)
} else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil {
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,11 +7,14 @@
package crypto
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -23,6 +26,7 @@ var (
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
RatchetError = errors.New("failed to ratchet session after use")
)
type megolmEvent struct {
@@ -32,13 +36,21 @@ type megolmEvent struct {
}
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, error) {
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, IncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmMegolmV1 {
return nil, UnsupportedAlgorithm
}
log := mach.machOrContextLog(ctx).With().
Str("action", "decrypt megolm event").
Str("event_id", evt.ID.String()).
Str("sender", evt.Sender.String()).
Str("sender_key", content.SenderKey.String()).
Str("session_id", content.SessionID.String()).
Logger()
ctx = log.WithContext(ctx)
encryptionRoomID := evt.RoomID
// Allow the server to move encrypted events between rooms if both the real room and target room are on a non-federatable .local domain.
// The message index checks to prevent replay attacks still apply and aren't based on the room ID,
@@ -46,22 +58,11 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if origRoomID, ok := evt.Content.Raw["com.beeper.original_room_id"].(string); ok && strings.HasSuffix(origRoomID, ".local") && strings.HasSuffix(evt.RoomID.String(), ".local") {
encryptionRoomID = id.RoomID(origRoomID)
}
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
sess, plaintext, messageIndex, err := mach.actuallyDecryptMegolmEvent(ctx, evt, encryptionRoomID, content)
if err != nil {
return nil, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return nil, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt megolm event: %w", err)
} else if ok, err = mach.CryptoStore.ValidateMessageIndex(sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return nil, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return nil, DuplicateMessageIndex
return nil, err
}
log = log.With().Uint("message_index", messageIndex).Logger()
var trustLevel id.TrustState
var forwardedKeys bool
@@ -70,14 +71,16 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if sess.SigningKey == ownSigningKey && sess.SenderKey == ownIdentityKey && len(sess.ForwardingChains) == 0 {
trustLevel = id.TrustStateVerified
} else {
device, err = mach.GetOrFetchDeviceByKey(evt.Sender, sess.SenderKey)
device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey)
if err != nil {
// We don't want to throw these errors as the message can still be decrypted.
mach.Log.Debug("Failed to get device %s/%s to verify session %s: %v", evt.Sender, sess.SenderKey, sess.ID(), err)
log.Debug().Err(err).Msg("Failed to get device to verify session")
trustLevel = id.TrustStateUnknownDevice
} else if len(sess.ForwardingChains) == 0 || (len(sess.ForwardingChains) == 1 && sess.ForwardingChains[0] == sess.SenderKey.String()) {
if device == nil {
mach.Log.Debug("Couldn't resolve trust level of session %s: sent by unknown device %s/%s", sess.ID(), evt.Sender, sess.SenderKey)
log.Debug().Err(err).
Str("session_sender_key", sess.SenderKey.String()).
Msg("Couldn't resolve trust level of session: sent by unknown device")
trustLevel = id.TrustStateUnknownDevice
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
return nil, DeviceKeyMismatch
@@ -91,7 +94,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if device != nil {
trustLevel = mach.ResolveTrust(device)
} else {
mach.Log.Debug("Couldn't resolve trust level of session %s: forwarding chain ends with unknown device %s", sess.ID(), lastChainItem)
log.Debug().
Str("forward_last_sender_key", lastChainItem).
Msg("Couldn't resolve trust level of session: forwarding chain ends with unknown device")
trustLevel = id.TrustStateForwarded
}
}
@@ -105,10 +110,11 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
return nil, WrongRoom
}
megolmEvt.Type.Class = evt.Type.Class
log = log.With().Str("decrypted_event_type", megolmEvt.Type.Repr()).Logger()
err = megolmEvt.Content.ParseRaw(megolmEvt.Type)
if err != nil {
if errors.Is(err, event.ErrUnsupportedContentType) {
mach.Log.Warn("Unsupported event type %s in encrypted event %s", megolmEvt.Type.Repr(), evt.ID)
log.Warn().Msg("Unsupported event type in encrypted event")
} else {
return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err)
}
@@ -119,13 +125,14 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if relatable.OptionalGetRelatesTo() == nil {
relatable.SetRelatesTo(content.RelatesTo)
} else {
mach.Log.Trace("Not overriding relation data in %s, as encrypted payload already has it", evt.ID)
log.Trace().Msg("Not overriding relation data as encrypted payload already has it")
}
}
if _, hasRelation := megolmEvt.Content.Raw["m.relates_to"]; !hasRelation {
megolmEvt.Content.Raw["m.relates_to"] = evt.Content.Raw["m.relates_to"]
}
}
log.Debug().Msg("Event decrypted successfully")
megolmEvt.Type.Class = evt.Type.Class
return &event.Event{
Sender: evt.Sender,
@@ -144,3 +151,106 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
},
}, nil
}
func removeItem(slice []uint, item uint) ([]uint, bool) {
for i, s := range slice {
if s == item {
return append(slice[:i], slice[i+1:]...), true
}
}
return slice, false
}
const missedIndexCutoff = 10
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
mach.megolmDecryptLock.Lock()
defer mach.megolmDecryptLock.Unlock()
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return sess, nil, 0, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
return sess, nil, 0, fmt.Errorf("failed to decrypt megolm event: %w", err)
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return sess, nil, messageIndex, DuplicateMessageIndex
}
expectedMessageIndex := sess.RatchetSafety.NextIndex
didModify := false
switch {
case messageIndex > expectedMessageIndex:
// When the index jumps, add indices in between to the missed indices list.
for i := expectedMessageIndex; i < messageIndex; i++ {
sess.RatchetSafety.MissedIndices = append(sess.RatchetSafety.MissedIndices, i)
}
fallthrough
case messageIndex == expectedMessageIndex:
// When the index moves forward (to the next one or jumping ahead), update the last received index.
sess.RatchetSafety.NextIndex = messageIndex + 1
didModify = true
default:
sess.RatchetSafety.MissedIndices, didModify = removeItem(sess.RatchetSafety.MissedIndices, messageIndex)
}
// Use presence of ReceivedAt as a sign that this is a recent megolm session,
// and therefore it's safe to drop missed indices entirely.
if !sess.ReceivedAt.IsZero() && len(sess.RatchetSafety.MissedIndices) > 0 && int(sess.RatchetSafety.MissedIndices[0]) < int(sess.RatchetSafety.NextIndex)-missedIndexCutoff {
limit := sess.RatchetSafety.NextIndex - missedIndexCutoff
var cutoff int
for ; cutoff < len(sess.RatchetSafety.MissedIndices) && sess.RatchetSafety.MissedIndices[cutoff] < limit; cutoff++ {
}
sess.RatchetSafety.LostIndices = append(sess.RatchetSafety.LostIndices, sess.RatchetSafety.MissedIndices[:cutoff]...)
sess.RatchetSafety.MissedIndices = sess.RatchetSafety.MissedIndices[cutoff:]
didModify = true
}
ratchetTargetIndex := uint32(sess.RatchetSafety.NextIndex)
if len(sess.RatchetSafety.MissedIndices) > 0 {
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
}
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
log := zerolog.Ctx(ctx).With().
Uint32("prev_ratchet_index", ratchetCurrentIndex).
Uint32("new_ratchet_index", ratchetTargetIndex).
Uint("next_new_index", sess.RatchetSafety.NextIndex).
Uints("missed_indices", sess.RatchetSafety.MissedIndices).
Uints("lost_indices", sess.RatchetSafety.LostIndices).
Int("max_messages", sess.MaxMessages).
Logger()
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Deleted fully used session")
}
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
}
} else {
log.Debug().Msg("Ratchet safety data didn't change")
}
return sess, plaintext, messageIndex, nil
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2021 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,11 +7,14 @@
package crypto
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -43,7 +46,7 @@ type DecryptedOlmEvent struct {
Content event.Content `json:"content"`
}
func (mach *OlmMachine) decryptOlmEvent(evt *event.Event, traceID string) (*DecryptedOlmEvent, error) {
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, IncorrectEncryptedContentType
@@ -54,7 +57,7 @@ func (mach *OlmMachine) decryptOlmEvent(evt *event.Event, traceID string) (*Decr
if !ok {
return nil, NotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body, traceID)
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
return nil, err
}
@@ -66,19 +69,19 @@ type OlmEventKeys struct {
Ed25519 id.Ed25519 `json:"ed25519"`
}
func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) (*DecryptedOlmEvent, error) {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
return nil, UnsupportedOlmMessageType
}
endTimeTrace := mach.timeTrace("decrypting olm ciphertext", traceID, 5*time.Second)
plaintext, err := mach.tryDecryptOlmCiphertext(sender, senderKey, olmType, ciphertext, traceID)
endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second)
plaintext, err := mach.tryDecryptOlmCiphertext(ctx, sender, senderKey, olmType, ciphertext)
endTimeTrace()
if err != nil {
return nil, err
}
defer mach.timeTrace("parsing decrypted olm event", traceID, time.Second)()
defer mach.timeTrace(ctx, "parsing decrypted olm event", time.Second)()
var olmEvt DecryptedOlmEvent
err = json.Unmarshal(plaintext, &olmEvt)
@@ -103,17 +106,18 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey
return &olmEvt, nil
}
func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) {
endTimeTrace := mach.timeTrace("waiting for olm lock", traceID, 5*time.Second)
func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second)
mach.olmLock.Lock()
endTimeTrace()
defer mach.olmLock.Unlock()
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(senderKey, olmType, ciphertext, traceID)
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext)
if err != nil {
if err == DecryptionFailedWithMatchingSession {
mach.Log.Warn("Found matching session yet decryption failed for sender %s with key %s", sender, senderKey)
go mach.unwedgeDevice(sender, senderKey)
log.Warn().Msg("Found matching session, but decryption failed")
go mach.unwedgeDevice(log, sender, senderKey)
}
return nil, fmt.Errorf("failed to decrypt olm event: %w", err)
}
@@ -128,39 +132,44 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.S
// New sessions can only be created if it's a prekey message, we can't decrypt the message
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(sender, senderKey)
go mach.unwedgeDevice(log, sender, senderKey)
return nil, DecryptionFailedForNormalMessage
}
mach.Log.Trace("Trying to create inbound session for %s/%s", sender, senderKey)
endTimeTrace = mach.timeTrace("creating inbound olm session", traceID, time.Second)
session, err := mach.createInboundSession(senderKey, ciphertext)
log.Trace().Msg("Trying to create inbound session")
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
endTimeTrace()
if err != nil {
go mach.unwedgeDevice(sender, senderKey)
go mach.unwedgeDevice(log, sender, senderKey)
return nil, fmt.Errorf("failed to create new session from prekey message: %w", err)
}
mach.Log.Debug("Created inbound olm session %s for %s/%s: %s", session.ID(), sender, senderKey, session.Describe())
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
log.Debug().
Str("olm_session_description", session.Describe()).
Msg("Created inbound olm session")
ctx = log.WithContext(ctx)
endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting prekey olm message with %s/%s", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "decrypting prekey olm message", time.Second)
plaintext, err = session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
go mach.unwedgeDevice(sender, senderKey)
go mach.unwedgeDevice(log, sender, senderKey)
return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err)
}
endTimeTrace = mach.timeTrace(fmt.Sprintf("updating new session %s/%s in database", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
endTimeTrace()
if err != nil {
mach.Log.Warn("Failed to update new olm session in crypto store after decrypting: %v", err)
log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting")
}
return plaintext, nil
}
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) {
endTimeTrace := mach.timeTrace(fmt.Sprintf("getting sessions with %s", senderKey), traceID, time.Second)
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
sessions, err := mach.CryptoStore.GetSessions(senderKey)
endTimeTrace()
if err != nil {
@@ -168,8 +177,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
}
for _, session := range sessions {
log := log.With().Str("olm_session_id", session.ID().String()).Logger()
ctx := log.WithContext(ctx)
if olmType == id.OlmMsgTypePreKey {
endTimeTrace = mach.timeTrace(fmt.Sprintf("checking if prekey olm message matches session %s/%s", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "checking if prekey olm message matches session", time.Second)
matches, err := session.Internal.MatchesInboundSession(ciphertext)
endTimeTrace()
if err != nil {
@@ -178,8 +189,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
continue
}
}
mach.Log.Trace("Trying to decrypt olm message from %s with session %s: %s", senderKey, session.ID(), session.Describe())
endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting olm message with %s/%s", senderKey, session.ID()), traceID, time.Second)
log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message")
endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second)
plaintext, err := session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
@@ -187,20 +198,20 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
return nil, DecryptionFailedWithMatchingSession
}
} else {
endTimeTrace = mach.timeTrace(fmt.Sprintf("updating session %s/%s in database", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
endTimeTrace()
if err != nil {
mach.Log.Warn("Failed to update olm session in crypto store after decrypting: %v", err)
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
}
mach.Log.Trace("Decrypted olm message from %s with session %s", senderKey, session.ID())
log.Debug().Msg("Decrypted olm message")
return plaintext, nil
}
}
return nil, nil
}
func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext string) (*OlmSession, error) {
func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.SenderKey, ciphertext string) (*OlmSession, error) {
session, err := mach.account.NewInboundSessionFrom(senderKey, ciphertext)
if err != nil {
return nil, err
@@ -208,40 +219,44 @@ func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext
mach.saveAccount()
err = mach.CryptoStore.AddSession(senderKey, session)
if err != nil {
mach.Log.Error("Failed to store created inbound session: %v", err)
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session")
}
return session, nil
}
const MinUnwedgeInterval = 1 * time.Hour
func (mach *OlmMachine) unwedgeDevice(sender id.UserID, senderKey id.SenderKey) {
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
log = log.With().Str("action", "unwedge olm session").Logger()
ctx := log.WithContext(context.Background())
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
delta := time.Now().Sub(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, senderKey, delta)
log.Debug().
Str("previous_recreation", delta.String()).
Msg("Not creating new Olm session as it was already recreated recently")
mach.recentlyUnwedgedLock.Unlock()
return
}
mach.recentlyUnwedged[senderKey] = time.Now()
mach.recentlyUnwedgedLock.Unlock()
deviceIdentity, err := mach.GetOrFetchDeviceByKey(sender, senderKey)
deviceIdentity, err := mach.GetOrFetchDeviceByKey(ctx, sender, senderKey)
if err != nil {
mach.Log.Error("Failed to find device info by identity key: %v", err)
log.Error().Err(err).Msg("Failed to find device info by identity key")
return
} else if deviceIdentity == nil {
mach.Log.Warn("Didn't find identity of %s/%s, can't unwedge session", sender, senderKey)
log.Warn().Msg("Didn't find identity for device")
return
}
mach.Log.Debug("Creating new Olm session with %s/%s (key: %s)", sender, deviceIdentity.DeviceID, senderKey)
log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session")
mach.devicesToUnwedgeLock.Lock()
mach.devicesToUnwedge[senderKey] = true
mach.devicesToUnwedgeLock.Unlock()
err = mach.SendEncryptedToDevice(deviceIdentity, event.ToDeviceDummy, event.Content{})
err = mach.SendEncryptedToDevice(ctx, deviceIdentity, event.ToDeviceDummy, event.Content{})
if err != nil {
mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, senderKey, err)
log.Error().Err(err).Msg("Failed to send dummy event to unwedge session")
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,9 +7,12 @@
package crypto
import (
"context"
"errors"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
@@ -25,10 +28,12 @@ var (
)
func (mach *OlmMachine) LoadDevices(user id.UserID) map[id.DeviceID]*id.Device {
return mach.fetchKeys([]id.UserID{user}, "", true)[user]
// TODO proper context?
return mach.fetchKeys(context.TODO(), []id.UserID{user}, "", true)[user]
}
func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
log := zerolog.Ctx(ctx)
deviceKeys := resp.DeviceKeys[userID][deviceID]
for signerUserID, signerKeys := range deviceKeys.Signatures {
for signerKey, signature := range signerKeys {
@@ -43,17 +48,27 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.
if verified, err := olm.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified {
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
signature := deviceKeys.Signatures[signerUserID][id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]
mach.Log.Trace("Verified self-signing signature for device %s/%s: %s", signerUserID, deviceID, signature)
log.Trace().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signed_device_id", deviceID.String()).
Str("signature", signature).
Msg("Verified self-signing signature")
err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
if err != nil {
mach.Log.Warn("Failed to store self-signing signature for device %s/%s: %v", signerUserID, deviceID, err)
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signed_device_id", deviceID.String()).
Msg("Failed to store self-signing signature")
}
}
} else {
if err == nil {
err = errors.New("invalid signature")
}
mach.Log.Warn("Could not verify device self-signing signature for %s/%s: %v", signerUserID, deviceID, err)
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signed_device_id", deviceID.String()).
Msg("Failed to verify self-signing signature")
}
}
}
@@ -61,25 +76,29 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
if err != nil {
mach.Log.Warn("Failed to store self-signing signature for %s/%s: %v", signerUserID, signKey, err)
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signer_key", signKey).
Msg("Failed to store self-signing signature")
}
}
}
}
}
func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
// TODO this function should probably return errors
req := &mautrix.ReqQueryKeys{
DeviceKeys: mautrix.DeviceKeysRequest{},
Timeout: 10 * 1000,
Token: sinceToken,
}
log := mach.machOrContextLog(ctx)
if !includeUntracked {
var err error
users, err = mach.CryptoStore.FilterTrackedUsers(users)
if err != nil {
mach.Log.Warn("Failed to filter tracked user list: %v", err)
log.Warn().Err(err).Msg("Failed to filter tracked user list")
}
}
if len(users) == 0 {
@@ -88,62 +107,88 @@ func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeU
for _, userID := range users {
req.DeviceKeys[userID] = mautrix.DeviceIDList{}
}
mach.Log.Trace("Querying keys for %v", users)
log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users")
resp, err := mach.Client.QueryKeys(req)
if err != nil {
mach.Log.Warn("Failed to query keys: %v", err)
log.Error().Err(err).Msg("Failed to query keys")
return
}
for server, err := range resp.Failures {
mach.Log.Warn("Query keys failure for %s: %v", server, err)
log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server")
}
mach.Log.Trace("Query key result received with %d users", len(resp.DeviceKeys))
log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received")
data = make(map[id.UserID]map[id.DeviceID]*id.Device)
for userID, devices := range resp.DeviceKeys {
log := log.With().Str("user_id", userID.String()).Logger()
delete(req.DeviceKeys, userID)
newDevices := make(map[id.DeviceID]*id.Device)
existingDevices, err := mach.CryptoStore.GetDevices(userID)
if err != nil {
mach.Log.Warn("Failed to get existing devices for %s: %v", userID, err)
log.Warn().Err(err).Msg("Failed to get existing devices for user")
existingDevices = make(map[id.DeviceID]*id.Device)
}
mach.Log.Trace("Updating devices for %s, got %d devices, have %d in store", userID, len(devices), len(existingDevices))
log.Debug().
Int("new_device_count", len(devices)).
Int("old_device_count", len(existingDevices)).
Msg("Updating devices in store")
changed := false
for deviceID, deviceKeys := range devices {
log := log.With().Str("device_id", deviceID.String()).Logger()
existing, ok := existingDevices[deviceID]
if !ok {
// New device
changed = true
}
mach.Log.Trace("Validating device %s of %s", deviceID, userID)
log.Trace().Msg("Validating device")
newDevice, err := mach.validateDevice(userID, deviceID, deviceKeys, existing)
if err != nil {
mach.Log.Error("Failed to validate device %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to validate device")
} else if newDevice != nil {
newDevices[deviceID] = newDevice
mach.storeDeviceSelfSignatures(userID, deviceID, resp)
mach.storeDeviceSelfSignatures(ctx, userID, deviceID, resp)
}
}
mach.Log.Trace("Storing new device list for %s containing %d devices", userID, len(newDevices))
log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list")
err = mach.CryptoStore.PutDevices(userID, newDevices)
if err != nil {
mach.Log.Warn("Failed to update device list for %s: %v", userID, err)
log.Warn().Err(err).Msg("Failed to update device list")
}
data[userID] = newDevices
changed = changed || len(newDevices) != len(existingDevices)
if changed {
if mach.DeleteKeysOnDeviceDelete {
for deviceID := range newDevices {
delete(existingDevices, deviceID)
}
for _, device := range existingDevices {
log := log.With().
Str("device_id", device.DeviceID.String()).
Str("identity_key", device.IdentityKey.String()).
Str("signing_key", device.SigningKey.String()).
Logger()
sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed")
if err != nil {
log.Err(err).Msg("Failed to redact megolm sessions from deleted device")
} else {
log.Info().
Strs("session_ids", stringifyArray(sessionIDs)).
Msg("Redacted megolm sessions from deleted device")
}
}
}
mach.OnDevicesChanged(userID)
}
}
for userID := range req.DeviceKeys {
mach.Log.Warn("Didn't get any keys for user %s", userID)
log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user")
}
mach.storeCrossSigningKeys(resp.MasterKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(resp.SelfSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(resp.UserSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys)
return data
}
@@ -154,10 +199,16 @@ func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeU
// not need to be called manually.
func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) {
for _, roomID := range mach.StateStore.FindSharedRooms(userID) {
mach.Log.Debug("Devices of %s changed, invalidating group session for %s", userID, roomID)
mach.Log.Debug().
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Invalidating group session in room due to device change notification")
err := mach.CryptoStore.RemoveOutboundGroupSession(roomID)
if err != nil {
mach.Log.Warn("Failed to invalidate outbound group session of %s on device change for %s: %v", roomID, userID, err)
mach.Log.Warn().Err(err).
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Failed to invalidate outbound group session")
}
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,10 +7,15 @@
package crypto
import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -33,6 +38,18 @@ func getRelatesTo(content interface{}) *event.RelatesTo {
return nil
}
func getMentions(content interface{}) *event.Mentions {
contentStruct, ok := content.(*event.Content)
if ok {
content = contentStruct.Parsed
}
message, ok := content.(*event.MessageEventContent)
if ok {
return message.Mentions
}
return nil
}
type rawMegolmEvent struct {
RoomID id.RoomID `json:"room_id"`
Type event.Type `json:"type"`
@@ -44,12 +61,29 @@ func IsShareError(err error) bool {
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
}
func parseMessageIndex(ciphertext []byte) (uint64, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
var err error
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
if err != nil {
return 0, err
} else if decoded[0] != 3 || decoded[1] != 8 {
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
}
index, read := binary.Uvarint(decoded[2 : 2+binary.MaxVarintLen64])
if read <= 0 {
return 0, fmt.Errorf("failed to decode varint, read value %d", read)
}
return index, nil
}
// EncryptMegolmEvent encrypts data with the m.megolm.v1.aes-sha2 algorithm.
//
// If you use the event.Content struct, make sure you pass a pointer to the struct,
// as JSON serialization will not work correctly otherwise.
func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.Log.Trace("Encrypting event of type %s for %s", evtType.Type, roomID)
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
@@ -64,15 +98,28 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
if err != nil {
return nil, err
}
log := mach.machOrContextLog(ctx).With().
Str("event_type", evtType.Type).
Str("room_id", roomID.String()).
Str("session_id", session.ID().String()).
Logger()
log.Trace().Msg("Encrypting event...")
ciphertext, err := session.Encrypt(plaintext)
if err != nil {
return nil, err
}
idx, err := parseMessageIndex(ciphertext)
if err != nil {
log.Warn().Err(err).Msg("Failed to get megolm message index of encrypted event")
} else {
log = log.With().Uint64("message_index", idx).Logger()
}
log.Debug().Msg("Encrypted event successfully")
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
if err != nil {
mach.Log.Warn("Failed to update megolm session in crypto store after encrypting: %v", err)
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
}
return &event.EncryptedEventContent{
encrypted := &event.EncryptedEventContent{
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
MegolmCiphertext: ciphertext,
@@ -81,13 +128,19 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
// These are deprecated
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
}, nil
}
if mach.PlaintextMentions {
encrypted.Mentions = getMentions(content)
}
return encrypted, nil
}
func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroupSession {
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(idKey, signingKey, roomID, session.ID(), session.Internal.Key(), "create")
if !mach.DontStoreOutboundKeys {
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
}
return session
}
@@ -96,21 +149,38 @@ type deviceSessionWrapper struct {
identity *id.Device
}
func strishArray[T ~string](arr []T) []string {
out := make([]string, len(arr))
for i, item := range arr {
out[i] = string(item)
}
return out
}
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
//
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) error {
mach.Log.Debug("Sharing group session for room %s to %v", roomID, users)
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
return AlreadyShared
}
log := mach.machOrContextLog(ctx).With().
Str("room_id", roomID.String()).
Str("action", "share megolm session").
Logger()
ctx = log.WithContext(ctx)
if session == nil || session.Expired() {
session = mach.newOutboundGroupSession(roomID)
session = mach.newOutboundGroupSession(ctx, roomID)
}
log = log.With().Str("session_id", session.ID().String()).Logger()
ctx = log.WithContext(ctx)
log.Debug().Strs("users", strishArray(users)).Msg("Sharing group session for room")
withheldCount := 0
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
@@ -120,20 +190,25 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
var fetchKeys []id.UserID
for _, userID := range users {
log := log.With().Str("target_user_id", userID.String()).Logger()
devices, err := mach.CryptoStore.GetDevices(userID)
if err != nil {
mach.Log.Error("Failed to get devices of %s", userID)
log.Error().Err(err).Msg("Failed to get devices of user")
} else if devices == nil {
mach.Log.Trace("GetDevices returned nil for %s, will fetch keys and retry", userID)
log.Debug().Msg("GetDevices returned nil, will fetch keys and retry")
fetchKeys = append(fetchKeys, userID)
} else if len(devices) == 0 {
mach.Log.Trace("%s has no devices, skipping", userID)
log.Trace().Msg("User has no devices, skipping")
} else {
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s", session.ID(), userID)
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user")
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
mach.Log.Trace("Found %d sessions, withholding from %d sessions and missing %d sessions to encrypt %s for for %s", len(olmSessions[userID]), len(toDeviceWithheld.Messages[userID]), len(missingUserSessions), session.ID(), userID)
mach.findOlmSessionsForUser(ctx, session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
log.Debug().
Int("olm_session_count", len(olmSessions[userID])).
Int("withheld_count", len(toDeviceWithheld.Messages[userID])).
Int("missing_count", len(missingUserSessions)).
Msg("Completed first pass of finding olm sessions")
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(missingUserSessions) > 0 {
missingSessions[userID] = missingUserSessions
@@ -146,18 +221,21 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
}
if len(fetchKeys) > 0 {
mach.Log.Trace("Fetching missing keys for %v", fetchKeys)
for userID, devices := range mach.fetchKeys(fetchKeys, "", true) {
mach.Log.Trace("Got %d device keys for %s", len(devices), userID)
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
log.Debug().
Int("device_count", len(devices)).
Str("target_user_id", userID.String()).
Msg("Got device keys for user")
missingSessions[userID] = devices
}
}
if len(missingSessions) > 0 {
mach.Log.Trace("Creating missing outbound sessions")
err = mach.createOutboundSessions(missingSessions)
log.Debug().Msg("Creating missing olm sessions")
err = mach.createOutboundSessions(ctx, missingSessions)
if err != nil {
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
log.Error().Err(err).Msg("Failed to create missing olm sessions")
}
}
@@ -176,42 +254,51 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
withheld = make(map[id.DeviceID]*event.Content)
toDeviceWithheld.Messages[userID] = withheld
}
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s (post-fetch retry)", session.ID(), userID)
mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil)
mach.Log.Trace("Found %d sessions and withholding from %d sessions to encrypt %s for for %s (post-fetch retry)", len(output), len(withheld), session.ID(), userID)
log := log.With().Str("target_user_id", userID.String()).Logger()
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
log.Debug().
Int("olm_session_count", len(output)).
Int("withheld_count", len(withheld)).
Msg("Completed post-fetch retry of finding olm sessions")
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(toDeviceWithheld.Messages[userID]) == 0 {
delete(toDeviceWithheld.Messages, userID)
}
}
err = mach.encryptAndSendGroupSession(session, olmSessions)
err = mach.encryptAndSendGroupSession(ctx, session, olmSessions)
if err != nil {
return fmt.Errorf("failed to share group session: %w", err)
}
if len(toDeviceWithheld.Messages) > 0 {
mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID)
log.Debug().
Int("device_count", withheldCount).
Int("user_count", len(toDeviceWithheld.Messages)).
Msg("Sending to-device messages to report withheld key")
// TODO remove the next 4 lines once clients support m.room_key.withheld
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err)
log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)")
}
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
log.Warn().Err(err).Msg("Failed to report withheld keys")
}
}
mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID)
log.Debug().Msg("Group session successfully shared")
session.Shared = true
return mach.CryptoStore.AddOutboundGroupSession(session)
}
func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
mach.Log.Trace("Encrypting group session %s for all found devices", session.ID())
log := zerolog.Ctx(ctx)
log.Trace().Msg("Encrypting group session for all found devices")
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
for userID, sessions := range olmSessions {
@@ -221,31 +308,41 @@ func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
mach.Log.Trace("Encrypting group session %s for %s of %s", session.ID(), deviceID, userID)
content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
log.Trace().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Msg("Encrypting group session for device")
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
deviceCount++
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
log.Debug().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Msg("Encrypted group session for device")
}
}
mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID())
log.Debug().
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
return err
}
func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
for deviceID, device := range devices {
log := zerolog.Ctx(ctx).With().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Logger()
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
if state := session.Users[userKey]; state != OGSNotShared {
continue
} else if userID == mach.Client.UserID && deviceID == mach.Client.DeviceID {
session.Users[userKey] = OGSIgnored
} else if device.Trust == id.TrustStateBlacklisted {
mach.Log.Debug(
"Not encrypting group session %s for %s of %s: device is blacklisted",
session.ID(), deviceID, userID,
)
log.Debug().Msg("Not encrypting group session for device: device is blacklisted")
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
@@ -256,10 +353,10 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
}}
session.Users[userKey] = OGSIgnored
} else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust {
mach.Log.Debug(
"Not encrypting group session %s for %s of %s: device is not verified (minimum: %s, device: %s)",
session.ID(), deviceID, userID, mach.SendKeysMinTrust, trustState,
)
log.Debug().
Str("min_trust", mach.SendKeysMinTrust.String()).
Str("device_trust", trustState.String()).
Msg("Not encrypting group session for device: device is not trusted")
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
@@ -270,9 +367,9 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
}}
session.Users[userKey] = OGSIgnored
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
mach.Log.Error("Failed to get session for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
} else if deviceSession == nil {
mach.Log.Warn("Didn't find a session for %s of %s", deviceID, userID)
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")
if missingOutput != nil {
missingOutput[deviceID] = device
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,6 +7,7 @@
package crypto
import (
"context"
"encoding/json"
"fmt"
@@ -16,7 +17,7 @@ import (
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) encryptOlmEvent(session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent {
func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent {
evt := &DecryptedOlmEvent{
Sender: mach.Client.UserID,
SenderDevice: mach.Client.DeviceID,
@@ -30,11 +31,16 @@ func (mach *OlmMachine) encryptOlmEvent(session *OlmSession, recipient *id.Devic
if err != nil {
panic(err)
}
mach.Log.Trace("Encrypting olm message for %s with session %s: %s", recipient.IdentityKey, session.ID(), session.Describe())
log := mach.machOrContextLog(ctx)
log.Debug().
Str("recipient_identity_key", recipient.IdentityKey.String()).
Str("olm_session_id", session.ID().String()).
Str("olm_session_description", session.Describe()).
Msg("Encrypting olm message")
msgType, ciphertext := session.Encrypt(plaintext)
err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session)
if err != nil {
mach.Log.Warn("Failed to update olm session in crypto store after encrypting: %v", err)
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
}
return &event.EncryptedEventContent{
Algorithm: id.AlgorithmOlmV1,
@@ -61,7 +67,7 @@ func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool
return shouldUnwedge
}
func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.DeviceID]*id.Device) error {
func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id.UserID]map[id.DeviceID]*id.Device) error {
request := make(mautrix.OneTimeKeysRequest)
for userID, devices := range input {
request[userID] = make(map[id.DeviceID]id.KeyAlgorithm)
@@ -84,6 +90,7 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device
if err != nil {
return fmt.Errorf("failed to claim keys: %w", err)
}
log := mach.machOrContextLog(ctx)
for userID, user := range resp.OneTimeKeys {
for deviceID, oneTimeKeys := range user {
var oneTimeKey mautrix.OneTimeKey
@@ -91,25 +98,30 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device
for keyID, oneTimeKey = range oneTimeKeys {
break
}
keyAlg, keyIndex := keyID.Parse()
log := log.With().
Str("peer_user_id", userID.String()).
Str("peer_device_id", deviceID.String()).
Str("peer_otk_id", keyID.String()).
Logger()
keyAlg, _ := keyID.Parse()
if keyAlg != id.KeyAlgorithmSignedCurve25519 {
mach.Log.Warn("Unexpected key ID algorithm in one-time key response for %s of %s: %s", deviceID, userID, keyID)
log.Warn().Msg("Unexpected key ID algorithm in one-time key response")
continue
}
identity := input[userID][deviceID]
if ok, err := olm.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil {
mach.Log.Error("Failed to verify signature for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to verify signature of one-time key")
} else if !ok {
mach.Log.Warn("Invalid signature for %s of %s", deviceID, userID)
log.Warn().Msg("One-time key has invalid signature from device")
} else if sess, err := mach.account.Internal.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil {
mach.Log.Error("Failed to create outbound session for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
} else {
wrapped := wrapSession(sess)
err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped)
if err != nil {
mach.Log.Error("Failed to store created session for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to store created outbound session")
} else {
mach.Log.Debug("Created new Olm session with %s/%s (OTK ID: %s)", userID, deviceID, keyIndex)
log.Debug().Msg("Created new Olm session")
}
}
}

View File

@@ -103,7 +103,7 @@ func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error)
SenderKey: session.SenderKey,
SenderClaimedKeys: SenderClaimedKeys{},
SessionID: session.ID(),
SessionKey: key,
SessionKey: string(key),
}
}
return export, nil

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -17,6 +17,7 @@ import (
"encoding/json"
"errors"
"fmt"
"time"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
@@ -108,6 +109,8 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
RoomID: session.RoomID,
// TODO should we add something here to mark the signing key as unverified like key requests do?
ForwardingChains: session.ForwardingChains,
ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
@@ -136,14 +139,18 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er
count := 0
for _, session := range sessions {
log := mach.Log.With().
Str("room_id", session.RoomID.String()).
Str("session_id", session.SessionID.String()).
Logger()
imported, err := mach.importExportedRoomKey(session)
if err != nil {
mach.Log.Warn("Failed to import Megolm session %s/%s from file: %v", session.RoomID, session.SessionID, err)
log.Error().Err(err).Msg("Failed to import Megolm session from file")
} else if imported {
mach.Log.Debug("Imported Megolm session %s/%s from file", session.RoomID, session.SessionID)
log.Debug().Msg("Imported Megolm session from file")
count++
} else {
mach.Log.Debug("Skipped Megolm session %s/%s: already in store", session.RoomID, session.SessionID)
log.Debug().Msg("Skipped Megolm session which is already in the store")
}
}
return count, len(sessions), nil

View File

@@ -1,16 +1,18 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !nosas
// +build !nosas
package crypto
import (
"context"
"errors"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
@@ -56,11 +58,11 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
select {
case <-keyResponseReceived:
// key request successful
mach.Log.Debug("Key for session %v was received, cancelling other key requests", sessionID)
mach.Log.Debug().Msgf("Key for session %v was received, cancelling other key requests", sessionID)
resChan <- true
case <-ctx.Done():
// if the context is done, key request was unsuccessful
mach.Log.Debug("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID)
mach.Log.Debug().Msgf("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID)
resChan <- false
}
@@ -128,20 +130,41 @@ func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.Sender
return err
}
func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content *event.ForwardedRoomKeyEventContent) bool {
func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.ForwardedRoomKeyEventContent) bool {
log := zerolog.Ctx(ctx).With().
Str("session_id", content.SessionID.String()).
Str("room_id", content.RoomID.String()).
Logger()
if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" {
mach.Log.Debug("Ignoring weird forwarded room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID)
log.Debug().
Str("algorithm", string(content.Algorithm)).
Msg("Ignoring weird forwarded room key")
return false
}
igsInternal, err := olm.InboundGroupSessionImport([]byte(content.SessionKey))
if err != nil {
mach.Log.Error("Failed to import inbound group session: %v", err)
log.Error().Err(err).Msg("Failed to import inbound group session")
return false
} else if igsInternal.ID() != content.SessionID {
mach.Log.Warn("Mismatched session ID while creating inbound group session")
log.Warn().
Str("actual_session_id", igsInternal.ID().String()).
Msg("Mismatched session ID while creating inbound group session from forward")
return false
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
var maxAge time.Duration
var maxMessages int
if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
maxMessages = config.RotationPeriodMessages
}
if content.MaxAge != 0 {
maxAge = time.Duration(content.MaxAge) * time.Millisecond
}
if content.MaxMessages != 0 {
maxMessages = content.MaxMessages
}
igs := &InboundGroupSession{
Internal: *igsInternal,
SigningKey: evt.Keys.Ed25519,
@@ -149,14 +172,19 @@ func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content *
RoomID: content.RoomID,
ForwardingChains: append(content.ForwardingKeyChain, evt.SenderKey.String()),
id: content.SessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
}
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
if err != nil {
mach.Log.Error("Failed to store new inbound group session: %v", err)
log.Error().Err(err).Msg("Failed to store new inbound group session")
return false
}
mach.markSessionReceived(content.SessionID)
mach.Log.Trace("Received forwarded inbound group session %s/%s/%s", content.RoomID, content.SenderKey, content.SessionID)
log.Debug().Msg("Received forwarded inbound group session")
return true
}
@@ -175,50 +203,72 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id
}
err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content)
if err != nil {
mach.Log.Warn("Failed to send key share rejection %s to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err)
mach.Log.Warn().Err(err).
Str("code", string(rejection.Code)).
Str("user_id", device.UserID.String()).
Str("device_id", device.DeviceID.String()).
Msg("Failed to send key share rejection")
}
err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content)
if err != nil {
mach.Log.Warn("Failed to send key share rejection %s (org.matrix.) to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err)
mach.Log.Warn().Err(err).
Str("code", string(rejection.Code)).
Str("user_id", device.UserID.String()).
Str("device_id", device.DeviceID.String()).
Msg("Failed to send key share rejection (legacy event type)")
}
}
func (mach *OlmMachine) defaultAllowKeyShare(device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection {
func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection {
log := mach.machOrContextLog(ctx)
if mach.Client.UserID != device.UserID {
mach.Log.Debug("Ignoring key request from a different user (%s)", device.UserID)
log.Debug().Msg("Rejecting key request from a different user")
return &KeyShareRejectOtherUser
} else if mach.Client.DeviceID == device.DeviceID {
mach.Log.Debug("Ignoring key request from ourselves")
log.Debug().Msg("Ignoring key request from ourselves")
return &KeyShareRejectNoResponse
} else if device.Trust == id.TrustStateBlacklisted {
mach.Log.Debug("Ignoring key request from blacklisted device %s", device.DeviceID)
log.Debug().Msg("Rejecting key request from blacklisted device")
return &KeyShareRejectBlacklisted
} else if trustState := mach.ResolveTrust(device); trustState >= mach.ShareKeysMinTrust {
mach.Log.Debug("Accepting key request from device %s (trust state: %s)", device.DeviceID, trustState)
log.Debug().
Str("min_trust", mach.SendKeysMinTrust.String()).
Str("device_trust", trustState.String()).
Msg("Accepting key request from trusted device")
return nil
} else {
mach.Log.Debug("Ignoring key request from unverified device %s (trust state: %s)", device.DeviceID, trustState)
log.Debug().
Str("min_trust", mach.SendKeysMinTrust.String()).
Str("device_trust", trustState.String()).
Msg("Rejecting key request from untrusted device")
return &KeyShareRejectUnverified
}
}
func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.RoomKeyRequestEventContent) {
func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.UserID, content *event.RoomKeyRequestEventContent) {
log := zerolog.Ctx(ctx).With().
Str("request_id", content.RequestID).
Str("device_id", content.RequestingDeviceID.String()).
Str("room_id", content.Body.RoomID.String()).
Str("session_id", content.Body.SessionID.String()).
Logger()
ctx = log.WithContext(ctx)
if content.Action != event.KeyRequestActionRequest {
return
} else if content.RequestingDeviceID == mach.Client.DeviceID && sender == mach.Client.UserID {
mach.Log.Debug("Ignoring key request %s from ourselves", content.RequestID)
log.Debug().Msg("Ignoring key request from ourselves")
return
}
mach.Log.Debug("Received key request %s for %s from %s/%s", content.RequestID, content.Body.SessionID, sender, content.RequestingDeviceID)
log.Debug().Msg("Received key request")
device, err := mach.GetOrFetchDevice(sender, content.RequestingDeviceID)
device, err := mach.GetOrFetchDevice(ctx, sender, content.RequestingDeviceID)
if err != nil {
mach.Log.Error("Failed to fetch device %s/%s that requested keys: %v", sender, content.RequestingDeviceID, err)
log.Error().Err(err).Msg("Failed to fetch device that requested keys")
return
}
rejection := mach.AllowKeyShare(device, content.Body)
rejection := mach.AllowKeyShare(ctx, device, content.Body)
if rejection != nil {
mach.rejectKeyRequest(*rejection, device, content.Body)
return
@@ -226,18 +276,29 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
if err != nil {
mach.Log.Error("Failed to fetch group session to forward to %s/%s: %v", device.UserID, device.DeviceID, err)
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
} else {
log.Error().Err(err).Msg("Failed to get group session to forward")
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
}
return
} else if igs == nil {
mach.Log.Warn("Didn't find group session %s to forward to %s/%s", content.Body.SessionID, device.UserID, device.DeviceID)
log.Error().Msg("Didn't find group session to forward")
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
return
}
if internalID := igs.ID(); internalID != content.Body.SessionID {
// Should this be an error?
log = log.With().Str("unexpected_session_id", internalID.String()).Logger()
}
exportedKey, err := igs.Internal.Export(igs.Internal.FirstKnownIndex())
firstKnownIndex := igs.Internal.FirstKnownIndex()
log = log.With().Uint32("first_known_index", firstKnownIndex).Logger()
exportedKey, err := igs.Internal.Export(firstKnownIndex)
if err != nil {
mach.Log.Error("Failed to export session %s to forward to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err)
log.Error().Err(err).Msg("Failed to export group session to forward")
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
return
}
@@ -248,7 +309,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
Algorithm: id.AlgorithmMegolmV1,
RoomID: igs.RoomID,
SessionID: igs.ID(),
SessionKey: exportedKey,
SessionKey: string(exportedKey),
},
SenderKey: content.Body.SenderKey,
ForwardingKeyChain: igs.ForwardingChains,
@@ -256,9 +317,42 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
},
}
if err := mach.SendEncryptedToDevice(device, event.ToDeviceForwardedRoomKey, forwardedRoomKey); err != nil {
mach.Log.Error("Failed to send encrypted forwarded key %s to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err)
if err = mach.SendEncryptedToDevice(ctx, device, event.ToDeviceForwardedRoomKey, forwardedRoomKey); err != nil {
log.Error().Err(err).Msg("Failed to encrypt and send group session")
} else {
log.Debug().Msg("Successfully sent forwarded group session")
}
}
func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) {
log := mach.machOrContextLog(ctx).With().
Str("room_id", content.RoomID.String()).
Str("session_id", content.SessionID.String()).
Int("first_message_index", content.FirstMessageIndex).
Logger()
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Acked group session was already redacted")
} else {
log.Err(err).Msg("Failed to get group session to check if it should be redacted")
}
return
}
log = log.With().
Str("sender_key", sess.SenderKey.String()).
Str("own_identity", mach.OwnIdentity().IdentityKey.String()).
Logger()
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
if err != nil {
log.Err(err).Msg("Failed to redact group session")
}
} else {
log.Debug().Bool("inbound", isInbound).Msg("Received room key ack")
}
mach.Log.Debug("Sent encrypted forwarded key to device %s/%s for session %s", device.UserID, device.DeviceID, igs.ID())
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,12 +7,14 @@
package crypto
import (
"context"
"errors"
"fmt"
"sync"
"time"
"maunium.net/go/mautrix/appservice"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/id"
@@ -20,28 +22,21 @@ import (
"maunium.net/go/mautrix/event"
)
// Logger is a simple logging struct for OlmMachine.
// Implementations are recommended to use fmt.Sprintf and manually add a newline after the message.
type Logger interface {
Error(message string, args ...interface{})
Warn(message string, args ...interface{})
Debug(message string, args ...interface{})
Trace(message string, args ...interface{})
}
// OlmMachine is the main struct for handling Matrix end-to-end encryption.
type OlmMachine struct {
Client *mautrix.Client
SSSS *ssss.Machine
Log Logger
Log *zerolog.Logger
CryptoStore Store
StateStore StateStore
PlaintextMentions bool
SendKeysMinTrust id.TrustState
ShareKeysMinTrust id.TrustState
AllowKeyShare func(*id.Device, event.RequestedKeyInfo) *KeyShareRejection
AllowKeyShare func(context.Context, *id.Device, event.RequestedKeyInfo) *KeyShareRejection
DefaultSASTimeout time.Duration
// AcceptVerificationFrom determines whether the machine will accept verification requests from this device.
@@ -60,7 +55,9 @@ type OlmMachine struct {
recentlyUnwedged map[id.IdentityKey]time.Time
recentlyUnwedgedLock sync.Mutex
olmLock sync.Mutex
olmLock sync.Mutex
megolmEncryptLock sync.Mutex
megolmDecryptLock sync.Mutex
otkUploadLock sync.Mutex
lastOTKUpload time.Time
@@ -69,6 +66,13 @@ type OlmMachine struct {
crossSigningPubkeys *CrossSigningPublicKeysCache
crossSigningPubkeysFetched bool
DeleteOutboundKeysOnAck bool
DontStoreOutboundKeys bool
DeletePreviousKeysOnReceive bool
RatchetKeysOnDecrypt bool
DeleteFullyUsedKeysOnDecrypt bool
DeleteKeysOnDeviceDelete bool
}
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
@@ -82,7 +86,11 @@ type StateStore interface {
}
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateStore StateStore) *OlmMachine {
func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Store, stateStore StateStore) *OlmMachine {
if log == nil {
logPtr := zerolog.Nop()
log = &logPtr
}
mach := &OlmMachine{
Client: client,
SSSS: ssss.NewSSSSMachine(client),
@@ -111,6 +119,14 @@ func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateS
return mach
}
func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
return mach.Log
}
return log
}
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
// This must be called before using the machine.
func (mach *OlmMachine) Load() (err error) {
@@ -127,7 +143,7 @@ func (mach *OlmMachine) Load() (err error) {
func (mach *OlmMachine) saveAccount() {
err := mach.CryptoStore.PutAccount(mach.account)
if err != nil {
mach.Log.Error("Failed to save account: %v", err)
mach.Log.Error().Err(err).Msg("Failed to save account")
}
}
@@ -136,12 +152,15 @@ func (mach *OlmMachine) FlushStore() error {
return mach.CryptoStore.Flush()
}
func (mach *OlmMachine) timeTrace(thing, trace string, expectedDuration time.Duration) func() {
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
start := time.Now()
return func() {
duration := time.Now().Sub(start)
if duration > expectedDuration {
mach.Log.Warn("%s took %s (trace: %s)", thing, duration, trace)
zerolog.Ctx(ctx).Warn().
Str("action", thing).
Dur("duration", duration).
Msg("Executing encryption function took longer than expected")
}
}
}
@@ -156,7 +175,11 @@ func (mach *OlmMachine) Fingerprint() string {
return mach.account.SigningKey().Fingerprint()
}
// OwnIdentity returns this device's DeviceIdentity struct
func (mach *OlmMachine) GetAccount() *OlmAccount {
return mach.account
}
// OwnIdentity returns this device's id.Device struct
func (mach *OlmMachine) OwnIdentity() *id.Device {
return &id.Device{
UserID: mach.Client.UserID,
@@ -168,11 +191,18 @@ func (mach *OlmMachine) OwnIdentity() *id.Device {
}
}
func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az *appservice.AppService) {
type asEventProcessor interface {
On(evtType event.Type, handler func(evt *event.Event))
OnOTK(func(otk *mautrix.OTKCount))
OnDeviceList(func(lists *mautrix.DeviceLists, since string))
}
func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
// ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events
ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceRoomKeyWithheld, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceBeeperRoomKeyAck, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceOrgMatrixRoomKeyWithheld, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationRequest, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationStart, mach.HandleToDeviceEvent)
@@ -182,34 +212,44 @@ func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az
ep.On(event.ToDeviceVerificationCancel, mach.HandleToDeviceEvent)
ep.OnOTK(mach.HandleOTKCounts)
ep.OnDeviceList(mach.HandleDeviceLists)
mach.Log.Trace("Added listeners for encryption data coming from appservice transactions")
mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions")
}
func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
if len(dl.Changed) > 0 {
traceID := time.Now().Format("15:04:05.000000")
mach.Log.Trace("Device list changes in /sync: %v (trace: %s)", dl.Changed, traceID)
mach.fetchKeys(dl.Changed, since, false)
mach.Log.Trace("Finished handling device list changes (trace: %s)", traceID)
mach.Log.Debug().
Str("trace_id", traceID).
Interface("changes", dl.Changed).
Msg("Device list changes in /sync")
mach.fetchKeys(context.TODO(), dl.Changed, since, false)
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
}
}
func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Debug("Dropping OTK counts targeted to %s/%s (not us)", otkCount.UserID, otkCount.DeviceID)
mach.Log.Warn().
Str("target_user_id", otkCount.UserID.String()).
Str("target_device_id", otkCount.DeviceID.String()).
Msg("Dropping OTK counts targeted to someone else")
return
}
minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
if otkCount.SignedCurve25519 < int(minCount) {
traceID := time.Now().Format("15:04:05.000000")
mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones... (trace: %s)", otkCount.SignedCurve25519, traceID)
err := mach.ShareKeys(otkCount.SignedCurve25519)
log := mach.Log.With().Str("trace_id", traceID).Logger()
ctx := log.WithContext(context.Background())
log.Debug().
Int("keys_left", otkCount.Curve25519).
Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...")
err := mach.ShareKeys(ctx, otkCount.SignedCurve25519)
if err != nil {
mach.Log.Error("Failed to share keys: %v (trace: %s)", err, traceID)
log.Error().Err(err).Msg("Failed to share keys")
} else {
mach.Log.Debug("Successfully shared keys (trace: %s)", traceID)
log.Debug().Msg("Successfully shared keys")
}
}
}
@@ -218,7 +258,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
//
// This can be easily registered into a mautrix client using .OnSync():
//
// client.Syncer.(*mautrix.DefaultSyncer).OnSync(c.crypto.ProcessSyncResponse)
// client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse)
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
mach.HandleDeviceLists(&resp.DeviceLists, since)
@@ -226,7 +266,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
evt.Type.Class = event.ToDeviceEventType
err := evt.Content.ParseRaw(evt.Type)
if err != nil {
mach.Log.Warn("Failed to parse to-device event of type %s: %v", evt.Type.Type, err)
mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event")
continue
}
mach.HandleToDeviceEvent(evt)
@@ -240,8 +280,8 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
//
// Currently this is not automatically called, so you must add a listener yourself:
//
// client.Syncer.(*mautrix.DefaultSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
if !mach.StateStore.IsEncrypted(evt.RoomID) {
return
}
@@ -263,10 +303,15 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
(prevContent.Membership == event.MembershipLeave && content.Membership == event.MembershipBan) {
return
}
mach.Log.Trace("Got membership state event in %s changing %s from %s to %s, invalidating group session", evt.RoomID, evt.GetStateKey(), prevContent.Membership, content.Membership)
mach.Log.Trace().
Str("room_id", evt.RoomID.String()).
Str("user_id", evt.GetStateKey()).
Str("prev_membership", string(prevContent.Membership)).
Str("new_membership", string(content.Membership)).
Msg("Got membership state change, invalidating group session in room")
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
if err != nil {
mach.Log.Warn("Failed to invalidate outbound group session of %s: %v", evt.RoomID, err)
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
}
}
@@ -275,43 +320,63 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Debug("Dropping to-device event targeted to %s/%s (not us)", evt.ToUserID, evt.ToDeviceID)
mach.Log.Debug().
Str("target_user_id", evt.ToUserID.String()).
Str("target_device_id", evt.ToDeviceID.String()).
Msg("Dropping to-device event targeted to someone else")
return
}
traceID := time.Now().Format("15:04:05.000000")
log := mach.Log.With().
Str("trace_id", traceID).
Str("sender", evt.Sender.String()).
Str("type", evt.Type.Type).
Logger()
ctx := log.WithContext(context.Background())
if evt.Type != event.ToDeviceEncrypted {
mach.Log.Trace("Starting handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID)
log.Debug().Msg("Starting handling to-device event")
}
switch content := evt.Content.Parsed.(type) {
case *event.EncryptedEventContent:
mach.Log.Debug("Handling encrypted to-device event from %s/%s (trace: %s)", evt.Sender, content.SenderKey, traceID)
decryptedEvt, err := mach.decryptOlmEvent(evt, traceID)
log = log.With().
Str("sender_key", content.SenderKey.String()).
Logger()
log.Debug().Msg("Handling encrypted to-device event")
ctx = log.WithContext(context.Background())
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
if err != nil {
mach.Log.Error("Failed to decrypt to-device event: %v (trace: %s)", err, traceID)
log.Error().Err(err).Msg("Failed to decrypt to-device event")
return
}
mach.Log.Trace("Successfully decrypted to-device from %s/%s into type %s (sender key: %s, trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, decryptedEvt.Type.String(), decryptedEvt.SenderKey, traceID)
log = log.With().
Str("decrypted_type", decryptedEvt.Type.Type).
Str("sender_device", decryptedEvt.SenderDevice.String()).
Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()).
Logger()
log.Trace().Msg("Successfully decrypted to-device event")
switch decryptedContent := decryptedEvt.Content.Parsed.(type) {
case *event.RoomKeyEventContent:
mach.receiveRoomKey(decryptedEvt, decryptedContent, traceID)
mach.Log.Trace("Handled room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent)
log.Trace().Msg("Handled room key event")
case *event.ForwardedRoomKeyEventContent:
if mach.importForwardedRoomKey(decryptedEvt, decryptedContent) {
if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) {
if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok {
// close channel to notify listener that the key was received
close(ch.(chan struct{}))
}
}
mach.Log.Trace("Handled forwarded room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
log.Trace().Msg("Handled forwarded room key event")
case *event.DummyEventContent:
mach.Log.Debug("Received encrypted dummy event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
log.Debug().Msg("Received encrypted dummy event")
default:
mach.Log.Debug("Unhandled encrypted to-device event of type %s from %s/%s (trace: %s)", decryptedEvt.Type.String(), decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
log.Debug().Msg("Unhandled encrypted to-device event")
}
return
case *event.RoomKeyRequestEventContent:
go mach.handleRoomKeyRequest(evt.Sender, content)
go mach.handleRoomKeyRequest(ctx, evt.Sender, content)
case *event.BeeperRoomKeyAckEventContent:
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content)
// verification cases
case *event.VerificationStartEventContent:
mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "")
@@ -326,27 +391,25 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
case *event.VerificationRequestEventContent:
mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "")
case *event.RoomKeyWithheldEventContent:
mach.handleRoomKeyWithheld(content)
mach.handleRoomKeyWithheld(ctx, content)
default:
deviceID, _ := evt.Content.Raw["device_id"].(string)
mach.Log.Trace("Unhandled to-device event of type %s from %s/%s (trace: %s)", evt.Type.Type, evt.Sender, deviceID, traceID)
log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event")
return
}
mach.Log.Trace("Finished handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID)
log.Debug().Msg("Finished handling to-device event")
}
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
// and if it's not found it asks the server for it.
func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
// get device identity
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
if err != nil {
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
} else if device != nil {
return device, nil
}
// try to fetch if not found
usersToDevices := mach.fetchKeys([]id.UserID{userID}, "", true)
usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true)
if devices, ok := usersToDevices[userID]; ok {
if device, ok = devices[deviceID]; ok {
return device, nil
@@ -359,12 +422,15 @@ func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID)
// GetOrFetchDeviceByKey attempts to retrieve the device identity for the device with the given identity key from the
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
// the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
if err != nil || deviceIdentity != nil {
return deviceIdentity, err
}
mach.Log.Debug("Didn't find identity of %s/%s in crypto store, fetching from server", userID, identityKey)
mach.machOrContextLog(ctx).Debug().
Str("user_id", userID.String()).
Str("identity_key", identityKey.String()).
Msg("Didn't find identity in crypto store, fetching from server")
devices := mach.LoadDevices(userID)
for _, device := range devices {
if device.IdentityKey == identityKey {
@@ -375,8 +441,8 @@ func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.I
}
// SendEncryptedToDevice sends an Olm-encrypted event to the given user device.
func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.Type, content event.Content) error {
if err := mach.createOutboundSessions(map[id.UserID]map[id.DeviceID]*id.Device{
func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.Device, evtType event.Type, content event.Content) error {
if err := mach.createOutboundSessions(ctx, map[id.UserID]map[id.DeviceID]*id.Device{
device.UserID: {
device.DeviceID: device,
},
@@ -395,10 +461,16 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.T
return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID)
}
encrypted := mach.encryptOlmEvent(olmSess, device, evtType, content)
encrypted := mach.encryptOlmEvent(ctx, olmSess, device, evtType, content)
encryptedContent := &event.Content{Parsed: &encrypted}
mach.Log.Debug("Sending encrypted to-device event of type %s to %s/%s (identity key: %s, olm session ID: %s)", evtType.Type, device.UserID, device.DeviceID, device.IdentityKey, olmSess.ID())
mach.machOrContextLog(ctx).Debug().
Str("decrypted_type", evtType.Type).
Str("to_user_id", device.UserID.String()).
Str("to_device_id", device.DeviceID.String()).
Str("to_identity_key", device.IdentityKey.String()).
Str("olm_session_id", olmSess.ID().String()).
Msg("Sending encrypted to-device event")
_, err = mach.Client.SendToDevice(event.ToDeviceEncrypted,
&mautrix.ReqSendToDevice{
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
@@ -412,22 +484,32 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.T
return err
}
func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, traceID string) {
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey)
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) {
log := zerolog.Ctx(ctx)
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages, isScheduled)
if err != nil {
mach.Log.Error("Failed to create inbound group session: %v", err)
log.Error().Err(err).Msg("Failed to create inbound group session")
return
} else if igs.ID() != sessionID {
mach.Log.Warn("Mismatched session ID while creating inbound group session")
log.Warn().
Str("expected_session_id", sessionID.String()).
Str("actual_session_id", igs.ID().String()).
Msg("Mismatched session ID while creating inbound group session")
return
}
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
if err != nil {
mach.Log.Error("Failed to store new inbound group session: %v", err)
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return
}
mach.markSessionReceived(sessionID)
mach.Log.Debug("Received inbound group session %s / %s / %s", roomID, senderKey, sessionID)
log.Debug().
Str("session_id", sessionID.String()).
Str("sender_key", senderKey.String()).
Str("max_age", maxAge.String()).
Int("max_messages", maxMessages).
Bool("is_scheduled", isScheduled).
Msg("Received inbound group session")
}
func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
@@ -465,24 +547,66 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
}
}
func (mach *OlmMachine) receiveRoomKey(evt *DecryptedOlmEvent, content *event.RoomKeyEventContent, traceID string) {
// TODO nio had a comment saying "handle this better" for the case where evt.Keys.Ed25519 is none?
func stringifyArray[T ~string](arr []T) []string {
strs := make([]string, len(arr))
for i, v := range arr {
strs[i] = string(v)
}
return strs
}
func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.RoomKeyEventContent) {
log := zerolog.Ctx(ctx).With().
Str("algorithm", string(content.Algorithm)).
Str("session_id", content.SessionID.String()).
Str("room_id", content.RoomID.String()).
Logger()
if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" {
mach.Log.Debug("Ignoring weird room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID)
log.Debug().Msg("Ignoring weird room key")
return
}
mach.createGroupSession(evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, traceID)
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
var maxAge time.Duration
var maxMessages int
if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
if maxAge == 0 {
maxAge = 7 * 24 * time.Hour
}
maxMessages = config.RotationPeriodMessages
if maxMessages == 0 {
maxMessages = 100
}
}
if content.MaxAge != 0 {
maxAge = time.Duration(content.MaxAge) * time.Millisecond
}
if content.MaxMessages != 0 {
maxMessages = content.MaxMessages
}
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
if err != nil {
log.Err(err).Msg("Failed to redact previous megolm sessions")
} else {
log.Info().
Strs("session_ids", stringifyArray(sessionIDs)).
Msg("Redacted previous megolm sessions")
}
}
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled)
}
func (mach *OlmMachine) handleRoomKeyWithheld(content *event.RoomKeyWithheldEventContent) {
func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) {
if content.Algorithm != id.AlgorithmMegolmV1 {
mach.Log.Debug("Non-megolm room key withheld event: %+v", content)
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
return
}
err := mach.CryptoStore.PutWithheldGroupSession(*content)
if err != nil {
mach.Log.Error("Failed to save room key withheld event: %v", err)
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
}
}
@@ -491,34 +615,38 @@ func (mach *OlmMachine) handleRoomKeyWithheld(content *event.RoomKeyWithheldEven
// If the Olm account hasn't been shared, the account keys will be uploaded.
// If currentOTKCount is less than half of the limit (100 / 2 = 50), enough one-time keys will be uploaded so exactly
// half of the limit is filled.
func (mach *OlmMachine) ShareKeys(currentOTKCount int) error {
func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) error {
log := mach.machOrContextLog(ctx)
start := time.Now()
mach.otkUploadLock.Lock()
defer mach.otkUploadLock.Unlock()
if mach.lastOTKUpload.Add(1 * time.Minute).After(start) {
mach.Log.Trace("Checking OTK count from server due to suspiciously close share keys requests")
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests")
resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{})
if err != nil {
return fmt.Errorf("failed to check current OTK counts: %w", err)
}
mach.Log.Trace("Fetched current OTK count (%d) from server (input count was %d)", resp.OneTimeKeyCounts.SignedCurve25519, currentOTKCount)
log.Debug().
Int("input_count", currentOTKCount).
Int("server_count", resp.OneTimeKeyCounts.SignedCurve25519).
Msg("Fetched current OTK count from server")
currentOTKCount = resp.OneTimeKeyCounts.SignedCurve25519
}
var deviceKeys *mautrix.DeviceKeys
if !mach.account.Shared {
deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID)
mach.Log.Trace("Going to upload initial account keys")
log.Debug().Msg("Going to upload initial account keys")
}
oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount)
if len(oneTimeKeys) == 0 && deviceKeys == nil {
mach.Log.Trace("No one-time keys nor device keys got when trying to share keys")
log.Debug().Msg("No one-time keys nor device keys got when trying to share keys")
return nil
}
req := &mautrix.ReqUploadKeys{
DeviceKeys: deviceKeys,
OneTimeKeys: oneTimeKeys,
}
mach.Log.Trace("Uploading %d one-time keys", len(oneTimeKeys))
log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys")
_, err := mach.Client.UploadKeys(req)
if err != nil {
return err
@@ -528,3 +656,23 @@ func (mach *OlmMachine) ShareKeys(currentOTKCount int) error {
mach.saveAccount()
return nil
}
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
for {
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
if err != nil {
log.Err(err).Msg("Failed to redact expired megolm sessions")
} else if len(sessionIDs) > 0 {
log.Info().Strs("session_ids", stringifyArray(sessionIDs)).Msg("Redacted expired megolm sessions")
} else {
log.Debug().Msg("Didn't find any expired megolm sessions")
}
select {
case <-ctx.Done():
log.Debug().Msg("Loop stopped")
return
case <-time.After(24 * time.Hour):
}
}
}

View File

@@ -288,7 +288,7 @@ func (s *InboundGroupSession) exportLen() uint {
// if we do not have a session key corresponding to the given index (ie, it was
// sent before the session key was shared with us) the error will be
// "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Export(messageIndex uint32) (string, error) {
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
key := make([]byte, s.exportLen())
r := C.olm_export_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
@@ -296,7 +296,7 @@ func (s *InboundGroupSession) Export(messageIndex uint32) (string, error) {
C.size_t(len(key)),
C.uint32_t(messageIndex))
if r == errorVal() {
return "", s.lastError()
return nil, s.lastError()
}
return string(key[:r]), nil
return key[:r], nil
}

View File

@@ -1,6 +1,3 @@
//go:build !nosas
// +build !nosas
package olm
// #cgo LDFLAGS: -lolm -lstdc++

View File

@@ -89,6 +89,12 @@ func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([]
return msg, err
}
type RatchetSafety struct {
NextIndex uint `json:"next_index"`
MissedIndices []uint `json:"missed_indices,omitempty"`
LostIndices []uint `json:"lost_indices,omitempty"`
}
type InboundGroupSession struct {
Internal olm.InboundGroupSession
@@ -97,11 +103,17 @@ type InboundGroupSession struct {
RoomID id.RoomID
ForwardingChains []string
RatchetSafety RatchetSafety
ReceivedAt time.Time
MaxAge int64
MaxMessages int
IsScheduled bool
id id.SessionID
}
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) {
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) (*InboundGroupSession, error) {
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
if err != nil {
return nil, err
@@ -112,6 +124,10 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: nil,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: isScheduled,
}, nil
}
@@ -122,6 +138,19 @@ func (igs *InboundGroupSession) ID() id.SessionID {
return igs.id
}
func (igs *InboundGroupSession) RatchetTo(index uint32) error {
exported, err := igs.Internal.Export(index)
if err != nil {
return err
}
imported, err := olm.InboundGroupSessionImport(exported)
if err != nil {
return err
}
igs.Internal = *imported
return nil
}
type OGSState int
const (

View File

@@ -7,13 +7,19 @@
package crypto
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
"maunium.net/go/mautrix/event"
@@ -80,6 +86,35 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) {
return store.SyncToken, nil
}
var _ mautrix.SyncStore = (*SQLCryptoStore)(nil)
func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {}
func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" }
func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
err := store.PutNextBatch(nextBatchToken)
if err != nil {
// TODO handle error
}
}
func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string {
nb, err := store.GetNextBatch()
if err != nil {
// TODO handle error
}
return nb
}
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
// TODO return error
store.DB.Log.Warn("Failed to scan device ID: %v", err)
}
return
}
// PutAccount stores an OlmAccount in the database.
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
store.Account = account
@@ -220,37 +255,72 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession)
return err
}
func intishPtr[T int | int64](i T) *T {
if i == 0 {
return nil
}
return &i
}
func datePtr(t time.Time) *time.Time {
if t.IsZero() {
return nil
}
return &t
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
forwardingChains := strings.Join(session.ForwardingChains, ",")
_, err := store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
if err != nil {
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
}
_, err = store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains
`, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID)
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
`,
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
session.IsScheduled, store.AccountID,
)
return err
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
var signingKey, forwardingChains, withheldCode sql.NullString
var sessionBytes []byte
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := store.DB.QueryRow(`
SELECT signing_key, session, forwarding_chains, withheld_code
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode)
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
} else if withheldCode.Valid {
return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheldCode.String)
return nil, &event.RoomKeyWithheldEventContent{
RoomID: roomID,
Algorithm: id.AlgorithmMegolmV1,
SessionID: sessionID,
SenderKey: senderKey,
Code: event.RoomKeyWithheldCode(withheldCode.String),
Reason: withheldReason.String,
}
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
@@ -261,18 +331,96 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
if senderKey == "" {
senderKey = id.Curve25519(senderKeyDB.String)
}
return &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
}, nil
}
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
_, err := store.DB.Exec(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, sessionID, store.AccountID)
return err
}
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
if roomID == "" && senderKey == "" {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
res, err := store.DB.Query(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
}
return sessionIDs, err
}
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
var query string
switch store.DB.Dialect {
case dbutil.Postgres:
query = `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND received_at + 2 * (max_age * interval '1 millisecond') < now()
RETURNING session_id
`
case dbutil.SQLite:
query = `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
RETURNING session_id
`
default:
return nil, fmt.Errorf("unsupported dialect")
}
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
}
return sessionIDs, err
}
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID)
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
return err
}
@@ -302,8 +450,11 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
for rows.Next() {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes []byte
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains)
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err != nil {
return
}
@@ -316,12 +467,24 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
result = append(result, &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
})
}
return
@@ -329,7 +492,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
@@ -343,7 +506,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
store.AccountID,
)
@@ -367,7 +530,7 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
max_messages=excluded.max_messages, message_count=excluded.message_count, max_age=excluded.max_age,
created_at=excluded.created_at, last_used=excluded.last_used, account_id=excluded.account_id
`, session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount,
session.MaxAge, session.CreationTime, session.LastEncryptedTime, store.AccountID)
session.MaxAge.Milliseconds(), session.CreationTime, session.LastEncryptedTime, store.AccountID)
return err
}
@@ -383,11 +546,12 @@ func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSe
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
var ogs OutboundGroupSession
var sessionBytes []byte
var maxAgeMS int64
err := store.DB.QueryRow(`
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
roomID, store.AccountID,
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.LastEncryptedTime)
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
@@ -400,6 +564,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
}
ogs.Internal = *intOGS
ogs.RoomID = roomID
ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond
return &ogs, nil
}
@@ -412,7 +577,7 @@ func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
const validateQuery = `
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
VALUES ($1, $2, $3, $4, $5)
@@ -422,11 +587,19 @@ func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessio
`
var expectedEventID id.EventID
var expectedTimestamp int64
err := store.DB.QueryRow(validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
if err != nil {
return false, err
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
zerolog.Ctx(ctx).Debug().
Uint("message_index", index).
Str("expected_event_id", expectedEventID.String()).
Int64("expected_timestamp", expectedTimestamp).
Int64("actual_timestamp", timestamp).
Msg("Failed to validate that message index wasn't duplicated")
return false, nil
}
return expectedEventID == eventID && expectedTimestamp == timestamp, nil
return true, nil
}
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.

View File

@@ -1,4 +1,4 @@
-- v0 -> v8: Latest revision
-- v0 -> v10: Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@@ -52,6 +52,11 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
forwarding_chains bytea,
withheld_code TEXT,
withheld_reason TEXT,
ratchet_safety jsonb,
received_at timestamp,
max_age BIGINT,
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (account_id, session_id)
);

View File

@@ -0,0 +1,2 @@
-- v9: Change outbound megolm session max_age column to milliseconds
UPDATE crypto_megolm_outbound_session SET max_age=max_age/1000000;

View File

@@ -0,0 +1,6 @@
-- v10: Add metadata for detecting when megolm sessions are safe to delete
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false;

View File

@@ -21,7 +21,7 @@ const VersionTableName = "crypto_version"
var fs embed.FS
func init() {
Table.Register(-1, 3, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
})
Table.RegisterFS(fs)

View File

@@ -7,10 +7,8 @@
package crypto
import (
"encoding/gob"
"errors"
"context"
"fmt"
"os"
"sort"
"sync"
@@ -18,10 +16,7 @@ import (
"maunium.net/go/mautrix/id"
)
// Deprecated: moved to id.Device
type DeviceIdentity = id.Device
var ErrGroupSessionWithheld = errors.New("group session has been withheld")
var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{}
// Store is used by OlmMachine to store Olm and Megolm sessions, user device lists and message indices.
//
@@ -58,6 +53,12 @@ type Store interface {
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error
// RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room.
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
RedactExpiredGroupSessions() ([]id.SessionID, error)
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
@@ -90,9 +91,9 @@ type Store interface {
// * If the map key doesn't exist, the given values should be stored and this should return true.
// * If the map key exists and the stored values match the given values, this should return true.
// * If the map key exists, but the stored values do not match the given values, this should return false.
ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
// GetDevices returns a map from device ID to DeviceIdentity containing all devices of a given user.
// GetDevices returns a map from device ID to id.Device struct containing all devices of a given user.
GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error)
// GetDevice returns a specific device of a given user.
GetDevice(id.UserID, id.DeviceID) (*id.Device, error)
@@ -129,12 +130,12 @@ type messageIndexValue struct {
Timestamp int64
}
// GobStore is a simple Store implementation that dumps everything into a .gob file.
//
// Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended.
type GobStore struct {
// MemoryStore is a simple in-memory Store implementation. It can optionally have a callback function for saving data,
// but the actual storage must be implemented manually.
type MemoryStore struct {
lock sync.RWMutex
path string
save func() error
Account *OlmAccount
Sessions map[id.SenderKey]OlmSessionList
@@ -147,14 +148,15 @@ type GobStore struct {
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
}
var _ Store = (*GobStore)(nil)
var _ Store = (*MemoryStore)(nil)
func NewMemoryStore(saveCallback func() error) *MemoryStore {
if saveCallback == nil {
saveCallback = func() error { return nil }
}
return &MemoryStore{
save: saveCallback,
// NewGobStore creates a new GobStore that saves everything to the given file.
//
// Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended.
func NewGobStore(path string) (*GobStore, error) {
gs := &GobStore{
path: path,
Sessions: make(map[id.SenderKey]OlmSessionList),
GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession),
WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent),
@@ -164,44 +166,20 @@ func NewGobStore(path string) (*GobStore, error) {
CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey),
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
}
return gs, gs.load()
}
func (gs *GobStore) save() error {
file, err := os.OpenFile(gs.path, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
err = gob.NewEncoder(file).Encode(gs)
_ = file.Close()
return err
}
func (gs *GobStore) load() error {
file, err := os.OpenFile(gs.path, os.O_RDONLY, 0600)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
err = gob.NewDecoder(file).Decode(gs)
_ = file.Close()
return err
}
func (gs *GobStore) Flush() error {
func (gs *MemoryStore) Flush() error {
gs.lock.Lock()
err := gs.save()
gs.lock.Unlock()
return err
}
func (gs *GobStore) GetAccount() (*OlmAccount, error) {
func (gs *MemoryStore) GetAccount() (*OlmAccount, error) {
return gs.Account, nil
}
func (gs *GobStore) PutAccount(account *OlmAccount) error {
func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
gs.lock.Lock()
gs.Account = account
err := gs.save()
@@ -209,7 +187,7 @@ func (gs *GobStore) PutAccount(account *OlmAccount) error {
return err
}
func (gs *GobStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
gs.lock.Lock()
sessions, ok := gs.Sessions[senderKey]
if !ok {
@@ -220,7 +198,7 @@ func (gs *GobStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error)
return sessions, nil
}
func (gs *GobStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
gs.lock.Lock()
sessions, _ := gs.Sessions[senderKey]
gs.Sessions[senderKey] = append(sessions, session)
@@ -230,19 +208,19 @@ func (gs *GobStore) AddSession(senderKey id.SenderKey, session *OlmSession) erro
return err
}
func (gs *GobStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *GobStore) HasSession(senderKey id.SenderKey) bool {
func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
return ok && len(sessions) > 0 && !sessions[0].Expired()
}
func (gs *GobStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
@@ -252,7 +230,7 @@ func (gs *GobStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error
return sessions[0], nil
}
func (gs *GobStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession {
func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession {
room, ok := gs.GroupSessions[roomID]
if !ok {
room = make(map[id.SenderKey]map[id.SessionID]*InboundGroupSession)
@@ -266,7 +244,7 @@ func (gs *GobStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) m
return sender
}
func (gs *GobStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
gs.lock.Lock()
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
err := gs.save()
@@ -274,7 +252,7 @@ func (gs *GobStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, se
return err
}
func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
gs.lock.Lock()
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
if !ok {
@@ -289,7 +267,57 @@ func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, se
return session, nil
}
func (gs *GobStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
gs.lock.Lock()
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
err := gs.save()
gs.lock.Unlock()
return err
}
func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
gs.lock.Lock()
var sessionIDs []id.SessionID
if roomID != "" && senderKey != "" {
sessions := gs.getGroupSessions(roomID, senderKey)
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
delete(sessions, sessionID)
}
} else if senderKey != "" {
for _, room := range gs.GroupSessions {
sessions, ok := room[senderKey]
if ok {
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
}
delete(room, senderKey)
}
}
} else if roomID != "" {
room, ok := gs.GroupSessions[roomID]
if ok {
for senderKey := range room {
sessions := room[senderKey]
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
}
}
delete(gs.GroupSessions, roomID)
}
} else {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
err := gs.save()
gs.lock.Unlock()
return sessionIDs, err
}
func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
room, ok := gs.WithheldGroupSessions[roomID]
if !ok {
room = make(map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent)
@@ -303,7 +331,7 @@ func (gs *GobStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.Send
return sender
}
func (gs *GobStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
gs.lock.Lock()
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
err := gs.save()
@@ -311,7 +339,7 @@ func (gs *GobStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventCo
return err
}
func (gs *GobStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
gs.lock.Lock()
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
gs.lock.Unlock()
@@ -321,7 +349,7 @@ func (gs *GobStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Sende
return session, nil
}
func (gs *GobStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
room, ok := gs.GroupSessions[roomID]
@@ -337,7 +365,7 @@ func (gs *GobStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSe
return result, nil
}
func (gs *GobStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
gs.lock.Lock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
@@ -351,7 +379,7 @@ func (gs *GobStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
return result, nil
}
func (gs *GobStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
gs.lock.Lock()
gs.OutGroupSessions[session.RoomID] = session
err := gs.save()
@@ -359,12 +387,12 @@ func (gs *GobStore) AddOutboundGroupSession(session *OutboundGroupSession) error
return err
}
func (gs *GobStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
gs.lock.RLock()
session, ok := gs.OutGroupSessions[roomID]
gs.lock.RUnlock()
@@ -374,7 +402,7 @@ func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSes
return session, nil
}
func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
gs.lock.Lock()
session, ok := gs.OutGroupSessions[roomID]
if !ok || session == nil {
@@ -386,7 +414,7 @@ func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
return nil
}
func (gs *GobStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
key := messageIndexKey{
@@ -409,7 +437,7 @@ func (gs *GobStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.Se
return true, nil
}
func (gs *GobStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
gs.lock.RLock()
devices, ok := gs.Devices[userID]
if !ok {
@@ -419,7 +447,7 @@ func (gs *GobStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, er
return devices, nil
}
func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@@ -433,7 +461,7 @@ func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Devic
return device, nil
}
func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@@ -448,7 +476,7 @@ func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey
return nil, nil
}
func (gs *GobStore) PutDevice(userID id.UserID, device *id.Device) error {
func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
gs.lock.Lock()
devices, ok := gs.Devices[userID]
if !ok {
@@ -461,7 +489,7 @@ func (gs *GobStore) PutDevice(userID id.UserID, device *id.Device) error {
return err
}
func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
gs.lock.Lock()
gs.Devices[userID] = devices
err := gs.save()
@@ -469,7 +497,7 @@ func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Dev
return err
}
func (gs *GobStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
gs.lock.RLock()
var ptr int
for _, userID := range users {
@@ -483,7 +511,7 @@ func (gs *GobStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
return users[:ptr], nil
}
func (gs *GobStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
gs.lock.RLock()
userKeys, ok := gs.CrossSigningKeys[userID]
if !ok {
@@ -505,7 +533,7 @@ func (gs *GobStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUs
return err
}
func (gs *GobStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
keys, ok := gs.CrossSigningKeys[userID]
@@ -515,7 +543,7 @@ func (gs *GobStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUs
return keys, nil
}
func (gs *GobStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
gs.lock.RLock()
signedUserSigs, ok := gs.KeySignatures[signedUserID]
if !ok {
@@ -538,7 +566,7 @@ func (gs *GobStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, s
return err
}
func (gs *GobStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
userKeys, ok := gs.KeySignatures[userID]
@@ -556,7 +584,7 @@ func (gs *GobStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, sign
return sigsBySigner, nil
}
func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID)
if err != nil {
return false, err
@@ -565,7 +593,7 @@ func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.
return ok, nil
}
func (gs *GobStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
var count int64
gs.lock.RLock()
for _, userSigs := range gs.KeySignatures {

View File

@@ -4,9 +4,6 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !nosas
// +build !nosas
package crypto
import (
@@ -92,13 +89,13 @@ func (mach *OlmMachine) getPKAndKeysMAC(sas *olm.SAS, sendingUser id.UserID, sen
if err != nil {
return "", "", err
}
mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac))
mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac))
keysMac, err := sas.CalculateMAC([]byte(keyIDString), []byte(sasInfo+"KEY_IDS"))
if err != nil {
return "", "", err
}
mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac))
mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac))
return string(pubKeyMac), string(keysMac), nil
}
@@ -144,14 +141,14 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User
// handleVerificationStart handles an incoming m.key.verification.start message.
// It initializes the state for this SAS verification process and stores it.
func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
mach.Log.Debug("Received verification start from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice)
mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
if err != nil {
mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID)
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
return
}
warnAndCancel := func(logReason, cancelReason string) {
mach.Log.Warn("Canceling verification transaction %v as it %s", transactionID, logReason)
mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
} else {
@@ -179,7 +176,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
if inRoomID != "" && transactionID != "" {
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error")
return
}
@@ -187,7 +184,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString)
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
if err != nil {
mach.Log.Error("Error accepting in-room SAS verification: %v", err)
mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err)
}
verState.chosenSASMethod = sasMethods[0]
verState.verificationStarted = true
@@ -197,7 +194,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
if resp == AcceptRequest {
sasMethods := commonSASMethods(hooks, content.ShortAuthenticationString)
if len(sasMethods) == 0 {
mach.Log.Error("No common SAS methods: %v", content.ShortAuthenticationString)
mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
} else {
@@ -222,7 +219,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
_, loaded := mach.keyVerificationTransactionState.LoadOrStore(userID.String()+":"+transactionID, verState)
if loaded {
// transaction already exists
mach.Log.Error("Transaction %v already exists, canceling", transactionID)
mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
} else {
@@ -240,10 +237,10 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
}
if err != nil {
mach.Log.Error("Error accepting SAS verification: %v", err)
mach.Log.Error().Msgf("Error accepting SAS verification: %v", err)
}
} else if resp == RejectRequest {
mach.Log.Debug("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
var err error
if inRoomID == "" {
err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
@@ -251,10 +248,10 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
}
if err != nil {
mach.Log.Error("Error canceling SAS verification: %v", err)
mach.Log.Error().Msgf("Error canceling SAS verification: %v", err)
}
} else {
mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
}
}
@@ -276,12 +273,12 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
// if deadline exceeded cancel due to timeout
mach.keyVerificationTransactionState.Delete(mapKey)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout)
mach.Log.Warn("Verification transaction %v is canceled due to timing out", transactionID)
mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID)
verState.lock.Unlock()
return
}
// otherwise the cancel func was called, so the timeout is reset
mach.Log.Debug("Extending timeout for transaction %v", transactionID)
mach.Log.Debug().Msgf("Extending timeout for transaction %v", transactionID)
timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), timeout)
verState.extendTimeout = timeoutCancel
verState.lock.Unlock()
@@ -292,10 +289,10 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
// handleVerificationAccept handles an incoming m.key.verification.accept message.
// It continues the SAS verification process by sending the SAS key message to the other device.
func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) {
mach.Log.Debug("Received verification accept for transaction %v", transactionID)
mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID)
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
verState.lock.Lock()
@@ -304,7 +301,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
if !verState.initiatedByUs || verState.verificationStarted {
// unexpected accept at this point
mach.Log.Warn("Unexpected verification accept message for transaction %v", transactionID)
mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage)
return
@@ -316,7 +313,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
content.MessageAuthenticationCode != event.HKDFHMACSHA256 ||
len(sasMethods) == 0 {
mach.Log.Warn("Canceling verification transaction %v due to unknown parameter", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod)
return
@@ -333,7 +330,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
}
if err != nil {
mach.Log.Error("Error sending SAS key to other device: %v", err)
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
return
}
}
@@ -341,10 +338,10 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
// handleVerificationKey handles an incoming m.key.verification.key message.
// It stores the other device's public key in order to acquire the SAS shared secret.
func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) {
mach.Log.Debug("Got verification key for transaction %v: %v", transactionID, content.Key)
mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key)
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
verState.lock.Lock()
@@ -355,14 +352,14 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
if !verState.verificationStarted || verState.keyReceived {
// unexpected key at this point
mach.Log.Warn("Unexpected verification key message for transaction %v", transactionID)
mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage)
return
}
if err := verState.sas.SetTheirKey([]byte(content.Key)); err != nil {
mach.Log.Error("Error setting other device's key: %v", err)
mach.Log.Error().Msgf("Error setting other device's key: %v", err)
return
}
@@ -371,9 +368,9 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
if verState.initiatedByUs {
// verify commitment string from accept message now
expectedCommitment := olm.NewUtility().Sha256(content.Key + verState.startEventCanonical)
mach.Log.Debug("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment)
mach.Log.Debug().Msgf("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment)
if expectedCommitment != verState.commitment {
mach.Log.Warn("Canceling verification transaction %v due to commitment mismatch", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch)
return
@@ -388,7 +385,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
}
if err != nil {
mach.Log.Error("Error sending SAS key to other device: %v", err)
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
return
}
}
@@ -416,10 +413,10 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
sasMethod := verState.chosenSASMethod
sas, err := sasMethod.GetVerificationSAS(initUserID, initDeviceID, initKey, acceptUserID, acceptDeviceID, acceptKey, transactionID, verState.sas)
if err != nil {
mach.Log.Error("Error generating SAS (method %v): %v", sasMethod.Type(), err)
mach.Log.Error().Msgf("Error generating SAS (method %v): %v", sasMethod.Type(), err)
return
}
mach.Log.Debug("Generated SAS (%v): %v", sasMethod.Type(), sas)
mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas)
go func() {
result := verState.hooks.VerifySASMatch(device, sas)
mach.sasCompared(result, transactionID, verState)
@@ -441,7 +438,7 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
}
if err != nil {
mach.Log.Error("Error sending verification MAC to other device: %v", err)
mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err)
}
} else {
verState.sasMatched <- false
@@ -451,10 +448,10 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
// handleVerificationMAC handles an incoming m.key.verification.mac message.
// It verifies the other device's MAC and if the MAC is valid it marks the device as trusted.
func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) {
mach.Log.Debug("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
verState.lock.Lock()
@@ -468,7 +465,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
if !verState.verificationStarted || !verState.keyReceived {
// unexpected MAC at this point
mach.Log.Warn("Unexpected MAC message for transaction %v", transactionID)
mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage)
return
}
@@ -480,7 +477,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
defer verState.lock.Unlock()
if !matched {
mach.Log.Warn("SAS do not match! Canceling transaction %v", transactionID)
mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch)
return
}
@@ -490,20 +487,20 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
expectedPKMAC, expectedKeysMAC, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID,
mach.Client.UserID, mach.Client.DeviceID, transactionID, device.SigningKey, keyID, content.Mac)
if err != nil {
mach.Log.Error("Error generating MAC to match with received MAC: %v", err)
mach.Log.Error().Msgf("Error generating MAC to match with received MAC: %v", err)
return
}
mach.Log.Debug("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
if content.Keys != expectedKeysMAC {
mach.Log.Warn("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch)
return
}
mach.Log.Debug("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
if content.Mac[keyID] != expectedPKMAC {
mach.Log.Warn("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch)
return
}
@@ -512,35 +509,35 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
device.Trust = id.TrustStateVerified
err = mach.CryptoStore.PutDevice(device.UserID, device)
if err != nil {
mach.Log.Warn("Failed to put device after verifying: %v", err)
mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err)
}
if mach.CrossSigningKeys != nil {
if device.UserID == mach.Client.UserID {
err := mach.SignOwnDevice(device)
if err != nil {
mach.Log.Error("Failed to cross-sign own device %s: %v", device.DeviceID, err)
mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err)
} else {
mach.Log.Debug("Cross-signed own device %v after SAS verification", device.DeviceID)
mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID)
}
} else {
masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID)
if err != nil {
mach.Log.Warn("Failed to fetch %s's master key: %v", device.UserID, err)
mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err)
} else {
if err := mach.SignUser(device.UserID, masterKey); err != nil {
mach.Log.Error("Failed to cross-sign master key of %s: %v", device.UserID, err)
mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err)
} else {
mach.Log.Debug("Cross-signed master key of %v after SAS verification", device.UserID)
mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID)
}
}
}
} else {
// TODO ask user to unlock cross-signing keys?
mach.Log.Debug("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID)
mach.Log.Debug().Msgf("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID)
}
mach.Log.Debug("Device %v of user %v verified successfully!", device.DeviceID, device.UserID)
mach.Log.Debug().Msgf("Device %v of user %v verified successfully!", device.DeviceID, device.UserID)
verState.hooks.OnSuccess()
}()
@@ -557,20 +554,20 @@ func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *even
}
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
mach.Log.Warn("SAS verification %v was canceled by %v with reason: %v (%v)",
mach.Log.Warn().Msgf("SAS verification %v was canceled by %v with reason: %v (%v)",
transactionID, userID, content.Reason, content.Code)
}
// handleVerificationRequest handles an incoming m.key.verification.request message.
func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) {
mach.Log.Debug("Received verification request from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice)
mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
if err != nil {
mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID)
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
return
}
if !content.SupportsVerificationMethod(event.VerificationMethodSAS) {
mach.Log.Warn("Canceling verification transaction %v as SAS is not supported", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
} else {
@@ -580,12 +577,12 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
}
resp, hooks := mach.AcceptVerificationFrom(transactionID, otherDevice, inRoomID)
if resp == AcceptRequest {
mach.Log.Debug("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
if inRoomID == "" {
_, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
} else {
if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil {
mach.Log.Error("Error sending in-room SAS verification ready: %v", err)
mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err)
}
if mach.Client.UserID < otherDevice.UserID {
// up to us to send the start message
@@ -593,17 +590,17 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
}
}
if err != nil {
mach.Log.Error("Error accepting SAS verification request: %v", err)
mach.Log.Error().Msgf("Error accepting SAS verification request: %v", err)
}
} else if resp == RejectRequest {
mach.Log.Debug("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
} else {
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
}
} else {
mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
}
}
@@ -620,7 +617,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica
if transactionID == "" {
transactionID = strconv.Itoa(rand.Int())
}
mach.Log.Debug("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID)
mach.Log.Debug().Msgf("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID)
verState := &verificationState{
sas: olm.NewSAS(),
@@ -669,7 +666,7 @@ func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, r
verState := verStateInterface.(*verificationState)
verState.lock.Lock()
defer verState.lock.Unlock()
mach.Log.Trace("User canceled verification transaction %v with reason: %v", transactionID, reason)
mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason)
mach.keyVerificationTransactionState.Delete(mapKey)
return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser)
}
@@ -766,9 +763,9 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID,
userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap)
if err != nil {
mach.Log.Error("Error generating master key MAC: %v", err)
mach.Log.Error().Msgf("Error generating master key MAC: %v", err)
} else {
mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
macMap[masterKeyID] = masterKeyMAC
}
}
@@ -777,8 +774,8 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
if err != nil {
return err
}
mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac)
mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac)
macMap[keyID] = pubKeyMac
content := &event.VerificationMacEventContent{

View File

@@ -7,6 +7,7 @@
package crypto
import (
"context"
"encoding/json"
"errors"
"time"
@@ -80,7 +81,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationCancel, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationCancel, content)
if err != nil {
return err
}
@@ -97,7 +98,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse
To: toUserID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.EventMessage, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.EventMessage, content)
if err != nil {
return "", err
}
@@ -116,7 +117,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transac
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationReady, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationReady, content)
if err != nil {
return err
}
@@ -141,7 +142,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserI
To: toUserID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationStart, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationStart, content)
if err != nil {
return nil, err
}
@@ -182,7 +183,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUs
To: fromUser,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationAccept, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationAccept, content)
if err != nil {
return err
}
@@ -198,7 +199,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationKey, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationKey, content)
if err != nil {
return err
}
@@ -222,9 +223,9 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID,
userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap)
if err != nil {
mach.Log.Error("Error generating master key MAC: %v", err)
mach.Log.Error().Msgf("Error generating master key MAC: %v", err)
} else {
mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
macMap[masterKeyID] = masterKeyMAC
}
}
@@ -233,8 +234,8 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
if err != nil {
return err
}
mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac)
mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac)
macMap[keyID] = pubKeyMac
content := &event.VerificationMacEventContent{
@@ -244,7 +245,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationMAC, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationMAC, content)
if err != nil {
return err
}
@@ -259,7 +260,7 @@ func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID
}
func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
mach.Log.Debug("Starting new in-room verification transaction user %v", device.UserID)
mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID)
request := transactionID == ""
if request {
@@ -310,15 +311,15 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de
}
func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) {
device, err := mach.GetOrFetchDevice(userID, content.FromDevice)
device, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
if err != nil {
mach.Log.Error("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
return
}
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
//mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)

View File

@@ -4,9 +4,6 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !nosas
// +build !nosas
package crypto
import (