refactor to mautrix 0.17.x; update deps

This commit is contained in:
Aine
2024-02-11 20:47:04 +02:00
parent 0a9701f4c9
commit dd0ad4c245
237 changed files with 9091 additions and 3317 deletions

View File

@@ -12,4 +12,8 @@ repos:
rev: v1.0.0-rc.1
hooks:
- id: go-imports-repo
args:
- "-local"
- "maunium.net/go/mautrix"
- "-w"
- id: go-vet-repo-mod

View File

@@ -1,3 +1,27 @@
## v0.17.0 (2024-01-16)
* **Breaking change *(bridge)*** Added raw event to portal membership handling
functions.
* **Breaking change *(everything)*** Added context parameters to all functions
(started by [@recht] in [#144]).
* **Breaking change *(client)*** Moved `EventSource` to `event.Source`.
* *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version
(`Client.DontProcessOldEvents`) is still available.
* *(crypto)* Added experimental pure Go Olm implementation to replace libolm
(thanks to [@DerLukas15] in [#106]).
* You can use the `goolm` build tag to the new implementation.
* *(bridge)* Added context parameter for bridge command events.
* *(bridge)* Added method to allow custom validation for the entire config.
* *(client)* Changed default syncer to not drop unknown events.
* The syncer will still drop known events if parsing the content fails.
* The behavior can be changed by changing the `ParseErrorHandler` function.
* *(crypto)* Fixed some places using math/rand instead of crypto/rand.
[@DerLukas15]: https://github.com/DerLukas15
[@recht]: https://github.com/recht
[#106]: https://github.com/mautrix/go/pull/106
[#144]: https://github.com/mautrix/go/pull/144
## v0.16.2 (2023-11-16)
* *(event)* Added `Redacts` field to `RedactionEventContent` for room v11+.

View File

@@ -17,8 +17,3 @@ In addition to the basic client API features the original project has, this fram
* Structs for parsing event content
* Helpers for parsing and generating Matrix HTML
* Helpers for handling push rules
This project contains modules that are licensed under Apache 2.0:
* [maunium.net/go/mautrix/crypto/canonicaljson](crypto/canonicaljson)
* [maunium.net/go/mautrix/crypto/olm](crypto/olm)

File diff suppressed because it is too large Load Diff

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this

View File

@@ -1,177 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS

View File

@@ -2,3 +2,5 @@
This is a Go package to produce Matrix [Canonical JSON](https://matrix.org/docs/spec/appendices#canonical-json).
It is essentially just [json.go](https://github.com/matrix-org/gomatrixserverlib/blob/master/json.go)
from gomatrixserverlib without all the other files that are completely useless for non-server use cases.
The original project is licensed under the Apache 2.0 license.

View File

@@ -8,6 +8,7 @@
package crypto
import (
"context"
"fmt"
"maunium.net/go/mautrix"
@@ -89,7 +90,7 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
}
// PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server.
func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error {
func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error {
userID := mach.Client.UserID
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String())
masterKey := mautrix.CrossSigningKeys{
@@ -134,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uia
},
}
err = mach.Client.UploadCrossSigningKeys(&mautrix.UploadCrossSigningKeysReq{
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
Master: masterKey,
SelfSigning: selfKey,
UserSigning: userKey,

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,6 +7,7 @@
package crypto
import (
"context"
"fmt"
"maunium.net/go/mautrix"
@@ -19,7 +20,7 @@ type CrossSigningPublicKeysCache struct {
UserSigningKey id.Ed25519
}
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCache {
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache {
if mach.crossSigningPubkeys != nil {
return mach.crossSigningPubkeys
}
@@ -30,7 +31,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa
if mach.crossSigningPubkeysFetched {
return nil
}
cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID)
cspk, err := mach.GetCrossSigningPublicKeys(ctx, mach.Client.UserID)
if err != nil {
mach.Log.Error().Err(err).Msg("Failed to get own cross-signing public keys")
return nil
@@ -40,8 +41,8 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa
return mach.crossSigningPubkeys
}
func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigningPublicKeysCache, error) {
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) {
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
return nil, fmt.Errorf("failed to get keys from database: %w", err)
}
@@ -58,7 +59,7 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigni
}
}
keys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{
keys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: mautrix.DeviceKeysRequest{
userID: mautrix.DeviceIDList{},
},

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,7 @@
package crypto
import (
"context"
"errors"
"fmt"
@@ -33,8 +34,8 @@ var (
ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC")
)
func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
if err != nil {
return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err)
}
@@ -59,7 +60,7 @@ func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.Verific
}
// SignUser creates a cross-signing signature for a user, stores it and uploads it to the server.
func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKey id.Ed25519) error {
if userID == mach.Client.UserID {
return ErrCantSignOwnMasterKey
} else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.UserSigningKey == nil {
@@ -74,7 +75,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
},
}
signature, err := mach.signAndUpload(masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey)
signature, err := mach.signAndUpload(ctx, masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey)
if err != nil {
return err
}
@@ -84,7 +85,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
Str("signature", signature).
Msg("Signed master key of user with our user-signing key")
if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}
@@ -92,7 +93,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
}
// SignOwnMasterKey uses the current account for signing the current user's master key and uploads the signature.
func (mach *OlmMachine) SignOwnMasterKey() error {
func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
if mach.CrossSigningKeys == nil {
return ErrCrossSigningKeysNotCached
} else if mach.account == nil {
@@ -124,7 +125,7 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
Str("signature", signature).
Msg("Signed own master key with own device key")
resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{
resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{
userID: map[string]mautrix.ReqKeysSignatures{
masterKey.String(): masterKeyObj,
},
@@ -136,7 +137,7 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
return fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures)
}
if err := mach.CryptoStore.PutSignature(userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}
@@ -144,14 +145,14 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
}
// SignOwnDevice creates a cross-signing signature for a device belonging to the current user and uploads it to the server.
func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) error {
if device.UserID != mach.Client.UserID {
return ErrCantSignOtherDevice
} else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.SelfSigningKey == nil {
return ErrSelfSigningKeyNotCached
}
deviceKeys, err := mach.getFullDeviceKeys(device)
deviceKeys, err := mach.getFullDeviceKeys(ctx, device)
if err != nil {
return err
}
@@ -166,7 +167,7 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
deviceKeyObj.Keys[id.KeyID(keyID)] = key
}
signature, err := mach.signAndUpload(deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey)
signature, err := mach.signAndUpload(ctx, deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey)
if err != nil {
return err
}
@@ -177,7 +178,7 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
Str("signature", signature).
Msg("Signed own device key with self-signing key")
if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
}
@@ -186,8 +187,8 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
// getFullDeviceKeys gets the full device keys object for the given device.
// This is used because we don't cache some of the details like list of algorithms and unsupported key types.
func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKeys, error) {
devicesKeys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{
func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device) (*mautrix.DeviceKeys, error) {
devicesKeys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: mautrix.DeviceKeysRequest{
device.UserID: mautrix.DeviceIDList{device.DeviceID},
},
@@ -208,7 +209,7 @@ func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKey
}
// signAndUpload signs the given key signatures object and uploads it to the server.
func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) {
func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) {
signature, err := key.SignJSON(req)
if err != nil {
return "", fmt.Errorf("failed to sign JSON: %w", err)
@@ -219,7 +220,7 @@ func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.U
},
}
resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{
resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{
userID: map[string]mautrix.ReqKeysSignatures{
signedThing: req,
},

View File

@@ -7,6 +7,7 @@
package crypto
import (
"context"
"fmt"
"maunium.net/go/mautrix"
@@ -16,16 +17,16 @@ import (
)
// FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine.
func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error {
masterKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningMaster, key)
func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error {
masterKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningMaster, key)
if err != nil {
return err
}
selfSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningSelf, key)
selfSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningSelf, key)
if err != nil {
return err
}
userSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningUser, key)
userSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningUser, key)
if err != nil {
return err
}
@@ -38,12 +39,12 @@ func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error {
}
// retrieveDecryptXSigningKey retrieves the requested cross-signing key from SSSS and decrypts it using the given SSSS key.
func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) {
func (mach *OlmMachine) retrieveDecryptXSigningKey(ctx context.Context, keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) {
var decryptedKey [utils.AESCTRKeyLength]byte
var encData ssss.EncryptedAccountDataEventContent
// retrieve and parse the account data for this key type from SSSS
err := mach.Client.GetAccountData(keyName.Type, &encData)
err := mach.Client.GetAccountData(ctx, keyName.Type, &encData)
if err != nil {
return decryptedKey, err
}
@@ -62,8 +63,8 @@ func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss
// is used. The base58-formatted recovery key is the first return parameter.
//
// The account password of the user is required for uploading keys to the server.
func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphrase string) (string, error) {
key, err := mach.SSSS.GenerateAndUploadKey(passphrase)
func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, userPassword, passphrase string) (string, error) {
key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase)
if err != nil {
return "", fmt.Errorf("failed to generate and upload SSSS key: %w", err)
}
@@ -77,12 +78,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra
recoveryKey := key.RecoveryKey()
// Store the private keys in SSSS
if err := mach.UploadCrossSigningKeysToSSSS(key, keysCache); err != nil {
if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil {
return recoveryKey, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err)
}
// Publish cross-signing keys
err = mach.PublishCrossSigningKeys(keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} {
err = mach.PublishCrossSigningKeys(ctx, keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} {
return &mautrix.ReqUIAuthLogin{
BaseAuthData: mautrix.BaseAuthData{
Type: mautrix.AuthTypePassword,
@@ -96,7 +97,7 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra
return recoveryKey, fmt.Errorf("failed to publish cross-signing keys: %w", err)
}
err = mach.SSSS.SetDefaultKeyID(key.ID)
err = mach.SSSS.SetDefaultKeyID(ctx, key.ID)
if err != nil {
return recoveryKey, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
}
@@ -105,14 +106,14 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra
}
// UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key.
func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(key *ssss.Key, keys *CrossSigningKeysCache) error {
if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil {
func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil {
return err
}
if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil {
return err
}
if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil {
return err
}
return nil

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -19,7 +19,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
log := mach.machOrContextLog(ctx)
for userID, userKeys := range crossSigningKeys {
log := log.With().Str("user_id", userID.String()).Logger()
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
@@ -32,7 +32,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil {
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
log.Debug().
@@ -50,7 +50,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger()
for _, usage := range userKeys.Usage {
log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key")
if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil {
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {
log.Error().Err(err).Msg("Error storing cross-signing key")
}
}
@@ -85,7 +85,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
} else {
if verified {
log.Debug().Err(err).Msg("Cross-signing key signature verified")
err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature)
err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature)
if err != nil {
log.Error().Err(err).Msg("Error storing cross-signing key signature")
}

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -23,7 +23,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted {
return device.Trust, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
@@ -44,7 +44,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
Msg("Self-signing key of user not found")
return id.TrustStateUnset, nil
}
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
@@ -57,7 +57,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
Msg("Self-signing key of user is not signed by their master key")
return id.TrustStateUnset, nil
}
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
@@ -89,7 +89,7 @@ func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool {
// IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key.
// In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user.
func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bool, error) {
csPubkeys := mach.GetOwnCrossSigningPublicKeys()
csPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx)
if csPubkeys == nil {
return false, nil
}
@@ -97,14 +97,14 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
return true, nil
}
// first we verify our user-signing key
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database")
return false, err
} else if !ourUserSigningKeyTrusted {
return false, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).
@@ -118,7 +118,7 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
Msg("Master key of user not found")
return false, nil
}
sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
sigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).

View File

@@ -105,7 +105,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH
}, nil
}
func (helper *CryptoHelper) Init() error {
func (helper *CryptoHelper) Init(ctx context.Context) error {
if helper == nil {
return fmt.Errorf("crypto helper is nil")
}
@@ -116,7 +116,7 @@ func (helper *CryptoHelper) Init() error {
var stateStore crypto.StateStore
if helper.managedStateStore != nil {
err := helper.managedStateStore.Upgrade()
err := helper.managedStateStore.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade client state store: %w", err)
}
@@ -132,11 +132,14 @@ func (helper *CryptoHelper) Init() error {
} else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory {
helper.client.Store = managedCryptoStore
}
err := managedCryptoStore.DB.Upgrade()
err := managedCryptoStore.DB.Upgrade(ctx)
if err != nil {
return fmt.Errorf("failed to upgrade crypto state store: %w", err)
}
storedDeviceID := managedCryptoStore.FindDeviceID()
storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx)
if err != nil {
return fmt.Errorf("failed to find existing device ID: %w", err)
}
if helper.LoginAs != nil {
if storedDeviceID != "" {
helper.LoginAs.DeviceID = storedDeviceID
@@ -146,7 +149,7 @@ func (helper *CryptoHelper) Init() error {
Str("username", helper.LoginAs.Identifier.User).
Str("device_id", helper.LoginAs.DeviceID.String()).
Msg("Logging in")
_, err = helper.client.Login(helper.LoginAs)
_, err = helper.client.Login(ctx, helper.LoginAs)
if err != nil {
return err
}
@@ -167,10 +170,10 @@ func (helper *CryptoHelper) Init() error {
return fmt.Errorf("the client must be logged in")
}
helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore)
err := helper.mach.Load()
err := helper.mach.Load(ctx)
if err != nil {
return fmt.Errorf("failed to load olm account: %w", err)
} else if err = helper.verifyDeviceKeysOnServer(); err != nil {
} else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil {
return err
}
@@ -204,9 +207,9 @@ func (helper *CryptoHelper) Machine() *crypto.OlmMachine {
return helper.mach
}
func (helper *CryptoHelper) verifyDeviceKeysOnServer() error {
func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error {
helper.log.Debug().Msg("Making sure our device has the expected keys on the server")
resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
helper.client.UserID: {helper.client.DeviceID},
},
@@ -242,27 +245,29 @@ var NoSessionFound = crypto.NoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) {
func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Event) {
if helper == nil {
return
}
content := evt.Content.AsEncrypted()
// TODO use context log instead of helper?
log := helper.log.With().
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
log.Debug().Msg("Decrypting received event")
ctx = log.WithContext(ctx)
decrypted, err := helper.Decrypt(evt)
decrypted, err := helper.Decrypt(ctx, evt)
if errors.Is(err, NoSessionFound) {
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = helper.Decrypt(evt)
decrypted, err = helper.Decrypt(ctx, evt)
} else {
go helper.waitLongerForSession(log, src, evt)
go helper.waitLongerForSession(ctx, log, evt)
return
}
}
@@ -271,14 +276,15 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.
helper.DecryptErrorCallback(evt, err)
return
}
helper.postDecrypt(src, decrypted)
helper.postDecrypt(ctx, decrypted)
}
func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) {
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted)
func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) {
decrypted.Mautrix.EventSource |= event.SourceDecrypted
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted)
}
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
if helper == nil {
return
}
@@ -294,7 +300,7 @@ func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.Sender
Str("device_id", deviceID.String()).
Str("room_id", roomID.String()).
Logger()
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
userID: {deviceID},
helper.client.UserID: {"*"},
})
@@ -305,55 +311,54 @@ func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.Sender
}
}
func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) {
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
go helper.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
log.Debug().Msg("Didn't get session, giving up")
helper.DecryptErrorCallback(evt, NoSessionFound)
return
}
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
decrypted, err := helper.Decrypt(evt)
decrypted, err := helper.Decrypt(ctx, evt)
if err != nil {
log.Error().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
return
}
helper.postDecrypt(src, decrypted)
helper.postDecrypt(ctx, decrypted)
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
if helper == nil {
return false
}
helper.lock.RLock()
defer helper.lock.RUnlock()
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
return helper.mach.DecryptMegolmEvent(ctx, evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
helper.lock.RLock()
defer helper.lock.RUnlock()
ctx := context.TODO()
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
return
}
helper.log.Debug().
@@ -361,7 +366,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
Str("room_id", roomID.String()).
Msg("Got session error while encrypting event, sharing group session and trying again")
var users []id.UserID
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID)
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(ctx, roomID)
if err != nil {
err = fmt.Errorf("failed to get room member list: %w", err)
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -91,7 +91,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
} else {
forwardedKeys = true
lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1]
device, _ = mach.CryptoStore.FindDeviceByKey(evt.Sender, id.IdentityKey(lastChainItem))
device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem))
if device != nil {
trustLevel = mach.ResolveTrust(device)
} else {
@@ -188,7 +188,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
mach.megolmDecryptLock.Lock()
defer mach.megolmDecryptLock.Unlock()
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
@@ -250,7 +250,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
Int("max_messages", sess.MaxMessages).
Logger()
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
return sess, plaintext, messageIndex, RatchetError
@@ -261,14 +261,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, RatchetError
} else {

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -159,7 +159,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
}
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
endTimeTrace()
if err != nil {
log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting")
@@ -170,7 +170,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
sessions, err := mach.CryptoStore.GetSessions(senderKey)
sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey)
endTimeTrace()
if err != nil {
return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err)
@@ -199,7 +199,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
}
} else {
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
endTimeTrace()
if err != nil {
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
@@ -216,8 +216,8 @@ func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.S
if err != nil {
return nil, err
}
mach.saveAccount()
err = mach.CryptoStore.AddSession(senderKey, session)
mach.saveAccount(ctx)
err = mach.CryptoStore.AddSession(ctx, senderKey, session)
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session")
}
@@ -228,7 +228,7 @@ const MinUnwedgeInterval = 1 * time.Hour
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
log = log.With().Str("action", "unwedge olm session").Logger()
ctx := log.WithContext(context.Background())
ctx := log.WithContext(context.TODO())
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
delta := time.Now().Sub(prevUnwedge)

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -27,9 +27,16 @@ var (
InvalidKeySignature = errors.New("invalid signature on device keys")
)
func (mach *OlmMachine) LoadDevices(user id.UserID) map[id.DeviceID]*id.Device {
// TODO proper context?
return mach.fetchKeys(context.TODO(), []id.UserID{user}, "", true)[user]
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
log := zerolog.Ctx(ctx)
if keys, err := mach.FetchKeys(ctx, []id.UserID{user}, true); err != nil {
log.Err(err).Msg("Failed to load devices")
} else if keys != nil {
return keys[user]
}
return nil
}
func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
@@ -53,7 +60,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
Str("signed_device_id", deviceID.String()).
Str("signature", signature).
Msg("Verified self-signing signature")
err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
err = mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
if err != nil {
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
@@ -74,7 +81,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
}
// save signature of device made by its own device signing key
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
err := mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
if err != nil {
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
@@ -86,19 +93,16 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
}
}
func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
// TODO this function should probably return errors
func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) {
req := &mautrix.ReqQueryKeys{
DeviceKeys: mautrix.DeviceKeysRequest{},
Timeout: 10 * 1000,
Token: sinceToken,
}
log := mach.machOrContextLog(ctx)
if !includeUntracked {
var err error
users, err = mach.CryptoStore.FilterTrackedUsers(users)
users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users)
if err != nil {
log.Warn().Err(err).Msg("Failed to filter tracked user list")
return nil, fmt.Errorf("failed to filter tracked user list: %w", err)
}
}
if len(users) == 0 {
@@ -108,10 +112,9 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
req.DeviceKeys[userID] = mautrix.DeviceIDList{}
}
log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users")
resp, err := mach.Client.QueryKeys(req)
resp, err := mach.Client.QueryKeys(ctx, req)
if err != nil {
log.Error().Err(err).Msg("Failed to query keys")
return
return nil, fmt.Errorf("failed to query keys: %w", err)
}
for server, err := range resp.Failures {
log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server")
@@ -123,7 +126,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
delete(req.DeviceKeys, userID)
newDevices := make(map[id.DeviceID]*id.Device)
existingDevices, err := mach.CryptoStore.GetDevices(userID)
existingDevices, err := mach.CryptoStore.GetDevices(ctx, userID)
if err != nil {
log.Warn().Err(err).Msg("Failed to get existing devices for user")
existingDevices = make(map[id.DeviceID]*id.Device)
@@ -151,7 +154,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
}
}
log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list")
err = mach.CryptoStore.PutDevices(userID, newDevices)
err = mach.CryptoStore.PutDevices(ctx, userID, newDevices)
if err != nil {
log.Warn().Err(err).Msg("Failed to update device list")
}
@@ -169,7 +172,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
Str("identity_key", device.IdentityKey.String()).
Str("signing_key", device.SigningKey.String()).
Logger()
sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, "", device.IdentityKey, "device removed")
if err != nil {
log.Err(err).Msg("Failed to redact megolm sessions from deleted device")
} else {
@@ -179,7 +182,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
}
}
}
mach.OnDevicesChanged(userID)
mach.OnDevicesChanged(ctx, userID)
}
}
for userID := range req.DeviceKeys {
@@ -190,25 +193,32 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys)
return data
return data, nil
}
// OnDevicesChanged finds all shared rooms with the given user and invalidates outbound sessions in those rooms.
//
// This is called automatically whenever a device list change is noticed in ProcessSyncResponse and usually does
// not need to be called manually.
func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) {
func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) {
if mach.DisableDeviceChangeKeyRotation {
return
}
for _, roomID := range mach.StateStore.FindSharedRooms(userID) {
mach.Log.Debug().
rooms, err := mach.StateStore.FindSharedRooms(ctx, userID)
if err != nil {
mach.machOrContextLog(ctx).Err(err).
Stringer("with_user_id", userID).
Msg("Failed to find shared rooms to invalidate group sessions")
return
}
for _, roomID := range rooms {
mach.machOrContextLog(ctx).Debug().
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Invalidating group session in room due to device change notification")
err := mach.CryptoStore.RemoveOutboundGroupSession(roomID)
err = mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID)
if err != nil {
mach.Log.Warn().Err(err).
mach.machOrContextLog(ctx).Err(err).
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Failed to invalidate outbound group session")

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -84,7 +84,7 @@ func parseMessageIndex(ciphertext []byte) (uint, error) {
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
} else if session == nil {
@@ -116,7 +116,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
log = log.With().Uint("message_index", idx).Logger()
}
log.Debug().Msg("Encrypted event successfully")
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
if err != nil {
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
}
@@ -137,7 +137,13 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
}
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID)
if err != nil {
mach.machOrContextLog(ctx).Err(err).
Stringer("room_id", roomID).
Msg("Failed to get encryption event in room")
}
session := NewOutboundGroupSession(roomID, encryptionEvent)
if !mach.DontStoreOutboundKeys {
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
@@ -165,7 +171,7 @@ func strishArray[T ~string](arr []T) []string {
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
@@ -192,7 +198,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
for _, userID := range users {
log := log.With().Str("target_user_id", userID.String()).Logger()
devices, err := mach.CryptoStore.GetDevices(userID)
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
if err != nil {
log.Error().Err(err).Msg("Failed to get devices of user")
} else if devices == nil {
@@ -223,12 +229,16 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
if len(fetchKeys) > 0 {
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
log.Debug().
Int("device_count", len(devices)).
Str("target_user_id", userID.String()).
Msg("Got device keys for user")
missingSessions[userID] = devices
if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil {
log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys")
} else if keys != nil {
for userID, devices := range keys {
log.Debug().
Int("device_count", len(devices)).
Str("target_user_id", userID.String()).
Msg("Got device keys for user")
missingSessions[userID] = devices
}
}
}
@@ -280,11 +290,11 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
Int("user_count", len(toDeviceWithheld.Messages)).
Msg("Sending to-device messages to report withheld key")
// TODO remove the next 4 lines once clients support m.room_key.withheld
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
if err != nil {
log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)")
}
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
if err != nil {
log.Warn().Err(err).Msg("Failed to report withheld keys")
}
@@ -292,7 +302,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
log.Debug().Msg("Group session successfully shared")
session.Shared = true
return mach.CryptoStore.AddOutboundGroupSession(session)
return mach.CryptoStore.AddOutboundGroupSession(ctx, session)
}
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
@@ -327,7 +337,7 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
return err
}
@@ -367,7 +377,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out
Reason: "This device does not encrypt messages for unverified devices",
}}
session.Users[userKey] = OGSIgnored
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey); err != nil {
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
} else if deviceSession == nil {
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -38,7 +38,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
Str("olm_session_description", session.Describe()).
Msg("Encrypting olm message")
msgType, ciphertext := session.Encrypt(plaintext)
err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session)
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
if err != nil {
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
}
@@ -54,8 +54,8 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
}
}
func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool {
if !mach.CryptoStore.HasSession(identityKey) {
func (mach *OlmMachine) shouldCreateNewSession(ctx context.Context, identityKey id.IdentityKey) bool {
if !mach.CryptoStore.HasSession(ctx, identityKey) {
return true
}
mach.devicesToUnwedgeLock.Lock()
@@ -72,7 +72,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
for userID, devices := range input {
request[userID] = make(map[id.DeviceID]id.KeyAlgorithm)
for deviceID, identity := range devices {
if mach.shouldCreateNewSession(identity.IdentityKey) {
if mach.shouldCreateNewSession(ctx, identity.IdentityKey) {
request[userID][deviceID] = id.KeyAlgorithmSignedCurve25519
}
}
@@ -83,7 +83,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
if len(request) == 0 {
return nil
}
resp, err := mach.Client.ClaimKeys(&mautrix.ReqClaimKeys{
resp, err := mach.Client.ClaimKeys(ctx, &mautrix.ReqClaimKeys{
OneTimeKeys: request,
Timeout: 10 * 1000,
})
@@ -117,7 +117,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
} else {
wrapped := wrapSession(sess)
err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped)
err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped)
if err != nil {
log.Error().Err(err).Msg("Failed to store created outbound session")
} else {

5
vendor/maunium.net/go/mautrix/crypto/goolm/README.md generated vendored Normal file
View File

@@ -0,0 +1,5 @@
# go-olm
This is a fork of [DerLukas's goolm](https://codeberg.org/DerLukas/goolm),
a pure Go implementation of libolm.
The original project is licensed under the MIT license.

View File

@@ -0,0 +1,522 @@
// account packages an account which stores the identity, one time keys and fallback keys.
package account
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
accountPickleVersionJSON byte = 1
accountPickleVersionLibOLM uint32 = 4
MaxOneTimeKeys int = 100 //maximum number of stored one time keys per Account
)
// Account stores an account for end to end encrypted messaging via the olm protocol.
// An Account can not be used to en/decrypt messages. However it can be used to contruct new olm sessions, which in turn do the en/decryption.
// There is no tracking of sessions in an account.
type Account struct {
IdKeys struct {
Ed25519 crypto.Ed25519KeyPair `json:"ed25519,omitempty"`
Curve25519 crypto.Curve25519KeyPair `json:"curve25519,omitempty"`
} `json:"identity_keys"`
OTKeys []crypto.OneTimeKey `json:"one_time_keys"`
CurrentFallbackKey crypto.OneTimeKey `json:"current_fallback_key,omitempty"`
PrevFallbackKey crypto.OneTimeKey `json:"prev_fallback_key,omitempty"`
NextOneTimeKeyID uint32 `json:"next_one_time_key_id,omitempty"`
NumFallbackKeys uint8 `json:"number_fallback_keys"`
}
// AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
func AccountFromJSONPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
}
a := &Account{}
err := a.UnpickleAsJSON(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
}
a := &Account{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation.
func NewAccount(reader io.Reader) (*Account, error) {
a := &Account{}
kPEd25519, err := crypto.Ed25519GenerateKey(reader)
if err != nil {
return nil, err
}
a.IdKeys.Ed25519 = kPEd25519
kPCurve25519, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return nil, err
}
a.IdKeys.Curve25519 = kPCurve25519
return a, nil
}
// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a Account) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, accountPickleVersionJSON, key)
}
// UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (a *Account) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON)
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string.
func (a Account) IdentityKeysJSON() ([]byte, error) {
res := struct {
Ed25519 string `json:"ed25519"`
Curve25519 string `json:"curve25519"`
}{}
ed25519, curve25519 := a.IdentityKeys()
res.Ed25519 = string(ed25519)
res.Curve25519 = string(curve25519)
return json.Marshal(res)
}
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account.
func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey))
curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey))
return ed25519, curve25519
}
// Sign returns the signature of a message using the Ed25519 key for this Account.
func (a Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput)
}
return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil
}
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
//
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) OneTimeKeys() map[string]id.Curve25519 {
oneTimeKeys := make(map[string]id.Curve25519)
for _, curKey := range a.OTKeys {
if !curKey.Published {
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
}
}
return oneTimeKeys
}
//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
Curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
"AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
}
}
*/
func (a Account) OneTimeKeysJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
otKeys := a.OneTimeKeys()
res["Curve25519"] = otKeys
return json.Marshal(res)
}
// MarkKeysAsPublished marks the current set of one time keys and the fallback key as being
// published.
func (a *Account) MarkKeysAsPublished() {
for keyIndex := range a.OTKeys {
if !a.OTKeys[keyIndex].Published {
a.OTKeys[keyIndex].Published = true
}
}
a.CurrentFallbackKey.Published = true
}
// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxOneTimeKeys then the older
// keys are discarded. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error {
for i := uint(0); i < num; i++ {
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return err
}
key.Key = newKP
a.NextOneTimeKeyID++
a.OTKeys = append([]crypto.OneTimeKey{key}, a.OTKeys...)
}
if len(a.OTKeys) > MaxOneTimeKeys {
a.OTKeys = a.OTKeys[:MaxOneTimeKeys]
}
return nil
}
// NewOutboundSession creates a new outbound session to a
// given curve25519 identity Key and one time key.
func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput)
}
theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey))
if err != nil {
return nil, err
}
theirOneTimeKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirOneTimeKey))
if err != nil {
return nil, err
}
s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded)
if err != nil {
return nil, err
}
return s, nil
}
// NewInboundSession creates a new inbound session from an incoming PRE_KEY message.
func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput)
}
var theirIdentityKeyDecoded *crypto.Curve25519PublicKey
var err error
if theirIdentityKey != nil {
theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey))
if err != nil {
return nil, err
}
theirIdentityKeyCurve := crypto.Curve25519PublicKey(theirIdentityKeyDecodedByte)
theirIdentityKeyDecoded = &theirIdentityKeyCurve
}
s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519)
if err != nil {
return nil, err
}
return s, nil
}
func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
return &a.OTKeys[curIndex]
}
}
if a.NumFallbackKeys >= 1 && a.CurrentFallbackKey.Key.PublicKey.Equal(toFind) {
return &a.CurrentFallbackKey
}
if a.NumFallbackKeys >= 2 && a.PrevFallbackKey.Key.PublicKey.Equal(toFind) {
return &a.PrevFallbackKey
}
return nil
}
// RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s.
func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) {
toFind := s.BobOneTimeKey
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
//Remove and return
a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1]
a.OTKeys = a.OTKeys[:len(a.OTKeys)-1]
return
}
}
//if the key is a fallback or prevFallback, don't remove it
}
// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenFallbackKey(reader io.Reader) error {
a.PrevFallbackKey = a.CurrentFallbackKey
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return err
}
key.Key = newKP
a.NextOneTimeKeyID++
if a.NumFallbackKeys < 2 {
a.NumFallbackKeys++
}
a.CurrentFallbackKey = key
return nil
}
// FallbackKey returns the public part of the current fallback key of the Account.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKey() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
}
return keys
}
//FallbackKeyJSON returns the public part of the current fallback key of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
}
}
*/
func (a Account) FallbackKeyJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKey()
res["curve25519"] = fbk
return json.Marshal(res)
}
// FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
}
return keys
}
//FallbackKeyUnpublishedJSON returns the public part of the current fallback key, only if it is unpublished, of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
}
}
*/
func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKeyUnpublished()
res["curve25519"] = fbk
return json.Marshal(res)
}
// ForgetOldFallbackKey resets the previous fallback key in the account.
func (a *Account) ForgetOldFallbackKey() {
if a.NumFallbackKeys >= 2 {
a.NumFallbackKeys = 1
a.PrevFallbackKey = crypto.OneTimeKey{}
}
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (a *Account) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = a.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read.
func (a *Account) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case accountPickleVersionLibOLM, 3, 2:
default:
return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion)
}
//read ed25519 key pair
readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//read curve25519 key pair
readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//Read number of onetimeKeys
numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//Read i one time keys
a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys)
for i := uint32(0); i < numberOTKeys; i++ {
readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
if pickledVersion <= 2 {
// version 2 did not have fallback keys
a.NumFallbackKeys = 0
} else if pickledVersion == 3 {
// version 3 used the published flag to indicate how many fallback keys
// were present (we'll have to assume that the keys were published)
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if a.CurrentFallbackKey.Published {
if a.PrevFallbackKey.Published {
a.NumFallbackKeys = 2
} else {
a.NumFallbackKeys = 1
}
} else {
a.NumFallbackKeys = 0
}
} else {
//Read number of fallback keys
numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
a.NumFallbackKeys = numFallbackKeys
if a.NumFallbackKeys >= 1 {
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if a.NumFallbackKeys >= 2 {
readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
}
}
//Read next onetime key id
nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
a.NextOneTimeKeyID = nextOTKeyID
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm().
func (a Account) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, a.PickleLen())
written, err := a.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (a Account) PickleLibOlm(target []byte) (int, error) {
if len(target) < a.PickleLen() {
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target)
writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenEdKey
writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenCurveKey
written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:])
for _, curOTKey := range a.OTKeys {
writtenOT, err := curOTKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
}
written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:])
if a.NumFallbackKeys >= 1 {
writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
if a.NumFallbackKeys >= 2 {
writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
}
}
written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled Account will have.
func (a Account) PickleLen() int {
length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM)
length += a.IdKeys.Ed25519.PickleLen()
length += a.IdKeys.Curve25519.PickleLen()
length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys)))
length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen())
length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys)
length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen())
length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID)
return length
}

22
vendor/maunium.net/go/mautrix/crypto/goolm/base64.go generated vendored Normal file
View File

@@ -0,0 +1,22 @@
package goolm
import (
"encoding/base64"
)
// Deprecated: base64.RawStdEncoding should be used directly
func Base64Decode(input []byte) ([]byte, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
if err != nil {
return nil, err
}
return decoded[:writtenBytes], nil
}
// Deprecated: base64.RawStdEncoding should be used directly
func Base64Encode(input []byte) []byte {
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
base64.RawStdEncoding.Encode(encoded, input)
return encoded
}

View File

@@ -0,0 +1,96 @@
package cipher
import (
"bytes"
"io"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// derivedAESKeys stores the derived keys for the AESSHA256 cipher
type derivedAESKeys struct {
key []byte
hmacKey []byte
iv []byte
}
// deriveAESKeys derives three keys for the AESSHA256 cipher
func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) {
hkdf := crypto.HKDFSHA256(key, nil, kdfInfo)
keys := &derivedAESKeys{
key: make([]byte, 32),
hmacKey: make([]byte, 32),
iv: make([]byte, 16),
}
if _, err := io.ReadFull(hkdf, keys.key); err != nil {
return nil, err
}
if _, err := io.ReadFull(hkdf, keys.hmacKey); err != nil {
return nil, err
}
if _, err := io.ReadFull(hkdf, keys.iv); err != nil {
return nil, err
}
return keys, nil
}
// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256.
func AESSha512BlockSize() int {
return crypto.AESCBCBlocksize()
}
// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256.
type AESSHA256 struct {
kdfInfo []byte
}
// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo).
func NewAESSHA256(kdfInfo []byte) *AESSHA256 {
return &AESSHA256{
kdfInfo: kdfInfo,
}
}
// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes).
func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) {
keys, err := deriveAESKeys(c.kdfInfo, key)
if err != nil {
return nil, err
}
ciphertext, err = crypto.AESCBCEncrypt(keys.key, keys.iv, plaintext)
if err != nil {
return nil, err
}
return ciphertext, nil
}
// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes).
func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) {
keys, err := deriveAESKeys(c.kdfInfo, key)
if err != nil {
return nil, err
}
plaintext, err = crypto.AESCBCDecrypt(keys.key, keys.iv, ciphertext)
if err != nil {
return nil, err
}
return plaintext, nil
}
// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes).
func (c AESSHA256) MAC(key, message []byte) ([]byte, error) {
keys, err := deriveAESKeys(c.kdfInfo, key)
if err != nil {
return nil, err
}
return crypto.HMACSHA256(keys.hmacKey, message), nil
}
// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes).
func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) {
mac, err := c.MAC(key, message)
if err != nil {
return false, err
}
return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil
}

View File

@@ -0,0 +1,17 @@
// cipher provides the methods and structs to do encryptions for olm/megolm.
package cipher
// Cipher defines a valid cipher.
type Cipher interface {
// Encrypt encrypts the plaintext.
Encrypt(key, plaintext []byte) (ciphertext []byte, err error)
// Decrypt decrypts the ciphertext.
Decrypt(key, ciphertext []byte) (plaintext []byte, err error)
//MAC returns the MAC of the message calculated with the key.
MAC(key, message []byte) ([]byte, error)
//Verify checks the MAC of the message calculated with the key against the givenMAC.
Verify(key, message, givenMAC []byte) (bool, error)
}

View File

@@ -0,0 +1,58 @@
package cipher
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
)
const (
kdfPickle = "Pickle" //used to derive the keys for encryption
pickleMACLength = 8
)
// PickleBlockSize returns the blocksize of the used cipher.
func PickleBlockSize() int {
return AESSha512BlockSize()
}
// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64.
func Pickle(key, input []byte) ([]byte, error) {
pickleCipher := NewAESSHA256([]byte(kdfPickle))
ciphertext, err := pickleCipher.Encrypt(key, input)
if err != nil {
return nil, err
}
mac, err := pickleCipher.MAC(key, ciphertext)
if err != nil {
return nil, err
}
ciphertext = append(ciphertext, mac[:pickleMACLength]...)
encoded := goolm.Base64Encode(ciphertext)
return encoded, nil
}
// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256.
func Unpickle(key, input []byte) ([]byte, error) {
pickleCipher := NewAESSHA256([]byte(kdfPickle))
ciphertext, err := goolm.Base64Decode(input)
if err != nil {
return nil, err
}
//remove mac and check
verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:])
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt pickle: %w", goolm.ErrBadMAC)
}
//Set to next block size
targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize())
copy(targetCipherText, ciphertext)
plaintext, err := pickleCipher.Decrypt(key, targetCipherText)
if err != nil {
return nil, err
}
return plaintext, nil
}

View File

@@ -0,0 +1,75 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
)
// AESCBCBlocksize returns the blocksize of the encryption method
func AESCBCBlocksize() int {
return aes.BlockSize
}
// AESCBCEncrypt encrypts the plaintext with the key and iv. len(iv) must be equal to the blocksize!
func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
}
if len(iv) != AESCBCBlocksize() {
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
}
var cipherText []byte
plaintext = pkcs5Padding(plaintext, AESCBCBlocksize())
if len(plaintext)%AESCBCBlocksize() != 0 {
return nil, fmt.Errorf("message: %w", goolm.ErrNotMultipleBlocksize)
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cipherText = make([]byte, len(plaintext))
cbc := cipher.NewCBCEncrypter(block, iv)
cbc.CryptBlocks(cipherText, plaintext)
return cipherText, nil
}
// AESCBCDecrypt decrypts the ciphertext with the key and iv. len(iv) must be equal to the blocksize!
func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
}
if len(iv) != AESCBCBlocksize() {
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
}
var block cipher.Block
var err error
block, err = aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(ciphertext) < AESCBCBlocksize() {
return nil, fmt.Errorf("ciphertext: %w", goolm.ErrNotMultipleBlocksize)
}
cbc := cipher.NewCBCDecrypter(block, iv)
cbc.CryptBlocks(ciphertext, ciphertext)
return pkcs5Unpadding(ciphertext), nil
}
// pkcs5Padding paddes the plaintext to be used in the AESCBC encryption.
func pkcs5Padding(plaintext []byte, blockSize int) []byte {
padding := (blockSize - len(plaintext)%blockSize)
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(plaintext, padtext...)
}
// pkcs5Unpadding undoes the padding to the plaintext after AESCBC decryption.
func pkcs5Unpadding(plaintext []byte) []byte {
length := len(plaintext)
unpadding := int(plaintext[length-1])
return plaintext[:(length - unpadding)]
}

View File

@@ -0,0 +1,186 @@
package crypto
import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"golang.org/x/crypto/curve25519"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/id"
)
const (
Curve25519KeyLength = curve25519.ScalarSize //The length of the private key.
curve25519PubKeyLength = 32
)
// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand.
func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) {
privateKeyByte := make([]byte, Curve25519KeyLength)
if reader == nil {
_, err := rand.Read(privateKeyByte)
if err != nil {
return Curve25519KeyPair{}, err
}
} else {
_, err := reader.Read(privateKeyByte)
if err != nil {
return Curve25519KeyPair{}, err
}
}
privateKey := Curve25519PrivateKey(privateKeyByte)
publicKey, err := privateKey.PubKey()
if err != nil {
return Curve25519KeyPair{}, err
}
return Curve25519KeyPair{
PrivateKey: Curve25519PrivateKey(privateKey),
PublicKey: Curve25519PublicKey(publicKey),
}, nil
}
// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given.
func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) {
publicKey, err := private.PubKey()
if err != nil {
return Curve25519KeyPair{}, err
}
return Curve25519KeyPair{
PrivateKey: private,
PublicKey: Curve25519PublicKey(publicKey),
}, nil
}
// Curve25519KeyPair stores both parts of a curve25519 key.
type Curve25519KeyPair struct {
PrivateKey Curve25519PrivateKey `json:"private,omitempty"`
PublicKey Curve25519PublicKey `json:"public,omitempty"`
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
return c.PublicKey.B64Encoded()
}
// SharedSecret returns the shared secret between the key pair and the given public key.
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
return c.PrivateKey.SharedSecret(pubKey)
}
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle curve25519 key pair: %w", goolm.ErrValueTooShort)
}
written, err := c.PublicKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle curve25519 key pair: %w", err)
}
if len(c.PrivateKey) != Curve25519KeyLength {
written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:])
} else {
written += libolmpickle.PickleBytes(c.PrivateKey, target[written:])
}
return written, nil
}
// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read.
func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) {
//unpickle PubKey
read, err := c.PublicKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
//unpickle PrivateKey
privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519KeyLength)
if err != nil {
return read, err
}
c.PrivateKey = privKey
return read + readPriv, nil
}
// PickleLen returns the number of bytes the pickled key pair will have.
func (c Curve25519KeyPair) PickleLen() int {
lenPublic := c.PublicKey.PickleLen()
var lenPrivate int
if len(c.PrivateKey) != Curve25519KeyLength {
lenPrivate = libolmpickle.PickleBytesLen(make([]byte, Curve25519KeyLength))
} else {
lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey)
}
return lenPublic + lenPrivate
}
// Curve25519PrivateKey represents the private key for curve25519 usage
type Curve25519PrivateKey []byte
// Equal compares the private key to the given private key.
func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool {
return bytes.Equal(c, x)
}
// PubKey returns the public key derived from the private key.
func (c Curve25519PrivateKey) PubKey() (Curve25519PublicKey, error) {
publicKey, err := curve25519.X25519(c, curve25519.Basepoint)
if err != nil {
return nil, err
}
return publicKey, nil
}
// SharedSecret returns the shared secret between the private key and the given public key.
func (c Curve25519PrivateKey) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
return curve25519.X25519(c, pubKey)
}
// Curve25519PublicKey represents the public key for curve25519 usage
type Curve25519PublicKey []byte
// Equal compares the public key to the given public key.
func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool {
return bytes.Equal(c, x)
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Curve25519PublicKey) B64Encoded() id.Curve25519 {
return id.Curve25519(base64.RawStdEncoding.EncodeToString(c))
}
// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle curve25519 public key: %w", goolm.ErrValueTooShort)
}
if len(c) != curve25519PubKeyLength {
return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil
}
return libolmpickle.PickleBytes(c, target), nil
}
// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read.
func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) {
unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, curve25519PubKeyLength)
if err != nil {
return 0, err
}
*c = unpickled
return readBytes, nil
}
// PickleLen returns the number of bytes the pickled public key will have.
func (c Curve25519PublicKey) PickleLen() int {
if len(c) != curve25519PubKeyLength {
return libolmpickle.PickleBytesLen(make([]byte, curve25519PubKeyLength))
}
return libolmpickle.PickleBytesLen(c)
}

View File

@@ -0,0 +1,181 @@
package crypto
import (
"bytes"
"crypto/ed25519"
"encoding/base64"
"fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/id"
)
const (
ED25519SignatureSize = ed25519.SignatureSize //The length of a signature
)
// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand.
func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) {
publicKey, privateKey, err := ed25519.GenerateKey(reader)
if err != nil {
return Ed25519KeyPair{}, err
}
return Ed25519KeyPair{
PrivateKey: Ed25519PrivateKey(privateKey),
PublicKey: Ed25519PublicKey(publicKey),
}, nil
}
// Ed25519GenerateFromPrivate creates a new ed25519 key pair with the private key given.
func Ed25519GenerateFromPrivate(privKey Ed25519PrivateKey) Ed25519KeyPair {
return Ed25519KeyPair{
PrivateKey: privKey,
PublicKey: privKey.PubKey(),
}
}
// Ed25519GenerateFromSeed creates a new ed25519 key pair with a given seed.
func Ed25519GenerateFromSeed(seed []byte) Ed25519KeyPair {
privKey := Ed25519PrivateKey(ed25519.NewKeyFromSeed(seed))
return Ed25519KeyPair{
PrivateKey: privKey,
PublicKey: privKey.PubKey(),
}
}
// Ed25519KeyPair stores both parts of a ed25519 key.
type Ed25519KeyPair struct {
PrivateKey Ed25519PrivateKey `json:"private,omitempty"`
PublicKey Ed25519PublicKey `json:"public,omitempty"`
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Ed25519KeyPair) B64Encoded() id.Ed25519 {
return id.Ed25519(base64.RawStdEncoding.EncodeToString(c.PublicKey))
}
// Sign returns the signature for the message.
func (c Ed25519KeyPair) Sign(message []byte) []byte {
return c.PrivateKey.Sign(message)
}
// Verify checks the signature of the message against the givenSignature
func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool {
return c.PublicKey.Verify(message, givenSignature)
}
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle ed25519 key pair: %w", goolm.ErrValueTooShort)
}
written, err := c.PublicKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle ed25519 key pair: %w", err)
}
if len(c.PrivateKey) != ed25519.PrivateKeySize {
written += libolmpickle.PickleBytes(make([]byte, ed25519.PrivateKeySize), target[written:])
} else {
written += libolmpickle.PickleBytes(c.PrivateKey, target[written:])
}
return written, nil
}
// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read.
func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) {
//unpickle PubKey
read, err := c.PublicKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
//unpickle PrivateKey
privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], ed25519.PrivateKeySize)
if err != nil {
return read, err
}
c.PrivateKey = privKey
return read + readPriv, nil
}
// PickleLen returns the number of bytes the pickled key pair will have.
func (c Ed25519KeyPair) PickleLen() int {
lenPublic := c.PublicKey.PickleLen()
var lenPrivate int
if len(c.PrivateKey) != ed25519.PrivateKeySize {
lenPrivate = libolmpickle.PickleBytesLen(make([]byte, ed25519.PrivateKeySize))
} else {
lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey)
}
return lenPublic + lenPrivate
}
// Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper.
type Ed25519PrivateKey ed25519.PrivateKey
// Equal compares the private key to the given private key.
func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool {
return bytes.Equal(c, x)
}
// PubKey returns the public key derived from the private key.
func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey {
publicKey := ed25519.PrivateKey(c).Public()
return Ed25519PublicKey(publicKey.(ed25519.PublicKey))
}
// Sign returns the signature for the message.
func (c Ed25519PrivateKey) Sign(message []byte) []byte {
return ed25519.Sign(ed25519.PrivateKey(c), message)
}
// Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper.
type Ed25519PublicKey ed25519.PublicKey
// Equal compares the public key to the given public key.
func (c Ed25519PublicKey) Equal(x Ed25519PublicKey) bool {
return bytes.Equal(c, x)
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Ed25519PublicKey) B64Encoded() id.Curve25519 {
return id.Curve25519(base64.RawStdEncoding.EncodeToString(c))
}
// Verify checks the signature of the message against the givenSignature
func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool {
return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature)
}
// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle ed25519 public key: %w", goolm.ErrValueTooShort)
}
if len(c) != ed25519.PublicKeySize {
return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil
}
return libolmpickle.PickleBytes(c, target), nil
}
// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read.
func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) {
unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, ed25519.PublicKeySize)
if err != nil {
return 0, err
}
*c = unpickled
return readBytes, nil
}
// PickleLen returns the number of bytes the pickled public key will have.
func (c Ed25519PublicKey) PickleLen() int {
if len(c) != ed25519.PublicKeySize {
return libolmpickle.PickleBytesLen(make([]byte, ed25519.PublicKeySize))
}
return libolmpickle.PickleBytesLen(c)
}

View File

@@ -0,0 +1,29 @@
package crypto
import (
"crypto/hmac"
"crypto/sha256"
"io"
"golang.org/x/crypto/hkdf"
)
// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key.
func HMACSHA256(key, input []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(input)
return hash.Sum(nil)
}
// SHA256 return the SHA-256 of the value.
func SHA256(value []byte) []byte {
hash := sha256.New()
hash.Write(value)
return hash.Sum(nil)
}
// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil.
// The reader can be used to read an arbitary length of bytes which are based on all parameters.
func HKDFSHA256(input, salt, info []byte) io.Reader {
return hkdf.New(sha256.New, input, salt, info)
}

View File

@@ -0,0 +1,2 @@
// crpyto provides the nessesary encryption methods for olm/megolm
package crypto

View File

@@ -0,0 +1,95 @@
package crypto
import (
"encoding/base64"
"encoding/binary"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/id"
)
// OneTimeKey stores the information about a one time key.
type OneTimeKey struct {
ID uint32 `json:"id"`
Published bool `json:"published"`
Key Curve25519KeyPair `json:"key,omitempty"`
}
// Equal compares the one time key to the given one.
func (otk OneTimeKey) Equal(s OneTimeKey) bool {
if otk.ID != s.ID {
return false
}
if otk.Published != s.Published {
return false
}
if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) {
return false
}
if !otk.Key.PublicKey.Equal(s.Key.PublicKey) {
return false
}
return true
}
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle one time key: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(uint32(c.ID), target)
written += libolmpickle.PickleBool(c.Published, target[written:])
writtenKey, err := c.Key.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle one time key: %w", err)
}
written += writtenKey
return written, nil
}
// UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read.
func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) {
totalReadBytes := 0
id, readBytes, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
totalReadBytes += readBytes
c.ID = id
published, readBytes, err := libolmpickle.UnpickleBool(value[totalReadBytes:])
if err != nil {
return 0, err
}
totalReadBytes += readBytes
c.Published = published
readBytes, err = c.Key.UnpickleLibOlm(value[totalReadBytes:])
if err != nil {
return 0, err
}
totalReadBytes += readBytes
return totalReadBytes, nil
}
// PickleLen returns the number of bytes the pickled OneTimeKey will have.
func (c OneTimeKey) PickleLen() int {
length := 0
length += libolmpickle.PickleUInt32Len(c.ID)
length += libolmpickle.PickleBoolLen(c.Published)
length += c.Key.PickleLen()
return length
}
// KeyIDEncoded returns the base64 encoded id.
func (c OneTimeKey) KeyIDEncoded() string {
resSlice := make([]byte, 4)
binary.BigEndian.PutUint32(resSlice, c.ID)
return base64.RawStdEncoding.EncodeToString(resSlice)
}
// PublicKeyEncoded returns the base64 encoded public key
func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 {
return c.Key.PublicKey.B64Encoded()
}

30
vendor/maunium.net/go/mautrix/crypto/goolm/errors.go generated vendored Normal file
View File

@@ -0,0 +1,30 @@
package goolm
import (
"errors"
)
// Those are the most common used errors
var (
ErrBadSignature = errors.New("bad signature")
ErrBadMAC = errors.New("bad mac")
ErrBadMessageFormat = errors.New("bad message format")
ErrBadVerification = errors.New("bad verification")
ErrWrongProtocolVersion = errors.New("wrong protocol version")
ErrEmptyInput = errors.New("empty input")
ErrNoKeyProvided = errors.New("no key")
ErrBadMessageKeyID = errors.New("bad message key id")
ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
ErrMsgIndexTooHigh = errors.New("message index too high")
ErrProtocolViolation = errors.New("not protocol message order")
ErrMessageKeyNotFound = errors.New("message key not found")
ErrChainTooHigh = errors.New("chain index too high")
ErrBadInput = errors.New("bad input")
ErrBadVersion = errors.New("wrong version")
ErrNotBlocksize = errors.New("length != blocksize")
ErrNotMultipleBlocksize = errors.New("length not a multiple of the blocksize")
ErrWrongPickleVersion = errors.New("wrong pickle version")
ErrValueTooShort = errors.New("value too short")
ErrInputToSmall = errors.New("input too small (truncated?)")
ErrOverflow = errors.New("overflow")
)

View File

@@ -0,0 +1,41 @@
package libolmpickle
import (
"encoding/binary"
)
func PickleUInt8(value uint8, target []byte) int {
target[0] = value
return 1
}
func PickleUInt8Len(value uint8) int {
return 1
}
func PickleBool(value bool, target []byte) int {
if value {
target[0] = 0x01
} else {
target[0] = 0x00
}
return 1
}
func PickleBoolLen(value bool) int {
return 1
}
func PickleBytes(value, target []byte) int {
return copy(target, value)
}
func PickleBytesLen(value []byte) int {
return len(value)
}
func PickleUInt32(value uint32, target []byte) int {
res := make([]byte, 4) //4 bytes for int32
binary.BigEndian.PutUint32(res, value)
return copy(target, res)
}
func PickleUInt32Len(value uint32) int {
return 4
}

View File

@@ -0,0 +1,53 @@
package libolmpickle
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
)
func isZeroByteSlice(bytes []byte) bool {
b := byte(0)
for _, s := range bytes {
b |= s
}
return b == 0
}
func UnpickleUInt8(value []byte) (uint8, int, error) {
if len(value) < 1 {
return 0, 0, fmt.Errorf("unpickle uint8: %w", goolm.ErrValueTooShort)
}
return value[0], 1, nil
}
func UnpickleBool(value []byte) (bool, int, error) {
if len(value) < 1 {
return false, 0, fmt.Errorf("unpickle bool: %w", goolm.ErrValueTooShort)
}
return value[0] != uint8(0x00), 1, nil
}
func UnpickleBytes(value []byte, length int) ([]byte, int, error) {
if len(value) < length {
return nil, 0, fmt.Errorf("unpickle bytes: %w", goolm.ErrValueTooShort)
}
resp := value[:length]
if isZeroByteSlice(resp) {
return nil, length, nil
}
return resp, length, nil
}
func UnpickleUInt32(value []byte) (uint32, int, error) {
if len(value) < 4 {
return 0, 0, fmt.Errorf("unpickle uint32: %w", goolm.ErrValueTooShort)
}
var res uint32
count := 0
for i := 3; i >= 0; i-- {
res |= uint32(value[count]) << (8 * i)
count++
}
return res, 4, nil
}

6
vendor/maunium.net/go/mautrix/crypto/goolm/main.go generated vendored Normal file
View File

@@ -0,0 +1,6 @@
// Package goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and written in C++.
// With goolm there is no need to use cgo when building Matrix clients in go.
/*
This package contains the possible errors which can occur as well as some simple functions. All the 'action' happens in the subdirectories.
*/
package goolm

View File

@@ -0,0 +1,234 @@
// megolm provides the ratchet used by the megolm protocol
package megolm
import (
"crypto/rand"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
megolmPickleVersion uint8 = 1
)
const (
protocolVersion = 3
RatchetParts = 4 // number of ratchet parts
RatchetPartLength = 256 / 8 // length of each ratchet part in bytes
)
var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS"))
// hasKeySeed are the seed for the different ratchet parts
var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{
{0x00},
{0x01},
{0x02},
{0x03},
}
// Ratchet represents the megolm ratchet as described in
//
// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/megolm.md
type Ratchet struct {
Data [RatchetParts * RatchetPartLength]byte `json:"data"`
Counter uint32 `json:"counter"`
}
// New creates a new ratchet with counter set to counter and the ratchet data set to data.
func New(counter uint32, data [RatchetParts * RatchetPartLength]byte) (*Ratchet, error) {
m := &Ratchet{
Counter: counter,
Data: data,
}
return m, nil
}
// NewWithRandom creates a new ratchet with counter set to counter an the data filled with random values.
func NewWithRandom(counter uint32) (*Ratchet, error) {
var data [RatchetParts * RatchetPartLength]byte
_, err := rand.Read(data[:])
if err != nil {
return nil, err
}
return New(counter, data)
}
// rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to.
func (m *Ratchet) rehashPart(from, to int) {
newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to])
copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength])
}
// Advance advances the ratchet one step.
func (m *Ratchet) Advance() {
var mask uint32 = 0x00FFFFFF
var h int
m.Counter++
// figure out how much we need to rekey
for h < RatchetParts {
if (m.Counter & mask) == 0 {
break
}
h++
mask >>= 8
}
// now update R(h)...R(3) based on R(h)
for i := RatchetParts - 1; i >= h; i-- {
m.rehashPart(h, i)
}
}
// AdvanceTo advances the ratchet so that the ratchet counter = target
func (m *Ratchet) AdvanceTo(target uint32) {
//starting with R0, see if we need to update each part of the hash
for j := 0; j < RatchetParts; j++ {
shift := uint32((RatchetParts - j - 1) * 8)
mask := (^uint32(0)) << shift
// how many times do we need to rehash this part?
// '& 0xff' ensures we handle integer wraparound correctly
steps := ((target >> shift) - m.Counter>>shift) & uint32(0xff)
if steps == 0 {
/*
deal with the edge case where m.Counter is slightly larger
than target. This should only happen for R(0), and implies
that target has wrapped around and we need to advance R(0)
256 times.
*/
if target < m.Counter {
steps = 0x100
} else {
continue
}
}
// for all but the last step, we can just bump R(j) without regard to R(j+1)...R(3).
for steps > 1 {
m.rehashPart(j, j)
steps--
}
/*
on the last step we also need to bump R(j+1)...R(3).
(Theoretically, we could skip bumping R(j+2) if we're going to bump
R(j+1) again, but the code to figure that out is a bit baroque and
doesn't save us much).
*/
for k := 3; k >= j; k-- {
m.rehashPart(j, k)
}
m.Counter = target & mask
}
}
// Encrypt encrypts the message in a message.GroupMessage with MAC and signature.
// The output is base64 encoded.
func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) {
var err error
encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext)
if err != nil {
return nil, fmt.Errorf("cipher encrypt: %w", err)
}
message := &message.GroupMessage{}
message.Version = protocolVersion
message.MessageIndex = r.Counter
message.Ciphertext = encryptedText
//creating the mac and signing is done in encode
output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key)
if err != nil {
return nil, err
}
r.Advance()
return output, nil
}
// SessionSharingMessage creates a message in the session sharing format.
func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error) {
m := message.MegolmSessionSharing{}
m.Counter = r.Counter
m.RatchetData = r.Data
encoded := m.EncodeAndSign(key)
return goolm.Base64Encode(encoded), nil
}
// SessionExportMessage creates a message in the session export format.
func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, error) {
m := message.MegolmSessionExport{}
m.Counter = r.Counter
m.RatchetData = r.Data
m.PublicKey = key
encoded := m.Encode()
return goolm.Base64Encode(encoded), nil
}
// Decrypt decrypts the ciphertext and verifies the MAC but not the signature.
func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) {
//verify mac
verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext)
if err != nil {
return nil, err
}
if !verifiedMAC {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC)
}
return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext)
}
// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(r, megolmPickleVersion, key)
}
// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion)
}
// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read.
func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) {
//read ratchet data
curPos := 0
ratchetData, readBytes, err := libolmpickle.UnpickleBytes(unpickled, RatchetParts*RatchetPartLength)
if err != nil {
return 0, err
}
copy(r.Data[:], ratchetData)
curPos += readBytes
//Read counter
counter, readBytes, err := libolmpickle.UnpickleUInt32(unpickled[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
r.Counter = counter
return curPos, nil
}
// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r Ratchet) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleBytes(r.Data[:], target)
written += libolmpickle.PickleUInt32(r.Counter, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled ratchet will have.
func (r Ratchet) PickleLen() int {
length := libolmpickle.PickleBytesLen(r.Data[:])
length += libolmpickle.PickleUInt32Len(r.Counter)
return length
}

View File

@@ -0,0 +1,70 @@
package message
import (
"encoding/binary"
"maunium.net/go/mautrix/crypto/goolm"
)
// checkDecodeErr checks if there was an error during decode.
func checkDecodeErr(readBytes int) error {
if readBytes == 0 {
//end reached
return goolm.ErrInputToSmall
}
if readBytes < 0 {
return goolm.ErrOverflow
}
return nil
}
// decodeVarInt decodes a single big-endian encoded varint.
func decodeVarInt(input []byte) (uint32, int) {
value, readBytes := binary.Uvarint(input)
return uint32(value), readBytes
}
// decodeVarString decodes the length of the string (varint) and returns the actual string
func decodeVarString(input []byte) ([]byte, int) {
stringLen, readBytes := decodeVarInt(input)
if readBytes <= 0 {
return nil, readBytes
}
input = input[readBytes:]
value := input[:stringLen]
readBytes += int(stringLen)
return value, readBytes
}
// encodeVarIntByteLength returns the number of bytes needed to encode the uint32.
func encodeVarIntByteLength(input uint32) int {
result := 1
for input >= 128 {
result++
input >>= 7
}
return result
}
// encodeVarStringByteLength returns the number of bytes needed to encode the input.
func encodeVarStringByteLength(input []byte) int {
result := encodeVarIntByteLength(uint32(len(input)))
result += len(input)
return result
}
// encodeVarInt encodes a single uint32
func encodeVarInt(input uint32) []byte {
out := make([]byte, encodeVarIntByteLength(input))
binary.PutUvarint(out, uint64(input))
return out
}
// encodeVarString encodes the length of the input (varint) and appends the actual input
func encodeVarString(input []byte) []byte {
out := make([]byte, encodeVarStringByteLength(input))
length := encodeVarInt(uint32(len(input)))
copy(out, length)
copy(out[len(length):], input)
return out
}

View File

@@ -0,0 +1,144 @@
package message
import (
"bytes"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
messageIndexTag = 0x08
cipherTextTag = 0x12
countMACBytesGroupMessage = 8
)
// GroupMessage represents a message in the group message format.
type GroupMessage struct {
Version byte `json:"version"`
MessageIndex uint32 `json:"index"`
Ciphertext []byte `json:"ciphertext"`
HasMessageIndex bool `json:"has_index"`
}
// Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present.
func (r *GroupMessage) Decode(input []byte) error {
r.Version = 0
r.MessageIndex = 0
r.Ciphertext = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case messageIndexTag:
r.MessageIndex = value
r.HasMessageIndex = true
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case cipherTextTag:
r.Ciphertext = value
}
}
}
return nil
}
// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message.
// If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended.
func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex)
lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(messageIndexTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarInt(r.MessageIndex)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(cipherTextTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Ciphertext)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
if len(macKey) != 0 && cipher != nil {
mac, err := r.MAC(macKey, cipher, out)
if err != nil {
return nil, err
}
out = append(out, mac[:countMACBytesGroupMessage]...)
}
if signKey != nil {
signature := signKey.Sign(out)
out = append(out, signature...)
}
return out, nil
}
// MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length.
func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) {
mac, err := cipher.MAC(key, message)
if err != nil {
return nil, err
}
return mac[:countMACBytesGroupMessage], nil
}
// VerifySignature verifies the givenSignature to the calculated signature of the message.
func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool {
return key.Verify(message, givenSignature)
}
// VerifySignature verifies the signature taken from the message to the calculated signature of the message.
func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool {
signature := message[len(message)-crypto.ED25519SignatureSize:]
message = message[:len(message)-crypto.ED25519SignatureSize]
return key.Verify(message, signature)
}
// VerifyMAC verifies the givenMAC to the calculated MAC of the message.
func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) {
checkMac, err := r.MAC(key, cipher, message)
if err != nil {
return false, err
}
return bytes.Equal(checkMac[:countMACBytesGroupMessage], givenMAC), nil
}
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize
endMAC := startMAC + countMACBytesGroupMessage
suplMac := message[startMAC:endMAC]
message = message[:startMAC]
return r.VerifyMAC(key, cipher, message, suplMac)
}

View File

@@ -0,0 +1,129 @@
package message
import (
"bytes"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
ratchetKeyTag = 0x0A
counterTag = 0x10
cipherTextKeyTag = 0x22
countMACBytesMessage = 8
)
// GroupMessage represents a message in the message format.
type Message struct {
Version byte `json:"version"`
HasCounter bool `json:"has_counter"`
Counter uint32 `json:"counter"`
RatchetKey crypto.Curve25519PublicKey `json:"ratchet_key"`
Ciphertext []byte `json:"ciphertext"`
}
// Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present.
func (r *Message) Decode(input []byte) error {
r.Version = 0
r.HasCounter = false
r.Counter = 0
r.RatchetKey = nil
r.Ciphertext = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesMessage {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case counterTag:
r.HasCounter = true
r.Counter = value
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case ratchetKeyTag:
r.RatchetKey = value
case cipherTextKeyTag:
r.Ciphertext = value
}
}
}
return nil
}
// EncodeAndMAC encodes the message and creates the MAC with the key and the cipher.
// If key or cipher is nil, no MAC is appended.
func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey)
lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter)
lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(ratchetKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarString(r.RatchetKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(counterTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarInt(r.Counter)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(cipherTextKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Ciphertext)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
if len(key) != 0 && cipher != nil {
mac, err := cipher.MAC(key, out)
if err != nil {
return nil, err
}
out = append(out, mac[:countMACBytesMessage]...)
}
return out, nil
}
// VerifyMAC verifies the givenMAC to the calculated MAC of the message.
func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) {
checkMAC, err := cipher.MAC(key, message)
if err != nil {
return false, err
}
return bytes.Equal(checkMAC[:countMACBytesMessage], givenMAC), nil
}
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
givenMAC := message[len(message)-countMACBytesMessage:]
return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC)
}

View File

@@ -0,0 +1,120 @@
package message
import (
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
oneTimeKeyIdTag = 0x0A
baseKeyTag = 0x12
identityKeyTag = 0x1A
messageTag = 0x22
)
type PreKeyMessage struct {
Version byte `json:"version"`
IdentityKey crypto.Curve25519PublicKey `json:"id_key"`
BaseKey crypto.Curve25519PublicKey `json:"base_key"`
OneTimeKey crypto.Curve25519PublicKey `json:"one_time_key"`
Message []byte `json:"message"`
}
// Decodes decodes the input and populates the corresponding fileds.
func (r *PreKeyMessage) Decode(input []byte) error {
r.Version = 0
r.IdentityKey = nil
r.BaseKey = nil
r.OneTimeKey = nil
r.Message = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input) {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
_, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case oneTimeKeyIdTag:
r.OneTimeKey = value
case baseKeyTag:
r.BaseKey = value
case identityKeyTag:
r.IdentityKey = value
case messageTag:
r.Message = value
}
}
}
return nil
}
// CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message.
func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey) bool {
ok := true
ok = ok && (theirIdentityKey != nil || r.IdentityKey != nil)
if r.IdentityKey != nil {
ok = ok && (len(r.IdentityKey) == crypto.Curve25519KeyLength)
}
ok = ok && len(r.Message) != 0
ok = ok && len(r.BaseKey) == crypto.Curve25519KeyLength
ok = ok && len(r.OneTimeKey) == crypto.Curve25519KeyLength
return ok
}
// Encode encodes the message.
func (r *PreKeyMessage) Encode() ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey)
lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey)
lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey)
lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(oneTimeKeyIdTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarString(r.OneTimeKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(identityKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.IdentityKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(baseKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.BaseKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(messageTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Message)
copy(out[curPos:], encodedValue)
return out, nil
}

View File

@@ -0,0 +1,44 @@
package message
import (
"encoding/binary"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
sessionExportVersion = 0x01
)
// MegolmSessionExport represents a message in the session export format.
type MegolmSessionExport struct {
Counter uint32 `json:"counter"`
RatchetData [128]byte `json:"data"`
PublicKey crypto.Ed25519PublicKey `json:"public_key"`
}
// Encode returns the encoded message in the correct format.
func (s MegolmSessionExport) Encode() []byte {
output := make([]byte, 165)
output[0] = sessionExportVersion
binary.BigEndian.PutUint32(output[1:], s.Counter)
copy(output[5:], s.RatchetData[:])
copy(output[133:], s.PublicKey)
return output
}
// Decode populates the struct with the data encoded in input.
func (s *MegolmSessionExport) Decode(input []byte) error {
if len(input) != 165 {
return fmt.Errorf("decrypt: %w", goolm.ErrBadInput)
}
if input[0] != sessionExportVersion {
return fmt.Errorf("decrypt: %w", goolm.ErrBadVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
s.PublicKey = input[133:]
return nil
}

View File

@@ -0,0 +1,50 @@
package message
import (
"encoding/binary"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
sessionSharingVersion = 0x02
)
// MegolmSessionSharing represents a message in the session sharing format.
type MegolmSessionSharing struct {
Counter uint32 `json:"counter"`
RatchetData [128]byte `json:"data"`
PublicKey crypto.Ed25519PublicKey `json:"-"` //only used when decrypting messages
}
// Encode returns the encoded message in the correct format with the signature by key appended.
func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte {
output := make([]byte, 229)
output[0] = sessionSharingVersion
binary.BigEndian.PutUint32(output[1:], s.Counter)
copy(output[5:], s.RatchetData[:])
copy(output[133:], key.PublicKey)
signature := key.Sign(output[:165])
copy(output[165:], signature)
return output
}
// VerifyAndDecode verifies the input and populates the struct with the data encoded in input.
func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
if len(input) != 229 {
return fmt.Errorf("verify: %w", goolm.ErrBadInput)
}
publicKey := crypto.Ed25519PublicKey(input[133:165])
if !publicKey.Verify(input[:165], input[165:]) {
return fmt.Errorf("verify: %w", goolm.ErrBadVerification)
}
s.PublicKey = publicKey
if input[0] != sessionSharingVersion {
return fmt.Errorf("verify: %w", goolm.ErrBadVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
return nil
}

258
vendor/maunium.net/go/mautrix/crypto/goolm/olm/chain.go generated vendored Normal file
View File

@@ -0,0 +1,258 @@
package olm
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
)
const (
chainKeySeed = 0x02
messageKeyLength = 32
)
// chainKey wraps the index and the public key
type chainKey struct {
Index uint32 `json:"index"`
Key crypto.Curve25519PublicKey `json:"key"`
}
// advance advances the chain
func (c *chainKey) advance() {
c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed})
c.Index++
}
// UnpickleLibOlm decodes the unencryted value and populates the chain key accordingly. It returns the number of bytes read.
func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.Key.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
r.Index, readBytes, err = libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r chainKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle chain key: %w", goolm.ErrValueTooShort)
}
written, err := r.Key.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle chain key: %w", err)
}
written += libolmpickle.PickleUInt32(r.Index, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled chain key will have.
func (r chainKey) PickleLen() int {
length := r.Key.PickleLen()
length += libolmpickle.PickleUInt32Len(r.Index)
return length
}
// senderChain is a chain for sending messages
type senderChain struct {
RKey crypto.Curve25519KeyPair `json:"ratchet_key"`
CKey chainKey `json:"chain_key"`
IsSet bool `json:"set"`
}
// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair.
func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain {
return &senderChain{
RKey: ratchet,
CKey: chainKey{
Index: 0,
Key: key,
},
IsSet: true,
}
}
// advance advances the chain
func (s *senderChain) advance() {
s.CKey.advance()
}
// ratchetKey returns the ratchet key pair.
func (s senderChain) ratchetKey() crypto.Curve25519KeyPair {
return s.RKey
}
// chainKey returns the current chainKey.
func (s senderChain) chainKey() chainKey {
return s.CKey
}
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.RKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r senderChain) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
}
written, err := r.RKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
writtenChain, err := r.CKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
written += writtenChain
return written, nil
}
// PickleLen returns the number of bytes the pickled chain will have.
func (r senderChain) PickleLen() int {
length := r.RKey.PickleLen()
length += r.CKey.PickleLen()
return length
}
// senderChain is a chain for receiving messages
type receiverChain struct {
RKey crypto.Curve25519PublicKey `json:"ratchet_key"`
CKey chainKey `json:"chain_key"`
}
// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key.
func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain {
return &receiverChain{
RKey: ratchet,
CKey: chainKey{
Index: 0,
Key: chain,
},
}
}
// advance advances the chain
func (s *receiverChain) advance() {
s.CKey.advance()
}
// ratchetKey returns the ratchet public key.
func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey {
return s.RKey
}
// chainKey returns the current chainKey.
func (s receiverChain) chainKey() chainKey {
return s.CKey
}
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.RKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r receiverChain) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
}
written, err := r.RKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
writtenChain, err := r.CKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
written += writtenChain
return written, nil
}
// PickleLen returns the number of bytes the pickled chain will have.
func (r receiverChain) PickleLen() int {
length := r.RKey.PickleLen()
length += r.CKey.PickleLen()
return length
}
// messageKey wraps the index and the key of a message
type messageKey struct {
Index uint32 `json:"index"`
Key []byte `json:"key"`
}
// UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read.
func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
ratchetKey, readBytes, err := libolmpickle.UnpickleBytes(value, messageKeyLength)
if err != nil {
return 0, err
}
m.Key = ratchetKey
curPos += readBytes
keyID, readBytes, err := libolmpickle.UnpickleUInt32(value[:curPos])
if err != nil {
return 0, err
}
curPos += readBytes
m.Index = keyID
return curPos, nil
}
// PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (m messageKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < m.PickleLen() {
return 0, fmt.Errorf("pickle message key: %w", goolm.ErrValueTooShort)
}
written := 0
if len(m.Key) != messageKeyLength {
written += libolmpickle.PickleBytes(make([]byte, messageKeyLength), target)
} else {
written += libolmpickle.PickleBytes(m.Key, target)
}
written += libolmpickle.PickleUInt32(m.Index, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled message key will have.
func (r messageKey) PickleLen() int {
length := libolmpickle.PickleBytesLen(make([]byte, messageKeyLength))
length += libolmpickle.PickleUInt32Len(r.Index)
return length
}

432
vendor/maunium.net/go/mautrix/crypto/goolm/olm/olm.go generated vendored Normal file
View File

@@ -0,0 +1,432 @@
// olm provides the ratchet used by the olm protocol
package olm
import (
"fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
olmPickleVersion uint8 = 1
)
const (
maxReceiverChains = 5
maxSkippedMessageKeys = 40
protocolVersion = 3
messageKeySeed = 0x01
maxMessageGap = 2000
sharedKeyLength = 32
)
// KdfInfo has the infos used for the kdf
var KdfInfo = struct {
Root []byte
Ratchet []byte
}{
Root: []byte("OLM_ROOT"),
Ratchet: []byte("OLM_RATCHET"),
}
var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS"))
// Ratchet represents the olm ratchet as described in
//
// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md
type Ratchet struct {
// The root key is used to generate chain keys from the ephemeral keys.
// A new root_key is derived each time a new chain is started.
RootKey crypto.Curve25519PublicKey `json:"root_key"`
// The sender chain is used to send messages. Each time a new ephemeral
// key is received from the remote server we generate a new sender chain
// with a new ephemeral key when we next send a message.
SenderChains senderChain `json:"sender_chain"`
// The receiver chain is used to decrypt received messages. We store the
// last few chains so we can decrypt any out of order messages we haven't
// received yet.
// New chains are prepended for easier access.
ReceiverChains []receiverChain `json:"receiver_chains"`
// Storing the keys of missed messages for future use.
// The order of the elements is not important.
SkippedMessageKeys []skippedMessageKey `json:"skipped_message_keys"`
}
// New creates a new ratchet, setting the kdfInfos and cipher.
func New() *Ratchet {
r := &Ratchet{}
return r
}
// InitializeAsBob initializes this ratchet from a receiving point of view (only first message).
func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
}
r.RootKey = derivedSecrets[0:sharedKeyLength]
newReceiverChain := newReceiverChain(derivedSecrets[sharedKeyLength:], theirRatchetKey)
r.ReceiverChains = append([]receiverChain{*newReceiverChain}, r.ReceiverChains...)
return nil
}
// InitializeAsAlice initializes this ratchet from a sending point of view (only first message).
func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
}
r.RootKey = derivedSecrets[0:sharedKeyLength]
newSenderChain := newSenderChain(derivedSecrets[sharedKeyLength:], ourRatchetKey)
r.SenderChains = *newSenderChain
return nil
}
// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations.
func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) {
var err error
if !r.SenderChains.IsSet {
newRatchetKey, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return nil, err
}
newChainKey, err := r.advanceRootKey(newRatchetKey, r.ReceiverChains[0].ratchetKey())
if err != nil {
return nil, err
}
newSenderChain := newSenderChain(newChainKey, newRatchetKey)
r.SenderChains = *newSenderChain
}
messageKey := r.createMessageKeys(r.SenderChains.chainKey())
r.SenderChains.advance()
encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext)
if err != nil {
return nil, fmt.Errorf("cipher encrypt: %w", err)
}
message := &message.Message{}
message.Version = protocolVersion
message.Counter = messageKey.Index
message.RatchetKey = r.SenderChains.ratchetKey().PublicKey
message.Ciphertext = encryptedText
//creating the mac is done in encode
output, err := message.EncodeAndMAC(messageKey.Key, RatchetCipher)
if err != nil {
return nil, err
}
return output, nil
}
// Decrypt decrypts the ciphertext and verifies the MAC. If reader is nil, crypto/rand is used for key generations.
func (r *Ratchet) Decrypt(input []byte) ([]byte, error) {
message := &message.Message{}
//The mac is not verified here, as we do not know the key yet
err := message.Decode(input)
if err != nil {
return nil, err
}
if message.Version != protocolVersion {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion)
}
if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
}
var receiverChainFromMessage *receiverChain
for curChainIndex := range r.ReceiverChains {
if r.ReceiverChains[curChainIndex].ratchetKey().Equal(message.RatchetKey) {
receiverChainFromMessage = &r.ReceiverChains[curChainIndex]
break
}
}
var result []byte
if receiverChainFromMessage == nil {
//Advancing the chain is done in this method
result, err = r.decryptForNewChain(message, input)
if err != nil {
return nil, err
}
} else if receiverChainFromMessage.chainKey().Index > message.Counter {
// No need to advance the chain
// Chain already advanced beyond the key for this message
// Check if the message keys are in the skipped key list.
foundSkippedKey := false
for curSkippedIndex := range r.SkippedMessageKeys {
if message.Counter == r.SkippedMessageKeys[curSkippedIndex].MKey.Index {
// Found the key for this message. Check the MAC.
verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input)
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt from skipped message keys: %w", goolm.ErrBadMAC)
}
result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext)
if err != nil {
return nil, fmt.Errorf("cipher decrypt: %w", err)
}
if len(result) != 0 {
// Remove the key from the skipped keys now that we've
// decoded the message it corresponds to.
r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1]
r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1]
}
foundSkippedKey = true
}
}
if !foundSkippedKey {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrMessageKeyNotFound)
}
} else {
//Advancing the chain is done in this method
result, err = r.decryptForExistingChain(receiverChainFromMessage, message, input)
if err != nil {
return nil, err
}
}
return result, nil
}
// advanceRootKey created the next root key and returns the next chainKey
func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatchetKey crypto.Curve25519PublicKey) (crypto.Curve25519PublicKey, error) {
sharedSecret, err := newRatchetKey.SharedSecret(oldRatchetKey)
if err != nil {
return nil, err
}
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return nil, err
}
r.RootKey = derivedSecrets[:sharedKeyLength]
return derivedSecrets[sharedKeyLength:], nil
}
// createMessageKeys returns the messageKey derived from the chainKey
func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey {
res := messageKey{}
res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed})
res.Index = chainKey.Index
return res
}
// decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified.
func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) {
if message.Counter < chain.CKey.Index {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrChainTooHigh)
}
// Limit the number of hashes we're prepared to compute
if message.Counter-chain.CKey.Index > maxMessageGap {
return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrMsgIndexTooHigh)
}
for chain.CKey.Index < message.Counter {
messageKey := r.createMessageKeys(chain.chainKey())
skippedKey := skippedMessageKey{
MKey: messageKey,
RKey: chain.ratchetKey(),
}
r.SkippedMessageKeys = append(r.SkippedMessageKeys, skippedKey)
chain.advance()
}
messageKey := r.createMessageKeys(chain.chainKey())
chain.advance()
verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage)
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrBadMAC)
}
return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext)
}
// decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key.
func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte) ([]byte, error) {
// They shouldn't move to a new chain until we've sent them a message
// acknowledging the last one
if !r.SenderChains.IsSet {
return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrProtocolViolation)
}
// Limit the number of hashes we're prepared to compute
if message.Counter > maxMessageGap {
return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrMsgIndexTooHigh)
}
newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey)
if err != nil {
return nil, err
}
newChain := newReceiverChain(newChainKey, message.RatchetKey)
r.ReceiverChains = append([]receiverChain{*newChain}, r.ReceiverChains...)
/*
They have started using a new ephemeral ratchet key.
We needed to derive a new set of chain keys.
We can discard our previous ephemeral ratchet key.
We will generate a new key when we send the next message.
*/
r.SenderChains = senderChain{}
decrypted, err := r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage)
if err != nil {
return nil, err
}
return decrypted, nil
}
// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(r, olmPickleVersion, key)
}
// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion)
}
// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read.
func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, error) {
//read ratchet data
curPos := 0
readBytes, err := r.RootKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
countSenderChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of sender chain
if err != nil {
return 0, err
}
curPos += readBytes
for i := uint32(0); i < countSenderChains; i++ {
if i == 0 {
//only first is stored
readBytes, err := r.SenderChains.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
r.SenderChains.IsSet = true
} else {
dummy := senderChain{}
readBytes, err := dummy.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
}
countReceivChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of recevier chain
if err != nil {
return 0, err
}
curPos += readBytes
r.ReceiverChains = make([]receiverChain, countReceivChains)
for i := uint32(0); i < countReceivChains; i++ {
readBytes, err := r.ReceiverChains[i].UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
countSkippedMessageKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of skippedMessageKeys
if err != nil {
return 0, err
}
curPos += readBytes
r.SkippedMessageKeys = make([]skippedMessageKey, countSkippedMessageKeys)
for i := uint32(0); i < countSkippedMessageKeys; i++ {
readBytes, err := r.SkippedMessageKeys[i].UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
// pickle v 0x80000001 includes a chain index; pickle v1 does not.
if includesChainIndex {
_, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
return curPos, nil
}
// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r Ratchet) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle ratchet: %w", goolm.ErrValueTooShort)
}
written, err := r.RootKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
if r.SenderChains.IsSet {
written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1
writtenSender, err := r.SenderChains.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
written += writtenSender
} else {
written += libolmpickle.PickleUInt32(0, target[written:]) //Length of sender chain
}
written += libolmpickle.PickleUInt32(uint32(len(r.ReceiverChains)), target[written:])
for _, curChain := range r.ReceiverChains {
writtenChain, err := curChain.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
written += writtenChain
}
written += libolmpickle.PickleUInt32(uint32(len(r.SkippedMessageKeys)), target[written:])
for _, curChain := range r.SkippedMessageKeys {
writtenChain, err := curChain.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
written += writtenChain
}
return written, nil
}
// PickleLen returns the actual number of bytes the pickled ratchet will have.
func (r Ratchet) PickleLen() int {
length := r.RootKey.PickleLen()
if r.SenderChains.IsSet {
length += libolmpickle.PickleUInt32Len(1)
length += r.SenderChains.PickleLen()
} else {
length += libolmpickle.PickleUInt32Len(0)
}
length += libolmpickle.PickleUInt32Len(uint32(len(r.ReceiverChains)))
length += len(r.ReceiverChains) * receiverChain{}.PickleLen()
length += libolmpickle.PickleUInt32Len(uint32(len(r.SkippedMessageKeys)))
length += len(r.SkippedMessageKeys) * skippedMessageKey{}.PickleLen()
return length
}
// PickleLen returns the minimum number of bytes the pickled ratchet must have.
func (r Ratchet) PickleLenMin() int {
length := r.RootKey.PickleLen()
length += libolmpickle.PickleUInt32Len(0)
length += libolmpickle.PickleUInt32Len(0)
length += libolmpickle.PickleUInt32Len(0)
return length
}

View File

@@ -0,0 +1,55 @@
package olm
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// skippedMessageKey stores a skipped message key
type skippedMessageKey struct {
RKey crypto.Curve25519PublicKey `json:"ratchet_key"`
MKey messageKey `json:"message_key"`
}
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.RKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = r.MKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
}
written, err := r.RKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
writtenChain, err := r.MKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
written += writtenChain
return written, nil
}
// PickleLen returns the number of bytes the pickled chain will have.
func (r skippedMessageKey) PickleLen() int {
length := r.RKey.PickleLen()
length += r.MKey.PickleLen()
return length
}

View File

@@ -0,0 +1,165 @@
package pk
import (
"encoding/base64"
"errors"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
const (
decryptionPickleVersionJSON uint8 = 1
decryptionPickleVersionLibOlm uint32 = 1
)
// Decryption is used to decrypt pk messages
type Decryption struct {
KeyPair crypto.Curve25519KeyPair `json:"key_pair"`
}
// NewDecryption returns a new Decryption with a new generated key pair.
func NewDecryption() (*Decryption, error) {
keyPair, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
return &Decryption{
KeyPair: keyPair,
}, nil
}
// NewDescriptionFromPrivate resturns a new Decryption with the private key fixed.
func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decryption, error) {
s := &Decryption{}
keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey)
if err != nil {
return nil, err
}
s.KeyPair = keyPair
return s, nil
}
// PubKey returns the public key base 64 encoded.
func (s Decryption) PubKey() id.Curve25519 {
return s.KeyPair.B64Encoded()
}
// PrivateKey returns the private key.
func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey {
return s.KeyPair.PrivateKey
}
// Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret.
func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) {
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
if err != nil {
return nil, err
}
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
if err != nil {
return nil, err
}
decodedMAC, err := goolm.Base64Decode(mac)
if err != nil {
return nil, err
}
cipher := cipher.NewAESSHA256(nil)
verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC)
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC)
}
plaintext, err := cipher.Decrypt(sharedSecret, ciphertext)
if err != nil {
return nil, err
}
return plaintext, nil
}
// PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key)
}
// UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON)
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (a *Decryption) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = a.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read.
func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case decryptionPickleVersionLibOlm:
default:
return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion)
}
readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm().
func (a Decryption) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, a.PickleLen())
written, err := a.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (a Decryption) PickleLibOlm(target []byte) (int, error) {
if len(target) < a.PickleLen() {
return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target)
writtenKey, err := a.KeyPair.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle Decryption: %w", err)
}
written += writtenKey
return written, nil
}
// PickleLen returns the number of bytes the pickled Decryption will have.
func (a Decryption) PickleLen() int {
length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm)
length += a.KeyPair.PickleLen()
return length
}

View File

@@ -0,0 +1,49 @@
package pk
import (
"encoding/base64"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// Encryption is used to encrypt pk messages
type Encryption struct {
RecipientKey crypto.Curve25519PublicKey `json:"recipient_key"`
}
// NewEncryption returns a new Encryption with the base64 encoded public key of the recipient
func NewEncryption(pubKey id.Curve25519) (*Encryption, error) {
pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKey))
if err != nil {
return nil, err
}
return &Encryption{
RecipientKey: pubKeyDecoded,
}, nil
}
// Encrypt encrypts the plaintext with the privateKey and returns the ciphertext and base64 encoded MAC.
func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519PrivateKey) (ciphertext, mac []byte, err error) {
keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey)
if err != nil {
return nil, nil, err
}
sharedSecret, err := keyPair.SharedSecret(e.RecipientKey)
if err != nil {
return nil, nil, err
}
cipher := cipher.NewAESSHA256(nil)
ciphertext, err = cipher.Encrypt(sharedSecret, plaintext)
if err != nil {
return nil, nil, err
}
mac, err = cipher.MAC(sharedSecret, ciphertext)
if err != nil {
return nil, nil, err
}
return ciphertext, goolm.Base64Encode(mac), nil
}

View File

@@ -0,0 +1,44 @@
package pk
import (
"crypto/rand"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id"
)
// Signing is used for signing a pk
type Signing struct {
KeyPair crypto.Ed25519KeyPair `json:"key_pair"`
Seed []byte `json:"seed"`
}
// NewSigningFromSeed constructs a new Signing based on a seed.
func NewSigningFromSeed(seed []byte) (*Signing, error) {
s := &Signing{}
s.Seed = seed
s.KeyPair = crypto.Ed25519GenerateFromSeed(seed)
return s, nil
}
// NewSigning returns a Signing based on a random seed
func NewSigning() (*Signing, error) {
seed := make([]byte, 32)
_, err := rand.Read(seed)
if err != nil {
return nil, err
}
return NewSigningFromSeed(seed)
}
// Sign returns the signature of the message base64 encoded.
func (s Signing) Sign(message []byte) []byte {
signature := s.KeyPair.Sign(message)
return goolm.Base64Encode(signature)
}
// PublicKey returns the public key of the key pair base 64 encoded.
func (s Signing) PublicKey() id.Ed25519 {
return s.KeyPair.B64Encoded()
}

76
vendor/maunium.net/go/mautrix/crypto/goolm/sas/main.go generated vendored Normal file
View File

@@ -0,0 +1,76 @@
// sas provides the means to do SAS between keys
package sas
import (
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// SAS contains the key pair and secret for SAS.
type SAS struct {
KeyPair crypto.Curve25519KeyPair
Secret []byte
}
// New creates a new SAS with a new key pair.
func New() (*SAS, error) {
kp, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
s := &SAS{
KeyPair: kp,
}
return s, nil
}
// GetPubkey returns the public key of the key pair base64 encoded
func (s SAS) GetPubkey() []byte {
return goolm.Base64Encode(s.KeyPair.PublicKey)
}
// SetTheirKey sets the key of the other party and computes the shared secret.
func (s *SAS) SetTheirKey(key []byte) error {
keyDecoded, err := goolm.Base64Decode(key)
if err != nil {
return err
}
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
if err != nil {
return err
}
s.Secret = sharedSecret
return nil
}
// GenerateBytes creates length bytes from the shared secret and info.
func (s SAS) GenerateBytes(info []byte, length uint) ([]byte, error) {
byteReader := crypto.HKDFSHA256(s.Secret, nil, info)
output := make([]byte, length)
if _, err := io.ReadFull(byteReader, output); err != nil {
return nil, err
}
return output, nil
}
// calculateMAC returns a base64 encoded MAC of input.
func (s *SAS) calculateMAC(input, info []byte, length uint) ([]byte, error) {
key, err := s.GenerateBytes(info, length)
if err != nil {
return nil, err
}
mac := crypto.HMACSHA256(key, input)
return goolm.Base64Encode(mac), nil
}
// CalculateMACFixes returns a base64 encoded, 32 byte long MAC of input.
func (s SAS) CalculateMAC(input, info []byte) ([]byte, error) {
return s.calculateMAC(input, info, 32)
}
// CalculateMACLongKDF returns a base64 encoded, 256 byte long MAC of input.
func (s SAS) CalculateMACLongKDF(input, info []byte) ([]byte, error) {
return s.calculateMAC(input, info, 256)
}

View File

@@ -0,0 +1,2 @@
// session provides the different types of sessions for en/decrypting of messages
package session

View File

@@ -0,0 +1,276 @@
package session
import (
"encoding/base64"
"errors"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/megolm"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
const (
megolmInboundSessionPickleVersionJSON byte = 1
megolmInboundSessionPickleVersionLibOlm uint32 = 2
)
// MegolmInboundSession stores information about the sessions of receive.
type MegolmInboundSession struct {
Ratchet megolm.Ratchet `json:"ratchet"`
SigningKey crypto.Ed25519PublicKey `json:"signing_key"`
InitialRatchet megolm.Ratchet `json:"initial_ratchet"`
SigningKeyVerified bool `json:"signing_key_verified"` //not used for now
}
// NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message.
func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) {
var err error
input, err = goolm.Base64Decode(input)
if err != nil {
return nil, err
}
msg := message.MegolmSessionSharing{}
err = msg.VerifyAndDecode(input)
if err != nil {
return nil, err
}
o := &MegolmInboundSession{}
o.SigningKey = msg.PublicKey
o.SigningKeyVerified = true
ratchet, err := megolm.New(msg.Counter, msg.RatchetData)
if err != nil {
return nil, err
}
o.Ratchet = *ratchet
o.InitialRatchet = *ratchet
return o, nil
}
// NewMegolmInboundSessionFromExport creates a new MegolmInboundSession from a base64 encoded session export message.
func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, error) {
var err error
input, err = goolm.Base64Decode(input)
if err != nil {
return nil, err
}
msg := message.MegolmSessionExport{}
err = msg.Decode(input)
if err != nil {
return nil, err
}
o := &MegolmInboundSession{}
o.SigningKey = msg.PublicKey
ratchet, err := megolm.New(msg.Counter, msg.RatchetData)
if err != nil {
return nil, err
}
o.Ratchet = *ratchet
o.InitialRatchet = *ratchet
return o, nil
}
// MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key.
func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &MegolmInboundSession{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// getRatchet tries to find the correct ratchet for a messageIndex.
func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) {
// pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that
if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) {
o.Ratchet.AdvanceTo(messageIndex)
return &o.Ratchet, nil
}
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
// the counter is before our initial ratchet - we can't decode this
return nil, fmt.Errorf("decrypt: %w", goolm.ErrRatchetNotAvailable)
}
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
copiedRatchet := o.InitialRatchet
copiedRatchet.AdvanceTo(messageIndex)
return &copiedRatchet, nil
}
// Decrypt decrypts a base64 encoded group message.
func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) {
if o.SigningKey == nil {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
}
decoded, err := goolm.Base64Decode(ciphertext)
if err != nil {
return nil, 0, err
}
msg := &message.GroupMessage{}
err = msg.Decode(decoded)
if err != nil {
return nil, 0, err
}
if msg.Version != protocolVersion {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion)
}
if msg.Ciphertext == nil || !msg.HasMessageIndex {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
}
// verify signature
verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded)
if !verifiedSignature {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadSignature)
}
targetRatch, err := o.getRatchet(msg.MessageIndex)
if err != nil {
return nil, 0, err
}
decrypted, err := targetRatch.Decrypt(decoded, &o.SigningKey, msg)
if err != nil {
return nil, 0, err
}
o.SigningKeyVerified = true
return decrypted, msg.MessageIndex, nil
}
// SessionID returns the base64 endoded signing key
func (o MegolmInboundSession) SessionID() id.SessionID {
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey))
}
// PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key)
}
// UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON)
}
// SessionExportMessage creates an base64 encoded export of the session.
func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) {
ratchet, err := o.getRatchet(messageIndex)
if err != nil {
return nil, err
}
return ratchet.SessionExportMessage(o.SigningKey)
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = o.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case megolmInboundSessionPickleVersionLibOlm, 1:
default:
return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion)
}
readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if pickledVersion == 1 {
// pickle v1 had no signing_key_verified field (all keyshares were verified at import time)
o.SigningKeyVerified = true
} else {
o.SigningKeyVerified, readBytes, err = libolmpickle.UnpickleBool(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm().
func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target)
writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
}
written += writtenInitRatchet
writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
}
written += writtenRatchet
writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
}
written += writtenPubKey
written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled session will have.
func (o MegolmInboundSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm)
length += o.InitialRatchet.PickleLen()
length += o.Ratchet.PickleLen()
length += o.SigningKey.PickleLen()
length += libolmpickle.PickleBoolLen(o.SigningKeyVerified)
return length
}

View File

@@ -0,0 +1,171 @@
package session
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/megolm"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
megolmOutboundSessionPickleVersion byte = 1
megolmOutboundSessionPickleVersionLibOlm uint32 = 1
)
// MegolmOutboundSession stores information about the sessions to send.
type MegolmOutboundSession struct {
Ratchet megolm.Ratchet `json:"ratchet"`
SigningKey crypto.Ed25519KeyPair `json:"signing_key"`
}
// NewMegolmOutboundSession creates a new MegolmOutboundSession.
func NewMegolmOutboundSession() (*MegolmOutboundSession, error) {
o := &MegolmOutboundSession{}
var err error
o.SigningKey, err = crypto.Ed25519GenerateKey(nil)
if err != nil {
return nil, err
}
var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte
_, err = rand.Read(randomData[:])
if err != nil {
return nil, err
}
ratchet, err := megolm.New(0, randomData)
if err != nil {
return nil, err
}
o.Ratchet = *ratchet
return o, nil
}
// MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key.
func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &MegolmOutboundSession{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// Encrypt encrypts the plaintext as a base64 encoded group message.
func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey)
if err != nil {
return nil, err
}
return goolm.Base64Encode(encrypted), nil
}
// SessionID returns the base64 endoded public signing key
func (o MegolmOutboundSession) SessionID() id.SessionID {
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey))
}
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key)
}
// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format.
func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion)
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = o.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case megolmOutboundSessionPickleVersionLibOlm:
default:
return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion)
}
readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm().
func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target)
writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenPubKey
return written, nil
}
// PickleLen returns the number of bytes the pickled session will have.
func (o MegolmOutboundSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm)
length += o.Ratchet.PickleLen()
length += o.SigningKey.PickleLen()
return length
}
func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) {
return o.Ratchet.SessionSharingMessage(o.SigningKey)
}

View File

@@ -0,0 +1,476 @@
package session
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/olm"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
const (
olmSessionPickleVersionJSON uint8 = 1
olmSessionPickleVersionLibOlm uint32 = 1
)
const (
protocolVersion = 0x3
)
// OlmSession stores all information for an olm session
type OlmSession struct {
ReceivedMessage bool `json:"received_message"`
AliceIdentityKey crypto.Curve25519PublicKey `json:"alice_id_key"`
AliceBaseKey crypto.Curve25519PublicKey `json:"alice_base_key"`
BobOneTimeKey crypto.Curve25519PublicKey `json:"bob_one_time_key"`
Ratchet olm.Ratchet `json:"ratchet"`
}
// SearchOTKFunc is used to retrieve a crypto.OneTimeKey from a public key.
type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey
// OlmSessionFromJSONPickled loads an OlmSession from a pickled base64 string. Decrypts
// the Session using the supplied key.
func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &OlmSession{}
err := a.UnpickleAsJSON(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key.
func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &OlmSession{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// NewOlmSession creates a new Session.
func NewOlmSession() *OlmSession {
s := &OlmSession{}
s.Ratchet = *olm.New()
return s
}
// NewOutboundOlmSession creates a new outbound session for sending the first message to a
// given curve25519 identityKey and oneTimeKey.
func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) {
s := NewOlmSession()
//generate E_A
baseKey, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
//generate T_0
ratchetKey, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
//Calculate shared secret via Triple Diffie-Hellman
var secret []byte
//ECDH(I_A,E_B)
idSecret, err := identityKeyAlice.SharedSecret(oneTimeKeyBob)
if err != nil {
return nil, err
}
//ECDH(E_A,I_B)
baseIdSecret, err := baseKey.SharedSecret(identityKeyBob)
if err != nil {
return nil, err
}
//ECDH(E_A,E_B)
baseOneTimeSecret, err := baseKey.SharedSecret(oneTimeKeyBob)
if err != nil {
return nil, err
}
secret = append(secret, idSecret...)
secret = append(secret, baseIdSecret...)
secret = append(secret, baseOneTimeSecret...)
//Init Ratchet
s.Ratchet.InitializeAsAlice(secret, ratchetKey)
s.AliceIdentityKey = identityKeyAlice.PublicKey
s.AliceBaseKey = baseKey.PublicKey
s.BobOneTimeKey = oneTimeKeyBob
return s, nil
}
// NewInboundOlmSession creates a new inbound session from receiving the first message.
func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, receivedOTKMsg []byte, searchBobOTK SearchOTKFunc, identityKeyBob crypto.Curve25519KeyPair) (*OlmSession, error) {
decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg)
if err != nil {
return nil, err
}
s := NewOlmSession()
//decode OneTimeKeyMessage
oneTimeMsg := message.PreKeyMessage{}
err = oneTimeMsg.Decode(decodedOTKMsg)
if err != nil {
return nil, fmt.Errorf("OneTimeKeyMessage decode: %w", err)
}
if !oneTimeMsg.CheckFields(identityKeyAlice) {
return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", goolm.ErrBadMessageFormat)
}
//Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked
// by oneTimeMsg.CheckFields
if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 {
//if both are set, compare them
if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) {
return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", goolm.ErrBadMessageKeyID)
}
}
if identityKeyAlice == nil {
//for downstream use set
identityKeyAlice = &oneTimeMsg.IdentityKey
}
oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey)
if oneTimeKeyBob == nil {
return nil, fmt.Errorf("ourOneTimeKey: %w", goolm.ErrBadMessageKeyID)
}
//Calculate shared secret via Triple Diffie-Hellman
var secret []byte
//ECDH(E_B,I_A)
idSecret, err := oneTimeKeyBob.Key.SharedSecret(*identityKeyAlice)
if err != nil {
return nil, err
}
//ECDH(I_B,E_A)
baseIdSecret, err := identityKeyBob.SharedSecret(oneTimeMsg.BaseKey)
if err != nil {
return nil, err
}
//ECDH(E_B,E_A)
baseOneTimeSecret, err := oneTimeKeyBob.Key.SharedSecret(oneTimeMsg.BaseKey)
if err != nil {
return nil, err
}
secret = append(secret, idSecret...)
secret = append(secret, baseIdSecret...)
secret = append(secret, baseOneTimeSecret...)
//decode message
msg := message.Message{}
err = msg.Decode(oneTimeMsg.Message)
if err != nil {
return nil, fmt.Errorf("Message decode: %w", err)
}
if len(msg.RatchetKey) == 0 {
return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat)
}
//Init Ratchet
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
s.AliceBaseKey = oneTimeMsg.BaseKey
s.AliceIdentityKey = oneTimeMsg.IdentityKey
s.BobOneTimeKey = oneTimeKeyBob.Key.PublicKey
//https://gitlab.matrix.org/matrix-org/olm/blob/master/docs/olm.md states to remove the oneTimeKey
//this is done via the account itself
return s, nil
}
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key)
}
// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format.
func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON)
}
// ID returns an identifier for this Session. Will be the same for both ends of the conversation.
// Generated by hashing the public keys used to create the session.
func (s OlmSession) ID() id.SessionID {
message := make([]byte, 3*crypto.Curve25519KeyLength)
copy(message, s.AliceIdentityKey)
copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey)
copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey)
hash := crypto.SHA256(message)
res := id.SessionID(goolm.Base64Encode(hash))
return res
}
// HasReceivedMessage returns true if this session has received any message.
func (s OlmSession) HasReceivedMessage() bool {
return s.ReceivedMessage
}
// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match.
func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) {
if len(receivedOTKMsg) == 0 {
return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput)
}
decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg)
if err != nil {
return false, err
}
var theirIdentityKey *crypto.Curve25519PublicKey
if theirIdentityKeyEncoded != nil {
decodedKey, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKeyEncoded))
if err != nil {
return false, err
}
theirIdentityKeyByte := crypto.Curve25519PublicKey(decodedKey)
theirIdentityKey = &theirIdentityKeyByte
}
msg := message.PreKeyMessage{}
err = msg.Decode(decodedOTKMsg)
if err != nil {
return false, err
}
if !msg.CheckFields(theirIdentityKey) {
return false, nil
}
same := true
if msg.IdentityKey != nil {
same = same && msg.IdentityKey.Equal(s.AliceIdentityKey)
}
if theirIdentityKey != nil {
same = same && theirIdentityKey.Equal(s.AliceIdentityKey)
}
same = same && bytes.Equal(msg.BaseKey, s.AliceBaseKey)
same = same && bytes.Equal(msg.OneTimeKey, s.BobOneTimeKey)
return same, nil
}
// EncryptMsgType returns the type of the next message that Encrypt will
// return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg.
// Returns MsgTypeMsg if the message will be a normal message.
func (s OlmSession) EncryptMsgType() id.OlmMsgType {
if s.ReceivedMessage {
return id.OlmMsgTypeMsg
}
return id.OlmMsgTypePreKey
}
// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations.
func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput)
}
messageType := s.EncryptMsgType()
encrypted, err := s.Ratchet.Encrypt(plaintext, reader)
if err != nil {
return 0, nil, err
}
result := encrypted
if !s.ReceivedMessage {
msg := message.PreKeyMessage{}
msg.Version = protocolVersion
msg.OneTimeKey = s.BobOneTimeKey
msg.IdentityKey = s.AliceIdentityKey
msg.BaseKey = s.AliceBaseKey
msg.Message = encrypted
var err error
messageBody, err := msg.Encode()
if err != nil {
return 0, nil, err
}
result = messageBody
}
return messageType, goolm.Base64Encode(result), nil
}
// Decrypt decrypts a base64 encoded message using the Session.
func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) {
if len(crypttext) == 0 {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput)
}
decodedCrypttext, err := goolm.Base64Decode(crypttext)
if err != nil {
return nil, err
}
msgBody := decodedCrypttext
if msgType != id.OlmMsgTypeMsg {
//Pre-Key Message
msg := message.PreKeyMessage{}
err := msg.Decode(decodedCrypttext)
if err != nil {
return nil, err
}
msgBody = msg.Message
}
plaintext, err := s.Ratchet.Decrypt(msgBody)
if err != nil {
return nil, err
}
s.ReceivedMessage = true
return plaintext, nil
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *OlmSession) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = o.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
includesChainIndex := true
switch pickledVersion {
case olmSessionPickleVersionLibOlm:
includesChainIndex = false
case uint32(0x80000001):
includesChainIndex = true
default:
return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion)
}
var readBytes int
o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.AliceIdentityKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.AliceBaseKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.BobOneTimeKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:], includesChainIndex)
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm().
func (o OlmSession) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o OlmSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target)
written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:])
writtenRatchet, err := o.AliceIdentityKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
return written, nil
}
// PickleLen returns the actual number of bytes the pickled session will have.
func (o OlmSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
length += o.AliceIdentityKey.PickleLen()
length += o.AliceBaseKey.PickleLen()
length += o.BobOneTimeKey.PickleLen()
length += o.Ratchet.PickleLen()
return length
}
// PickleLenMin returns the minimum number of bytes the pickled session must have.
func (o OlmSession) PickleLenMin() int {
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
length += o.AliceIdentityKey.PickleLen()
length += o.AliceBaseKey.PickleLen()
length += o.BobOneTimeKey.PickleLen()
length += o.Ratchet.PickleLenMin()
return length
}
// Describe returns a string describing the current state of the session for debugging.
func (o OlmSession) Describe() string {
var res string
if o.Ratchet.SenderChains.IsSet {
res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index)
} else {
res += "sender chain index: "
}
res += "receiver chain indicies:"
for _, curChain := range o.Ratchet.ReceiverChains {
res += fmt.Sprintf(" %d", curChain.CKey.Index)
}
res += " skipped message keys:"
for _, curSkip := range o.Ratchet.SkippedMessageKeys {
res += fmt.Sprintf(" %d", curSkip.MKey.Index)
}
return res
}

View File

@@ -0,0 +1,23 @@
package utilities
import (
"encoding/base64"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id"
)
// VerifySignature verifies an ed25519 signature.
func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) {
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
if err != nil {
return false, err
}
signatureDecoded, err := goolm.Base64Decode(signature)
if err != nil {
return false, err
}
publicKey := crypto.Ed25519PublicKey(keyDecoded)
return publicKey.Verify(message, signatureDecoded), nil
}

View File

@@ -0,0 +1,60 @@
package utilities
import (
"encoding/json"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
)
// PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format.
func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("pickle: %w", goolm.ErrNoKeyProvided)
}
marshaled, err := json.Marshal(object)
if err != nil {
return nil, fmt.Errorf("pickle marshal: %w", err)
}
marshaled = append([]byte{pickleVersion}, marshaled...)
toEncrypt := make([]byte, len(marshaled))
copy(toEncrypt, marshaled)
//pad marshaled to get block size
if len(marshaled)%cipher.PickleBlockSize() != 0 {
padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize()
toEncrypt = make([]byte, len(marshaled)+padding)
copy(toEncrypt, marshaled)
}
encrypted, err := cipher.Pickle(key, toEncrypt)
if err != nil {
return nil, fmt.Errorf("pickle encrypt: %w", err)
}
return encrypted, nil
}
// UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
if len(key) == 0 {
return fmt.Errorf("unpickle: %w", goolm.ErrNoKeyProvided)
}
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return fmt.Errorf("unpickle decrypt: %w", err)
}
//unpad decrypted so unmarshal works
for i := len(decrypted) - 1; i >= 0; i-- {
if decrypted[i] != 0 {
decrypted = decrypted[:i+1]
break
}
}
if decrypted[0] != pickleVersion {
return fmt.Errorf("unpickle: %w", goolm.ErrWrongPickleVersion)
}
err = json.Unmarshal(decrypted[1:], object)
if err != nil {
return fmt.Errorf("unpickle unmarshal: %w", err)
}
return nil
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,6 +8,7 @@ package crypto
import (
"bytes"
"context"
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
@@ -91,7 +92,7 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession,
return sessionsJSON, nil
}
func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, error) {
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) {
if session.Algorithm != id.AlgorithmMegolmV1 {
return false, ErrInvalidExportedAlgorithm
}
@@ -112,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
// We already have an equivalent or better session in the store, so don't override it.
return false, nil
}
err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs)
err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs)
if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err)
}
@@ -127,7 +128,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
// ImportKeys imports data that was exported with the format specified in the Matrix spec.
// See https://spec.matrix.org/v1.2/client-server-api/#key-exports
func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, error) {
func (mach *OlmMachine) ImportKeys(ctx context.Context, passphrase string, data []byte) (int, int, error) {
exportData, err := decodeKeyExport(data)
if err != nil {
return 0, 0, err
@@ -143,8 +144,11 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er
Str("room_id", session.RoomID.String()).
Str("session_id", session.SessionID.String()).
Logger()
imported, err := mach.importExportedRoomKey(session)
imported, err := mach.importExportedRoomKey(ctx, session)
if err != nil {
if ctx.Err() != nil {
return count, len(sessions), ctx.Err()
}
log.Error().Err(err).Msg("Failed to import Megolm session from file")
} else if imported {
log.Debug().Msg("Imported Megolm session from file")

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -48,7 +48,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
keyResponseReceived := make(chan struct{})
mach.roomKeyRequestFilled.Store(sessionID, keyResponseReceived)
err := mach.SendRoomKeyRequest(roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}})
err := mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}})
if err != nil {
return nil, err
}
@@ -85,7 +85,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
},
}
mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceCancel)
mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceCancel)
}()
return resChan, nil
}
@@ -99,7 +99,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
// to the specific key request, but currently it only supports a single target device and is therefore deprecated.
// A future function may properly support multiple targets and automatically canceling the other requests when receiving
// the first response.
func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error {
func (mach *OlmMachine) SendRoomKeyRequest(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error {
if len(requestID) == 0 {
requestID = mach.Client.TxnID()
}
@@ -126,7 +126,7 @@ func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.Sender
toDeviceReq.Messages[user][device] = requestEvent
}
}
_, err := mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceReq)
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceReq)
return err
}
@@ -152,7 +152,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
Msg("Mismatched session ID while creating inbound group session from forward")
return false
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
if err != nil {
log.Error().Err(err).Msg("Failed to get encryption event for room")
}
var maxAge time.Duration
var maxMessages int
if config != nil {
@@ -178,7 +181,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
}
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
if err != nil {
log.Error().Err(err).Msg("Failed to store new inbound group session")
return false
@@ -188,7 +191,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
return true
}
func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) {
func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) {
if rejection.Code == "" {
// If the rejection code is empty, it means don't share keys, but also don't tell the requester.
return
@@ -201,7 +204,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id
Code: rejection.Code,
Reason: rejection.Reason,
}
err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content)
err := mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content)
if err != nil {
mach.Log.Warn().Err(err).
Str("code", string(rejection.Code)).
@@ -209,7 +212,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id
Str("device_id", device.DeviceID.String()).
Msg("Failed to send key share rejection")
}
err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content)
err = mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content)
if err != nil {
mach.Log.Warn().Err(err).
Str("code", string(rejection.Code)).
@@ -270,23 +273,23 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
rejection := mach.AllowKeyShare(ctx, device, content.Body)
if rejection != nil {
mach.rejectKeyRequest(*rejection, device, content.Body)
mach.rejectKeyRequest(ctx, *rejection, device, content.Body)
return
}
igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
} else {
log.Error().Err(err).Msg("Failed to get group session to forward")
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
}
return
} else if igs == nil {
log.Error().Msg("Didn't find group session to forward")
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
return
}
if internalID := igs.ID(); internalID != content.Body.SessionID {
@@ -299,7 +302,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
exportedKey, err := igs.Internal.Export(firstKnownIndex)
if err != nil {
log.Error().Err(err).Msg("Failed to export group session to forward")
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
return
}
@@ -331,7 +334,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
Int("first_message_index", content.FirstMessageIndex).
Logger()
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Acked group session was already redacted")
@@ -351,7 +354,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
if err != nil {
log.Err(err).Msg("Failed to redact group session")
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2023 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -33,6 +33,9 @@ type OlmMachine struct {
PlaintextMentions bool
// Never ask the server for keys automatically as a side effect.
DisableKeyFetching bool
SendKeysMinTrust id.TrustState
ShareKeysMinTrust id.TrustState
@@ -80,11 +83,11 @@ type OlmMachine struct {
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
type StateStore interface {
// IsEncrypted returns whether a room is encrypted.
IsEncrypted(id.RoomID) bool
IsEncrypted(context.Context, id.RoomID) (bool, error)
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent
GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error)
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
FindSharedRooms(id.UserID) []id.RoomID
FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error)
}
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
@@ -131,8 +134,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
// This must be called before using the machine.
func (mach *OlmMachine) Load() (err error) {
mach.account, err = mach.CryptoStore.GetAccount()
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
mach.account, err = mach.CryptoStore.GetAccount(ctx)
if err != nil {
return
}
@@ -142,16 +145,16 @@ func (mach *OlmMachine) Load() (err error) {
return nil
}
func (mach *OlmMachine) saveAccount() {
err := mach.CryptoStore.PutAccount(mach.account)
func (mach *OlmMachine) saveAccount(ctx context.Context) {
err := mach.CryptoStore.PutAccount(ctx, mach.account)
if err != nil {
mach.Log.Error().Err(err).Msg("Failed to save account")
}
}
// FlushStore calls the Flush method of the CryptoStore.
func (mach *OlmMachine) FlushStore() error {
return mach.CryptoStore.Flush()
func (mach *OlmMachine) FlushStore(ctx context.Context) error {
return mach.CryptoStore.Flush(ctx)
}
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
@@ -194,9 +197,9 @@ func (mach *OlmMachine) OwnIdentity() *id.Device {
}
type asEventProcessor interface {
On(evtType event.Type, handler func(evt *event.Event))
OnOTK(func(otk *mautrix.OTKCount))
OnDeviceList(func(lists *mautrix.DeviceLists, since string))
On(evtType event.Type, handler func(ctx context.Context, evt *event.Event))
OnOTK(func(ctx context.Context, otk *mautrix.OTKCount))
OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string))
}
func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
@@ -217,19 +220,23 @@ func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions")
}
func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists, since string) {
if len(dl.Changed) > 0 {
traceID := time.Now().Format("15:04:05.000000")
mach.Log.Debug().
Str("trace_id", traceID).
Interface("changes", dl.Changed).
Msg("Device list changes in /sync")
mach.fetchKeys(context.TODO(), dl.Changed, since, false)
if mach.DisableKeyFetching {
mach.CryptoStore.MarkTrackedUsersOutdated(ctx, dl.Changed)
} else {
mach.FetchKeys(ctx, dl.Changed, false)
}
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
}
}
func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) {
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Warn().
@@ -243,7 +250,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
if otkCount.SignedCurve25519 < int(minCount) {
traceID := time.Now().Format("15:04:05.000000")
log := mach.Log.With().Str("trace_id", traceID).Logger()
ctx := log.WithContext(context.Background())
ctx = log.WithContext(ctx)
log.Debug().
Int("keys_left", otkCount.Curve25519).
Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...")
@@ -261,8 +268,8 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
// This can be easily registered into a mautrix client using .OnSync():
//
// client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse)
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
mach.HandleDeviceLists(&resp.DeviceLists, since)
func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) bool {
mach.HandleDeviceLists(ctx, &resp.DeviceLists, since)
for _, evt := range resp.ToDevice.Events {
evt.Type.Class = event.ToDeviceEventType
@@ -271,10 +278,10 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event")
continue
}
mach.HandleToDeviceEvent(evt)
mach.HandleToDeviceEvent(ctx, evt)
}
mach.HandleOTKCounts(&resp.DeviceOTKCount)
mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
return true
}
@@ -283,8 +290,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
// Currently this is not automatically called, so you must add a listener yourself:
//
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
if !mach.StateStore.IsEncrypted(evt.RoomID) {
func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) {
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
Msg("Failed to check if room is encrypted to handle member event")
return
} else if !isEncrypted {
return
}
content := evt.Content.AsMember()
@@ -311,7 +322,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
Str("prev_membership", string(prevContent.Membership)).
Str("new_membership", string(content.Membership)).
Msg("Got membership state change, invalidating group session in room")
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
if err != nil {
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
}
@@ -319,7 +330,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
// don't need to add any custom handlers if you use that method.
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) {
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Debug().
@@ -329,12 +340,13 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
return
}
traceID := time.Now().Format("15:04:05.000000")
// TODO use context log?
log := mach.Log.With().
Str("trace_id", traceID).
Str("sender", evt.Sender.String()).
Str("type", evt.Type.Type).
Logger()
ctx := log.WithContext(context.Background())
ctx = log.WithContext(ctx)
if evt.Type != event.ToDeviceEncrypted {
log.Debug().Msg("Starting handling to-device event")
}
@@ -344,7 +356,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
Str("sender_key", content.SenderKey.String()).
Logger()
log.Debug().Msg("Handling encrypted to-device event")
ctx = log.WithContext(context.Background())
ctx = log.WithContext(ctx)
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
if err != nil {
log.Error().Err(err).Msg("Failed to decrypt to-device event")
@@ -381,17 +393,17 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content)
// verification cases
case *event.VerificationStartEventContent:
mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "")
mach.handleVerificationStart(ctx, evt.Sender, content, content.TransactionID, 10*time.Minute, "")
case *event.VerificationAcceptEventContent:
mach.handleVerificationAccept(evt.Sender, content, content.TransactionID)
mach.handleVerificationAccept(ctx, evt.Sender, content, content.TransactionID)
case *event.VerificationKeyEventContent:
mach.handleVerificationKey(evt.Sender, content, content.TransactionID)
mach.handleVerificationKey(ctx, evt.Sender, content, content.TransactionID)
case *event.VerificationMacEventContent:
mach.handleVerificationMAC(evt.Sender, content, content.TransactionID)
mach.handleVerificationMAC(ctx, evt.Sender, content, content.TransactionID)
case *event.VerificationCancelEventContent:
mach.handleVerificationCancel(evt.Sender, content, content.TransactionID)
case *event.VerificationRequestEventContent:
mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "")
mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "")
case *event.RoomKeyWithheldEventContent:
mach.handleRoomKeyWithheld(ctx, content)
default:
@@ -405,14 +417,15 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
// and if it's not found it asks the server for it.
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
if err != nil {
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
} else if device != nil {
} else if device != nil || mach.DisableKeyFetching {
return device, nil
}
usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true)
if devices, ok := usersToDevices[userID]; ok {
if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil {
return nil, fmt.Errorf("failed to fetch keys: %w", err)
} else if devices, ok := usersToDevices[userID]; ok {
if device, ok = devices[deviceID]; ok {
return device, nil
}
@@ -425,15 +438,15 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
// the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
if err != nil || deviceIdentity != nil {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
if err != nil || deviceIdentity != nil || mach.DisableKeyFetching {
return deviceIdentity, err
}
mach.machOrContextLog(ctx).Debug().
Str("user_id", userID.String()).
Str("identity_key", identityKey.String()).
Msg("Didn't find identity in crypto store, fetching from server")
devices := mach.LoadDevices(userID)
devices := mach.LoadDevices(ctx, userID)
for _, device := range devices {
if device.IdentityKey == identityKey {
return device, nil
@@ -455,7 +468,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey)
olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey)
if err != nil {
return err
}
@@ -473,7 +486,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
Str("to_identity_key", device.IdentityKey.String()).
Str("olm_session_id", olmSess.ID().String()).
Msg("Sending encrypted to-device event")
_, err = mach.Client.SendToDevice(event.ToDeviceEncrypted,
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted,
&mautrix.ReqSendToDevice{
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
device.UserID: {
@@ -499,7 +512,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
Msg("Mismatched session ID while creating inbound group session")
return
}
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
if err != nil {
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return
@@ -525,7 +538,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
}
// WaitForSession waits for the given Megolm session to arrive.
func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
mach.keyWaitersLock.Lock()
ch, ok := mach.keyWaiters[sessionID]
if !ok {
@@ -534,7 +547,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
}
mach.keyWaitersLock.Unlock()
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
return true
}
@@ -542,10 +555,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
case <-ch:
return true
case <-time.After(timeout):
sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
// Check if the session somehow appeared in the store without telling us
// We accept withheld sessions as received, as then the decryption attempt will show the error.
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
case <-ctx.Done():
return false
}
}
@@ -568,7 +583,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
return
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
if err != nil {
log.Error().Err(err).Msg("Failed to get encryption event for room")
}
var maxAge time.Duration
var maxMessages int
if config != nil {
@@ -589,7 +607,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
}
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device")
if err != nil {
log.Err(err).Msg("Failed to redact previous megolm sessions")
} else {
@@ -606,7 +624,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
return
}
err := mach.CryptoStore.PutWithheldGroupSession(*content)
err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content)
if err != nil {
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
}
@@ -624,7 +642,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
defer mach.otkUploadLock.Unlock()
if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 {
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count")
resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{})
resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{})
if err != nil {
return fmt.Errorf("failed to check current OTK counts: %w", err)
}
@@ -637,6 +655,15 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
var deviceKeys *mautrix.DeviceKeys
if !mach.account.Shared {
deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID)
err := mach.CryptoStore.PutDevice(ctx, mach.Client.UserID, &id.Device{
UserID: mach.Client.UserID,
DeviceID: mach.Client.DeviceID,
IdentityKey: deviceKeys.Keys.GetCurve25519(mach.Client.DeviceID),
SigningKey: deviceKeys.Keys.GetEd25519(mach.Client.DeviceID),
})
if err != nil {
return fmt.Errorf("failed to save initial keys: %w", err)
}
log.Debug().Msg("Going to upload initial account keys")
}
oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount)
@@ -649,20 +676,20 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
OneTimeKeys: oneTimeKeys,
}
log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys")
_, err := mach.Client.UploadKeys(req)
_, err := mach.Client.UploadKeys(ctx, req)
if err != nil {
return err
}
mach.lastOTKUpload = time.Now()
mach.account.Shared = true
mach.saveAccount()
mach.saveAccount(ctx)
return nil
}
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
for {
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx)
if err != nil {
log.Err(err).Msg("Failed to redact expired megolm sessions")
} else if len(sessionIDs) > 0 {

View File

@@ -1,177 +0,0 @@
Apache License
Version 2.0, January 2004
http://www.apache.org/licenses/
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
1. Definitions.
"License" shall mean the terms and conditions for use, reproduction,
and distribution as defined by Sections 1 through 9 of this document.
"Licensor" shall mean the copyright owner or entity authorized by
the copyright owner that is granting the License.
"Legal Entity" shall mean the union of the acting entity and all
other entities that control, are controlled by, or are under common
control with that entity. For the purposes of this definition,
"control" means (i) the power, direct or indirect, to cause the
direction or management of such entity, whether by contract or
otherwise, or (ii) ownership of fifty percent (50%) or more of the
outstanding shares, or (iii) beneficial ownership of such entity.
"You" (or "Your") shall mean an individual or Legal Entity
exercising permissions granted by this License.
"Source" form shall mean the preferred form for making modifications,
including but not limited to software source code, documentation
source, and configuration files.
"Object" form shall mean any form resulting from mechanical
transformation or translation of a Source form, including but
not limited to compiled object code, generated documentation,
and conversions to other media types.
"Work" shall mean the work of authorship, whether in Source or
Object form, made available under the License, as indicated by a
copyright notice that is included in or attached to the work
(an example is provided in the Appendix below).
"Derivative Works" shall mean any work, whether in Source or Object
form, that is based on (or derived from) the Work and for which the
editorial revisions, annotations, elaborations, or other modifications
represent, as a whole, an original work of authorship. For the purposes
of this License, Derivative Works shall not include works that remain
separable from, or merely link (or bind by name) to the interfaces of,
the Work and Derivative Works thereof.
"Contribution" shall mean any work of authorship, including
the original version of the Work and any modifications or additions
to that Work or Derivative Works thereof, that is intentionally
submitted to Licensor for inclusion in the Work by the copyright owner
or by an individual or Legal Entity authorized to submit on behalf of
the copyright owner. For the purposes of this definition, "submitted"
means any form of electronic, verbal, or written communication sent
to the Licensor or its representatives, including but not limited to
communication on electronic mailing lists, source code control systems,
and issue tracking systems that are managed by, or on behalf of, the
Licensor for the purpose of discussing and improving the Work, but
excluding communication that is conspicuously marked or otherwise
designated in writing by the copyright owner as "Not a Contribution."
"Contributor" shall mean Licensor and any individual or Legal Entity
on behalf of whom a Contribution has been received by Licensor and
subsequently incorporated within the Work.
2. Grant of Copyright License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
copyright license to reproduce, prepare Derivative Works of,
publicly display, publicly perform, sublicense, and distribute the
Work and such Derivative Works in Source or Object form.
3. Grant of Patent License. Subject to the terms and conditions of
this License, each Contributor hereby grants to You a perpetual,
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
(except as stated in this section) patent license to make, have made,
use, offer to sell, sell, import, and otherwise transfer the Work,
where such license applies only to those patent claims licensable
by such Contributor that are necessarily infringed by their
Contribution(s) alone or by combination of their Contribution(s)
with the Work to which such Contribution(s) was submitted. If You
institute patent litigation against any entity (including a
cross-claim or counterclaim in a lawsuit) alleging that the Work
or a Contribution incorporated within the Work constitutes direct
or contributory patent infringement, then any patent licenses
granted to You under this License for that Work shall terminate
as of the date such litigation is filed.
4. Redistribution. You may reproduce and distribute copies of the
Work or Derivative Works thereof in any medium, with or without
modifications, and in Source or Object form, provided that You
meet the following conditions:
(a) You must give any other recipients of the Work or
Derivative Works a copy of this License; and
(b) You must cause any modified files to carry prominent notices
stating that You changed the files; and
(c) You must retain, in the Source form of any Derivative Works
that You distribute, all copyright, patent, trademark, and
attribution notices from the Source form of the Work,
excluding those notices that do not pertain to any part of
the Derivative Works; and
(d) If the Work includes a "NOTICE" text file as part of its
distribution, then any Derivative Works that You distribute must
include a readable copy of the attribution notices contained
within such NOTICE file, excluding those notices that do not
pertain to any part of the Derivative Works, in at least one
of the following places: within a NOTICE text file distributed
as part of the Derivative Works; within the Source form or
documentation, if provided along with the Derivative Works; or,
within a display generated by the Derivative Works, if and
wherever such third-party notices normally appear. The contents
of the NOTICE file are for informational purposes only and
do not modify the License. You may add Your own attribution
notices within Derivative Works that You distribute, alongside
or as an addendum to the NOTICE text from the Work, provided
that such additional attribution notices cannot be construed
as modifying the License.
You may add Your own copyright statement to Your modifications and
may provide additional or different license terms and conditions
for use, reproduction, or distribution of Your modifications, or
for any such Derivative Works as a whole, provided Your use,
reproduction, and distribution of the Work otherwise complies with
the conditions stated in this License.
5. Submission of Contributions. Unless You explicitly state otherwise,
any Contribution intentionally submitted for inclusion in the Work
by You to the Licensor shall be under the terms and conditions of
this License, without any additional terms or conditions.
Notwithstanding the above, nothing herein shall supersede or modify
the terms of any separate license agreement you may have executed
with Licensor regarding such Contributions.
6. Trademarks. This License does not grant permission to use the trade
names, trademarks, service marks, or product names of the Licensor,
except as required for reasonable and customary use in describing the
origin of the Work and reproducing the content of the NOTICE file.
7. Disclaimer of Warranty. Unless required by applicable law or
agreed to in writing, Licensor provides the Work (and each
Contributor provides its Contributions) on an "AS IS" BASIS,
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
implied, including, without limitation, any warranties or conditions
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
PARTICULAR PURPOSE. You are solely responsible for determining the
appropriateness of using or redistributing the Work and assume any
risks associated with Your exercise of permissions under this License.
8. Limitation of Liability. In no event and under no legal theory,
whether in tort (including negligence), contract, or otherwise,
unless required by applicable law (such as deliberate and grossly
negligent acts) or agreed to in writing, shall any Contributor be
liable to You for damages, including any direct, indirect, special,
incidental, or consequential damages of any character arising as a
result of this License or out of the use or inability to use the
Work (including but not limited to damages for loss of goodwill,
work stoppage, computer failure or malfunction, or any and all
other commercial damages or losses), even if such Contributor
has been advised of the possibility of such damages.
9. Accepting Warranty or Additional Liability. While redistributing
the Work or Derivative Works thereof, You may choose to offer,
and charge a fee for, acceptance of support, warranty, indemnity,
or other liability obligations and/or rights consistent with this
License. However, in accepting such obligations, You may act only
on Your own behalf and on Your sole responsibility, not on behalf
of any other Contributor, and only if You agree to indemnify,
defend, and hold each Contributor harmless for any liability
incurred by, or claims asserted against, such Contributor by reason
of your accepting any such warranty or additional liability.
END OF TERMS AND CONDITIONS

View File

@@ -1,2 +1,4 @@
# Go olm bindings
Based on [Dhole/go-olm](https://github.com/Dhole/go-olm)
The original project is licensed under the Apache 2.0 license.

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++
@@ -155,6 +157,7 @@ func (a *Account) Unpickle(pickled, key []byte) error {
return nil
}
// Deprecated
func (a *Account) GobEncode() ([]byte, error) {
pickled := a.Pickle(pickleKey)
length := base64.RawStdEncoding.DecodedLen(len(pickled))
@@ -163,6 +166,7 @@ func (a *Account) GobEncode() ([]byte, error) {
return rawPickled, err
}
// Deprecated
func (a *Account) GobDecode(rawPickled []byte) error {
if a.int == nil {
*a = *NewBlankAccount()
@@ -173,6 +177,7 @@ func (a *Account) GobDecode(rawPickled []byte) error {
return a.Unpickle(pickled, pickleKey)
}
// Deprecated
func (a *Account) MarshalJSON() ([]byte, error) {
pickled := a.Pickle(pickleKey)
quotes := make([]byte, len(pickled)+2)
@@ -182,6 +187,7 @@ func (a *Account) MarshalJSON() ([]byte, error) {
return quotes, nil
}
// Deprecated
func (a *Account) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString

View File

@@ -0,0 +1,154 @@
//go:build goolm
package olm
import (
"encoding/json"
"github.com/tidwall/sjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/account"
"maunium.net/go/mautrix/id"
)
// Account stores a device account for end to end encrypted messaging.
type Account struct {
account.Account
}
// NewAccount creates a new Account.
func NewAccount() *Account {
a, err := account.NewAccount(nil)
if err != nil {
panic(err)
}
ac := &Account{}
ac.Account = *a
return ac
}
func NewBlankAccount() *Account {
return &Account{}
}
// Clear clears the memory used to back this Account.
func (a *Account) Clear() error {
a.Account = account.Account{}
return nil
}
// Pickle returns an Account as a base64 string. Encrypts the Account using the
// supplied key.
func (a *Account) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := a.Account.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
func (a *Account) IdentityKeysJSON() []byte {
identityKeys, err := a.Account.IdentityKeysJSON()
if err != nil {
panic(err)
}
return identityKeys
}
// Sign returns the signature of a message using the ed25519 key for this
// Account.
func (a *Account) Sign(message []byte) []byte {
if len(message) == 0 {
panic(EmptyInput)
}
signature, err := a.Account.Sign(message)
if err != nil {
panic(err)
}
return signature
}
// SignJSON signs the given JSON object following the Matrix specification:
// https://matrix.org/docs/spec/appendices#signing-json
func (a *Account) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil
}
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
// Account can store.
func (a *Account) MaxNumberOfOneTimeKeys() uint {
return uint(account.MaxOneTimeKeys)
}
// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
// keys are discarded.
func (a *Account) GenOneTimeKeys(num uint) {
err := a.Account.GenOneTimeKeys(nil, num)
if err != nil {
panic(err)
}
}
// NewOutboundSession creates a new out-bound session for sending messages to a
// given curve25519 identityKey and oneTimeKey. Returns error on failure.
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, EmptyInput
}
s := &Session{}
newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey)
if err != nil {
return nil, err
}
s.OlmSession = *newSession
return s, nil
}
// NewInboundSession creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure.
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, EmptyInput
}
s := &Session{}
newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg))
if err != nil {
return nil, err
}
s.OlmSession = *newSession
return s, nil
}
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
// messages from an incoming PRE_KEY message. Returns error on failure.
func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) {
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
return nil, EmptyInput
}
s := &Session{}
newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg))
if err != nil {
return nil, err
}
s.OlmSession = *newSession
return s, nil
}
// RemoveOneTimeKeys removes the one time keys that the session used from the
// Account. Returns error on failure.
func (a *Account) RemoveOneTimeKeys(s *Session) error {
a.Account.RemoveOneTimeKeys(&s.OlmSession)
return nil
}

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
import (

View File

@@ -0,0 +1,23 @@
//go:build goolm
package olm
import (
"errors"
"maunium.net/go/mautrix/crypto/goolm"
)
// Error codes from go-olm
var (
EmptyInput = goolm.ErrEmptyInput
NoKeyProvided = goolm.ErrNoKeyProvided
NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
InputNotJSONString = errors.New("input doesn't look like a JSON string")
)
// Error codes from olm code
var (
UnknownMessageIndex = goolm.ErrRatchetNotAvailable
)

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++
@@ -147,6 +149,7 @@ func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
return nil
}
// Deprecated
func (s *InboundGroupSession) GobEncode() ([]byte, error) {
pickled := s.Pickle(pickleKey)
length := base64.RawStdEncoding.DecodedLen(len(pickled))
@@ -155,6 +158,7 @@ func (s *InboundGroupSession) GobEncode() ([]byte, error) {
return rawPickled, err
}
// Deprecated
func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
if s == nil || s.int == nil {
*s = *NewBlankInboundGroupSession()
@@ -165,6 +169,7 @@ func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
return s.Unpickle(pickled, pickleKey)
}
// Deprecated
func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
pickled := s.Pickle(pickleKey)
quotes := make([]byte, len(pickled)+2)
@@ -174,6 +179,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
return quotes, nil
}
// Deprecated
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString

View File

@@ -0,0 +1,149 @@
//go:build goolm
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
// InboundGroupSession stores an inbound encrypted messaging session for a
// group.
type InboundGroupSession struct {
session.MegolmInboundSession
}
// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled
// base64 string. Decrypts the InboundGroupSession using the supplied key.
// Returns error on failure.
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
lenKey := len(key)
if lenKey == 0 {
key = []byte(" ")
}
megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key)
if err != nil {
return nil, err
}
return &InboundGroupSession{
MegolmInboundSession: *megolmSession,
}, nil
}
// NewInboundGroupSession creates a new inbound group session from a key
// exported from OutboundGroupSession.Key(). Returns error on failure.
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, EmptyInput
}
megolmSession, err := session.NewMegolmInboundSession(sessionKey)
if err != nil {
return nil, err
}
return &InboundGroupSession{
MegolmInboundSession: *megolmSession,
}, nil
}
// InboundGroupSessionImport imports an inbound group session from a previous
// export. Returns error on failure.
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
if len(sessionKey) == 0 {
return nil, EmptyInput
}
megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey)
if err != nil {
return nil, err
}
return &InboundGroupSession{
MegolmInboundSession: *megolmSession,
}, nil
}
func NewBlankInboundGroupSession() *InboundGroupSession {
return &InboundGroupSession{}
}
// Clear clears the memory used to back this InboundGroupSession.
func (s *InboundGroupSession) Clear() error {
s.MegolmInboundSession = session.MegolmInboundSession{}
return nil
}
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
// InboundGroupSession using the supplied key.
func (s *InboundGroupSession) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := s.MegolmInboundSession.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
} else if len(pickled) == 0 {
return EmptyInput
}
sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key)
if err != nil {
return err
}
s.MegolmInboundSession = *sOlm
return nil
}
// Decrypt decrypts a message using the InboundGroupSession. Returns the the
// plain-text and message index on success. Returns error on failure.
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
if len(message) == 0 {
return nil, 0, EmptyInput
}
plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message)
if err != nil {
return nil, 0, err
}
return plaintext, uint(messageIndex), nil
}
// ID returns a base64-encoded identifier for this session.
func (s *InboundGroupSession) ID() id.SessionID {
return s.MegolmInboundSession.SessionID()
}
// FirstKnownIndex returns the first message index we know how to decrypt.
func (s *InboundGroupSession) FirstKnownIndex() uint32 {
return s.MegolmInboundSession.InitialRatchet.Counter
}
// IsVerified check if the session has been verified as a valid session. (A
// session is verified either because the original session share was signed, or
// because we have subsequently successfully decrypted a message.)
func (s *InboundGroupSession) IsVerified() uint {
if s.MegolmInboundSession.SigningKeyVerified {
return 1
}
return 0
}
// Export returns the base64-encoded ratchet key for this session, at the given
// index, in a format which can be used by
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
// InboundGroupSession using the supplied key. Returns error on failure.
// if we do not have a session key corresponding to the given index (ie, it was
// sent before the session key was shared with us) the error will be
// returned.
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex)
if err != nil {
return nil, err
}
return res, nil
}

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++

20
vendor/maunium.net/go/mautrix/crypto/olm/olm_goolm.go generated vendored Normal file
View File

@@ -0,0 +1,20 @@
//go:build goolm
package olm
import (
"maunium.net/go/mautrix/id"
)
// Signatures is the data structure used to sign JSON objects.
type Signatures map[id.UserID]map[id.DeviceKeyID]string
// Version returns the version number of the olm library.
func Version() (major, minor, patch uint8) {
return 3, 2, 15
}
// SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON.
func SetPickleKey(key []byte) {
panic("gob and json encoding is deprecated and not supported with goolm")
}

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++
@@ -122,6 +124,7 @@ func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
return nil
}
// Deprecated
func (s *OutboundGroupSession) GobEncode() ([]byte, error) {
pickled := s.Pickle(pickleKey)
length := base64.RawStdEncoding.DecodedLen(len(pickled))
@@ -130,6 +133,7 @@ func (s *OutboundGroupSession) GobEncode() ([]byte, error) {
return rawPickled, err
}
// Deprecated
func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
if s == nil || s.int == nil {
*s = *NewBlankOutboundGroupSession()
@@ -140,6 +144,7 @@ func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
return s.Unpickle(pickled, pickleKey)
}
// Deprecated
func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
pickled := s.Pickle(pickleKey)
quotes := make([]byte, len(pickled)+2)
@@ -149,6 +154,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
return quotes, nil
}
// Deprecated
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString

View File

@@ -0,0 +1,111 @@
//go:build goolm
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
// OutboundGroupSession stores an outbound encrypted messaging session for a
// group.
type OutboundGroupSession struct {
session.MegolmOutboundSession
}
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
// base64 string. Decrypts the OutboundGroupSession using the supplied key.
// Returns error on failure. If the key doesn't match the one used to encrypt
// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
lenKey := len(key)
if lenKey == 0 {
key = []byte(" ")
}
megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key)
if err != nil {
return nil, err
}
return &OutboundGroupSession{
MegolmOutboundSession: *megolmSession,
}, nil
}
// NewOutboundGroupSession creates a new outbound group session.
func NewOutboundGroupSession() *OutboundGroupSession {
megolmSession, err := session.NewMegolmOutboundSession()
if err != nil {
panic(err)
}
return &OutboundGroupSession{
MegolmOutboundSession: *megolmSession,
}
}
// newOutboundGroupSession initialises an empty OutboundGroupSession.
func NewBlankOutboundGroupSession() *OutboundGroupSession {
return &OutboundGroupSession{}
}
// Clear clears the memory used to back this OutboundGroupSession.
func (s *OutboundGroupSession) Clear() error {
s.MegolmOutboundSession = session.MegolmOutboundSession{}
return nil
}
// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the
// OutboundGroupSession using the supplied key.
func (s *OutboundGroupSession) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := s.MegolmOutboundSession.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
}
return s.MegolmOutboundSession.Unpickle(pickled, key)
}
// Encrypt encrypts a message using the Session. Returns the encrypted message
// as base64.
func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
if len(plaintext) == 0 {
panic(EmptyInput)
}
message, err := s.MegolmOutboundSession.Encrypt(plaintext)
if err != nil {
panic(err)
}
return message
}
// ID returns a base64-encoded identifier for this session.
func (s *OutboundGroupSession) ID() id.SessionID {
return s.MegolmOutboundSession.SessionID()
}
// MessageIndex returns the message index for this session. Each message is
// sent with an increasing index; this returns the index for the next message.
func (s *OutboundGroupSession) MessageIndex() uint {
return uint(s.MegolmOutboundSession.Ratchet.Counter)
}
// Key returns the base64-encoded current ratchet key for this session.
func (s *OutboundGroupSession) Key() string {
message, err := s.MegolmOutboundSession.SessionSharingMessage()
if err != nil {
panic(err)
}
return string(message)
}

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++

71
vendor/maunium.net/go/mautrix/crypto/olm/pk_goolm.go generated vendored Normal file
View File

@@ -0,0 +1,71 @@
//go:build goolm
package olm
import (
"encoding/json"
"github.com/tidwall/sjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/pk"
"maunium.net/go/mautrix/id"
)
// PkSigning stores a key pair for signing messages.
type PkSigning struct {
pk.Signing
PublicKey id.Ed25519
Seed []byte
}
// Clear clears the underlying memory of a PkSigning object.
func (p *PkSigning) Clear() {
p.Signing = pk.Signing{}
}
// NewPkSigningFromSeed creates a new PkSigning object using the given seed.
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) {
p := &PkSigning{}
signing, err := pk.NewSigningFromSeed(seed)
if err != nil {
return nil, err
}
p.Signing = *signing
p.Seed = seed
p.PublicKey = p.Signing.PublicKey()
return p, nil
}
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages.
func NewPkSigning() (*PkSigning, error) {
p := &PkSigning{}
signing, err := pk.NewSigning()
if err != nil {
return nil, err
}
p.Signing = *signing
p.Seed = signing.Seed
p.PublicKey = p.Signing.PublicKey()
return p, err
}
// Sign creates a signature for the given message using this key.
func (p *PkSigning) Sign(message []byte) ([]byte, error) {
return p.Signing.Sign(message), nil
}
// SignJSON creates a signature for the given object after encoding it to canonical JSON.
func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
if err != nil {
return "", err
}
return string(signature), nil
}

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++
@@ -155,6 +157,7 @@ func (s *Session) Unpickle(pickled, key []byte) error {
return nil
}
// Deprecated
func (s *Session) GobEncode() ([]byte, error) {
pickled := s.Pickle(pickleKey)
length := base64.RawStdEncoding.DecodedLen(len(pickled))
@@ -163,6 +166,7 @@ func (s *Session) GobEncode() ([]byte, error) {
return rawPickled, err
}
// Deprecated
func (s *Session) GobDecode(rawPickled []byte) error {
if s == nil || s.int == nil {
*s = *NewBlankSession()
@@ -173,6 +177,7 @@ func (s *Session) GobDecode(rawPickled []byte) error {
return s.Unpickle(pickled, pickleKey)
}
// Deprecated
func (s *Session) MarshalJSON() ([]byte, error) {
pickled := s.Pickle(pickleKey)
quotes := make([]byte, len(pickled)+2)
@@ -182,6 +187,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
return quotes, nil
}
// Deprecated
func (s *Session) UnmarshalJSON(data []byte) error {
if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
return InputNotJSONString

View File

@@ -0,0 +1,110 @@
//go:build goolm
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/id"
)
// Session stores an end to end encrypted messaging session.
type Session struct {
session.OlmSession
}
// SessionFromPickled loads a Session from a pickled base64 string. Decrypts
// the Session using the supplied key. Returns error on failure.
func SessionFromPickled(pickled, key []byte) (*Session, error) {
if len(pickled) == 0 {
return nil, EmptyInput
}
s := NewBlankSession()
return s, s.Unpickle(pickled, key)
}
func NewBlankSession() *Session {
return &Session{}
}
// Clear clears the memory used to back this Session.
func (s *Session) Clear() error {
s.OlmSession = session.OlmSession{}
return nil
}
// Pickle returns a Session as a base64 string. Encrypts the Session using the
// supplied key.
func (s *Session) Pickle(key []byte) []byte {
if len(key) == 0 {
panic(NoKeyProvided)
}
pickled, err := s.OlmSession.Pickle(key)
if err != nil {
panic(err)
}
return pickled
}
func (s *Session) Unpickle(pickled, key []byte) error {
if len(key) == 0 {
return NoKeyProvided
} else if len(pickled) == 0 {
return EmptyInput
}
sOlm, err := session.OlmSessionFromPickled(pickled, key)
if err != nil {
return err
}
s.OlmSession = *sOlm
return nil
}
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match. Returns error on
// failure.
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
return s.MatchesInboundSessionFrom("", oneTimeKeyMsg)
}
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match. Returns error on
// failure.
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
if theirIdentityKey != "" {
theirKey := id.Curve25519(theirIdentityKey)
return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg))
}
return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg))
}
// Encrypt encrypts a message using the Session. Returns the encrypted message
// as base64.
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
if len(plaintext) == 0 {
panic(EmptyInput)
}
messageType, message, err := s.OlmSession.Encrypt(plaintext, nil)
if err != nil {
panic(err)
}
return messageType, message
}
// Decrypt decrypts a message using the Session. Returns the the plain-text on
// success. Returns error on failure.
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
if len(message) == 0 {
return nil, EmptyInput
}
return s.OlmSession.Decrypt([]byte(message), msgType)
}
// Describe generates a string describing the internal state of an olm session for debugging and logging purposes.
func (s *Session) Describe() string {
return s.OlmSession.Describe()
}

View File

@@ -1,3 +1,5 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++

View File

@@ -0,0 +1,92 @@
//go:build goolm
package olm
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exgjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
// Utility stores the necessary state to perform hash and signature
// verification operations.
type Utility struct{}
// Clear clears the memory used to back this utility.
func (u *Utility) Clear() error {
return nil
}
// NewUtility creates a new utility.
func NewUtility() *Utility {
return &Utility{}
}
// Sha256 calculates the SHA-256 hash of the input and encodes it as base64.
func (u *Utility) Sha256(input string) string {
if len(input) == 0 {
panic(EmptyInput)
}
hash := sha256.Sum256([]byte(input))
return base64.RawStdEncoding.EncodeToString(hash[:])
}
// VerifySignature verifies an ed25519 signature. Returns true if the verification
// suceeds or false otherwise. Returns error on failure. If the key was too
// small then the error will be "INVALID_BASE64".
func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) {
if len(message) == 0 || len(key) == 0 || len(signature) == 0 {
return false, EmptyInput
}
return utilities.VerifySignature([]byte(message), key, []byte(signature))
}
// VerifySignatureJSON verifies the signature in the JSON object _obj following
// the Matrix specification:
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
// If the _obj is a struct, the `json` tags will be honored.
func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
var err error
objJSON, ok := obj.(json.RawMessage)
if !ok {
objJSON, err = json.Marshal(obj)
if err != nil {
return false, err
}
}
sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
if !sig.Exists() || sig.Type != gjson.String {
return false, SignatureNotFound
}
objJSON, err = sjson.DeleteBytes(objJSON, "unsigned")
if err != nil {
return false, err
}
objJSON, err = sjson.DeleteBytes(objJSON, "signatures")
if err != nil {
return false, err
}
objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON))
return u.VerifySignature(objJSONString, key, sig.Str)
}
// VerifySignatureJSON verifies the signature in the JSON object _obj following
// the Matrix specification:
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
// This function is a wrapper over Utility.VerifySignatureJSON that creates and
// destroys the Utility object transparently.
// If the _obj is a struct, the `json` tags will be honored.
func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
u := NewUtility()
defer u.Clear()
return u.VerifySignatureJSON(obj, userID, keyName, key)
}

View File

@@ -1,3 +1,5 @@
//go:build !nosas && !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++

View File

@@ -0,0 +1,23 @@
//go:build !nosas && goolm
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/sas"
)
// SAS stores an Olm Short Authentication String (SAS) object.
type SAS struct {
sas.SAS
}
// NewSAS creates a new SAS object.
func NewSAS() *SAS {
newSAS, err := sas.New()
if err != nil {
panic(err)
}
return &SAS{
SAS: *newSAS,
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -27,7 +27,7 @@ import (
"maunium.net/go/mautrix/id"
)
var PostgresArrayWrapper func(interface{}) interface {
var PostgresArrayWrapper func(any) interface {
driver.Valuer
sql.Scanner
}
@@ -62,22 +62,22 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID
}
// Flush does nothing for this implementation as data is already persisted in the database.
func (store *SQLCryptoStore) Flush() error {
func (store *SQLCryptoStore) Flush(_ context.Context) error {
return nil
}
// PutNextBatch stores the next sync batch token for the current account.
func (store *SQLCryptoStore) PutNextBatch(nextBatch string) error {
func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error {
store.SyncToken = nextBatch
_, err := store.DB.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
_, err := store.DB.Exec(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
return err
}
// GetNextBatch retrieves the next sync batch token for the current account.
func (store *SQLCryptoStore) GetNextBatch() (string, error) {
func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) {
if store.SyncToken == "" {
err := store.DB.
QueryRow("SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID).
err := store.DB.Conn(ctx).
QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID).
Scan(&store.SyncToken)
if !errors.Is(err, sql.ErrNoRows) {
return "", err
@@ -88,38 +88,42 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) {
var _ mautrix.SyncStore = (*SQLCryptoStore)(nil)
func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {}
func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" }
func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
err := store.PutNextBatch(nextBatchToken)
if err != nil {
// TODO handle error
}
func (store *SQLCryptoStore) SaveFilterID(ctx context.Context, _ id.UserID, _ string) error {
return nil
}
func (store *SQLCryptoStore) LoadFilterID(ctx context.Context, _ id.UserID) (string, error) {
return "", nil
}
func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string {
nb, err := store.GetNextBatch()
func (store *SQLCryptoStore) SaveNextBatch(ctx context.Context, _ id.UserID, nextBatchToken string) error {
err := store.PutNextBatch(ctx, nextBatchToken)
if err != nil {
// TODO handle error
return fmt.Errorf("unable to store batch: %w", err)
}
return nb
return nil
}
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
// TODO return error
store.DB.Log.Warn("Failed to scan device ID: %v", err)
func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) (string, error) {
nb, err := store.GetNextBatch(ctx)
if err != nil {
return "", fmt.Errorf("unable to load batch: %w", err)
}
return nb, nil
}
func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.DeviceID, err error) {
err = store.DB.QueryRow(ctx, "SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
// PutAccount stores an OlmAccount in the database.
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error {
store.Account = account
bytes := account.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(`
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
account=excluded.account, account_id=excluded.account_id
@@ -128,9 +132,9 @@ func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
}
// GetAccount retrieves an OlmAccount from the database.
func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
if store.Account == nil {
row := store.DB.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
var accountBytes []byte
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
@@ -149,7 +153,7 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
}
// HasSession returns whether there is an Olm session for the given sender key.
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
func (store *SQLCryptoStore) HasSession(ctx context.Context, key id.SenderKey) bool {
store.olmSessionCacheLock.Lock()
cache, ok := store.olmSessionCache[key]
store.olmSessionCacheLock.Unlock()
@@ -157,17 +161,17 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
return true
}
var sessionID id.SessionID
err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
err := store.DB.QueryRow(ctx, "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
key, store.AccountID).Scan(&sessionID)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return false
}
return len(sessionID) > 0
}
// GetSessions returns all the known Olm sessions for a sender key.
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) {
rows, err := store.DB.Query("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) (OlmSessionList, error) {
rows, err := store.DB.Query(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
key, store.AccountID)
if err != nil {
return nil, err
@@ -207,11 +211,11 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session
}
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) {
func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) {
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
row := store.DB.QueryRow("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
key, store.AccountID)
sess := OlmSession{Internal: *olm.NewBlankSession()}
@@ -219,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
var sessionID id.SessionID
err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
@@ -237,20 +241,20 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
}
// AddSession persists an Olm session for a sender in the database.
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error {
func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error {
store.olmSessionCacheLock.Lock()
defer store.olmSessionCacheLock.Unlock()
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID)
store.getOlmSessionCache(key)[session.ID()] = session
return err
}
// UpdateSession replaces the Olm session for a sender in the database.
func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error {
func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
_, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID)
return err
}
@@ -270,14 +274,14 @@ func datePtr(t time.Time) *time.Time {
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
forwardingChains := strings.Join(session.ForwardingChains, ",")
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
if err != nil {
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
}
_, err = store.DB.Exec(`
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
@@ -296,19 +300,19 @@ func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.Send
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := store.DB.QueryRow(`
err := store.DB.QueryRow(ctx, `
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
@@ -322,22 +326,7 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
Reason: withheldReason.String,
}
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return nil, err
}
var chains []string
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
if senderKey == "" {
senderKey = id.Curve25519(senderKeyDB.String)
}
@@ -355,8 +344,8 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
}, nil
}
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
_, err := store.DB.Exec(`
func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
_, err := store.DB.Exec(ctx, `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
@@ -364,27 +353,24 @@ func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, ses
return err
}
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
if roomID == "" && senderKey == "" {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
res, err := store.DB.Query(`
res, err := store.DB.Query(ctx, `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
if err != nil {
return nil, err
}
return sessionIDs, err
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
}
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) {
var query string
switch store.DB.Dialect {
case dbutil.Postgres:
@@ -408,46 +394,40 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error
default:
return nil, fmt.Errorf("unsupported dialect")
}
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
if err != nil {
return nil, err
}
return sessionIDs, err
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
}
func (store *SQLCryptoStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
res, err := store.DB.Query(`
func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) {
res, err := store.DB.Query(ctx, `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
if err != nil {
return nil, err
}
return sessionIDs, err
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
}
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
return err
}
func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
var code, reason sql.NullString
err := store.DB.QueryRow(`
err := store.DB.QueryRow(ctx, `
SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&code, &reason)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil || !code.Valid {
return nil, err
@@ -462,82 +442,79 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey
}, nil
}
func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*InboundGroupSession, err error) {
for rows.Next() {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
igs = olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return
}
if forwardingChains != "" {
chains = strings.Split(forwardingChains, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return
err = fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
if err != nil {
return
}
var chains []string
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
result = append(result, &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
})
}
return
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*InboundGroupSession, error) {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err != nil {
return nil, err
}
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
return &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
}, nil
}
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
if err == sql.ErrNoRows {
return []*InboundGroupSession{}, nil
} else if err != nil {
if err != nil {
return nil, err
}
return store.scanGroupSessionList(rows)
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
}
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
store.AccountID,
)
if err == sql.ErrNoRows {
return []*InboundGroupSession{}, nil
} else if err != nil {
if err != nil {
return nil, err
}
return store.scanGroupSessionList(rows)
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
}
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(`
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_outbound_session
(room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id)
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
@@ -551,24 +528,24 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
}
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSession) error {
func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
_, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID)
return err
}
// GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID.
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
var ogs OutboundGroupSession
var sessionBytes []byte
var maxAgeMS int64
err := store.DB.QueryRow(`
err := store.DB.QueryRow(ctx, `
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
roomID, store.AccountID,
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
if err == sql.ErrNoRows {
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
@@ -585,8 +562,8 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
}
// RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID.
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
_, err := store.DB.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roomID id.RoomID) error {
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
roomID, store.AccountID)
return err
}
@@ -603,7 +580,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
`
var expectedEventID id.EventID
var expectedTimestamp int64
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
err := store.DB.QueryRow(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
if err != nil {
return false, err
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
@@ -618,69 +595,58 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
return true, nil
}
func scanDevice(rows dbutil.Scannable) (*id.Device, error) {
var device id.Device
err := rows.Scan(&device.UserID, &device.DeviceID, &device.IdentityKey, &device.SigningKey, &device.Trust, &device.Deleted, &device.Name)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &device, nil
}
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
var ignore id.UserID
err := store.DB.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
if err == sql.ErrNoRows {
err := store.DB.QueryRow(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
rows, err := store.DB.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
if err != nil {
return nil, err
}
data := make(map[id.DeviceID]*id.Device)
for rows.Next() {
var identity id.Device
err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
return nil, err
}
identity.UserID = userID
data[identity.DeviceID] = &identity
err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) {
data[device.DeviceID] = device
return true, nil
})
if err != nil {
return nil, err
}
return data, nil
}
// GetDevice returns the device dentity for a given user and device ID.
func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
var identity id.Device
err := store.DB.QueryRow(`
SELECT identity_key, signing_key, trust, deleted, name
func (store *SQLCryptoStore) GetDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
return scanDevice(store.DB.QueryRow(ctx, `
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
FROM crypto_device WHERE user_id=$1 AND device_id=$2`,
userID, deviceID,
).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
identity.UserID = userID
identity.DeviceID = deviceID
return &identity, nil
))
}
// FindDeviceByKey finds a specific device by its sender key.
func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
var identity id.Device
err := store.DB.QueryRow(`
SELECT device_id, signing_key, trust, deleted, name
func (store *SQLCryptoStore) FindDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
return scanDevice(store.DB.QueryRow(ctx, `
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
FROM crypto_device WHERE user_id=$1 AND identity_key=$2`,
userID, identityKey,
).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
if err != nil {
if err == sql.ErrNoRows {
return nil, nil
}
return nil, err
}
identity.UserID = userID
identity.IdentityKey = identityKey
return &identity, nil
))
}
const deviceInsertQuery = `
@@ -693,106 +659,115 @@ ON CONFLICT (user_id, device_id) DO UPDATE
var deviceMassInsertTemplate = strings.ReplaceAll(deviceInsertQuery, "($1, $2, $3, $4, $5, $6, $7)", "%s")
// PutDevice stores a single device for a user, replacing it if it exists already.
func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *id.Device) error {
_, err := store.DB.Exec(deviceInsertQuery,
func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, device *id.Device) error {
_, err := store.DB.Exec(ctx, deviceInsertQuery,
userID, device.DeviceID, device.IdentityKey, device.SigningKey, device.Trust, device.Deleted, device.Name)
return err
}
const trackedUserUpsertQuery = `
INSERT INTO crypto_tracked_user (user_id, devices_outdated)
VALUES ($1, false)
ON CONFLICT (user_id) DO UPDATE
SET devices_outdated = EXCLUDED.devices_outdated
`
// PutDevices stores the device identity information for the given user ID.
func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
tx, err := store.DB.Begin()
if err != nil {
return err
}
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
return fmt.Errorf("failed to add user to tracked users list: %w", err)
}
_, err = tx.Exec("UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to delete old devices: %w", err)
}
if len(devices) == 0 {
err = tx.Commit()
func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
_, err := store.DB.Exec(ctx, trackedUserUpsertQuery, userID)
if err != nil {
return fmt.Errorf("failed to commit changes (no devices added): %w", err)
return fmt.Errorf("failed to upsert user to tracked users list: %w", err)
}
_, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
if err != nil {
return fmt.Errorf("failed to delete old devices: %w", err)
}
if len(devices) == 0 {
return nil
}
deviceBatchLen := 5 // how many devices will be inserted per query
deviceIDs := make([]id.DeviceID, 0, len(devices))
for deviceID := range devices {
deviceIDs = append(deviceIDs, deviceID)
}
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
var batchDevices []id.DeviceID
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
} else {
batchDevices = deviceIDs[batchDeviceIdx:]
}
values := make([]interface{}, 1, len(devices)*6+1)
values[0] = userID
valueStrings := make([]string, 0, len(devices))
i := 2
for _, deviceID := range batchDevices {
identity := devices[deviceID]
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
i += 6
}
valueString := strings.Join(valueStrings, ",")
_, err = store.DB.Exec(ctx, fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
if err != nil {
return fmt.Errorf("failed to insert new devices: %w", err)
}
}
return nil
}
deviceBatchLen := 5 // how many devices will be inserted per query
deviceIDs := make([]id.DeviceID, 0, len(devices))
for deviceID := range devices {
deviceIDs = append(deviceIDs, deviceID)
}
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
var batchDevices []id.DeviceID
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
} else {
batchDevices = deviceIDs[batchDeviceIdx:]
}
values := make([]interface{}, 1, len(devices)*6+1)
values[0] = userID
valueStrings := make([]string, 0, len(devices))
i := 2
for _, deviceID := range batchDevices {
identity := devices[deviceID]
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
i += 6
}
valueString := strings.Join(valueStrings, ",")
_, err = tx.Exec(fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
if err != nil {
_ = tx.Rollback()
return fmt.Errorf("failed to insert new devices: %w", err)
}
}
err = tx.Commit()
if err != nil {
return fmt.Errorf("failed to commit changes: %w", err)
}
return nil
})
}
// FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information.
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) {
var rows dbutil.Rows
var err error
if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
} else {
queryString := make([]string, len(users))
params := make([]interface{}, len(users))
for i, user := range users {
queryString[i] = fmt.Sprintf("$%d", i+1)
queryString[i] = fmt.Sprintf("?%d", i+1)
params[i] = user
}
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
}
if err != nil {
return users, err
}
var ptr int
for rows.Next() {
err = rows.Scan(&users[ptr])
if err != nil {
return users, err
} else {
ptr++
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
}
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error {
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
// TODO refactor to use a single query
for _, userID := range users {
_, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID)
if err != nil {
return fmt.Errorf("failed to update user in the tracked users list: %w", err)
}
}
return nil
})
}
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) {
rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE")
if err != nil {
return nil, err
}
return users[:ptr], nil
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
}
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
_, err := store.DB.Exec(`
func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) VALUES ($1, $2, $3, $4)
ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key
`, userID, usage, key, key)
@@ -800,8 +775,8 @@ func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.Cross
}
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
rows, err := store.DB.Query("SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
func (store *SQLCryptoStore) GetCrossSigningKeys(ctx context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
rows, err := store.DB.Query(ctx, "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
if err != nil {
return nil, err
}
@@ -820,8 +795,8 @@ func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.Cross
}
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
_, err := store.DB.Exec(`
func (store *SQLCryptoStore) PutSignature(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
_, err := store.DB.Exec(ctx, `
INSERT INTO crypto_cross_signing_signatures (signed_user_id, signed_key, signer_user_id, signer_key, signature) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature
`, signedUserID, signedKey, signerUserID, signerKey, signature)
@@ -829,8 +804,8 @@ func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.E
}
// GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer.
func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
rows, err := store.DB.Query("SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
func (store *SQLCryptoStore) GetSignaturesForKeyBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
rows, err := store.DB.Query(ctx, "SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
if err != nil {
return nil, err
}
@@ -849,18 +824,18 @@ func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25
}
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
func (store *SQLCryptoStore) IsKeySignedBy(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
func (store *SQLCryptoStore) IsKeySignedBy(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
q := `SELECT EXISTS(
SELECT 1 FROM crypto_cross_signing_signatures
WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4
)`
err = store.DB.QueryRow(q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
err = store.DB.QueryRow(ctx, q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
return
}
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
res, err := store.DB.Exec(ctx, "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
if err != nil {
return 0, err
}

View File

@@ -1,4 +1,4 @@
-- v0 -> v10: Latest revision
-- v0 -> v11: Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS crypto_message_index (
);
CREATE TABLE IF NOT EXISTS crypto_tracked_user (
user_id TEXT PRIMARY KEY
user_id TEXT PRIMARY KEY,
devices_outdated BOOLEAN NOT NULL DEFAULT FALSE
);
CREATE TABLE IF NOT EXISTS crypto_device (

View File

@@ -0,0 +1,2 @@
-- v11: Add devices_outdated field to crypto_tracked_user
ALTER TABLE crypto_tracked_user ADD COLUMN devices_outdated BOOLEAN NOT NULL DEFAULT FALSE;

View File

@@ -7,6 +7,7 @@
package sql_store_upgrade
import (
"context"
"embed"
"fmt"
@@ -21,7 +22,7 @@ const VersionTableName = "crypto_version"
var fs embed.FS
func init() {
Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
})
Table.RegisterFS(fs)

View File

@@ -7,6 +7,7 @@
package ssss
import (
"context"
"fmt"
"maunium.net/go/mautrix"
@@ -29,9 +30,9 @@ type DefaultSecretStorageKeyContent struct {
}
// GetDefaultKeyID retrieves the default key ID for this account from SSSS.
func (mach *Machine) GetDefaultKeyID() (string, error) {
func (mach *Machine) GetDefaultKeyID(ctx context.Context) (string, error) {
var data DefaultSecretStorageKeyContent
err := mach.Client.GetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &data)
err := mach.Client.GetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &data)
if err != nil {
if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_NOT_FOUND" {
return "", ErrNoDefaultKeyAccountDataEvent
@@ -45,36 +46,36 @@ func (mach *Machine) GetDefaultKeyID() (string, error) {
}
// SetDefaultKeyID sets the default key ID for this account on the server.
func (mach *Machine) SetDefaultKeyID(keyID string) error {
return mach.Client.SetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID})
func (mach *Machine) SetDefaultKeyID(ctx context.Context, keyID string) error {
return mach.Client.SetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID})
}
// GetKeyData gets the details about the given key ID.
func (mach *Machine) GetKeyData(keyID string) (keyData *KeyMetadata, err error) {
func (mach *Machine) GetKeyData(ctx context.Context, keyID string) (keyData *KeyMetadata, err error) {
keyData = &KeyMetadata{id: keyID}
err = mach.Client.GetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
err = mach.Client.GetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
return
}
// SetKeyData stores SSSS key metadata on the server.
func (mach *Machine) SetKeyData(keyID string, keyData *KeyMetadata) error {
return mach.Client.SetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
func (mach *Machine) SetKeyData(ctx context.Context, keyID string, keyData *KeyMetadata) error {
return mach.Client.SetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
}
// GetDefaultKeyData gets the details about the default key ID (see GetDefaultKeyID).
func (mach *Machine) GetDefaultKeyData() (keyID string, keyData *KeyMetadata, err error) {
keyID, err = mach.GetDefaultKeyID()
func (mach *Machine) GetDefaultKeyData(ctx context.Context) (keyID string, keyData *KeyMetadata, err error) {
keyID, err = mach.GetDefaultKeyID(ctx)
if err != nil {
return
}
keyData, err = mach.GetKeyData(keyID)
keyData, err = mach.GetKeyData(ctx, keyID)
return
}
// GetDecryptedAccountData gets the account data event with the given event type and decrypts it using the given key.
func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([]byte, error) {
func (mach *Machine) GetDecryptedAccountData(ctx context.Context, eventType event.Type, key *Key) ([]byte, error) {
var encData EncryptedAccountDataEventContent
err := mach.Client.GetAccountData(eventType.Type, &encData)
err := mach.Client.GetAccountData(ctx, eventType.Type, &encData)
if err != nil {
return nil, err
}
@@ -82,7 +83,7 @@ func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([]
}
// SetEncryptedAccountData encrypts the given data with the given keys and stores it on the server.
func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte, keys ...*Key) error {
func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType event.Type, data []byte, keys ...*Key) error {
if len(keys) == 0 {
return ErrNoKeyGiven
}
@@ -90,17 +91,17 @@ func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte,
for _, key := range keys {
encrypted[key.ID] = key.Encrypt(eventType.Type, data)
}
return mach.Client.SetAccountData(eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted})
return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted})
}
// GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server.
func (mach *Machine) GenerateAndUploadKey(passphrase string) (key *Key, err error) {
func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) {
key, err = NewKey(passphrase)
if err != nil {
return nil, fmt.Errorf("failed to generate new key: %w", err)
}
err = mach.SetKeyData(key.ID, key.Metadata)
err = mach.SetKeyData(ctx, key.ID, key.Metadata)
if err != nil {
err = fmt.Errorf("failed to upload key: %w", err)
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -26,64 +26,64 @@ var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{}
type Store interface {
// Flush ensures that everything in the store is persisted to disk.
// This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately.
Flush() error
Flush(context.Context) error
// PutAccount updates the OlmAccount in the store.
PutAccount(*OlmAccount) error
PutAccount(context.Context, *OlmAccount) error
// GetAccount returns the OlmAccount in the store that was previously inserted with PutAccount.
GetAccount() (*OlmAccount, error)
GetAccount(ctx context.Context) (*OlmAccount, error)
// AddSession inserts an Olm session into the store.
AddSession(id.SenderKey, *OlmSession) error
AddSession(context.Context, id.SenderKey, *OlmSession) error
// HasSession returns whether or not the store has an Olm session with the given sender key.
HasSession(id.SenderKey) bool
HasSession(context.Context, id.SenderKey) bool
// GetSessions returns all Olm sessions in the store with the given sender key.
GetSessions(id.SenderKey) (OlmSessionList, error)
GetSessions(context.Context, id.SenderKey) (OlmSessionList, error)
// GetLatestSession returns the session with the highest session ID (lexiographically sorting).
// It's usually safe to return the most recently added session if sorting by session ID is too difficult.
GetLatestSession(id.SenderKey) (*OlmSession, error)
GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error)
// UpdateSession updates a session that has previously been inserted with AddSession.
UpdateSession(id.SenderKey, *OlmSession) error
UpdateSession(context.Context, id.SenderKey, *OlmSession) error
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
// sessions inserted with this call.
PutGroupSession(id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
// GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error
RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error
// RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room.
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
RedactExpiredGroupSessions() ([]id.SessionID, error)
RedactExpiredGroupSessions(context.Context) ([]id.SessionID, error)
// RedactOutdatedGroupSessions removes the session data for all inbound Megolm sessions that are lacking the expiration metadata.
RedactOutdatedGroupSessions() ([]id.SessionID, error)
RedactOutdatedGroupSessions(context.Context) ([]id.SessionID, error)
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
GetWithheldGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
GetGroupSessionsForRoom(id.RoomID) ([]*InboundGroupSession, error)
GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error)
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
GetAllGroupSessions() ([]*InboundGroupSession, error)
GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error)
// AddOutboundGroupSession inserts the given outbound Megolm session into the store.
//
// The store should index inserted sessions by the RoomID field to support getting and removing sessions.
// There will only be one outbound session per room ID at a time.
AddOutboundGroupSession(*OutboundGroupSession) error
AddOutboundGroupSession(context.Context, *OutboundGroupSession) error
// UpdateOutboundGroupSession updates the given outbound Megolm session in the store.
UpdateOutboundGroupSession(*OutboundGroupSession) error
UpdateOutboundGroupSession(context.Context, *OutboundGroupSession) error
// GetOutboundGroupSession gets the stored outbound Megolm session for the given room ID from the store.
GetOutboundGroupSession(id.RoomID) (*OutboundGroupSession, error)
GetOutboundGroupSession(context.Context, id.RoomID) (*OutboundGroupSession, error)
// RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID.
RemoveOutboundGroupSession(id.RoomID) error
RemoveOutboundGroupSession(context.Context, id.RoomID) error
// ValidateMessageIndex validates that the given message details aren't from a replay attack.
//
@@ -96,29 +96,33 @@ type Store interface {
ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
// GetDevices returns a map from device ID to id.Device struct containing all devices of a given user.
GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error)
GetDevices(context.Context, id.UserID) (map[id.DeviceID]*id.Device, error)
// GetDevice returns a specific device of a given user.
GetDevice(id.UserID, id.DeviceID) (*id.Device, error)
GetDevice(context.Context, id.UserID, id.DeviceID) (*id.Device, error)
// PutDevice stores a single device for a user, replacing it if it exists already.
PutDevice(id.UserID, *id.Device) error
PutDevice(context.Context, id.UserID, *id.Device) error
// PutDevices overrides the stored device list for the given user with the given list.
PutDevices(id.UserID, map[id.DeviceID]*id.Device) error
PutDevices(context.Context, id.UserID, map[id.DeviceID]*id.Device) error
// FindDeviceByKey finds a specific device by its identity key.
FindDeviceByKey(id.UserID, id.IdentityKey) (*id.Device, error)
FindDeviceByKey(context.Context, id.UserID, id.IdentityKey) (*id.Device, error)
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
FilterTrackedUsers([]id.UserID) ([]id.UserID, error)
FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error)
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
MarkTrackedUsersOutdated(context.Context, []id.UserID) error
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
GetOutdatedTrackedUsers(context.Context) ([]id.UserID, error)
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
PutCrossSigningKey(id.UserID, id.CrossSigningUsage, id.Ed25519) error
PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
GetCrossSigningKeys(id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
GetCrossSigningKeys(context.Context, id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
PutSignature(signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
PutSignature(ctx context.Context, signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
IsKeySignedBy(userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error)
DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error)
}
type messageIndexKey struct {
@@ -148,6 +152,7 @@ type MemoryStore struct {
Devices map[id.UserID]map[id.DeviceID]*id.Device
CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
OutdatedUsers map[id.UserID]struct{}
}
var _ Store = (*MemoryStore)(nil)
@@ -167,21 +172,22 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
Devices: make(map[id.UserID]map[id.DeviceID]*id.Device),
CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey),
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
OutdatedUsers: make(map[id.UserID]struct{}),
}
}
func (gs *MemoryStore) Flush() error {
func (gs *MemoryStore) Flush(_ context.Context) error {
gs.lock.Lock()
err := gs.save()
gs.lock.Unlock()
return err
}
func (gs *MemoryStore) GetAccount() (*OlmAccount, error) {
func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) {
return gs.Account, nil
}
func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error {
gs.lock.Lock()
gs.Account = account
err := gs.save()
@@ -189,7 +195,7 @@ func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
return err
}
func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) {
gs.lock.Lock()
sessions, ok := gs.Sessions[senderKey]
if !ok {
@@ -200,7 +206,7 @@ func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, erro
return sessions, nil
}
func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error {
gs.lock.Lock()
sessions, _ := gs.Sessions[senderKey]
gs.Sessions[senderKey] = append(sessions, session)
@@ -210,19 +216,19 @@ func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) e
return err
}
func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool {
func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
return ok && len(sessions) > 0 && !sessions[0].Expired()
}
func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
@@ -246,7 +252,7 @@ func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey
return sender
}
func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
gs.lock.Lock()
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
err := gs.save()
@@ -254,7 +260,7 @@ func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey,
return err
}
func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
gs.lock.Lock()
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
if !ok {
@@ -269,7 +275,7 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey,
return session, nil
}
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
gs.lock.Lock()
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
err := gs.save()
@@ -277,7 +283,7 @@ func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderK
return err
}
func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
gs.lock.Lock()
var sessionIDs []id.SessionID
if roomID != "" && senderKey != "" {
@@ -315,11 +321,11 @@ func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.Sender
return sessionIDs, err
}
func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
func (gs *MemoryStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
@@ -337,7 +343,7 @@ func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.S
return sender
}
func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error {
gs.lock.Lock()
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
err := gs.save()
@@ -345,7 +351,7 @@ func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEven
return err
}
func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
gs.lock.Lock()
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
gs.lock.Unlock()
@@ -355,7 +361,7 @@ func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Se
return session, nil
}
func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
room, ok := gs.GroupSessions[roomID]
@@ -371,7 +377,7 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGrou
return result, nil
}
func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) {
gs.lock.Lock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
@@ -385,7 +391,7 @@ func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
return result, nil
}
func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error {
gs.lock.Lock()
gs.OutGroupSessions[session.RoomID] = session
err := gs.save()
@@ -393,12 +399,12 @@ func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) er
return err
}
func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
gs.lock.RLock()
session, ok := gs.OutGroupSessions[roomID]
gs.lock.RUnlock()
@@ -408,7 +414,7 @@ func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroup
return session, nil
}
func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error {
gs.lock.Lock()
session, ok := gs.OutGroupSessions[roomID]
if !ok || session == nil {
@@ -443,7 +449,7 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send
return true, nil
}
func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
gs.lock.RLock()
devices, ok := gs.Devices[userID]
if !ok {
@@ -453,7 +459,7 @@ func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device,
return devices, nil
}
func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
func (gs *MemoryStore) GetDevice(_ context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@@ -467,7 +473,7 @@ func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.De
return device, nil
}
func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@@ -482,7 +488,7 @@ func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.Identity
return nil, nil
}
func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error {
gs.lock.Lock()
devices, ok := gs.Devices[userID]
if !ok {
@@ -495,15 +501,18 @@ func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
return err
}
func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
gs.lock.Lock()
gs.Devices[userID] = devices
err := gs.save()
if err == nil {
delete(gs.OutdatedUsers, userID)
}
gs.lock.Unlock()
return err
}
func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) {
gs.lock.RLock()
var ptr int
for _, userID := range users {
@@ -517,7 +526,28 @@ func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error
return users[:ptr], nil
}
func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error {
gs.lock.Lock()
for _, userID := range users {
if _, ok := gs.Devices[userID]; ok {
gs.OutdatedUsers[userID] = struct{}{}
}
}
gs.lock.Unlock()
return nil
}
func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) {
gs.lock.RLock()
users := make([]id.UserID, 0, len(gs.OutdatedUsers))
for userID := range gs.OutdatedUsers {
users = append(users, userID)
}
gs.lock.RUnlock()
return users, nil
}
func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
gs.lock.RLock()
userKeys, ok := gs.CrossSigningKeys[userID]
if !ok {
@@ -539,7 +569,7 @@ func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSignin
return err
}
func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
keys, ok := gs.CrossSigningKeys[userID]
@@ -549,7 +579,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSignin
return keys, nil
}
func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
gs.lock.RLock()
signedUserSigs, ok := gs.KeySignatures[signedUserID]
if !ok {
@@ -572,7 +602,7 @@ func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519
return err
}
func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
userKeys, ok := gs.KeySignatures[userID]
@@ -590,8 +620,8 @@ func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, s
return sigsBySigner, nil
}
func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID)
func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
sigs, err := gs.GetSignaturesForKeyBy(ctx, userID, key, signerID)
if err != nil {
return false, err
}
@@ -599,7 +629,7 @@ func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID
return ok, nil
}
func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
var count int64
gs.lock.RLock()
for _, userSigs := range gs.KeySignatures {

View File

@@ -10,10 +10,10 @@ import (
"crypto/aes"
"crypto/cipher"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"crypto/sha512"
"encoding/base64"
"math/rand"
"strings"
"go.mau.fi/util/base58"

View File

@@ -54,8 +54,8 @@ const (
)
// sendToOneDevice sends a to-device event to a single device.
func (mach *OlmMachine) sendToOneDevice(userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error {
_, err := mach.Client.SendToDevice(eventType, &mautrix.ReqSendToDevice{
func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error {
_, err := mach.Client.SendToDevice(ctx, eventType, &mautrix.ReqSendToDevice{
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
userID: {
deviceID: {
@@ -118,19 +118,19 @@ type verificationState struct {
}
// getTransactionState retrieves the given transaction's state, or cancels the transaction if it cannot be found or there is a mismatch.
func (mach *OlmMachine) getTransactionState(transactionID string, userID id.UserID) (*verificationState, error) {
func (mach *OlmMachine) getTransactionState(ctx context.Context, transactionID string, userID id.UserID) (*verificationState, error) {
verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID)
if !ok {
_ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction)
_ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction)
return nil, ErrUnknownTransaction
}
verState := verStateInterface.(*verificationState)
if verState.otherDevice.UserID != userID {
reason := fmt.Sprintf("Unknown user for transaction %v: %v", transactionID, userID)
if verState.inRoomID == "" {
_ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch)
_ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch)
} else {
_ = mach.SendInRoomSASVerificationCancel(verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch)
_ = mach.SendInRoomSASVerificationCancel(ctx, verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch)
}
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
return nil, fmt.Errorf("%w %s: %s", ErrUnknownUserForTransaction, transactionID, userID)
@@ -140,9 +140,9 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User
// handleVerificationStart handles an incoming m.key.verification.start message.
// It initializes the state for this SAS verification process and stores it.
func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
func (mach *OlmMachine) handleVerificationStart(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice)
if err != nil {
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
return
@@ -150,9 +150,9 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event
warnAndCancel := func(logReason, cancelReason string) {
mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
} else {
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
}
}
switch {
@@ -168,21 +168,21 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event
case !content.SupportsSASMethod(event.SASDecimal):
warnAndCancel("does not support decimal SAS", "Decimal SAS method must be supported")
default:
mach.actuallyStartVerification(userID, content, otherDevice, transactionID, timeout, inRoomID)
mach.actuallyStartVerification(ctx, userID, content, otherDevice, transactionID, timeout, inRoomID)
}
}
func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
func (mach *OlmMachine) actuallyStartVerification(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
if inRoomID != "" && transactionID != "" {
verState, err := mach.getTransactionState(transactionID, userID)
verState, err := mach.getTransactionState(ctx, transactionID, userID)
if err != nil {
mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error")
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error")
return
}
mach.timeoutAfter(verState, transactionID, timeout)
mach.timeoutAfter(ctx, verState, transactionID, timeout)
sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString)
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
if err != nil {
mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err)
}
@@ -196,9 +196,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
if len(sasMethods) == 0 {
mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
} else {
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
}
return
}
@@ -221,20 +221,20 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
// transaction already exists
mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
} else {
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
}
return
}
mach.timeoutAfter(verState, transactionID, timeout)
mach.timeoutAfter(ctx, verState, transactionID, timeout)
var err error
if inRoomID == "" {
err = mach.SendSASVerificationAccept(userID, content, verState.sas.GetPubkey(), sasMethods)
err = mach.SendSASVerificationAccept(ctx, userID, content, verState.sas.GetPubkey(), sasMethods)
} else {
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
}
if err != nil {
mach.Log.Error().Msgf("Error accepting SAS verification: %v", err)
@@ -243,9 +243,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
var err error
if inRoomID == "" {
err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
err = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
} else {
err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
err = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
}
if err != nil {
mach.Log.Error().Msgf("Error canceling SAS verification: %v", err)
@@ -255,8 +255,8 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
}
}
func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID string, timeout time.Duration) {
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
func (mach *OlmMachine) timeoutAfter(ctx context.Context, verState *verificationState, transactionID string, timeout time.Duration) {
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, timeout)
verState.extendTimeout = timeoutCancel
go func() {
mapKey := verState.otherDevice.UserID.String() + ":" + transactionID
@@ -272,7 +272,7 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
if timeoutCtx.Err() == context.DeadlineExceeded {
// if deadline exceeded cancel due to timeout
mach.keyVerificationTransactionState.Delete(mapKey)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Timed out", event.VerificationCancelByTimeout)
mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID)
verState.lock.Unlock()
return
@@ -288,9 +288,9 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
// handleVerificationAccept handles an incoming m.key.verification.accept message.
// It continues the SAS verification process by sending the SAS key message to the other device.
func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) {
func (mach *OlmMachine) handleVerificationAccept(ctx context.Context, userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) {
mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID)
verState, err := mach.getTransactionState(transactionID, userID)
verState, err := mach.getTransactionState(ctx, transactionID, userID)
if err != nil {
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
@@ -303,7 +303,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
// unexpected accept at this point
mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage)
return
}
@@ -315,7 +315,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod)
return
}
@@ -325,9 +325,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
verState.verificationStarted = true
if verState.inRoomID == "" {
err = mach.SendSASVerificationKey(userID, verState.otherDevice.DeviceID, transactionID, string(key))
err = mach.SendSASVerificationKey(ctx, userID, verState.otherDevice.DeviceID, transactionID, string(key))
} else {
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key))
}
if err != nil {
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
@@ -337,9 +337,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
// handleVerificationKey handles an incoming m.key.verification.key message.
// It stores the other device's public key in order to acquire the SAS shared secret.
func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) {
func (mach *OlmMachine) handleVerificationKey(ctx context.Context, userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) {
mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key)
verState, err := mach.getTransactionState(transactionID, userID)
verState, err := mach.getTransactionState(ctx, transactionID, userID)
if err != nil {
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
@@ -354,7 +354,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
// unexpected key at this point
mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage)
return
}
@@ -372,7 +372,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
if expectedCommitment != verState.commitment {
mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch)
return
}
} else {
@@ -380,9 +380,9 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
key := verState.sas.GetPubkey()
if verState.inRoomID == "" {
err = mach.SendSASVerificationKey(userID, device.DeviceID, transactionID, string(key))
err = mach.SendSASVerificationKey(ctx, userID, device.DeviceID, transactionID, string(key))
} else {
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key))
}
if err != nil {
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
@@ -419,13 +419,13 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas)
go func() {
result := verState.hooks.VerifySASMatch(device, sas)
mach.sasCompared(result, transactionID, verState)
mach.sasCompared(ctx, result, transactionID, verState)
}()
}
// sasCompared is called asynchronously. It waits for the SAS to be compared for the verification to proceed.
// If the SAS match, then our MAC is sent out. Otherwise the transaction is canceled.
func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verState *verificationState) {
func (mach *OlmMachine) sasCompared(ctx context.Context, didMatch bool, transactionID string, verState *verificationState) {
verState.lock.Lock()
defer verState.lock.Unlock()
verState.extendTimeout()
@@ -433,9 +433,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
verState.sasMatched <- true
var err error
if verState.inRoomID == "" {
err = mach.SendSASVerificationMAC(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
err = mach.SendSASVerificationMAC(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
} else {
err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
err = mach.SendInRoomSASVerificationMAC(ctx, verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
}
if err != nil {
mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err)
@@ -447,9 +447,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
// handleVerificationMAC handles an incoming m.key.verification.mac message.
// It verifies the other device's MAC and if the MAC is valid it marks the device as trusted.
func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) {
func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.UserID, content *event.VerificationMacEventContent, transactionID string) {
mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
verState, err := mach.getTransactionState(transactionID, userID)
verState, err := mach.getTransactionState(ctx, transactionID, userID)
if err != nil {
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
@@ -466,7 +466,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
if !verState.verificationStarted || !verState.keyReceived {
// unexpected MAC at this point
mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage)
return
}
@@ -478,7 +478,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
if !matched {
mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch)
return
}
@@ -494,38 +494,38 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
if content.Keys != expectedKeysMAC {
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch)
return
}
mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
if content.Mac[keyID] != expectedPKMAC {
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch)
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch)
return
}
// we can finally trust this device
device.Trust = id.TrustStateVerified
err = mach.CryptoStore.PutDevice(device.UserID, device)
err = mach.CryptoStore.PutDevice(ctx, device.UserID, device)
if err != nil {
mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err)
}
if mach.CrossSigningKeys != nil {
if device.UserID == mach.Client.UserID {
err := mach.SignOwnDevice(device)
err := mach.SignOwnDevice(ctx, device)
if err != nil {
mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err)
} else {
mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID)
}
} else {
masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID)
masterKey, err := mach.fetchMasterKey(ctx, device, content, verState, transactionID)
if err != nil {
mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err)
} else {
if err := mach.SignUser(device.UserID, masterKey); err != nil {
if err := mach.SignUser(ctx, device.UserID, masterKey); err != nil {
mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err)
} else {
mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID)
@@ -559,9 +559,9 @@ func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *even
}
// handleVerificationRequest handles an incoming m.key.verification.request message.
func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) {
func (mach *OlmMachine) handleVerificationRequest(ctx context.Context, userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) {
mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice)
if err != nil {
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
return
@@ -569,9 +569,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
if !content.SupportsVerificationMethod(event.VerificationMethodSAS) {
mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
} else {
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
}
return
}
@@ -579,14 +579,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
if resp == AcceptRequest {
mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
if inRoomID == "" {
_, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
_, err = mach.NewSASVerificationWith(ctx, otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
} else {
if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil {
if err := mach.SendInRoomSASVerificationReady(ctx, inRoomID, transactionID); err != nil {
mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err)
}
if mach.Client.UserID < otherDevice.UserID {
// up to us to send the start message
_, err = mach.newInRoomSASVerificationWithInner(inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
_, err = mach.newInRoomSASVerificationWithInner(ctx, inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
}
}
if err != nil {
@@ -595,9 +595,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
} else if resp == RejectRequest {
mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
} else {
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
}
} else {
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
@@ -606,14 +606,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
// NewSimpleSASVerificationWith starts the SAS verification process with another device with a default timeout,
// a generated transaction ID and support for both emoji and decimal SAS methods.
func (mach *OlmMachine) NewSimpleSASVerificationWith(device *id.Device, hooks VerificationHooks) (string, error) {
return mach.NewSASVerificationWith(device, hooks, "", mach.DefaultSASTimeout)
func (mach *OlmMachine) NewSimpleSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks) (string, error) {
return mach.NewSASVerificationWith(ctx, device, hooks, "", mach.DefaultSASTimeout)
}
// NewSASVerificationWith starts the SAS verification process with another device.
// If the other device accepts the verification transaction, the methods in `hooks` will be used to verify the SAS match and to complete the transaction..
// If the transaction ID is empty, a new one is generated.
func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
func (mach *OlmMachine) NewSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
if transactionID == "" {
transactionID = strconv.Itoa(rand.Int())
}
@@ -631,7 +631,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica
verState.lock.Lock()
defer verState.lock.Unlock()
startEvent, err := mach.SendSASVerificationStart(device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods())
startEvent, err := mach.SendSASVerificationStart(ctx, device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods())
if err != nil {
return "", err
}
@@ -651,13 +651,13 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica
return "", ErrTransactionAlreadyExists
}
mach.timeoutAfter(verState, transactionID, timeout)
mach.timeoutAfter(ctx, verState, transactionID, timeout)
return transactionID, nil
}
// CancelSASVerification is used by the user to cancel a SAS verification process with the given reason.
func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, reason string) error {
func (mach *OlmMachine) CancelSASVerification(ctx context.Context, userID id.UserID, transactionID, reason string) error {
mapKey := userID.String() + ":" + transactionID
verStateInterface, ok := mach.keyVerificationTransactionState.Load(mapKey)
if !ok {
@@ -668,21 +668,21 @@ func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, r
defer verState.lock.Unlock()
mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason)
mach.keyVerificationTransactionState.Delete(mapKey)
return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser)
return mach.callbackAndCancelSASVerification(ctx, verState, transactionID, reason, event.VerificationCancelByUser)
}
// SendSASVerificationCancel is used to manually send a SAS cancel message process with the given reason and cancellation code.
func (mach *OlmMachine) SendSASVerificationCancel(userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error {
func (mach *OlmMachine) SendSASVerificationCancel(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error {
content := &event.VerificationCancelEventContent{
TransactionID: transactionID,
Reason: reason,
Code: code,
}
return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationCancel, content)
return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationCancel, content)
}
// SendSASVerificationStart is used to manually send the SAS verification start message to another device.
func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
func (mach *OlmMachine) SendSASVerificationStart(ctx context.Context, toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
sasMethods := make([]event.SASMethod, len(methods))
for i, method := range methods {
sasMethods[i] = method.Type()
@@ -696,14 +696,14 @@ func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID
MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256},
ShortAuthenticationString: sasMethods,
}
return content, mach.sendToOneDevice(toUserID, toDeviceID, event.ToDeviceVerificationStart, content)
return content, mach.sendToOneDevice(ctx, toUserID, toDeviceID, event.ToDeviceVerificationStart, content)
}
// SendSASVerificationAccept is used to manually send an accept for a SAS verification process from a received m.key.verification.start event.
func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error {
func (mach *OlmMachine) SendSASVerificationAccept(ctx context.Context, fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error {
if startEvent.Method != event.VerificationMethodSAS {
reason := "Unknown verification method: " + string(startEvent.Method)
if err := mach.SendSASVerificationCancel(fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
if err := mach.SendSASVerificationCancel(ctx, fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
return err
}
return ErrUnknownVerificationMethod
@@ -730,25 +730,25 @@ func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent
ShortAuthenticationString: sasMethods,
Commitment: hash,
}
return mach.sendToOneDevice(fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content)
return mach.sendToOneDevice(ctx, fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content)
}
func (mach *OlmMachine) callbackAndCancelSASVerification(verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error {
func (mach *OlmMachine) callbackAndCancelSASVerification(ctx context.Context, verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error {
go verState.hooks.OnCancel(true, reason, code)
return mach.SendSASVerificationCancel(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code)
return mach.SendSASVerificationCancel(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code)
}
// SendSASVerificationKey sends the ephemeral public key for a device to the partner device.
func (mach *OlmMachine) SendSASVerificationKey(userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error {
func (mach *OlmMachine) SendSASVerificationKey(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error {
content := &event.VerificationKeyEventContent{
TransactionID: transactionID,
Key: key,
}
return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationKey, content)
return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationKey, content)
}
// SendSASVerificationMAC is use the MAC of a device's key to the partner device.
func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
func (mach *OlmMachine) SendSASVerificationMAC(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String())
signingKey := mach.account.SigningKey()
@@ -784,7 +784,7 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
Mac: macMap,
}
return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationMAC, content)
return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationMAC, content)
}
func commonSASMethods(hooks VerificationHooks, otherDeviceMethods []event.SASMethod) []VerificationMethod {

View File

@@ -38,6 +38,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error {
return ErrNoRelatesTo
}
ctx := context.TODO()
switch content := evt.Content.Parsed.(type) {
case *event.MessageEventContent:
if content.MsgType == event.MsgVerificationRequest {
@@ -54,18 +55,18 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error {
Timestamp: evt.Timestamp,
TransactionID: evt.ID.String(),
}
mach.handleVerificationRequest(evt.Sender, newContent, evt.ID.String(), evt.RoomID)
mach.handleVerificationRequest(ctx, evt.Sender, newContent, evt.ID.String(), evt.RoomID)
}
case *event.VerificationStartEventContent:
mach.handleVerificationStart(evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID)
mach.handleVerificationStart(ctx, evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID)
case *event.VerificationReadyEventContent:
mach.handleInRoomVerificationReady(evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String())
mach.handleInRoomVerificationReady(ctx, evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String())
case *event.VerificationAcceptEventContent:
mach.handleVerificationAccept(evt.Sender, content, content.RelatesTo.EventID.String())
mach.handleVerificationAccept(ctx, evt.Sender, content, content.RelatesTo.EventID.String())
case *event.VerificationKeyEventContent:
mach.handleVerificationKey(evt.Sender, content, content.RelatesTo.EventID.String())
mach.handleVerificationKey(ctx, evt.Sender, content, content.RelatesTo.EventID.String())
case *event.VerificationMacEventContent:
mach.handleVerificationMAC(evt.Sender, content, content.RelatesTo.EventID.String())
mach.handleVerificationMAC(ctx, evt.Sender, content, content.RelatesTo.EventID.String())
case *event.VerificationCancelEventContent:
mach.handleVerificationCancel(evt.Sender, content, content.RelatesTo.EventID.String())
}
@@ -73,7 +74,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error {
}
// SendInRoomSASVerificationCancel is used to manually send an in-room SAS cancel message process with the given reason and cancellation code.
func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error {
func (mach *OlmMachine) SendInRoomSASVerificationCancel(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error {
content := &event.VerificationCancelEventContent{
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
Reason: reason,
@@ -81,16 +82,16 @@ func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationCancel, content)
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationCancel, content)
if err != nil {
return err
}
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
return err
}
// SendInRoomSASVerificationRequest is used to manually send an in-room SAS verification request message to another user.
func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) {
func (mach *OlmMachine) SendInRoomSASVerificationRequest(ctx context.Context, roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) {
content := &event.MessageEventContent{
MsgType: event.MsgVerificationRequest,
FromDevice: mach.Client.DeviceID,
@@ -98,11 +99,11 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse
To: toUserID,
}
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.EventMessage, content)
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.EventMessage, content)
if err != nil {
return "", err
}
resp, err := mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
resp, err := mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
if err != nil {
return "", err
}
@@ -110,23 +111,23 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse
}
// SendInRoomSASVerificationReady is used to manually send an in-room SAS verification ready message to another user.
func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transactionID string) error {
func (mach *OlmMachine) SendInRoomSASVerificationReady(ctx context.Context, roomID id.RoomID, transactionID string) error {
content := &event.VerificationReadyEventContent{
FromDevice: mach.Client.DeviceID,
Methods: []event.VerificationMethod{event.VerificationMethodSAS},
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
}
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationReady, content)
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationReady, content)
if err != nil {
return err
}
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
return err
}
// SendInRoomSASVerificationStart is used to manually send the in-room SAS verification start message to another user.
func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
func (mach *OlmMachine) SendInRoomSASVerificationStart(ctx context.Context, roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
sasMethods := make([]event.SASMethod, len(methods))
for i, method := range methods {
sasMethods[i] = method.Type()
@@ -142,19 +143,19 @@ func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserI
To: toUserID,
}
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationStart, content)
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationStart, content)
if err != nil {
return nil, err
}
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
return content, err
}
// SendInRoomSASVerificationAccept is used to manually send an accept for an in-room SAS verification process from a received m.key.verification.start event.
func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error {
func (mach *OlmMachine) SendInRoomSASVerificationAccept(ctx context.Context, roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error {
if startEvent.Method != event.VerificationMethodSAS {
reason := "Unknown verification method: " + string(startEvent.Method)
if err := mach.SendInRoomSASVerificationCancel(roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
if err := mach.SendInRoomSASVerificationCancel(ctx, roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
return err
}
return ErrUnknownVerificationMethod
@@ -183,32 +184,32 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUs
To: fromUser,
}
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationAccept, content)
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationAccept, content)
if err != nil {
return err
}
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
return err
}
// SendInRoomSASVerificationKey sends the ephemeral public key for a device to the partner device for an in-room verification.
func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id.UserID, transactionID string, key string) error {
func (mach *OlmMachine) SendInRoomSASVerificationKey(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, key string) error {
content := &event.VerificationKeyEventContent{
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
Key: key,
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationKey, content)
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationKey, content)
if err != nil {
return err
}
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
return err
}
// SendInRoomSASVerificationMAC sends the MAC of a device's key to the partner device for an in-room verification.
func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
func (mach *OlmMachine) SendInRoomSASVerificationMAC(ctx context.Context, roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String())
signingKey := mach.account.SigningKey()
@@ -245,28 +246,28 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationMAC, content)
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationMAC, content)
if err != nil {
return err
}
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
return err
}
// NewInRoomSASVerificationWith starts the in-room SAS verification process with another user in the given room.
// It returns the generated transaction ID.
func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) {
return mach.newInRoomSASVerificationWithInner(inRoomID, &id.Device{UserID: userID}, hooks, "", timeout)
func (mach *OlmMachine) NewInRoomSASVerificationWith(ctx context.Context, inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) {
return mach.newInRoomSASVerificationWithInner(ctx, inRoomID, &id.Device{UserID: userID}, hooks, "", timeout)
}
func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
func (mach *OlmMachine) newInRoomSASVerificationWithInner(ctx context.Context, inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID)
request := transactionID == ""
if request {
var err error
// get new transaction ID from the request message event ID
transactionID, err = mach.SendInRoomSASVerificationRequest(inRoomID, device.UserID, hooks.VerificationMethods())
transactionID, err = mach.SendInRoomSASVerificationRequest(ctx, inRoomID, device.UserID, hooks.VerificationMethods())
if err != nil {
return "", err
}
@@ -286,7 +287,7 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de
if !request {
// start in-room verification
startEvent, err := mach.SendInRoomSASVerificationStart(inRoomID, device.UserID, transactionID, hooks.VerificationMethods())
startEvent, err := mach.SendInRoomSASVerificationStart(ctx, inRoomID, device.UserID, transactionID, hooks.VerificationMethods())
if err != nil {
return "", err
}
@@ -305,19 +306,19 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de
mach.keyVerificationTransactionState.Store(device.UserID.String()+":"+transactionID, verState)
mach.timeoutAfter(verState, transactionID, timeout)
mach.timeoutAfter(ctx, verState, transactionID, timeout)
return transactionID, nil
}
func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) {
device, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
func (mach *OlmMachine) handleInRoomVerificationReady(ctx context.Context, userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) {
device, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice)
if err != nil {
mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
return
}
verState, err := mach.getTransactionState(transactionID, userID)
verState, err := mach.getTransactionState(ctx, transactionID, userID)
if err != nil {
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
@@ -327,7 +328,7 @@ func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID i
if mach.Client.UserID < userID {
// up to us to send the start message
verState.lock.Lock()
mach.newInRoomSASVerificationWithInner(roomID, device, verState.hooks, transactionID, 10*time.Minute)
mach.newInRoomSASVerificationWithInner(ctx, roomID, device, verState.hooks, transactionID, 10*time.Minute)
verState.lock.Unlock()
}
}

View File

@@ -105,6 +105,8 @@ func (evt *Event) MarshalJSON() ([]byte, error) {
}
type MautrixInfo struct {
EventSource Source
TrustState id.TrustState
ForwardedKeys bool
WasEncrypted bool

72
vendor/maunium.net/go/mautrix/event/eventsource.go generated vendored Normal file
View File

@@ -0,0 +1,72 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package event
import (
"fmt"
)
// Source represents the part of the sync response that an event came from.
type Source int
const (
SourcePresence Source = 1 << iota
SourceJoin
SourceInvite
SourceLeave
SourceAccountData
SourceTimeline
SourceState
SourceEphemeral
SourceToDevice
SourceDecrypted
)
const primaryTypes = SourcePresence | SourceAccountData | SourceToDevice | SourceTimeline | SourceState
const roomSections = SourceJoin | SourceInvite | SourceLeave
const roomableTypes = SourceAccountData | SourceTimeline | SourceState
const encryptableTypes = roomableTypes | SourceToDevice
func (es Source) String() string {
var typeName string
switch es & primaryTypes {
case SourcePresence:
typeName = "presence"
case SourceAccountData:
typeName = "account data"
case SourceToDevice:
typeName = "to-device"
case SourceTimeline:
typeName = "timeline"
case SourceState:
typeName = "state"
default:
return fmt.Sprintf("unknown (%d)", es)
}
if es&roomableTypes != 0 {
switch es & roomSections {
case SourceJoin:
typeName = "joined room " + typeName
case SourceInvite:
typeName = "invited room " + typeName
case SourceLeave:
typeName = "left room " + typeName
default:
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
es &^= roomSections
}
if es&encryptableTypes != 0 && es&SourceDecrypted != 0 {
typeName += " (decrypted)"
es &^= SourceDecrypted
}
es &^= primaryTypes
if es != 0 {
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
return typeName
}

View File

@@ -199,10 +199,14 @@ type FileInfo struct {
ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"`
ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"`
ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"`
Width int `json:"-"`
Height int `json:"-"`
Duration int `json:"-"`
Size int `json:"-"`
Blurhash string `json:"blurhash,omitempty"`
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
Width int `json:"-"`
Height int `json:"-"`
Duration int `json:"-"`
Size int `json:"-"`
}
type serializableFileInfo struct {
@@ -211,6 +215,9 @@ type serializableFileInfo struct {
ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"`
ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"`
Blurhash string `json:"blurhash,omitempty"`
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
Width json.Number `json:"w,omitempty"`
Height json.Number `json:"h,omitempty"`
Duration json.Number `json:"duration,omitempty"`
@@ -226,6 +233,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI
ThumbnailURL: fileInfo.ThumbnailURL,
ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo),
ThumbnailFile: fileInfo.ThumbnailFile,
Blurhash: fileInfo.Blurhash,
AnoaBlurhash: fileInfo.AnoaBlurhash,
}
if fileInfo.Width > 0 {
sfi.Width = json.Number(strconv.Itoa(fileInfo.Width))
@@ -252,6 +262,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) {
MimeType: sfi.MimeType,
ThumbnailURL: sfi.ThumbnailURL,
ThumbnailFile: sfi.ThumbnailFile,
Blurhash: sfi.Blurhash,
AnoaBlurhash: sfi.AnoaBlurhash,
}
if sfi.ThumbnailInfo != nil {
fileInfo.ThumbnailInfo = &FileInfo{}

View File

@@ -287,9 +287,7 @@ type Signatures map[id.UserID]map[id.KeyID]string
type ReqQueryKeys struct {
DeviceKeys DeviceKeysRequest `json:"device_keys"`
Timeout int64 `json:"timeout,omitempty"`
Token string `json:"token,omitempty"`
Timeout int64 `json:"timeout,omitempty"`
}
type DeviceKeysRequest map[id.UserID]DeviceIDList

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,6 +7,7 @@
package sqlstatestore
import (
"context"
"database/sql"
"embed"
"encoding/json"
@@ -15,6 +16,7 @@ import (
"strconv"
"strings"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
@@ -44,26 +46,28 @@ func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge b
}
}
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) (bool, error) {
var isRegistered bool
err := store.
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
Scan(&isRegistered)
if err != nil {
store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return isRegistered
return isRegistered, err
}
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
store.Log.Warn("Failed to mark %s as registered: %v", userID, err)
}
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
return err
}
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent {
members := make(map[id.UserID]*event.MemberEventContent)
type Member struct {
id.UserID
event.MemberEventContent
}
func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) {
args := make([]any, len(memberships)+1)
args[0] = roomID
query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1"
@@ -75,25 +79,26 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...even
}
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
}
rows, err := store.Query(query, args...)
rows, err := store.Query(ctx, query, args...)
if err != nil {
return members
return nil, err
}
var userID id.UserID
var member event.MemberEventContent
for rows.Next() {
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil {
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
} else {
members[userID] = &member
}
}
return members
members := make(map[id.UserID]*event.MemberEventContent)
return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) {
err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL)
return
}).Iter(func(m Member) (bool, error) {
members[m.UserID] = &m.MemberEventContent
return true, nil
})
}
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite)
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) {
var memberMap map[id.UserID]*event.MemberEventContent
memberMap, err = store.GetRoomMembers(ctx, roomID, event.MembershipJoin, event.MembershipInvite)
if err != nil {
return
}
members = make([]id.UserID, len(memberMap))
i := 0
for userID := range memberMap {
@@ -103,37 +108,39 @@ func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (mem
return
}
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
membership := event.MembershipLeave
err := store.
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
func (store *SQLStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (membership event.Membership, err error) {
err = store.
QueryRow(ctx, "SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&membership)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err)
if errors.Is(err, sql.ErrNoRows) {
membership = event.MembershipLeave
err = nil
}
return membership
return
}
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member.Membership = event.MembershipLeave
func (store *SQLStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
member, err := store.TryGetMember(ctx, roomID, userID)
if member == nil && err == nil {
member = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return member
return member, err
}
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
func (store *SQLStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
var member event.MemberEventContent
err := store.
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
QueryRow(ctx, "SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err)
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
}
return &member, err == nil
return &member, nil
}
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
query := `
SELECT room_id FROM mx_user_profile
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
@@ -141,38 +148,32 @@ func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID
`
if !store.IsBridge {
query = `
SELECT mx_user_profile.room_id FROM mx_user_profile
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
`
SELECT mx_user_profile.room_id FROM mx_user_profile
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
`
}
rows, err := store.Query(query, userID)
rows, err := store.Query(ctx, query, userID)
if err != nil {
store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err)
return
return nil, err
}
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
store.Log.Warn("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
}
func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership, err := store.GetMembership(ctx, roomID, userID)
if err != nil {
zerolog.Ctx(ctx).Err(err).Msg("Failed to get membership")
return false
}
return
}
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
@@ -181,27 +182,23 @@ func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, all
return false
}
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
_, err := store.Exec(`
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
`, roomID, userID, membership)
if err != nil {
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
}
return err
}
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
_, err := store.Exec(`
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
_, err := store.Exec(ctx, `
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
if err != nil {
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
}
return err
}
func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
params := make([]any, len(memberships)+1)
params[0] = roomID
@@ -213,109 +210,85 @@ func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...
}
query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ","))
}
_, err := store.Exec(query, params...)
if err != nil {
store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err)
}
_, err := store.Exec(ctx, query, params...)
return err
}
func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
contentBytes, err := json.Marshal(content)
if err != nil {
store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err)
return
return fmt.Errorf("failed to marshal content JSON: %w", err)
}
_, err = store.Exec(`
_, err = store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption
`, roomID, contentBytes)
if err != nil {
store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err)
}
return err
}
func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
var data []byte
err := store.
QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err)
}
return nil
if errors.Is(err, sql.ErrNoRows) {
return nil, nil
} else if err != nil {
return nil, err
} else if data == nil {
return nil
return nil, nil
}
content := &event.EncryptionEventContent{}
err = json.Unmarshal(data, content)
var content event.EncryptionEventContent
err = json.Unmarshal(data, &content)
if err != nil {
store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err)
return nil
return nil, fmt.Errorf("failed to parse content JSON: %w", err)
}
return content
return &content, nil
}
func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool {
cfg := store.GetEncryptionEvent(roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
}
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
levelsBytes, err := json.Marshal(levels)
if err != nil {
store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err)
return
}
_, err = store.Exec(`
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
_, err := store.Exec(ctx, `
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
`, roomID, levelsBytes)
if err != nil {
store.Log.Warn("Failed to store power levels of %s: %v", roomID, err)
}
`, roomID, dbutil.JSON{Data: levels})
return err
}
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
var data []byte
err := store.
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err)
}
return
} else if data == nil {
return
}
levels = &event.PowerLevelsEventContent{}
err = json.Unmarshal(data, levels)
if err != nil {
store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err)
return nil
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
err = store.
QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&dbutil.JSON{Data: &levels})
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
if store.Dialect == dbutil.Postgres {
var powerLevel int
err := store.
QueryRow(`
QueryRow(ctx, `
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
FROM mx_room_state WHERE room_id=$1
`, roomID, userID).
Scan(&powerLevel)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err)
return powerLevel, err
} else {
levels, err := store.GetPowerLevels(ctx, roomID)
if err != nil {
return 0, err
}
return powerLevel
return levels.GetUserLevel(userID), nil
}
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
@@ -325,23 +298,26 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType
}
var powerLevel int
err := store.
QueryRow(`
QueryRow(ctx, `
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
FROM mx_room_state WHERE room_id=$1
`, roomID, eventType.Type, defaultType, defaultValue).
Scan(&powerLevel)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue
if errors.Is(err, sql.ErrNoRows) {
err = nil
powerLevel = defaultValue
}
return powerLevel
return powerLevel, err
} else {
levels, err := store.GetPowerLevels(ctx, roomID)
if err != nil {
return 0, err
}
return levels.GetEventLevel(eventType), nil
}
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
@@ -351,19 +327,22 @@ func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, ev
}
var hasPower bool
err := store.
QueryRow(`SELECT
QueryRow(ctx, `SELECT
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
>=
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
Scan(&hasPower)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue == 0
if errors.Is(err, sql.ErrNoRows) {
err = nil
hasPower = defaultValue == 0
}
return hasPower
return hasPower, err
} else {
levels, err := store.GetPowerLevels(ctx, roomID)
if err != nil {
return false, err
}
return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
}
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}

View File

@@ -1,19 +1,20 @@
package sqlstatestore
import (
"context"
"fmt"
"go.mau.fi/util/dbutil"
)
func init() {
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error {
portalExists, err := db.TableExists(tx, "portal")
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error {
portalExists, err := db.TableExists(ctx, "portal")
if err != nil {
return fmt.Errorf("failed to check if portal table exists")
}
if portalExists {
_, err = tx.Exec(`
_, err = db.Exec(ctx, `
INSERT INTO mx_room_state (room_id, encryption)
SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
ON CONFLICT (room_id) DO UPDATE

View File

@@ -7,33 +7,37 @@
package mautrix
import (
"context"
"sync"
"github.com/rs/zerolog"
"go.mau.fi/util/exerrors"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// StateStore is an interface for storing basic room state information.
type StateStore interface {
IsInRoom(roomID id.RoomID, userID id.UserID) bool
IsInvited(roomID id.RoomID, userID id.UserID) bool
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership)
IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent)
IsEncrypted(roomID id.RoomID) bool
SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error
IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error)
GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error)
GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error)
}
func UpdateStateStore(store StateStore, evt *event.Event) {
func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
if store == nil || evt == nil || evt.StateKey == nil {
return
}
@@ -41,13 +45,20 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
if evt.Type != event.StateMember && evt.GetStateKey() != "" {
return
}
var err error
switch content := evt.Content.Parsed.(type) {
case *event.MemberEventContent:
store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
err = store.SetMember(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), content)
case *event.PowerLevelsEventContent:
store.SetPowerLevels(evt.RoomID, content)
err = store.SetPowerLevels(ctx, evt.RoomID, content)
case *event.EncryptionEventContent:
store.SetEncryptionEvent(evt.RoomID, content)
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
}
if err != nil {
zerolog.Ctx(ctx).Warn().Err(err).
Stringer("event_id", evt.ID).
Str("event_type", evt.Type.Type).
Msg("Failed to update state store")
}
}
@@ -56,8 +67,8 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
// client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler)
//
// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) {
UpdateStateStore(cli.StateStore, evt)
func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) {
UpdateStateStore(ctx, cli.StateStore, evt)
}
type MemoryStateStore struct {
@@ -81,20 +92,21 @@ func NewMemoryStateStore() StateStore {
}
}
func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool {
func (store *MemoryStateStore) IsRegistered(_ context.Context, userID id.UserID) (bool, error) {
store.registrationsLock.RLock()
defer store.registrationsLock.RUnlock()
registered, ok := store.Registrations[userID]
return ok && registered
return ok && registered, nil
}
func (store *MemoryStateStore) MarkRegistered(userID id.UserID) {
func (store *MemoryStateStore) MarkRegistered(_ context.Context, userID id.UserID) error {
store.registrationsLock.Lock()
defer store.registrationsLock.Unlock()
store.Registrations[userID] = true
return nil
}
func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
func (store *MemoryStateStore) GetRoomMembers(_ context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
store.membersLock.RLock()
members, ok := store.Members[roomID]
store.membersLock.RUnlock()
@@ -104,11 +116,14 @@ func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*e
store.Members[roomID] = members
store.membersLock.Unlock()
}
return members
return members, nil
}
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) {
members := store.GetRoomMembers(roomID)
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
members, err := store.GetRoomMembers(ctx, roomID)
if err != nil {
return nil, err
}
ids := make([]id.UserID, 0, len(members))
for id := range members {
ids = append(ids, id)
@@ -116,39 +131,39 @@ func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (
return ids, nil
}
func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
return store.GetMember(roomID, userID).Membership
func (store *MemoryStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (event.Membership, error) {
return exerrors.Must(store.GetMember(ctx, roomID, userID)).Membership, nil
}
func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
member, err := store.TryGetMember(ctx, roomID, userID)
if member == nil && err == nil {
member = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return member
return member, err
}
func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) {
store.membersLock.RLock()
defer store.membersLock.RUnlock()
members, membersOk := store.Members[roomID]
if !membersOk {
return
}
member, ok = members[userID]
member = members[userID]
return
}
func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
func (store *MemoryStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join")
}
func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
func (store *MemoryStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(ctx, roomID, userID, "join", "invite")
}
func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
func (store *MemoryStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := exerrors.Must(store.GetMembership(ctx, roomID, userID))
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
@@ -157,7 +172,7 @@ func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID,
return false
}
func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
func (store *MemoryStateStore) SetMembership(_ context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
@@ -175,9 +190,10 @@ func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID,
}
store.Members[roomID] = members
store.membersLock.Unlock()
return nil
}
func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
func (store *MemoryStateStore) SetMember(_ context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
@@ -189,14 +205,15 @@ func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, mem
}
store.Members[roomID] = members
store.membersLock.Unlock()
return nil
}
func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.RoomID, memberships ...event.Membership) error {
store.membersLock.Lock()
defer store.membersLock.Unlock()
members, ok := store.Members[roomID]
if !ok {
return
return nil
}
for userID, member := range members {
for _, membership := range memberships {
@@ -206,46 +223,49 @@ func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships
}
}
}
return nil
}
func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
store.powerLevelsLock.Lock()
store.PowerLevels[roomID] = levels
store.powerLevelsLock.Unlock()
return nil
}
func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
store.powerLevelsLock.RUnlock()
return
}
func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
return store.GetPowerLevels(roomID).GetUserLevel(userID)
func (store *MemoryStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetUserLevel(userID), nil
}
func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
func (store *MemoryStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetEventLevel(eventType), nil
}
func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
}
func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
store.encryptionLock.Unlock()
return nil
}
func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
store.encryptionLock.RLock()
defer store.encryptionLock.RUnlock()
return store.Encryption[roomID]
return store.Encryption[roomID], nil
}
func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool {
cfg := store.GetEncryptionEvent(roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
cfg, err := store.GetEncryptionEvent(ctx, roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
}

145
vendor/maunium.net/go/mautrix/sync.go generated vendored
View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,6 +7,7 @@
package mautrix
import (
"context"
"errors"
"fmt"
"runtime/debug"
@@ -16,78 +17,17 @@ import (
"maunium.net/go/mautrix/id"
)
// EventSource represents the part of the sync response that an event came from.
type EventSource int
const (
EventSourcePresence EventSource = 1 << iota
EventSourceJoin
EventSourceInvite
EventSourceLeave
EventSourceAccountData
EventSourceTimeline
EventSourceState
EventSourceEphemeral
EventSourceToDevice
EventSourceDecrypted
)
const primaryTypes = EventSourcePresence | EventSourceAccountData | EventSourceToDevice | EventSourceTimeline | EventSourceState
const roomSections = EventSourceJoin | EventSourceInvite | EventSourceLeave
const roomableTypes = EventSourceAccountData | EventSourceTimeline | EventSourceState
const encryptableTypes = roomableTypes | EventSourceToDevice
func (es EventSource) String() string {
var typeName string
switch es & primaryTypes {
case EventSourcePresence:
typeName = "presence"
case EventSourceAccountData:
typeName = "account data"
case EventSourceToDevice:
typeName = "to-device"
case EventSourceTimeline:
typeName = "timeline"
case EventSourceState:
typeName = "state"
default:
return fmt.Sprintf("unknown (%d)", es)
}
if es&roomableTypes != 0 {
switch es & roomSections {
case EventSourceJoin:
typeName = "joined room " + typeName
case EventSourceInvite:
typeName = "invited room " + typeName
case EventSourceLeave:
typeName = "left room " + typeName
default:
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
es &^= roomSections
}
if es&encryptableTypes != 0 && es&EventSourceDecrypted != 0 {
typeName += " (decrypted)"
es &^= EventSourceDecrypted
}
es &^= primaryTypes
if es != 0 {
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
}
return typeName
}
// EventHandler handles a single event from a sync response.
type EventHandler func(source EventSource, evt *event.Event)
type EventHandler func(ctx context.Context, evt *event.Event)
// SyncHandler handles a whole sync response. If the return value is false, handling will be stopped completely.
type SyncHandler func(resp *RespSync, since string) bool
type SyncHandler func(ctx context.Context, resp *RespSync, since string) bool
// Syncer is an interface that must be satisfied in order to do /sync requests on a client.
type Syncer interface {
// ProcessResponse processes the /sync response. The since parameter is the since= value that was used to produce the response.
// This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped permanently.
ProcessResponse(resp *RespSync, since string) error
ProcessResponse(ctx context.Context, resp *RespSync, since string) error
// OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently.
OnFailedSync(res *RespSync, err error) (time.Duration, error)
// GetFilterJSON for the given user ID. NOT the filter ID.
@@ -101,7 +41,7 @@ type ExtensibleSyncer interface {
}
type DispatchableSyncer interface {
Dispatch(source EventSource, evt *event.Event)
Dispatch(ctx context.Context, evt *event.Event)
}
// DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively
@@ -134,14 +74,17 @@ func NewDefaultSyncer() *DefaultSyncer {
globalListeners: []EventHandler{},
ParseEventContent: true,
ParseErrorHandler: func(evt *event.Event, err error) bool {
return false
// By default, drop known events that can't be parsed, but let unknown events through
return errors.Is(err, event.ErrUnsupportedContentType) ||
// Also allow events that had their content already parsed by some other function
errors.Is(err, event.ErrContentAlreadyParsed)
},
}
}
// ProcessResponse processes the /sync response in a way suitable for bots. "Suitable for bots" means a stream of
// unrepeating events. Returns a fatal error if a listener panics.
func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) {
func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, since string) (err error) {
defer func() {
if r := recover(); r != nil {
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
@@ -149,38 +92,38 @@ func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error)
}()
for _, listener := range s.syncListeners {
if !listener(res, since) {
if !listener(ctx, res, since) {
return
}
}
s.processSyncEvents("", res.ToDevice.Events, EventSourceToDevice)
s.processSyncEvents("", res.Presence.Events, EventSourcePresence)
s.processSyncEvents("", res.AccountData.Events, EventSourceAccountData)
s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice)
s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence)
s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData)
for roomID, roomData := range res.Rooms.Join {
s.processSyncEvents(roomID, roomData.State.Events, EventSourceJoin|EventSourceState)
s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceJoin|EventSourceTimeline)
s.processSyncEvents(roomID, roomData.Ephemeral.Events, EventSourceJoin|EventSourceEphemeral)
s.processSyncEvents(roomID, roomData.AccountData.Events, EventSourceJoin|EventSourceAccountData)
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState)
s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline)
s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral)
s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData)
}
for roomID, roomData := range res.Rooms.Invite {
s.processSyncEvents(roomID, roomData.State.Events, EventSourceInvite|EventSourceState)
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState)
}
for roomID, roomData := range res.Rooms.Leave {
s.processSyncEvents(roomID, roomData.State.Events, EventSourceLeave|EventSourceState)
s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceLeave|EventSourceTimeline)
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState)
s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline)
}
return
}
func (s *DefaultSyncer) processSyncEvents(roomID id.RoomID, events []*event.Event, source EventSource) {
func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) {
for _, evt := range events {
s.processSyncEvent(roomID, evt, source)
s.processSyncEvent(ctx, roomID, evt, source)
}
}
func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, source EventSource) {
func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) {
evt.RoomID = roomID
// Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer.
@@ -188,11 +131,11 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
switch {
case evt.StateKey != nil:
evt.Type.Class = event.StateEventType
case source == EventSourcePresence, source&EventSourceEphemeral != 0:
case source == event.SourcePresence, source&event.SourceEphemeral != 0:
evt.Type.Class = event.EphemeralEventType
case source&EventSourceAccountData != 0:
case source&event.SourceAccountData != 0:
evt.Type.Class = event.AccountDataEventType
case source == EventSourceToDevice:
case source == event.SourceToDevice:
evt.Type.Class = event.ToDeviceEventType
default:
evt.Type.Class = event.MessageEventType
@@ -205,17 +148,18 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
}
}
s.Dispatch(source, evt)
evt.Mautrix.EventSource = source
s.Dispatch(ctx, evt)
}
func (s *DefaultSyncer) Dispatch(source EventSource, evt *event.Event) {
func (s *DefaultSyncer) Dispatch(ctx context.Context, evt *event.Event) {
for _, fn := range s.globalListeners {
fn(source, evt)
fn(ctx, evt)
}
listeners, exists := s.listeners[evt.Type]
if exists {
for _, fn := range listeners {
fn(source, evt)
fn(ctx, evt)
}
}
}
@@ -263,31 +207,18 @@ func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter {
return s.FilterJSON
}
// OldEventIgnorer is a utility struct for bots to ignore events from before the bot joined the room.
//
// Deprecated: Use Client.DontProcessOldEvents instead.
type OldEventIgnorer struct {
UserID id.UserID
}
func (oei *OldEventIgnorer) Register(syncer ExtensibleSyncer) {
syncer.OnSync(oei.DontProcessOldEvents)
}
func (oei *OldEventIgnorer) DontProcessOldEvents(resp *RespSync, since string) bool {
return dontProcessOldEvents(oei.UserID, resp, since)
}
// DontProcessOldEvents is a sync handler that removes rooms that the user just joined.
// It's meant for bots to ignore events from before the bot joined the room.
//
// To use it, register it with your Syncer, e.g.:
//
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.DontProcessOldEvents)
func (cli *Client) DontProcessOldEvents(resp *RespSync, since string) bool {
func (cli *Client) DontProcessOldEvents(_ context.Context, resp *RespSync, since string) bool {
return dontProcessOldEvents(cli.UserID, resp, since)
}
var _ SyncHandler = (*Client)(nil).DontProcessOldEvents
func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
if since == "" {
return false
@@ -324,7 +255,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
// To use it, register it with your Syncer, e.g.:
//
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool {
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
for _, meta := range resp.Rooms.Invite {
var inviteState []event.StrippedState
var inviteEvt *event.Event
@@ -349,3 +280,5 @@ func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool {
}
return true
}
var _ SyncHandler = (*Client)(nil).MoveInviteState

View File

@@ -1,21 +1,26 @@
package mautrix
import (
"context"
"errors"
"fmt"
"maunium.net/go/mautrix/id"
)
var _ SyncStore = (*MemorySyncStore)(nil)
var _ SyncStore = (*AccountDataStore)(nil)
// SyncStore is an interface which must be satisfied to store client data.
//
// You can either write a struct which persists this data to disk, or you can use the
// provided "MemorySyncStore" which just keeps data around in-memory which is lost on
// restarts.
type SyncStore interface {
SaveFilterID(userID id.UserID, filterID string)
LoadFilterID(userID id.UserID) string
SaveNextBatch(userID id.UserID, nextBatchToken string)
LoadNextBatch(userID id.UserID) string
SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error
LoadFilterID(ctx context.Context, userID id.UserID) (string, error)
SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error
LoadNextBatch(ctx context.Context, userID id.UserID) (string, error)
}
// Deprecated: renamed to SyncStore
@@ -32,23 +37,25 @@ type MemorySyncStore struct {
}
// SaveFilterID to memory.
func (s *MemorySyncStore) SaveFilterID(userID id.UserID, filterID string) {
func (s *MemorySyncStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error {
s.Filters[userID] = filterID
return nil
}
// LoadFilterID from memory.
func (s *MemorySyncStore) LoadFilterID(userID id.UserID) string {
return s.Filters[userID]
func (s *MemorySyncStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) {
return s.Filters[userID], nil
}
// SaveNextBatch to memory.
func (s *MemorySyncStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
func (s *MemorySyncStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
s.NextBatch[userID] = nextBatchToken
return nil
}
// LoadNextBatch from memory.
func (s *MemorySyncStore) LoadNextBatch(userID id.UserID) string {
return s.NextBatch[userID]
func (s *MemorySyncStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) {
return s.NextBatch[userID], nil
}
// NewMemorySyncStore constructs a new MemorySyncStore.
@@ -72,34 +79,35 @@ type accountData struct {
NextBatch string `json:"next_batch"`
}
func (s *AccountDataStore) SaveFilterID(userID id.UserID, filterID string) {
func (s *AccountDataStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
}
s.FilterID = filterID
return nil
}
func (s *AccountDataStore) LoadFilterID(userID id.UserID) string {
func (s *AccountDataStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
}
return s.FilterID
return s.FilterID, nil
}
func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
} else if nextBatchToken == s.nextBatch {
return
return nil
}
data := accountData{
NextBatch: nextBatchToken,
}
err := s.client.SetAccountData(s.EventType, data)
err := s.client.SetAccountData(ctx, s.EventType, data)
if err != nil {
s.client.Log.Warn().Err(err).Msg("Failed to save next batch token to account data")
return fmt.Errorf("failed to save next batch token to account data: %w", err)
} else {
s.client.Log.Debug().
Str("old_token", s.nextBatch).
@@ -107,28 +115,29 @@ func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string
Msg("Saved next batch token")
s.nextBatch = nextBatchToken
}
return nil
}
func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string {
func (s *AccountDataStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
}
data := &accountData{}
err := s.client.GetAccountData(s.EventType, data)
err := s.client.GetAccountData(ctx, s.EventType, data)
if err != nil {
if errors.Is(err, MNotFound) {
s.client.Log.Debug().Msg("No next batch token found in account data")
return "", nil
} else {
s.client.Log.Warn().Err(err).Msg("Failed to load next batch token from account data")
return "", fmt.Errorf("failed to load next batch token from account data: %w", err)
}
return ""
}
s.nextBatch = data.NextBatch
s.client.Log.Debug().Str("next_batch", data.NextBatch).Msg("Loaded next batch token from account data")
return s.nextBatch
return s.nextBatch, nil
}
// NewAccountDataStore returns a new AccountDataStore, which stores

View File

@@ -7,7 +7,7 @@ import (
"strings"
)
const Version = "v0.16.2"
const Version = "v0.17.0"
var GoModVersion = ""
var Commit = ""

View File

@@ -93,6 +93,8 @@ var (
SpecV15 = MustParseSpecVersion("v1.5")
SpecV16 = MustParseSpecVersion("v1.6")
SpecV17 = MustParseSpecVersion("v1.7")
SpecV18 = MustParseSpecVersion("v1.8")
SpecV19 = MustParseSpecVersion("v1.9")
)
func (svf SpecVersionFormat) String() string {