refactor to mautrix 0.17.x; update deps
This commit is contained in:
4
vendor/maunium.net/go/mautrix/.pre-commit-config.yaml
generated
vendored
4
vendor/maunium.net/go/mautrix/.pre-commit-config.yaml
generated
vendored
@@ -12,4 +12,8 @@ repos:
|
||||
rev: v1.0.0-rc.1
|
||||
hooks:
|
||||
- id: go-imports-repo
|
||||
args:
|
||||
- "-local"
|
||||
- "maunium.net/go/mautrix"
|
||||
- "-w"
|
||||
- id: go-vet-repo-mod
|
||||
|
||||
24
vendor/maunium.net/go/mautrix/CHANGELOG.md
generated
vendored
24
vendor/maunium.net/go/mautrix/CHANGELOG.md
generated
vendored
@@ -1,3 +1,27 @@
|
||||
## v0.17.0 (2024-01-16)
|
||||
|
||||
* **Breaking change *(bridge)*** Added raw event to portal membership handling
|
||||
functions.
|
||||
* **Breaking change *(everything)*** Added context parameters to all functions
|
||||
(started by [@recht] in [#144]).
|
||||
* **Breaking change *(client)*** Moved `EventSource` to `event.Source`.
|
||||
* *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version
|
||||
(`Client.DontProcessOldEvents`) is still available.
|
||||
* *(crypto)* Added experimental pure Go Olm implementation to replace libolm
|
||||
(thanks to [@DerLukas15] in [#106]).
|
||||
* You can use the `goolm` build tag to the new implementation.
|
||||
* *(bridge)* Added context parameter for bridge command events.
|
||||
* *(bridge)* Added method to allow custom validation for the entire config.
|
||||
* *(client)* Changed default syncer to not drop unknown events.
|
||||
* The syncer will still drop known events if parsing the content fails.
|
||||
* The behavior can be changed by changing the `ParseErrorHandler` function.
|
||||
* *(crypto)* Fixed some places using math/rand instead of crypto/rand.
|
||||
|
||||
[@DerLukas15]: https://github.com/DerLukas15
|
||||
[@recht]: https://github.com/recht
|
||||
[#106]: https://github.com/mautrix/go/pull/106
|
||||
[#144]: https://github.com/mautrix/go/pull/144
|
||||
|
||||
## v0.16.2 (2023-11-16)
|
||||
|
||||
* *(event)* Added `Redacts` field to `RedactionEventContent` for room v11+.
|
||||
|
||||
5
vendor/maunium.net/go/mautrix/README.md
generated
vendored
5
vendor/maunium.net/go/mautrix/README.md
generated
vendored
@@ -17,8 +17,3 @@ In addition to the basic client API features the original project has, this fram
|
||||
* Structs for parsing event content
|
||||
* Helpers for parsing and generating Matrix HTML
|
||||
* Helpers for handling push rules
|
||||
|
||||
This project contains modules that are licensed under Apache 2.0:
|
||||
|
||||
* [maunium.net/go/mautrix/crypto/canonicaljson](crypto/canonicaljson)
|
||||
* [maunium.net/go/mautrix/crypto/olm](crypto/olm)
|
||||
|
||||
642
vendor/maunium.net/go/mautrix/client.go
generated
vendored
642
vendor/maunium.net/go/mautrix/client.go
generated
vendored
File diff suppressed because it is too large
Load Diff
2
vendor/maunium.net/go/mautrix/crypto/account.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/account.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
|
||||
177
vendor/maunium.net/go/mautrix/crypto/canonicaljson/LICENSE
generated
vendored
177
vendor/maunium.net/go/mautrix/crypto/canonicaljson/LICENSE
generated
vendored
@@ -1,177 +0,0 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
2
vendor/maunium.net/go/mautrix/crypto/canonicaljson/README.md
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/canonicaljson/README.md
generated
vendored
@@ -2,3 +2,5 @@
|
||||
This is a Go package to produce Matrix [Canonical JSON](https://matrix.org/docs/spec/appendices#canonical-json).
|
||||
It is essentially just [json.go](https://github.com/matrix-org/gomatrixserverlib/blob/master/json.go)
|
||||
from gomatrixserverlib without all the other files that are completely useless for non-server use cases.
|
||||
|
||||
The original project is licensed under the Apache 2.0 license.
|
||||
|
||||
5
vendor/maunium.net/go/mautrix/crypto/cross_sign_key.go
generated
vendored
5
vendor/maunium.net/go/mautrix/crypto/cross_sign_key.go
generated
vendored
@@ -8,6 +8,7 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
@@ -89,7 +90,7 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
|
||||
}
|
||||
|
||||
// PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server.
|
||||
func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error {
|
||||
func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error {
|
||||
userID := mach.Client.UserID
|
||||
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String())
|
||||
masterKey := mautrix.CrossSigningKeys{
|
||||
@@ -134,7 +135,7 @@ func (mach *OlmMachine) PublishCrossSigningKeys(keys *CrossSigningKeysCache, uia
|
||||
},
|
||||
}
|
||||
|
||||
err = mach.Client.UploadCrossSigningKeys(&mautrix.UploadCrossSigningKeysReq{
|
||||
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
|
||||
Master: masterKey,
|
||||
SelfSigning: selfKey,
|
||||
UserSigning: userKey,
|
||||
|
||||
13
vendor/maunium.net/go/mautrix/crypto/cross_sign_pubkey.go
generated
vendored
13
vendor/maunium.net/go/mautrix/crypto/cross_sign_pubkey.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,6 +7,7 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
@@ -19,7 +20,7 @@ type CrossSigningPublicKeysCache struct {
|
||||
UserSigningKey id.Ed25519
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCache {
|
||||
func (mach *OlmMachine) GetOwnCrossSigningPublicKeys(ctx context.Context) *CrossSigningPublicKeysCache {
|
||||
if mach.crossSigningPubkeys != nil {
|
||||
return mach.crossSigningPubkeys
|
||||
}
|
||||
@@ -30,7 +31,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa
|
||||
if mach.crossSigningPubkeysFetched {
|
||||
return nil
|
||||
}
|
||||
cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID)
|
||||
cspk, err := mach.GetCrossSigningPublicKeys(ctx, mach.Client.UserID)
|
||||
if err != nil {
|
||||
mach.Log.Error().Err(err).Msg("Failed to get own cross-signing public keys")
|
||||
return nil
|
||||
@@ -40,8 +41,8 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa
|
||||
return mach.crossSigningPubkeys
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigningPublicKeysCache, error) {
|
||||
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
func (mach *OlmMachine) GetCrossSigningPublicKeys(ctx context.Context, userID id.UserID) (*CrossSigningPublicKeysCache, error) {
|
||||
dbKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get keys from database: %w", err)
|
||||
}
|
||||
@@ -58,7 +59,7 @@ func (mach *OlmMachine) GetCrossSigningPublicKeys(userID id.UserID) (*CrossSigni
|
||||
}
|
||||
}
|
||||
|
||||
keys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{
|
||||
keys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: mautrix.DeviceKeysRequest{
|
||||
userID: mautrix.DeviceIDList{},
|
||||
},
|
||||
|
||||
35
vendor/maunium.net/go/mautrix/crypto/cross_sign_signing.go
generated
vendored
35
vendor/maunium.net/go/mautrix/crypto/cross_sign_signing.go
generated
vendored
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -8,6 +8,7 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
@@ -33,8 +34,8 @@ var (
|
||||
ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC")
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
|
||||
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
|
||||
func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
|
||||
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err)
|
||||
}
|
||||
@@ -59,7 +60,7 @@ func (mach *OlmMachine) fetchMasterKey(device *id.Device, content *event.Verific
|
||||
}
|
||||
|
||||
// SignUser creates a cross-signing signature for a user, stores it and uploads it to the server.
|
||||
func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
|
||||
func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKey id.Ed25519) error {
|
||||
if userID == mach.Client.UserID {
|
||||
return ErrCantSignOwnMasterKey
|
||||
} else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.UserSigningKey == nil {
|
||||
@@ -74,7 +75,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
|
||||
},
|
||||
}
|
||||
|
||||
signature, err := mach.signAndUpload(masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey)
|
||||
signature, err := mach.signAndUpload(ctx, masterKeyObj, userID, masterKey.String(), mach.CrossSigningKeys.UserSigningKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -84,7 +85,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
|
||||
Str("signature", signature).
|
||||
Msg("Signed master key of user with our user-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
@@ -92,7 +93,7 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
|
||||
}
|
||||
|
||||
// SignOwnMasterKey uses the current account for signing the current user's master key and uploads the signature.
|
||||
func (mach *OlmMachine) SignOwnMasterKey() error {
|
||||
func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
|
||||
if mach.CrossSigningKeys == nil {
|
||||
return ErrCrossSigningKeysNotCached
|
||||
} else if mach.account == nil {
|
||||
@@ -124,7 +125,7 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
|
||||
Str("signature", signature).
|
||||
Msg("Signed own master key with own device key")
|
||||
|
||||
resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{
|
||||
resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{
|
||||
userID: map[string]mautrix.ReqKeysSignatures{
|
||||
masterKey.String(): masterKeyObj,
|
||||
},
|
||||
@@ -136,7 +137,7 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
|
||||
return fmt.Errorf("%w: %+v", ErrSignatureUploadFail, resp.Failures)
|
||||
}
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, userID, mach.account.SigningKey(), signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
@@ -144,14 +145,14 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
|
||||
}
|
||||
|
||||
// SignOwnDevice creates a cross-signing signature for a device belonging to the current user and uploads it to the server.
|
||||
func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
|
||||
func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) error {
|
||||
if device.UserID != mach.Client.UserID {
|
||||
return ErrCantSignOtherDevice
|
||||
} else if mach.CrossSigningKeys == nil || mach.CrossSigningKeys.SelfSigningKey == nil {
|
||||
return ErrSelfSigningKeyNotCached
|
||||
}
|
||||
|
||||
deviceKeys, err := mach.getFullDeviceKeys(device)
|
||||
deviceKeys, err := mach.getFullDeviceKeys(ctx, device)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -166,7 +167,7 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
|
||||
deviceKeyObj.Keys[id.KeyID(keyID)] = key
|
||||
}
|
||||
|
||||
signature, err := mach.signAndUpload(deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey)
|
||||
signature, err := mach.signAndUpload(ctx, deviceKeyObj, device.UserID, device.DeviceID.String(), mach.CrossSigningKeys.SelfSigningKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -177,7 +178,7 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
|
||||
Str("signature", signature).
|
||||
Msg("Signed own device key with self-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
|
||||
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
}
|
||||
|
||||
@@ -186,8 +187,8 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
|
||||
|
||||
// getFullDeviceKeys gets the full device keys object for the given device.
|
||||
// This is used because we don't cache some of the details like list of algorithms and unsupported key types.
|
||||
func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKeys, error) {
|
||||
devicesKeys, err := mach.Client.QueryKeys(&mautrix.ReqQueryKeys{
|
||||
func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device) (*mautrix.DeviceKeys, error) {
|
||||
devicesKeys, err := mach.Client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: mautrix.DeviceKeysRequest{
|
||||
device.UserID: mautrix.DeviceIDList{device.DeviceID},
|
||||
},
|
||||
@@ -208,7 +209,7 @@ func (mach *OlmMachine) getFullDeviceKeys(device *id.Device) (*mautrix.DeviceKey
|
||||
}
|
||||
|
||||
// signAndUpload signs the given key signatures object and uploads it to the server.
|
||||
func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) {
|
||||
func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) {
|
||||
signature, err := key.SignJSON(req)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to sign JSON: %w", err)
|
||||
@@ -219,7 +220,7 @@ func (mach *OlmMachine) signAndUpload(req mautrix.ReqKeysSignatures, userID id.U
|
||||
},
|
||||
}
|
||||
|
||||
resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{
|
||||
resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{
|
||||
userID: map[string]mautrix.ReqKeysSignatures{
|
||||
signedThing: req,
|
||||
},
|
||||
|
||||
31
vendor/maunium.net/go/mautrix/crypto/cross_sign_ssss.go
generated
vendored
31
vendor/maunium.net/go/mautrix/crypto/cross_sign_ssss.go
generated
vendored
@@ -7,6 +7,7 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
@@ -16,16 +17,16 @@ import (
|
||||
)
|
||||
|
||||
// FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine.
|
||||
func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error {
|
||||
masterKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningMaster, key)
|
||||
func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(ctx context.Context, key *ssss.Key) error {
|
||||
masterKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningMaster, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
selfSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningSelf, key)
|
||||
selfSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningSelf, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
userSignKey, err := mach.retrieveDecryptXSigningKey(event.AccountDataCrossSigningUser, key)
|
||||
userSignKey, err := mach.retrieveDecryptXSigningKey(ctx, event.AccountDataCrossSigningUser, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -38,12 +39,12 @@ func (mach *OlmMachine) FetchCrossSigningKeysFromSSSS(key *ssss.Key) error {
|
||||
}
|
||||
|
||||
// retrieveDecryptXSigningKey retrieves the requested cross-signing key from SSSS and decrypts it using the given SSSS key.
|
||||
func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) {
|
||||
func (mach *OlmMachine) retrieveDecryptXSigningKey(ctx context.Context, keyName event.Type, key *ssss.Key) ([utils.AESCTRKeyLength]byte, error) {
|
||||
var decryptedKey [utils.AESCTRKeyLength]byte
|
||||
var encData ssss.EncryptedAccountDataEventContent
|
||||
|
||||
// retrieve and parse the account data for this key type from SSSS
|
||||
err := mach.Client.GetAccountData(keyName.Type, &encData)
|
||||
err := mach.Client.GetAccountData(ctx, keyName.Type, &encData)
|
||||
if err != nil {
|
||||
return decryptedKey, err
|
||||
}
|
||||
@@ -62,8 +63,8 @@ func (mach *OlmMachine) retrieveDecryptXSigningKey(keyName event.Type, key *ssss
|
||||
// is used. The base58-formatted recovery key is the first return parameter.
|
||||
//
|
||||
// The account password of the user is required for uploading keys to the server.
|
||||
func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphrase string) (string, error) {
|
||||
key, err := mach.SSSS.GenerateAndUploadKey(passphrase)
|
||||
func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, userPassword, passphrase string) (string, error) {
|
||||
key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("failed to generate and upload SSSS key: %w", err)
|
||||
}
|
||||
@@ -77,12 +78,12 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra
|
||||
recoveryKey := key.RecoveryKey()
|
||||
|
||||
// Store the private keys in SSSS
|
||||
if err := mach.UploadCrossSigningKeysToSSSS(key, keysCache); err != nil {
|
||||
if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil {
|
||||
return recoveryKey, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err)
|
||||
}
|
||||
|
||||
// Publish cross-signing keys
|
||||
err = mach.PublishCrossSigningKeys(keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} {
|
||||
err = mach.PublishCrossSigningKeys(ctx, keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} {
|
||||
return &mautrix.ReqUIAuthLogin{
|
||||
BaseAuthData: mautrix.BaseAuthData{
|
||||
Type: mautrix.AuthTypePassword,
|
||||
@@ -96,7 +97,7 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra
|
||||
return recoveryKey, fmt.Errorf("failed to publish cross-signing keys: %w", err)
|
||||
}
|
||||
|
||||
err = mach.SSSS.SetDefaultKeyID(key.ID)
|
||||
err = mach.SSSS.SetDefaultKeyID(ctx, key.ID)
|
||||
if err != nil {
|
||||
return recoveryKey, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
|
||||
}
|
||||
@@ -105,14 +106,14 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(userPassword, passphra
|
||||
}
|
||||
|
||||
// UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key.
|
||||
func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(key *ssss.Key, keys *CrossSigningKeysCache) error {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil {
|
||||
func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil {
|
||||
return err
|
||||
}
|
||||
if err := mach.SSSS.SetEncryptedAccountData(event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil {
|
||||
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil {
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
|
||||
10
vendor/maunium.net/go/mautrix/crypto/cross_sign_store.go
generated
vendored
10
vendor/maunium.net/go/mautrix/crypto/cross_sign_store.go
generated
vendored
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -19,7 +19,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
||||
log := mach.machOrContextLog(ctx)
|
||||
for userID, userKeys := range crossSigningKeys {
|
||||
log := log.With().Str("user_id", userID.String()).Logger()
|
||||
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).
|
||||
Msg("Error fetching current cross-signing keys of user")
|
||||
@@ -32,7 +32,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
||||
if newKeyUsage == curKeyUsage {
|
||||
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
|
||||
// old key is not in the new key map, so we drop signatures made by it
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil {
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(ctx, userID, curKey.Key); err != nil {
|
||||
log.Error().Err(err).Msg("Error deleting old signatures made by user")
|
||||
} else {
|
||||
log.Debug().
|
||||
@@ -50,7 +50,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
||||
log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger()
|
||||
for _, usage := range userKeys.Usage {
|
||||
log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key")
|
||||
if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil {
|
||||
if err = mach.CryptoStore.PutCrossSigningKey(ctx, userID, usage, key); err != nil {
|
||||
log.Error().Err(err).Msg("Error storing cross-signing key")
|
||||
}
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
|
||||
} else {
|
||||
if verified {
|
||||
log.Debug().Err(err).Msg("Cross-signing key signature verified")
|
||||
err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature)
|
||||
err = mach.CryptoStore.PutSignature(ctx, userID, key, signUserID, signingKey, signature)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Error storing cross-signing key signature")
|
||||
}
|
||||
|
||||
16
vendor/maunium.net/go/mautrix/crypto/cross_sign_validation.go
generated
vendored
16
vendor/maunium.net/go/mautrix/crypto/cross_sign_validation.go
generated
vendored
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -23,7 +23,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
|
||||
if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted {
|
||||
return device.Trust, nil
|
||||
}
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
@@ -44,7 +44,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
|
||||
Msg("Self-signing key of user not found")
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
|
||||
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
@@ -57,7 +57,7 @@ func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Devi
|
||||
Msg("Self-signing key of user is not signed by their master key")
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
|
||||
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
@@ -89,7 +89,7 @@ func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool {
|
||||
// IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key.
|
||||
// In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user.
|
||||
func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bool, error) {
|
||||
csPubkeys := mach.GetOwnCrossSigningPublicKeys()
|
||||
csPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx)
|
||||
if csPubkeys == nil {
|
||||
return false, nil
|
||||
}
|
||||
@@ -97,14 +97,14 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
|
||||
return true, nil
|
||||
}
|
||||
// first we verify our user-signing key
|
||||
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
|
||||
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database")
|
||||
return false, err
|
||||
} else if !ourUserSigningKeyTrusted {
|
||||
return false, nil
|
||||
}
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, userID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
@@ -118,7 +118,7 @@ func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bo
|
||||
Msg("Master key of user not found")
|
||||
return false, nil
|
||||
}
|
||||
sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
|
||||
sigExists, err := mach.CryptoStore.IsKeySignedBy(ctx, userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
|
||||
69
vendor/maunium.net/go/mautrix/crypto/cryptohelper/cryptohelper.go
generated
vendored
69
vendor/maunium.net/go/mautrix/crypto/cryptohelper/cryptohelper.go
generated
vendored
@@ -105,7 +105,7 @@ func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoH
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Init() error {
|
||||
func (helper *CryptoHelper) Init(ctx context.Context) error {
|
||||
if helper == nil {
|
||||
return fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
@@ -116,7 +116,7 @@ func (helper *CryptoHelper) Init() error {
|
||||
|
||||
var stateStore crypto.StateStore
|
||||
if helper.managedStateStore != nil {
|
||||
err := helper.managedStateStore.Upgrade()
|
||||
err := helper.managedStateStore.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade client state store: %w", err)
|
||||
}
|
||||
@@ -132,11 +132,14 @@ func (helper *CryptoHelper) Init() error {
|
||||
} else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory {
|
||||
helper.client.Store = managedCryptoStore
|
||||
}
|
||||
err := managedCryptoStore.DB.Upgrade()
|
||||
err := managedCryptoStore.DB.Upgrade(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade crypto state store: %w", err)
|
||||
}
|
||||
storedDeviceID := managedCryptoStore.FindDeviceID()
|
||||
storedDeviceID, err := managedCryptoStore.FindDeviceID(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to find existing device ID: %w", err)
|
||||
}
|
||||
if helper.LoginAs != nil {
|
||||
if storedDeviceID != "" {
|
||||
helper.LoginAs.DeviceID = storedDeviceID
|
||||
@@ -146,7 +149,7 @@ func (helper *CryptoHelper) Init() error {
|
||||
Str("username", helper.LoginAs.Identifier.User).
|
||||
Str("device_id", helper.LoginAs.DeviceID.String()).
|
||||
Msg("Logging in")
|
||||
_, err = helper.client.Login(helper.LoginAs)
|
||||
_, err = helper.client.Login(ctx, helper.LoginAs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -167,10 +170,10 @@ func (helper *CryptoHelper) Init() error {
|
||||
return fmt.Errorf("the client must be logged in")
|
||||
}
|
||||
helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore)
|
||||
err := helper.mach.Load()
|
||||
err := helper.mach.Load(ctx)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load olm account: %w", err)
|
||||
} else if err = helper.verifyDeviceKeysOnServer(); err != nil {
|
||||
} else if err = helper.verifyDeviceKeysOnServer(ctx); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -204,9 +207,9 @@ func (helper *CryptoHelper) Machine() *crypto.OlmMachine {
|
||||
return helper.mach
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) verifyDeviceKeysOnServer() error {
|
||||
func (helper *CryptoHelper) verifyDeviceKeysOnServer(ctx context.Context) error {
|
||||
helper.log.Debug().Msg("Making sure our device has the expected keys on the server")
|
||||
resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{
|
||||
resp, err := helper.client.QueryKeys(ctx, &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
|
||||
helper.client.UserID: {helper.client.DeviceID},
|
||||
},
|
||||
@@ -242,27 +245,29 @@ var NoSessionFound = crypto.NoSessionFound
|
||||
const initialSessionWaitTimeout = 3 * time.Second
|
||||
const extendedSessionWaitTimeout = 22 * time.Second
|
||||
|
||||
func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) {
|
||||
func (helper *CryptoHelper) HandleEncrypted(ctx context.Context, evt *event.Event) {
|
||||
if helper == nil {
|
||||
return
|
||||
}
|
||||
content := evt.Content.AsEncrypted()
|
||||
// TODO use context log instead of helper?
|
||||
log := helper.log.With().
|
||||
Str("event_id", evt.ID.String()).
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Logger()
|
||||
log.Debug().Msg("Decrypting received event")
|
||||
ctx = log.WithContext(ctx)
|
||||
|
||||
decrypted, err := helper.Decrypt(evt)
|
||||
decrypted, err := helper.Decrypt(ctx, evt)
|
||||
if errors.Is(err, NoSessionFound) {
|
||||
log.Debug().
|
||||
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, waiting for keys to arrive...")
|
||||
if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
if helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
|
||||
decrypted, err = helper.Decrypt(evt)
|
||||
decrypted, err = helper.Decrypt(ctx, evt)
|
||||
} else {
|
||||
go helper.waitLongerForSession(log, src, evt)
|
||||
go helper.waitLongerForSession(ctx, log, evt)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -271,14 +276,15 @@ func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.
|
||||
helper.DecryptErrorCallback(evt, err)
|
||||
return
|
||||
}
|
||||
helper.postDecrypt(src, decrypted)
|
||||
helper.postDecrypt(ctx, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) {
|
||||
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted)
|
||||
func (helper *CryptoHelper) postDecrypt(ctx context.Context, decrypted *event.Event) {
|
||||
decrypted.Mautrix.EventSource |= event.SourceDecrypted
|
||||
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(ctx, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
func (helper *CryptoHelper) RequestSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
if helper == nil {
|
||||
return
|
||||
}
|
||||
@@ -294,7 +300,7 @@ func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.Sender
|
||||
Str("device_id", deviceID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Logger()
|
||||
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
|
||||
err := helper.mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
|
||||
userID: {deviceID},
|
||||
helper.client.UserID: {"*"},
|
||||
})
|
||||
@@ -305,55 +311,54 @@ func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.Sender
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
|
||||
func (helper *CryptoHelper) waitLongerForSession(ctx context.Context, log zerolog.Logger, evt *event.Event) {
|
||||
content := evt.Content.AsEncrypted()
|
||||
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
|
||||
|
||||
go helper.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
go helper.RequestSession(context.TODO(), evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
|
||||
if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
if !helper.mach.WaitForSession(ctx, evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
log.Debug().Msg("Didn't get session, giving up")
|
||||
helper.DecryptErrorCallback(evt, NoSessionFound)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
|
||||
decrypted, err := helper.Decrypt(evt)
|
||||
decrypted, err := helper.Decrypt(ctx, evt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to decrypt event")
|
||||
helper.DecryptErrorCallback(evt, err)
|
||||
return
|
||||
}
|
||||
|
||||
helper.postDecrypt(src, decrypted)
|
||||
helper.postDecrypt(ctx, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
func (helper *CryptoHelper) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
if helper == nil {
|
||||
return false
|
||||
}
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
|
||||
return helper.mach.WaitForSession(ctx, roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
||||
func (helper *CryptoHelper) Decrypt(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
if helper == nil {
|
||||
return nil, fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
|
||||
return helper.mach.DecryptMegolmEvent(ctx, evt)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
|
||||
func (helper *CryptoHelper) Encrypt(ctx context.Context, roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
|
||||
if helper == nil {
|
||||
return nil, fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
ctx := context.TODO()
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
|
||||
if err != nil {
|
||||
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
|
||||
if !errors.Is(err, crypto.SessionExpired) && err != crypto.NoGroupSession && !errors.Is(err, crypto.SessionNotShared) {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().
|
||||
@@ -361,7 +366,7 @@ func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, conten
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Got session error while encrypting event, sharing group session and trying again")
|
||||
var users []id.UserID
|
||||
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID)
|
||||
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(ctx, roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get room member list: %w", err)
|
||||
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {
|
||||
|
||||
12
vendor/maunium.net/go/mautrix/crypto/decryptmegolm.go
generated
vendored
12
vendor/maunium.net/go/mautrix/crypto/decryptmegolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -91,7 +91,7 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
|
||||
} else {
|
||||
forwardedKeys = true
|
||||
lastChainItem := sess.ForwardingChains[len(sess.ForwardingChains)-1]
|
||||
device, _ = mach.CryptoStore.FindDeviceByKey(evt.Sender, id.IdentityKey(lastChainItem))
|
||||
device, _ = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, id.IdentityKey(lastChainItem))
|
||||
if device != nil {
|
||||
trustLevel = mach.ResolveTrust(device)
|
||||
} else {
|
||||
@@ -188,7 +188,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
||||
mach.megolmDecryptLock.Lock()
|
||||
defer mach.megolmDecryptLock.Unlock()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
@@ -250,7 +250,7 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
||||
Int("max_messages", sess.MaxMessages).
|
||||
Logger()
|
||||
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
|
||||
err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete fully used session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
@@ -261,14 +261,14 @@ func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *eve
|
||||
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
||||
log.Err(err).Msg("Failed to ratchet session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
} else if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store ratcheted session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Info().Msg("Ratcheted session forward")
|
||||
}
|
||||
} else if didModify {
|
||||
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
if err = mach.CryptoStore.PutGroupSession(ctx, sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
|
||||
14
vendor/maunium.net/go/mautrix/crypto/decryptolm.go
generated
vendored
14
vendor/maunium.net/go/mautrix/crypto/decryptolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -159,7 +159,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
||||
}
|
||||
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
|
||||
err = mach.CryptoStore.UpdateSession(senderKey, session)
|
||||
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting")
|
||||
@@ -170,7 +170,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.U
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
log := *zerolog.Ctx(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
|
||||
sessions, err := mach.CryptoStore.GetSessions(senderKey)
|
||||
sessions, err := mach.CryptoStore.GetSessions(ctx, senderKey)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get session for %s: %w", senderKey, err)
|
||||
@@ -199,7 +199,7 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.C
|
||||
}
|
||||
} else {
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
|
||||
err = mach.CryptoStore.UpdateSession(senderKey, session)
|
||||
err = mach.CryptoStore.UpdateSession(ctx, senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
|
||||
@@ -216,8 +216,8 @@ func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.S
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mach.saveAccount()
|
||||
err = mach.CryptoStore.AddSession(senderKey, session)
|
||||
mach.saveAccount(ctx)
|
||||
err = mach.CryptoStore.AddSession(ctx, senderKey, session)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session")
|
||||
}
|
||||
@@ -228,7 +228,7 @@ const MinUnwedgeInterval = 1 * time.Hour
|
||||
|
||||
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
|
||||
log = log.With().Str("action", "unwedge olm session").Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
ctx := log.WithContext(context.TODO())
|
||||
mach.recentlyUnwedgedLock.Lock()
|
||||
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
|
||||
delta := time.Now().Sub(prevUnwedge)
|
||||
|
||||
60
vendor/maunium.net/go/mautrix/crypto/devicelist.go
generated
vendored
60
vendor/maunium.net/go/mautrix/crypto/devicelist.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -27,9 +27,16 @@ var (
|
||||
InvalidKeySignature = errors.New("invalid signature on device keys")
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) LoadDevices(user id.UserID) map[id.DeviceID]*id.Device {
|
||||
// TODO proper context?
|
||||
return mach.fetchKeys(context.TODO(), []id.UserID{user}, "", true)[user]
|
||||
func (mach *OlmMachine) LoadDevices(ctx context.Context, user id.UserID) (keys map[id.DeviceID]*id.Device) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
|
||||
if keys, err := mach.FetchKeys(ctx, []id.UserID{user}, true); err != nil {
|
||||
log.Err(err).Msg("Failed to load devices")
|
||||
} else if keys != nil {
|
||||
return keys[user]
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
|
||||
@@ -53,7 +60,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
|
||||
Str("signed_device_id", deviceID.String()).
|
||||
Str("signature", signature).
|
||||
Msg("Verified self-signing signature")
|
||||
err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
|
||||
err = mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
@@ -74,7 +81,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
|
||||
}
|
||||
// save signature of device made by its own device signing key
|
||||
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
|
||||
err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
|
||||
err := mach.CryptoStore.PutSignature(ctx, userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
@@ -86,19 +93,16 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
|
||||
// TODO this function should probably return errors
|
||||
func (mach *OlmMachine) FetchKeys(ctx context.Context, users []id.UserID, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device, err error) {
|
||||
req := &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: mautrix.DeviceKeysRequest{},
|
||||
Timeout: 10 * 1000,
|
||||
Token: sinceToken,
|
||||
}
|
||||
log := mach.machOrContextLog(ctx)
|
||||
if !includeUntracked {
|
||||
var err error
|
||||
users, err = mach.CryptoStore.FilterTrackedUsers(users)
|
||||
users, err = mach.CryptoStore.FilterTrackedUsers(ctx, users)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to filter tracked user list")
|
||||
return nil, fmt.Errorf("failed to filter tracked user list: %w", err)
|
||||
}
|
||||
}
|
||||
if len(users) == 0 {
|
||||
@@ -108,10 +112,9 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
||||
req.DeviceKeys[userID] = mautrix.DeviceIDList{}
|
||||
}
|
||||
log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users")
|
||||
resp, err := mach.Client.QueryKeys(req)
|
||||
resp, err := mach.Client.QueryKeys(ctx, req)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to query keys")
|
||||
return
|
||||
return nil, fmt.Errorf("failed to query keys: %w", err)
|
||||
}
|
||||
for server, err := range resp.Failures {
|
||||
log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server")
|
||||
@@ -123,7 +126,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
||||
delete(req.DeviceKeys, userID)
|
||||
|
||||
newDevices := make(map[id.DeviceID]*id.Device)
|
||||
existingDevices, err := mach.CryptoStore.GetDevices(userID)
|
||||
existingDevices, err := mach.CryptoStore.GetDevices(ctx, userID)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get existing devices for user")
|
||||
existingDevices = make(map[id.DeviceID]*id.Device)
|
||||
@@ -151,7 +154,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
||||
}
|
||||
}
|
||||
log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list")
|
||||
err = mach.CryptoStore.PutDevices(userID, newDevices)
|
||||
err = mach.CryptoStore.PutDevices(ctx, userID, newDevices)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update device list")
|
||||
}
|
||||
@@ -169,7 +172,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
||||
Str("identity_key", device.IdentityKey.String()).
|
||||
Str("signing_key", device.SigningKey.String()).
|
||||
Logger()
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, "", device.IdentityKey, "device removed")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact megolm sessions from deleted device")
|
||||
} else {
|
||||
@@ -179,7 +182,7 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
||||
}
|
||||
}
|
||||
}
|
||||
mach.OnDevicesChanged(userID)
|
||||
mach.OnDevicesChanged(ctx, userID)
|
||||
}
|
||||
}
|
||||
for userID := range req.DeviceKeys {
|
||||
@@ -190,25 +193,32 @@ func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceT
|
||||
mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys)
|
||||
mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys)
|
||||
|
||||
return data
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// OnDevicesChanged finds all shared rooms with the given user and invalidates outbound sessions in those rooms.
|
||||
//
|
||||
// This is called automatically whenever a device list change is noticed in ProcessSyncResponse and usually does
|
||||
// not need to be called manually.
|
||||
func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) {
|
||||
func (mach *OlmMachine) OnDevicesChanged(ctx context.Context, userID id.UserID) {
|
||||
if mach.DisableDeviceChangeKeyRotation {
|
||||
return
|
||||
}
|
||||
for _, roomID := range mach.StateStore.FindSharedRooms(userID) {
|
||||
mach.Log.Debug().
|
||||
rooms, err := mach.StateStore.FindSharedRooms(ctx, userID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Err(err).
|
||||
Stringer("with_user_id", userID).
|
||||
Msg("Failed to find shared rooms to invalidate group sessions")
|
||||
return
|
||||
}
|
||||
for _, roomID := range rooms {
|
||||
mach.machOrContextLog(ctx).Debug().
|
||||
Str("user_id", userID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Invalidating group session in room due to device change notification")
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(roomID)
|
||||
err = mach.CryptoStore.RemoveOutboundGroupSession(ctx, roomID)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Err(err).
|
||||
mach.machOrContextLog(ctx).Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Failed to invalidate outbound group session")
|
||||
|
||||
44
vendor/maunium.net/go/mautrix/crypto/encryptmegolm.go
generated
vendored
44
vendor/maunium.net/go/mautrix/crypto/encryptmegolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -84,7 +84,7 @@ func parseMessageIndex(ciphertext []byte) (uint, error) {
|
||||
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
|
||||
} else if session == nil {
|
||||
@@ -116,7 +116,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
|
||||
log = log.With().Uint("message_index", idx).Logger()
|
||||
}
|
||||
log.Debug().Msg("Encrypted event successfully")
|
||||
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
|
||||
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
|
||||
}
|
||||
@@ -137,7 +137,13 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
|
||||
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
|
||||
encryptionEvent, err := mach.StateStore.GetEncryptionEvent(ctx, roomID)
|
||||
if err != nil {
|
||||
mach.machOrContextLog(ctx).Err(err).
|
||||
Stringer("room_id", roomID).
|
||||
Msg("Failed to get encryption event in room")
|
||||
}
|
||||
session := NewOutboundGroupSession(roomID, encryptionEvent)
|
||||
if !mach.DontStoreOutboundKeys {
|
||||
signingKey, idKey := mach.account.Keys()
|
||||
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
|
||||
@@ -165,7 +171,7 @@ func strishArray[T ~string](arr []T) []string {
|
||||
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(ctx, roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get previous outbound group session: %w", err)
|
||||
} else if session != nil && session.Shared && !session.Expired() {
|
||||
@@ -192,7 +198,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
||||
|
||||
for _, userID := range users {
|
||||
log := log.With().Str("target_user_id", userID.String()).Logger()
|
||||
devices, err := mach.CryptoStore.GetDevices(userID)
|
||||
devices, err := mach.CryptoStore.GetDevices(ctx, userID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get devices of user")
|
||||
} else if devices == nil {
|
||||
@@ -223,12 +229,16 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
||||
|
||||
if len(fetchKeys) > 0 {
|
||||
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
|
||||
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
|
||||
log.Debug().
|
||||
Int("device_count", len(devices)).
|
||||
Str("target_user_id", userID.String()).
|
||||
Msg("Got device keys for user")
|
||||
missingSessions[userID] = devices
|
||||
if keys, err := mach.FetchKeys(ctx, fetchKeys, true); err != nil {
|
||||
log.Err(err).Strs("users", strishArray(fetchKeys)).Msg("Failed to fetch missing keys")
|
||||
} else if keys != nil {
|
||||
for userID, devices := range keys {
|
||||
log.Debug().
|
||||
Int("device_count", len(devices)).
|
||||
Str("target_user_id", userID.String()).
|
||||
Msg("Got device keys for user")
|
||||
missingSessions[userID] = devices
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -280,11 +290,11 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
||||
Int("user_count", len(toDeviceWithheld.Messages)).
|
||||
Msg("Sending to-device messages to report withheld key")
|
||||
// TODO remove the next 4 lines once clients support m.room_key.withheld
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
|
||||
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)")
|
||||
}
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
|
||||
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to report withheld keys")
|
||||
}
|
||||
@@ -292,7 +302,7 @@ func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID,
|
||||
|
||||
log.Debug().Msg("Group session successfully shared")
|
||||
session.Shared = true
|
||||
return mach.CryptoStore.AddOutboundGroupSession(session)
|
||||
return mach.CryptoStore.AddOutboundGroupSession(ctx, session)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
|
||||
@@ -327,7 +337,7 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
|
||||
Int("device_count", deviceCount).
|
||||
Int("user_count", len(toDevice.Messages)).
|
||||
Msg("Sending to-device messages to share group session")
|
||||
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
|
||||
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted, toDevice)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -367,7 +377,7 @@ func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *Out
|
||||
Reason: "This device does not encrypt messages for unverified devices",
|
||||
}}
|
||||
session.Users[userKey] = OGSIgnored
|
||||
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
|
||||
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
|
||||
} else if deviceSession == nil {
|
||||
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")
|
||||
|
||||
14
vendor/maunium.net/go/mautrix/crypto/encryptolm.go
generated
vendored
14
vendor/maunium.net/go/mautrix/crypto/encryptolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -38,7 +38,7 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Encrypting olm message")
|
||||
msgType, ciphertext := session.Encrypt(plaintext)
|
||||
err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session)
|
||||
err = mach.CryptoStore.UpdateSession(ctx, recipient.IdentityKey, session)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
|
||||
}
|
||||
@@ -54,8 +54,8 @@ func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool {
|
||||
if !mach.CryptoStore.HasSession(identityKey) {
|
||||
func (mach *OlmMachine) shouldCreateNewSession(ctx context.Context, identityKey id.IdentityKey) bool {
|
||||
if !mach.CryptoStore.HasSession(ctx, identityKey) {
|
||||
return true
|
||||
}
|
||||
mach.devicesToUnwedgeLock.Lock()
|
||||
@@ -72,7 +72,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
|
||||
for userID, devices := range input {
|
||||
request[userID] = make(map[id.DeviceID]id.KeyAlgorithm)
|
||||
for deviceID, identity := range devices {
|
||||
if mach.shouldCreateNewSession(identity.IdentityKey) {
|
||||
if mach.shouldCreateNewSession(ctx, identity.IdentityKey) {
|
||||
request[userID][deviceID] = id.KeyAlgorithmSignedCurve25519
|
||||
}
|
||||
}
|
||||
@@ -83,7 +83,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
|
||||
if len(request) == 0 {
|
||||
return nil
|
||||
}
|
||||
resp, err := mach.Client.ClaimKeys(&mautrix.ReqClaimKeys{
|
||||
resp, err := mach.Client.ClaimKeys(ctx, &mautrix.ReqClaimKeys{
|
||||
OneTimeKeys: request,
|
||||
Timeout: 10 * 1000,
|
||||
})
|
||||
@@ -117,7 +117,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
|
||||
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
|
||||
} else {
|
||||
wrapped := wrapSession(sess)
|
||||
err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped)
|
||||
err = mach.CryptoStore.AddSession(ctx, identity.IdentityKey, wrapped)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store created outbound session")
|
||||
} else {
|
||||
|
||||
5
vendor/maunium.net/go/mautrix/crypto/goolm/README.md
generated
vendored
Normal file
5
vendor/maunium.net/go/mautrix/crypto/goolm/README.md
generated
vendored
Normal file
@@ -0,0 +1,5 @@
|
||||
# go-olm
|
||||
This is a fork of [DerLukas's goolm](https://codeberg.org/DerLukas/goolm),
|
||||
a pure Go implementation of libolm.
|
||||
|
||||
The original project is licensed under the MIT license.
|
||||
522
vendor/maunium.net/go/mautrix/crypto/goolm/account/account.go
generated
vendored
Normal file
522
vendor/maunium.net/go/mautrix/crypto/goolm/account/account.go
generated
vendored
Normal file
@@ -0,0 +1,522 @@
|
||||
// account packages an account which stores the identity, one time keys and fallback keys.
|
||||
package account
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
)
|
||||
|
||||
const (
|
||||
accountPickleVersionJSON byte = 1
|
||||
accountPickleVersionLibOLM uint32 = 4
|
||||
MaxOneTimeKeys int = 100 //maximum number of stored one time keys per Account
|
||||
)
|
||||
|
||||
// Account stores an account for end to end encrypted messaging via the olm protocol.
|
||||
// An Account can not be used to en/decrypt messages. However it can be used to contruct new olm sessions, which in turn do the en/decryption.
|
||||
// There is no tracking of sessions in an account.
|
||||
type Account struct {
|
||||
IdKeys struct {
|
||||
Ed25519 crypto.Ed25519KeyPair `json:"ed25519,omitempty"`
|
||||
Curve25519 crypto.Curve25519KeyPair `json:"curve25519,omitempty"`
|
||||
} `json:"identity_keys"`
|
||||
OTKeys []crypto.OneTimeKey `json:"one_time_keys"`
|
||||
CurrentFallbackKey crypto.OneTimeKey `json:"current_fallback_key,omitempty"`
|
||||
PrevFallbackKey crypto.OneTimeKey `json:"prev_fallback_key,omitempty"`
|
||||
NextOneTimeKeyID uint32 `json:"next_one_time_key_id,omitempty"`
|
||||
NumFallbackKeys uint8 `json:"number_fallback_keys"`
|
||||
}
|
||||
|
||||
// AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
|
||||
func AccountFromJSONPickled(pickled, key []byte) (*Account, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
a := &Account{}
|
||||
err := a.UnpickleAsJSON(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
|
||||
func AccountFromPickled(pickled, key []byte) (*Account, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
a := &Account{}
|
||||
err := a.Unpickle(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation.
|
||||
func NewAccount(reader io.Reader) (*Account, error) {
|
||||
a := &Account{}
|
||||
kPEd25519, err := crypto.Ed25519GenerateKey(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.IdKeys.Ed25519 = kPEd25519
|
||||
kPCurve25519, err := crypto.Curve25519GenerateKey(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
a.IdKeys.Curve25519 = kPCurve25519
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (a Account) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(a, accountPickleVersionJSON, key)
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
||||
func (a *Account) UnpickleAsJSON(pickled, key []byte) error {
|
||||
return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON)
|
||||
}
|
||||
|
||||
// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string.
|
||||
func (a Account) IdentityKeysJSON() ([]byte, error) {
|
||||
res := struct {
|
||||
Ed25519 string `json:"ed25519"`
|
||||
Curve25519 string `json:"curve25519"`
|
||||
}{}
|
||||
ed25519, curve25519 := a.IdentityKeys()
|
||||
res.Ed25519 = string(ed25519)
|
||||
res.Curve25519 = string(curve25519)
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
||||
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account.
|
||||
func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
|
||||
ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey))
|
||||
curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey))
|
||||
return ed25519, curve25519
|
||||
}
|
||||
|
||||
// Sign returns the signature of a message using the Ed25519 key for this Account.
|
||||
func (a Account) Sign(message []byte) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil
|
||||
}
|
||||
|
||||
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
|
||||
//
|
||||
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
|
||||
func (a Account) OneTimeKeys() map[string]id.Curve25519 {
|
||||
oneTimeKeys := make(map[string]id.Curve25519)
|
||||
for _, curKey := range a.OTKeys {
|
||||
if !curKey.Published {
|
||||
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
|
||||
}
|
||||
}
|
||||
return oneTimeKeys
|
||||
}
|
||||
|
||||
//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string.
|
||||
//
|
||||
//The returned JSON is of format:
|
||||
/*
|
||||
{
|
||||
Curve25519: {
|
||||
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
|
||||
"AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
|
||||
}
|
||||
}
|
||||
*/
|
||||
func (a Account) OneTimeKeysJSON() ([]byte, error) {
|
||||
res := make(map[string]map[string]id.Curve25519)
|
||||
otKeys := a.OneTimeKeys()
|
||||
res["Curve25519"] = otKeys
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
||||
// MarkKeysAsPublished marks the current set of one time keys and the fallback key as being
|
||||
// published.
|
||||
func (a *Account) MarkKeysAsPublished() {
|
||||
for keyIndex := range a.OTKeys {
|
||||
if !a.OTKeys[keyIndex].Published {
|
||||
a.OTKeys[keyIndex].Published = true
|
||||
}
|
||||
}
|
||||
a.CurrentFallbackKey.Published = true
|
||||
}
|
||||
|
||||
// GenOneTimeKeys generates a number of new one time keys. If the total number
|
||||
// of keys stored by this Account exceeds MaxOneTimeKeys then the older
|
||||
// keys are discarded. If reader is nil, crypto/rand is used for the key creation.
|
||||
func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error {
|
||||
for i := uint(0); i < num; i++ {
|
||||
key := crypto.OneTimeKey{
|
||||
Published: false,
|
||||
ID: a.NextOneTimeKeyID,
|
||||
}
|
||||
newKP, err := crypto.Curve25519GenerateKey(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key.Key = newKP
|
||||
a.NextOneTimeKeyID++
|
||||
a.OTKeys = append([]crypto.OneTimeKey{key}, a.OTKeys...)
|
||||
}
|
||||
if len(a.OTKeys) > MaxOneTimeKeys {
|
||||
a.OTKeys = a.OTKeys[:MaxOneTimeKeys]
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewOutboundSession creates a new outbound session to a
|
||||
// given curve25519 identity Key and one time key.
|
||||
func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) {
|
||||
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
|
||||
return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
theirOneTimeKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirOneTimeKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewInboundSession creates a new inbound session from an incoming PRE_KEY message.
|
||||
func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
var theirIdentityKeyDecoded *crypto.Curve25519PublicKey
|
||||
var err error
|
||||
if theirIdentityKey != nil {
|
||||
theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
theirIdentityKeyCurve := crypto.Curve25519PublicKey(theirIdentityKeyDecodedByte)
|
||||
theirIdentityKeyDecoded = &theirIdentityKeyCurve
|
||||
}
|
||||
|
||||
s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
|
||||
for curIndex := range a.OTKeys {
|
||||
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
|
||||
return &a.OTKeys[curIndex]
|
||||
}
|
||||
}
|
||||
if a.NumFallbackKeys >= 1 && a.CurrentFallbackKey.Key.PublicKey.Equal(toFind) {
|
||||
return &a.CurrentFallbackKey
|
||||
}
|
||||
if a.NumFallbackKeys >= 2 && a.PrevFallbackKey.Key.PublicKey.Equal(toFind) {
|
||||
return &a.PrevFallbackKey
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s.
|
||||
func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) {
|
||||
toFind := s.BobOneTimeKey
|
||||
for curIndex := range a.OTKeys {
|
||||
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
|
||||
//Remove and return
|
||||
a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1]
|
||||
a.OTKeys = a.OTKeys[:len(a.OTKeys)-1]
|
||||
return
|
||||
}
|
||||
}
|
||||
//if the key is a fallback or prevFallback, don't remove it
|
||||
}
|
||||
|
||||
// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation.
|
||||
func (a *Account) GenFallbackKey(reader io.Reader) error {
|
||||
a.PrevFallbackKey = a.CurrentFallbackKey
|
||||
key := crypto.OneTimeKey{
|
||||
Published: false,
|
||||
ID: a.NextOneTimeKeyID,
|
||||
}
|
||||
newKP, err := crypto.Curve25519GenerateKey(reader)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
key.Key = newKP
|
||||
a.NextOneTimeKeyID++
|
||||
if a.NumFallbackKeys < 2 {
|
||||
a.NumFallbackKeys++
|
||||
}
|
||||
a.CurrentFallbackKey = key
|
||||
return nil
|
||||
}
|
||||
|
||||
// FallbackKey returns the public part of the current fallback key of the Account.
|
||||
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
|
||||
func (a Account) FallbackKey() map[string]id.Curve25519 {
|
||||
keys := make(map[string]id.Curve25519)
|
||||
if a.NumFallbackKeys >= 1 {
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
//FallbackKeyJSON returns the public part of the current fallback key of the Account as a JSON string.
|
||||
//
|
||||
//The returned JSON is of format:
|
||||
/*
|
||||
{
|
||||
curve25519: {
|
||||
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
|
||||
}
|
||||
}
|
||||
*/
|
||||
func (a Account) FallbackKeyJSON() ([]byte, error) {
|
||||
res := make(map[string]map[string]id.Curve25519)
|
||||
fbk := a.FallbackKey()
|
||||
res["curve25519"] = fbk
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
||||
// FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished.
|
||||
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
|
||||
func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
|
||||
keys := make(map[string]id.Curve25519)
|
||||
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
|
||||
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
|
||||
}
|
||||
return keys
|
||||
}
|
||||
|
||||
//FallbackKeyUnpublishedJSON returns the public part of the current fallback key, only if it is unpublished, of the Account as a JSON string.
|
||||
//
|
||||
//The returned JSON is of format:
|
||||
/*
|
||||
{
|
||||
curve25519: {
|
||||
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
|
||||
}
|
||||
}
|
||||
*/
|
||||
func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
|
||||
res := make(map[string]map[string]id.Curve25519)
|
||||
fbk := a.FallbackKeyUnpublished()
|
||||
res["curve25519"] = fbk
|
||||
return json.Marshal(res)
|
||||
}
|
||||
|
||||
// ForgetOldFallbackKey resets the previous fallback key in the account.
|
||||
func (a *Account) ForgetOldFallbackKey() {
|
||||
if a.NumFallbackKeys >= 2 {
|
||||
a.NumFallbackKeys = 1
|
||||
a.PrevFallbackKey = crypto.OneTimeKey{}
|
||||
}
|
||||
}
|
||||
|
||||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (a *Account) Unpickle(pickled, key []byte) error {
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = a.UnpickleLibOlm(decrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read.
|
||||
func (a *Account) UnpickleLibOlm(value []byte) (int, error) {
|
||||
//First 4 bytes are the accountPickleVersion
|
||||
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch pickledVersion {
|
||||
case accountPickleVersionLibOLM, 3, 2:
|
||||
default:
|
||||
return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion)
|
||||
}
|
||||
//read ed25519 key pair
|
||||
readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
//read curve25519 key pair
|
||||
readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
//Read number of onetimeKeys
|
||||
numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
//Read i one time keys
|
||||
a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys)
|
||||
for i := uint32(0); i < numberOTKeys; i++ {
|
||||
readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
}
|
||||
if pickledVersion <= 2 {
|
||||
// version 2 did not have fallback keys
|
||||
a.NumFallbackKeys = 0
|
||||
} else if pickledVersion == 3 {
|
||||
// version 3 used the published flag to indicate how many fallback keys
|
||||
// were present (we'll have to assume that the keys were published)
|
||||
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
if a.CurrentFallbackKey.Published {
|
||||
if a.PrevFallbackKey.Published {
|
||||
a.NumFallbackKeys = 2
|
||||
} else {
|
||||
a.NumFallbackKeys = 1
|
||||
}
|
||||
} else {
|
||||
a.NumFallbackKeys = 0
|
||||
}
|
||||
} else {
|
||||
//Read number of fallback keys
|
||||
numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
a.NumFallbackKeys = numFallbackKeys
|
||||
if a.NumFallbackKeys >= 1 {
|
||||
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
if a.NumFallbackKeys >= 2 {
|
||||
readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
}
|
||||
}
|
||||
}
|
||||
//Read next onetime key id
|
||||
nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
a.NextOneTimeKeyID = nextOTKeyID
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm().
|
||||
func (a Account) Pickle(key []byte) ([]byte, error) {
|
||||
pickeledBytes := make([]byte, a.PickleLen())
|
||||
written, err := a.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if written != len(pickeledBytes) {
|
||||
return nil, errors.New("number of written bytes not correct")
|
||||
}
|
||||
encrypted, err := cipher.Pickle(key, pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (a Account) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < a.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target)
|
||||
writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle account: %w", err)
|
||||
}
|
||||
written += writtenEdKey
|
||||
writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle account: %w", err)
|
||||
}
|
||||
written += writtenCurveKey
|
||||
written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:])
|
||||
for _, curOTKey := range a.OTKeys {
|
||||
writtenOT, err := curOTKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle account: %w", err)
|
||||
}
|
||||
written += writtenOT
|
||||
}
|
||||
written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:])
|
||||
if a.NumFallbackKeys >= 1 {
|
||||
writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle account: %w", err)
|
||||
}
|
||||
written += writtenOT
|
||||
|
||||
if a.NumFallbackKeys >= 2 {
|
||||
writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle account: %w", err)
|
||||
}
|
||||
written += writtenOT
|
||||
}
|
||||
}
|
||||
written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:])
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled Account will have.
|
||||
func (a Account) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM)
|
||||
length += a.IdKeys.Ed25519.PickleLen()
|
||||
length += a.IdKeys.Curve25519.PickleLen()
|
||||
length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys)))
|
||||
length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen())
|
||||
length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys)
|
||||
length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen())
|
||||
length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID)
|
||||
return length
|
||||
}
|
||||
22
vendor/maunium.net/go/mautrix/crypto/goolm/base64.go
generated
vendored
Normal file
22
vendor/maunium.net/go/mautrix/crypto/goolm/base64.go
generated
vendored
Normal file
@@ -0,0 +1,22 @@
|
||||
package goolm
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
)
|
||||
|
||||
// Deprecated: base64.RawStdEncoding should be used directly
|
||||
func Base64Decode(input []byte) ([]byte, error) {
|
||||
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
|
||||
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decoded[:writtenBytes], nil
|
||||
}
|
||||
|
||||
// Deprecated: base64.RawStdEncoding should be used directly
|
||||
func Base64Encode(input []byte) []byte {
|
||||
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
|
||||
base64.RawStdEncoding.Encode(encoded, input)
|
||||
return encoded
|
||||
}
|
||||
96
vendor/maunium.net/go/mautrix/crypto/goolm/cipher/aes_sha256.go
generated
vendored
Normal file
96
vendor/maunium.net/go/mautrix/crypto/goolm/cipher/aes_sha256.go
generated
vendored
Normal file
@@ -0,0 +1,96 @@
|
||||
package cipher
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
// derivedAESKeys stores the derived keys for the AESSHA256 cipher
|
||||
type derivedAESKeys struct {
|
||||
key []byte
|
||||
hmacKey []byte
|
||||
iv []byte
|
||||
}
|
||||
|
||||
// deriveAESKeys derives three keys for the AESSHA256 cipher
|
||||
func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) {
|
||||
hkdf := crypto.HKDFSHA256(key, nil, kdfInfo)
|
||||
keys := &derivedAESKeys{
|
||||
key: make([]byte, 32),
|
||||
hmacKey: make([]byte, 32),
|
||||
iv: make([]byte, 16),
|
||||
}
|
||||
if _, err := io.ReadFull(hkdf, keys.key); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.ReadFull(hkdf, keys.hmacKey); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if _, err := io.ReadFull(hkdf, keys.iv); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256.
|
||||
func AESSha512BlockSize() int {
|
||||
return crypto.AESCBCBlocksize()
|
||||
}
|
||||
|
||||
// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256.
|
||||
type AESSHA256 struct {
|
||||
kdfInfo []byte
|
||||
}
|
||||
|
||||
// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo).
|
||||
func NewAESSHA256(kdfInfo []byte) *AESSHA256 {
|
||||
return &AESSHA256{
|
||||
kdfInfo: kdfInfo,
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes).
|
||||
func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) {
|
||||
keys, err := deriveAESKeys(c.kdfInfo, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ciphertext, err = crypto.AESCBCEncrypt(keys.key, keys.iv, plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ciphertext, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes).
|
||||
func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) {
|
||||
keys, err := deriveAESKeys(c.kdfInfo, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
plaintext, err = crypto.AESCBCDecrypt(keys.key, keys.iv, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes).
|
||||
func (c AESSHA256) MAC(key, message []byte) ([]byte, error) {
|
||||
keys, err := deriveAESKeys(c.kdfInfo, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return crypto.HMACSHA256(keys.hmacKey, message), nil
|
||||
}
|
||||
|
||||
// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes).
|
||||
func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) {
|
||||
mac, err := c.MAC(key, message)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil
|
||||
}
|
||||
17
vendor/maunium.net/go/mautrix/crypto/goolm/cipher/main.go
generated
vendored
Normal file
17
vendor/maunium.net/go/mautrix/crypto/goolm/cipher/main.go
generated
vendored
Normal file
@@ -0,0 +1,17 @@
|
||||
// cipher provides the methods and structs to do encryptions for olm/megolm.
|
||||
package cipher
|
||||
|
||||
// Cipher defines a valid cipher.
|
||||
type Cipher interface {
|
||||
// Encrypt encrypts the plaintext.
|
||||
Encrypt(key, plaintext []byte) (ciphertext []byte, err error)
|
||||
|
||||
// Decrypt decrypts the ciphertext.
|
||||
Decrypt(key, ciphertext []byte) (plaintext []byte, err error)
|
||||
|
||||
//MAC returns the MAC of the message calculated with the key.
|
||||
MAC(key, message []byte) ([]byte, error)
|
||||
|
||||
//Verify checks the MAC of the message calculated with the key against the givenMAC.
|
||||
Verify(key, message, givenMAC []byte) (bool, error)
|
||||
}
|
||||
58
vendor/maunium.net/go/mautrix/crypto/goolm/cipher/pickle.go
generated
vendored
Normal file
58
vendor/maunium.net/go/mautrix/crypto/goolm/cipher/pickle.go
generated
vendored
Normal file
@@ -0,0 +1,58 @@
|
||||
package cipher
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
)
|
||||
|
||||
const (
|
||||
kdfPickle = "Pickle" //used to derive the keys for encryption
|
||||
pickleMACLength = 8
|
||||
)
|
||||
|
||||
// PickleBlockSize returns the blocksize of the used cipher.
|
||||
func PickleBlockSize() int {
|
||||
return AESSha512BlockSize()
|
||||
}
|
||||
|
||||
// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64.
|
||||
func Pickle(key, input []byte) ([]byte, error) {
|
||||
pickleCipher := NewAESSHA256([]byte(kdfPickle))
|
||||
ciphertext, err := pickleCipher.Encrypt(key, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mac, err := pickleCipher.MAC(key, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ciphertext = append(ciphertext, mac[:pickleMACLength]...)
|
||||
encoded := goolm.Base64Encode(ciphertext)
|
||||
return encoded, nil
|
||||
}
|
||||
|
||||
// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256.
|
||||
func Unpickle(key, input []byte) ([]byte, error) {
|
||||
pickleCipher := NewAESSHA256([]byte(kdfPickle))
|
||||
ciphertext, err := goolm.Base64Decode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//remove mac and check
|
||||
verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !verified {
|
||||
return nil, fmt.Errorf("decrypt pickle: %w", goolm.ErrBadMAC)
|
||||
}
|
||||
//Set to next block size
|
||||
targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize())
|
||||
copy(targetCipherText, ciphertext)
|
||||
plaintext, err := pickleCipher.Decrypt(key, targetCipherText)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
75
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/aes_cbc.go
generated
vendored
Normal file
75
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/aes_cbc.go
generated
vendored
Normal file
@@ -0,0 +1,75 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
)
|
||||
|
||||
// AESCBCBlocksize returns the blocksize of the encryption method
|
||||
func AESCBCBlocksize() int {
|
||||
return aes.BlockSize
|
||||
}
|
||||
|
||||
// AESCBCEncrypt encrypts the plaintext with the key and iv. len(iv) must be equal to the blocksize!
|
||||
func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
|
||||
}
|
||||
if len(iv) != AESCBCBlocksize() {
|
||||
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
|
||||
}
|
||||
var cipherText []byte
|
||||
plaintext = pkcs5Padding(plaintext, AESCBCBlocksize())
|
||||
if len(plaintext)%AESCBCBlocksize() != 0 {
|
||||
return nil, fmt.Errorf("message: %w", goolm.ErrNotMultipleBlocksize)
|
||||
}
|
||||
block, err := aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cipherText = make([]byte, len(plaintext))
|
||||
cbc := cipher.NewCBCEncrypter(block, iv)
|
||||
cbc.CryptBlocks(cipherText, plaintext)
|
||||
return cipherText, nil
|
||||
}
|
||||
|
||||
// AESCBCDecrypt decrypts the ciphertext with the key and iv. len(iv) must be equal to the blocksize!
|
||||
func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
|
||||
}
|
||||
if len(iv) != AESCBCBlocksize() {
|
||||
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
|
||||
}
|
||||
var block cipher.Block
|
||||
var err error
|
||||
block, err = aes.NewCipher(key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if len(ciphertext) < AESCBCBlocksize() {
|
||||
return nil, fmt.Errorf("ciphertext: %w", goolm.ErrNotMultipleBlocksize)
|
||||
}
|
||||
|
||||
cbc := cipher.NewCBCDecrypter(block, iv)
|
||||
cbc.CryptBlocks(ciphertext, ciphertext)
|
||||
return pkcs5Unpadding(ciphertext), nil
|
||||
}
|
||||
|
||||
// pkcs5Padding paddes the plaintext to be used in the AESCBC encryption.
|
||||
func pkcs5Padding(plaintext []byte, blockSize int) []byte {
|
||||
padding := (blockSize - len(plaintext)%blockSize)
|
||||
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
|
||||
return append(plaintext, padtext...)
|
||||
}
|
||||
|
||||
// pkcs5Unpadding undoes the padding to the plaintext after AESCBC decryption.
|
||||
func pkcs5Unpadding(plaintext []byte) []byte {
|
||||
length := len(plaintext)
|
||||
unpadding := int(plaintext[length-1])
|
||||
return plaintext[:(length - unpadding)]
|
||||
}
|
||||
186
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/curve25519.go
generated
vendored
Normal file
186
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/curve25519.go
generated
vendored
Normal file
@@ -0,0 +1,186 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/curve25519"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
Curve25519KeyLength = curve25519.ScalarSize //The length of the private key.
|
||||
curve25519PubKeyLength = 32
|
||||
)
|
||||
|
||||
// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand.
|
||||
func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) {
|
||||
privateKeyByte := make([]byte, Curve25519KeyLength)
|
||||
if reader == nil {
|
||||
_, err := rand.Read(privateKeyByte)
|
||||
if err != nil {
|
||||
return Curve25519KeyPair{}, err
|
||||
}
|
||||
} else {
|
||||
_, err := reader.Read(privateKeyByte)
|
||||
if err != nil {
|
||||
return Curve25519KeyPair{}, err
|
||||
}
|
||||
}
|
||||
|
||||
privateKey := Curve25519PrivateKey(privateKeyByte)
|
||||
|
||||
publicKey, err := privateKey.PubKey()
|
||||
if err != nil {
|
||||
return Curve25519KeyPair{}, err
|
||||
}
|
||||
return Curve25519KeyPair{
|
||||
PrivateKey: Curve25519PrivateKey(privateKey),
|
||||
PublicKey: Curve25519PublicKey(publicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given.
|
||||
func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) {
|
||||
publicKey, err := private.PubKey()
|
||||
if err != nil {
|
||||
return Curve25519KeyPair{}, err
|
||||
}
|
||||
return Curve25519KeyPair{
|
||||
PrivateKey: private,
|
||||
PublicKey: Curve25519PublicKey(publicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Curve25519KeyPair stores both parts of a curve25519 key.
|
||||
type Curve25519KeyPair struct {
|
||||
PrivateKey Curve25519PrivateKey `json:"private,omitempty"`
|
||||
PublicKey Curve25519PublicKey `json:"public,omitempty"`
|
||||
}
|
||||
|
||||
// B64Encoded returns a base64 encoded string of the public key.
|
||||
func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
|
||||
return c.PublicKey.B64Encoded()
|
||||
}
|
||||
|
||||
// SharedSecret returns the shared secret between the key pair and the given public key.
|
||||
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
|
||||
return c.PrivateKey.SharedSecret(pubKey)
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < c.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle curve25519 key pair: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written, err := c.PublicKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle curve25519 key pair: %w", err)
|
||||
}
|
||||
if len(c.PrivateKey) != Curve25519KeyLength {
|
||||
written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:])
|
||||
} else {
|
||||
written += libolmpickle.PickleBytes(c.PrivateKey, target[written:])
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read.
|
||||
func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) {
|
||||
//unpickle PubKey
|
||||
read, err := c.PublicKey.UnpickleLibOlm(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
//unpickle PrivateKey
|
||||
privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519KeyLength)
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
c.PrivateKey = privKey
|
||||
return read + readPriv, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled key pair will have.
|
||||
func (c Curve25519KeyPair) PickleLen() int {
|
||||
lenPublic := c.PublicKey.PickleLen()
|
||||
var lenPrivate int
|
||||
if len(c.PrivateKey) != Curve25519KeyLength {
|
||||
lenPrivate = libolmpickle.PickleBytesLen(make([]byte, Curve25519KeyLength))
|
||||
} else {
|
||||
lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey)
|
||||
}
|
||||
return lenPublic + lenPrivate
|
||||
}
|
||||
|
||||
// Curve25519PrivateKey represents the private key for curve25519 usage
|
||||
type Curve25519PrivateKey []byte
|
||||
|
||||
// Equal compares the private key to the given private key.
|
||||
func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool {
|
||||
return bytes.Equal(c, x)
|
||||
}
|
||||
|
||||
// PubKey returns the public key derived from the private key.
|
||||
func (c Curve25519PrivateKey) PubKey() (Curve25519PublicKey, error) {
|
||||
publicKey, err := curve25519.X25519(c, curve25519.Basepoint)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return publicKey, nil
|
||||
}
|
||||
|
||||
// SharedSecret returns the shared secret between the private key and the given public key.
|
||||
func (c Curve25519PrivateKey) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
|
||||
return curve25519.X25519(c, pubKey)
|
||||
}
|
||||
|
||||
// Curve25519PublicKey represents the public key for curve25519 usage
|
||||
type Curve25519PublicKey []byte
|
||||
|
||||
// Equal compares the public key to the given public key.
|
||||
func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool {
|
||||
return bytes.Equal(c, x)
|
||||
}
|
||||
|
||||
// B64Encoded returns a base64 encoded string of the public key.
|
||||
func (c Curve25519PublicKey) B64Encoded() id.Curve25519 {
|
||||
return id.Curve25519(base64.RawStdEncoding.EncodeToString(c))
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < c.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle curve25519 public key: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
if len(c) != curve25519PubKeyLength {
|
||||
return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil
|
||||
}
|
||||
return libolmpickle.PickleBytes(c, target), nil
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read.
|
||||
func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) {
|
||||
unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, curve25519PubKeyLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
*c = unpickled
|
||||
return readBytes, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled public key will have.
|
||||
func (c Curve25519PublicKey) PickleLen() int {
|
||||
if len(c) != curve25519PubKeyLength {
|
||||
return libolmpickle.PickleBytesLen(make([]byte, curve25519PubKeyLength))
|
||||
}
|
||||
return libolmpickle.PickleBytesLen(c)
|
||||
}
|
||||
181
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/ed25519.go
generated
vendored
Normal file
181
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/ed25519.go
generated
vendored
Normal file
@@ -0,0 +1,181 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"crypto/ed25519"
|
||||
"encoding/base64"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
ED25519SignatureSize = ed25519.SignatureSize //The length of a signature
|
||||
)
|
||||
|
||||
// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand.
|
||||
func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) {
|
||||
publicKey, privateKey, err := ed25519.GenerateKey(reader)
|
||||
if err != nil {
|
||||
return Ed25519KeyPair{}, err
|
||||
}
|
||||
return Ed25519KeyPair{
|
||||
PrivateKey: Ed25519PrivateKey(privateKey),
|
||||
PublicKey: Ed25519PublicKey(publicKey),
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Ed25519GenerateFromPrivate creates a new ed25519 key pair with the private key given.
|
||||
func Ed25519GenerateFromPrivate(privKey Ed25519PrivateKey) Ed25519KeyPair {
|
||||
return Ed25519KeyPair{
|
||||
PrivateKey: privKey,
|
||||
PublicKey: privKey.PubKey(),
|
||||
}
|
||||
}
|
||||
|
||||
// Ed25519GenerateFromSeed creates a new ed25519 key pair with a given seed.
|
||||
func Ed25519GenerateFromSeed(seed []byte) Ed25519KeyPair {
|
||||
privKey := Ed25519PrivateKey(ed25519.NewKeyFromSeed(seed))
|
||||
return Ed25519KeyPair{
|
||||
PrivateKey: privKey,
|
||||
PublicKey: privKey.PubKey(),
|
||||
}
|
||||
}
|
||||
|
||||
// Ed25519KeyPair stores both parts of a ed25519 key.
|
||||
type Ed25519KeyPair struct {
|
||||
PrivateKey Ed25519PrivateKey `json:"private,omitempty"`
|
||||
PublicKey Ed25519PublicKey `json:"public,omitempty"`
|
||||
}
|
||||
|
||||
// B64Encoded returns a base64 encoded string of the public key.
|
||||
func (c Ed25519KeyPair) B64Encoded() id.Ed25519 {
|
||||
return id.Ed25519(base64.RawStdEncoding.EncodeToString(c.PublicKey))
|
||||
}
|
||||
|
||||
// Sign returns the signature for the message.
|
||||
func (c Ed25519KeyPair) Sign(message []byte) []byte {
|
||||
return c.PrivateKey.Sign(message)
|
||||
}
|
||||
|
||||
// Verify checks the signature of the message against the givenSignature
|
||||
func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool {
|
||||
return c.PublicKey.Verify(message, givenSignature)
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < c.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle ed25519 key pair: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written, err := c.PublicKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle ed25519 key pair: %w", err)
|
||||
}
|
||||
|
||||
if len(c.PrivateKey) != ed25519.PrivateKeySize {
|
||||
written += libolmpickle.PickleBytes(make([]byte, ed25519.PrivateKeySize), target[written:])
|
||||
} else {
|
||||
written += libolmpickle.PickleBytes(c.PrivateKey, target[written:])
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read.
|
||||
func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) {
|
||||
//unpickle PubKey
|
||||
read, err := c.PublicKey.UnpickleLibOlm(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
//unpickle PrivateKey
|
||||
privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], ed25519.PrivateKeySize)
|
||||
if err != nil {
|
||||
return read, err
|
||||
}
|
||||
c.PrivateKey = privKey
|
||||
return read + readPriv, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled key pair will have.
|
||||
func (c Ed25519KeyPair) PickleLen() int {
|
||||
lenPublic := c.PublicKey.PickleLen()
|
||||
var lenPrivate int
|
||||
if len(c.PrivateKey) != ed25519.PrivateKeySize {
|
||||
lenPrivate = libolmpickle.PickleBytesLen(make([]byte, ed25519.PrivateKeySize))
|
||||
} else {
|
||||
lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey)
|
||||
}
|
||||
return lenPublic + lenPrivate
|
||||
}
|
||||
|
||||
// Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper.
|
||||
type Ed25519PrivateKey ed25519.PrivateKey
|
||||
|
||||
// Equal compares the private key to the given private key.
|
||||
func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool {
|
||||
return bytes.Equal(c, x)
|
||||
}
|
||||
|
||||
// PubKey returns the public key derived from the private key.
|
||||
func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey {
|
||||
publicKey := ed25519.PrivateKey(c).Public()
|
||||
return Ed25519PublicKey(publicKey.(ed25519.PublicKey))
|
||||
}
|
||||
|
||||
// Sign returns the signature for the message.
|
||||
func (c Ed25519PrivateKey) Sign(message []byte) []byte {
|
||||
return ed25519.Sign(ed25519.PrivateKey(c), message)
|
||||
}
|
||||
|
||||
// Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper.
|
||||
type Ed25519PublicKey ed25519.PublicKey
|
||||
|
||||
// Equal compares the public key to the given public key.
|
||||
func (c Ed25519PublicKey) Equal(x Ed25519PublicKey) bool {
|
||||
return bytes.Equal(c, x)
|
||||
}
|
||||
|
||||
// B64Encoded returns a base64 encoded string of the public key.
|
||||
func (c Ed25519PublicKey) B64Encoded() id.Curve25519 {
|
||||
return id.Curve25519(base64.RawStdEncoding.EncodeToString(c))
|
||||
}
|
||||
|
||||
// Verify checks the signature of the message against the givenSignature
|
||||
func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool {
|
||||
return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature)
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < c.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle ed25519 public key: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
if len(c) != ed25519.PublicKeySize {
|
||||
return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil
|
||||
}
|
||||
return libolmpickle.PickleBytes(c, target), nil
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read.
|
||||
func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) {
|
||||
unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, ed25519.PublicKeySize)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
*c = unpickled
|
||||
return readBytes, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled public key will have.
|
||||
func (c Ed25519PublicKey) PickleLen() int {
|
||||
if len(c) != ed25519.PublicKeySize {
|
||||
return libolmpickle.PickleBytesLen(make([]byte, ed25519.PublicKeySize))
|
||||
}
|
||||
return libolmpickle.PickleBytesLen(c)
|
||||
}
|
||||
29
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/hmac.go
generated
vendored
Normal file
29
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/hmac.go
generated
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"crypto/hmac"
|
||||
"crypto/sha256"
|
||||
"io"
|
||||
|
||||
"golang.org/x/crypto/hkdf"
|
||||
)
|
||||
|
||||
// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key.
|
||||
func HMACSHA256(key, input []byte) []byte {
|
||||
hash := hmac.New(sha256.New, key)
|
||||
hash.Write(input)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
// SHA256 return the SHA-256 of the value.
|
||||
func SHA256(value []byte) []byte {
|
||||
hash := sha256.New()
|
||||
hash.Write(value)
|
||||
return hash.Sum(nil)
|
||||
}
|
||||
|
||||
// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil.
|
||||
// The reader can be used to read an arbitary length of bytes which are based on all parameters.
|
||||
func HKDFSHA256(input, salt, info []byte) io.Reader {
|
||||
return hkdf.New(sha256.New, input, salt, info)
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/main.go
generated
vendored
Normal file
2
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/main.go
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// crpyto provides the nessesary encryption methods for olm/megolm
|
||||
package crypto
|
||||
95
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/one_time_key.go
generated
vendored
Normal file
95
vendor/maunium.net/go/mautrix/crypto/goolm/crypto/one_time_key.go
generated
vendored
Normal file
@@ -0,0 +1,95 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// OneTimeKey stores the information about a one time key.
|
||||
type OneTimeKey struct {
|
||||
ID uint32 `json:"id"`
|
||||
Published bool `json:"published"`
|
||||
Key Curve25519KeyPair `json:"key,omitempty"`
|
||||
}
|
||||
|
||||
// Equal compares the one time key to the given one.
|
||||
func (otk OneTimeKey) Equal(s OneTimeKey) bool {
|
||||
if otk.ID != s.ID {
|
||||
return false
|
||||
}
|
||||
if otk.Published != s.Published {
|
||||
return false
|
||||
}
|
||||
if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) {
|
||||
return false
|
||||
}
|
||||
if !otk.Key.PublicKey.Equal(s.Key.PublicKey) {
|
||||
return false
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < c.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle one time key: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := libolmpickle.PickleUInt32(uint32(c.ID), target)
|
||||
written += libolmpickle.PickleBool(c.Published, target[written:])
|
||||
writtenKey, err := c.Key.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle one time key: %w", err)
|
||||
}
|
||||
written += writtenKey
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read.
|
||||
func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) {
|
||||
totalReadBytes := 0
|
||||
id, readBytes, err := libolmpickle.UnpickleUInt32(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
totalReadBytes += readBytes
|
||||
c.ID = id
|
||||
published, readBytes, err := libolmpickle.UnpickleBool(value[totalReadBytes:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
totalReadBytes += readBytes
|
||||
c.Published = published
|
||||
readBytes, err = c.Key.UnpickleLibOlm(value[totalReadBytes:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
totalReadBytes += readBytes
|
||||
return totalReadBytes, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled OneTimeKey will have.
|
||||
func (c OneTimeKey) PickleLen() int {
|
||||
length := 0
|
||||
length += libolmpickle.PickleUInt32Len(c.ID)
|
||||
length += libolmpickle.PickleBoolLen(c.Published)
|
||||
length += c.Key.PickleLen()
|
||||
return length
|
||||
}
|
||||
|
||||
// KeyIDEncoded returns the base64 encoded id.
|
||||
func (c OneTimeKey) KeyIDEncoded() string {
|
||||
resSlice := make([]byte, 4)
|
||||
binary.BigEndian.PutUint32(resSlice, c.ID)
|
||||
return base64.RawStdEncoding.EncodeToString(resSlice)
|
||||
}
|
||||
|
||||
// PublicKeyEncoded returns the base64 encoded public key
|
||||
func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 {
|
||||
return c.Key.PublicKey.B64Encoded()
|
||||
}
|
||||
30
vendor/maunium.net/go/mautrix/crypto/goolm/errors.go
generated
vendored
Normal file
30
vendor/maunium.net/go/mautrix/crypto/goolm/errors.go
generated
vendored
Normal file
@@ -0,0 +1,30 @@
|
||||
package goolm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
)
|
||||
|
||||
// Those are the most common used errors
|
||||
var (
|
||||
ErrBadSignature = errors.New("bad signature")
|
||||
ErrBadMAC = errors.New("bad mac")
|
||||
ErrBadMessageFormat = errors.New("bad message format")
|
||||
ErrBadVerification = errors.New("bad verification")
|
||||
ErrWrongProtocolVersion = errors.New("wrong protocol version")
|
||||
ErrEmptyInput = errors.New("empty input")
|
||||
ErrNoKeyProvided = errors.New("no key")
|
||||
ErrBadMessageKeyID = errors.New("bad message key id")
|
||||
ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
|
||||
ErrMsgIndexTooHigh = errors.New("message index too high")
|
||||
ErrProtocolViolation = errors.New("not protocol message order")
|
||||
ErrMessageKeyNotFound = errors.New("message key not found")
|
||||
ErrChainTooHigh = errors.New("chain index too high")
|
||||
ErrBadInput = errors.New("bad input")
|
||||
ErrBadVersion = errors.New("wrong version")
|
||||
ErrNotBlocksize = errors.New("length != blocksize")
|
||||
ErrNotMultipleBlocksize = errors.New("length not a multiple of the blocksize")
|
||||
ErrWrongPickleVersion = errors.New("wrong pickle version")
|
||||
ErrValueTooShort = errors.New("value too short")
|
||||
ErrInputToSmall = errors.New("input too small (truncated?)")
|
||||
ErrOverflow = errors.New("overflow")
|
||||
)
|
||||
41
vendor/maunium.net/go/mautrix/crypto/goolm/libolmpickle/pickle.go
generated
vendored
Normal file
41
vendor/maunium.net/go/mautrix/crypto/goolm/libolmpickle/pickle.go
generated
vendored
Normal file
@@ -0,0 +1,41 @@
|
||||
package libolmpickle
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
)
|
||||
|
||||
func PickleUInt8(value uint8, target []byte) int {
|
||||
target[0] = value
|
||||
return 1
|
||||
}
|
||||
func PickleUInt8Len(value uint8) int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func PickleBool(value bool, target []byte) int {
|
||||
if value {
|
||||
target[0] = 0x01
|
||||
} else {
|
||||
target[0] = 0x00
|
||||
}
|
||||
return 1
|
||||
}
|
||||
func PickleBoolLen(value bool) int {
|
||||
return 1
|
||||
}
|
||||
|
||||
func PickleBytes(value, target []byte) int {
|
||||
return copy(target, value)
|
||||
}
|
||||
func PickleBytesLen(value []byte) int {
|
||||
return len(value)
|
||||
}
|
||||
|
||||
func PickleUInt32(value uint32, target []byte) int {
|
||||
res := make([]byte, 4) //4 bytes for int32
|
||||
binary.BigEndian.PutUint32(res, value)
|
||||
return copy(target, res)
|
||||
}
|
||||
func PickleUInt32Len(value uint32) int {
|
||||
return 4
|
||||
}
|
||||
53
vendor/maunium.net/go/mautrix/crypto/goolm/libolmpickle/unpickle.go
generated
vendored
Normal file
53
vendor/maunium.net/go/mautrix/crypto/goolm/libolmpickle/unpickle.go
generated
vendored
Normal file
@@ -0,0 +1,53 @@
|
||||
package libolmpickle
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
)
|
||||
|
||||
func isZeroByteSlice(bytes []byte) bool {
|
||||
b := byte(0)
|
||||
for _, s := range bytes {
|
||||
b |= s
|
||||
}
|
||||
return b == 0
|
||||
}
|
||||
|
||||
func UnpickleUInt8(value []byte) (uint8, int, error) {
|
||||
if len(value) < 1 {
|
||||
return 0, 0, fmt.Errorf("unpickle uint8: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
return value[0], 1, nil
|
||||
}
|
||||
|
||||
func UnpickleBool(value []byte) (bool, int, error) {
|
||||
if len(value) < 1 {
|
||||
return false, 0, fmt.Errorf("unpickle bool: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
return value[0] != uint8(0x00), 1, nil
|
||||
}
|
||||
|
||||
func UnpickleBytes(value []byte, length int) ([]byte, int, error) {
|
||||
if len(value) < length {
|
||||
return nil, 0, fmt.Errorf("unpickle bytes: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
resp := value[:length]
|
||||
if isZeroByteSlice(resp) {
|
||||
return nil, length, nil
|
||||
}
|
||||
return resp, length, nil
|
||||
}
|
||||
|
||||
func UnpickleUInt32(value []byte) (uint32, int, error) {
|
||||
if len(value) < 4 {
|
||||
return 0, 0, fmt.Errorf("unpickle uint32: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
var res uint32
|
||||
count := 0
|
||||
for i := 3; i >= 0; i-- {
|
||||
res |= uint32(value[count]) << (8 * i)
|
||||
count++
|
||||
}
|
||||
return res, 4, nil
|
||||
}
|
||||
6
vendor/maunium.net/go/mautrix/crypto/goolm/main.go
generated
vendored
Normal file
6
vendor/maunium.net/go/mautrix/crypto/goolm/main.go
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
// Package goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and written in C++.
|
||||
// With goolm there is no need to use cgo when building Matrix clients in go.
|
||||
/*
|
||||
This package contains the possible errors which can occur as well as some simple functions. All the 'action' happens in the subdirectories.
|
||||
*/
|
||||
package goolm
|
||||
234
vendor/maunium.net/go/mautrix/crypto/goolm/megolm/megolm.go
generated
vendored
Normal file
234
vendor/maunium.net/go/mautrix/crypto/goolm/megolm/megolm.go
generated
vendored
Normal file
@@ -0,0 +1,234 @@
|
||||
// megolm provides the ratchet used by the megolm protocol
|
||||
package megolm
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/message"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
)
|
||||
|
||||
const (
|
||||
megolmPickleVersion uint8 = 1
|
||||
)
|
||||
|
||||
const (
|
||||
protocolVersion = 3
|
||||
RatchetParts = 4 // number of ratchet parts
|
||||
RatchetPartLength = 256 / 8 // length of each ratchet part in bytes
|
||||
)
|
||||
|
||||
var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS"))
|
||||
|
||||
// hasKeySeed are the seed for the different ratchet parts
|
||||
var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{
|
||||
{0x00},
|
||||
{0x01},
|
||||
{0x02},
|
||||
{0x03},
|
||||
}
|
||||
|
||||
// Ratchet represents the megolm ratchet as described in
|
||||
//
|
||||
// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/megolm.md
|
||||
type Ratchet struct {
|
||||
Data [RatchetParts * RatchetPartLength]byte `json:"data"`
|
||||
Counter uint32 `json:"counter"`
|
||||
}
|
||||
|
||||
// New creates a new ratchet with counter set to counter and the ratchet data set to data.
|
||||
func New(counter uint32, data [RatchetParts * RatchetPartLength]byte) (*Ratchet, error) {
|
||||
m := &Ratchet{
|
||||
Counter: counter,
|
||||
Data: data,
|
||||
}
|
||||
return m, nil
|
||||
}
|
||||
|
||||
// NewWithRandom creates a new ratchet with counter set to counter an the data filled with random values.
|
||||
func NewWithRandom(counter uint32) (*Ratchet, error) {
|
||||
var data [RatchetParts * RatchetPartLength]byte
|
||||
_, err := rand.Read(data[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return New(counter, data)
|
||||
}
|
||||
|
||||
// rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to.
|
||||
func (m *Ratchet) rehashPart(from, to int) {
|
||||
newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to])
|
||||
copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength])
|
||||
}
|
||||
|
||||
// Advance advances the ratchet one step.
|
||||
func (m *Ratchet) Advance() {
|
||||
var mask uint32 = 0x00FFFFFF
|
||||
var h int
|
||||
m.Counter++
|
||||
|
||||
// figure out how much we need to rekey
|
||||
for h < RatchetParts {
|
||||
if (m.Counter & mask) == 0 {
|
||||
break
|
||||
}
|
||||
h++
|
||||
mask >>= 8
|
||||
}
|
||||
|
||||
// now update R(h)...R(3) based on R(h)
|
||||
for i := RatchetParts - 1; i >= h; i-- {
|
||||
m.rehashPart(h, i)
|
||||
}
|
||||
}
|
||||
|
||||
// AdvanceTo advances the ratchet so that the ratchet counter = target
|
||||
func (m *Ratchet) AdvanceTo(target uint32) {
|
||||
//starting with R0, see if we need to update each part of the hash
|
||||
for j := 0; j < RatchetParts; j++ {
|
||||
shift := uint32((RatchetParts - j - 1) * 8)
|
||||
mask := (^uint32(0)) << shift
|
||||
|
||||
// how many times do we need to rehash this part?
|
||||
// '& 0xff' ensures we handle integer wraparound correctly
|
||||
steps := ((target >> shift) - m.Counter>>shift) & uint32(0xff)
|
||||
|
||||
if steps == 0 {
|
||||
/*
|
||||
deal with the edge case where m.Counter is slightly larger
|
||||
than target. This should only happen for R(0), and implies
|
||||
that target has wrapped around and we need to advance R(0)
|
||||
256 times.
|
||||
*/
|
||||
if target < m.Counter {
|
||||
steps = 0x100
|
||||
} else {
|
||||
continue
|
||||
}
|
||||
}
|
||||
// for all but the last step, we can just bump R(j) without regard to R(j+1)...R(3).
|
||||
for steps > 1 {
|
||||
m.rehashPart(j, j)
|
||||
steps--
|
||||
}
|
||||
/*
|
||||
on the last step we also need to bump R(j+1)...R(3).
|
||||
|
||||
(Theoretically, we could skip bumping R(j+2) if we're going to bump
|
||||
R(j+1) again, but the code to figure that out is a bit baroque and
|
||||
doesn't save us much).
|
||||
*/
|
||||
for k := 3; k >= j; k-- {
|
||||
m.rehashPart(j, k)
|
||||
}
|
||||
m.Counter = target & mask
|
||||
}
|
||||
}
|
||||
|
||||
// Encrypt encrypts the message in a message.GroupMessage with MAC and signature.
|
||||
// The output is base64 encoded.
|
||||
func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) {
|
||||
var err error
|
||||
encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cipher encrypt: %w", err)
|
||||
}
|
||||
|
||||
message := &message.GroupMessage{}
|
||||
message.Version = protocolVersion
|
||||
message.MessageIndex = r.Counter
|
||||
message.Ciphertext = encryptedText
|
||||
//creating the mac and signing is done in encode
|
||||
output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.Advance()
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// SessionSharingMessage creates a message in the session sharing format.
|
||||
func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error) {
|
||||
m := message.MegolmSessionSharing{}
|
||||
m.Counter = r.Counter
|
||||
m.RatchetData = r.Data
|
||||
encoded := m.EncodeAndSign(key)
|
||||
return goolm.Base64Encode(encoded), nil
|
||||
}
|
||||
|
||||
// SessionExportMessage creates a message in the session export format.
|
||||
func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, error) {
|
||||
m := message.MegolmSessionExport{}
|
||||
m.Counter = r.Counter
|
||||
m.RatchetData = r.Data
|
||||
m.PublicKey = key
|
||||
encoded := m.Encode()
|
||||
return goolm.Base64Encode(encoded), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext and verifies the MAC but not the signature.
|
||||
func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) {
|
||||
//verify mac
|
||||
verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !verifiedMAC {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC)
|
||||
}
|
||||
|
||||
return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext)
|
||||
}
|
||||
|
||||
// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(r, megolmPickleVersion, key)
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
||||
func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error {
|
||||
return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion)
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read.
|
||||
func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) {
|
||||
//read ratchet data
|
||||
curPos := 0
|
||||
ratchetData, readBytes, err := libolmpickle.UnpickleBytes(unpickled, RatchetParts*RatchetPartLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
copy(r.Data[:], ratchetData)
|
||||
curPos += readBytes
|
||||
//Read counter
|
||||
counter, readBytes, err := libolmpickle.UnpickleUInt32(unpickled[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
r.Counter = counter
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (r Ratchet) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < r.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := libolmpickle.PickleBytes(r.Data[:], target)
|
||||
written += libolmpickle.PickleUInt32(r.Counter, target[written:])
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled ratchet will have.
|
||||
func (r Ratchet) PickleLen() int {
|
||||
length := libolmpickle.PickleBytesLen(r.Data[:])
|
||||
length += libolmpickle.PickleUInt32Len(r.Counter)
|
||||
return length
|
||||
}
|
||||
70
vendor/maunium.net/go/mautrix/crypto/goolm/message/decoder.go
generated
vendored
Normal file
70
vendor/maunium.net/go/mautrix/crypto/goolm/message/decoder.go
generated
vendored
Normal file
@@ -0,0 +1,70 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
)
|
||||
|
||||
// checkDecodeErr checks if there was an error during decode.
|
||||
func checkDecodeErr(readBytes int) error {
|
||||
if readBytes == 0 {
|
||||
//end reached
|
||||
return goolm.ErrInputToSmall
|
||||
}
|
||||
if readBytes < 0 {
|
||||
return goolm.ErrOverflow
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
// decodeVarInt decodes a single big-endian encoded varint.
|
||||
func decodeVarInt(input []byte) (uint32, int) {
|
||||
value, readBytes := binary.Uvarint(input)
|
||||
return uint32(value), readBytes
|
||||
}
|
||||
|
||||
// decodeVarString decodes the length of the string (varint) and returns the actual string
|
||||
func decodeVarString(input []byte) ([]byte, int) {
|
||||
stringLen, readBytes := decodeVarInt(input)
|
||||
if readBytes <= 0 {
|
||||
return nil, readBytes
|
||||
}
|
||||
input = input[readBytes:]
|
||||
value := input[:stringLen]
|
||||
readBytes += int(stringLen)
|
||||
return value, readBytes
|
||||
}
|
||||
|
||||
// encodeVarIntByteLength returns the number of bytes needed to encode the uint32.
|
||||
func encodeVarIntByteLength(input uint32) int {
|
||||
result := 1
|
||||
for input >= 128 {
|
||||
result++
|
||||
input >>= 7
|
||||
}
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeVarStringByteLength returns the number of bytes needed to encode the input.
|
||||
func encodeVarStringByteLength(input []byte) int {
|
||||
result := encodeVarIntByteLength(uint32(len(input)))
|
||||
result += len(input)
|
||||
return result
|
||||
}
|
||||
|
||||
// encodeVarInt encodes a single uint32
|
||||
func encodeVarInt(input uint32) []byte {
|
||||
out := make([]byte, encodeVarIntByteLength(input))
|
||||
binary.PutUvarint(out, uint64(input))
|
||||
return out
|
||||
}
|
||||
|
||||
// encodeVarString encodes the length of the input (varint) and appends the actual input
|
||||
func encodeVarString(input []byte) []byte {
|
||||
out := make([]byte, encodeVarStringByteLength(input))
|
||||
length := encodeVarInt(uint32(len(input)))
|
||||
copy(out, length)
|
||||
copy(out[len(length):], input)
|
||||
return out
|
||||
}
|
||||
144
vendor/maunium.net/go/mautrix/crypto/goolm/message/group_message.go
generated
vendored
Normal file
144
vendor/maunium.net/go/mautrix/crypto/goolm/message/group_message.go
generated
vendored
Normal file
@@ -0,0 +1,144 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
messageIndexTag = 0x08
|
||||
cipherTextTag = 0x12
|
||||
countMACBytesGroupMessage = 8
|
||||
)
|
||||
|
||||
// GroupMessage represents a message in the group message format.
|
||||
type GroupMessage struct {
|
||||
Version byte `json:"version"`
|
||||
MessageIndex uint32 `json:"index"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
HasMessageIndex bool `json:"has_index"`
|
||||
}
|
||||
|
||||
// Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present.
|
||||
func (r *GroupMessage) Decode(input []byte) error {
|
||||
r.Version = 0
|
||||
r.MessageIndex = 0
|
||||
r.Ciphertext = nil
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
//first Byte is always version
|
||||
r.Version = input[0]
|
||||
curPos := 1
|
||||
for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize {
|
||||
//Read Key
|
||||
curKey, readBytes := decodeVarInt(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
if (curKey & 0b111) == 0 {
|
||||
//The value is of type varint
|
||||
value, readBytes := decodeVarInt(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
switch curKey {
|
||||
case messageIndexTag:
|
||||
r.MessageIndex = value
|
||||
r.HasMessageIndex = true
|
||||
}
|
||||
} else if (curKey & 0b111) == 2 {
|
||||
//The value is of type string
|
||||
value, readBytes := decodeVarString(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
switch curKey {
|
||||
case cipherTextTag:
|
||||
r.Ciphertext = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message.
|
||||
// If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended.
|
||||
func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) {
|
||||
var lengthOfMessage int
|
||||
lengthOfMessage += 1 //Version
|
||||
lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex)
|
||||
lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext)
|
||||
out := make([]byte, lengthOfMessage)
|
||||
out[0] = r.Version
|
||||
curPos := 1
|
||||
encodedTag := encodeVarInt(messageIndexTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue := encodeVarInt(r.MessageIndex)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
encodedTag = encodeVarInt(cipherTextTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue = encodeVarString(r.Ciphertext)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
if len(macKey) != 0 && cipher != nil {
|
||||
mac, err := r.MAC(macKey, cipher, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, mac[:countMACBytesGroupMessage]...)
|
||||
}
|
||||
if signKey != nil {
|
||||
signature := signKey.Sign(out)
|
||||
out = append(out, signature...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length.
|
||||
func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) {
|
||||
mac, err := cipher.MAC(key, message)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return mac[:countMACBytesGroupMessage], nil
|
||||
}
|
||||
|
||||
// VerifySignature verifies the givenSignature to the calculated signature of the message.
|
||||
func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool {
|
||||
return key.Verify(message, givenSignature)
|
||||
}
|
||||
|
||||
// VerifySignature verifies the signature taken from the message to the calculated signature of the message.
|
||||
func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool {
|
||||
signature := message[len(message)-crypto.ED25519SignatureSize:]
|
||||
message = message[:len(message)-crypto.ED25519SignatureSize]
|
||||
return key.Verify(message, signature)
|
||||
}
|
||||
|
||||
// VerifyMAC verifies the givenMAC to the calculated MAC of the message.
|
||||
func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) {
|
||||
checkMac, err := r.MAC(key, cipher, message)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bytes.Equal(checkMac[:countMACBytesGroupMessage], givenMAC), nil
|
||||
}
|
||||
|
||||
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
|
||||
func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
|
||||
startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize
|
||||
endMAC := startMAC + countMACBytesGroupMessage
|
||||
suplMac := message[startMAC:endMAC]
|
||||
message = message[:startMAC]
|
||||
return r.VerifyMAC(key, cipher, message, suplMac)
|
||||
}
|
||||
129
vendor/maunium.net/go/mautrix/crypto/goolm/message/message.go
generated
vendored
Normal file
129
vendor/maunium.net/go/mautrix/crypto/goolm/message/message.go
generated
vendored
Normal file
@@ -0,0 +1,129 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
ratchetKeyTag = 0x0A
|
||||
counterTag = 0x10
|
||||
cipherTextKeyTag = 0x22
|
||||
countMACBytesMessage = 8
|
||||
)
|
||||
|
||||
// GroupMessage represents a message in the message format.
|
||||
type Message struct {
|
||||
Version byte `json:"version"`
|
||||
HasCounter bool `json:"has_counter"`
|
||||
Counter uint32 `json:"counter"`
|
||||
RatchetKey crypto.Curve25519PublicKey `json:"ratchet_key"`
|
||||
Ciphertext []byte `json:"ciphertext"`
|
||||
}
|
||||
|
||||
// Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present.
|
||||
func (r *Message) Decode(input []byte) error {
|
||||
r.Version = 0
|
||||
r.HasCounter = false
|
||||
r.Counter = 0
|
||||
r.RatchetKey = nil
|
||||
r.Ciphertext = nil
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
//first Byte is always version
|
||||
r.Version = input[0]
|
||||
curPos := 1
|
||||
for curPos < len(input)-countMACBytesMessage {
|
||||
//Read Key
|
||||
curKey, readBytes := decodeVarInt(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
if (curKey & 0b111) == 0 {
|
||||
//The value is of type varint
|
||||
value, readBytes := decodeVarInt(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
switch curKey {
|
||||
case counterTag:
|
||||
r.HasCounter = true
|
||||
r.Counter = value
|
||||
}
|
||||
} else if (curKey & 0b111) == 2 {
|
||||
//The value is of type string
|
||||
value, readBytes := decodeVarString(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
switch curKey {
|
||||
case ratchetKeyTag:
|
||||
r.RatchetKey = value
|
||||
case cipherTextKeyTag:
|
||||
r.Ciphertext = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// EncodeAndMAC encodes the message and creates the MAC with the key and the cipher.
|
||||
// If key or cipher is nil, no MAC is appended.
|
||||
func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) {
|
||||
var lengthOfMessage int
|
||||
lengthOfMessage += 1 //Version
|
||||
lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey)
|
||||
lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter)
|
||||
lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext)
|
||||
out := make([]byte, lengthOfMessage)
|
||||
out[0] = r.Version
|
||||
curPos := 1
|
||||
encodedTag := encodeVarInt(ratchetKeyTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue := encodeVarString(r.RatchetKey)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
encodedTag = encodeVarInt(counterTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue = encodeVarInt(r.Counter)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
encodedTag = encodeVarInt(cipherTextKeyTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue = encodeVarString(r.Ciphertext)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
if len(key) != 0 && cipher != nil {
|
||||
mac, err := cipher.MAC(key, out)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
out = append(out, mac[:countMACBytesMessage]...)
|
||||
}
|
||||
return out, nil
|
||||
}
|
||||
|
||||
// VerifyMAC verifies the givenMAC to the calculated MAC of the message.
|
||||
func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) {
|
||||
checkMAC, err := cipher.MAC(key, message)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return bytes.Equal(checkMAC[:countMACBytesMessage], givenMAC), nil
|
||||
}
|
||||
|
||||
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
|
||||
func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
|
||||
givenMAC := message[len(message)-countMACBytesMessage:]
|
||||
return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC)
|
||||
}
|
||||
120
vendor/maunium.net/go/mautrix/crypto/goolm/message/prekey_message.go
generated
vendored
Normal file
120
vendor/maunium.net/go/mautrix/crypto/goolm/message/prekey_message.go
generated
vendored
Normal file
@@ -0,0 +1,120 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
oneTimeKeyIdTag = 0x0A
|
||||
baseKeyTag = 0x12
|
||||
identityKeyTag = 0x1A
|
||||
messageTag = 0x22
|
||||
)
|
||||
|
||||
type PreKeyMessage struct {
|
||||
Version byte `json:"version"`
|
||||
IdentityKey crypto.Curve25519PublicKey `json:"id_key"`
|
||||
BaseKey crypto.Curve25519PublicKey `json:"base_key"`
|
||||
OneTimeKey crypto.Curve25519PublicKey `json:"one_time_key"`
|
||||
Message []byte `json:"message"`
|
||||
}
|
||||
|
||||
// Decodes decodes the input and populates the corresponding fileds.
|
||||
func (r *PreKeyMessage) Decode(input []byte) error {
|
||||
r.Version = 0
|
||||
r.IdentityKey = nil
|
||||
r.BaseKey = nil
|
||||
r.OneTimeKey = nil
|
||||
r.Message = nil
|
||||
if len(input) == 0 {
|
||||
return nil
|
||||
}
|
||||
//first Byte is always version
|
||||
r.Version = input[0]
|
||||
curPos := 1
|
||||
for curPos < len(input) {
|
||||
//Read Key
|
||||
curKey, readBytes := decodeVarInt(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
if (curKey & 0b111) == 0 {
|
||||
//The value is of type varint
|
||||
_, readBytes := decodeVarInt(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
} else if (curKey & 0b111) == 2 {
|
||||
//The value is of type string
|
||||
value, readBytes := decodeVarString(input[curPos:])
|
||||
if err := checkDecodeErr(readBytes); err != nil {
|
||||
return err
|
||||
}
|
||||
curPos += readBytes
|
||||
switch curKey {
|
||||
case oneTimeKeyIdTag:
|
||||
r.OneTimeKey = value
|
||||
case baseKeyTag:
|
||||
r.BaseKey = value
|
||||
case identityKeyTag:
|
||||
r.IdentityKey = value
|
||||
case messageTag:
|
||||
r.Message = value
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
}
|
||||
|
||||
// CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message.
|
||||
func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey) bool {
|
||||
ok := true
|
||||
ok = ok && (theirIdentityKey != nil || r.IdentityKey != nil)
|
||||
if r.IdentityKey != nil {
|
||||
ok = ok && (len(r.IdentityKey) == crypto.Curve25519KeyLength)
|
||||
}
|
||||
ok = ok && len(r.Message) != 0
|
||||
ok = ok && len(r.BaseKey) == crypto.Curve25519KeyLength
|
||||
ok = ok && len(r.OneTimeKey) == crypto.Curve25519KeyLength
|
||||
return ok
|
||||
}
|
||||
|
||||
// Encode encodes the message.
|
||||
func (r *PreKeyMessage) Encode() ([]byte, error) {
|
||||
var lengthOfMessage int
|
||||
lengthOfMessage += 1 //Version
|
||||
lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey)
|
||||
lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey)
|
||||
lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey)
|
||||
lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message)
|
||||
out := make([]byte, lengthOfMessage)
|
||||
out[0] = r.Version
|
||||
curPos := 1
|
||||
encodedTag := encodeVarInt(oneTimeKeyIdTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue := encodeVarString(r.OneTimeKey)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
encodedTag = encodeVarInt(identityKeyTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue = encodeVarString(r.IdentityKey)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
encodedTag = encodeVarInt(baseKeyTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue = encodeVarString(r.BaseKey)
|
||||
copy(out[curPos:], encodedValue)
|
||||
curPos += len(encodedValue)
|
||||
encodedTag = encodeVarInt(messageTag)
|
||||
copy(out[curPos:], encodedTag)
|
||||
curPos += len(encodedTag)
|
||||
encodedValue = encodeVarString(r.Message)
|
||||
copy(out[curPos:], encodedValue)
|
||||
return out, nil
|
||||
}
|
||||
44
vendor/maunium.net/go/mautrix/crypto/goolm/message/session_export.go
generated
vendored
Normal file
44
vendor/maunium.net/go/mautrix/crypto/goolm/message/session_export.go
generated
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionExportVersion = 0x01
|
||||
)
|
||||
|
||||
// MegolmSessionExport represents a message in the session export format.
|
||||
type MegolmSessionExport struct {
|
||||
Counter uint32 `json:"counter"`
|
||||
RatchetData [128]byte `json:"data"`
|
||||
PublicKey crypto.Ed25519PublicKey `json:"public_key"`
|
||||
}
|
||||
|
||||
// Encode returns the encoded message in the correct format.
|
||||
func (s MegolmSessionExport) Encode() []byte {
|
||||
output := make([]byte, 165)
|
||||
output[0] = sessionExportVersion
|
||||
binary.BigEndian.PutUint32(output[1:], s.Counter)
|
||||
copy(output[5:], s.RatchetData[:])
|
||||
copy(output[133:], s.PublicKey)
|
||||
return output
|
||||
}
|
||||
|
||||
// Decode populates the struct with the data encoded in input.
|
||||
func (s *MegolmSessionExport) Decode(input []byte) error {
|
||||
if len(input) != 165 {
|
||||
return fmt.Errorf("decrypt: %w", goolm.ErrBadInput)
|
||||
}
|
||||
if input[0] != sessionExportVersion {
|
||||
return fmt.Errorf("decrypt: %w", goolm.ErrBadVersion)
|
||||
}
|
||||
s.Counter = binary.BigEndian.Uint32(input[1:5])
|
||||
copy(s.RatchetData[:], input[5:133])
|
||||
s.PublicKey = input[133:]
|
||||
return nil
|
||||
}
|
||||
50
vendor/maunium.net/go/mautrix/crypto/goolm/message/session_sharing.go
generated
vendored
Normal file
50
vendor/maunium.net/go/mautrix/crypto/goolm/message/session_sharing.go
generated
vendored
Normal file
@@ -0,0 +1,50 @@
|
||||
package message
|
||||
|
||||
import (
|
||||
"encoding/binary"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
const (
|
||||
sessionSharingVersion = 0x02
|
||||
)
|
||||
|
||||
// MegolmSessionSharing represents a message in the session sharing format.
|
||||
type MegolmSessionSharing struct {
|
||||
Counter uint32 `json:"counter"`
|
||||
RatchetData [128]byte `json:"data"`
|
||||
PublicKey crypto.Ed25519PublicKey `json:"-"` //only used when decrypting messages
|
||||
}
|
||||
|
||||
// Encode returns the encoded message in the correct format with the signature by key appended.
|
||||
func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte {
|
||||
output := make([]byte, 229)
|
||||
output[0] = sessionSharingVersion
|
||||
binary.BigEndian.PutUint32(output[1:], s.Counter)
|
||||
copy(output[5:], s.RatchetData[:])
|
||||
copy(output[133:], key.PublicKey)
|
||||
signature := key.Sign(output[:165])
|
||||
copy(output[165:], signature)
|
||||
return output
|
||||
}
|
||||
|
||||
// VerifyAndDecode verifies the input and populates the struct with the data encoded in input.
|
||||
func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
|
||||
if len(input) != 229 {
|
||||
return fmt.Errorf("verify: %w", goolm.ErrBadInput)
|
||||
}
|
||||
publicKey := crypto.Ed25519PublicKey(input[133:165])
|
||||
if !publicKey.Verify(input[:165], input[165:]) {
|
||||
return fmt.Errorf("verify: %w", goolm.ErrBadVerification)
|
||||
}
|
||||
s.PublicKey = publicKey
|
||||
if input[0] != sessionSharingVersion {
|
||||
return fmt.Errorf("verify: %w", goolm.ErrBadVersion)
|
||||
}
|
||||
s.Counter = binary.BigEndian.Uint32(input[1:5])
|
||||
copy(s.RatchetData[:], input[5:133])
|
||||
return nil
|
||||
}
|
||||
258
vendor/maunium.net/go/mautrix/crypto/goolm/olm/chain.go
generated
vendored
Normal file
258
vendor/maunium.net/go/mautrix/crypto/goolm/olm/chain.go
generated
vendored
Normal file
@@ -0,0 +1,258 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
)
|
||||
|
||||
const (
|
||||
chainKeySeed = 0x02
|
||||
messageKeyLength = 32
|
||||
)
|
||||
|
||||
// chainKey wraps the index and the public key
|
||||
type chainKey struct {
|
||||
Index uint32 `json:"index"`
|
||||
Key crypto.Curve25519PublicKey `json:"key"`
|
||||
}
|
||||
|
||||
// advance advances the chain
|
||||
func (c *chainKey) advance() {
|
||||
c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed})
|
||||
c.Index++
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the chain key accordingly. It returns the number of bytes read.
|
||||
func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) {
|
||||
curPos := 0
|
||||
readBytes, err := r.Key.UnpickleLibOlm(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
r.Index, readBytes, err = libolmpickle.UnpickleUInt32(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (r chainKey) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < r.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle chain key: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written, err := r.Key.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle chain key: %w", err)
|
||||
}
|
||||
written += libolmpickle.PickleUInt32(r.Index, target[written:])
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled chain key will have.
|
||||
func (r chainKey) PickleLen() int {
|
||||
length := r.Key.PickleLen()
|
||||
length += libolmpickle.PickleUInt32Len(r.Index)
|
||||
return length
|
||||
}
|
||||
|
||||
// senderChain is a chain for sending messages
|
||||
type senderChain struct {
|
||||
RKey crypto.Curve25519KeyPair `json:"ratchet_key"`
|
||||
CKey chainKey `json:"chain_key"`
|
||||
IsSet bool `json:"set"`
|
||||
}
|
||||
|
||||
// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair.
|
||||
func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain {
|
||||
return &senderChain{
|
||||
RKey: ratchet,
|
||||
CKey: chainKey{
|
||||
Index: 0,
|
||||
Key: key,
|
||||
},
|
||||
IsSet: true,
|
||||
}
|
||||
}
|
||||
|
||||
// advance advances the chain
|
||||
func (s *senderChain) advance() {
|
||||
s.CKey.advance()
|
||||
}
|
||||
|
||||
// ratchetKey returns the ratchet key pair.
|
||||
func (s senderChain) ratchetKey() crypto.Curve25519KeyPair {
|
||||
return s.RKey
|
||||
}
|
||||
|
||||
// chainKey returns the current chainKey.
|
||||
func (s senderChain) chainKey() chainKey {
|
||||
return s.CKey
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
|
||||
func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) {
|
||||
curPos := 0
|
||||
readBytes, err := r.RKey.UnpickleLibOlm(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (r senderChain) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < r.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written, err := r.RKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", err)
|
||||
}
|
||||
writtenChain, err := r.CKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", err)
|
||||
}
|
||||
written += writtenChain
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled chain will have.
|
||||
func (r senderChain) PickleLen() int {
|
||||
length := r.RKey.PickleLen()
|
||||
length += r.CKey.PickleLen()
|
||||
return length
|
||||
}
|
||||
|
||||
// senderChain is a chain for receiving messages
|
||||
type receiverChain struct {
|
||||
RKey crypto.Curve25519PublicKey `json:"ratchet_key"`
|
||||
CKey chainKey `json:"chain_key"`
|
||||
}
|
||||
|
||||
// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key.
|
||||
func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain {
|
||||
return &receiverChain{
|
||||
RKey: ratchet,
|
||||
CKey: chainKey{
|
||||
Index: 0,
|
||||
Key: chain,
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// advance advances the chain
|
||||
func (s *receiverChain) advance() {
|
||||
s.CKey.advance()
|
||||
}
|
||||
|
||||
// ratchetKey returns the ratchet public key.
|
||||
func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey {
|
||||
return s.RKey
|
||||
}
|
||||
|
||||
// chainKey returns the current chainKey.
|
||||
func (s receiverChain) chainKey() chainKey {
|
||||
return s.CKey
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
|
||||
func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) {
|
||||
curPos := 0
|
||||
readBytes, err := r.RKey.UnpickleLibOlm(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (r receiverChain) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < r.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written, err := r.RKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", err)
|
||||
}
|
||||
writtenChain, err := r.CKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", err)
|
||||
}
|
||||
written += writtenChain
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled chain will have.
|
||||
func (r receiverChain) PickleLen() int {
|
||||
length := r.RKey.PickleLen()
|
||||
length += r.CKey.PickleLen()
|
||||
return length
|
||||
}
|
||||
|
||||
// messageKey wraps the index and the key of a message
|
||||
type messageKey struct {
|
||||
Index uint32 `json:"index"`
|
||||
Key []byte `json:"key"`
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read.
|
||||
func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) {
|
||||
curPos := 0
|
||||
ratchetKey, readBytes, err := libolmpickle.UnpickleBytes(value, messageKeyLength)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
m.Key = ratchetKey
|
||||
curPos += readBytes
|
||||
keyID, readBytes, err := libolmpickle.UnpickleUInt32(value[:curPos])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
m.Index = keyID
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (m messageKey) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < m.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle message key: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := 0
|
||||
if len(m.Key) != messageKeyLength {
|
||||
written += libolmpickle.PickleBytes(make([]byte, messageKeyLength), target)
|
||||
} else {
|
||||
written += libolmpickle.PickleBytes(m.Key, target)
|
||||
}
|
||||
written += libolmpickle.PickleUInt32(m.Index, target[written:])
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled message key will have.
|
||||
func (r messageKey) PickleLen() int {
|
||||
length := libolmpickle.PickleBytesLen(make([]byte, messageKeyLength))
|
||||
length += libolmpickle.PickleUInt32Len(r.Index)
|
||||
return length
|
||||
}
|
||||
432
vendor/maunium.net/go/mautrix/crypto/goolm/olm/olm.go
generated
vendored
Normal file
432
vendor/maunium.net/go/mautrix/crypto/goolm/olm/olm.go
generated
vendored
Normal file
@@ -0,0 +1,432 @@
|
||||
// olm provides the ratchet used by the olm protocol
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/message"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
)
|
||||
|
||||
const (
|
||||
olmPickleVersion uint8 = 1
|
||||
)
|
||||
|
||||
const (
|
||||
maxReceiverChains = 5
|
||||
maxSkippedMessageKeys = 40
|
||||
protocolVersion = 3
|
||||
messageKeySeed = 0x01
|
||||
|
||||
maxMessageGap = 2000
|
||||
sharedKeyLength = 32
|
||||
)
|
||||
|
||||
// KdfInfo has the infos used for the kdf
|
||||
var KdfInfo = struct {
|
||||
Root []byte
|
||||
Ratchet []byte
|
||||
}{
|
||||
Root: []byte("OLM_ROOT"),
|
||||
Ratchet: []byte("OLM_RATCHET"),
|
||||
}
|
||||
|
||||
var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS"))
|
||||
|
||||
// Ratchet represents the olm ratchet as described in
|
||||
//
|
||||
// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md
|
||||
type Ratchet struct {
|
||||
// The root key is used to generate chain keys from the ephemeral keys.
|
||||
// A new root_key is derived each time a new chain is started.
|
||||
RootKey crypto.Curve25519PublicKey `json:"root_key"`
|
||||
|
||||
// The sender chain is used to send messages. Each time a new ephemeral
|
||||
// key is received from the remote server we generate a new sender chain
|
||||
// with a new ephemeral key when we next send a message.
|
||||
SenderChains senderChain `json:"sender_chain"`
|
||||
|
||||
// The receiver chain is used to decrypt received messages. We store the
|
||||
// last few chains so we can decrypt any out of order messages we haven't
|
||||
// received yet.
|
||||
// New chains are prepended for easier access.
|
||||
ReceiverChains []receiverChain `json:"receiver_chains"`
|
||||
|
||||
// Storing the keys of missed messages for future use.
|
||||
// The order of the elements is not important.
|
||||
SkippedMessageKeys []skippedMessageKey `json:"skipped_message_keys"`
|
||||
}
|
||||
|
||||
// New creates a new ratchet, setting the kdfInfos and cipher.
|
||||
func New() *Ratchet {
|
||||
r := &Ratchet{}
|
||||
return r
|
||||
}
|
||||
|
||||
// InitializeAsBob initializes this ratchet from a receiving point of view (only first message).
|
||||
func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error {
|
||||
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
|
||||
derivedSecrets := make([]byte, 2*sharedKeyLength)
|
||||
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
|
||||
return err
|
||||
}
|
||||
r.RootKey = derivedSecrets[0:sharedKeyLength]
|
||||
newReceiverChain := newReceiverChain(derivedSecrets[sharedKeyLength:], theirRatchetKey)
|
||||
r.ReceiverChains = append([]receiverChain{*newReceiverChain}, r.ReceiverChains...)
|
||||
return nil
|
||||
}
|
||||
|
||||
// InitializeAsAlice initializes this ratchet from a sending point of view (only first message).
|
||||
func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error {
|
||||
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
|
||||
derivedSecrets := make([]byte, 2*sharedKeyLength)
|
||||
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
|
||||
return err
|
||||
}
|
||||
r.RootKey = derivedSecrets[0:sharedKeyLength]
|
||||
newSenderChain := newSenderChain(derivedSecrets[sharedKeyLength:], ourRatchetKey)
|
||||
r.SenderChains = *newSenderChain
|
||||
return nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations.
|
||||
func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) {
|
||||
var err error
|
||||
if !r.SenderChains.IsSet {
|
||||
newRatchetKey, err := crypto.Curve25519GenerateKey(reader)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newChainKey, err := r.advanceRootKey(newRatchetKey, r.ReceiverChains[0].ratchetKey())
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newSenderChain := newSenderChain(newChainKey, newRatchetKey)
|
||||
r.SenderChains = *newSenderChain
|
||||
}
|
||||
|
||||
messageKey := r.createMessageKeys(r.SenderChains.chainKey())
|
||||
r.SenderChains.advance()
|
||||
|
||||
encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cipher encrypt: %w", err)
|
||||
}
|
||||
|
||||
message := &message.Message{}
|
||||
message.Version = protocolVersion
|
||||
message.Counter = messageKey.Index
|
||||
message.RatchetKey = r.SenderChains.ratchetKey().PublicKey
|
||||
message.Ciphertext = encryptedText
|
||||
//creating the mac is done in encode
|
||||
output, err := message.EncodeAndMAC(messageKey.Key, RatchetCipher)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext and verifies the MAC. If reader is nil, crypto/rand is used for key generations.
|
||||
func (r *Ratchet) Decrypt(input []byte) ([]byte, error) {
|
||||
message := &message.Message{}
|
||||
//The mac is not verified here, as we do not know the key yet
|
||||
err := message.Decode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if message.Version != protocolVersion {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion)
|
||||
}
|
||||
if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
|
||||
}
|
||||
var receiverChainFromMessage *receiverChain
|
||||
for curChainIndex := range r.ReceiverChains {
|
||||
if r.ReceiverChains[curChainIndex].ratchetKey().Equal(message.RatchetKey) {
|
||||
receiverChainFromMessage = &r.ReceiverChains[curChainIndex]
|
||||
break
|
||||
}
|
||||
}
|
||||
var result []byte
|
||||
if receiverChainFromMessage == nil {
|
||||
//Advancing the chain is done in this method
|
||||
result, err = r.decryptForNewChain(message, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
} else if receiverChainFromMessage.chainKey().Index > message.Counter {
|
||||
// No need to advance the chain
|
||||
// Chain already advanced beyond the key for this message
|
||||
// Check if the message keys are in the skipped key list.
|
||||
foundSkippedKey := false
|
||||
for curSkippedIndex := range r.SkippedMessageKeys {
|
||||
if message.Counter == r.SkippedMessageKeys[curSkippedIndex].MKey.Index {
|
||||
// Found the key for this message. Check the MAC.
|
||||
verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !verified {
|
||||
return nil, fmt.Errorf("decrypt from skipped message keys: %w", goolm.ErrBadMAC)
|
||||
}
|
||||
result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("cipher decrypt: %w", err)
|
||||
}
|
||||
if len(result) != 0 {
|
||||
// Remove the key from the skipped keys now that we've
|
||||
// decoded the message it corresponds to.
|
||||
r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1]
|
||||
r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1]
|
||||
}
|
||||
foundSkippedKey = true
|
||||
}
|
||||
}
|
||||
if !foundSkippedKey {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrMessageKeyNotFound)
|
||||
}
|
||||
} else {
|
||||
//Advancing the chain is done in this method
|
||||
result, err = r.decryptForExistingChain(receiverChainFromMessage, message, input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
|
||||
return result, nil
|
||||
}
|
||||
|
||||
// advanceRootKey created the next root key and returns the next chainKey
|
||||
func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatchetKey crypto.Curve25519PublicKey) (crypto.Curve25519PublicKey, error) {
|
||||
sharedSecret, err := newRatchetKey.SharedSecret(oldRatchetKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet)
|
||||
derivedSecrets := make([]byte, 2*sharedKeyLength)
|
||||
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
r.RootKey = derivedSecrets[:sharedKeyLength]
|
||||
return derivedSecrets[sharedKeyLength:], nil
|
||||
}
|
||||
|
||||
// createMessageKeys returns the messageKey derived from the chainKey
|
||||
func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey {
|
||||
res := messageKey{}
|
||||
res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed})
|
||||
res.Index = chainKey.Index
|
||||
return res
|
||||
}
|
||||
|
||||
// decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified.
|
||||
func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) {
|
||||
if message.Counter < chain.CKey.Index {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrChainTooHigh)
|
||||
}
|
||||
// Limit the number of hashes we're prepared to compute
|
||||
if message.Counter-chain.CKey.Index > maxMessageGap {
|
||||
return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrMsgIndexTooHigh)
|
||||
}
|
||||
for chain.CKey.Index < message.Counter {
|
||||
messageKey := r.createMessageKeys(chain.chainKey())
|
||||
skippedKey := skippedMessageKey{
|
||||
MKey: messageKey,
|
||||
RKey: chain.ratchetKey(),
|
||||
}
|
||||
r.SkippedMessageKeys = append(r.SkippedMessageKeys, skippedKey)
|
||||
chain.advance()
|
||||
}
|
||||
messageKey := r.createMessageKeys(chain.chainKey())
|
||||
chain.advance()
|
||||
verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !verified {
|
||||
return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrBadMAC)
|
||||
}
|
||||
return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext)
|
||||
}
|
||||
|
||||
// decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key.
|
||||
func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte) ([]byte, error) {
|
||||
// They shouldn't move to a new chain until we've sent them a message
|
||||
// acknowledging the last one
|
||||
if !r.SenderChains.IsSet {
|
||||
return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrProtocolViolation)
|
||||
}
|
||||
// Limit the number of hashes we're prepared to compute
|
||||
if message.Counter > maxMessageGap {
|
||||
return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrMsgIndexTooHigh)
|
||||
}
|
||||
|
||||
newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
newChain := newReceiverChain(newChainKey, message.RatchetKey)
|
||||
r.ReceiverChains = append([]receiverChain{*newChain}, r.ReceiverChains...)
|
||||
/*
|
||||
They have started using a new ephemeral ratchet key.
|
||||
We needed to derive a new set of chain keys.
|
||||
We can discard our previous ephemeral ratchet key.
|
||||
We will generate a new key when we send the next message.
|
||||
*/
|
||||
r.SenderChains = senderChain{}
|
||||
|
||||
decrypted, err := r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return decrypted, nil
|
||||
}
|
||||
|
||||
// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(r, olmPickleVersion, key)
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
||||
func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error {
|
||||
return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion)
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read.
|
||||
func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, error) {
|
||||
//read ratchet data
|
||||
curPos := 0
|
||||
readBytes, err := r.RootKey.UnpickleLibOlm(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
countSenderChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of sender chain
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
for i := uint32(0); i < countSenderChains; i++ {
|
||||
if i == 0 {
|
||||
//only first is stored
|
||||
readBytes, err := r.SenderChains.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
r.SenderChains.IsSet = true
|
||||
} else {
|
||||
dummy := senderChain{}
|
||||
readBytes, err := dummy.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
}
|
||||
}
|
||||
countReceivChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of recevier chain
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
r.ReceiverChains = make([]receiverChain, countReceivChains)
|
||||
for i := uint32(0); i < countReceivChains; i++ {
|
||||
readBytes, err := r.ReceiverChains[i].UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
}
|
||||
countSkippedMessageKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of skippedMessageKeys
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
r.SkippedMessageKeys = make([]skippedMessageKey, countSkippedMessageKeys)
|
||||
for i := uint32(0); i < countSkippedMessageKeys; i++ {
|
||||
readBytes, err := r.SkippedMessageKeys[i].UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
}
|
||||
// pickle v 0x80000001 includes a chain index; pickle v1 does not.
|
||||
if includesChainIndex {
|
||||
_, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
}
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (r Ratchet) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < r.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle ratchet: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written, err := r.RootKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle ratchet: %w", err)
|
||||
}
|
||||
if r.SenderChains.IsSet {
|
||||
written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1
|
||||
writtenSender, err := r.SenderChains.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle ratchet: %w", err)
|
||||
}
|
||||
written += writtenSender
|
||||
} else {
|
||||
written += libolmpickle.PickleUInt32(0, target[written:]) //Length of sender chain
|
||||
}
|
||||
written += libolmpickle.PickleUInt32(uint32(len(r.ReceiverChains)), target[written:])
|
||||
for _, curChain := range r.ReceiverChains {
|
||||
writtenChain, err := curChain.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle ratchet: %w", err)
|
||||
}
|
||||
written += writtenChain
|
||||
}
|
||||
written += libolmpickle.PickleUInt32(uint32(len(r.SkippedMessageKeys)), target[written:])
|
||||
for _, curChain := range r.SkippedMessageKeys {
|
||||
writtenChain, err := curChain.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle ratchet: %w", err)
|
||||
}
|
||||
written += writtenChain
|
||||
}
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the actual number of bytes the pickled ratchet will have.
|
||||
func (r Ratchet) PickleLen() int {
|
||||
length := r.RootKey.PickleLen()
|
||||
if r.SenderChains.IsSet {
|
||||
length += libolmpickle.PickleUInt32Len(1)
|
||||
length += r.SenderChains.PickleLen()
|
||||
} else {
|
||||
length += libolmpickle.PickleUInt32Len(0)
|
||||
}
|
||||
length += libolmpickle.PickleUInt32Len(uint32(len(r.ReceiverChains)))
|
||||
length += len(r.ReceiverChains) * receiverChain{}.PickleLen()
|
||||
length += libolmpickle.PickleUInt32Len(uint32(len(r.SkippedMessageKeys)))
|
||||
length += len(r.SkippedMessageKeys) * skippedMessageKey{}.PickleLen()
|
||||
return length
|
||||
}
|
||||
|
||||
// PickleLen returns the minimum number of bytes the pickled ratchet must have.
|
||||
func (r Ratchet) PickleLenMin() int {
|
||||
length := r.RootKey.PickleLen()
|
||||
length += libolmpickle.PickleUInt32Len(0)
|
||||
length += libolmpickle.PickleUInt32Len(0)
|
||||
length += libolmpickle.PickleUInt32Len(0)
|
||||
return length
|
||||
}
|
||||
55
vendor/maunium.net/go/mautrix/crypto/goolm/olm/skipped_message.go
generated
vendored
Normal file
55
vendor/maunium.net/go/mautrix/crypto/goolm/olm/skipped_message.go
generated
vendored
Normal file
@@ -0,0 +1,55 @@
|
||||
package olm
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
// skippedMessageKey stores a skipped message key
|
||||
type skippedMessageKey struct {
|
||||
RKey crypto.Curve25519PublicKey `json:"ratchet_key"`
|
||||
MKey messageKey `json:"message_key"`
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
|
||||
func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) {
|
||||
curPos := 0
|
||||
readBytes, err := r.RKey.UnpickleLibOlm(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = r.MKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < r.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written, err := r.RKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", err)
|
||||
}
|
||||
writtenChain, err := r.MKey.PickleLibOlm(target)
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle sender chain: %w", err)
|
||||
}
|
||||
written += writtenChain
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled chain will have.
|
||||
func (r skippedMessageKey) PickleLen() int {
|
||||
length := r.RKey.PickleLen()
|
||||
length += r.MKey.PickleLen()
|
||||
return length
|
||||
}
|
||||
165
vendor/maunium.net/go/mautrix/crypto/goolm/pk/decryption.go
generated
vendored
Normal file
165
vendor/maunium.net/go/mautrix/crypto/goolm/pk/decryption.go
generated
vendored
Normal file
@@ -0,0 +1,165 @@
|
||||
package pk
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
decryptionPickleVersionJSON uint8 = 1
|
||||
decryptionPickleVersionLibOlm uint32 = 1
|
||||
)
|
||||
|
||||
// Decryption is used to decrypt pk messages
|
||||
type Decryption struct {
|
||||
KeyPair crypto.Curve25519KeyPair `json:"key_pair"`
|
||||
}
|
||||
|
||||
// NewDecryption returns a new Decryption with a new generated key pair.
|
||||
func NewDecryption() (*Decryption, error) {
|
||||
keyPair, err := crypto.Curve25519GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Decryption{
|
||||
KeyPair: keyPair,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewDescriptionFromPrivate resturns a new Decryption with the private key fixed.
|
||||
func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decryption, error) {
|
||||
s := &Decryption{}
|
||||
keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.KeyPair = keyPair
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// PubKey returns the public key base 64 encoded.
|
||||
func (s Decryption) PubKey() id.Curve25519 {
|
||||
return s.KeyPair.B64Encoded()
|
||||
}
|
||||
|
||||
// PrivateKey returns the private key.
|
||||
func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey {
|
||||
return s.KeyPair.PrivateKey
|
||||
}
|
||||
|
||||
// Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret.
|
||||
func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) {
|
||||
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
decodedMAC, err := goolm.Base64Decode(mac)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
cipher := cipher.NewAESSHA256(nil)
|
||||
verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if !verified {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC)
|
||||
}
|
||||
plaintext, err := cipher.Decrypt(sharedSecret, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key)
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
||||
func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error {
|
||||
return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON)
|
||||
}
|
||||
|
||||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (a *Decryption) Unpickle(pickled, key []byte) error {
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = a.UnpickleLibOlm(decrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read.
|
||||
func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) {
|
||||
//First 4 bytes are the accountPickleVersion
|
||||
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch pickledVersion {
|
||||
case decryptionPickleVersionLibOlm:
|
||||
default:
|
||||
return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion)
|
||||
}
|
||||
readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm().
|
||||
func (a Decryption) Pickle(key []byte) ([]byte, error) {
|
||||
pickeledBytes := make([]byte, a.PickleLen())
|
||||
written, err := a.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if written != len(pickeledBytes) {
|
||||
return nil, errors.New("number of written bytes not correct")
|
||||
}
|
||||
encrypted, err := cipher.Pickle(key, pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (a Decryption) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < a.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target)
|
||||
writtenKey, err := a.KeyPair.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle Decryption: %w", err)
|
||||
}
|
||||
written += writtenKey
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled Decryption will have.
|
||||
func (a Decryption) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm)
|
||||
length += a.KeyPair.PickleLen()
|
||||
return length
|
||||
}
|
||||
49
vendor/maunium.net/go/mautrix/crypto/goolm/pk/encryption.go
generated
vendored
Normal file
49
vendor/maunium.net/go/mautrix/crypto/goolm/pk/encryption.go
generated
vendored
Normal file
@@ -0,0 +1,49 @@
|
||||
package pk
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
// Encryption is used to encrypt pk messages
|
||||
type Encryption struct {
|
||||
RecipientKey crypto.Curve25519PublicKey `json:"recipient_key"`
|
||||
}
|
||||
|
||||
// NewEncryption returns a new Encryption with the base64 encoded public key of the recipient
|
||||
func NewEncryption(pubKey id.Curve25519) (*Encryption, error) {
|
||||
pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Encryption{
|
||||
RecipientKey: pubKeyDecoded,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext with the privateKey and returns the ciphertext and base64 encoded MAC.
|
||||
func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519PrivateKey) (ciphertext, mac []byte, err error) {
|
||||
keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
sharedSecret, err := keyPair.SharedSecret(e.RecipientKey)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
cipher := cipher.NewAESSHA256(nil)
|
||||
ciphertext, err = cipher.Encrypt(sharedSecret, plaintext)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
mac, err = cipher.MAC(sharedSecret, ciphertext)
|
||||
if err != nil {
|
||||
return nil, nil, err
|
||||
}
|
||||
return ciphertext, goolm.Base64Encode(mac), nil
|
||||
}
|
||||
44
vendor/maunium.net/go/mautrix/crypto/goolm/pk/signing.go
generated
vendored
Normal file
44
vendor/maunium.net/go/mautrix/crypto/goolm/pk/signing.go
generated
vendored
Normal file
@@ -0,0 +1,44 @@
|
||||
package pk
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Signing is used for signing a pk
|
||||
type Signing struct {
|
||||
KeyPair crypto.Ed25519KeyPair `json:"key_pair"`
|
||||
Seed []byte `json:"seed"`
|
||||
}
|
||||
|
||||
// NewSigningFromSeed constructs a new Signing based on a seed.
|
||||
func NewSigningFromSeed(seed []byte) (*Signing, error) {
|
||||
s := &Signing{}
|
||||
s.Seed = seed
|
||||
s.KeyPair = crypto.Ed25519GenerateFromSeed(seed)
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewSigning returns a Signing based on a random seed
|
||||
func NewSigning() (*Signing, error) {
|
||||
seed := make([]byte, 32)
|
||||
_, err := rand.Read(seed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return NewSigningFromSeed(seed)
|
||||
}
|
||||
|
||||
// Sign returns the signature of the message base64 encoded.
|
||||
func (s Signing) Sign(message []byte) []byte {
|
||||
signature := s.KeyPair.Sign(message)
|
||||
return goolm.Base64Encode(signature)
|
||||
}
|
||||
|
||||
// PublicKey returns the public key of the key pair base 64 encoded.
|
||||
func (s Signing) PublicKey() id.Ed25519 {
|
||||
return s.KeyPair.B64Encoded()
|
||||
}
|
||||
76
vendor/maunium.net/go/mautrix/crypto/goolm/sas/main.go
generated
vendored
Normal file
76
vendor/maunium.net/go/mautrix/crypto/goolm/sas/main.go
generated
vendored
Normal file
@@ -0,0 +1,76 @@
|
||||
// sas provides the means to do SAS between keys
|
||||
package sas
|
||||
|
||||
import (
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
)
|
||||
|
||||
// SAS contains the key pair and secret for SAS.
|
||||
type SAS struct {
|
||||
KeyPair crypto.Curve25519KeyPair
|
||||
Secret []byte
|
||||
}
|
||||
|
||||
// New creates a new SAS with a new key pair.
|
||||
func New() (*SAS, error) {
|
||||
kp, err := crypto.Curve25519GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := &SAS{
|
||||
KeyPair: kp,
|
||||
}
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// GetPubkey returns the public key of the key pair base64 encoded
|
||||
func (s SAS) GetPubkey() []byte {
|
||||
return goolm.Base64Encode(s.KeyPair.PublicKey)
|
||||
}
|
||||
|
||||
// SetTheirKey sets the key of the other party and computes the shared secret.
|
||||
func (s *SAS) SetTheirKey(key []byte) error {
|
||||
keyDecoded, err := goolm.Base64Decode(key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.Secret = sharedSecret
|
||||
return nil
|
||||
}
|
||||
|
||||
// GenerateBytes creates length bytes from the shared secret and info.
|
||||
func (s SAS) GenerateBytes(info []byte, length uint) ([]byte, error) {
|
||||
byteReader := crypto.HKDFSHA256(s.Secret, nil, info)
|
||||
output := make([]byte, length)
|
||||
if _, err := io.ReadFull(byteReader, output); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return output, nil
|
||||
}
|
||||
|
||||
// calculateMAC returns a base64 encoded MAC of input.
|
||||
func (s *SAS) calculateMAC(input, info []byte, length uint) ([]byte, error) {
|
||||
key, err := s.GenerateBytes(info, length)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
mac := crypto.HMACSHA256(key, input)
|
||||
return goolm.Base64Encode(mac), nil
|
||||
}
|
||||
|
||||
// CalculateMACFixes returns a base64 encoded, 32 byte long MAC of input.
|
||||
func (s SAS) CalculateMAC(input, info []byte) ([]byte, error) {
|
||||
return s.calculateMAC(input, info, 32)
|
||||
}
|
||||
|
||||
// CalculateMACLongKDF returns a base64 encoded, 256 byte long MAC of input.
|
||||
func (s SAS) CalculateMACLongKDF(input, info []byte) ([]byte, error) {
|
||||
return s.calculateMAC(input, info, 256)
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/crypto/goolm/session/main.go
generated
vendored
Normal file
2
vendor/maunium.net/go/mautrix/crypto/goolm/session/main.go
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
// session provides the different types of sessions for en/decrypting of messages
|
||||
package session
|
||||
276
vendor/maunium.net/go/mautrix/crypto/goolm/session/megolm_inbound_session.go
generated
vendored
Normal file
276
vendor/maunium.net/go/mautrix/crypto/goolm/session/megolm_inbound_session.go
generated
vendored
Normal file
@@ -0,0 +1,276 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/megolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/message"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
megolmInboundSessionPickleVersionJSON byte = 1
|
||||
megolmInboundSessionPickleVersionLibOlm uint32 = 2
|
||||
)
|
||||
|
||||
// MegolmInboundSession stores information about the sessions of receive.
|
||||
type MegolmInboundSession struct {
|
||||
Ratchet megolm.Ratchet `json:"ratchet"`
|
||||
SigningKey crypto.Ed25519PublicKey `json:"signing_key"`
|
||||
InitialRatchet megolm.Ratchet `json:"initial_ratchet"`
|
||||
SigningKeyVerified bool `json:"signing_key_verified"` //not used for now
|
||||
}
|
||||
|
||||
// NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message.
|
||||
func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) {
|
||||
var err error
|
||||
input, err = goolm.Base64Decode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := message.MegolmSessionSharing{}
|
||||
err = msg.VerifyAndDecode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o := &MegolmInboundSession{}
|
||||
o.SigningKey = msg.PublicKey
|
||||
o.SigningKeyVerified = true
|
||||
ratchet, err := megolm.New(msg.Counter, msg.RatchetData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Ratchet = *ratchet
|
||||
o.InitialRatchet = *ratchet
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// NewMegolmInboundSessionFromExport creates a new MegolmInboundSession from a base64 encoded session export message.
|
||||
func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, error) {
|
||||
var err error
|
||||
input, err = goolm.Base64Decode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msg := message.MegolmSessionExport{}
|
||||
err = msg.Decode(input)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o := &MegolmInboundSession{}
|
||||
o.SigningKey = msg.PublicKey
|
||||
ratchet, err := megolm.New(msg.Counter, msg.RatchetData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Ratchet = *ratchet
|
||||
o.InitialRatchet = *ratchet
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key.
|
||||
func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
a := &MegolmInboundSession{}
|
||||
err := a.Unpickle(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// getRatchet tries to find the correct ratchet for a messageIndex.
|
||||
func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) {
|
||||
// pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that
|
||||
if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) {
|
||||
o.Ratchet.AdvanceTo(messageIndex)
|
||||
return &o.Ratchet, nil
|
||||
}
|
||||
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
|
||||
// the counter is before our initial ratchet - we can't decode this
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrRatchetNotAvailable)
|
||||
}
|
||||
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
|
||||
copiedRatchet := o.InitialRatchet
|
||||
copiedRatchet.AdvanceTo(messageIndex)
|
||||
return &copiedRatchet, nil
|
||||
|
||||
}
|
||||
|
||||
// Decrypt decrypts a base64 encoded group message.
|
||||
func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) {
|
||||
if o.SigningKey == nil {
|
||||
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
|
||||
}
|
||||
decoded, err := goolm.Base64Decode(ciphertext)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
msg := &message.GroupMessage{}
|
||||
err = msg.Decode(decoded)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
if msg.Version != protocolVersion {
|
||||
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion)
|
||||
}
|
||||
if msg.Ciphertext == nil || !msg.HasMessageIndex {
|
||||
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
|
||||
}
|
||||
|
||||
// verify signature
|
||||
verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded)
|
||||
if !verifiedSignature {
|
||||
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadSignature)
|
||||
}
|
||||
|
||||
targetRatch, err := o.getRatchet(msg.MessageIndex)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
|
||||
decrypted, err := targetRatch.Decrypt(decoded, &o.SigningKey, msg)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
o.SigningKeyVerified = true
|
||||
return decrypted, msg.MessageIndex, nil
|
||||
|
||||
}
|
||||
|
||||
// SessionID returns the base64 endoded signing key
|
||||
func (o MegolmInboundSession) SessionID() id.SessionID {
|
||||
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey))
|
||||
}
|
||||
|
||||
// PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key)
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
||||
func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error {
|
||||
return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON)
|
||||
}
|
||||
|
||||
// SessionExportMessage creates an base64 encoded export of the session.
|
||||
func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) {
|
||||
ratchet, err := o.getRatchet(messageIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return ratchet.SessionExportMessage(o.SigningKey)
|
||||
}
|
||||
|
||||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error {
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = o.UnpickleLibOlm(decrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
|
||||
func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) {
|
||||
//First 4 bytes are the accountPickleVersion
|
||||
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch pickledVersion {
|
||||
case megolmInboundSessionPickleVersionLibOlm, 1:
|
||||
default:
|
||||
return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion)
|
||||
}
|
||||
readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
if pickledVersion == 1 {
|
||||
// pickle v1 had no signing_key_verified field (all keyshares were verified at import time)
|
||||
o.SigningKeyVerified = true
|
||||
} else {
|
||||
o.SigningKeyVerified, readBytes, err = libolmpickle.UnpickleBool(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
}
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm().
|
||||
func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
|
||||
pickeledBytes := make([]byte, o.PickleLen())
|
||||
written, err := o.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if written != len(pickeledBytes) {
|
||||
return nil, errors.New("number of written bytes not correct")
|
||||
}
|
||||
encrypted, err := cipher.Pickle(key, pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < o.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target)
|
||||
writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
|
||||
}
|
||||
written += writtenInitRatchet
|
||||
writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
|
||||
}
|
||||
written += writtenRatchet
|
||||
writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
|
||||
}
|
||||
written += writtenPubKey
|
||||
written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:])
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled session will have.
|
||||
func (o MegolmInboundSession) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm)
|
||||
length += o.InitialRatchet.PickleLen()
|
||||
length += o.Ratchet.PickleLen()
|
||||
length += o.SigningKey.PickleLen()
|
||||
length += libolmpickle.PickleBoolLen(o.SigningKeyVerified)
|
||||
return length
|
||||
}
|
||||
171
vendor/maunium.net/go/mautrix/crypto/goolm/session/megolm_outbound_session.go
generated
vendored
Normal file
171
vendor/maunium.net/go/mautrix/crypto/goolm/session/megolm_outbound_session.go
generated
vendored
Normal file
@@ -0,0 +1,171 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"crypto/rand"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/megolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
)
|
||||
|
||||
const (
|
||||
megolmOutboundSessionPickleVersion byte = 1
|
||||
megolmOutboundSessionPickleVersionLibOlm uint32 = 1
|
||||
)
|
||||
|
||||
// MegolmOutboundSession stores information about the sessions to send.
|
||||
type MegolmOutboundSession struct {
|
||||
Ratchet megolm.Ratchet `json:"ratchet"`
|
||||
SigningKey crypto.Ed25519KeyPair `json:"signing_key"`
|
||||
}
|
||||
|
||||
// NewMegolmOutboundSession creates a new MegolmOutboundSession.
|
||||
func NewMegolmOutboundSession() (*MegolmOutboundSession, error) {
|
||||
o := &MegolmOutboundSession{}
|
||||
var err error
|
||||
o.SigningKey, err = crypto.Ed25519GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte
|
||||
_, err = rand.Read(randomData[:])
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ratchet, err := megolm.New(0, randomData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
o.Ratchet = *ratchet
|
||||
return o, nil
|
||||
}
|
||||
|
||||
// MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key.
|
||||
func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
a := &MegolmOutboundSession{}
|
||||
err := a.Unpickle(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// Encrypt encrypts the plaintext as a base64 encoded group message.
|
||||
func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
|
||||
encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return goolm.Base64Encode(encrypted), nil
|
||||
}
|
||||
|
||||
// SessionID returns the base64 endoded public signing key
|
||||
func (o MegolmOutboundSession) SessionID() id.SessionID {
|
||||
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey))
|
||||
}
|
||||
|
||||
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key)
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format.
|
||||
func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error {
|
||||
return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion)
|
||||
}
|
||||
|
||||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = o.UnpickleLibOlm(decrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
|
||||
func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) {
|
||||
//First 4 bytes are the accountPickleVersion
|
||||
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
switch pickledVersion {
|
||||
case megolmOutboundSessionPickleVersionLibOlm:
|
||||
default:
|
||||
return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion)
|
||||
}
|
||||
readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm().
|
||||
func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
|
||||
pickeledBytes := make([]byte, o.PickleLen())
|
||||
written, err := o.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if written != len(pickeledBytes) {
|
||||
return nil, errors.New("number of written bytes not correct")
|
||||
}
|
||||
encrypted, err := cipher.Pickle(key, pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < o.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target)
|
||||
writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
|
||||
}
|
||||
written += writtenRatchet
|
||||
writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
|
||||
}
|
||||
written += writtenPubKey
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the number of bytes the pickled session will have.
|
||||
func (o MegolmOutboundSession) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm)
|
||||
length += o.Ratchet.PickleLen()
|
||||
length += o.SigningKey.PickleLen()
|
||||
return length
|
||||
}
|
||||
|
||||
func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) {
|
||||
return o.Ratchet.SessionSharingMessage(o.SigningKey)
|
||||
}
|
||||
476
vendor/maunium.net/go/mautrix/crypto/goolm/session/olm_session.go
generated
vendored
Normal file
476
vendor/maunium.net/go/mautrix/crypto/goolm/session/olm_session.go
generated
vendored
Normal file
@@ -0,0 +1,476 @@
|
||||
package session
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"encoding/base64"
|
||||
"errors"
|
||||
"fmt"
|
||||
"io"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
||||
"maunium.net/go/mautrix/crypto/goolm/message"
|
||||
"maunium.net/go/mautrix/crypto/goolm/olm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
const (
|
||||
olmSessionPickleVersionJSON uint8 = 1
|
||||
olmSessionPickleVersionLibOlm uint32 = 1
|
||||
)
|
||||
|
||||
const (
|
||||
protocolVersion = 0x3
|
||||
)
|
||||
|
||||
// OlmSession stores all information for an olm session
|
||||
type OlmSession struct {
|
||||
ReceivedMessage bool `json:"received_message"`
|
||||
AliceIdentityKey crypto.Curve25519PublicKey `json:"alice_id_key"`
|
||||
AliceBaseKey crypto.Curve25519PublicKey `json:"alice_base_key"`
|
||||
BobOneTimeKey crypto.Curve25519PublicKey `json:"bob_one_time_key"`
|
||||
Ratchet olm.Ratchet `json:"ratchet"`
|
||||
}
|
||||
|
||||
// SearchOTKFunc is used to retrieve a crypto.OneTimeKey from a public key.
|
||||
type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey
|
||||
|
||||
// OlmSessionFromJSONPickled loads an OlmSession from a pickled base64 string. Decrypts
|
||||
// the Session using the supplied key.
|
||||
func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
a := &OlmSession{}
|
||||
err := a.UnpickleAsJSON(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key.
|
||||
func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
a := &OlmSession{}
|
||||
err := a.Unpickle(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return a, nil
|
||||
}
|
||||
|
||||
// NewOlmSession creates a new Session.
|
||||
func NewOlmSession() *OlmSession {
|
||||
s := &OlmSession{}
|
||||
s.Ratchet = *olm.New()
|
||||
return s
|
||||
}
|
||||
|
||||
// NewOutboundOlmSession creates a new outbound session for sending the first message to a
|
||||
// given curve25519 identityKey and oneTimeKey.
|
||||
func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) {
|
||||
s := NewOlmSession()
|
||||
//generate E_A
|
||||
baseKey, err := crypto.Curve25519GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//generate T_0
|
||||
ratchetKey, err := crypto.Curve25519GenerateKey(nil)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
//Calculate shared secret via Triple Diffie-Hellman
|
||||
var secret []byte
|
||||
//ECDH(I_A,E_B)
|
||||
idSecret, err := identityKeyAlice.SharedSecret(oneTimeKeyBob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//ECDH(E_A,I_B)
|
||||
baseIdSecret, err := baseKey.SharedSecret(identityKeyBob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//ECDH(E_A,E_B)
|
||||
baseOneTimeSecret, err := baseKey.SharedSecret(oneTimeKeyBob)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
secret = append(secret, idSecret...)
|
||||
secret = append(secret, baseIdSecret...)
|
||||
secret = append(secret, baseOneTimeSecret...)
|
||||
//Init Ratchet
|
||||
s.Ratchet.InitializeAsAlice(secret, ratchetKey)
|
||||
s.AliceIdentityKey = identityKeyAlice.PublicKey
|
||||
s.AliceBaseKey = baseKey.PublicKey
|
||||
s.BobOneTimeKey = oneTimeKeyBob
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewInboundOlmSession creates a new inbound session from receiving the first message.
|
||||
func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, receivedOTKMsg []byte, searchBobOTK SearchOTKFunc, identityKeyBob crypto.Curve25519KeyPair) (*OlmSession, error) {
|
||||
decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s := NewOlmSession()
|
||||
|
||||
//decode OneTimeKeyMessage
|
||||
oneTimeMsg := message.PreKeyMessage{}
|
||||
err = oneTimeMsg.Decode(decodedOTKMsg)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("OneTimeKeyMessage decode: %w", err)
|
||||
}
|
||||
if !oneTimeMsg.CheckFields(identityKeyAlice) {
|
||||
return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", goolm.ErrBadMessageFormat)
|
||||
}
|
||||
|
||||
//Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked
|
||||
// by oneTimeMsg.CheckFields
|
||||
if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 {
|
||||
//if both are set, compare them
|
||||
if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) {
|
||||
return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", goolm.ErrBadMessageKeyID)
|
||||
}
|
||||
}
|
||||
if identityKeyAlice == nil {
|
||||
//for downstream use set
|
||||
identityKeyAlice = &oneTimeMsg.IdentityKey
|
||||
}
|
||||
|
||||
oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey)
|
||||
if oneTimeKeyBob == nil {
|
||||
return nil, fmt.Errorf("ourOneTimeKey: %w", goolm.ErrBadMessageKeyID)
|
||||
}
|
||||
|
||||
//Calculate shared secret via Triple Diffie-Hellman
|
||||
var secret []byte
|
||||
//ECDH(E_B,I_A)
|
||||
idSecret, err := oneTimeKeyBob.Key.SharedSecret(*identityKeyAlice)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//ECDH(I_B,E_A)
|
||||
baseIdSecret, err := identityKeyBob.SharedSecret(oneTimeMsg.BaseKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
//ECDH(E_B,E_A)
|
||||
baseOneTimeSecret, err := oneTimeKeyBob.Key.SharedSecret(oneTimeMsg.BaseKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
secret = append(secret, idSecret...)
|
||||
secret = append(secret, baseIdSecret...)
|
||||
secret = append(secret, baseOneTimeSecret...)
|
||||
//decode message
|
||||
msg := message.Message{}
|
||||
err = msg.Decode(oneTimeMsg.Message)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("Message decode: %w", err)
|
||||
}
|
||||
|
||||
if len(msg.RatchetKey) == 0 {
|
||||
return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat)
|
||||
}
|
||||
//Init Ratchet
|
||||
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
|
||||
s.AliceBaseKey = oneTimeMsg.BaseKey
|
||||
s.AliceIdentityKey = oneTimeMsg.IdentityKey
|
||||
s.BobOneTimeKey = oneTimeKeyBob.Key.PublicKey
|
||||
|
||||
//https://gitlab.matrix.org/matrix-org/olm/blob/master/docs/olm.md states to remove the oneTimeKey
|
||||
//this is done via the account itself
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
||||
func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) {
|
||||
return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key)
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format.
|
||||
func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error {
|
||||
return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON)
|
||||
}
|
||||
|
||||
// ID returns an identifier for this Session. Will be the same for both ends of the conversation.
|
||||
// Generated by hashing the public keys used to create the session.
|
||||
func (s OlmSession) ID() id.SessionID {
|
||||
message := make([]byte, 3*crypto.Curve25519KeyLength)
|
||||
copy(message, s.AliceIdentityKey)
|
||||
copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey)
|
||||
copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey)
|
||||
hash := crypto.SHA256(message)
|
||||
res := id.SessionID(goolm.Base64Encode(hash))
|
||||
return res
|
||||
}
|
||||
|
||||
// HasReceivedMessage returns true if this session has received any message.
|
||||
func (s OlmSession) HasReceivedMessage() bool {
|
||||
return s.ReceivedMessage
|
||||
}
|
||||
|
||||
// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match.
|
||||
func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) {
|
||||
if len(receivedOTKMsg) == 0 {
|
||||
return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
|
||||
var theirIdentityKey *crypto.Curve25519PublicKey
|
||||
if theirIdentityKeyEncoded != nil {
|
||||
decodedKey, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKeyEncoded))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
theirIdentityKeyByte := crypto.Curve25519PublicKey(decodedKey)
|
||||
theirIdentityKey = &theirIdentityKeyByte
|
||||
}
|
||||
|
||||
msg := message.PreKeyMessage{}
|
||||
err = msg.Decode(decodedOTKMsg)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
if !msg.CheckFields(theirIdentityKey) {
|
||||
return false, nil
|
||||
}
|
||||
|
||||
same := true
|
||||
if msg.IdentityKey != nil {
|
||||
same = same && msg.IdentityKey.Equal(s.AliceIdentityKey)
|
||||
}
|
||||
if theirIdentityKey != nil {
|
||||
same = same && theirIdentityKey.Equal(s.AliceIdentityKey)
|
||||
}
|
||||
same = same && bytes.Equal(msg.BaseKey, s.AliceBaseKey)
|
||||
same = same && bytes.Equal(msg.OneTimeKey, s.BobOneTimeKey)
|
||||
return same, nil
|
||||
}
|
||||
|
||||
// EncryptMsgType returns the type of the next message that Encrypt will
|
||||
// return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg.
|
||||
// Returns MsgTypeMsg if the message will be a normal message.
|
||||
func (s OlmSession) EncryptMsgType() id.OlmMsgType {
|
||||
if s.ReceivedMessage {
|
||||
return id.OlmMsgTypeMsg
|
||||
}
|
||||
return id.OlmMsgTypePreKey
|
||||
}
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations.
|
||||
func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) {
|
||||
if len(plaintext) == 0 {
|
||||
return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
messageType := s.EncryptMsgType()
|
||||
encrypted, err := s.Ratchet.Encrypt(plaintext, reader)
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
result := encrypted
|
||||
if !s.ReceivedMessage {
|
||||
msg := message.PreKeyMessage{}
|
||||
msg.Version = protocolVersion
|
||||
msg.OneTimeKey = s.BobOneTimeKey
|
||||
msg.IdentityKey = s.AliceIdentityKey
|
||||
msg.BaseKey = s.AliceBaseKey
|
||||
msg.Message = encrypted
|
||||
|
||||
var err error
|
||||
messageBody, err := msg.Encode()
|
||||
if err != nil {
|
||||
return 0, nil, err
|
||||
}
|
||||
result = messageBody
|
||||
}
|
||||
|
||||
return messageType, goolm.Base64Encode(result), nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a base64 encoded message using the Session.
|
||||
func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) {
|
||||
if len(crypttext) == 0 {
|
||||
return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput)
|
||||
}
|
||||
decodedCrypttext, err := goolm.Base64Decode(crypttext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgBody := decodedCrypttext
|
||||
if msgType != id.OlmMsgTypeMsg {
|
||||
//Pre-Key Message
|
||||
msg := message.PreKeyMessage{}
|
||||
err := msg.Decode(decodedCrypttext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
msgBody = msg.Message
|
||||
}
|
||||
plaintext, err := s.Ratchet.Decrypt(msgBody)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.ReceivedMessage = true
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
||||
// The decrypted value is then passed to UnpickleLibOlm.
|
||||
func (o *OlmSession) Unpickle(pickled, key []byte) error {
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = o.UnpickleLibOlm(decrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
|
||||
func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) {
|
||||
//First 4 bytes are the accountPickleVersion
|
||||
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
includesChainIndex := true
|
||||
switch pickledVersion {
|
||||
case olmSessionPickleVersionLibOlm:
|
||||
includesChainIndex = false
|
||||
case uint32(0x80000001):
|
||||
includesChainIndex = true
|
||||
default:
|
||||
return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion)
|
||||
}
|
||||
var readBytes int
|
||||
o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = o.AliceIdentityKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = o.AliceBaseKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = o.BobOneTimeKey.UnpickleLibOlm(value[curPos:])
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:], includesChainIndex)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
curPos += readBytes
|
||||
return curPos, nil
|
||||
}
|
||||
|
||||
// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm().
|
||||
func (o OlmSession) Pickle(key []byte) ([]byte, error) {
|
||||
pickeledBytes := make([]byte, o.PickleLen())
|
||||
written, err := o.PickleLibOlm(pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
if written != len(pickeledBytes) {
|
||||
return nil, errors.New("number of written bytes not correct")
|
||||
}
|
||||
encrypted, err := cipher.Pickle(key, pickeledBytes)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
||||
// It returns the number of bytes written.
|
||||
func (o OlmSession) PickleLibOlm(target []byte) (int, error) {
|
||||
if len(target) < o.PickleLen() {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
|
||||
}
|
||||
written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target)
|
||||
written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:])
|
||||
writtenRatchet, err := o.AliceIdentityKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
|
||||
}
|
||||
written += writtenRatchet
|
||||
writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
|
||||
}
|
||||
written += writtenRatchet
|
||||
writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
|
||||
}
|
||||
written += writtenRatchet
|
||||
writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:])
|
||||
if err != nil {
|
||||
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
|
||||
}
|
||||
written += writtenRatchet
|
||||
return written, nil
|
||||
}
|
||||
|
||||
// PickleLen returns the actual number of bytes the pickled session will have.
|
||||
func (o OlmSession) PickleLen() int {
|
||||
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
|
||||
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
|
||||
length += o.AliceIdentityKey.PickleLen()
|
||||
length += o.AliceBaseKey.PickleLen()
|
||||
length += o.BobOneTimeKey.PickleLen()
|
||||
length += o.Ratchet.PickleLen()
|
||||
return length
|
||||
}
|
||||
|
||||
// PickleLenMin returns the minimum number of bytes the pickled session must have.
|
||||
func (o OlmSession) PickleLenMin() int {
|
||||
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
|
||||
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
|
||||
length += o.AliceIdentityKey.PickleLen()
|
||||
length += o.AliceBaseKey.PickleLen()
|
||||
length += o.BobOneTimeKey.PickleLen()
|
||||
length += o.Ratchet.PickleLenMin()
|
||||
return length
|
||||
}
|
||||
|
||||
// Describe returns a string describing the current state of the session for debugging.
|
||||
func (o OlmSession) Describe() string {
|
||||
var res string
|
||||
if o.Ratchet.SenderChains.IsSet {
|
||||
res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index)
|
||||
} else {
|
||||
res += "sender chain index: "
|
||||
}
|
||||
res += "receiver chain indicies:"
|
||||
for _, curChain := range o.Ratchet.ReceiverChains {
|
||||
res += fmt.Sprintf(" %d", curChain.CKey.Index)
|
||||
}
|
||||
res += " skipped message keys:"
|
||||
for _, curSkip := range o.Ratchet.SkippedMessageKeys {
|
||||
res += fmt.Sprintf(" %d", curSkip.MKey.Index)
|
||||
}
|
||||
return res
|
||||
}
|
||||
23
vendor/maunium.net/go/mautrix/crypto/goolm/utilities/main.go
generated
vendored
Normal file
23
vendor/maunium.net/go/mautrix/crypto/goolm/utilities/main.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
package utilities
|
||||
|
||||
import (
|
||||
"encoding/base64"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// VerifySignature verifies an ed25519 signature.
|
||||
func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) {
|
||||
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
signatureDecoded, err := goolm.Base64Decode(signature)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
publicKey := crypto.Ed25519PublicKey(keyDecoded)
|
||||
return publicKey.Verify(message, signatureDecoded), nil
|
||||
}
|
||||
60
vendor/maunium.net/go/mautrix/crypto/goolm/utilities/pickle.go
generated
vendored
Normal file
60
vendor/maunium.net/go/mautrix/crypto/goolm/utilities/pickle.go
generated
vendored
Normal file
@@ -0,0 +1,60 @@
|
||||
package utilities
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
||||
)
|
||||
|
||||
// PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format.
|
||||
func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) {
|
||||
if len(key) == 0 {
|
||||
return nil, fmt.Errorf("pickle: %w", goolm.ErrNoKeyProvided)
|
||||
}
|
||||
marshaled, err := json.Marshal(object)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pickle marshal: %w", err)
|
||||
}
|
||||
marshaled = append([]byte{pickleVersion}, marshaled...)
|
||||
toEncrypt := make([]byte, len(marshaled))
|
||||
copy(toEncrypt, marshaled)
|
||||
//pad marshaled to get block size
|
||||
if len(marshaled)%cipher.PickleBlockSize() != 0 {
|
||||
padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize()
|
||||
toEncrypt = make([]byte, len(marshaled)+padding)
|
||||
copy(toEncrypt, marshaled)
|
||||
}
|
||||
encrypted, err := cipher.Pickle(key, toEncrypt)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("pickle encrypt: %w", err)
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
// UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
||||
func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
|
||||
if len(key) == 0 {
|
||||
return fmt.Errorf("unpickle: %w", goolm.ErrNoKeyProvided)
|
||||
}
|
||||
decrypted, err := cipher.Unpickle(key, pickled)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unpickle decrypt: %w", err)
|
||||
}
|
||||
//unpad decrypted so unmarshal works
|
||||
for i := len(decrypted) - 1; i >= 0; i-- {
|
||||
if decrypted[i] != 0 {
|
||||
decrypted = decrypted[:i+1]
|
||||
break
|
||||
}
|
||||
}
|
||||
if decrypted[0] != pickleVersion {
|
||||
return fmt.Errorf("unpickle: %w", goolm.ErrWrongPickleVersion)
|
||||
}
|
||||
err = json.Unmarshal(decrypted[1:], object)
|
||||
if err != nil {
|
||||
return fmt.Errorf("unpickle unmarshal: %w", err)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
16
vendor/maunium.net/go/mautrix/crypto/keyimport.go
generated
vendored
16
vendor/maunium.net/go/mautrix/crypto/keyimport.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -8,6 +8,7 @@ package crypto
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"context"
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
@@ -91,7 +92,7 @@ func decryptKeyExport(passphrase string, exportData []byte) ([]ExportedSession,
|
||||
return sessionsJSON, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, error) {
|
||||
func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session ExportedSession) (bool, error) {
|
||||
if session.Algorithm != id.AlgorithmMegolmV1 {
|
||||
return false, ErrInvalidExportedAlgorithm
|
||||
}
|
||||
@@ -112,12 +113,12 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
|
||||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
}
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID())
|
||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||
// We already have an equivalent or better session in the store, so don't override it.
|
||||
return false, nil
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(igs.RoomID, igs.SenderKey, igs.ID(), igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, igs.RoomID, igs.SenderKey, igs.ID(), igs)
|
||||
if err != nil {
|
||||
return false, fmt.Errorf("failed to store imported session: %w", err)
|
||||
}
|
||||
@@ -127,7 +128,7 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
|
||||
|
||||
// ImportKeys imports data that was exported with the format specified in the Matrix spec.
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#key-exports
|
||||
func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, error) {
|
||||
func (mach *OlmMachine) ImportKeys(ctx context.Context, passphrase string, data []byte) (int, int, error) {
|
||||
exportData, err := decodeKeyExport(data)
|
||||
if err != nil {
|
||||
return 0, 0, err
|
||||
@@ -143,8 +144,11 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er
|
||||
Str("room_id", session.RoomID.String()).
|
||||
Str("session_id", session.SessionID.String()).
|
||||
Logger()
|
||||
imported, err := mach.importExportedRoomKey(session)
|
||||
imported, err := mach.importExportedRoomKey(ctx, session)
|
||||
if err != nil {
|
||||
if ctx.Err() != nil {
|
||||
return count, len(sessions), ctx.Err()
|
||||
}
|
||||
log.Error().Err(err).Msg("Failed to import Megolm session from file")
|
||||
} else if imported {
|
||||
log.Debug().Msg("Imported Megolm session from file")
|
||||
|
||||
39
vendor/maunium.net/go/mautrix/crypto/keysharing.go
generated
vendored
39
vendor/maunium.net/go/mautrix/crypto/keysharing.go
generated
vendored
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -48,7 +48,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
|
||||
keyResponseReceived := make(chan struct{})
|
||||
mach.roomKeyRequestFilled.Store(sessionID, keyResponseReceived)
|
||||
|
||||
err := mach.SendRoomKeyRequest(roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}})
|
||||
err := mach.SendRoomKeyRequest(ctx, roomID, senderKey, sessionID, requestID, map[id.UserID][]id.DeviceID{toUser: {toDevice}})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -85,7 +85,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
|
||||
},
|
||||
}
|
||||
|
||||
mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceCancel)
|
||||
mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceCancel)
|
||||
}()
|
||||
return resChan, nil
|
||||
}
|
||||
@@ -99,7 +99,7 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
|
||||
// to the specific key request, but currently it only supports a single target device and is therefore deprecated.
|
||||
// A future function may properly support multiple targets and automatically canceling the other requests when receiving
|
||||
// the first response.
|
||||
func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error {
|
||||
func (mach *OlmMachine) SendRoomKeyRequest(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, requestID string, users map[id.UserID][]id.DeviceID) error {
|
||||
if len(requestID) == 0 {
|
||||
requestID = mach.Client.TxnID()
|
||||
}
|
||||
@@ -126,7 +126,7 @@ func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.Sender
|
||||
toDeviceReq.Messages[user][device] = requestEvent
|
||||
}
|
||||
}
|
||||
_, err := mach.Client.SendToDevice(event.ToDeviceRoomKeyRequest, toDeviceReq)
|
||||
_, err := mach.Client.SendToDevice(ctx, event.ToDeviceRoomKeyRequest, toDeviceReq)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -152,7 +152,10 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
||||
Msg("Mismatched session ID while creating inbound group session from forward")
|
||||
return false
|
||||
}
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get encryption event for room")
|
||||
}
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
@@ -178,7 +181,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
||||
MaxMessages: maxMessages,
|
||||
IsScheduled: content.IsScheduled,
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to store new inbound group session")
|
||||
return false
|
||||
@@ -188,7 +191,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
|
||||
return true
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) {
|
||||
func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShareRejection, device *id.Device, request event.RequestedKeyInfo) {
|
||||
if rejection.Code == "" {
|
||||
// If the rejection code is empty, it means don't share keys, but also don't tell the requester.
|
||||
return
|
||||
@@ -201,7 +204,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id
|
||||
Code: rejection.Code,
|
||||
Reason: rejection.Reason,
|
||||
}
|
||||
err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content)
|
||||
err := mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Err(err).
|
||||
Str("code", string(rejection.Code)).
|
||||
@@ -209,7 +212,7 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id
|
||||
Str("device_id", device.DeviceID.String()).
|
||||
Msg("Failed to send key share rejection")
|
||||
}
|
||||
err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content)
|
||||
err = mach.sendToOneDevice(ctx, device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Err(err).
|
||||
Str("code", string(rejection.Code)).
|
||||
@@ -270,23 +273,23 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
|
||||
|
||||
rejection := mach.AllowKeyShare(ctx, device, content.Body)
|
||||
if rejection != nil {
|
||||
mach.rejectKeyRequest(*rejection, device, content.Body)
|
||||
mach.rejectKeyRequest(ctx, *rejection, device, content.Body)
|
||||
return
|
||||
}
|
||||
|
||||
igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
|
||||
igs, err := mach.CryptoStore.GetGroupSession(ctx, content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Requested group session not available")
|
||||
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
|
||||
mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed to get group session to forward")
|
||||
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
|
||||
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
|
||||
}
|
||||
return
|
||||
} else if igs == nil {
|
||||
log.Error().Msg("Didn't find group session to forward")
|
||||
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
|
||||
mach.rejectKeyRequest(ctx, KeyShareRejectUnavailable, device, content.Body)
|
||||
return
|
||||
}
|
||||
if internalID := igs.ID(); internalID != content.Body.SessionID {
|
||||
@@ -299,7 +302,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
|
||||
exportedKey, err := igs.Internal.Export(firstKnownIndex)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to export group session to forward")
|
||||
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
|
||||
mach.rejectKeyRequest(ctx, KeyShareRejectInternalError, device, content.Body)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -331,7 +334,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
|
||||
Int("first_message_index", content.FirstMessageIndex).
|
||||
Logger()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, content.RoomID, "", content.SessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Acked group session was already redacted")
|
||||
@@ -351,7 +354,7 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
|
||||
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
|
||||
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
|
||||
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
|
||||
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
|
||||
err = mach.CryptoStore.RedactGroupSession(ctx, content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact group session")
|
||||
}
|
||||
|
||||
131
vendor/maunium.net/go/mautrix/crypto/machine.go
generated
vendored
131
vendor/maunium.net/go/mautrix/crypto/machine.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -33,6 +33,9 @@ type OlmMachine struct {
|
||||
|
||||
PlaintextMentions bool
|
||||
|
||||
// Never ask the server for keys automatically as a side effect.
|
||||
DisableKeyFetching bool
|
||||
|
||||
SendKeysMinTrust id.TrustState
|
||||
ShareKeysMinTrust id.TrustState
|
||||
|
||||
@@ -80,11 +83,11 @@ type OlmMachine struct {
|
||||
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
|
||||
type StateStore interface {
|
||||
// IsEncrypted returns whether a room is encrypted.
|
||||
IsEncrypted(id.RoomID) bool
|
||||
IsEncrypted(context.Context, id.RoomID) (bool, error)
|
||||
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
|
||||
GetEncryptionEvent(id.RoomID) *event.EncryptionEventContent
|
||||
GetEncryptionEvent(context.Context, id.RoomID) (*event.EncryptionEventContent, error)
|
||||
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
|
||||
FindSharedRooms(id.UserID) []id.RoomID
|
||||
FindSharedRooms(context.Context, id.UserID) ([]id.RoomID, error)
|
||||
}
|
||||
|
||||
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
|
||||
@@ -131,8 +134,8 @@ func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
|
||||
|
||||
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
|
||||
// This must be called before using the machine.
|
||||
func (mach *OlmMachine) Load() (err error) {
|
||||
mach.account, err = mach.CryptoStore.GetAccount()
|
||||
func (mach *OlmMachine) Load(ctx context.Context) (err error) {
|
||||
mach.account, err = mach.CryptoStore.GetAccount(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -142,16 +145,16 @@ func (mach *OlmMachine) Load() (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) saveAccount() {
|
||||
err := mach.CryptoStore.PutAccount(mach.account)
|
||||
func (mach *OlmMachine) saveAccount(ctx context.Context) {
|
||||
err := mach.CryptoStore.PutAccount(ctx, mach.account)
|
||||
if err != nil {
|
||||
mach.Log.Error().Err(err).Msg("Failed to save account")
|
||||
}
|
||||
}
|
||||
|
||||
// FlushStore calls the Flush method of the CryptoStore.
|
||||
func (mach *OlmMachine) FlushStore() error {
|
||||
return mach.CryptoStore.Flush()
|
||||
func (mach *OlmMachine) FlushStore(ctx context.Context) error {
|
||||
return mach.CryptoStore.Flush(ctx)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
|
||||
@@ -194,9 +197,9 @@ func (mach *OlmMachine) OwnIdentity() *id.Device {
|
||||
}
|
||||
|
||||
type asEventProcessor interface {
|
||||
On(evtType event.Type, handler func(evt *event.Event))
|
||||
OnOTK(func(otk *mautrix.OTKCount))
|
||||
OnDeviceList(func(lists *mautrix.DeviceLists, since string))
|
||||
On(evtType event.Type, handler func(ctx context.Context, evt *event.Event))
|
||||
OnOTK(func(ctx context.Context, otk *mautrix.OTKCount))
|
||||
OnDeviceList(func(ctx context.Context, lists *mautrix.DeviceLists, since string))
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
|
||||
@@ -217,19 +220,23 @@ func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
|
||||
mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions")
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
|
||||
func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.DeviceLists, since string) {
|
||||
if len(dl.Changed) > 0 {
|
||||
traceID := time.Now().Format("15:04:05.000000")
|
||||
mach.Log.Debug().
|
||||
Str("trace_id", traceID).
|
||||
Interface("changes", dl.Changed).
|
||||
Msg("Device list changes in /sync")
|
||||
mach.fetchKeys(context.TODO(), dl.Changed, since, false)
|
||||
if mach.DisableKeyFetching {
|
||||
mach.CryptoStore.MarkTrackedUsersOutdated(ctx, dl.Changed)
|
||||
} else {
|
||||
mach.FetchKeys(ctx, dl.Changed, false)
|
||||
}
|
||||
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
|
||||
func (mach *OlmMachine) HandleOTKCounts(ctx context.Context, otkCount *mautrix.OTKCount) {
|
||||
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
|
||||
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
|
||||
mach.Log.Warn().
|
||||
@@ -243,7 +250,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
|
||||
if otkCount.SignedCurve25519 < int(minCount) {
|
||||
traceID := time.Now().Format("15:04:05.000000")
|
||||
log := mach.Log.With().Str("trace_id", traceID).Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
ctx = log.WithContext(ctx)
|
||||
log.Debug().
|
||||
Int("keys_left", otkCount.Curve25519).
|
||||
Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...")
|
||||
@@ -261,8 +268,8 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
|
||||
// This can be easily registered into a mautrix client using .OnSync():
|
||||
//
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse)
|
||||
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
|
||||
mach.HandleDeviceLists(&resp.DeviceLists, since)
|
||||
func (mach *OlmMachine) ProcessSyncResponse(ctx context.Context, resp *mautrix.RespSync, since string) bool {
|
||||
mach.HandleDeviceLists(ctx, &resp.DeviceLists, since)
|
||||
|
||||
for _, evt := range resp.ToDevice.Events {
|
||||
evt.Type.Class = event.ToDeviceEventType
|
||||
@@ -271,10 +278,10 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
|
||||
mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event")
|
||||
continue
|
||||
}
|
||||
mach.HandleToDeviceEvent(evt)
|
||||
mach.HandleToDeviceEvent(ctx, evt)
|
||||
}
|
||||
|
||||
mach.HandleOTKCounts(&resp.DeviceOTKCount)
|
||||
mach.HandleOTKCounts(ctx, &resp.DeviceOTKCount)
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -283,8 +290,12 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
|
||||
// Currently this is not automatically called, so you must add a listener yourself:
|
||||
//
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
|
||||
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
|
||||
if !mach.StateStore.IsEncrypted(evt.RoomID) {
|
||||
func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event) {
|
||||
if isEncrypted, err := mach.StateStore.IsEncrypted(ctx, evt.RoomID); err != nil {
|
||||
mach.machOrContextLog(ctx).Err(err).Stringer("room_id", evt.RoomID).
|
||||
Msg("Failed to check if room is encrypted to handle member event")
|
||||
return
|
||||
} else if !isEncrypted {
|
||||
return
|
||||
}
|
||||
content := evt.Content.AsMember()
|
||||
@@ -311,7 +322,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
|
||||
Str("prev_membership", string(prevContent.Membership)).
|
||||
Str("new_membership", string(content.Membership)).
|
||||
Msg("Got membership state change, invalidating group session in room")
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(ctx, evt.RoomID)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
|
||||
}
|
||||
@@ -319,7 +330,7 @@ func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Even
|
||||
|
||||
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
|
||||
// don't need to add any custom handlers if you use that method.
|
||||
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
||||
func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) {
|
||||
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
|
||||
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
|
||||
mach.Log.Debug().
|
||||
@@ -329,12 +340,13 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
||||
return
|
||||
}
|
||||
traceID := time.Now().Format("15:04:05.000000")
|
||||
// TODO use context log?
|
||||
log := mach.Log.With().
|
||||
Str("trace_id", traceID).
|
||||
Str("sender", evt.Sender.String()).
|
||||
Str("type", evt.Type.Type).
|
||||
Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
ctx = log.WithContext(ctx)
|
||||
if evt.Type != event.ToDeviceEncrypted {
|
||||
log.Debug().Msg("Starting handling to-device event")
|
||||
}
|
||||
@@ -344,7 +356,7 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
||||
Str("sender_key", content.SenderKey.String()).
|
||||
Logger()
|
||||
log.Debug().Msg("Handling encrypted to-device event")
|
||||
ctx = log.WithContext(context.Background())
|
||||
ctx = log.WithContext(ctx)
|
||||
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to decrypt to-device event")
|
||||
@@ -381,17 +393,17 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
||||
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content)
|
||||
// verification cases
|
||||
case *event.VerificationStartEventContent:
|
||||
mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "")
|
||||
mach.handleVerificationStart(ctx, evt.Sender, content, content.TransactionID, 10*time.Minute, "")
|
||||
case *event.VerificationAcceptEventContent:
|
||||
mach.handleVerificationAccept(evt.Sender, content, content.TransactionID)
|
||||
mach.handleVerificationAccept(ctx, evt.Sender, content, content.TransactionID)
|
||||
case *event.VerificationKeyEventContent:
|
||||
mach.handleVerificationKey(evt.Sender, content, content.TransactionID)
|
||||
mach.handleVerificationKey(ctx, evt.Sender, content, content.TransactionID)
|
||||
case *event.VerificationMacEventContent:
|
||||
mach.handleVerificationMAC(evt.Sender, content, content.TransactionID)
|
||||
mach.handleVerificationMAC(ctx, evt.Sender, content, content.TransactionID)
|
||||
case *event.VerificationCancelEventContent:
|
||||
mach.handleVerificationCancel(evt.Sender, content, content.TransactionID)
|
||||
case *event.VerificationRequestEventContent:
|
||||
mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "")
|
||||
mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "")
|
||||
case *event.RoomKeyWithheldEventContent:
|
||||
mach.handleRoomKeyWithheld(ctx, content)
|
||||
default:
|
||||
@@ -405,14 +417,15 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
||||
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
|
||||
// and if it's not found it asks the server for it.
|
||||
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
|
||||
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
|
||||
} else if device != nil {
|
||||
} else if device != nil || mach.DisableKeyFetching {
|
||||
return device, nil
|
||||
}
|
||||
usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true)
|
||||
if devices, ok := usersToDevices[userID]; ok {
|
||||
if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil {
|
||||
return nil, fmt.Errorf("failed to fetch keys: %w", err)
|
||||
} else if devices, ok := usersToDevices[userID]; ok {
|
||||
if device, ok = devices[deviceID]; ok {
|
||||
return device, nil
|
||||
}
|
||||
@@ -425,15 +438,15 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
|
||||
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
|
||||
// the given identity key.
|
||||
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
|
||||
if err != nil || deviceIdentity != nil {
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
|
||||
if err != nil || deviceIdentity != nil || mach.DisableKeyFetching {
|
||||
return deviceIdentity, err
|
||||
}
|
||||
mach.machOrContextLog(ctx).Debug().
|
||||
Str("user_id", userID.String()).
|
||||
Str("identity_key", identityKey.String()).
|
||||
Msg("Didn't find identity in crypto store, fetching from server")
|
||||
devices := mach.LoadDevices(userID)
|
||||
devices := mach.LoadDevices(ctx, userID)
|
||||
for _, device := range devices {
|
||||
if device.IdentityKey == identityKey {
|
||||
return device, nil
|
||||
@@ -455,7 +468,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
|
||||
mach.olmLock.Lock()
|
||||
defer mach.olmLock.Unlock()
|
||||
|
||||
olmSess, err := mach.CryptoStore.GetLatestSession(device.IdentityKey)
|
||||
olmSess, err := mach.CryptoStore.GetLatestSession(ctx, device.IdentityKey)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -473,7 +486,7 @@ func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.De
|
||||
Str("to_identity_key", device.IdentityKey.String()).
|
||||
Str("olm_session_id", olmSess.ID().String()).
|
||||
Msg("Sending encrypted to-device event")
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceEncrypted,
|
||||
_, err = mach.Client.SendToDevice(ctx, event.ToDeviceEncrypted,
|
||||
&mautrix.ReqSendToDevice{
|
||||
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
|
||||
device.UserID: {
|
||||
@@ -499,7 +512,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
|
||||
Msg("Mismatched session ID while creating inbound group session")
|
||||
return
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
|
||||
err = mach.CryptoStore.PutGroupSession(ctx, roomID, senderKey, sessionID, igs)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
|
||||
return
|
||||
@@ -525,7 +538,7 @@ func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
|
||||
}
|
||||
|
||||
// WaitForSession waits for the given Megolm session to arrive.
|
||||
func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
func (mach *OlmMachine) WaitForSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
mach.keyWaitersLock.Lock()
|
||||
ch, ok := mach.keyWaiters[sessionID]
|
||||
if !ok {
|
||||
@@ -534,7 +547,7 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
|
||||
}
|
||||
mach.keyWaitersLock.Unlock()
|
||||
// Handle race conditions where a session appears between the failed decryption and WaitForSession call.
|
||||
sess, err := mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
|
||||
sess, err := mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
if sess != nil || errors.Is(err, ErrGroupSessionWithheld) {
|
||||
return true
|
||||
}
|
||||
@@ -542,10 +555,12 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
|
||||
case <-ch:
|
||||
return true
|
||||
case <-time.After(timeout):
|
||||
sess, err = mach.CryptoStore.GetGroupSession(roomID, senderKey, sessionID)
|
||||
sess, err = mach.CryptoStore.GetGroupSession(ctx, roomID, senderKey, sessionID)
|
||||
// Check if the session somehow appeared in the store without telling us
|
||||
// We accept withheld sessions as received, as then the decryption attempt will show the error.
|
||||
return sess != nil || errors.Is(err, ErrGroupSessionWithheld)
|
||||
case <-ctx.Done():
|
||||
return false
|
||||
}
|
||||
}
|
||||
|
||||
@@ -568,7 +583,10 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
|
||||
return
|
||||
}
|
||||
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
config, err := mach.StateStore.GetEncryptionEvent(ctx, content.RoomID)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to get encryption event for room")
|
||||
}
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
@@ -589,7 +607,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
|
||||
}
|
||||
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
|
||||
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(ctx, content.RoomID, evt.SenderKey, "received new key from device")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact previous megolm sessions")
|
||||
} else {
|
||||
@@ -606,7 +624,7 @@ func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *even
|
||||
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
|
||||
return
|
||||
}
|
||||
err := mach.CryptoStore.PutWithheldGroupSession(*content)
|
||||
err := mach.CryptoStore.PutWithheldGroupSession(ctx, *content)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
|
||||
}
|
||||
@@ -624,7 +642,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
||||
defer mach.otkUploadLock.Unlock()
|
||||
if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 {
|
||||
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count")
|
||||
resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{})
|
||||
resp, err := mach.Client.UploadKeys(ctx, &mautrix.ReqUploadKeys{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check current OTK counts: %w", err)
|
||||
}
|
||||
@@ -637,6 +655,15 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
||||
var deviceKeys *mautrix.DeviceKeys
|
||||
if !mach.account.Shared {
|
||||
deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID)
|
||||
err := mach.CryptoStore.PutDevice(ctx, mach.Client.UserID, &id.Device{
|
||||
UserID: mach.Client.UserID,
|
||||
DeviceID: mach.Client.DeviceID,
|
||||
IdentityKey: deviceKeys.Keys.GetCurve25519(mach.Client.DeviceID),
|
||||
SigningKey: deviceKeys.Keys.GetEd25519(mach.Client.DeviceID),
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to save initial keys: %w", err)
|
||||
}
|
||||
log.Debug().Msg("Going to upload initial account keys")
|
||||
}
|
||||
oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount)
|
||||
@@ -649,20 +676,20 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
|
||||
OneTimeKeys: oneTimeKeys,
|
||||
}
|
||||
log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys")
|
||||
_, err := mach.Client.UploadKeys(req)
|
||||
_, err := mach.Client.UploadKeys(ctx, req)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mach.lastOTKUpload = time.Now()
|
||||
mach.account.Shared = true
|
||||
mach.saveAccount()
|
||||
mach.saveAccount(ctx)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
|
||||
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
|
||||
for {
|
||||
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
|
||||
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions(ctx)
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact expired megolm sessions")
|
||||
} else if len(sessionIDs) > 0 {
|
||||
|
||||
177
vendor/maunium.net/go/mautrix/crypto/olm/LICENSE
generated
vendored
177
vendor/maunium.net/go/mautrix/crypto/olm/LICENSE
generated
vendored
@@ -1,177 +0,0 @@
|
||||
|
||||
Apache License
|
||||
Version 2.0, January 2004
|
||||
http://www.apache.org/licenses/
|
||||
|
||||
TERMS AND CONDITIONS FOR USE, REPRODUCTION, AND DISTRIBUTION
|
||||
|
||||
1. Definitions.
|
||||
|
||||
"License" shall mean the terms and conditions for use, reproduction,
|
||||
and distribution as defined by Sections 1 through 9 of this document.
|
||||
|
||||
"Licensor" shall mean the copyright owner or entity authorized by
|
||||
the copyright owner that is granting the License.
|
||||
|
||||
"Legal Entity" shall mean the union of the acting entity and all
|
||||
other entities that control, are controlled by, or are under common
|
||||
control with that entity. For the purposes of this definition,
|
||||
"control" means (i) the power, direct or indirect, to cause the
|
||||
direction or management of such entity, whether by contract or
|
||||
otherwise, or (ii) ownership of fifty percent (50%) or more of the
|
||||
outstanding shares, or (iii) beneficial ownership of such entity.
|
||||
|
||||
"You" (or "Your") shall mean an individual or Legal Entity
|
||||
exercising permissions granted by this License.
|
||||
|
||||
"Source" form shall mean the preferred form for making modifications,
|
||||
including but not limited to software source code, documentation
|
||||
source, and configuration files.
|
||||
|
||||
"Object" form shall mean any form resulting from mechanical
|
||||
transformation or translation of a Source form, including but
|
||||
not limited to compiled object code, generated documentation,
|
||||
and conversions to other media types.
|
||||
|
||||
"Work" shall mean the work of authorship, whether in Source or
|
||||
Object form, made available under the License, as indicated by a
|
||||
copyright notice that is included in or attached to the work
|
||||
(an example is provided in the Appendix below).
|
||||
|
||||
"Derivative Works" shall mean any work, whether in Source or Object
|
||||
form, that is based on (or derived from) the Work and for which the
|
||||
editorial revisions, annotations, elaborations, or other modifications
|
||||
represent, as a whole, an original work of authorship. For the purposes
|
||||
of this License, Derivative Works shall not include works that remain
|
||||
separable from, or merely link (or bind by name) to the interfaces of,
|
||||
the Work and Derivative Works thereof.
|
||||
|
||||
"Contribution" shall mean any work of authorship, including
|
||||
the original version of the Work and any modifications or additions
|
||||
to that Work or Derivative Works thereof, that is intentionally
|
||||
submitted to Licensor for inclusion in the Work by the copyright owner
|
||||
or by an individual or Legal Entity authorized to submit on behalf of
|
||||
the copyright owner. For the purposes of this definition, "submitted"
|
||||
means any form of electronic, verbal, or written communication sent
|
||||
to the Licensor or its representatives, including but not limited to
|
||||
communication on electronic mailing lists, source code control systems,
|
||||
and issue tracking systems that are managed by, or on behalf of, the
|
||||
Licensor for the purpose of discussing and improving the Work, but
|
||||
excluding communication that is conspicuously marked or otherwise
|
||||
designated in writing by the copyright owner as "Not a Contribution."
|
||||
|
||||
"Contributor" shall mean Licensor and any individual or Legal Entity
|
||||
on behalf of whom a Contribution has been received by Licensor and
|
||||
subsequently incorporated within the Work.
|
||||
|
||||
2. Grant of Copyright License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
copyright license to reproduce, prepare Derivative Works of,
|
||||
publicly display, publicly perform, sublicense, and distribute the
|
||||
Work and such Derivative Works in Source or Object form.
|
||||
|
||||
3. Grant of Patent License. Subject to the terms and conditions of
|
||||
this License, each Contributor hereby grants to You a perpetual,
|
||||
worldwide, non-exclusive, no-charge, royalty-free, irrevocable
|
||||
(except as stated in this section) patent license to make, have made,
|
||||
use, offer to sell, sell, import, and otherwise transfer the Work,
|
||||
where such license applies only to those patent claims licensable
|
||||
by such Contributor that are necessarily infringed by their
|
||||
Contribution(s) alone or by combination of their Contribution(s)
|
||||
with the Work to which such Contribution(s) was submitted. If You
|
||||
institute patent litigation against any entity (including a
|
||||
cross-claim or counterclaim in a lawsuit) alleging that the Work
|
||||
or a Contribution incorporated within the Work constitutes direct
|
||||
or contributory patent infringement, then any patent licenses
|
||||
granted to You under this License for that Work shall terminate
|
||||
as of the date such litigation is filed.
|
||||
|
||||
4. Redistribution. You may reproduce and distribute copies of the
|
||||
Work or Derivative Works thereof in any medium, with or without
|
||||
modifications, and in Source or Object form, provided that You
|
||||
meet the following conditions:
|
||||
|
||||
(a) You must give any other recipients of the Work or
|
||||
Derivative Works a copy of this License; and
|
||||
|
||||
(b) You must cause any modified files to carry prominent notices
|
||||
stating that You changed the files; and
|
||||
|
||||
(c) You must retain, in the Source form of any Derivative Works
|
||||
that You distribute, all copyright, patent, trademark, and
|
||||
attribution notices from the Source form of the Work,
|
||||
excluding those notices that do not pertain to any part of
|
||||
the Derivative Works; and
|
||||
|
||||
(d) If the Work includes a "NOTICE" text file as part of its
|
||||
distribution, then any Derivative Works that You distribute must
|
||||
include a readable copy of the attribution notices contained
|
||||
within such NOTICE file, excluding those notices that do not
|
||||
pertain to any part of the Derivative Works, in at least one
|
||||
of the following places: within a NOTICE text file distributed
|
||||
as part of the Derivative Works; within the Source form or
|
||||
documentation, if provided along with the Derivative Works; or,
|
||||
within a display generated by the Derivative Works, if and
|
||||
wherever such third-party notices normally appear. The contents
|
||||
of the NOTICE file are for informational purposes only and
|
||||
do not modify the License. You may add Your own attribution
|
||||
notices within Derivative Works that You distribute, alongside
|
||||
or as an addendum to the NOTICE text from the Work, provided
|
||||
that such additional attribution notices cannot be construed
|
||||
as modifying the License.
|
||||
|
||||
You may add Your own copyright statement to Your modifications and
|
||||
may provide additional or different license terms and conditions
|
||||
for use, reproduction, or distribution of Your modifications, or
|
||||
for any such Derivative Works as a whole, provided Your use,
|
||||
reproduction, and distribution of the Work otherwise complies with
|
||||
the conditions stated in this License.
|
||||
|
||||
5. Submission of Contributions. Unless You explicitly state otherwise,
|
||||
any Contribution intentionally submitted for inclusion in the Work
|
||||
by You to the Licensor shall be under the terms and conditions of
|
||||
this License, without any additional terms or conditions.
|
||||
Notwithstanding the above, nothing herein shall supersede or modify
|
||||
the terms of any separate license agreement you may have executed
|
||||
with Licensor regarding such Contributions.
|
||||
|
||||
6. Trademarks. This License does not grant permission to use the trade
|
||||
names, trademarks, service marks, or product names of the Licensor,
|
||||
except as required for reasonable and customary use in describing the
|
||||
origin of the Work and reproducing the content of the NOTICE file.
|
||||
|
||||
7. Disclaimer of Warranty. Unless required by applicable law or
|
||||
agreed to in writing, Licensor provides the Work (and each
|
||||
Contributor provides its Contributions) on an "AS IS" BASIS,
|
||||
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or
|
||||
implied, including, without limitation, any warranties or conditions
|
||||
of TITLE, NON-INFRINGEMENT, MERCHANTABILITY, or FITNESS FOR A
|
||||
PARTICULAR PURPOSE. You are solely responsible for determining the
|
||||
appropriateness of using or redistributing the Work and assume any
|
||||
risks associated with Your exercise of permissions under this License.
|
||||
|
||||
8. Limitation of Liability. In no event and under no legal theory,
|
||||
whether in tort (including negligence), contract, or otherwise,
|
||||
unless required by applicable law (such as deliberate and grossly
|
||||
negligent acts) or agreed to in writing, shall any Contributor be
|
||||
liable to You for damages, including any direct, indirect, special,
|
||||
incidental, or consequential damages of any character arising as a
|
||||
result of this License or out of the use or inability to use the
|
||||
Work (including but not limited to damages for loss of goodwill,
|
||||
work stoppage, computer failure or malfunction, or any and all
|
||||
other commercial damages or losses), even if such Contributor
|
||||
has been advised of the possibility of such damages.
|
||||
|
||||
9. Accepting Warranty or Additional Liability. While redistributing
|
||||
the Work or Derivative Works thereof, You may choose to offer,
|
||||
and charge a fee for, acceptance of support, warranty, indemnity,
|
||||
or other liability obligations and/or rights consistent with this
|
||||
License. However, in accepting such obligations, You may act only
|
||||
on Your own behalf and on Your sole responsibility, not on behalf
|
||||
of any other Contributor, and only if You agree to indemnify,
|
||||
defend, and hold each Contributor harmless for any liability
|
||||
incurred by, or claims asserted against, such Contributor by reason
|
||||
of your accepting any such warranty or additional liability.
|
||||
|
||||
END OF TERMS AND CONDITIONS
|
||||
2
vendor/maunium.net/go/mautrix/crypto/olm/README.md
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/olm/README.md
generated
vendored
@@ -1,2 +1,4 @@
|
||||
# Go olm bindings
|
||||
Based on [Dhole/go-olm](https://github.com/Dhole/go-olm)
|
||||
|
||||
The original project is licensed under the Apache 2.0 license.
|
||||
|
||||
6
vendor/maunium.net/go/mautrix/crypto/olm/account.go
generated
vendored
6
vendor/maunium.net/go/mautrix/crypto/olm/account.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
@@ -155,6 +157,7 @@ func (a *Account) Unpickle(pickled, key []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) GobEncode() ([]byte, error) {
|
||||
pickled := a.Pickle(pickleKey)
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
@@ -163,6 +166,7 @@ func (a *Account) GobEncode() ([]byte, error) {
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) GobDecode(rawPickled []byte) error {
|
||||
if a.int == nil {
|
||||
*a = *NewBlankAccount()
|
||||
@@ -173,6 +177,7 @@ func (a *Account) GobDecode(rawPickled []byte) error {
|
||||
return a.Unpickle(pickled, pickleKey)
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) MarshalJSON() ([]byte, error) {
|
||||
pickled := a.Pickle(pickleKey)
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
@@ -182,6 +187,7 @@ func (a *Account) MarshalJSON() ([]byte, error) {
|
||||
return quotes, nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (a *Account) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
|
||||
154
vendor/maunium.net/go/mautrix/crypto/olm/account_goolm.go
generated
vendored
Normal file
154
vendor/maunium.net/go/mautrix/crypto/olm/account_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,154 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/crypto/goolm/account"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Account stores a device account for end to end encrypted messaging.
|
||||
type Account struct {
|
||||
account.Account
|
||||
}
|
||||
|
||||
// NewAccount creates a new Account.
|
||||
func NewAccount() *Account {
|
||||
a, err := account.NewAccount(nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
ac := &Account{}
|
||||
ac.Account = *a
|
||||
return ac
|
||||
}
|
||||
|
||||
func NewBlankAccount() *Account {
|
||||
return &Account{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this Account.
|
||||
func (a *Account) Clear() error {
|
||||
a.Account = account.Account{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns an Account as a base64 string. Encrypts the Account using the
|
||||
// supplied key.
|
||||
func (a *Account) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := a.Account.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
// IdentityKeysJSON returns the public parts of the identity keys for the Account.
|
||||
func (a *Account) IdentityKeysJSON() []byte {
|
||||
identityKeys, err := a.Account.IdentityKeysJSON()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return identityKeys
|
||||
}
|
||||
|
||||
// Sign returns the signature of a message using the ed25519 key for this
|
||||
// Account.
|
||||
func (a *Account) Sign(message []byte) []byte {
|
||||
if len(message) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
signature, err := a.Account.Sign(message)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return signature
|
||||
}
|
||||
|
||||
// SignJSON signs the given JSON object following the Matrix specification:
|
||||
// https://matrix.org/docs/spec/appendices#signing-json
|
||||
func (a *Account) SignJSON(obj interface{}) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
|
||||
return string(a.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))), nil
|
||||
}
|
||||
|
||||
// MaxNumberOfOneTimeKeys returns the largest number of one time keys this
|
||||
// Account can store.
|
||||
func (a *Account) MaxNumberOfOneTimeKeys() uint {
|
||||
return uint(account.MaxOneTimeKeys)
|
||||
}
|
||||
|
||||
// GenOneTimeKeys generates a number of new one time keys. If the total number
|
||||
// of keys stored by this Account exceeds MaxNumberOfOneTimeKeys then the old
|
||||
// keys are discarded.
|
||||
func (a *Account) GenOneTimeKeys(num uint) {
|
||||
err := a.Account.GenOneTimeKeys(nil, num)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}
|
||||
|
||||
// NewOutboundSession creates a new out-bound session for sending messages to a
|
||||
// given curve25519 identityKey and oneTimeKey. Returns error on failure.
|
||||
func (a *Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := &Session{}
|
||||
newSession, err := a.Account.NewOutboundSession(theirIdentityKey, theirOneTimeKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.OlmSession = *newSession
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewInboundSession creates a new in-bound session for sending/receiving
|
||||
// messages from an incoming PRE_KEY message. Returns error on failure.
|
||||
func (a *Account) NewInboundSession(oneTimeKeyMsg string) (*Session, error) {
|
||||
if len(oneTimeKeyMsg) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := &Session{}
|
||||
newSession, err := a.Account.NewInboundSession(nil, []byte(oneTimeKeyMsg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.OlmSession = *newSession
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// NewInboundSessionFrom creates a new in-bound session for sending/receiving
|
||||
// messages from an incoming PRE_KEY message. Returns error on failure.
|
||||
func (a *Account) NewInboundSessionFrom(theirIdentityKey id.Curve25519, oneTimeKeyMsg string) (*Session, error) {
|
||||
if len(theirIdentityKey) == 0 || len(oneTimeKeyMsg) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := &Session{}
|
||||
newSession, err := a.Account.NewInboundSession(&theirIdentityKey, []byte(oneTimeKeyMsg))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
s.OlmSession = *newSession
|
||||
return s, nil
|
||||
}
|
||||
|
||||
// RemoveOneTimeKeys removes the one time keys that the session used from the
|
||||
// Account. Returns error on failure.
|
||||
func (a *Account) RemoveOneTimeKeys(s *Session) error {
|
||||
a.Account.RemoveOneTimeKeys(&s.OlmSession)
|
||||
return nil
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/crypto/olm/error.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/olm/error.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
|
||||
23
vendor/maunium.net/go/mautrix/crypto/olm/error_goolm.go
generated
vendored
Normal file
23
vendor/maunium.net/go/mautrix/crypto/olm/error_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/goolm"
|
||||
)
|
||||
|
||||
// Error codes from go-olm
|
||||
var (
|
||||
EmptyInput = goolm.ErrEmptyInput
|
||||
NoKeyProvided = goolm.ErrNoKeyProvided
|
||||
NotEnoughGoRandom = errors.New("couldn't get enough randomness from crypto/rand")
|
||||
SignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
|
||||
InputNotJSONString = errors.New("input doesn't look like a JSON string")
|
||||
)
|
||||
|
||||
// Error codes from olm code
|
||||
var (
|
||||
UnknownMessageIndex = goolm.ErrRatchetNotAvailable
|
||||
)
|
||||
6
vendor/maunium.net/go/mautrix/crypto/olm/inboundgroupsession.go
generated
vendored
6
vendor/maunium.net/go/mautrix/crypto/olm/inboundgroupsession.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
@@ -147,6 +149,7 @@ func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) GobEncode() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
@@ -155,6 +158,7 @@ func (s *InboundGroupSession) GobEncode() ([]byte, error) {
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankInboundGroupSession()
|
||||
@@ -165,6 +169,7 @@ func (s *InboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
return s.Unpickle(pickled, pickleKey)
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
@@ -174,6 +179,7 @@ func (s *InboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
return quotes, nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *InboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
|
||||
149
vendor/maunium.net/go/mautrix/crypto/olm/inboundgroupsession_goolm.go
generated
vendored
Normal file
149
vendor/maunium.net/go/mautrix/crypto/olm/inboundgroupsession_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,149 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// InboundGroupSession stores an inbound encrypted messaging session for a
|
||||
// group.
|
||||
type InboundGroupSession struct {
|
||||
session.MegolmInboundSession
|
||||
}
|
||||
|
||||
// InboundGroupSessionFromPickled loads an InboundGroupSession from a pickled
|
||||
// base64 string. Decrypts the InboundGroupSession using the supplied key.
|
||||
// Returns error on failure.
|
||||
func InboundGroupSessionFromPickled(pickled, key []byte) (*InboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
lenKey := len(key)
|
||||
if lenKey == 0 {
|
||||
key = []byte(" ")
|
||||
}
|
||||
megolmSession, err := session.MegolmInboundSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
MegolmInboundSession: *megolmSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewInboundGroupSession creates a new inbound group session from a key
|
||||
// exported from OutboundGroupSession.Key(). Returns error on failure.
|
||||
func NewInboundGroupSession(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
megolmSession, err := session.NewMegolmInboundSession(sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
MegolmInboundSession: *megolmSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// InboundGroupSessionImport imports an inbound group session from a previous
|
||||
// export. Returns error on failure.
|
||||
func InboundGroupSessionImport(sessionKey []byte) (*InboundGroupSession, error) {
|
||||
if len(sessionKey) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
megolmSession, err := session.NewMegolmInboundSessionFromExport(sessionKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
MegolmInboundSession: *megolmSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func NewBlankInboundGroupSession() *InboundGroupSession {
|
||||
return &InboundGroupSession{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this InboundGroupSession.
|
||||
func (s *InboundGroupSession) Clear() error {
|
||||
s.MegolmInboundSession = session.MegolmInboundSession{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns an InboundGroupSession as a base64 string. Encrypts the
|
||||
// InboundGroupSession using the supplied key.
|
||||
func (s *InboundGroupSession) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := s.MegolmInboundSession.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
func (s *InboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
} else if len(pickled) == 0 {
|
||||
return EmptyInput
|
||||
}
|
||||
sOlm, err := session.MegolmInboundSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.MegolmInboundSession = *sOlm
|
||||
return nil
|
||||
}
|
||||
|
||||
// Decrypt decrypts a message using the InboundGroupSession. Returns the the
|
||||
// plain-text and message index on success. Returns error on failure.
|
||||
func (s *InboundGroupSession) Decrypt(message []byte) ([]byte, uint, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, 0, EmptyInput
|
||||
}
|
||||
plaintext, messageIndex, err := s.MegolmInboundSession.Decrypt(message)
|
||||
if err != nil {
|
||||
return nil, 0, err
|
||||
}
|
||||
return plaintext, uint(messageIndex), nil
|
||||
}
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
func (s *InboundGroupSession) ID() id.SessionID {
|
||||
return s.MegolmInboundSession.SessionID()
|
||||
}
|
||||
|
||||
// FirstKnownIndex returns the first message index we know how to decrypt.
|
||||
func (s *InboundGroupSession) FirstKnownIndex() uint32 {
|
||||
return s.MegolmInboundSession.InitialRatchet.Counter
|
||||
}
|
||||
|
||||
// IsVerified check if the session has been verified as a valid session. (A
|
||||
// session is verified either because the original session share was signed, or
|
||||
// because we have subsequently successfully decrypted a message.)
|
||||
func (s *InboundGroupSession) IsVerified() uint {
|
||||
if s.MegolmInboundSession.SigningKeyVerified {
|
||||
return 1
|
||||
}
|
||||
return 0
|
||||
}
|
||||
|
||||
// Export returns the base64-encoded ratchet key for this session, at the given
|
||||
// index, in a format which can be used by
|
||||
// InboundGroupSession.InboundGroupSessionImport(). Encrypts the
|
||||
// InboundGroupSession using the supplied key. Returns error on failure.
|
||||
// if we do not have a session key corresponding to the given index (ie, it was
|
||||
// sent before the session key was shared with us) the error will be
|
||||
// returned.
|
||||
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
|
||||
res, err := s.MegolmInboundSession.SessionExportMessage(messageIndex)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return res, nil
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/crypto/olm/olm.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/olm/olm.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
|
||||
20
vendor/maunium.net/go/mautrix/crypto/olm/olm_goolm.go
generated
vendored
Normal file
20
vendor/maunium.net/go/mautrix/crypto/olm/olm_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,20 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Signatures is the data structure used to sign JSON objects.
|
||||
type Signatures map[id.UserID]map[id.DeviceKeyID]string
|
||||
|
||||
// Version returns the version number of the olm library.
|
||||
func Version() (major, minor, patch uint8) {
|
||||
return 3, 2, 15
|
||||
}
|
||||
|
||||
// SetPickleKey sets the global pickle key used when encoding structs with Gob or JSON.
|
||||
func SetPickleKey(key []byte) {
|
||||
panic("gob and json encoding is deprecated and not supported with goolm")
|
||||
}
|
||||
6
vendor/maunium.net/go/mautrix/crypto/olm/outboundgroupsession.go
generated
vendored
6
vendor/maunium.net/go/mautrix/crypto/olm/outboundgroupsession.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
@@ -122,6 +124,7 @@ func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) GobEncode() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
@@ -130,6 +133,7 @@ func (s *OutboundGroupSession) GobEncode() ([]byte, error) {
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankOutboundGroupSession()
|
||||
@@ -140,6 +144,7 @@ func (s *OutboundGroupSession) GobDecode(rawPickled []byte) error {
|
||||
return s.Unpickle(pickled, pickleKey)
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
@@ -149,6 +154,7 @@ func (s *OutboundGroupSession) MarshalJSON() ([]byte, error) {
|
||||
return quotes, nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *OutboundGroupSession) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
|
||||
111
vendor/maunium.net/go/mautrix/crypto/olm/outboundgroupsession_goolm.go
generated
vendored
Normal file
111
vendor/maunium.net/go/mautrix/crypto/olm/outboundgroupsession_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,111 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// OutboundGroupSession stores an outbound encrypted messaging session for a
|
||||
// group.
|
||||
type OutboundGroupSession struct {
|
||||
session.MegolmOutboundSession
|
||||
}
|
||||
|
||||
// OutboundGroupSessionFromPickled loads an OutboundGroupSession from a pickled
|
||||
// base64 string. Decrypts the OutboundGroupSession using the supplied key.
|
||||
// Returns error on failure. If the key doesn't match the one used to encrypt
|
||||
// the OutboundGroupSession then the error will be "BAD_SESSION_KEY". If the
|
||||
// base64 couldn't be decoded then the error will be "INVALID_BASE64".
|
||||
func OutboundGroupSessionFromPickled(pickled, key []byte) (*OutboundGroupSession, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
lenKey := len(key)
|
||||
if lenKey == 0 {
|
||||
key = []byte(" ")
|
||||
}
|
||||
megolmSession, err := session.MegolmOutboundSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &OutboundGroupSession{
|
||||
MegolmOutboundSession: *megolmSession,
|
||||
}, nil
|
||||
}
|
||||
|
||||
// NewOutboundGroupSession creates a new outbound group session.
|
||||
func NewOutboundGroupSession() *OutboundGroupSession {
|
||||
megolmSession, err := session.NewMegolmOutboundSession()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &OutboundGroupSession{
|
||||
MegolmOutboundSession: *megolmSession,
|
||||
}
|
||||
}
|
||||
|
||||
// newOutboundGroupSession initialises an empty OutboundGroupSession.
|
||||
func NewBlankOutboundGroupSession() *OutboundGroupSession {
|
||||
return &OutboundGroupSession{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this OutboundGroupSession.
|
||||
func (s *OutboundGroupSession) Clear() error {
|
||||
s.MegolmOutboundSession = session.MegolmOutboundSession{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns an OutboundGroupSession as a base64 string. Encrypts the
|
||||
// OutboundGroupSession using the supplied key.
|
||||
func (s *OutboundGroupSession) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := s.MegolmOutboundSession.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
func (s *OutboundGroupSession) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
}
|
||||
return s.MegolmOutboundSession.Unpickle(pickled, key)
|
||||
}
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message
|
||||
// as base64.
|
||||
func (s *OutboundGroupSession) Encrypt(plaintext []byte) []byte {
|
||||
if len(plaintext) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
message, err := s.MegolmOutboundSession.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return message
|
||||
}
|
||||
|
||||
// ID returns a base64-encoded identifier for this session.
|
||||
func (s *OutboundGroupSession) ID() id.SessionID {
|
||||
return s.MegolmOutboundSession.SessionID()
|
||||
}
|
||||
|
||||
// MessageIndex returns the message index for this session. Each message is
|
||||
// sent with an increasing index; this returns the index for the next message.
|
||||
func (s *OutboundGroupSession) MessageIndex() uint {
|
||||
return uint(s.MegolmOutboundSession.Ratchet.Counter)
|
||||
}
|
||||
|
||||
// Key returns the base64-encoded current ratchet key for this session.
|
||||
func (s *OutboundGroupSession) Key() string {
|
||||
message, err := s.MegolmOutboundSession.SessionSharingMessage()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return string(message)
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/crypto/olm/pk.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/olm/pk.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
|
||||
71
vendor/maunium.net/go/mautrix/crypto/olm/pk_goolm.go
generated
vendored
Normal file
71
vendor/maunium.net/go/mautrix/crypto/olm/pk_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,71 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/crypto/goolm/pk"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// PkSigning stores a key pair for signing messages.
|
||||
type PkSigning struct {
|
||||
pk.Signing
|
||||
PublicKey id.Ed25519
|
||||
Seed []byte
|
||||
}
|
||||
|
||||
// Clear clears the underlying memory of a PkSigning object.
|
||||
func (p *PkSigning) Clear() {
|
||||
p.Signing = pk.Signing{}
|
||||
}
|
||||
|
||||
// NewPkSigningFromSeed creates a new PkSigning object using the given seed.
|
||||
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) {
|
||||
p := &PkSigning{}
|
||||
signing, err := pk.NewSigningFromSeed(seed)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Signing = *signing
|
||||
p.Seed = seed
|
||||
p.PublicKey = p.Signing.PublicKey()
|
||||
return p, nil
|
||||
}
|
||||
|
||||
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages.
|
||||
func NewPkSigning() (*PkSigning, error) {
|
||||
p := &PkSigning{}
|
||||
signing, err := pk.NewSigning()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
p.Signing = *signing
|
||||
p.Seed = signing.Seed
|
||||
p.PublicKey = p.Signing.PublicKey()
|
||||
return p, err
|
||||
}
|
||||
|
||||
// Sign creates a signature for the given message using this key.
|
||||
func (p *PkSigning) Sign(message []byte) ([]byte, error) {
|
||||
return p.Signing.Sign(message), nil
|
||||
}
|
||||
|
||||
// SignJSON creates a signature for the given object after encoding it to canonical JSON.
|
||||
func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
|
||||
objJSON, err := json.Marshal(obj)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
|
||||
signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(signature), nil
|
||||
}
|
||||
6
vendor/maunium.net/go/mautrix/crypto/olm/session.go
generated
vendored
6
vendor/maunium.net/go/mautrix/crypto/olm/session.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
@@ -155,6 +157,7 @@ func (s *Session) Unpickle(pickled, key []byte) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) GobEncode() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
length := base64.RawStdEncoding.DecodedLen(len(pickled))
|
||||
@@ -163,6 +166,7 @@ func (s *Session) GobEncode() ([]byte, error) {
|
||||
return rawPickled, err
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) GobDecode(rawPickled []byte) error {
|
||||
if s == nil || s.int == nil {
|
||||
*s = *NewBlankSession()
|
||||
@@ -173,6 +177,7 @@ func (s *Session) GobDecode(rawPickled []byte) error {
|
||||
return s.Unpickle(pickled, pickleKey)
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) MarshalJSON() ([]byte, error) {
|
||||
pickled := s.Pickle(pickleKey)
|
||||
quotes := make([]byte, len(pickled)+2)
|
||||
@@ -182,6 +187,7 @@ func (s *Session) MarshalJSON() ([]byte, error) {
|
||||
return quotes, nil
|
||||
}
|
||||
|
||||
// Deprecated
|
||||
func (s *Session) UnmarshalJSON(data []byte) error {
|
||||
if len(data) == 0 || len(data) == 0 || data[0] != '"' || data[len(data)-1] != '"' {
|
||||
return InputNotJSONString
|
||||
|
||||
110
vendor/maunium.net/go/mautrix/crypto/olm/session_goolm.go
generated
vendored
Normal file
110
vendor/maunium.net/go/mautrix/crypto/olm/session_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,110 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/session"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Session stores an end to end encrypted messaging session.
|
||||
type Session struct {
|
||||
session.OlmSession
|
||||
}
|
||||
|
||||
// SessionFromPickled loads a Session from a pickled base64 string. Decrypts
|
||||
// the Session using the supplied key. Returns error on failure.
|
||||
func SessionFromPickled(pickled, key []byte) (*Session, error) {
|
||||
if len(pickled) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
s := NewBlankSession()
|
||||
return s, s.Unpickle(pickled, key)
|
||||
}
|
||||
|
||||
func NewBlankSession() *Session {
|
||||
return &Session{}
|
||||
}
|
||||
|
||||
// Clear clears the memory used to back this Session.
|
||||
func (s *Session) Clear() error {
|
||||
s.OlmSession = session.OlmSession{}
|
||||
return nil
|
||||
}
|
||||
|
||||
// Pickle returns a Session as a base64 string. Encrypts the Session using the
|
||||
// supplied key.
|
||||
func (s *Session) Pickle(key []byte) []byte {
|
||||
if len(key) == 0 {
|
||||
panic(NoKeyProvided)
|
||||
}
|
||||
pickled, err := s.OlmSession.Pickle(key)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return pickled
|
||||
}
|
||||
|
||||
func (s *Session) Unpickle(pickled, key []byte) error {
|
||||
if len(key) == 0 {
|
||||
return NoKeyProvided
|
||||
} else if len(pickled) == 0 {
|
||||
return EmptyInput
|
||||
}
|
||||
sOlm, err := session.OlmSessionFromPickled(pickled, key)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
s.OlmSession = *sOlm
|
||||
return nil
|
||||
}
|
||||
|
||||
// MatchesInboundSession checks if the PRE_KEY message is for this in-bound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match. Returns error on
|
||||
// failure.
|
||||
func (s *Session) MatchesInboundSession(oneTimeKeyMsg string) (bool, error) {
|
||||
return s.MatchesInboundSessionFrom("", oneTimeKeyMsg)
|
||||
}
|
||||
|
||||
// MatchesInboundSessionFrom checks if the PRE_KEY message is for this in-bound
|
||||
// Session. This can happen if multiple messages are sent to this Account
|
||||
// before this Account sends a message in reply. Returns true if the session
|
||||
// matches. Returns false if the session does not match. Returns error on
|
||||
// failure.
|
||||
func (s *Session) MatchesInboundSessionFrom(theirIdentityKey, oneTimeKeyMsg string) (bool, error) {
|
||||
if theirIdentityKey != "" {
|
||||
theirKey := id.Curve25519(theirIdentityKey)
|
||||
return s.OlmSession.MatchesInboundSessionFrom(&theirKey, []byte(oneTimeKeyMsg))
|
||||
}
|
||||
return s.OlmSession.MatchesInboundSessionFrom(nil, []byte(oneTimeKeyMsg))
|
||||
|
||||
}
|
||||
|
||||
// Encrypt encrypts a message using the Session. Returns the encrypted message
|
||||
// as base64.
|
||||
func (s *Session) Encrypt(plaintext []byte) (id.OlmMsgType, []byte) {
|
||||
if len(plaintext) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
messageType, message, err := s.OlmSession.Encrypt(plaintext, nil)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return messageType, message
|
||||
}
|
||||
|
||||
// Decrypt decrypts a message using the Session. Returns the the plain-text on
|
||||
// success. Returns error on failure.
|
||||
func (s *Session) Decrypt(message string, msgType id.OlmMsgType) ([]byte, error) {
|
||||
if len(message) == 0 {
|
||||
return nil, EmptyInput
|
||||
}
|
||||
return s.OlmSession.Decrypt([]byte(message), msgType)
|
||||
}
|
||||
|
||||
// Describe generates a string describing the internal state of an olm session for debugging and logging purposes.
|
||||
func (s *Session) Describe() string {
|
||||
return s.OlmSession.Describe()
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/crypto/olm/utility.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/olm/utility.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
|
||||
92
vendor/maunium.net/go/mautrix/crypto/olm/utility_goolm.go
generated
vendored
Normal file
92
vendor/maunium.net/go/mautrix/crypto/olm/utility_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,92 @@
|
||||
//go:build goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"crypto/sha256"
|
||||
"encoding/base64"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
"go.mau.fi/util/exgjson"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/canonicaljson"
|
||||
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Utility stores the necessary state to perform hash and signature
|
||||
// verification operations.
|
||||
type Utility struct{}
|
||||
|
||||
// Clear clears the memory used to back this utility.
|
||||
func (u *Utility) Clear() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// NewUtility creates a new utility.
|
||||
func NewUtility() *Utility {
|
||||
return &Utility{}
|
||||
}
|
||||
|
||||
// Sha256 calculates the SHA-256 hash of the input and encodes it as base64.
|
||||
func (u *Utility) Sha256(input string) string {
|
||||
if len(input) == 0 {
|
||||
panic(EmptyInput)
|
||||
}
|
||||
hash := sha256.Sum256([]byte(input))
|
||||
return base64.RawStdEncoding.EncodeToString(hash[:])
|
||||
}
|
||||
|
||||
// VerifySignature verifies an ed25519 signature. Returns true if the verification
|
||||
// suceeds or false otherwise. Returns error on failure. If the key was too
|
||||
// small then the error will be "INVALID_BASE64".
|
||||
func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) {
|
||||
if len(message) == 0 || len(key) == 0 || len(signature) == 0 {
|
||||
return false, EmptyInput
|
||||
}
|
||||
return utilities.VerifySignature([]byte(message), key, []byte(signature))
|
||||
}
|
||||
|
||||
// VerifySignatureJSON verifies the signature in the JSON object _obj following
|
||||
// the Matrix specification:
|
||||
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
|
||||
// If the _obj is a struct, the `json` tags will be honored.
|
||||
func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
|
||||
var err error
|
||||
objJSON, ok := obj.(json.RawMessage)
|
||||
if !ok {
|
||||
objJSON, err = json.Marshal(obj)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
|
||||
if !sig.Exists() || sig.Type != gjson.String {
|
||||
return false, SignatureNotFound
|
||||
}
|
||||
objJSON, err = sjson.DeleteBytes(objJSON, "unsigned")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
objJSON, err = sjson.DeleteBytes(objJSON, "signatures")
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON))
|
||||
return u.VerifySignature(objJSONString, key, sig.Str)
|
||||
}
|
||||
|
||||
// VerifySignatureJSON verifies the signature in the JSON object _obj following
|
||||
// the Matrix specification:
|
||||
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
|
||||
// This function is a wrapper over Utility.VerifySignatureJSON that creates and
|
||||
// destroys the Utility object transparently.
|
||||
// If the _obj is a struct, the `json` tags will be honored.
|
||||
func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
|
||||
u := NewUtility()
|
||||
defer u.Clear()
|
||||
return u.VerifySignatureJSON(obj, userID, keyName, key)
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/crypto/olm/verification.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/olm/verification.go
generated
vendored
@@ -1,3 +1,5 @@
|
||||
//go:build !nosas && !goolm
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
|
||||
23
vendor/maunium.net/go/mautrix/crypto/olm/verification_goolm.go
generated
vendored
Normal file
23
vendor/maunium.net/go/mautrix/crypto/olm/verification_goolm.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
//go:build !nosas && goolm
|
||||
|
||||
package olm
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/crypto/goolm/sas"
|
||||
)
|
||||
|
||||
// SAS stores an Olm Short Authentication String (SAS) object.
|
||||
type SAS struct {
|
||||
sas.SAS
|
||||
}
|
||||
|
||||
// NewSAS creates a new SAS object.
|
||||
func NewSAS() *SAS {
|
||||
newSAS, err := sas.New()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
return &SAS{
|
||||
SAS: *newSAS,
|
||||
}
|
||||
}
|
||||
527
vendor/maunium.net/go/mautrix/crypto/sql_store.go
generated
vendored
527
vendor/maunium.net/go/mautrix/crypto/sql_store.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -27,7 +27,7 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
var PostgresArrayWrapper func(interface{}) interface {
|
||||
var PostgresArrayWrapper func(any) interface {
|
||||
driver.Valuer
|
||||
sql.Scanner
|
||||
}
|
||||
@@ -62,22 +62,22 @@ func NewSQLCryptoStore(db *dbutil.Database, log dbutil.DatabaseLogger, accountID
|
||||
}
|
||||
|
||||
// Flush does nothing for this implementation as data is already persisted in the database.
|
||||
func (store *SQLCryptoStore) Flush() error {
|
||||
func (store *SQLCryptoStore) Flush(_ context.Context) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
// PutNextBatch stores the next sync batch token for the current account.
|
||||
func (store *SQLCryptoStore) PutNextBatch(nextBatch string) error {
|
||||
func (store *SQLCryptoStore) PutNextBatch(ctx context.Context, nextBatch string) error {
|
||||
store.SyncToken = nextBatch
|
||||
_, err := store.DB.Exec(`UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
|
||||
_, err := store.DB.Exec(ctx, `UPDATE crypto_account SET sync_token=$1 WHERE account_id=$2`, store.SyncToken, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetNextBatch retrieves the next sync batch token for the current account.
|
||||
func (store *SQLCryptoStore) GetNextBatch() (string, error) {
|
||||
func (store *SQLCryptoStore) GetNextBatch(ctx context.Context) (string, error) {
|
||||
if store.SyncToken == "" {
|
||||
err := store.DB.
|
||||
QueryRow("SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID).
|
||||
err := store.DB.Conn(ctx).
|
||||
QueryRowContext(ctx, "SELECT sync_token FROM crypto_account WHERE account_id=$1", store.AccountID).
|
||||
Scan(&store.SyncToken)
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
return "", err
|
||||
@@ -88,38 +88,42 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) {
|
||||
|
||||
var _ mautrix.SyncStore = (*SQLCryptoStore)(nil)
|
||||
|
||||
func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {}
|
||||
func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" }
|
||||
|
||||
func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
|
||||
err := store.PutNextBatch(nextBatchToken)
|
||||
if err != nil {
|
||||
// TODO handle error
|
||||
}
|
||||
func (store *SQLCryptoStore) SaveFilterID(ctx context.Context, _ id.UserID, _ string) error {
|
||||
return nil
|
||||
}
|
||||
func (store *SQLCryptoStore) LoadFilterID(ctx context.Context, _ id.UserID) (string, error) {
|
||||
return "", nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string {
|
||||
nb, err := store.GetNextBatch()
|
||||
func (store *SQLCryptoStore) SaveNextBatch(ctx context.Context, _ id.UserID, nextBatchToken string) error {
|
||||
err := store.PutNextBatch(ctx, nextBatchToken)
|
||||
if err != nil {
|
||||
// TODO handle error
|
||||
return fmt.Errorf("unable to store batch: %w", err)
|
||||
}
|
||||
return nb
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
|
||||
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
// TODO return error
|
||||
store.DB.Log.Warn("Failed to scan device ID: %v", err)
|
||||
func (store *SQLCryptoStore) LoadNextBatch(ctx context.Context, _ id.UserID) (string, error) {
|
||||
nb, err := store.GetNextBatch(ctx)
|
||||
if err != nil {
|
||||
return "", fmt.Errorf("unable to load batch: %w", err)
|
||||
}
|
||||
return nb, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID(ctx context.Context) (deviceID id.DeviceID, err error) {
|
||||
err = store.DB.QueryRow(ctx, "SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PutAccount stores an OlmAccount in the database.
|
||||
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
|
||||
func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount) error {
|
||||
store.Account = account
|
||||
bytes := account.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(`
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
|
||||
account=excluded.account, account_id=excluded.account_id
|
||||
@@ -128,9 +132,9 @@ func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
|
||||
}
|
||||
|
||||
// GetAccount retrieves an OlmAccount from the database.
|
||||
func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
|
||||
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
|
||||
if store.Account == nil {
|
||||
row := store.DB.QueryRow("SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
|
||||
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID)
|
||||
acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
|
||||
var accountBytes []byte
|
||||
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes)
|
||||
@@ -149,7 +153,7 @@ func (store *SQLCryptoStore) GetAccount() (*OlmAccount, error) {
|
||||
}
|
||||
|
||||
// HasSession returns whether there is an Olm session for the given sender key.
|
||||
func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
|
||||
func (store *SQLCryptoStore) HasSession(ctx context.Context, key id.SenderKey) bool {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
cache, ok := store.olmSessionCache[key]
|
||||
store.olmSessionCacheLock.Unlock()
|
||||
@@ -157,17 +161,17 @@ func (store *SQLCryptoStore) HasSession(key id.SenderKey) bool {
|
||||
return true
|
||||
}
|
||||
var sessionID id.SessionID
|
||||
err := store.DB.QueryRow("SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
|
||||
err := store.DB.QueryRow(ctx, "SELECT session_id FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 LIMIT 1",
|
||||
key, store.AccountID).Scan(&sessionID)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return false
|
||||
}
|
||||
return len(sessionID) > 0
|
||||
}
|
||||
|
||||
// GetSessions returns all the known Olm sessions for a sender key.
|
||||
func (store *SQLCryptoStore) GetSessions(key id.SenderKey) (OlmSessionList, error) {
|
||||
rows, err := store.DB.Query("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
|
||||
func (store *SQLCryptoStore) GetSessions(ctx context.Context, key id.SenderKey) (OlmSessionList, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC",
|
||||
key, store.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -207,11 +211,11 @@ func (store *SQLCryptoStore) getOlmSessionCache(key id.SenderKey) map[id.Session
|
||||
}
|
||||
|
||||
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
|
||||
func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, error) {
|
||||
func (store *SQLCryptoStore) GetLatestSession(ctx context.Context, key id.SenderKey) (*OlmSession, error) {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
|
||||
row := store.DB.QueryRow("SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
|
||||
row := store.DB.QueryRow(ctx, "SELECT session_id, session, created_at, last_encrypted, last_decrypted FROM crypto_olm_session WHERE sender_key=$1 AND account_id=$2 ORDER BY last_decrypted DESC LIMIT 1",
|
||||
key, store.AccountID)
|
||||
|
||||
sess := OlmSession{Internal: *olm.NewBlankSession()}
|
||||
@@ -219,7 +223,7 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
|
||||
var sessionID id.SessionID
|
||||
|
||||
err := row.Scan(&sessionID, &sessionBytes, &sess.CreationTime, &sess.LastEncryptedTime, &sess.LastDecryptedTime)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
@@ -237,20 +241,20 @@ func (store *SQLCryptoStore) GetLatestSession(key id.SenderKey) (*OlmSession, er
|
||||
}
|
||||
|
||||
// AddSession persists an Olm session for a sender in the database.
|
||||
func (store *SQLCryptoStore) AddSession(key id.SenderKey, session *OlmSession) error {
|
||||
func (store *SQLCryptoStore) AddSession(ctx context.Context, key id.SenderKey, session *OlmSession) error {
|
||||
store.olmSessionCacheLock.Lock()
|
||||
defer store.olmSessionCacheLock.Unlock()
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_olm_session (session_id, sender_key, session, created_at, last_encrypted, last_decrypted, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
session.ID(), key, sessionBytes, session.CreationTime, session.LastEncryptedTime, session.LastDecryptedTime, store.AccountID)
|
||||
store.getOlmSessionCache(key)[session.ID()] = session
|
||||
return err
|
||||
}
|
||||
|
||||
// UpdateSession replaces the Olm session for a sender in the database.
|
||||
func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession) error {
|
||||
func (store *SQLCryptoStore) UpdateSession(ctx context.Context, _ id.SenderKey, session *OlmSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec("UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_olm_session SET session=$1, last_encrypted=$2, last_decrypted=$3 WHERE session_id=$4 AND account_id=$5",
|
||||
sessionBytes, session.LastEncryptedTime, session.LastDecryptedTime, session.ID(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
@@ -270,14 +274,14 @@ func datePtr(t time.Time) *time.Time {
|
||||
}
|
||||
|
||||
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
forwardingChains := strings.Join(session.ForwardingChains, ",")
|
||||
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
|
||||
}
|
||||
_, err = store.DB.Exec(`
|
||||
_, err = store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_megolm_inbound_session (
|
||||
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
||||
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
|
||||
@@ -296,19 +300,19 @@ func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.Send
|
||||
}
|
||||
|
||||
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err := store.DB.QueryRow(`
|
||||
err := store.DB.QueryRow(ctx, `
|
||||
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
@@ -322,22 +326,7 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
||||
Reason: withheldReason.String,
|
||||
}
|
||||
}
|
||||
igs := olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var chains []string
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
|
||||
if senderKey == "" {
|
||||
senderKey = id.Curve25519(senderKeyDB.String)
|
||||
}
|
||||
@@ -355,8 +344,8 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
_, err := store.DB.Exec(`
|
||||
func (store *SQLCryptoStore) RedactGroupSession(ctx context.Context, _ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
|
||||
@@ -364,27 +353,24 @@ func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, ses
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
func (store *SQLCryptoStore) RedactGroupSessions(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
if roomID == "" && senderKey == "" {
|
||||
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
|
||||
}
|
||||
res, err := store.DB.Query(`
|
||||
res, err := store.DB.Query(ctx, `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
|
||||
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
|
||||
RETURNING session_id
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessionIDs, err
|
||||
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
|
||||
func (store *SQLCryptoStore) RedactExpiredGroupSessions(ctx context.Context) ([]id.SessionID, error) {
|
||||
var query string
|
||||
switch store.DB.Dialect {
|
||||
case dbutil.Postgres:
|
||||
@@ -408,46 +394,40 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported dialect")
|
||||
}
|
||||
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
res, err := store.DB.Query(ctx, query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessionIDs, err
|
||||
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
|
||||
res, err := store.DB.Query(`
|
||||
func (store *SQLCryptoStore) RedactOutdatedGroupSessions(ctx context.Context) ([]id.SessionID, error) {
|
||||
res, err := store.DB.Query(ctx, `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
|
||||
RETURNING session_id
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return sessionIDs, err
|
||||
return dbutil.NewRowIter(res, dbutil.ScanSingleColumn[id.SessionID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
func (store *SQLCryptoStore) PutWithheldGroupSession(ctx context.Context, content event.RoomKeyWithheldEventContent) error {
|
||||
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
func (store *SQLCryptoStore) GetWithheldGroupSession(ctx context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
var code, reason sql.NullString
|
||||
err := store.DB.QueryRow(`
|
||||
err := store.DB.QueryRow(ctx, `
|
||||
SELECT withheld_code, withheld_reason FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&code, &reason)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil || !code.Valid {
|
||||
return nil, err
|
||||
@@ -462,82 +442,79 @@ func (store *SQLCryptoStore) GetWithheldGroupSession(roomID id.RoomID, senderKey
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*InboundGroupSession, err error) {
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
var signingKey, senderKey, forwardingChains sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
func (store *SQLCryptoStore) postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes []byte, forwardingChains string) (igs *olm.InboundGroupSession, chains []string, safety RatchetSafety, err error) {
|
||||
igs = olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if forwardingChains != "" {
|
||||
chains = strings.Split(forwardingChains, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return
|
||||
err = fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
igs := olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
var chains []string
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
result = append(result, &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*InboundGroupSession, error) {
|
||||
var roomID id.RoomID
|
||||
var signingKey, senderKey, forwardingChains sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
igs, chains, rs, err := store.postScanInboundGroupSession(sessionBytes, ratchetSafetyBytes, forwardingChains.String)
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(ctx, `
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
||||
roomID, store.AccountID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return []*InboundGroupSession{}, nil
|
||||
} else if err != nil {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store.scanGroupSessionList(rows)
|
||||
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(ctx, `
|
||||
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
|
||||
store.AccountID,
|
||||
)
|
||||
if err == sql.ErrNoRows {
|
||||
return []*InboundGroupSession{}, nil
|
||||
} else if err != nil {
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return store.scanGroupSessionList(rows)
|
||||
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
|
||||
}
|
||||
|
||||
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
|
||||
func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
func (store *SQLCryptoStore) AddOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec(`
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_megolm_outbound_session
|
||||
(room_id, session_id, session, shared, max_messages, message_count, max_age, created_at, last_used, account_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10)
|
||||
@@ -551,24 +528,24 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
|
||||
}
|
||||
|
||||
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
|
||||
func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
func (store *SQLCryptoStore) UpdateOutboundGroupSession(ctx context.Context, session *OutboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
_, err := store.DB.Exec("UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_megolm_outbound_session SET session=$1, message_count=$2, last_used=$3 WHERE room_id=$4 AND session_id=$5 AND account_id=$6",
|
||||
sessionBytes, session.MessageCount, session.LastEncryptedTime, session.RoomID, session.ID(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID.
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(ctx context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
var ogs OutboundGroupSession
|
||||
var sessionBytes []byte
|
||||
var maxAgeMS int64
|
||||
err := store.DB.QueryRow(`
|
||||
err := store.DB.QueryRow(ctx, `
|
||||
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
|
||||
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
|
||||
roomID, store.AccountID,
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
|
||||
if err == sql.ErrNoRows {
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
@@ -585,8 +562,8 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
|
||||
}
|
||||
|
||||
// RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID.
|
||||
func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
_, err := store.DB.Exec("DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
|
||||
func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roomID id.RoomID) error {
|
||||
_, err := store.DB.Exec(ctx, "DELETE FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2",
|
||||
roomID, store.AccountID)
|
||||
return err
|
||||
}
|
||||
@@ -603,7 +580,7 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
|
||||
`
|
||||
var expectedEventID id.EventID
|
||||
var expectedTimestamp int64
|
||||
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
err := store.DB.QueryRow(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
|
||||
@@ -618,69 +595,58 @@ func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func scanDevice(rows dbutil.Scannable) (*id.Device, error) {
|
||||
var device id.Device
|
||||
err := rows.Scan(&device.UserID, &device.DeviceID, &device.IdentityKey, &device.SigningKey, &device.Trust, &device.Deleted, &device.Name)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &device, nil
|
||||
}
|
||||
|
||||
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
|
||||
func (store *SQLCryptoStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
func (store *SQLCryptoStore) GetDevices(ctx context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
var ignore id.UserID
|
||||
err := store.DB.QueryRow("SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
|
||||
if err == sql.ErrNoRows {
|
||||
err := store.DB.QueryRow(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id=$1", userID).Scan(&ignore)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
rows, err := store.DB.Query("SELECT device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
|
||||
rows, err := store.DB.Query(ctx, "SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name FROM crypto_device WHERE user_id=$1 AND deleted=false", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
data := make(map[id.DeviceID]*id.Device)
|
||||
for rows.Next() {
|
||||
var identity id.Device
|
||||
err := rows.Scan(&identity.DeviceID, &identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
identity.UserID = userID
|
||||
data[identity.DeviceID] = &identity
|
||||
err = dbutil.NewRowIter(rows, scanDevice).Iter(func(device *id.Device) (bool, error) {
|
||||
data[device.DeviceID] = device
|
||||
return true, nil
|
||||
})
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return data, nil
|
||||
}
|
||||
|
||||
// GetDevice returns the device dentity for a given user and device ID.
|
||||
func (store *SQLCryptoStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
var identity id.Device
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT identity_key, signing_key, trust, deleted, name
|
||||
func (store *SQLCryptoStore) GetDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
return scanDevice(store.DB.QueryRow(ctx, `
|
||||
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
|
||||
FROM crypto_device WHERE user_id=$1 AND device_id=$2`,
|
||||
userID, deviceID,
|
||||
).Scan(&identity.IdentityKey, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
identity.UserID = userID
|
||||
identity.DeviceID = deviceID
|
||||
return &identity, nil
|
||||
))
|
||||
}
|
||||
|
||||
// FindDeviceByKey finds a specific device by its sender key.
|
||||
func (store *SQLCryptoStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
var identity id.Device
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT device_id, signing_key, trust, deleted, name
|
||||
func (store *SQLCryptoStore) FindDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
return scanDevice(store.DB.QueryRow(ctx, `
|
||||
SELECT user_id, device_id, identity_key, signing_key, trust, deleted, name
|
||||
FROM crypto_device WHERE user_id=$1 AND identity_key=$2`,
|
||||
userID, identityKey,
|
||||
).Scan(&identity.DeviceID, &identity.SigningKey, &identity.Trust, &identity.Deleted, &identity.Name)
|
||||
if err != nil {
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
}
|
||||
return nil, err
|
||||
}
|
||||
identity.UserID = userID
|
||||
identity.IdentityKey = identityKey
|
||||
return &identity, nil
|
||||
))
|
||||
}
|
||||
|
||||
const deviceInsertQuery = `
|
||||
@@ -693,106 +659,115 @@ ON CONFLICT (user_id, device_id) DO UPDATE
|
||||
var deviceMassInsertTemplate = strings.ReplaceAll(deviceInsertQuery, "($1, $2, $3, $4, $5, $6, $7)", "%s")
|
||||
|
||||
// PutDevice stores a single device for a user, replacing it if it exists already.
|
||||
func (store *SQLCryptoStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
_, err := store.DB.Exec(deviceInsertQuery,
|
||||
func (store *SQLCryptoStore) PutDevice(ctx context.Context, userID id.UserID, device *id.Device) error {
|
||||
_, err := store.DB.Exec(ctx, deviceInsertQuery,
|
||||
userID, device.DeviceID, device.IdentityKey, device.SigningKey, device.Trust, device.Deleted, device.Name)
|
||||
return err
|
||||
}
|
||||
|
||||
const trackedUserUpsertQuery = `
|
||||
INSERT INTO crypto_tracked_user (user_id, devices_outdated)
|
||||
VALUES ($1, false)
|
||||
ON CONFLICT (user_id) DO UPDATE
|
||||
SET devices_outdated = EXCLUDED.devices_outdated
|
||||
`
|
||||
|
||||
// PutDevices stores the device identity information for the given user ID.
|
||||
func (store *SQLCryptoStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
tx, err := store.DB.Begin()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
_, err = tx.Exec("INSERT INTO crypto_tracked_user (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add user to tracked users list: %w", err)
|
||||
}
|
||||
|
||||
_, err = tx.Exec("UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("failed to delete old devices: %w", err)
|
||||
}
|
||||
if len(devices) == 0 {
|
||||
err = tx.Commit()
|
||||
func (store *SQLCryptoStore) PutDevices(ctx context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
_, err := store.DB.Exec(ctx, trackedUserUpsertQuery, userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes (no devices added): %w", err)
|
||||
return fmt.Errorf("failed to upsert user to tracked users list: %w", err)
|
||||
}
|
||||
|
||||
_, err = store.DB.Exec(ctx, "UPDATE crypto_device SET deleted=true WHERE user_id=$1", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to delete old devices: %w", err)
|
||||
}
|
||||
if len(devices) == 0 {
|
||||
return nil
|
||||
}
|
||||
deviceBatchLen := 5 // how many devices will be inserted per query
|
||||
deviceIDs := make([]id.DeviceID, 0, len(devices))
|
||||
for deviceID := range devices {
|
||||
deviceIDs = append(deviceIDs, deviceID)
|
||||
}
|
||||
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
|
||||
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
|
||||
var batchDevices []id.DeviceID
|
||||
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
|
||||
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
|
||||
} else {
|
||||
batchDevices = deviceIDs[batchDeviceIdx:]
|
||||
}
|
||||
values := make([]interface{}, 1, len(devices)*6+1)
|
||||
values[0] = userID
|
||||
valueStrings := make([]string, 0, len(devices))
|
||||
i := 2
|
||||
for _, deviceID := range batchDevices {
|
||||
identity := devices[deviceID]
|
||||
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
|
||||
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
|
||||
i += 6
|
||||
}
|
||||
valueString := strings.Join(valueStrings, ",")
|
||||
_, err = store.DB.Exec(ctx, fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to insert new devices: %w", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
deviceBatchLen := 5 // how many devices will be inserted per query
|
||||
deviceIDs := make([]id.DeviceID, 0, len(devices))
|
||||
for deviceID := range devices {
|
||||
deviceIDs = append(deviceIDs, deviceID)
|
||||
}
|
||||
const valueStringFormat = "($1, $%d, $%d, $%d, $%d, $%d, $%d)"
|
||||
for batchDeviceIdx := 0; batchDeviceIdx < len(deviceIDs); batchDeviceIdx += deviceBatchLen {
|
||||
var batchDevices []id.DeviceID
|
||||
if batchDeviceIdx+deviceBatchLen < len(deviceIDs) {
|
||||
batchDevices = deviceIDs[batchDeviceIdx : batchDeviceIdx+deviceBatchLen]
|
||||
} else {
|
||||
batchDevices = deviceIDs[batchDeviceIdx:]
|
||||
}
|
||||
values := make([]interface{}, 1, len(devices)*6+1)
|
||||
values[0] = userID
|
||||
valueStrings := make([]string, 0, len(devices))
|
||||
i := 2
|
||||
for _, deviceID := range batchDevices {
|
||||
identity := devices[deviceID]
|
||||
values = append(values, deviceID, identity.IdentityKey, identity.SigningKey, identity.Trust, identity.Deleted, identity.Name)
|
||||
valueStrings = append(valueStrings, fmt.Sprintf(valueStringFormat, i, i+1, i+2, i+3, i+4, i+5))
|
||||
i += 6
|
||||
}
|
||||
valueString := strings.Join(valueStrings, ",")
|
||||
_, err = tx.Exec(fmt.Sprintf(deviceMassInsertTemplate, valueString), values...)
|
||||
if err != nil {
|
||||
_ = tx.Rollback()
|
||||
return fmt.Errorf("failed to insert new devices: %w", err)
|
||||
}
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to commit changes: %w", err)
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// FilterTrackedUsers finds all the user IDs out of the given ones for which the database contains identity information.
|
||||
func (store *SQLCryptoStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
|
||||
func (store *SQLCryptoStore) FilterTrackedUsers(ctx context.Context, users []id.UserID) ([]id.UserID, error) {
|
||||
var rows dbutil.Rows
|
||||
var err error
|
||||
if store.DB.Dialect == dbutil.Postgres && PostgresArrayWrapper != nil {
|
||||
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
|
||||
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id = ANY($1)", PostgresArrayWrapper(users))
|
||||
} else {
|
||||
queryString := make([]string, len(users))
|
||||
params := make([]interface{}, len(users))
|
||||
for i, user := range users {
|
||||
queryString[i] = fmt.Sprintf("$%d", i+1)
|
||||
queryString[i] = fmt.Sprintf("?%d", i+1)
|
||||
params[i] = user
|
||||
}
|
||||
rows, err = store.DB.Query("SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
|
||||
rows, err = store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE user_id IN ("+strings.Join(queryString, ",")+")", params...)
|
||||
}
|
||||
if err != nil {
|
||||
return users, err
|
||||
}
|
||||
var ptr int
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&users[ptr])
|
||||
if err != nil {
|
||||
return users, err
|
||||
} else {
|
||||
ptr++
|
||||
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
|
||||
}
|
||||
|
||||
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
|
||||
func (store *SQLCryptoStore) MarkTrackedUsersOutdated(ctx context.Context, users []id.UserID) error {
|
||||
return store.DB.DoTxn(ctx, nil, func(ctx context.Context) error {
|
||||
// TODO refactor to use a single query
|
||||
for _, userID := range users {
|
||||
_, err := store.DB.Exec(ctx, "UPDATE crypto_tracked_user SET devices_outdated = true WHERE user_id = $1", userID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to update user in the tracked users list: %w", err)
|
||||
}
|
||||
}
|
||||
|
||||
return nil
|
||||
})
|
||||
}
|
||||
|
||||
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
|
||||
func (store *SQLCryptoStore) GetOutdatedTrackedUsers(ctx context.Context) ([]id.UserID, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT user_id FROM crypto_tracked_user WHERE devices_outdated = TRUE")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return users[:ptr], nil
|
||||
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.UserID]).AsList()
|
||||
}
|
||||
|
||||
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
|
||||
func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
_, err := store.DB.Exec(`
|
||||
func (store *SQLCryptoStore) PutCrossSigningKey(ctx context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_cross_signing_keys (user_id, usage, key, first_seen_key) VALUES ($1, $2, $3, $4)
|
||||
ON CONFLICT (user_id, usage) DO UPDATE SET key=excluded.key
|
||||
`, userID, usage, key, key)
|
||||
@@ -800,8 +775,8 @@ func (store *SQLCryptoStore) PutCrossSigningKey(userID id.UserID, usage id.Cross
|
||||
}
|
||||
|
||||
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
|
||||
func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
rows, err := store.DB.Query("SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
|
||||
func (store *SQLCryptoStore) GetCrossSigningKeys(ctx context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT usage, key, first_seen_key FROM crypto_cross_signing_keys WHERE user_id=$1", userID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -820,8 +795,8 @@ func (store *SQLCryptoStore) GetCrossSigningKeys(userID id.UserID) (map[id.Cross
|
||||
}
|
||||
|
||||
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
|
||||
func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
_, err := store.DB.Exec(`
|
||||
func (store *SQLCryptoStore) PutSignature(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
_, err := store.DB.Exec(ctx, `
|
||||
INSERT INTO crypto_cross_signing_signatures (signed_user_id, signed_key, signer_user_id, signer_key, signature) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (signed_user_id, signed_key, signer_user_id, signer_key) DO UPDATE SET signature=excluded.signature
|
||||
`, signedUserID, signedKey, signerUserID, signerKey, signature)
|
||||
@@ -829,8 +804,8 @@ func (store *SQLCryptoStore) PutSignature(signedUserID id.UserID, signedKey id.E
|
||||
}
|
||||
|
||||
// GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer.
|
||||
func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
rows, err := store.DB.Query("SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
|
||||
func (store *SQLCryptoStore) GetSignaturesForKeyBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
rows, err := store.DB.Query(ctx, "SELECT signer_key, signature FROM crypto_cross_signing_signatures WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3", userID, key, signerID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -849,18 +824,18 @@ func (store *SQLCryptoStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25
|
||||
}
|
||||
|
||||
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
|
||||
func (store *SQLCryptoStore) IsKeySignedBy(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
|
||||
func (store *SQLCryptoStore) IsKeySignedBy(ctx context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519) (isSigned bool, err error) {
|
||||
q := `SELECT EXISTS(
|
||||
SELECT 1 FROM crypto_cross_signing_signatures
|
||||
WHERE signed_user_id=$1 AND signed_key=$2 AND signer_user_id=$3 AND signer_key=$4
|
||||
)`
|
||||
err = store.DB.QueryRow(q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
|
||||
err = store.DB.QueryRow(ctx, q, signedUserID, signedKey, signerUserID, signerKey).Scan(&isSigned)
|
||||
return
|
||||
}
|
||||
|
||||
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
|
||||
func (store *SQLCryptoStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
res, err := store.DB.Exec("DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
|
||||
func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
res, err := store.DB.Exec(ctx, "DELETE FROM crypto_cross_signing_signatures WHERE signer_user_id=$1 AND signer_key=$2", userID, key)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
|
||||
5
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/00-latest-revision.sql
generated
vendored
5
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/00-latest-revision.sql
generated
vendored
@@ -1,4 +1,4 @@
|
||||
-- v0 -> v10: Latest revision
|
||||
-- v0 -> v11: Latest revision
|
||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
@@ -17,7 +17,8 @@ CREATE TABLE IF NOT EXISTS crypto_message_index (
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_tracked_user (
|
||||
user_id TEXT PRIMARY KEY
|
||||
user_id TEXT PRIMARY KEY,
|
||||
devices_outdated BOOLEAN NOT NULL DEFAULT FALSE
|
||||
);
|
||||
|
||||
CREATE TABLE IF NOT EXISTS crypto_device (
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/11-outdated-devices.sql
generated
vendored
Normal file
2
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/11-outdated-devices.sql
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
-- v11: Add devices_outdated field to crypto_tracked_user
|
||||
ALTER TABLE crypto_tracked_user ADD COLUMN devices_outdated BOOLEAN NOT NULL DEFAULT FALSE;
|
||||
3
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/upgrade.go
generated
vendored
3
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/upgrade.go
generated
vendored
@@ -7,6 +7,7 @@
|
||||
package sql_store_upgrade
|
||||
|
||||
import (
|
||||
"context"
|
||||
"embed"
|
||||
"fmt"
|
||||
|
||||
@@ -21,7 +22,7 @@ const VersionTableName = "crypto_version"
|
||||
var fs embed.FS
|
||||
|
||||
func init() {
|
||||
Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
|
||||
Table.Register(-1, 3, 0, "Unsupported version", false, func(ctx context.Context, database *dbutil.Database) error {
|
||||
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
|
||||
})
|
||||
Table.RegisterFS(fs)
|
||||
|
||||
35
vendor/maunium.net/go/mautrix/crypto/ssss/client.go
generated
vendored
35
vendor/maunium.net/go/mautrix/crypto/ssss/client.go
generated
vendored
@@ -7,6 +7,7 @@
|
||||
package ssss
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
@@ -29,9 +30,9 @@ type DefaultSecretStorageKeyContent struct {
|
||||
}
|
||||
|
||||
// GetDefaultKeyID retrieves the default key ID for this account from SSSS.
|
||||
func (mach *Machine) GetDefaultKeyID() (string, error) {
|
||||
func (mach *Machine) GetDefaultKeyID(ctx context.Context) (string, error) {
|
||||
var data DefaultSecretStorageKeyContent
|
||||
err := mach.Client.GetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &data)
|
||||
err := mach.Client.GetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &data)
|
||||
if err != nil {
|
||||
if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && httpErr.RespError.ErrCode == "M_NOT_FOUND" {
|
||||
return "", ErrNoDefaultKeyAccountDataEvent
|
||||
@@ -45,36 +46,36 @@ func (mach *Machine) GetDefaultKeyID() (string, error) {
|
||||
}
|
||||
|
||||
// SetDefaultKeyID sets the default key ID for this account on the server.
|
||||
func (mach *Machine) SetDefaultKeyID(keyID string) error {
|
||||
return mach.Client.SetAccountData(event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID})
|
||||
func (mach *Machine) SetDefaultKeyID(ctx context.Context, keyID string) error {
|
||||
return mach.Client.SetAccountData(ctx, event.AccountDataSecretStorageDefaultKey.Type, &DefaultSecretStorageKeyContent{keyID})
|
||||
}
|
||||
|
||||
// GetKeyData gets the details about the given key ID.
|
||||
func (mach *Machine) GetKeyData(keyID string) (keyData *KeyMetadata, err error) {
|
||||
func (mach *Machine) GetKeyData(ctx context.Context, keyID string) (keyData *KeyMetadata, err error) {
|
||||
keyData = &KeyMetadata{id: keyID}
|
||||
err = mach.Client.GetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
|
||||
err = mach.Client.GetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
|
||||
return
|
||||
}
|
||||
|
||||
// SetKeyData stores SSSS key metadata on the server.
|
||||
func (mach *Machine) SetKeyData(keyID string, keyData *KeyMetadata) error {
|
||||
return mach.Client.SetAccountData(fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
|
||||
func (mach *Machine) SetKeyData(ctx context.Context, keyID string, keyData *KeyMetadata) error {
|
||||
return mach.Client.SetAccountData(ctx, fmt.Sprintf("%s.%s", event.AccountDataSecretStorageKey.Type, keyID), keyData)
|
||||
}
|
||||
|
||||
// GetDefaultKeyData gets the details about the default key ID (see GetDefaultKeyID).
|
||||
func (mach *Machine) GetDefaultKeyData() (keyID string, keyData *KeyMetadata, err error) {
|
||||
keyID, err = mach.GetDefaultKeyID()
|
||||
func (mach *Machine) GetDefaultKeyData(ctx context.Context) (keyID string, keyData *KeyMetadata, err error) {
|
||||
keyID, err = mach.GetDefaultKeyID(ctx)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
keyData, err = mach.GetKeyData(keyID)
|
||||
keyData, err = mach.GetKeyData(ctx, keyID)
|
||||
return
|
||||
}
|
||||
|
||||
// GetDecryptedAccountData gets the account data event with the given event type and decrypts it using the given key.
|
||||
func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([]byte, error) {
|
||||
func (mach *Machine) GetDecryptedAccountData(ctx context.Context, eventType event.Type, key *Key) ([]byte, error) {
|
||||
var encData EncryptedAccountDataEventContent
|
||||
err := mach.Client.GetAccountData(eventType.Type, &encData)
|
||||
err := mach.Client.GetAccountData(ctx, eventType.Type, &encData)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -82,7 +83,7 @@ func (mach *Machine) GetDecryptedAccountData(eventType event.Type, key *Key) ([]
|
||||
}
|
||||
|
||||
// SetEncryptedAccountData encrypts the given data with the given keys and stores it on the server.
|
||||
func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte, keys ...*Key) error {
|
||||
func (mach *Machine) SetEncryptedAccountData(ctx context.Context, eventType event.Type, data []byte, keys ...*Key) error {
|
||||
if len(keys) == 0 {
|
||||
return ErrNoKeyGiven
|
||||
}
|
||||
@@ -90,17 +91,17 @@ func (mach *Machine) SetEncryptedAccountData(eventType event.Type, data []byte,
|
||||
for _, key := range keys {
|
||||
encrypted[key.ID] = key.Encrypt(eventType.Type, data)
|
||||
}
|
||||
return mach.Client.SetAccountData(eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted})
|
||||
return mach.Client.SetAccountData(ctx, eventType.Type, &EncryptedAccountDataEventContent{Encrypted: encrypted})
|
||||
}
|
||||
|
||||
// GenerateAndUploadKey generates a new SSSS key and stores the metadata on the server.
|
||||
func (mach *Machine) GenerateAndUploadKey(passphrase string) (key *Key, err error) {
|
||||
func (mach *Machine) GenerateAndUploadKey(ctx context.Context, passphrase string) (key *Key, err error) {
|
||||
key, err = NewKey(passphrase)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to generate new key: %w", err)
|
||||
}
|
||||
|
||||
err = mach.SetKeyData(key.ID, key.Metadata)
|
||||
err = mach.SetKeyData(ctx, key.ID, key.Metadata)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to upload key: %w", err)
|
||||
}
|
||||
|
||||
168
vendor/maunium.net/go/mautrix/crypto/store.go
generated
vendored
168
vendor/maunium.net/go/mautrix/crypto/store.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -26,64 +26,64 @@ var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{}
|
||||
type Store interface {
|
||||
// Flush ensures that everything in the store is persisted to disk.
|
||||
// This doesn't have to do anything, e.g. for database-backed implementations that persist everything immediately.
|
||||
Flush() error
|
||||
Flush(context.Context) error
|
||||
|
||||
// PutAccount updates the OlmAccount in the store.
|
||||
PutAccount(*OlmAccount) error
|
||||
PutAccount(context.Context, *OlmAccount) error
|
||||
// GetAccount returns the OlmAccount in the store that was previously inserted with PutAccount.
|
||||
GetAccount() (*OlmAccount, error)
|
||||
GetAccount(ctx context.Context) (*OlmAccount, error)
|
||||
|
||||
// AddSession inserts an Olm session into the store.
|
||||
AddSession(id.SenderKey, *OlmSession) error
|
||||
AddSession(context.Context, id.SenderKey, *OlmSession) error
|
||||
// HasSession returns whether or not the store has an Olm session with the given sender key.
|
||||
HasSession(id.SenderKey) bool
|
||||
HasSession(context.Context, id.SenderKey) bool
|
||||
// GetSessions returns all Olm sessions in the store with the given sender key.
|
||||
GetSessions(id.SenderKey) (OlmSessionList, error)
|
||||
GetSessions(context.Context, id.SenderKey) (OlmSessionList, error)
|
||||
// GetLatestSession returns the session with the highest session ID (lexiographically sorting).
|
||||
// It's usually safe to return the most recently added session if sorting by session ID is too difficult.
|
||||
GetLatestSession(id.SenderKey) (*OlmSession, error)
|
||||
GetLatestSession(context.Context, id.SenderKey) (*OlmSession, error)
|
||||
// UpdateSession updates a session that has previously been inserted with AddSession.
|
||||
UpdateSession(id.SenderKey, *OlmSession) error
|
||||
UpdateSession(context.Context, id.SenderKey, *OlmSession) error
|
||||
|
||||
// PutGroupSession inserts an inbound Megolm session into the store. If an earlier withhold event has been inserted
|
||||
// with PutWithheldGroupSession, this call should replace that. However, PutWithheldGroupSession must not replace
|
||||
// sessions inserted with this call.
|
||||
PutGroupSession(id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
|
||||
PutGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, *InboundGroupSession) error
|
||||
// GetGroupSession gets an inbound Megolm session from the store. If the group session has been withheld
|
||||
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
|
||||
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
|
||||
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
|
||||
GetGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
|
||||
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
|
||||
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error
|
||||
RedactGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID, string) error
|
||||
// RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room.
|
||||
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
|
||||
RedactGroupSessions(context.Context, id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
|
||||
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
|
||||
RedactExpiredGroupSessions() ([]id.SessionID, error)
|
||||
RedactExpiredGroupSessions(context.Context) ([]id.SessionID, error)
|
||||
// RedactOutdatedGroupSessions removes the session data for all inbound Megolm sessions that are lacking the expiration metadata.
|
||||
RedactOutdatedGroupSessions() ([]id.SessionID, error)
|
||||
RedactOutdatedGroupSessions(context.Context) ([]id.SessionID, error)
|
||||
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
|
||||
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
|
||||
PutWithheldGroupSession(context.Context, event.RoomKeyWithheldEventContent) error
|
||||
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
|
||||
GetWithheldGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
|
||||
GetWithheldGroupSession(context.Context, id.RoomID, id.SenderKey, id.SessionID) (*event.RoomKeyWithheldEventContent, error)
|
||||
|
||||
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
|
||||
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
|
||||
GetGroupSessionsForRoom(id.RoomID) ([]*InboundGroupSession, error)
|
||||
GetGroupSessionsForRoom(context.Context, id.RoomID) ([]*InboundGroupSession, error)
|
||||
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
|
||||
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
|
||||
GetAllGroupSessions() ([]*InboundGroupSession, error)
|
||||
GetAllGroupSessions(context.Context) ([]*InboundGroupSession, error)
|
||||
|
||||
// AddOutboundGroupSession inserts the given outbound Megolm session into the store.
|
||||
//
|
||||
// The store should index inserted sessions by the RoomID field to support getting and removing sessions.
|
||||
// There will only be one outbound session per room ID at a time.
|
||||
AddOutboundGroupSession(*OutboundGroupSession) error
|
||||
AddOutboundGroupSession(context.Context, *OutboundGroupSession) error
|
||||
// UpdateOutboundGroupSession updates the given outbound Megolm session in the store.
|
||||
UpdateOutboundGroupSession(*OutboundGroupSession) error
|
||||
UpdateOutboundGroupSession(context.Context, *OutboundGroupSession) error
|
||||
// GetOutboundGroupSession gets the stored outbound Megolm session for the given room ID from the store.
|
||||
GetOutboundGroupSession(id.RoomID) (*OutboundGroupSession, error)
|
||||
GetOutboundGroupSession(context.Context, id.RoomID) (*OutboundGroupSession, error)
|
||||
// RemoveOutboundGroupSession removes the stored outbound Megolm session for the given room ID.
|
||||
RemoveOutboundGroupSession(id.RoomID) error
|
||||
RemoveOutboundGroupSession(context.Context, id.RoomID) error
|
||||
|
||||
// ValidateMessageIndex validates that the given message details aren't from a replay attack.
|
||||
//
|
||||
@@ -96,29 +96,33 @@ type Store interface {
|
||||
ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
|
||||
|
||||
// GetDevices returns a map from device ID to id.Device struct containing all devices of a given user.
|
||||
GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error)
|
||||
GetDevices(context.Context, id.UserID) (map[id.DeviceID]*id.Device, error)
|
||||
// GetDevice returns a specific device of a given user.
|
||||
GetDevice(id.UserID, id.DeviceID) (*id.Device, error)
|
||||
GetDevice(context.Context, id.UserID, id.DeviceID) (*id.Device, error)
|
||||
// PutDevice stores a single device for a user, replacing it if it exists already.
|
||||
PutDevice(id.UserID, *id.Device) error
|
||||
PutDevice(context.Context, id.UserID, *id.Device) error
|
||||
// PutDevices overrides the stored device list for the given user with the given list.
|
||||
PutDevices(id.UserID, map[id.DeviceID]*id.Device) error
|
||||
PutDevices(context.Context, id.UserID, map[id.DeviceID]*id.Device) error
|
||||
// FindDeviceByKey finds a specific device by its identity key.
|
||||
FindDeviceByKey(id.UserID, id.IdentityKey) (*id.Device, error)
|
||||
FindDeviceByKey(context.Context, id.UserID, id.IdentityKey) (*id.Device, error)
|
||||
// FilterTrackedUsers returns a filtered version of the given list that only includes user IDs whose device lists
|
||||
// have been stored with PutDevices. A user is considered tracked even if the PutDevices list was empty.
|
||||
FilterTrackedUsers([]id.UserID) ([]id.UserID, error)
|
||||
FilterTrackedUsers(context.Context, []id.UserID) ([]id.UserID, error)
|
||||
// MarkTrackedUsersOutdated flags that the device list for given users are outdated.
|
||||
MarkTrackedUsersOutdated(context.Context, []id.UserID) error
|
||||
// GetOutdatedTrackerUsers gets all tracked users whose devices need to be updated.
|
||||
GetOutdatedTrackedUsers(context.Context) ([]id.UserID, error)
|
||||
|
||||
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
|
||||
PutCrossSigningKey(id.UserID, id.CrossSigningUsage, id.Ed25519) error
|
||||
PutCrossSigningKey(context.Context, id.UserID, id.CrossSigningUsage, id.Ed25519) error
|
||||
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
|
||||
GetCrossSigningKeys(id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
|
||||
GetCrossSigningKeys(context.Context, id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error)
|
||||
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
|
||||
PutSignature(signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
|
||||
PutSignature(ctx context.Context, signedUser id.UserID, signedKey id.Ed25519, signerUser id.UserID, signerKey id.Ed25519, signature string) error
|
||||
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
|
||||
IsKeySignedBy(userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
|
||||
IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signedByUser id.UserID, signedByKey id.Ed25519) (bool, error)
|
||||
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
|
||||
DropSignaturesByKey(id.UserID, id.Ed25519) (int64, error)
|
||||
DropSignaturesByKey(context.Context, id.UserID, id.Ed25519) (int64, error)
|
||||
}
|
||||
|
||||
type messageIndexKey struct {
|
||||
@@ -148,6 +152,7 @@ type MemoryStore struct {
|
||||
Devices map[id.UserID]map[id.DeviceID]*id.Device
|
||||
CrossSigningKeys map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey
|
||||
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
|
||||
OutdatedUsers map[id.UserID]struct{}
|
||||
}
|
||||
|
||||
var _ Store = (*MemoryStore)(nil)
|
||||
@@ -167,21 +172,22 @@ func NewMemoryStore(saveCallback func() error) *MemoryStore {
|
||||
Devices: make(map[id.UserID]map[id.DeviceID]*id.Device),
|
||||
CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey),
|
||||
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
|
||||
OutdatedUsers: make(map[id.UserID]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) Flush() error {
|
||||
func (gs *MemoryStore) Flush(_ context.Context) error {
|
||||
gs.lock.Lock()
|
||||
err := gs.save()
|
||||
gs.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetAccount() (*OlmAccount, error) {
|
||||
func (gs *MemoryStore) GetAccount(_ context.Context) (*OlmAccount, error) {
|
||||
return gs.Account, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
|
||||
func (gs *MemoryStore) PutAccount(_ context.Context, account *OlmAccount) error {
|
||||
gs.lock.Lock()
|
||||
gs.Account = account
|
||||
err := gs.save()
|
||||
@@ -189,7 +195,7 @@ func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
|
||||
func (gs *MemoryStore) GetSessions(_ context.Context, senderKey id.SenderKey) (OlmSessionList, error) {
|
||||
gs.lock.Lock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
if !ok {
|
||||
@@ -200,7 +206,7 @@ func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, erro
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
|
||||
func (gs *MemoryStore) AddSession(_ context.Context, senderKey id.SenderKey, session *OlmSession) error {
|
||||
gs.lock.Lock()
|
||||
sessions, _ := gs.Sessions[senderKey]
|
||||
gs.Sessions[senderKey] = append(sessions, session)
|
||||
@@ -210,19 +216,19 @@ func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) e
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
|
||||
func (gs *MemoryStore) UpdateSession(_ context.Context, _ id.SenderKey, _ *OlmSession) error {
|
||||
// we don't need to do anything here because the session is a pointer and already stored in our map
|
||||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool {
|
||||
func (gs *MemoryStore) HasSession(_ context.Context, senderKey id.SenderKey) bool {
|
||||
gs.lock.RLock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
gs.lock.RUnlock()
|
||||
return ok && len(sessions) > 0 && !sessions[0].Expired()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
|
||||
func (gs *MemoryStore) GetLatestSession(_ context.Context, senderKey id.SenderKey) (*OlmSession, error) {
|
||||
gs.lock.RLock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
gs.lock.RUnlock()
|
||||
@@ -246,7 +252,7 @@ func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey
|
||||
return sender
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
|
||||
func (gs *MemoryStore) PutGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
|
||||
gs.lock.Lock()
|
||||
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
|
||||
err := gs.save()
|
||||
@@ -254,7 +260,7 @@ func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey,
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
|
||||
if !ok {
|
||||
@@ -269,7 +275,7 @@ func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey,
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
func (gs *MemoryStore) RedactGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
gs.lock.Lock()
|
||||
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
|
||||
err := gs.save()
|
||||
@@ -277,7 +283,7 @@ func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderK
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
func (gs *MemoryStore) RedactGroupSessions(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
gs.lock.Lock()
|
||||
var sessionIDs []id.SessionID
|
||||
if roomID != "" && senderKey != "" {
|
||||
@@ -315,11 +321,11 @@ func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.Sender
|
||||
return sessionIDs, err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
|
||||
func (gs *MemoryStore) RedactExpiredGroupSessions(_ context.Context) ([]id.SessionID, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
|
||||
func (gs *MemoryStore) RedactOutdatedGroupSessions(_ context.Context) ([]id.SessionID, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
@@ -337,7 +343,7 @@ func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.S
|
||||
return sender
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
func (gs *MemoryStore) PutWithheldGroupSession(_ context.Context, content event.RoomKeyWithheldEventContent) error {
|
||||
gs.lock.Lock()
|
||||
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
|
||||
err := gs.save()
|
||||
@@ -345,7 +351,7 @@ func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEven
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
func (gs *MemoryStore) GetWithheldGroupSession(_ context.Context, roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
|
||||
gs.lock.Unlock()
|
||||
@@ -355,7 +361,7 @@ func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Se
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetGroupSessionsForRoom(_ context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
defer gs.lock.Unlock()
|
||||
room, ok := gs.GroupSessions[roomID]
|
||||
@@ -371,7 +377,7 @@ func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGrou
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetAllGroupSessions(_ context.Context) ([]*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
var result []*InboundGroupSession
|
||||
for _, room := range gs.GroupSessions {
|
||||
@@ -385,7 +391,7 @@ func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
func (gs *MemoryStore) AddOutboundGroupSession(_ context.Context, session *OutboundGroupSession) error {
|
||||
gs.lock.Lock()
|
||||
gs.OutGroupSessions[session.RoomID] = session
|
||||
err := gs.save()
|
||||
@@ -393,12 +399,12 @@ func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) er
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
|
||||
func (gs *MemoryStore) UpdateOutboundGroupSession(_ context.Context, _ *OutboundGroupSession) error {
|
||||
// we don't need to do anything here because the session is a pointer and already stored in our map
|
||||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetOutboundGroupSession(_ context.Context, roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
gs.lock.RLock()
|
||||
session, ok := gs.OutGroupSessions[roomID]
|
||||
gs.lock.RUnlock()
|
||||
@@ -408,7 +414,7 @@ func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroup
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
func (gs *MemoryStore) RemoveOutboundGroupSession(_ context.Context, roomID id.RoomID) error {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.OutGroupSessions[roomID]
|
||||
if !ok || session == nil {
|
||||
@@ -443,7 +449,7 @@ func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.Send
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
func (gs *MemoryStore) GetDevices(_ context.Context, userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
if !ok {
|
||||
@@ -453,7 +459,7 @@ func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device,
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
func (gs *MemoryStore) GetDevice(_ context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
@@ -467,7 +473,7 @@ func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.De
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
func (gs *MemoryStore) FindDeviceByKey(_ context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
@@ -482,7 +488,7 @@ func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.Identity
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
func (gs *MemoryStore) PutDevice(_ context.Context, userID id.UserID, device *id.Device) error {
|
||||
gs.lock.Lock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
if !ok {
|
||||
@@ -495,15 +501,18 @@ func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
func (gs *MemoryStore) PutDevices(_ context.Context, userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
gs.lock.Lock()
|
||||
gs.Devices[userID] = devices
|
||||
err := gs.save()
|
||||
if err == nil {
|
||||
delete(gs.OutdatedUsers, userID)
|
||||
}
|
||||
gs.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
|
||||
func (gs *MemoryStore) FilterTrackedUsers(_ context.Context, users []id.UserID) ([]id.UserID, error) {
|
||||
gs.lock.RLock()
|
||||
var ptr int
|
||||
for _, userID := range users {
|
||||
@@ -517,7 +526,28 @@ func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error
|
||||
return users[:ptr], nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
func (gs *MemoryStore) MarkTrackedUsersOutdated(_ context.Context, users []id.UserID) error {
|
||||
gs.lock.Lock()
|
||||
for _, userID := range users {
|
||||
if _, ok := gs.Devices[userID]; ok {
|
||||
gs.OutdatedUsers[userID] = struct{}{}
|
||||
}
|
||||
}
|
||||
gs.lock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetOutdatedTrackedUsers(_ context.Context) ([]id.UserID, error) {
|
||||
gs.lock.RLock()
|
||||
users := make([]id.UserID, 0, len(gs.OutdatedUsers))
|
||||
for userID := range gs.OutdatedUsers {
|
||||
users = append(users, userID)
|
||||
}
|
||||
gs.lock.RUnlock()
|
||||
return users, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutCrossSigningKey(_ context.Context, userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
gs.lock.RLock()
|
||||
userKeys, ok := gs.CrossSigningKeys[userID]
|
||||
if !ok {
|
||||
@@ -539,7 +569,7 @@ func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSignin
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
func (gs *MemoryStore) GetCrossSigningKeys(_ context.Context, userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
keys, ok := gs.CrossSigningKeys[userID]
|
||||
@@ -549,7 +579,7 @@ func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSignin
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
func (gs *MemoryStore) PutSignature(_ context.Context, signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
gs.lock.RLock()
|
||||
signedUserSigs, ok := gs.KeySignatures[signedUserID]
|
||||
if !ok {
|
||||
@@ -572,7 +602,7 @@ func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
func (gs *MemoryStore) GetSignaturesForKeyBy(_ context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
userKeys, ok := gs.KeySignatures[userID]
|
||||
@@ -590,8 +620,8 @@ func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, s
|
||||
return sigsBySigner, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
|
||||
sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID)
|
||||
func (gs *MemoryStore) IsKeySignedBy(ctx context.Context, userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
|
||||
sigs, err := gs.GetSignaturesForKeyBy(ctx, userID, key, signerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
@@ -599,7 +629,7 @@ func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
func (gs *MemoryStore) DropSignaturesByKey(_ context.Context, userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
var count int64
|
||||
gs.lock.RLock()
|
||||
for _, userSigs := range gs.KeySignatures {
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/crypto/utils/utils.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/utils/utils.go
generated
vendored
@@ -10,10 +10,10 @@ import (
|
||||
"crypto/aes"
|
||||
"crypto/cipher"
|
||||
"crypto/hmac"
|
||||
"crypto/rand"
|
||||
"crypto/sha256"
|
||||
"crypto/sha512"
|
||||
"encoding/base64"
|
||||
"math/rand"
|
||||
"strings"
|
||||
|
||||
"go.mau.fi/util/base58"
|
||||
|
||||
166
vendor/maunium.net/go/mautrix/crypto/verification.go
generated
vendored
166
vendor/maunium.net/go/mautrix/crypto/verification.go
generated
vendored
@@ -54,8 +54,8 @@ const (
|
||||
)
|
||||
|
||||
// sendToOneDevice sends a to-device event to a single device.
|
||||
func (mach *OlmMachine) sendToOneDevice(userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error {
|
||||
_, err := mach.Client.SendToDevice(eventType, &mautrix.ReqSendToDevice{
|
||||
func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error {
|
||||
_, err := mach.Client.SendToDevice(ctx, eventType, &mautrix.ReqSendToDevice{
|
||||
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
|
||||
userID: {
|
||||
deviceID: {
|
||||
@@ -118,19 +118,19 @@ type verificationState struct {
|
||||
}
|
||||
|
||||
// getTransactionState retrieves the given transaction's state, or cancels the transaction if it cannot be found or there is a mismatch.
|
||||
func (mach *OlmMachine) getTransactionState(transactionID string, userID id.UserID) (*verificationState, error) {
|
||||
func (mach *OlmMachine) getTransactionState(ctx context.Context, transactionID string, userID id.UserID) (*verificationState, error) {
|
||||
verStateInterface, ok := mach.keyVerificationTransactionState.Load(userID.String() + ":" + transactionID)
|
||||
if !ok {
|
||||
_ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction)
|
||||
_ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, "Unknown transaction: "+transactionID, event.VerificationCancelUnknownTransaction)
|
||||
return nil, ErrUnknownTransaction
|
||||
}
|
||||
verState := verStateInterface.(*verificationState)
|
||||
if verState.otherDevice.UserID != userID {
|
||||
reason := fmt.Sprintf("Unknown user for transaction %v: %v", transactionID, userID)
|
||||
if verState.inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch)
|
||||
_ = mach.SendSASVerificationCancel(ctx, userID, id.DeviceID("*"), transactionID, reason, event.VerificationCancelUserMismatch)
|
||||
} else {
|
||||
_ = mach.SendInRoomSASVerificationCancel(verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch)
|
||||
_ = mach.SendInRoomSASVerificationCancel(ctx, verState.inRoomID, userID, transactionID, reason, event.VerificationCancelUserMismatch)
|
||||
}
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
return nil, fmt.Errorf("%w %s: %s", ErrUnknownUserForTransaction, transactionID, userID)
|
||||
@@ -140,9 +140,9 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User
|
||||
|
||||
// handleVerificationStart handles an incoming m.key.verification.start message.
|
||||
// It initializes the state for this SAS verification process and stores it.
|
||||
func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
|
||||
func (mach *OlmMachine) handleVerificationStart(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
|
||||
mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
|
||||
return
|
||||
@@ -150,9 +150,9 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event
|
||||
warnAndCancel := func(logReason, cancelReason string) {
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
|
||||
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
|
||||
} else {
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
|
||||
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
|
||||
}
|
||||
}
|
||||
switch {
|
||||
@@ -168,21 +168,21 @@ func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event
|
||||
case !content.SupportsSASMethod(event.SASDecimal):
|
||||
warnAndCancel("does not support decimal SAS", "Decimal SAS method must be supported")
|
||||
default:
|
||||
mach.actuallyStartVerification(userID, content, otherDevice, transactionID, timeout, inRoomID)
|
||||
mach.actuallyStartVerification(ctx, userID, content, otherDevice, transactionID, timeout, inRoomID)
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
|
||||
func (mach *OlmMachine) actuallyStartVerification(ctx context.Context, userID id.UserID, content *event.VerificationStartEventContent, otherDevice *id.Device, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
|
||||
if inRoomID != "" && transactionID != "" {
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
verState, err := mach.getTransactionState(ctx, transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error")
|
||||
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error")
|
||||
return
|
||||
}
|
||||
mach.timeoutAfter(verState, transactionID, timeout)
|
||||
mach.timeoutAfter(ctx, verState, transactionID, timeout)
|
||||
sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString)
|
||||
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
|
||||
err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err)
|
||||
}
|
||||
@@ -196,9 +196,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
if len(sasMethods) == 0 {
|
||||
mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
|
||||
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
|
||||
} else {
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
|
||||
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -221,20 +221,20 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
// transaction already exists
|
||||
mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
|
||||
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
|
||||
} else {
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
|
||||
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
mach.timeoutAfter(verState, transactionID, timeout)
|
||||
mach.timeoutAfter(ctx, verState, transactionID, timeout)
|
||||
|
||||
var err error
|
||||
if inRoomID == "" {
|
||||
err = mach.SendSASVerificationAccept(userID, content, verState.sas.GetPubkey(), sasMethods)
|
||||
err = mach.SendSASVerificationAccept(ctx, userID, content, verState.sas.GetPubkey(), sasMethods)
|
||||
} else {
|
||||
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
|
||||
err = mach.SendInRoomSASVerificationAccept(ctx, inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error accepting SAS verification: %v", err)
|
||||
@@ -243,9 +243,9 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
var err error
|
||||
if inRoomID == "" {
|
||||
err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
err = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
} else {
|
||||
err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
err = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error canceling SAS verification: %v", err)
|
||||
@@ -255,8 +255,8 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID string, timeout time.Duration) {
|
||||
timeoutCtx, timeoutCancel := context.WithTimeout(context.Background(), timeout)
|
||||
func (mach *OlmMachine) timeoutAfter(ctx context.Context, verState *verificationState, transactionID string, timeout time.Duration) {
|
||||
timeoutCtx, timeoutCancel := context.WithTimeout(ctx, timeout)
|
||||
verState.extendTimeout = timeoutCancel
|
||||
go func() {
|
||||
mapKey := verState.otherDevice.UserID.String() + ":" + transactionID
|
||||
@@ -272,7 +272,7 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
|
||||
if timeoutCtx.Err() == context.DeadlineExceeded {
|
||||
// if deadline exceeded cancel due to timeout
|
||||
mach.keyVerificationTransactionState.Delete(mapKey)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Timed out", event.VerificationCancelByTimeout)
|
||||
mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID)
|
||||
verState.lock.Unlock()
|
||||
return
|
||||
@@ -288,9 +288,9 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
|
||||
|
||||
// handleVerificationAccept handles an incoming m.key.verification.accept message.
|
||||
// It continues the SAS verification process by sending the SAS key message to the other device.
|
||||
func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) {
|
||||
func (mach *OlmMachine) handleVerificationAccept(ctx context.Context, userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) {
|
||||
mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID)
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
verState, err := mach.getTransactionState(ctx, transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
@@ -303,7 +303,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
// unexpected accept at this point
|
||||
mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -315,7 +315,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -325,9 +325,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
verState.verificationStarted = true
|
||||
|
||||
if verState.inRoomID == "" {
|
||||
err = mach.SendSASVerificationKey(userID, verState.otherDevice.DeviceID, transactionID, string(key))
|
||||
err = mach.SendSASVerificationKey(ctx, userID, verState.otherDevice.DeviceID, transactionID, string(key))
|
||||
} else {
|
||||
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
|
||||
err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key))
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
|
||||
@@ -337,9 +337,9 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
|
||||
// handleVerificationKey handles an incoming m.key.verification.key message.
|
||||
// It stores the other device's public key in order to acquire the SAS shared secret.
|
||||
func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) {
|
||||
func (mach *OlmMachine) handleVerificationKey(ctx context.Context, userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) {
|
||||
mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key)
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
verState, err := mach.getTransactionState(ctx, transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
@@ -354,7 +354,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
// unexpected key at this point
|
||||
mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -372,7 +372,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
if expectedCommitment != verState.commitment {
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch)
|
||||
return
|
||||
}
|
||||
} else {
|
||||
@@ -380,9 +380,9 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
key := verState.sas.GetPubkey()
|
||||
|
||||
if verState.inRoomID == "" {
|
||||
err = mach.SendSASVerificationKey(userID, device.DeviceID, transactionID, string(key))
|
||||
err = mach.SendSASVerificationKey(ctx, userID, device.DeviceID, transactionID, string(key))
|
||||
} else {
|
||||
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
|
||||
err = mach.SendInRoomSASVerificationKey(ctx, verState.inRoomID, userID, transactionID, string(key))
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
|
||||
@@ -419,13 +419,13 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas)
|
||||
go func() {
|
||||
result := verState.hooks.VerifySASMatch(device, sas)
|
||||
mach.sasCompared(result, transactionID, verState)
|
||||
mach.sasCompared(ctx, result, transactionID, verState)
|
||||
}()
|
||||
}
|
||||
|
||||
// sasCompared is called asynchronously. It waits for the SAS to be compared for the verification to proceed.
|
||||
// If the SAS match, then our MAC is sent out. Otherwise the transaction is canceled.
|
||||
func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verState *verificationState) {
|
||||
func (mach *OlmMachine) sasCompared(ctx context.Context, didMatch bool, transactionID string, verState *verificationState) {
|
||||
verState.lock.Lock()
|
||||
defer verState.lock.Unlock()
|
||||
verState.extendTimeout()
|
||||
@@ -433,9 +433,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
|
||||
verState.sasMatched <- true
|
||||
var err error
|
||||
if verState.inRoomID == "" {
|
||||
err = mach.SendSASVerificationMAC(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
|
||||
err = mach.SendSASVerificationMAC(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
|
||||
} else {
|
||||
err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
|
||||
err = mach.SendInRoomSASVerificationMAC(ctx, verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err)
|
||||
@@ -447,9 +447,9 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
|
||||
|
||||
// handleVerificationMAC handles an incoming m.key.verification.mac message.
|
||||
// It verifies the other device's MAC and if the MAC is valid it marks the device as trusted.
|
||||
func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) {
|
||||
func (mach *OlmMachine) handleVerificationMAC(ctx context.Context, userID id.UserID, content *event.VerificationMacEventContent, transactionID string) {
|
||||
mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
verState, err := mach.getTransactionState(ctx, transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
@@ -466,7 +466,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
|
||||
if !verState.verificationStarted || !verState.keyReceived {
|
||||
// unexpected MAC at this point
|
||||
mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -478,7 +478,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
|
||||
|
||||
if !matched {
|
||||
mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -494,38 +494,38 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
|
||||
mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
|
||||
if content.Keys != expectedKeysMAC {
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch)
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
|
||||
if content.Mac[keyID] != expectedPKMAC {
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch)
|
||||
_ = mach.callbackAndCancelSASVerification(ctx, verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch)
|
||||
return
|
||||
}
|
||||
|
||||
// we can finally trust this device
|
||||
device.Trust = id.TrustStateVerified
|
||||
err = mach.CryptoStore.PutDevice(device.UserID, device)
|
||||
err = mach.CryptoStore.PutDevice(ctx, device.UserID, device)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err)
|
||||
}
|
||||
|
||||
if mach.CrossSigningKeys != nil {
|
||||
if device.UserID == mach.Client.UserID {
|
||||
err := mach.SignOwnDevice(device)
|
||||
err := mach.SignOwnDevice(ctx, device)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err)
|
||||
} else {
|
||||
mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID)
|
||||
}
|
||||
} else {
|
||||
masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID)
|
||||
masterKey, err := mach.fetchMasterKey(ctx, device, content, verState, transactionID)
|
||||
if err != nil {
|
||||
mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err)
|
||||
} else {
|
||||
if err := mach.SignUser(device.UserID, masterKey); err != nil {
|
||||
if err := mach.SignUser(ctx, device.UserID, masterKey); err != nil {
|
||||
mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err)
|
||||
} else {
|
||||
mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID)
|
||||
@@ -559,9 +559,9 @@ func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *even
|
||||
}
|
||||
|
||||
// handleVerificationRequest handles an incoming m.key.verification.request message.
|
||||
func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) {
|
||||
func (mach *OlmMachine) handleVerificationRequest(ctx context.Context, userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) {
|
||||
mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
|
||||
return
|
||||
@@ -569,9 +569,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
|
||||
if !content.SupportsVerificationMethod(event.VerificationMethodSAS) {
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
|
||||
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
|
||||
} else {
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
|
||||
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -579,14 +579,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
|
||||
if resp == AcceptRequest {
|
||||
mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
if inRoomID == "" {
|
||||
_, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
|
||||
_, err = mach.NewSASVerificationWith(ctx, otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
|
||||
} else {
|
||||
if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil {
|
||||
if err := mach.SendInRoomSASVerificationReady(ctx, inRoomID, transactionID); err != nil {
|
||||
mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err)
|
||||
}
|
||||
if mach.Client.UserID < otherDevice.UserID {
|
||||
// up to us to send the start message
|
||||
_, err = mach.newInRoomSASVerificationWithInner(inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
|
||||
_, err = mach.newInRoomSASVerificationWithInner(ctx, inRoomID, otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
@@ -595,9 +595,9 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
|
||||
} else if resp == RejectRequest {
|
||||
mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
_ = mach.SendSASVerificationCancel(ctx, otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
} else {
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
_ = mach.SendInRoomSASVerificationCancel(ctx, inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
}
|
||||
} else {
|
||||
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
@@ -606,14 +606,14 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
|
||||
|
||||
// NewSimpleSASVerificationWith starts the SAS verification process with another device with a default timeout,
|
||||
// a generated transaction ID and support for both emoji and decimal SAS methods.
|
||||
func (mach *OlmMachine) NewSimpleSASVerificationWith(device *id.Device, hooks VerificationHooks) (string, error) {
|
||||
return mach.NewSASVerificationWith(device, hooks, "", mach.DefaultSASTimeout)
|
||||
func (mach *OlmMachine) NewSimpleSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks) (string, error) {
|
||||
return mach.NewSASVerificationWith(ctx, device, hooks, "", mach.DefaultSASTimeout)
|
||||
}
|
||||
|
||||
// NewSASVerificationWith starts the SAS verification process with another device.
|
||||
// If the other device accepts the verification transaction, the methods in `hooks` will be used to verify the SAS match and to complete the transaction..
|
||||
// If the transaction ID is empty, a new one is generated.
|
||||
func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
|
||||
func (mach *OlmMachine) NewSASVerificationWith(ctx context.Context, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
|
||||
if transactionID == "" {
|
||||
transactionID = strconv.Itoa(rand.Int())
|
||||
}
|
||||
@@ -631,7 +631,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica
|
||||
verState.lock.Lock()
|
||||
defer verState.lock.Unlock()
|
||||
|
||||
startEvent, err := mach.SendSASVerificationStart(device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods())
|
||||
startEvent, err := mach.SendSASVerificationStart(ctx, device.UserID, device.DeviceID, transactionID, hooks.VerificationMethods())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -651,13 +651,13 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica
|
||||
return "", ErrTransactionAlreadyExists
|
||||
}
|
||||
|
||||
mach.timeoutAfter(verState, transactionID, timeout)
|
||||
mach.timeoutAfter(ctx, verState, transactionID, timeout)
|
||||
|
||||
return transactionID, nil
|
||||
}
|
||||
|
||||
// CancelSASVerification is used by the user to cancel a SAS verification process with the given reason.
|
||||
func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, reason string) error {
|
||||
func (mach *OlmMachine) CancelSASVerification(ctx context.Context, userID id.UserID, transactionID, reason string) error {
|
||||
mapKey := userID.String() + ":" + transactionID
|
||||
verStateInterface, ok := mach.keyVerificationTransactionState.Load(mapKey)
|
||||
if !ok {
|
||||
@@ -668,21 +668,21 @@ func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, r
|
||||
defer verState.lock.Unlock()
|
||||
mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason)
|
||||
mach.keyVerificationTransactionState.Delete(mapKey)
|
||||
return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser)
|
||||
return mach.callbackAndCancelSASVerification(ctx, verState, transactionID, reason, event.VerificationCancelByUser)
|
||||
}
|
||||
|
||||
// SendSASVerificationCancel is used to manually send a SAS cancel message process with the given reason and cancellation code.
|
||||
func (mach *OlmMachine) SendSASVerificationCancel(userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error {
|
||||
func (mach *OlmMachine) SendSASVerificationCancel(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, reason string, code event.VerificationCancelCode) error {
|
||||
content := &event.VerificationCancelEventContent{
|
||||
TransactionID: transactionID,
|
||||
Reason: reason,
|
||||
Code: code,
|
||||
}
|
||||
return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationCancel, content)
|
||||
return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationCancel, content)
|
||||
}
|
||||
|
||||
// SendSASVerificationStart is used to manually send the SAS verification start message to another device.
|
||||
func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
|
||||
func (mach *OlmMachine) SendSASVerificationStart(ctx context.Context, toUserID id.UserID, toDeviceID id.DeviceID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
|
||||
sasMethods := make([]event.SASMethod, len(methods))
|
||||
for i, method := range methods {
|
||||
sasMethods[i] = method.Type()
|
||||
@@ -696,14 +696,14 @@ func (mach *OlmMachine) SendSASVerificationStart(toUserID id.UserID, toDeviceID
|
||||
MessageAuthenticationCodes: []event.MACMethod{event.HKDFHMACSHA256},
|
||||
ShortAuthenticationString: sasMethods,
|
||||
}
|
||||
return content, mach.sendToOneDevice(toUserID, toDeviceID, event.ToDeviceVerificationStart, content)
|
||||
return content, mach.sendToOneDevice(ctx, toUserID, toDeviceID, event.ToDeviceVerificationStart, content)
|
||||
}
|
||||
|
||||
// SendSASVerificationAccept is used to manually send an accept for a SAS verification process from a received m.key.verification.start event.
|
||||
func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error {
|
||||
func (mach *OlmMachine) SendSASVerificationAccept(ctx context.Context, fromUser id.UserID, startEvent *event.VerificationStartEventContent, publicKey []byte, methods []VerificationMethod) error {
|
||||
if startEvent.Method != event.VerificationMethodSAS {
|
||||
reason := "Unknown verification method: " + string(startEvent.Method)
|
||||
if err := mach.SendSASVerificationCancel(fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
|
||||
if err := mach.SendSASVerificationCancel(ctx, fromUser, startEvent.FromDevice, startEvent.TransactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrUnknownVerificationMethod
|
||||
@@ -730,25 +730,25 @@ func (mach *OlmMachine) SendSASVerificationAccept(fromUser id.UserID, startEvent
|
||||
ShortAuthenticationString: sasMethods,
|
||||
Commitment: hash,
|
||||
}
|
||||
return mach.sendToOneDevice(fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content)
|
||||
return mach.sendToOneDevice(ctx, fromUser, startEvent.FromDevice, event.ToDeviceVerificationAccept, content)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) callbackAndCancelSASVerification(verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error {
|
||||
func (mach *OlmMachine) callbackAndCancelSASVerification(ctx context.Context, verState *verificationState, transactionID, reason string, code event.VerificationCancelCode) error {
|
||||
go verState.hooks.OnCancel(true, reason, code)
|
||||
return mach.SendSASVerificationCancel(verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code)
|
||||
return mach.SendSASVerificationCancel(ctx, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, reason, code)
|
||||
}
|
||||
|
||||
// SendSASVerificationKey sends the ephemeral public key for a device to the partner device.
|
||||
func (mach *OlmMachine) SendSASVerificationKey(userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error {
|
||||
func (mach *OlmMachine) SendSASVerificationKey(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, key string) error {
|
||||
content := &event.VerificationKeyEventContent{
|
||||
TransactionID: transactionID,
|
||||
Key: key,
|
||||
}
|
||||
return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationKey, content)
|
||||
return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationKey, content)
|
||||
}
|
||||
|
||||
// SendSASVerificationMAC is use the MAC of a device's key to the partner device.
|
||||
func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
|
||||
func (mach *OlmMachine) SendSASVerificationMAC(ctx context.Context, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
|
||||
keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String())
|
||||
|
||||
signingKey := mach.account.SigningKey()
|
||||
@@ -784,7 +784,7 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
|
||||
Mac: macMap,
|
||||
}
|
||||
|
||||
return mach.sendToOneDevice(userID, deviceID, event.ToDeviceVerificationMAC, content)
|
||||
return mach.sendToOneDevice(ctx, userID, deviceID, event.ToDeviceVerificationMAC, content)
|
||||
}
|
||||
|
||||
func commonSASMethods(hooks VerificationHooks, otherDeviceMethods []event.SASMethod) []VerificationMethod {
|
||||
|
||||
77
vendor/maunium.net/go/mautrix/crypto/verification_in_room.go
generated
vendored
77
vendor/maunium.net/go/mautrix/crypto/verification_in_room.go
generated
vendored
@@ -38,6 +38,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error {
|
||||
return ErrNoRelatesTo
|
||||
}
|
||||
|
||||
ctx := context.TODO()
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.MessageEventContent:
|
||||
if content.MsgType == event.MsgVerificationRequest {
|
||||
@@ -54,18 +55,18 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error {
|
||||
Timestamp: evt.Timestamp,
|
||||
TransactionID: evt.ID.String(),
|
||||
}
|
||||
mach.handleVerificationRequest(evt.Sender, newContent, evt.ID.String(), evt.RoomID)
|
||||
mach.handleVerificationRequest(ctx, evt.Sender, newContent, evt.ID.String(), evt.RoomID)
|
||||
}
|
||||
case *event.VerificationStartEventContent:
|
||||
mach.handleVerificationStart(evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID)
|
||||
mach.handleVerificationStart(ctx, evt.Sender, content, content.RelatesTo.EventID.String(), 10*time.Minute, evt.RoomID)
|
||||
case *event.VerificationReadyEventContent:
|
||||
mach.handleInRoomVerificationReady(evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String())
|
||||
mach.handleInRoomVerificationReady(ctx, evt.Sender, evt.RoomID, content, content.RelatesTo.EventID.String())
|
||||
case *event.VerificationAcceptEventContent:
|
||||
mach.handleVerificationAccept(evt.Sender, content, content.RelatesTo.EventID.String())
|
||||
mach.handleVerificationAccept(ctx, evt.Sender, content, content.RelatesTo.EventID.String())
|
||||
case *event.VerificationKeyEventContent:
|
||||
mach.handleVerificationKey(evt.Sender, content, content.RelatesTo.EventID.String())
|
||||
mach.handleVerificationKey(ctx, evt.Sender, content, content.RelatesTo.EventID.String())
|
||||
case *event.VerificationMacEventContent:
|
||||
mach.handleVerificationMAC(evt.Sender, content, content.RelatesTo.EventID.String())
|
||||
mach.handleVerificationMAC(ctx, evt.Sender, content, content.RelatesTo.EventID.String())
|
||||
case *event.VerificationCancelEventContent:
|
||||
mach.handleVerificationCancel(evt.Sender, content, content.RelatesTo.EventID.String())
|
||||
}
|
||||
@@ -73,7 +74,7 @@ func (mach *OlmMachine) ProcessInRoomVerification(evt *event.Event) error {
|
||||
}
|
||||
|
||||
// SendInRoomSASVerificationCancel is used to manually send an in-room SAS cancel message process with the given reason and cancellation code.
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error {
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationCancel(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, reason string, code event.VerificationCancelCode) error {
|
||||
content := &event.VerificationCancelEventContent{
|
||||
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
|
||||
Reason: reason,
|
||||
@@ -81,16 +82,16 @@ func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID
|
||||
To: userID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationCancel, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationCancel, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
||||
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendInRoomSASVerificationRequest is used to manually send an in-room SAS verification request message to another user.
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) {
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationRequest(ctx context.Context, roomID id.RoomID, toUserID id.UserID, methods []VerificationMethod) (string, error) {
|
||||
content := &event.MessageEventContent{
|
||||
MsgType: event.MsgVerificationRequest,
|
||||
FromDevice: mach.Client.DeviceID,
|
||||
@@ -98,11 +99,11 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse
|
||||
To: toUserID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.EventMessage, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.EventMessage, content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
resp, err := mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
||||
resp, err := mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -110,23 +111,23 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse
|
||||
}
|
||||
|
||||
// SendInRoomSASVerificationReady is used to manually send an in-room SAS verification ready message to another user.
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transactionID string) error {
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationReady(ctx context.Context, roomID id.RoomID, transactionID string) error {
|
||||
content := &event.VerificationReadyEventContent{
|
||||
FromDevice: mach.Client.DeviceID,
|
||||
Methods: []event.VerificationMethod{event.VerificationMethodSAS},
|
||||
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationReady, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationReady, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
||||
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendInRoomSASVerificationStart is used to manually send the in-room SAS verification start message to another user.
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationStart(ctx context.Context, roomID id.RoomID, toUserID id.UserID, transactionID string, methods []VerificationMethod) (*event.VerificationStartEventContent, error) {
|
||||
sasMethods := make([]event.SASMethod, len(methods))
|
||||
for i, method := range methods {
|
||||
sasMethods[i] = method.Type()
|
||||
@@ -142,19 +143,19 @@ func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserI
|
||||
To: toUserID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationStart, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationStart, content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
||||
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
|
||||
return content, err
|
||||
}
|
||||
|
||||
// SendInRoomSASVerificationAccept is used to manually send an accept for an in-room SAS verification process from a received m.key.verification.start event.
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error {
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationAccept(ctx context.Context, roomID id.RoomID, fromUser id.UserID, startEvent *event.VerificationStartEventContent, transactionID string, publicKey []byte, methods []VerificationMethod) error {
|
||||
if startEvent.Method != event.VerificationMethodSAS {
|
||||
reason := "Unknown verification method: " + string(startEvent.Method)
|
||||
if err := mach.SendInRoomSASVerificationCancel(roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
|
||||
if err := mach.SendInRoomSASVerificationCancel(ctx, roomID, fromUser, transactionID, reason, event.VerificationCancelUnknownMethod); err != nil {
|
||||
return err
|
||||
}
|
||||
return ErrUnknownVerificationMethod
|
||||
@@ -183,32 +184,32 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUs
|
||||
To: fromUser,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationAccept, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationAccept, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
||||
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendInRoomSASVerificationKey sends the ephemeral public key for a device to the partner device for an in-room verification.
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id.UserID, transactionID string, key string) error {
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationKey(ctx context.Context, roomID id.RoomID, userID id.UserID, transactionID string, key string) error {
|
||||
content := &event.VerificationKeyEventContent{
|
||||
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
|
||||
Key: key,
|
||||
To: userID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationKey, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationKey, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
||||
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// SendInRoomSASVerificationMAC sends the MAC of a device's key to the partner device for an in-room verification.
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
|
||||
func (mach *OlmMachine) SendInRoomSASVerificationMAC(ctx context.Context, roomID id.RoomID, userID id.UserID, deviceID id.DeviceID, transactionID string, sas *olm.SAS) error {
|
||||
keyID := id.NewKeyID(id.KeyAlgorithmEd25519, mach.Client.DeviceID.String())
|
||||
|
||||
signingKey := mach.account.SigningKey()
|
||||
@@ -245,28 +246,28 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
|
||||
To: userID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationMAC, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(ctx, roomID, event.InRoomVerificationMAC, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = mach.Client.SendMessageEvent(roomID, event.EventEncrypted, encrypted)
|
||||
_, err = mach.Client.SendMessageEvent(ctx, roomID, event.EventEncrypted, encrypted)
|
||||
return err
|
||||
}
|
||||
|
||||
// NewInRoomSASVerificationWith starts the in-room SAS verification process with another user in the given room.
|
||||
// It returns the generated transaction ID.
|
||||
func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) {
|
||||
return mach.newInRoomSASVerificationWithInner(inRoomID, &id.Device{UserID: userID}, hooks, "", timeout)
|
||||
func (mach *OlmMachine) NewInRoomSASVerificationWith(ctx context.Context, inRoomID id.RoomID, userID id.UserID, hooks VerificationHooks, timeout time.Duration) (string, error) {
|
||||
return mach.newInRoomSASVerificationWithInner(ctx, inRoomID, &id.Device{UserID: userID}, hooks, "", timeout)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
|
||||
func (mach *OlmMachine) newInRoomSASVerificationWithInner(ctx context.Context, inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
|
||||
mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID)
|
||||
|
||||
request := transactionID == ""
|
||||
if request {
|
||||
var err error
|
||||
// get new transaction ID from the request message event ID
|
||||
transactionID, err = mach.SendInRoomSASVerificationRequest(inRoomID, device.UserID, hooks.VerificationMethods())
|
||||
transactionID, err = mach.SendInRoomSASVerificationRequest(ctx, inRoomID, device.UserID, hooks.VerificationMethods())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -286,7 +287,7 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de
|
||||
|
||||
if !request {
|
||||
// start in-room verification
|
||||
startEvent, err := mach.SendInRoomSASVerificationStart(inRoomID, device.UserID, transactionID, hooks.VerificationMethods())
|
||||
startEvent, err := mach.SendInRoomSASVerificationStart(ctx, inRoomID, device.UserID, transactionID, hooks.VerificationMethods())
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -305,19 +306,19 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de
|
||||
|
||||
mach.keyVerificationTransactionState.Store(device.UserID.String()+":"+transactionID, verState)
|
||||
|
||||
mach.timeoutAfter(verState, transactionID, timeout)
|
||||
mach.timeoutAfter(ctx, verState, transactionID, timeout)
|
||||
|
||||
return transactionID, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) {
|
||||
device, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
|
||||
func (mach *OlmMachine) handleInRoomVerificationReady(ctx context.Context, userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) {
|
||||
device, err := mach.GetOrFetchDevice(ctx, userID, content.FromDevice)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
|
||||
return
|
||||
}
|
||||
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
verState, err := mach.getTransactionState(ctx, transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
@@ -327,7 +328,7 @@ func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID i
|
||||
if mach.Client.UserID < userID {
|
||||
// up to us to send the start message
|
||||
verState.lock.Lock()
|
||||
mach.newInRoomSASVerificationWithInner(roomID, device, verState.hooks, transactionID, 10*time.Minute)
|
||||
mach.newInRoomSASVerificationWithInner(ctx, roomID, device, verState.hooks, transactionID, 10*time.Minute)
|
||||
verState.lock.Unlock()
|
||||
}
|
||||
}
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/event/events.go
generated
vendored
2
vendor/maunium.net/go/mautrix/event/events.go
generated
vendored
@@ -105,6 +105,8 @@ func (evt *Event) MarshalJSON() ([]byte, error) {
|
||||
}
|
||||
|
||||
type MautrixInfo struct {
|
||||
EventSource Source
|
||||
|
||||
TrustState id.TrustState
|
||||
ForwardedKeys bool
|
||||
WasEncrypted bool
|
||||
|
||||
72
vendor/maunium.net/go/mautrix/event/eventsource.go
generated
vendored
Normal file
72
vendor/maunium.net/go/mautrix/event/eventsource.go
generated
vendored
Normal file
@@ -0,0 +1,72 @@
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package event
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
)
|
||||
|
||||
// Source represents the part of the sync response that an event came from.
|
||||
type Source int
|
||||
|
||||
const (
|
||||
SourcePresence Source = 1 << iota
|
||||
SourceJoin
|
||||
SourceInvite
|
||||
SourceLeave
|
||||
SourceAccountData
|
||||
SourceTimeline
|
||||
SourceState
|
||||
SourceEphemeral
|
||||
SourceToDevice
|
||||
SourceDecrypted
|
||||
)
|
||||
|
||||
const primaryTypes = SourcePresence | SourceAccountData | SourceToDevice | SourceTimeline | SourceState
|
||||
const roomSections = SourceJoin | SourceInvite | SourceLeave
|
||||
const roomableTypes = SourceAccountData | SourceTimeline | SourceState
|
||||
const encryptableTypes = roomableTypes | SourceToDevice
|
||||
|
||||
func (es Source) String() string {
|
||||
var typeName string
|
||||
switch es & primaryTypes {
|
||||
case SourcePresence:
|
||||
typeName = "presence"
|
||||
case SourceAccountData:
|
||||
typeName = "account data"
|
||||
case SourceToDevice:
|
||||
typeName = "to-device"
|
||||
case SourceTimeline:
|
||||
typeName = "timeline"
|
||||
case SourceState:
|
||||
typeName = "state"
|
||||
default:
|
||||
return fmt.Sprintf("unknown (%d)", es)
|
||||
}
|
||||
if es&roomableTypes != 0 {
|
||||
switch es & roomSections {
|
||||
case SourceJoin:
|
||||
typeName = "joined room " + typeName
|
||||
case SourceInvite:
|
||||
typeName = "invited room " + typeName
|
||||
case SourceLeave:
|
||||
typeName = "left room " + typeName
|
||||
default:
|
||||
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
|
||||
}
|
||||
es &^= roomSections
|
||||
}
|
||||
if es&encryptableTypes != 0 && es&SourceDecrypted != 0 {
|
||||
typeName += " (decrypted)"
|
||||
es &^= SourceDecrypted
|
||||
}
|
||||
es &^= primaryTypes
|
||||
if es != 0 {
|
||||
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
20
vendor/maunium.net/go/mautrix/event/message.go
generated
vendored
20
vendor/maunium.net/go/mautrix/event/message.go
generated
vendored
@@ -199,10 +199,14 @@ type FileInfo struct {
|
||||
ThumbnailInfo *FileInfo `json:"thumbnail_info,omitempty"`
|
||||
ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"`
|
||||
ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"`
|
||||
Width int `json:"-"`
|
||||
Height int `json:"-"`
|
||||
Duration int `json:"-"`
|
||||
Size int `json:"-"`
|
||||
|
||||
Blurhash string `json:"blurhash,omitempty"`
|
||||
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
|
||||
|
||||
Width int `json:"-"`
|
||||
Height int `json:"-"`
|
||||
Duration int `json:"-"`
|
||||
Size int `json:"-"`
|
||||
}
|
||||
|
||||
type serializableFileInfo struct {
|
||||
@@ -211,6 +215,9 @@ type serializableFileInfo struct {
|
||||
ThumbnailURL id.ContentURIString `json:"thumbnail_url,omitempty"`
|
||||
ThumbnailFile *EncryptedFileInfo `json:"thumbnail_file,omitempty"`
|
||||
|
||||
Blurhash string `json:"blurhash,omitempty"`
|
||||
AnoaBlurhash string `json:"xyz.amorgan.blurhash,omitempty"`
|
||||
|
||||
Width json.Number `json:"w,omitempty"`
|
||||
Height json.Number `json:"h,omitempty"`
|
||||
Duration json.Number `json:"duration,omitempty"`
|
||||
@@ -226,6 +233,9 @@ func (sfi *serializableFileInfo) CopyFrom(fileInfo *FileInfo) *serializableFileI
|
||||
ThumbnailURL: fileInfo.ThumbnailURL,
|
||||
ThumbnailInfo: (&serializableFileInfo{}).CopyFrom(fileInfo.ThumbnailInfo),
|
||||
ThumbnailFile: fileInfo.ThumbnailFile,
|
||||
|
||||
Blurhash: fileInfo.Blurhash,
|
||||
AnoaBlurhash: fileInfo.AnoaBlurhash,
|
||||
}
|
||||
if fileInfo.Width > 0 {
|
||||
sfi.Width = json.Number(strconv.Itoa(fileInfo.Width))
|
||||
@@ -252,6 +262,8 @@ func (sfi *serializableFileInfo) CopyTo(fileInfo *FileInfo) {
|
||||
MimeType: sfi.MimeType,
|
||||
ThumbnailURL: sfi.ThumbnailURL,
|
||||
ThumbnailFile: sfi.ThumbnailFile,
|
||||
Blurhash: sfi.Blurhash,
|
||||
AnoaBlurhash: sfi.AnoaBlurhash,
|
||||
}
|
||||
if sfi.ThumbnailInfo != nil {
|
||||
fileInfo.ThumbnailInfo = &FileInfo{}
|
||||
|
||||
4
vendor/maunium.net/go/mautrix/requests.go
generated
vendored
4
vendor/maunium.net/go/mautrix/requests.go
generated
vendored
@@ -287,9 +287,7 @@ type Signatures map[id.UserID]map[id.KeyID]string
|
||||
|
||||
type ReqQueryKeys struct {
|
||||
DeviceKeys DeviceKeysRequest `json:"device_keys"`
|
||||
|
||||
Timeout int64 `json:"timeout,omitempty"`
|
||||
Token string `json:"token,omitempty"`
|
||||
Timeout int64 `json:"timeout,omitempty"`
|
||||
}
|
||||
|
||||
type DeviceKeysRequest map[id.UserID]DeviceIDList
|
||||
|
||||
309
vendor/maunium.net/go/mautrix/sqlstatestore/statestore.go
generated
vendored
309
vendor/maunium.net/go/mautrix/sqlstatestore/statestore.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,6 +7,7 @@
|
||||
package sqlstatestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
@@ -15,6 +16,7 @@ import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/dbutil"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
@@ -44,26 +46,28 @@ func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge b
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
|
||||
func (store *SQLStateStore) IsRegistered(ctx context.Context, userID id.UserID) (bool, error) {
|
||||
var isRegistered bool
|
||||
err := store.
|
||||
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
|
||||
QueryRow(ctx, "SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
|
||||
Scan(&isRegistered)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return isRegistered
|
||||
return isRegistered, err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
|
||||
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to mark %s as registered: %v", userID, err)
|
||||
}
|
||||
func (store *SQLStateStore) MarkRegistered(ctx context.Context, userID id.UserID) error {
|
||||
_, err := store.Exec(ctx, "INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent {
|
||||
members := make(map[id.UserID]*event.MemberEventContent)
|
||||
type Member struct {
|
||||
id.UserID
|
||||
event.MemberEventContent
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) (map[id.UserID]*event.MemberEventContent, error) {
|
||||
args := make([]any, len(memberships)+1)
|
||||
args[0] = roomID
|
||||
query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1"
|
||||
@@ -75,25 +79,26 @@ func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...even
|
||||
}
|
||||
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
|
||||
}
|
||||
rows, err := store.Query(query, args...)
|
||||
rows, err := store.Query(ctx, query, args...)
|
||||
if err != nil {
|
||||
return members
|
||||
return nil, err
|
||||
}
|
||||
var userID id.UserID
|
||||
var member event.MemberEventContent
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
|
||||
} else {
|
||||
members[userID] = &member
|
||||
}
|
||||
}
|
||||
return members
|
||||
members := make(map[id.UserID]*event.MemberEventContent)
|
||||
return members, dbutil.NewRowIter(rows, func(row dbutil.Scannable) (ret Member, err error) {
|
||||
err = row.Scan(&ret.UserID, &ret.Membership, &ret.Displayname, &ret.AvatarURL)
|
||||
return
|
||||
}).Iter(func(m Member) (bool, error) {
|
||||
members[m.UserID] = &m.MemberEventContent
|
||||
return true, nil
|
||||
})
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
|
||||
memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite)
|
||||
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) (members []id.UserID, err error) {
|
||||
var memberMap map[id.UserID]*event.MemberEventContent
|
||||
memberMap, err = store.GetRoomMembers(ctx, roomID, event.MembershipJoin, event.MembershipInvite)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
members = make([]id.UserID, len(memberMap))
|
||||
i := 0
|
||||
for userID := range memberMap {
|
||||
@@ -103,37 +108,39 @@ func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (mem
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
membership := event.MembershipLeave
|
||||
err := store.
|
||||
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
func (store *SQLStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (membership event.Membership, err error) {
|
||||
err = store.
|
||||
QueryRow(ctx, "SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&membership)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
membership = event.MembershipLeave
|
||||
err = nil
|
||||
}
|
||||
return membership
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member.Membership = event.MembershipLeave
|
||||
func (store *SQLStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
member, err := store.TryGetMember(ctx, roomID, userID)
|
||||
if member == nil && err == nil {
|
||||
member = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return member
|
||||
return member, err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
|
||||
func (store *SQLStateStore) TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
var member event.MemberEventContent
|
||||
err := store.
|
||||
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
QueryRow(ctx, "SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &member, err == nil
|
||||
return &member, nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
|
||||
func (store *SQLStateStore) FindSharedRooms(ctx context.Context, userID id.UserID) ([]id.RoomID, error) {
|
||||
query := `
|
||||
SELECT room_id FROM mx_user_profile
|
||||
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
|
||||
@@ -141,38 +148,32 @@ func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID
|
||||
`
|
||||
if !store.IsBridge {
|
||||
query = `
|
||||
SELECT mx_user_profile.room_id FROM mx_user_profile
|
||||
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
|
||||
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
|
||||
`
|
||||
SELECT mx_user_profile.room_id FROM mx_user_profile
|
||||
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
|
||||
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
|
||||
`
|
||||
}
|
||||
rows, err := store.Query(query, userID)
|
||||
rows, err := store.Query(ctx, query, userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err)
|
||||
return
|
||||
return nil, err
|
||||
}
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan room ID: %v", err)
|
||||
} else {
|
||||
rooms = append(rooms, roomID)
|
||||
}
|
||||
return dbutil.NewRowIter(rows, dbutil.ScanSingleColumn[id.RoomID]).AsList()
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership, err := store.GetMembership(ctx, roomID, userID)
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Err(err).Msg("Failed to get membership")
|
||||
return false
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
@@ -181,27 +182,23 @@ func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, all
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
_, err := store.Exec(`
|
||||
func (store *SQLStateStore) SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
|
||||
_, err := store.Exec(ctx, `
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
|
||||
`, roomID, userID, membership)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
_, err := store.Exec(`
|
||||
func (store *SQLStateStore) SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
|
||||
_, err := store.Exec(ctx, `
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
|
||||
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
|
||||
func (store *SQLStateStore) ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error {
|
||||
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
|
||||
params := make([]any, len(memberships)+1)
|
||||
params[0] = roomID
|
||||
@@ -213,109 +210,85 @@ func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...
|
||||
}
|
||||
query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ","))
|
||||
}
|
||||
_, err := store.Exec(query, params...)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err)
|
||||
}
|
||||
_, err := store.Exec(ctx, query, params...)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
|
||||
func (store *SQLStateStore) SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
|
||||
contentBytes, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err)
|
||||
return
|
||||
return fmt.Errorf("failed to marshal content JSON: %w", err)
|
||||
}
|
||||
_, err = store.Exec(`
|
||||
_, err = store.Exec(ctx, `
|
||||
INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption
|
||||
`, roomID, contentBytes)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
|
||||
func (store *SQLStateStore) GetEncryptionEvent(ctx context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
|
||||
var data []byte
|
||||
err := store.
|
||||
QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
QueryRow(ctx, "SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err)
|
||||
}
|
||||
return nil
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else if data == nil {
|
||||
return nil
|
||||
return nil, nil
|
||||
}
|
||||
content := &event.EncryptionEventContent{}
|
||||
err = json.Unmarshal(data, content)
|
||||
var content event.EncryptionEventContent
|
||||
err = json.Unmarshal(data, &content)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err)
|
||||
return nil
|
||||
return nil, fmt.Errorf("failed to parse content JSON: %w", err)
|
||||
}
|
||||
return content
|
||||
return &content, nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool {
|
||||
cfg := store.GetEncryptionEvent(roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
|
||||
func (store *SQLStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
|
||||
cfg, err := store.GetEncryptionEvent(ctx, roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
levelsBytes, err := json.Marshal(levels)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err)
|
||||
return
|
||||
}
|
||||
_, err = store.Exec(`
|
||||
func (store *SQLStateStore) SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
|
||||
_, err := store.Exec(ctx, `
|
||||
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
|
||||
`, roomID, levelsBytes)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to store power levels of %s: %v", roomID, err)
|
||||
}
|
||||
`, roomID, dbutil.JSON{Data: levels})
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
var data []byte
|
||||
err := store.
|
||||
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err)
|
||||
}
|
||||
return
|
||||
} else if data == nil {
|
||||
return
|
||||
}
|
||||
levels = &event.PowerLevelsEventContent{}
|
||||
err = json.Unmarshal(data, levels)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err)
|
||||
return nil
|
||||
func (store *SQLStateStore) GetPowerLevels(ctx context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
|
||||
err = store.
|
||||
QueryRow(ctx, "SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&dbutil.JSON{Data: &levels})
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
func (store *SQLStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
QueryRow(ctx, `
|
||||
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, userID).
|
||||
Scan(&powerLevel)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err)
|
||||
return powerLevel, err
|
||||
} else {
|
||||
levels, err := store.GetPowerLevels(ctx, roomID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return powerLevel
|
||||
return levels.GetUserLevel(userID), nil
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
func (store *SQLStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
@@ -325,23 +298,26 @@ func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType
|
||||
}
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
QueryRow(ctx, `
|
||||
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&powerLevel)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
powerLevel = defaultValue
|
||||
}
|
||||
return powerLevel
|
||||
return powerLevel, err
|
||||
} else {
|
||||
levels, err := store.GetPowerLevels(ctx, roomID)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
}
|
||||
return levels.GetEventLevel(eventType), nil
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
func (store *SQLStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
@@ -351,19 +327,22 @@ func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, ev
|
||||
}
|
||||
var hasPower bool
|
||||
err := store.
|
||||
QueryRow(`SELECT
|
||||
QueryRow(ctx, `SELECT
|
||||
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
>=
|
||||
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
|
||||
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&hasPower)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue == 0
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
hasPower = defaultValue == 0
|
||||
}
|
||||
return hasPower
|
||||
return hasPower, err
|
||||
} else {
|
||||
levels, err := store.GetPowerLevels(ctx, roomID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
return levels.GetUserLevel(userID) >= levels.GetEventLevel(eventType), nil
|
||||
}
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
|
||||
7
vendor/maunium.net/go/mautrix/sqlstatestore/v05-mark-encryption-state-resync.go
generated
vendored
7
vendor/maunium.net/go/mautrix/sqlstatestore/v05-mark-encryption-state-resync.go
generated
vendored
@@ -1,19 +1,20 @@
|
||||
package sqlstatestore
|
||||
|
||||
import (
|
||||
"context"
|
||||
"fmt"
|
||||
|
||||
"go.mau.fi/util/dbutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error {
|
||||
portalExists, err := db.TableExists(tx, "portal")
|
||||
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(ctx context.Context, db *dbutil.Database) error {
|
||||
portalExists, err := db.TableExists(ctx, "portal")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if portal table exists")
|
||||
}
|
||||
if portalExists {
|
||||
_, err = tx.Exec(`
|
||||
_, err = db.Exec(ctx, `
|
||||
INSERT INTO mx_room_state (room_id, encryption)
|
||||
SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
|
||||
ON CONFLICT (room_id) DO UPDATE
|
||||
|
||||
136
vendor/maunium.net/go/mautrix/statestore.go
generated
vendored
136
vendor/maunium.net/go/mautrix/statestore.go
generated
vendored
@@ -7,33 +7,37 @@
|
||||
package mautrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"sync"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"go.mau.fi/util/exerrors"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// StateStore is an interface for storing basic room state information.
|
||||
type StateStore interface {
|
||||
IsInRoom(roomID id.RoomID, userID id.UserID) bool
|
||||
IsInvited(roomID id.RoomID, userID id.UserID) bool
|
||||
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
|
||||
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
|
||||
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
|
||||
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
|
||||
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
|
||||
ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership)
|
||||
IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
|
||||
IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool
|
||||
IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
|
||||
GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
|
||||
TryGetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error)
|
||||
SetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error
|
||||
SetMember(ctx context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error
|
||||
ClearCachedMembers(ctx context.Context, roomID id.RoomID, memberships ...event.Membership) error
|
||||
|
||||
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
|
||||
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
|
||||
SetPowerLevels(ctx context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error
|
||||
GetPowerLevels(ctx context.Context, roomID id.RoomID) (*event.PowerLevelsEventContent, error)
|
||||
|
||||
SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent)
|
||||
IsEncrypted(roomID id.RoomID) bool
|
||||
SetEncryptionEvent(ctx context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error
|
||||
IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error)
|
||||
|
||||
GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error)
|
||||
GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error)
|
||||
}
|
||||
|
||||
func UpdateStateStore(store StateStore, evt *event.Event) {
|
||||
func UpdateStateStore(ctx context.Context, store StateStore, evt *event.Event) {
|
||||
if store == nil || evt == nil || evt.StateKey == nil {
|
||||
return
|
||||
}
|
||||
@@ -41,13 +45,20 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
|
||||
if evt.Type != event.StateMember && evt.GetStateKey() != "" {
|
||||
return
|
||||
}
|
||||
var err error
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.MemberEventContent:
|
||||
store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
|
||||
err = store.SetMember(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), content)
|
||||
case *event.PowerLevelsEventContent:
|
||||
store.SetPowerLevels(evt.RoomID, content)
|
||||
err = store.SetPowerLevels(ctx, evt.RoomID, content)
|
||||
case *event.EncryptionEventContent:
|
||||
store.SetEncryptionEvent(evt.RoomID, content)
|
||||
err = store.SetEncryptionEvent(ctx, evt.RoomID, content)
|
||||
}
|
||||
if err != nil {
|
||||
zerolog.Ctx(ctx).Warn().Err(err).
|
||||
Stringer("event_id", evt.ID).
|
||||
Str("event_type", evt.Type.Type).
|
||||
Msg("Failed to update state store")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -56,8 +67,8 @@ func UpdateStateStore(store StateStore, evt *event.Event) {
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler)
|
||||
//
|
||||
// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
|
||||
func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
func (cli *Client) StateStoreSyncHandler(ctx context.Context, evt *event.Event) {
|
||||
UpdateStateStore(ctx, cli.StateStore, evt)
|
||||
}
|
||||
|
||||
type MemoryStateStore struct {
|
||||
@@ -81,20 +92,21 @@ func NewMemoryStateStore() StateStore {
|
||||
}
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool {
|
||||
func (store *MemoryStateStore) IsRegistered(_ context.Context, userID id.UserID) (bool, error) {
|
||||
store.registrationsLock.RLock()
|
||||
defer store.registrationsLock.RUnlock()
|
||||
registered, ok := store.Registrations[userID]
|
||||
return ok && registered
|
||||
return ok && registered, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) MarkRegistered(userID id.UserID) {
|
||||
func (store *MemoryStateStore) MarkRegistered(_ context.Context, userID id.UserID) error {
|
||||
store.registrationsLock.Lock()
|
||||
defer store.registrationsLock.Unlock()
|
||||
store.Registrations[userID] = true
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
|
||||
func (store *MemoryStateStore) GetRoomMembers(_ context.Context, roomID id.RoomID) (map[id.UserID]*event.MemberEventContent, error) {
|
||||
store.membersLock.RLock()
|
||||
members, ok := store.Members[roomID]
|
||||
store.membersLock.RUnlock()
|
||||
@@ -104,11 +116,14 @@ func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*e
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
return members
|
||||
return members, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) {
|
||||
members := store.GetRoomMembers(roomID)
|
||||
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(ctx context.Context, roomID id.RoomID) ([]id.UserID, error) {
|
||||
members, err := store.GetRoomMembers(ctx, roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
ids := make([]id.UserID, 0, len(members))
|
||||
for id := range members {
|
||||
ids = append(ids, id)
|
||||
@@ -116,39 +131,39 @@ func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
return store.GetMember(roomID, userID).Membership
|
||||
func (store *MemoryStateStore) GetMembership(ctx context.Context, roomID id.RoomID, userID id.UserID) (event.Membership, error) {
|
||||
return exerrors.Must(store.GetMember(ctx, roomID, userID)).Membership, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
func (store *MemoryStateStore) GetMember(ctx context.Context, roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, error) {
|
||||
member, err := store.TryGetMember(ctx, roomID, userID)
|
||||
if member == nil && err == nil {
|
||||
member = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return member
|
||||
return member, err
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
|
||||
func (store *MemoryStateStore) TryGetMember(_ context.Context, roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, err error) {
|
||||
store.membersLock.RLock()
|
||||
defer store.membersLock.RUnlock()
|
||||
members, membersOk := store.Members[roomID]
|
||||
if !membersOk {
|
||||
return
|
||||
}
|
||||
member, ok = members[userID]
|
||||
member = members[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
func (store *MemoryStateStore) IsInRoom(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
func (store *MemoryStateStore) IsInvited(ctx context.Context, roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(ctx, roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
func (store *MemoryStateStore) IsMembership(ctx context.Context, roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := exerrors.Must(store.GetMembership(ctx, roomID, userID))
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
@@ -157,7 +172,7 @@ func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID,
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
func (store *MemoryStateStore) SetMembership(_ context.Context, roomID id.RoomID, userID id.UserID, membership event.Membership) error {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
@@ -175,9 +190,10 @@ func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID,
|
||||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
func (store *MemoryStateStore) SetMember(_ context.Context, roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) error {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
@@ -189,14 +205,15 @@ func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, mem
|
||||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
|
||||
func (store *MemoryStateStore) ClearCachedMembers(_ context.Context, roomID id.RoomID, memberships ...event.Membership) error {
|
||||
store.membersLock.Lock()
|
||||
defer store.membersLock.Unlock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
for userID, member := range members {
|
||||
for _, membership := range memberships {
|
||||
@@ -206,46 +223,49 @@ func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships
|
||||
}
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
func (store *MemoryStateStore) SetPowerLevels(_ context.Context, roomID id.RoomID, levels *event.PowerLevelsEventContent) error {
|
||||
store.powerLevelsLock.Lock()
|
||||
store.PowerLevels[roomID] = levels
|
||||
store.powerLevelsLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
func (store *MemoryStateStore) GetPowerLevels(_ context.Context, roomID id.RoomID) (levels *event.PowerLevelsEventContent, err error) {
|
||||
store.powerLevelsLock.RLock()
|
||||
levels = store.PowerLevels[roomID]
|
||||
store.powerLevelsLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
func (store *MemoryStateStore) GetPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID) (int, error) {
|
||||
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetUserLevel(userID), nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
func (store *MemoryStateStore) GetPowerLevelRequirement(ctx context.Context, roomID id.RoomID, eventType event.Type) (int, error) {
|
||||
return exerrors.Must(store.GetPowerLevels(ctx, roomID)).GetEventLevel(eventType), nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
func (store *MemoryStateStore) HasPowerLevel(ctx context.Context, roomID id.RoomID, userID id.UserID, eventType event.Type) (bool, error) {
|
||||
return exerrors.Must(store.GetPowerLevel(ctx, roomID, userID)) >= exerrors.Must(store.GetPowerLevelRequirement(ctx, roomID, eventType)), nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
|
||||
func (store *MemoryStateStore) SetEncryptionEvent(_ context.Context, roomID id.RoomID, content *event.EncryptionEventContent) error {
|
||||
store.encryptionLock.Lock()
|
||||
store.Encryption[roomID] = content
|
||||
store.encryptionLock.Unlock()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
|
||||
func (store *MemoryStateStore) GetEncryptionEvent(_ context.Context, roomID id.RoomID) (*event.EncryptionEventContent, error) {
|
||||
store.encryptionLock.RLock()
|
||||
defer store.encryptionLock.RUnlock()
|
||||
return store.Encryption[roomID]
|
||||
return store.Encryption[roomID], nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool {
|
||||
cfg := store.GetEncryptionEvent(roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
|
||||
func (store *MemoryStateStore) IsEncrypted(ctx context.Context, roomID id.RoomID) (bool, error) {
|
||||
cfg, err := store.GetEncryptionEvent(ctx, roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1, err
|
||||
}
|
||||
|
||||
145
vendor/maunium.net/go/mautrix/sync.go
generated
vendored
145
vendor/maunium.net/go/mautrix/sync.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2024 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,6 +7,7 @@
|
||||
package mautrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
@@ -16,78 +17,17 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// EventSource represents the part of the sync response that an event came from.
|
||||
type EventSource int
|
||||
|
||||
const (
|
||||
EventSourcePresence EventSource = 1 << iota
|
||||
EventSourceJoin
|
||||
EventSourceInvite
|
||||
EventSourceLeave
|
||||
EventSourceAccountData
|
||||
EventSourceTimeline
|
||||
EventSourceState
|
||||
EventSourceEphemeral
|
||||
EventSourceToDevice
|
||||
EventSourceDecrypted
|
||||
)
|
||||
|
||||
const primaryTypes = EventSourcePresence | EventSourceAccountData | EventSourceToDevice | EventSourceTimeline | EventSourceState
|
||||
const roomSections = EventSourceJoin | EventSourceInvite | EventSourceLeave
|
||||
const roomableTypes = EventSourceAccountData | EventSourceTimeline | EventSourceState
|
||||
const encryptableTypes = roomableTypes | EventSourceToDevice
|
||||
|
||||
func (es EventSource) String() string {
|
||||
var typeName string
|
||||
switch es & primaryTypes {
|
||||
case EventSourcePresence:
|
||||
typeName = "presence"
|
||||
case EventSourceAccountData:
|
||||
typeName = "account data"
|
||||
case EventSourceToDevice:
|
||||
typeName = "to-device"
|
||||
case EventSourceTimeline:
|
||||
typeName = "timeline"
|
||||
case EventSourceState:
|
||||
typeName = "state"
|
||||
default:
|
||||
return fmt.Sprintf("unknown (%d)", es)
|
||||
}
|
||||
if es&roomableTypes != 0 {
|
||||
switch es & roomSections {
|
||||
case EventSourceJoin:
|
||||
typeName = "joined room " + typeName
|
||||
case EventSourceInvite:
|
||||
typeName = "invited room " + typeName
|
||||
case EventSourceLeave:
|
||||
typeName = "left room " + typeName
|
||||
default:
|
||||
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
|
||||
}
|
||||
es &^= roomSections
|
||||
}
|
||||
if es&encryptableTypes != 0 && es&EventSourceDecrypted != 0 {
|
||||
typeName += " (decrypted)"
|
||||
es &^= EventSourceDecrypted
|
||||
}
|
||||
es &^= primaryTypes
|
||||
if es != 0 {
|
||||
return fmt.Sprintf("unknown (%s+%d)", typeName, es)
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
|
||||
// EventHandler handles a single event from a sync response.
|
||||
type EventHandler func(source EventSource, evt *event.Event)
|
||||
type EventHandler func(ctx context.Context, evt *event.Event)
|
||||
|
||||
// SyncHandler handles a whole sync response. If the return value is false, handling will be stopped completely.
|
||||
type SyncHandler func(resp *RespSync, since string) bool
|
||||
type SyncHandler func(ctx context.Context, resp *RespSync, since string) bool
|
||||
|
||||
// Syncer is an interface that must be satisfied in order to do /sync requests on a client.
|
||||
type Syncer interface {
|
||||
// ProcessResponse processes the /sync response. The since parameter is the since= value that was used to produce the response.
|
||||
// This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped permanently.
|
||||
ProcessResponse(resp *RespSync, since string) error
|
||||
ProcessResponse(ctx context.Context, resp *RespSync, since string) error
|
||||
// OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently.
|
||||
OnFailedSync(res *RespSync, err error) (time.Duration, error)
|
||||
// GetFilterJSON for the given user ID. NOT the filter ID.
|
||||
@@ -101,7 +41,7 @@ type ExtensibleSyncer interface {
|
||||
}
|
||||
|
||||
type DispatchableSyncer interface {
|
||||
Dispatch(source EventSource, evt *event.Event)
|
||||
Dispatch(ctx context.Context, evt *event.Event)
|
||||
}
|
||||
|
||||
// DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively
|
||||
@@ -134,14 +74,17 @@ func NewDefaultSyncer() *DefaultSyncer {
|
||||
globalListeners: []EventHandler{},
|
||||
ParseEventContent: true,
|
||||
ParseErrorHandler: func(evt *event.Event, err error) bool {
|
||||
return false
|
||||
// By default, drop known events that can't be parsed, but let unknown events through
|
||||
return errors.Is(err, event.ErrUnsupportedContentType) ||
|
||||
// Also allow events that had their content already parsed by some other function
|
||||
errors.Is(err, event.ErrContentAlreadyParsed)
|
||||
},
|
||||
}
|
||||
}
|
||||
|
||||
// ProcessResponse processes the /sync response in a way suitable for bots. "Suitable for bots" means a stream of
|
||||
// unrepeating events. Returns a fatal error if a listener panics.
|
||||
func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error) {
|
||||
func (s *DefaultSyncer) ProcessResponse(ctx context.Context, res *RespSync, since string) (err error) {
|
||||
defer func() {
|
||||
if r := recover(); r != nil {
|
||||
err = fmt.Errorf("ProcessResponse panicked! since=%s panic=%s\n%s", since, r, debug.Stack())
|
||||
@@ -149,38 +92,38 @@ func (s *DefaultSyncer) ProcessResponse(res *RespSync, since string) (err error)
|
||||
}()
|
||||
|
||||
for _, listener := range s.syncListeners {
|
||||
if !listener(res, since) {
|
||||
if !listener(ctx, res, since) {
|
||||
return
|
||||
}
|
||||
}
|
||||
|
||||
s.processSyncEvents("", res.ToDevice.Events, EventSourceToDevice)
|
||||
s.processSyncEvents("", res.Presence.Events, EventSourcePresence)
|
||||
s.processSyncEvents("", res.AccountData.Events, EventSourceAccountData)
|
||||
s.processSyncEvents(ctx, "", res.ToDevice.Events, event.SourceToDevice)
|
||||
s.processSyncEvents(ctx, "", res.Presence.Events, event.SourcePresence)
|
||||
s.processSyncEvents(ctx, "", res.AccountData.Events, event.SourceAccountData)
|
||||
|
||||
for roomID, roomData := range res.Rooms.Join {
|
||||
s.processSyncEvents(roomID, roomData.State.Events, EventSourceJoin|EventSourceState)
|
||||
s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceJoin|EventSourceTimeline)
|
||||
s.processSyncEvents(roomID, roomData.Ephemeral.Events, EventSourceJoin|EventSourceEphemeral)
|
||||
s.processSyncEvents(roomID, roomData.AccountData.Events, EventSourceJoin|EventSourceAccountData)
|
||||
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceJoin|event.SourceState)
|
||||
s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceJoin|event.SourceTimeline)
|
||||
s.processSyncEvents(ctx, roomID, roomData.Ephemeral.Events, event.SourceJoin|event.SourceEphemeral)
|
||||
s.processSyncEvents(ctx, roomID, roomData.AccountData.Events, event.SourceJoin|event.SourceAccountData)
|
||||
}
|
||||
for roomID, roomData := range res.Rooms.Invite {
|
||||
s.processSyncEvents(roomID, roomData.State.Events, EventSourceInvite|EventSourceState)
|
||||
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceInvite|event.SourceState)
|
||||
}
|
||||
for roomID, roomData := range res.Rooms.Leave {
|
||||
s.processSyncEvents(roomID, roomData.State.Events, EventSourceLeave|EventSourceState)
|
||||
s.processSyncEvents(roomID, roomData.Timeline.Events, EventSourceLeave|EventSourceTimeline)
|
||||
s.processSyncEvents(ctx, roomID, roomData.State.Events, event.SourceLeave|event.SourceState)
|
||||
s.processSyncEvents(ctx, roomID, roomData.Timeline.Events, event.SourceLeave|event.SourceTimeline)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (s *DefaultSyncer) processSyncEvents(roomID id.RoomID, events []*event.Event, source EventSource) {
|
||||
func (s *DefaultSyncer) processSyncEvents(ctx context.Context, roomID id.RoomID, events []*event.Event, source event.Source) {
|
||||
for _, evt := range events {
|
||||
s.processSyncEvent(roomID, evt, source)
|
||||
s.processSyncEvent(ctx, roomID, evt, source)
|
||||
}
|
||||
}
|
||||
|
||||
func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, source EventSource) {
|
||||
func (s *DefaultSyncer) processSyncEvent(ctx context.Context, roomID id.RoomID, evt *event.Event, source event.Source) {
|
||||
evt.RoomID = roomID
|
||||
|
||||
// Ensure the type class is correct. It's safe to mutate the class since the event type is not a pointer.
|
||||
@@ -188,11 +131,11 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
|
||||
switch {
|
||||
case evt.StateKey != nil:
|
||||
evt.Type.Class = event.StateEventType
|
||||
case source == EventSourcePresence, source&EventSourceEphemeral != 0:
|
||||
case source == event.SourcePresence, source&event.SourceEphemeral != 0:
|
||||
evt.Type.Class = event.EphemeralEventType
|
||||
case source&EventSourceAccountData != 0:
|
||||
case source&event.SourceAccountData != 0:
|
||||
evt.Type.Class = event.AccountDataEventType
|
||||
case source == EventSourceToDevice:
|
||||
case source == event.SourceToDevice:
|
||||
evt.Type.Class = event.ToDeviceEventType
|
||||
default:
|
||||
evt.Type.Class = event.MessageEventType
|
||||
@@ -205,17 +148,18 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
|
||||
}
|
||||
}
|
||||
|
||||
s.Dispatch(source, evt)
|
||||
evt.Mautrix.EventSource = source
|
||||
s.Dispatch(ctx, evt)
|
||||
}
|
||||
|
||||
func (s *DefaultSyncer) Dispatch(source EventSource, evt *event.Event) {
|
||||
func (s *DefaultSyncer) Dispatch(ctx context.Context, evt *event.Event) {
|
||||
for _, fn := range s.globalListeners {
|
||||
fn(source, evt)
|
||||
fn(ctx, evt)
|
||||
}
|
||||
listeners, exists := s.listeners[evt.Type]
|
||||
if exists {
|
||||
for _, fn := range listeners {
|
||||
fn(source, evt)
|
||||
fn(ctx, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -263,31 +207,18 @@ func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter {
|
||||
return s.FilterJSON
|
||||
}
|
||||
|
||||
// OldEventIgnorer is a utility struct for bots to ignore events from before the bot joined the room.
|
||||
//
|
||||
// Deprecated: Use Client.DontProcessOldEvents instead.
|
||||
type OldEventIgnorer struct {
|
||||
UserID id.UserID
|
||||
}
|
||||
|
||||
func (oei *OldEventIgnorer) Register(syncer ExtensibleSyncer) {
|
||||
syncer.OnSync(oei.DontProcessOldEvents)
|
||||
}
|
||||
|
||||
func (oei *OldEventIgnorer) DontProcessOldEvents(resp *RespSync, since string) bool {
|
||||
return dontProcessOldEvents(oei.UserID, resp, since)
|
||||
}
|
||||
|
||||
// DontProcessOldEvents is a sync handler that removes rooms that the user just joined.
|
||||
// It's meant for bots to ignore events from before the bot joined the room.
|
||||
//
|
||||
// To use it, register it with your Syncer, e.g.:
|
||||
//
|
||||
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.DontProcessOldEvents)
|
||||
func (cli *Client) DontProcessOldEvents(resp *RespSync, since string) bool {
|
||||
func (cli *Client) DontProcessOldEvents(_ context.Context, resp *RespSync, since string) bool {
|
||||
return dontProcessOldEvents(cli.UserID, resp, since)
|
||||
}
|
||||
|
||||
var _ SyncHandler = (*Client)(nil).DontProcessOldEvents
|
||||
|
||||
func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
|
||||
if since == "" {
|
||||
return false
|
||||
@@ -324,7 +255,7 @@ func dontProcessOldEvents(userID id.UserID, resp *RespSync, since string) bool {
|
||||
// To use it, register it with your Syncer, e.g.:
|
||||
//
|
||||
// cli.Syncer.(mautrix.ExtensibleSyncer).OnSync(cli.MoveInviteState)
|
||||
func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool {
|
||||
func (cli *Client) MoveInviteState(ctx context.Context, resp *RespSync, _ string) bool {
|
||||
for _, meta := range resp.Rooms.Invite {
|
||||
var inviteState []event.StrippedState
|
||||
var inviteEvt *event.Event
|
||||
@@ -349,3 +280,5 @@ func (cli *Client) MoveInviteState(resp *RespSync, _ string) bool {
|
||||
}
|
||||
return true
|
||||
}
|
||||
|
||||
var _ SyncHandler = (*Client)(nil).MoveInviteState
|
||||
|
||||
53
vendor/maunium.net/go/mautrix/syncstore.go
generated
vendored
53
vendor/maunium.net/go/mautrix/syncstore.go
generated
vendored
@@ -1,21 +1,26 @@
|
||||
package mautrix
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
var _ SyncStore = (*MemorySyncStore)(nil)
|
||||
var _ SyncStore = (*AccountDataStore)(nil)
|
||||
|
||||
// SyncStore is an interface which must be satisfied to store client data.
|
||||
//
|
||||
// You can either write a struct which persists this data to disk, or you can use the
|
||||
// provided "MemorySyncStore" which just keeps data around in-memory which is lost on
|
||||
// restarts.
|
||||
type SyncStore interface {
|
||||
SaveFilterID(userID id.UserID, filterID string)
|
||||
LoadFilterID(userID id.UserID) string
|
||||
SaveNextBatch(userID id.UserID, nextBatchToken string)
|
||||
LoadNextBatch(userID id.UserID) string
|
||||
SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error
|
||||
LoadFilterID(ctx context.Context, userID id.UserID) (string, error)
|
||||
SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error
|
||||
LoadNextBatch(ctx context.Context, userID id.UserID) (string, error)
|
||||
}
|
||||
|
||||
// Deprecated: renamed to SyncStore
|
||||
@@ -32,23 +37,25 @@ type MemorySyncStore struct {
|
||||
}
|
||||
|
||||
// SaveFilterID to memory.
|
||||
func (s *MemorySyncStore) SaveFilterID(userID id.UserID, filterID string) {
|
||||
func (s *MemorySyncStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error {
|
||||
s.Filters[userID] = filterID
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadFilterID from memory.
|
||||
func (s *MemorySyncStore) LoadFilterID(userID id.UserID) string {
|
||||
return s.Filters[userID]
|
||||
func (s *MemorySyncStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) {
|
||||
return s.Filters[userID], nil
|
||||
}
|
||||
|
||||
// SaveNextBatch to memory.
|
||||
func (s *MemorySyncStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
|
||||
func (s *MemorySyncStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
|
||||
s.NextBatch[userID] = nextBatchToken
|
||||
return nil
|
||||
}
|
||||
|
||||
// LoadNextBatch from memory.
|
||||
func (s *MemorySyncStore) LoadNextBatch(userID id.UserID) string {
|
||||
return s.NextBatch[userID]
|
||||
func (s *MemorySyncStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) {
|
||||
return s.NextBatch[userID], nil
|
||||
}
|
||||
|
||||
// NewMemorySyncStore constructs a new MemorySyncStore.
|
||||
@@ -72,34 +79,35 @@ type accountData struct {
|
||||
NextBatch string `json:"next_batch"`
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) SaveFilterID(userID id.UserID, filterID string) {
|
||||
func (s *AccountDataStore) SaveFilterID(ctx context.Context, userID id.UserID, filterID string) error {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
}
|
||||
s.FilterID = filterID
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) LoadFilterID(userID id.UserID) string {
|
||||
func (s *AccountDataStore) LoadFilterID(ctx context.Context, userID id.UserID) (string, error) {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
}
|
||||
return s.FilterID
|
||||
return s.FilterID, nil
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
|
||||
func (s *AccountDataStore) SaveNextBatch(ctx context.Context, userID id.UserID, nextBatchToken string) error {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
} else if nextBatchToken == s.nextBatch {
|
||||
return
|
||||
return nil
|
||||
}
|
||||
|
||||
data := accountData{
|
||||
NextBatch: nextBatchToken,
|
||||
}
|
||||
|
||||
err := s.client.SetAccountData(s.EventType, data)
|
||||
err := s.client.SetAccountData(ctx, s.EventType, data)
|
||||
if err != nil {
|
||||
s.client.Log.Warn().Err(err).Msg("Failed to save next batch token to account data")
|
||||
return fmt.Errorf("failed to save next batch token to account data: %w", err)
|
||||
} else {
|
||||
s.client.Log.Debug().
|
||||
Str("old_token", s.nextBatch).
|
||||
@@ -107,28 +115,29 @@ func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string
|
||||
Msg("Saved next batch token")
|
||||
s.nextBatch = nextBatchToken
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string {
|
||||
func (s *AccountDataStore) LoadNextBatch(ctx context.Context, userID id.UserID) (string, error) {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
}
|
||||
|
||||
data := &accountData{}
|
||||
|
||||
err := s.client.GetAccountData(s.EventType, data)
|
||||
err := s.client.GetAccountData(ctx, s.EventType, data)
|
||||
if err != nil {
|
||||
if errors.Is(err, MNotFound) {
|
||||
s.client.Log.Debug().Msg("No next batch token found in account data")
|
||||
return "", nil
|
||||
} else {
|
||||
s.client.Log.Warn().Err(err).Msg("Failed to load next batch token from account data")
|
||||
return "", fmt.Errorf("failed to load next batch token from account data: %w", err)
|
||||
}
|
||||
return ""
|
||||
}
|
||||
s.nextBatch = data.NextBatch
|
||||
s.client.Log.Debug().Str("next_batch", data.NextBatch).Msg("Loaded next batch token from account data")
|
||||
|
||||
return s.nextBatch
|
||||
return s.nextBatch, nil
|
||||
}
|
||||
|
||||
// NewAccountDataStore returns a new AccountDataStore, which stores
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/version.go
generated
vendored
2
vendor/maunium.net/go/mautrix/version.go
generated
vendored
@@ -7,7 +7,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
const Version = "v0.16.2"
|
||||
const Version = "v0.17.0"
|
||||
|
||||
var GoModVersion = ""
|
||||
var Commit = ""
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/versions.go
generated
vendored
2
vendor/maunium.net/go/mautrix/versions.go
generated
vendored
@@ -93,6 +93,8 @@ var (
|
||||
SpecV15 = MustParseSpecVersion("v1.5")
|
||||
SpecV16 = MustParseSpecVersion("v1.6")
|
||||
SpecV17 = MustParseSpecVersion("v1.7")
|
||||
SpecV18 = MustParseSpecVersion("v1.8")
|
||||
SpecV19 = MustParseSpecVersion("v1.9")
|
||||
)
|
||||
|
||||
func (svf SpecVersionFormat) String() string {
|
||||
|
||||
Reference in New Issue
Block a user