refactor to mautrix 0.17.x; update deps

This commit is contained in:
Aine
2024-02-11 20:47:04 +02:00
parent 0a9701f4c9
commit dd0ad4c245
237 changed files with 9091 additions and 3317 deletions

5
vendor/maunium.net/go/mautrix/crypto/goolm/README.md generated vendored Normal file
View File

@@ -0,0 +1,5 @@
# go-olm
This is a fork of [DerLukas's goolm](https://codeberg.org/DerLukas/goolm),
a pure Go implementation of libolm.
The original project is licensed under the MIT license.

View File

@@ -0,0 +1,522 @@
// account packages an account which stores the identity, one time keys and fallback keys.
package account
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
accountPickleVersionJSON byte = 1
accountPickleVersionLibOLM uint32 = 4
MaxOneTimeKeys int = 100 //maximum number of stored one time keys per Account
)
// Account stores an account for end to end encrypted messaging via the olm protocol.
// An Account can not be used to en/decrypt messages. However it can be used to contruct new olm sessions, which in turn do the en/decryption.
// There is no tracking of sessions in an account.
type Account struct {
IdKeys struct {
Ed25519 crypto.Ed25519KeyPair `json:"ed25519,omitempty"`
Curve25519 crypto.Curve25519KeyPair `json:"curve25519,omitempty"`
} `json:"identity_keys"`
OTKeys []crypto.OneTimeKey `json:"one_time_keys"`
CurrentFallbackKey crypto.OneTimeKey `json:"current_fallback_key,omitempty"`
PrevFallbackKey crypto.OneTimeKey `json:"prev_fallback_key,omitempty"`
NextOneTimeKeyID uint32 `json:"next_one_time_key_id,omitempty"`
NumFallbackKeys uint8 `json:"number_fallback_keys"`
}
// AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
func AccountFromJSONPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
}
a := &Account{}
err := a.UnpickleAsJSON(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
}
a := &Account{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation.
func NewAccount(reader io.Reader) (*Account, error) {
a := &Account{}
kPEd25519, err := crypto.Ed25519GenerateKey(reader)
if err != nil {
return nil, err
}
a.IdKeys.Ed25519 = kPEd25519
kPCurve25519, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return nil, err
}
a.IdKeys.Curve25519 = kPCurve25519
return a, nil
}
// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a Account) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, accountPickleVersionJSON, key)
}
// UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (a *Account) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON)
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string.
func (a Account) IdentityKeysJSON() ([]byte, error) {
res := struct {
Ed25519 string `json:"ed25519"`
Curve25519 string `json:"curve25519"`
}{}
ed25519, curve25519 := a.IdentityKeys()
res.Ed25519 = string(ed25519)
res.Curve25519 = string(curve25519)
return json.Marshal(res)
}
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account.
func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey))
curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey))
return ed25519, curve25519
}
// Sign returns the signature of a message using the Ed25519 key for this Account.
func (a Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput)
}
return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil
}
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
//
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) OneTimeKeys() map[string]id.Curve25519 {
oneTimeKeys := make(map[string]id.Curve25519)
for _, curKey := range a.OTKeys {
if !curKey.Published {
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
}
}
return oneTimeKeys
}
//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
Curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
"AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
}
}
*/
func (a Account) OneTimeKeysJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
otKeys := a.OneTimeKeys()
res["Curve25519"] = otKeys
return json.Marshal(res)
}
// MarkKeysAsPublished marks the current set of one time keys and the fallback key as being
// published.
func (a *Account) MarkKeysAsPublished() {
for keyIndex := range a.OTKeys {
if !a.OTKeys[keyIndex].Published {
a.OTKeys[keyIndex].Published = true
}
}
a.CurrentFallbackKey.Published = true
}
// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxOneTimeKeys then the older
// keys are discarded. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error {
for i := uint(0); i < num; i++ {
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return err
}
key.Key = newKP
a.NextOneTimeKeyID++
a.OTKeys = append([]crypto.OneTimeKey{key}, a.OTKeys...)
}
if len(a.OTKeys) > MaxOneTimeKeys {
a.OTKeys = a.OTKeys[:MaxOneTimeKeys]
}
return nil
}
// NewOutboundSession creates a new outbound session to a
// given curve25519 identity Key and one time key.
func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput)
}
theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey))
if err != nil {
return nil, err
}
theirOneTimeKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirOneTimeKey))
if err != nil {
return nil, err
}
s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded)
if err != nil {
return nil, err
}
return s, nil
}
// NewInboundSession creates a new inbound session from an incoming PRE_KEY message.
func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput)
}
var theirIdentityKeyDecoded *crypto.Curve25519PublicKey
var err error
if theirIdentityKey != nil {
theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey))
if err != nil {
return nil, err
}
theirIdentityKeyCurve := crypto.Curve25519PublicKey(theirIdentityKeyDecodedByte)
theirIdentityKeyDecoded = &theirIdentityKeyCurve
}
s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519)
if err != nil {
return nil, err
}
return s, nil
}
func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
return &a.OTKeys[curIndex]
}
}
if a.NumFallbackKeys >= 1 && a.CurrentFallbackKey.Key.PublicKey.Equal(toFind) {
return &a.CurrentFallbackKey
}
if a.NumFallbackKeys >= 2 && a.PrevFallbackKey.Key.PublicKey.Equal(toFind) {
return &a.PrevFallbackKey
}
return nil
}
// RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s.
func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) {
toFind := s.BobOneTimeKey
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
//Remove and return
a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1]
a.OTKeys = a.OTKeys[:len(a.OTKeys)-1]
return
}
}
//if the key is a fallback or prevFallback, don't remove it
}
// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenFallbackKey(reader io.Reader) error {
a.PrevFallbackKey = a.CurrentFallbackKey
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return err
}
key.Key = newKP
a.NextOneTimeKeyID++
if a.NumFallbackKeys < 2 {
a.NumFallbackKeys++
}
a.CurrentFallbackKey = key
return nil
}
// FallbackKey returns the public part of the current fallback key of the Account.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKey() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
}
return keys
}
//FallbackKeyJSON returns the public part of the current fallback key of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
}
}
*/
func (a Account) FallbackKeyJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKey()
res["curve25519"] = fbk
return json.Marshal(res)
}
// FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
}
return keys
}
//FallbackKeyUnpublishedJSON returns the public part of the current fallback key, only if it is unpublished, of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
}
}
*/
func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKeyUnpublished()
res["curve25519"] = fbk
return json.Marshal(res)
}
// ForgetOldFallbackKey resets the previous fallback key in the account.
func (a *Account) ForgetOldFallbackKey() {
if a.NumFallbackKeys >= 2 {
a.NumFallbackKeys = 1
a.PrevFallbackKey = crypto.OneTimeKey{}
}
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (a *Account) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = a.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read.
func (a *Account) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case accountPickleVersionLibOLM, 3, 2:
default:
return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion)
}
//read ed25519 key pair
readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//read curve25519 key pair
readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//Read number of onetimeKeys
numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//Read i one time keys
a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys)
for i := uint32(0); i < numberOTKeys; i++ {
readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
if pickledVersion <= 2 {
// version 2 did not have fallback keys
a.NumFallbackKeys = 0
} else if pickledVersion == 3 {
// version 3 used the published flag to indicate how many fallback keys
// were present (we'll have to assume that the keys were published)
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if a.CurrentFallbackKey.Published {
if a.PrevFallbackKey.Published {
a.NumFallbackKeys = 2
} else {
a.NumFallbackKeys = 1
}
} else {
a.NumFallbackKeys = 0
}
} else {
//Read number of fallback keys
numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
a.NumFallbackKeys = numFallbackKeys
if a.NumFallbackKeys >= 1 {
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if a.NumFallbackKeys >= 2 {
readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
}
}
//Read next onetime key id
nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
a.NextOneTimeKeyID = nextOTKeyID
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm().
func (a Account) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, a.PickleLen())
written, err := a.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (a Account) PickleLibOlm(target []byte) (int, error) {
if len(target) < a.PickleLen() {
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target)
writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenEdKey
writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenCurveKey
written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:])
for _, curOTKey := range a.OTKeys {
writtenOT, err := curOTKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
}
written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:])
if a.NumFallbackKeys >= 1 {
writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
if a.NumFallbackKeys >= 2 {
writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
}
}
written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled Account will have.
func (a Account) PickleLen() int {
length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM)
length += a.IdKeys.Ed25519.PickleLen()
length += a.IdKeys.Curve25519.PickleLen()
length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys)))
length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen())
length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys)
length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen())
length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID)
return length
}

22
vendor/maunium.net/go/mautrix/crypto/goolm/base64.go generated vendored Normal file
View File

@@ -0,0 +1,22 @@
package goolm
import (
"encoding/base64"
)
// Deprecated: base64.RawStdEncoding should be used directly
func Base64Decode(input []byte) ([]byte, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(input)))
writtenBytes, err := base64.RawStdEncoding.Decode(decoded, input)
if err != nil {
return nil, err
}
return decoded[:writtenBytes], nil
}
// Deprecated: base64.RawStdEncoding should be used directly
func Base64Encode(input []byte) []byte {
encoded := make([]byte, base64.RawStdEncoding.EncodedLen(len(input)))
base64.RawStdEncoding.Encode(encoded, input)
return encoded
}

View File

@@ -0,0 +1,96 @@
package cipher
import (
"bytes"
"io"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// derivedAESKeys stores the derived keys for the AESSHA256 cipher
type derivedAESKeys struct {
key []byte
hmacKey []byte
iv []byte
}
// deriveAESKeys derives three keys for the AESSHA256 cipher
func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) {
hkdf := crypto.HKDFSHA256(key, nil, kdfInfo)
keys := &derivedAESKeys{
key: make([]byte, 32),
hmacKey: make([]byte, 32),
iv: make([]byte, 16),
}
if _, err := io.ReadFull(hkdf, keys.key); err != nil {
return nil, err
}
if _, err := io.ReadFull(hkdf, keys.hmacKey); err != nil {
return nil, err
}
if _, err := io.ReadFull(hkdf, keys.iv); err != nil {
return nil, err
}
return keys, nil
}
// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256.
func AESSha512BlockSize() int {
return crypto.AESCBCBlocksize()
}
// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256.
type AESSHA256 struct {
kdfInfo []byte
}
// NewAESSHA256 returns a new AESSHA256 cipher with the key derive function info (kdfInfo).
func NewAESSHA256(kdfInfo []byte) *AESSHA256 {
return &AESSHA256{
kdfInfo: kdfInfo,
}
}
// Encrypt encrypts the plaintext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes).
func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error) {
keys, err := deriveAESKeys(c.kdfInfo, key)
if err != nil {
return nil, err
}
ciphertext, err = crypto.AESCBCEncrypt(keys.key, keys.iv, plaintext)
if err != nil {
return nil, err
}
return ciphertext, nil
}
// Decrypt decrypts the ciphertext with the key. The key is used to derive the actual encryption key (32 bytes) as well as the iv (16 bytes).
func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error) {
keys, err := deriveAESKeys(c.kdfInfo, key)
if err != nil {
return nil, err
}
plaintext, err = crypto.AESCBCDecrypt(keys.key, keys.iv, ciphertext)
if err != nil {
return nil, err
}
return plaintext, nil
}
// MAC returns the MAC for the message using the key. The key is used to derive the actual mac key (32 bytes).
func (c AESSHA256) MAC(key, message []byte) ([]byte, error) {
keys, err := deriveAESKeys(c.kdfInfo, key)
if err != nil {
return nil, err
}
return crypto.HMACSHA256(keys.hmacKey, message), nil
}
// Verify checks the MAC of the message using the key against the givenMAC. The key is used to derive the actual mac key (32 bytes).
func (c AESSHA256) Verify(key, message, givenMAC []byte) (bool, error) {
mac, err := c.MAC(key, message)
if err != nil {
return false, err
}
return bytes.Equal(givenMAC, mac[:len(givenMAC)]), nil
}

View File

@@ -0,0 +1,17 @@
// cipher provides the methods and structs to do encryptions for olm/megolm.
package cipher
// Cipher defines a valid cipher.
type Cipher interface {
// Encrypt encrypts the plaintext.
Encrypt(key, plaintext []byte) (ciphertext []byte, err error)
// Decrypt decrypts the ciphertext.
Decrypt(key, ciphertext []byte) (plaintext []byte, err error)
//MAC returns the MAC of the message calculated with the key.
MAC(key, message []byte) ([]byte, error)
//Verify checks the MAC of the message calculated with the key against the givenMAC.
Verify(key, message, givenMAC []byte) (bool, error)
}

View File

@@ -0,0 +1,58 @@
package cipher
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
)
const (
kdfPickle = "Pickle" //used to derive the keys for encryption
pickleMACLength = 8
)
// PickleBlockSize returns the blocksize of the used cipher.
func PickleBlockSize() int {
return AESSha512BlockSize()
}
// Pickle encrypts the input with the key and the cipher AESSHA256. The result is then encoded in base64.
func Pickle(key, input []byte) ([]byte, error) {
pickleCipher := NewAESSHA256([]byte(kdfPickle))
ciphertext, err := pickleCipher.Encrypt(key, input)
if err != nil {
return nil, err
}
mac, err := pickleCipher.MAC(key, ciphertext)
if err != nil {
return nil, err
}
ciphertext = append(ciphertext, mac[:pickleMACLength]...)
encoded := goolm.Base64Encode(ciphertext)
return encoded, nil
}
// Unpickle decodes the input from base64 and decrypts the decoded input with the key and the cipher AESSHA256.
func Unpickle(key, input []byte) ([]byte, error) {
pickleCipher := NewAESSHA256([]byte(kdfPickle))
ciphertext, err := goolm.Base64Decode(input)
if err != nil {
return nil, err
}
//remove mac and check
verified, err := pickleCipher.Verify(key, ciphertext[:len(ciphertext)-pickleMACLength], ciphertext[len(ciphertext)-pickleMACLength:])
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt pickle: %w", goolm.ErrBadMAC)
}
//Set to next block size
targetCipherText := make([]byte, int(len(ciphertext)/PickleBlockSize())*PickleBlockSize())
copy(targetCipherText, ciphertext)
plaintext, err := pickleCipher.Decrypt(key, targetCipherText)
if err != nil {
return nil, err
}
return plaintext, nil
}

View File

@@ -0,0 +1,75 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
)
// AESCBCBlocksize returns the blocksize of the encryption method
func AESCBCBlocksize() int {
return aes.BlockSize
}
// AESCBCEncrypt encrypts the plaintext with the key and iv. len(iv) must be equal to the blocksize!
func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
}
if len(iv) != AESCBCBlocksize() {
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
}
var cipherText []byte
plaintext = pkcs5Padding(plaintext, AESCBCBlocksize())
if len(plaintext)%AESCBCBlocksize() != 0 {
return nil, fmt.Errorf("message: %w", goolm.ErrNotMultipleBlocksize)
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cipherText = make([]byte, len(plaintext))
cbc := cipher.NewCBCEncrypter(block, iv)
cbc.CryptBlocks(cipherText, plaintext)
return cipherText, nil
}
// AESCBCDecrypt decrypts the ciphertext with the key and iv. len(iv) must be equal to the blocksize!
func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
}
if len(iv) != AESCBCBlocksize() {
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
}
var block cipher.Block
var err error
block, err = aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(ciphertext) < AESCBCBlocksize() {
return nil, fmt.Errorf("ciphertext: %w", goolm.ErrNotMultipleBlocksize)
}
cbc := cipher.NewCBCDecrypter(block, iv)
cbc.CryptBlocks(ciphertext, ciphertext)
return pkcs5Unpadding(ciphertext), nil
}
// pkcs5Padding paddes the plaintext to be used in the AESCBC encryption.
func pkcs5Padding(plaintext []byte, blockSize int) []byte {
padding := (blockSize - len(plaintext)%blockSize)
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(plaintext, padtext...)
}
// pkcs5Unpadding undoes the padding to the plaintext after AESCBC decryption.
func pkcs5Unpadding(plaintext []byte) []byte {
length := len(plaintext)
unpadding := int(plaintext[length-1])
return plaintext[:(length - unpadding)]
}

View File

@@ -0,0 +1,186 @@
package crypto
import (
"bytes"
"crypto/rand"
"encoding/base64"
"fmt"
"io"
"golang.org/x/crypto/curve25519"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/id"
)
const (
Curve25519KeyLength = curve25519.ScalarSize //The length of the private key.
curve25519PubKeyLength = 32
)
// Curve25519GenerateKey creates a new curve25519 key pair. If reader is nil, the random data is taken from crypto/rand.
func Curve25519GenerateKey(reader io.Reader) (Curve25519KeyPair, error) {
privateKeyByte := make([]byte, Curve25519KeyLength)
if reader == nil {
_, err := rand.Read(privateKeyByte)
if err != nil {
return Curve25519KeyPair{}, err
}
} else {
_, err := reader.Read(privateKeyByte)
if err != nil {
return Curve25519KeyPair{}, err
}
}
privateKey := Curve25519PrivateKey(privateKeyByte)
publicKey, err := privateKey.PubKey()
if err != nil {
return Curve25519KeyPair{}, err
}
return Curve25519KeyPair{
PrivateKey: Curve25519PrivateKey(privateKey),
PublicKey: Curve25519PublicKey(publicKey),
}, nil
}
// Curve25519GenerateFromPrivate creates a new curve25519 key pair with the private key given.
func Curve25519GenerateFromPrivate(private Curve25519PrivateKey) (Curve25519KeyPair, error) {
publicKey, err := private.PubKey()
if err != nil {
return Curve25519KeyPair{}, err
}
return Curve25519KeyPair{
PrivateKey: private,
PublicKey: Curve25519PublicKey(publicKey),
}, nil
}
// Curve25519KeyPair stores both parts of a curve25519 key.
type Curve25519KeyPair struct {
PrivateKey Curve25519PrivateKey `json:"private,omitempty"`
PublicKey Curve25519PublicKey `json:"public,omitempty"`
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Curve25519KeyPair) B64Encoded() id.Curve25519 {
return c.PublicKey.B64Encoded()
}
// SharedSecret returns the shared secret between the key pair and the given public key.
func (c Curve25519KeyPair) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
return c.PrivateKey.SharedSecret(pubKey)
}
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Curve25519KeyPair) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle curve25519 key pair: %w", goolm.ErrValueTooShort)
}
written, err := c.PublicKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle curve25519 key pair: %w", err)
}
if len(c.PrivateKey) != Curve25519KeyLength {
written += libolmpickle.PickleBytes(make([]byte, Curve25519KeyLength), target[written:])
} else {
written += libolmpickle.PickleBytes(c.PrivateKey, target[written:])
}
return written, nil
}
// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read.
func (c *Curve25519KeyPair) UnpickleLibOlm(value []byte) (int, error) {
//unpickle PubKey
read, err := c.PublicKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
//unpickle PrivateKey
privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], Curve25519KeyLength)
if err != nil {
return read, err
}
c.PrivateKey = privKey
return read + readPriv, nil
}
// PickleLen returns the number of bytes the pickled key pair will have.
func (c Curve25519KeyPair) PickleLen() int {
lenPublic := c.PublicKey.PickleLen()
var lenPrivate int
if len(c.PrivateKey) != Curve25519KeyLength {
lenPrivate = libolmpickle.PickleBytesLen(make([]byte, Curve25519KeyLength))
} else {
lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey)
}
return lenPublic + lenPrivate
}
// Curve25519PrivateKey represents the private key for curve25519 usage
type Curve25519PrivateKey []byte
// Equal compares the private key to the given private key.
func (c Curve25519PrivateKey) Equal(x Curve25519PrivateKey) bool {
return bytes.Equal(c, x)
}
// PubKey returns the public key derived from the private key.
func (c Curve25519PrivateKey) PubKey() (Curve25519PublicKey, error) {
publicKey, err := curve25519.X25519(c, curve25519.Basepoint)
if err != nil {
return nil, err
}
return publicKey, nil
}
// SharedSecret returns the shared secret between the private key and the given public key.
func (c Curve25519PrivateKey) SharedSecret(pubKey Curve25519PublicKey) ([]byte, error) {
return curve25519.X25519(c, pubKey)
}
// Curve25519PublicKey represents the public key for curve25519 usage
type Curve25519PublicKey []byte
// Equal compares the public key to the given public key.
func (c Curve25519PublicKey) Equal(x Curve25519PublicKey) bool {
return bytes.Equal(c, x)
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Curve25519PublicKey) B64Encoded() id.Curve25519 {
return id.Curve25519(base64.RawStdEncoding.EncodeToString(c))
}
// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Curve25519PublicKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle curve25519 public key: %w", goolm.ErrValueTooShort)
}
if len(c) != curve25519PubKeyLength {
return libolmpickle.PickleBytes(make([]byte, curve25519PubKeyLength), target), nil
}
return libolmpickle.PickleBytes(c, target), nil
}
// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read.
func (c *Curve25519PublicKey) UnpickleLibOlm(value []byte) (int, error) {
unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, curve25519PubKeyLength)
if err != nil {
return 0, err
}
*c = unpickled
return readBytes, nil
}
// PickleLen returns the number of bytes the pickled public key will have.
func (c Curve25519PublicKey) PickleLen() int {
if len(c) != curve25519PubKeyLength {
return libolmpickle.PickleBytesLen(make([]byte, curve25519PubKeyLength))
}
return libolmpickle.PickleBytesLen(c)
}

View File

@@ -0,0 +1,181 @@
package crypto
import (
"bytes"
"crypto/ed25519"
"encoding/base64"
"fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/id"
)
const (
ED25519SignatureSize = ed25519.SignatureSize //The length of a signature
)
// Ed25519GenerateKey creates a new ed25519 key pair. If reader is nil, the random data is taken from crypto/rand.
func Ed25519GenerateKey(reader io.Reader) (Ed25519KeyPair, error) {
publicKey, privateKey, err := ed25519.GenerateKey(reader)
if err != nil {
return Ed25519KeyPair{}, err
}
return Ed25519KeyPair{
PrivateKey: Ed25519PrivateKey(privateKey),
PublicKey: Ed25519PublicKey(publicKey),
}, nil
}
// Ed25519GenerateFromPrivate creates a new ed25519 key pair with the private key given.
func Ed25519GenerateFromPrivate(privKey Ed25519PrivateKey) Ed25519KeyPair {
return Ed25519KeyPair{
PrivateKey: privKey,
PublicKey: privKey.PubKey(),
}
}
// Ed25519GenerateFromSeed creates a new ed25519 key pair with a given seed.
func Ed25519GenerateFromSeed(seed []byte) Ed25519KeyPair {
privKey := Ed25519PrivateKey(ed25519.NewKeyFromSeed(seed))
return Ed25519KeyPair{
PrivateKey: privKey,
PublicKey: privKey.PubKey(),
}
}
// Ed25519KeyPair stores both parts of a ed25519 key.
type Ed25519KeyPair struct {
PrivateKey Ed25519PrivateKey `json:"private,omitempty"`
PublicKey Ed25519PublicKey `json:"public,omitempty"`
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Ed25519KeyPair) B64Encoded() id.Ed25519 {
return id.Ed25519(base64.RawStdEncoding.EncodeToString(c.PublicKey))
}
// Sign returns the signature for the message.
func (c Ed25519KeyPair) Sign(message []byte) []byte {
return c.PrivateKey.Sign(message)
}
// Verify checks the signature of the message against the givenSignature
func (c Ed25519KeyPair) Verify(message, givenSignature []byte) bool {
return c.PublicKey.Verify(message, givenSignature)
}
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Ed25519KeyPair) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle ed25519 key pair: %w", goolm.ErrValueTooShort)
}
written, err := c.PublicKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle ed25519 key pair: %w", err)
}
if len(c.PrivateKey) != ed25519.PrivateKeySize {
written += libolmpickle.PickleBytes(make([]byte, ed25519.PrivateKeySize), target[written:])
} else {
written += libolmpickle.PickleBytes(c.PrivateKey, target[written:])
}
return written, nil
}
// UnpickleLibOlm decodes the unencryted value and populates the key pair accordingly. It returns the number of bytes read.
func (c *Ed25519KeyPair) UnpickleLibOlm(value []byte) (int, error) {
//unpickle PubKey
read, err := c.PublicKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
//unpickle PrivateKey
privKey, readPriv, err := libolmpickle.UnpickleBytes(value[read:], ed25519.PrivateKeySize)
if err != nil {
return read, err
}
c.PrivateKey = privKey
return read + readPriv, nil
}
// PickleLen returns the number of bytes the pickled key pair will have.
func (c Ed25519KeyPair) PickleLen() int {
lenPublic := c.PublicKey.PickleLen()
var lenPrivate int
if len(c.PrivateKey) != ed25519.PrivateKeySize {
lenPrivate = libolmpickle.PickleBytesLen(make([]byte, ed25519.PrivateKeySize))
} else {
lenPrivate = libolmpickle.PickleBytesLen(c.PrivateKey)
}
return lenPublic + lenPrivate
}
// Curve25519PrivateKey represents the private key for ed25519 usage. This is just a wrapper.
type Ed25519PrivateKey ed25519.PrivateKey
// Equal compares the private key to the given private key.
func (c Ed25519PrivateKey) Equal(x Ed25519PrivateKey) bool {
return bytes.Equal(c, x)
}
// PubKey returns the public key derived from the private key.
func (c Ed25519PrivateKey) PubKey() Ed25519PublicKey {
publicKey := ed25519.PrivateKey(c).Public()
return Ed25519PublicKey(publicKey.(ed25519.PublicKey))
}
// Sign returns the signature for the message.
func (c Ed25519PrivateKey) Sign(message []byte) []byte {
return ed25519.Sign(ed25519.PrivateKey(c), message)
}
// Ed25519PublicKey represents the public key for ed25519 usage. This is just a wrapper.
type Ed25519PublicKey ed25519.PublicKey
// Equal compares the public key to the given public key.
func (c Ed25519PublicKey) Equal(x Ed25519PublicKey) bool {
return bytes.Equal(c, x)
}
// B64Encoded returns a base64 encoded string of the public key.
func (c Ed25519PublicKey) B64Encoded() id.Curve25519 {
return id.Curve25519(base64.RawStdEncoding.EncodeToString(c))
}
// Verify checks the signature of the message against the givenSignature
func (c Ed25519PublicKey) Verify(message, givenSignature []byte) bool {
return ed25519.Verify(ed25519.PublicKey(c), message, givenSignature)
}
// PickleLibOlm encodes the public key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c Ed25519PublicKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle ed25519 public key: %w", goolm.ErrValueTooShort)
}
if len(c) != ed25519.PublicKeySize {
return libolmpickle.PickleBytes(make([]byte, ed25519.PublicKeySize), target), nil
}
return libolmpickle.PickleBytes(c, target), nil
}
// UnpickleLibOlm decodes the unencryted value and populates the public key accordingly. It returns the number of bytes read.
func (c *Ed25519PublicKey) UnpickleLibOlm(value []byte) (int, error) {
unpickled, readBytes, err := libolmpickle.UnpickleBytes(value, ed25519.PublicKeySize)
if err != nil {
return 0, err
}
*c = unpickled
return readBytes, nil
}
// PickleLen returns the number of bytes the pickled public key will have.
func (c Ed25519PublicKey) PickleLen() int {
if len(c) != ed25519.PublicKeySize {
return libolmpickle.PickleBytesLen(make([]byte, ed25519.PublicKeySize))
}
return libolmpickle.PickleBytesLen(c)
}

View File

@@ -0,0 +1,29 @@
package crypto
import (
"crypto/hmac"
"crypto/sha256"
"io"
"golang.org/x/crypto/hkdf"
)
// HMACSHA256 returns the hash message authentication code with SHA-256 of the input with the key.
func HMACSHA256(key, input []byte) []byte {
hash := hmac.New(sha256.New, key)
hash.Write(input)
return hash.Sum(nil)
}
// SHA256 return the SHA-256 of the value.
func SHA256(value []byte) []byte {
hash := sha256.New()
hash.Write(value)
return hash.Sum(nil)
}
// HKDFSHA256 is the key deivation function based on HMAC and returns a reader based on input. salt and info can both be nil.
// The reader can be used to read an arbitary length of bytes which are based on all parameters.
func HKDFSHA256(input, salt, info []byte) io.Reader {
return hkdf.New(sha256.New, input, salt, info)
}

View File

@@ -0,0 +1,2 @@
// crpyto provides the nessesary encryption methods for olm/megolm
package crypto

View File

@@ -0,0 +1,95 @@
package crypto
import (
"encoding/base64"
"encoding/binary"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/id"
)
// OneTimeKey stores the information about a one time key.
type OneTimeKey struct {
ID uint32 `json:"id"`
Published bool `json:"published"`
Key Curve25519KeyPair `json:"key,omitempty"`
}
// Equal compares the one time key to the given one.
func (otk OneTimeKey) Equal(s OneTimeKey) bool {
if otk.ID != s.ID {
return false
}
if otk.Published != s.Published {
return false
}
if !otk.Key.PrivateKey.Equal(s.Key.PrivateKey) {
return false
}
if !otk.Key.PublicKey.Equal(s.Key.PublicKey) {
return false
}
return true
}
// PickleLibOlm encodes the key pair into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (c OneTimeKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < c.PickleLen() {
return 0, fmt.Errorf("pickle one time key: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(uint32(c.ID), target)
written += libolmpickle.PickleBool(c.Published, target[written:])
writtenKey, err := c.Key.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle one time key: %w", err)
}
written += writtenKey
return written, nil
}
// UnpickleLibOlm decodes the unencryted value and populates the OneTimeKey accordingly. It returns the number of bytes read.
func (c *OneTimeKey) UnpickleLibOlm(value []byte) (int, error) {
totalReadBytes := 0
id, readBytes, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
totalReadBytes += readBytes
c.ID = id
published, readBytes, err := libolmpickle.UnpickleBool(value[totalReadBytes:])
if err != nil {
return 0, err
}
totalReadBytes += readBytes
c.Published = published
readBytes, err = c.Key.UnpickleLibOlm(value[totalReadBytes:])
if err != nil {
return 0, err
}
totalReadBytes += readBytes
return totalReadBytes, nil
}
// PickleLen returns the number of bytes the pickled OneTimeKey will have.
func (c OneTimeKey) PickleLen() int {
length := 0
length += libolmpickle.PickleUInt32Len(c.ID)
length += libolmpickle.PickleBoolLen(c.Published)
length += c.Key.PickleLen()
return length
}
// KeyIDEncoded returns the base64 encoded id.
func (c OneTimeKey) KeyIDEncoded() string {
resSlice := make([]byte, 4)
binary.BigEndian.PutUint32(resSlice, c.ID)
return base64.RawStdEncoding.EncodeToString(resSlice)
}
// PublicKeyEncoded returns the base64 encoded public key
func (c OneTimeKey) PublicKeyEncoded() id.Curve25519 {
return c.Key.PublicKey.B64Encoded()
}

30
vendor/maunium.net/go/mautrix/crypto/goolm/errors.go generated vendored Normal file
View File

@@ -0,0 +1,30 @@
package goolm
import (
"errors"
)
// Those are the most common used errors
var (
ErrBadSignature = errors.New("bad signature")
ErrBadMAC = errors.New("bad mac")
ErrBadMessageFormat = errors.New("bad message format")
ErrBadVerification = errors.New("bad verification")
ErrWrongProtocolVersion = errors.New("wrong protocol version")
ErrEmptyInput = errors.New("empty input")
ErrNoKeyProvided = errors.New("no key")
ErrBadMessageKeyID = errors.New("bad message key id")
ErrRatchetNotAvailable = errors.New("ratchet not available: attempt to decode a message whose index is earlier than our earliest known session key")
ErrMsgIndexTooHigh = errors.New("message index too high")
ErrProtocolViolation = errors.New("not protocol message order")
ErrMessageKeyNotFound = errors.New("message key not found")
ErrChainTooHigh = errors.New("chain index too high")
ErrBadInput = errors.New("bad input")
ErrBadVersion = errors.New("wrong version")
ErrNotBlocksize = errors.New("length != blocksize")
ErrNotMultipleBlocksize = errors.New("length not a multiple of the blocksize")
ErrWrongPickleVersion = errors.New("wrong pickle version")
ErrValueTooShort = errors.New("value too short")
ErrInputToSmall = errors.New("input too small (truncated?)")
ErrOverflow = errors.New("overflow")
)

View File

@@ -0,0 +1,41 @@
package libolmpickle
import (
"encoding/binary"
)
func PickleUInt8(value uint8, target []byte) int {
target[0] = value
return 1
}
func PickleUInt8Len(value uint8) int {
return 1
}
func PickleBool(value bool, target []byte) int {
if value {
target[0] = 0x01
} else {
target[0] = 0x00
}
return 1
}
func PickleBoolLen(value bool) int {
return 1
}
func PickleBytes(value, target []byte) int {
return copy(target, value)
}
func PickleBytesLen(value []byte) int {
return len(value)
}
func PickleUInt32(value uint32, target []byte) int {
res := make([]byte, 4) //4 bytes for int32
binary.BigEndian.PutUint32(res, value)
return copy(target, res)
}
func PickleUInt32Len(value uint32) int {
return 4
}

View File

@@ -0,0 +1,53 @@
package libolmpickle
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
)
func isZeroByteSlice(bytes []byte) bool {
b := byte(0)
for _, s := range bytes {
b |= s
}
return b == 0
}
func UnpickleUInt8(value []byte) (uint8, int, error) {
if len(value) < 1 {
return 0, 0, fmt.Errorf("unpickle uint8: %w", goolm.ErrValueTooShort)
}
return value[0], 1, nil
}
func UnpickleBool(value []byte) (bool, int, error) {
if len(value) < 1 {
return false, 0, fmt.Errorf("unpickle bool: %w", goolm.ErrValueTooShort)
}
return value[0] != uint8(0x00), 1, nil
}
func UnpickleBytes(value []byte, length int) ([]byte, int, error) {
if len(value) < length {
return nil, 0, fmt.Errorf("unpickle bytes: %w", goolm.ErrValueTooShort)
}
resp := value[:length]
if isZeroByteSlice(resp) {
return nil, length, nil
}
return resp, length, nil
}
func UnpickleUInt32(value []byte) (uint32, int, error) {
if len(value) < 4 {
return 0, 0, fmt.Errorf("unpickle uint32: %w", goolm.ErrValueTooShort)
}
var res uint32
count := 0
for i := 3; i >= 0; i-- {
res |= uint32(value[count]) << (8 * i)
count++
}
return res, 4, nil
}

6
vendor/maunium.net/go/mautrix/crypto/goolm/main.go generated vendored Normal file
View File

@@ -0,0 +1,6 @@
// Package goolm is a pure Go implementation of libolm. Libolm is a cryptographic library used for end-to-end encryption in Matrix and written in C++.
// With goolm there is no need to use cgo when building Matrix clients in go.
/*
This package contains the possible errors which can occur as well as some simple functions. All the 'action' happens in the subdirectories.
*/
package goolm

View File

@@ -0,0 +1,234 @@
// megolm provides the ratchet used by the megolm protocol
package megolm
import (
"crypto/rand"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
megolmPickleVersion uint8 = 1
)
const (
protocolVersion = 3
RatchetParts = 4 // number of ratchet parts
RatchetPartLength = 256 / 8 // length of each ratchet part in bytes
)
var RatchetCipher = cipher.NewAESSHA256([]byte("MEGOLM_KEYS"))
// hasKeySeed are the seed for the different ratchet parts
var hashKeySeeds [RatchetParts][]byte = [RatchetParts][]byte{
{0x00},
{0x01},
{0x02},
{0x03},
}
// Ratchet represents the megolm ratchet as described in
//
// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/megolm.md
type Ratchet struct {
Data [RatchetParts * RatchetPartLength]byte `json:"data"`
Counter uint32 `json:"counter"`
}
// New creates a new ratchet with counter set to counter and the ratchet data set to data.
func New(counter uint32, data [RatchetParts * RatchetPartLength]byte) (*Ratchet, error) {
m := &Ratchet{
Counter: counter,
Data: data,
}
return m, nil
}
// NewWithRandom creates a new ratchet with counter set to counter an the data filled with random values.
func NewWithRandom(counter uint32) (*Ratchet, error) {
var data [RatchetParts * RatchetPartLength]byte
_, err := rand.Read(data[:])
if err != nil {
return nil, err
}
return New(counter, data)
}
// rehashPart rehases the part of the ratchet data with the base defined as from storing into the target to.
func (m *Ratchet) rehashPart(from, to int) {
newData := crypto.HMACSHA256(m.Data[from*RatchetPartLength:from*RatchetPartLength+RatchetPartLength], hashKeySeeds[to])
copy(m.Data[to*RatchetPartLength:], newData[:RatchetPartLength])
}
// Advance advances the ratchet one step.
func (m *Ratchet) Advance() {
var mask uint32 = 0x00FFFFFF
var h int
m.Counter++
// figure out how much we need to rekey
for h < RatchetParts {
if (m.Counter & mask) == 0 {
break
}
h++
mask >>= 8
}
// now update R(h)...R(3) based on R(h)
for i := RatchetParts - 1; i >= h; i-- {
m.rehashPart(h, i)
}
}
// AdvanceTo advances the ratchet so that the ratchet counter = target
func (m *Ratchet) AdvanceTo(target uint32) {
//starting with R0, see if we need to update each part of the hash
for j := 0; j < RatchetParts; j++ {
shift := uint32((RatchetParts - j - 1) * 8)
mask := (^uint32(0)) << shift
// how many times do we need to rehash this part?
// '& 0xff' ensures we handle integer wraparound correctly
steps := ((target >> shift) - m.Counter>>shift) & uint32(0xff)
if steps == 0 {
/*
deal with the edge case where m.Counter is slightly larger
than target. This should only happen for R(0), and implies
that target has wrapped around and we need to advance R(0)
256 times.
*/
if target < m.Counter {
steps = 0x100
} else {
continue
}
}
// for all but the last step, we can just bump R(j) without regard to R(j+1)...R(3).
for steps > 1 {
m.rehashPart(j, j)
steps--
}
/*
on the last step we also need to bump R(j+1)...R(3).
(Theoretically, we could skip bumping R(j+2) if we're going to bump
R(j+1) again, but the code to figure that out is a bit baroque and
doesn't save us much).
*/
for k := 3; k >= j; k-- {
m.rehashPart(j, k)
}
m.Counter = target & mask
}
}
// Encrypt encrypts the message in a message.GroupMessage with MAC and signature.
// The output is base64 encoded.
func (r *Ratchet) Encrypt(plaintext []byte, key *crypto.Ed25519KeyPair) ([]byte, error) {
var err error
encryptedText, err := RatchetCipher.Encrypt(r.Data[:], plaintext)
if err != nil {
return nil, fmt.Errorf("cipher encrypt: %w", err)
}
message := &message.GroupMessage{}
message.Version = protocolVersion
message.MessageIndex = r.Counter
message.Ciphertext = encryptedText
//creating the mac and signing is done in encode
output, err := message.EncodeAndMacAndSign(r.Data[:], RatchetCipher, key)
if err != nil {
return nil, err
}
r.Advance()
return output, nil
}
// SessionSharingMessage creates a message in the session sharing format.
func (r Ratchet) SessionSharingMessage(key crypto.Ed25519KeyPair) ([]byte, error) {
m := message.MegolmSessionSharing{}
m.Counter = r.Counter
m.RatchetData = r.Data
encoded := m.EncodeAndSign(key)
return goolm.Base64Encode(encoded), nil
}
// SessionExportMessage creates a message in the session export format.
func (r Ratchet) SessionExportMessage(key crypto.Ed25519PublicKey) ([]byte, error) {
m := message.MegolmSessionExport{}
m.Counter = r.Counter
m.RatchetData = r.Data
m.PublicKey = key
encoded := m.Encode()
return goolm.Base64Encode(encoded), nil
}
// Decrypt decrypts the ciphertext and verifies the MAC but not the signature.
func (r Ratchet) Decrypt(ciphertext []byte, signingkey *crypto.Ed25519PublicKey, msg *message.GroupMessage) ([]byte, error) {
//verify mac
verifiedMAC, err := msg.VerifyMACInline(r.Data[:], RatchetCipher, ciphertext)
if err != nil {
return nil, err
}
if !verifiedMAC {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC)
}
return RatchetCipher.Decrypt(r.Data[:], msg.Ciphertext)
}
// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(r, megolmPickleVersion, key)
}
// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(r, pickled, key, megolmPickleVersion)
}
// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read.
func (r *Ratchet) UnpickleLibOlm(unpickled []byte) (int, error) {
//read ratchet data
curPos := 0
ratchetData, readBytes, err := libolmpickle.UnpickleBytes(unpickled, RatchetParts*RatchetPartLength)
if err != nil {
return 0, err
}
copy(r.Data[:], ratchetData)
curPos += readBytes
//Read counter
counter, readBytes, err := libolmpickle.UnpickleUInt32(unpickled[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
r.Counter = counter
return curPos, nil
}
// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r Ratchet) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleBytes(r.Data[:], target)
written += libolmpickle.PickleUInt32(r.Counter, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled ratchet will have.
func (r Ratchet) PickleLen() int {
length := libolmpickle.PickleBytesLen(r.Data[:])
length += libolmpickle.PickleUInt32Len(r.Counter)
return length
}

View File

@@ -0,0 +1,70 @@
package message
import (
"encoding/binary"
"maunium.net/go/mautrix/crypto/goolm"
)
// checkDecodeErr checks if there was an error during decode.
func checkDecodeErr(readBytes int) error {
if readBytes == 0 {
//end reached
return goolm.ErrInputToSmall
}
if readBytes < 0 {
return goolm.ErrOverflow
}
return nil
}
// decodeVarInt decodes a single big-endian encoded varint.
func decodeVarInt(input []byte) (uint32, int) {
value, readBytes := binary.Uvarint(input)
return uint32(value), readBytes
}
// decodeVarString decodes the length of the string (varint) and returns the actual string
func decodeVarString(input []byte) ([]byte, int) {
stringLen, readBytes := decodeVarInt(input)
if readBytes <= 0 {
return nil, readBytes
}
input = input[readBytes:]
value := input[:stringLen]
readBytes += int(stringLen)
return value, readBytes
}
// encodeVarIntByteLength returns the number of bytes needed to encode the uint32.
func encodeVarIntByteLength(input uint32) int {
result := 1
for input >= 128 {
result++
input >>= 7
}
return result
}
// encodeVarStringByteLength returns the number of bytes needed to encode the input.
func encodeVarStringByteLength(input []byte) int {
result := encodeVarIntByteLength(uint32(len(input)))
result += len(input)
return result
}
// encodeVarInt encodes a single uint32
func encodeVarInt(input uint32) []byte {
out := make([]byte, encodeVarIntByteLength(input))
binary.PutUvarint(out, uint64(input))
return out
}
// encodeVarString encodes the length of the input (varint) and appends the actual input
func encodeVarString(input []byte) []byte {
out := make([]byte, encodeVarStringByteLength(input))
length := encodeVarInt(uint32(len(input)))
copy(out, length)
copy(out[len(length):], input)
return out
}

View File

@@ -0,0 +1,144 @@
package message
import (
"bytes"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
messageIndexTag = 0x08
cipherTextTag = 0x12
countMACBytesGroupMessage = 8
)
// GroupMessage represents a message in the group message format.
type GroupMessage struct {
Version byte `json:"version"`
MessageIndex uint32 `json:"index"`
Ciphertext []byte `json:"ciphertext"`
HasMessageIndex bool `json:"has_index"`
}
// Decodes decodes the input and populates the corresponding fileds. MAC and signature are ignored but have to be present.
func (r *GroupMessage) Decode(input []byte) error {
r.Version = 0
r.MessageIndex = 0
r.Ciphertext = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesGroupMessage-crypto.ED25519SignatureSize {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case messageIndexTag:
r.MessageIndex = value
r.HasMessageIndex = true
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case cipherTextTag:
r.Ciphertext = value
}
}
}
return nil
}
// EncodeAndMacAndSign encodes the message, creates the mac with the key and the cipher and signs the message.
// If macKey or cipher is nil, no mac is appended. If signKey is nil, no signature is appended.
func (r *GroupMessage) EncodeAndMacAndSign(macKey []byte, cipher cipher.Cipher, signKey *crypto.Ed25519KeyPair) ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(messageIndexTag) + encodeVarIntByteLength(r.MessageIndex)
lengthOfMessage += encodeVarIntByteLength(cipherTextTag) + encodeVarStringByteLength(r.Ciphertext)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(messageIndexTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarInt(r.MessageIndex)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(cipherTextTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Ciphertext)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
if len(macKey) != 0 && cipher != nil {
mac, err := r.MAC(macKey, cipher, out)
if err != nil {
return nil, err
}
out = append(out, mac[:countMACBytesGroupMessage]...)
}
if signKey != nil {
signature := signKey.Sign(out)
out = append(out, signature...)
}
return out, nil
}
// MAC returns the MAC of the message calculated with cipher and key. The length of the MAC is truncated to the correct length.
func (r *GroupMessage) MAC(key []byte, cipher cipher.Cipher, message []byte) ([]byte, error) {
mac, err := cipher.MAC(key, message)
if err != nil {
return nil, err
}
return mac[:countMACBytesGroupMessage], nil
}
// VerifySignature verifies the givenSignature to the calculated signature of the message.
func (r *GroupMessage) VerifySignature(key crypto.Ed25519PublicKey, message, givenSignature []byte) bool {
return key.Verify(message, givenSignature)
}
// VerifySignature verifies the signature taken from the message to the calculated signature of the message.
func (r *GroupMessage) VerifySignatureInline(key crypto.Ed25519PublicKey, message []byte) bool {
signature := message[len(message)-crypto.ED25519SignatureSize:]
message = message[:len(message)-crypto.ED25519SignatureSize]
return key.Verify(message, signature)
}
// VerifyMAC verifies the givenMAC to the calculated MAC of the message.
func (r *GroupMessage) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) {
checkMac, err := r.MAC(key, cipher, message)
if err != nil {
return false, err
}
return bytes.Equal(checkMac[:countMACBytesGroupMessage], givenMAC), nil
}
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
func (r *GroupMessage) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
startMAC := len(message) - countMACBytesGroupMessage - crypto.ED25519SignatureSize
endMAC := startMAC + countMACBytesGroupMessage
suplMac := message[startMAC:endMAC]
message = message[:startMAC]
return r.VerifyMAC(key, cipher, message, suplMac)
}

View File

@@ -0,0 +1,129 @@
package message
import (
"bytes"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
ratchetKeyTag = 0x0A
counterTag = 0x10
cipherTextKeyTag = 0x22
countMACBytesMessage = 8
)
// GroupMessage represents a message in the message format.
type Message struct {
Version byte `json:"version"`
HasCounter bool `json:"has_counter"`
Counter uint32 `json:"counter"`
RatchetKey crypto.Curve25519PublicKey `json:"ratchet_key"`
Ciphertext []byte `json:"ciphertext"`
}
// Decodes decodes the input and populates the corresponding fileds. MAC is ignored but has to be present.
func (r *Message) Decode(input []byte) error {
r.Version = 0
r.HasCounter = false
r.Counter = 0
r.RatchetKey = nil
r.Ciphertext = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input)-countMACBytesMessage {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
value, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case counterTag:
r.HasCounter = true
r.Counter = value
}
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case ratchetKeyTag:
r.RatchetKey = value
case cipherTextKeyTag:
r.Ciphertext = value
}
}
}
return nil
}
// EncodeAndMAC encodes the message and creates the MAC with the key and the cipher.
// If key or cipher is nil, no MAC is appended.
func (r *Message) EncodeAndMAC(key []byte, cipher cipher.Cipher) ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(ratchetKeyTag) + encodeVarStringByteLength(r.RatchetKey)
lengthOfMessage += encodeVarIntByteLength(counterTag) + encodeVarIntByteLength(r.Counter)
lengthOfMessage += encodeVarIntByteLength(cipherTextKeyTag) + encodeVarStringByteLength(r.Ciphertext)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(ratchetKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarString(r.RatchetKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(counterTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarInt(r.Counter)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(cipherTextKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Ciphertext)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
if len(key) != 0 && cipher != nil {
mac, err := cipher.MAC(key, out)
if err != nil {
return nil, err
}
out = append(out, mac[:countMACBytesMessage]...)
}
return out, nil
}
// VerifyMAC verifies the givenMAC to the calculated MAC of the message.
func (r *Message) VerifyMAC(key []byte, cipher cipher.Cipher, message, givenMAC []byte) (bool, error) {
checkMAC, err := cipher.MAC(key, message)
if err != nil {
return false, err
}
return bytes.Equal(checkMAC[:countMACBytesMessage], givenMAC), nil
}
// VerifyMACInline verifies the MAC taken from the message to the calculated MAC of the message.
func (r *Message) VerifyMACInline(key []byte, cipher cipher.Cipher, message []byte) (bool, error) {
givenMAC := message[len(message)-countMACBytesMessage:]
return r.VerifyMAC(key, cipher, message[:len(message)-countMACBytesMessage], givenMAC)
}

View File

@@ -0,0 +1,120 @@
package message
import (
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
oneTimeKeyIdTag = 0x0A
baseKeyTag = 0x12
identityKeyTag = 0x1A
messageTag = 0x22
)
type PreKeyMessage struct {
Version byte `json:"version"`
IdentityKey crypto.Curve25519PublicKey `json:"id_key"`
BaseKey crypto.Curve25519PublicKey `json:"base_key"`
OneTimeKey crypto.Curve25519PublicKey `json:"one_time_key"`
Message []byte `json:"message"`
}
// Decodes decodes the input and populates the corresponding fileds.
func (r *PreKeyMessage) Decode(input []byte) error {
r.Version = 0
r.IdentityKey = nil
r.BaseKey = nil
r.OneTimeKey = nil
r.Message = nil
if len(input) == 0 {
return nil
}
//first Byte is always version
r.Version = input[0]
curPos := 1
for curPos < len(input) {
//Read Key
curKey, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
if (curKey & 0b111) == 0 {
//The value is of type varint
_, readBytes := decodeVarInt(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
} else if (curKey & 0b111) == 2 {
//The value is of type string
value, readBytes := decodeVarString(input[curPos:])
if err := checkDecodeErr(readBytes); err != nil {
return err
}
curPos += readBytes
switch curKey {
case oneTimeKeyIdTag:
r.OneTimeKey = value
case baseKeyTag:
r.BaseKey = value
case identityKeyTag:
r.IdentityKey = value
case messageTag:
r.Message = value
}
}
}
return nil
}
// CheckField verifies the fields. If theirIdentityKey is nil, it is not compared to the key in the message.
func (r *PreKeyMessage) CheckFields(theirIdentityKey *crypto.Curve25519PublicKey) bool {
ok := true
ok = ok && (theirIdentityKey != nil || r.IdentityKey != nil)
if r.IdentityKey != nil {
ok = ok && (len(r.IdentityKey) == crypto.Curve25519KeyLength)
}
ok = ok && len(r.Message) != 0
ok = ok && len(r.BaseKey) == crypto.Curve25519KeyLength
ok = ok && len(r.OneTimeKey) == crypto.Curve25519KeyLength
return ok
}
// Encode encodes the message.
func (r *PreKeyMessage) Encode() ([]byte, error) {
var lengthOfMessage int
lengthOfMessage += 1 //Version
lengthOfMessage += encodeVarIntByteLength(oneTimeKeyIdTag) + encodeVarStringByteLength(r.OneTimeKey)
lengthOfMessage += encodeVarIntByteLength(identityKeyTag) + encodeVarStringByteLength(r.IdentityKey)
lengthOfMessage += encodeVarIntByteLength(baseKeyTag) + encodeVarStringByteLength(r.BaseKey)
lengthOfMessage += encodeVarIntByteLength(messageTag) + encodeVarStringByteLength(r.Message)
out := make([]byte, lengthOfMessage)
out[0] = r.Version
curPos := 1
encodedTag := encodeVarInt(oneTimeKeyIdTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue := encodeVarString(r.OneTimeKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(identityKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.IdentityKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(baseKeyTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.BaseKey)
copy(out[curPos:], encodedValue)
curPos += len(encodedValue)
encodedTag = encodeVarInt(messageTag)
copy(out[curPos:], encodedTag)
curPos += len(encodedTag)
encodedValue = encodeVarString(r.Message)
copy(out[curPos:], encodedValue)
return out, nil
}

View File

@@ -0,0 +1,44 @@
package message
import (
"encoding/binary"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
sessionExportVersion = 0x01
)
// MegolmSessionExport represents a message in the session export format.
type MegolmSessionExport struct {
Counter uint32 `json:"counter"`
RatchetData [128]byte `json:"data"`
PublicKey crypto.Ed25519PublicKey `json:"public_key"`
}
// Encode returns the encoded message in the correct format.
func (s MegolmSessionExport) Encode() []byte {
output := make([]byte, 165)
output[0] = sessionExportVersion
binary.BigEndian.PutUint32(output[1:], s.Counter)
copy(output[5:], s.RatchetData[:])
copy(output[133:], s.PublicKey)
return output
}
// Decode populates the struct with the data encoded in input.
func (s *MegolmSessionExport) Decode(input []byte) error {
if len(input) != 165 {
return fmt.Errorf("decrypt: %w", goolm.ErrBadInput)
}
if input[0] != sessionExportVersion {
return fmt.Errorf("decrypt: %w", goolm.ErrBadVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
s.PublicKey = input[133:]
return nil
}

View File

@@ -0,0 +1,50 @@
package message
import (
"encoding/binary"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
const (
sessionSharingVersion = 0x02
)
// MegolmSessionSharing represents a message in the session sharing format.
type MegolmSessionSharing struct {
Counter uint32 `json:"counter"`
RatchetData [128]byte `json:"data"`
PublicKey crypto.Ed25519PublicKey `json:"-"` //only used when decrypting messages
}
// Encode returns the encoded message in the correct format with the signature by key appended.
func (s MegolmSessionSharing) EncodeAndSign(key crypto.Ed25519KeyPair) []byte {
output := make([]byte, 229)
output[0] = sessionSharingVersion
binary.BigEndian.PutUint32(output[1:], s.Counter)
copy(output[5:], s.RatchetData[:])
copy(output[133:], key.PublicKey)
signature := key.Sign(output[:165])
copy(output[165:], signature)
return output
}
// VerifyAndDecode verifies the input and populates the struct with the data encoded in input.
func (s *MegolmSessionSharing) VerifyAndDecode(input []byte) error {
if len(input) != 229 {
return fmt.Errorf("verify: %w", goolm.ErrBadInput)
}
publicKey := crypto.Ed25519PublicKey(input[133:165])
if !publicKey.Verify(input[:165], input[165:]) {
return fmt.Errorf("verify: %w", goolm.ErrBadVerification)
}
s.PublicKey = publicKey
if input[0] != sessionSharingVersion {
return fmt.Errorf("verify: %w", goolm.ErrBadVersion)
}
s.Counter = binary.BigEndian.Uint32(input[1:5])
copy(s.RatchetData[:], input[5:133])
return nil
}

258
vendor/maunium.net/go/mautrix/crypto/goolm/olm/chain.go generated vendored Normal file
View File

@@ -0,0 +1,258 @@
package olm
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
)
const (
chainKeySeed = 0x02
messageKeyLength = 32
)
// chainKey wraps the index and the public key
type chainKey struct {
Index uint32 `json:"index"`
Key crypto.Curve25519PublicKey `json:"key"`
}
// advance advances the chain
func (c *chainKey) advance() {
c.Key = crypto.HMACSHA256(c.Key, []byte{chainKeySeed})
c.Index++
}
// UnpickleLibOlm decodes the unencryted value and populates the chain key accordingly. It returns the number of bytes read.
func (r *chainKey) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.Key.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
r.Index, readBytes, err = libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r chainKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle chain key: %w", goolm.ErrValueTooShort)
}
written, err := r.Key.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle chain key: %w", err)
}
written += libolmpickle.PickleUInt32(r.Index, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled chain key will have.
func (r chainKey) PickleLen() int {
length := r.Key.PickleLen()
length += libolmpickle.PickleUInt32Len(r.Index)
return length
}
// senderChain is a chain for sending messages
type senderChain struct {
RKey crypto.Curve25519KeyPair `json:"ratchet_key"`
CKey chainKey `json:"chain_key"`
IsSet bool `json:"set"`
}
// newSenderChain returns a sender chain initialized with chainKey and ratchet key pair.
func newSenderChain(key crypto.Curve25519PublicKey, ratchet crypto.Curve25519KeyPair) *senderChain {
return &senderChain{
RKey: ratchet,
CKey: chainKey{
Index: 0,
Key: key,
},
IsSet: true,
}
}
// advance advances the chain
func (s *senderChain) advance() {
s.CKey.advance()
}
// ratchetKey returns the ratchet key pair.
func (s senderChain) ratchetKey() crypto.Curve25519KeyPair {
return s.RKey
}
// chainKey returns the current chainKey.
func (s senderChain) chainKey() chainKey {
return s.CKey
}
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
func (r *senderChain) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.RKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r senderChain) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
}
written, err := r.RKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
writtenChain, err := r.CKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
written += writtenChain
return written, nil
}
// PickleLen returns the number of bytes the pickled chain will have.
func (r senderChain) PickleLen() int {
length := r.RKey.PickleLen()
length += r.CKey.PickleLen()
return length
}
// senderChain is a chain for receiving messages
type receiverChain struct {
RKey crypto.Curve25519PublicKey `json:"ratchet_key"`
CKey chainKey `json:"chain_key"`
}
// newReceiverChain returns a receiver chain initialized with chainKey and ratchet public key.
func newReceiverChain(chain crypto.Curve25519PublicKey, ratchet crypto.Curve25519PublicKey) *receiverChain {
return &receiverChain{
RKey: ratchet,
CKey: chainKey{
Index: 0,
Key: chain,
},
}
}
// advance advances the chain
func (s *receiverChain) advance() {
s.CKey.advance()
}
// ratchetKey returns the ratchet public key.
func (s receiverChain) ratchetKey() crypto.Curve25519PublicKey {
return s.RKey
}
// chainKey returns the current chainKey.
func (s receiverChain) chainKey() chainKey {
return s.CKey
}
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
func (r *receiverChain) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.RKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = r.CKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r receiverChain) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
}
written, err := r.RKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
writtenChain, err := r.CKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
written += writtenChain
return written, nil
}
// PickleLen returns the number of bytes the pickled chain will have.
func (r receiverChain) PickleLen() int {
length := r.RKey.PickleLen()
length += r.CKey.PickleLen()
return length
}
// messageKey wraps the index and the key of a message
type messageKey struct {
Index uint32 `json:"index"`
Key []byte `json:"key"`
}
// UnpickleLibOlm decodes the unencryted value and populates the message key accordingly. It returns the number of bytes read.
func (m *messageKey) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
ratchetKey, readBytes, err := libolmpickle.UnpickleBytes(value, messageKeyLength)
if err != nil {
return 0, err
}
m.Key = ratchetKey
curPos += readBytes
keyID, readBytes, err := libolmpickle.UnpickleUInt32(value[:curPos])
if err != nil {
return 0, err
}
curPos += readBytes
m.Index = keyID
return curPos, nil
}
// PickleLibOlm encodes the message key into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (m messageKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < m.PickleLen() {
return 0, fmt.Errorf("pickle message key: %w", goolm.ErrValueTooShort)
}
written := 0
if len(m.Key) != messageKeyLength {
written += libolmpickle.PickleBytes(make([]byte, messageKeyLength), target)
} else {
written += libolmpickle.PickleBytes(m.Key, target)
}
written += libolmpickle.PickleUInt32(m.Index, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled message key will have.
func (r messageKey) PickleLen() int {
length := libolmpickle.PickleBytesLen(make([]byte, messageKeyLength))
length += libolmpickle.PickleUInt32Len(r.Index)
return length
}

432
vendor/maunium.net/go/mautrix/crypto/goolm/olm/olm.go generated vendored Normal file
View File

@@ -0,0 +1,432 @@
// olm provides the ratchet used by the olm protocol
package olm
import (
"fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
olmPickleVersion uint8 = 1
)
const (
maxReceiverChains = 5
maxSkippedMessageKeys = 40
protocolVersion = 3
messageKeySeed = 0x01
maxMessageGap = 2000
sharedKeyLength = 32
)
// KdfInfo has the infos used for the kdf
var KdfInfo = struct {
Root []byte
Ratchet []byte
}{
Root: []byte("OLM_ROOT"),
Ratchet: []byte("OLM_RATCHET"),
}
var RatchetCipher = cipher.NewAESSHA256([]byte("OLM_KEYS"))
// Ratchet represents the olm ratchet as described in
//
// https://gitlab.matrix.org/matrix-org/olm/-/blob/master/docs/olm.md
type Ratchet struct {
// The root key is used to generate chain keys from the ephemeral keys.
// A new root_key is derived each time a new chain is started.
RootKey crypto.Curve25519PublicKey `json:"root_key"`
// The sender chain is used to send messages. Each time a new ephemeral
// key is received from the remote server we generate a new sender chain
// with a new ephemeral key when we next send a message.
SenderChains senderChain `json:"sender_chain"`
// The receiver chain is used to decrypt received messages. We store the
// last few chains so we can decrypt any out of order messages we haven't
// received yet.
// New chains are prepended for easier access.
ReceiverChains []receiverChain `json:"receiver_chains"`
// Storing the keys of missed messages for future use.
// The order of the elements is not important.
SkippedMessageKeys []skippedMessageKey `json:"skipped_message_keys"`
}
// New creates a new ratchet, setting the kdfInfos and cipher.
func New() *Ratchet {
r := &Ratchet{}
return r
}
// InitializeAsBob initializes this ratchet from a receiving point of view (only first message).
func (r *Ratchet) InitializeAsBob(sharedSecret []byte, theirRatchetKey crypto.Curve25519PublicKey) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
}
r.RootKey = derivedSecrets[0:sharedKeyLength]
newReceiverChain := newReceiverChain(derivedSecrets[sharedKeyLength:], theirRatchetKey)
r.ReceiverChains = append([]receiverChain{*newReceiverChain}, r.ReceiverChains...)
return nil
}
// InitializeAsAlice initializes this ratchet from a sending point of view (only first message).
func (r *Ratchet) InitializeAsAlice(sharedSecret []byte, ourRatchetKey crypto.Curve25519KeyPair) error {
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, nil, KdfInfo.Root)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return err
}
r.RootKey = derivedSecrets[0:sharedKeyLength]
newSenderChain := newSenderChain(derivedSecrets[sharedKeyLength:], ourRatchetKey)
r.SenderChains = *newSenderChain
return nil
}
// Encrypt encrypts the message in a message.Message with MAC. If reader is nil, crypto/rand is used for key generations.
func (r *Ratchet) Encrypt(plaintext []byte, reader io.Reader) ([]byte, error) {
var err error
if !r.SenderChains.IsSet {
newRatchetKey, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return nil, err
}
newChainKey, err := r.advanceRootKey(newRatchetKey, r.ReceiverChains[0].ratchetKey())
if err != nil {
return nil, err
}
newSenderChain := newSenderChain(newChainKey, newRatchetKey)
r.SenderChains = *newSenderChain
}
messageKey := r.createMessageKeys(r.SenderChains.chainKey())
r.SenderChains.advance()
encryptedText, err := RatchetCipher.Encrypt(messageKey.Key, plaintext)
if err != nil {
return nil, fmt.Errorf("cipher encrypt: %w", err)
}
message := &message.Message{}
message.Version = protocolVersion
message.Counter = messageKey.Index
message.RatchetKey = r.SenderChains.ratchetKey().PublicKey
message.Ciphertext = encryptedText
//creating the mac is done in encode
output, err := message.EncodeAndMAC(messageKey.Key, RatchetCipher)
if err != nil {
return nil, err
}
return output, nil
}
// Decrypt decrypts the ciphertext and verifies the MAC. If reader is nil, crypto/rand is used for key generations.
func (r *Ratchet) Decrypt(input []byte) ([]byte, error) {
message := &message.Message{}
//The mac is not verified here, as we do not know the key yet
err := message.Decode(input)
if err != nil {
return nil, err
}
if message.Version != protocolVersion {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion)
}
if !message.HasCounter || len(message.RatchetKey) == 0 || len(message.Ciphertext) == 0 {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
}
var receiverChainFromMessage *receiverChain
for curChainIndex := range r.ReceiverChains {
if r.ReceiverChains[curChainIndex].ratchetKey().Equal(message.RatchetKey) {
receiverChainFromMessage = &r.ReceiverChains[curChainIndex]
break
}
}
var result []byte
if receiverChainFromMessage == nil {
//Advancing the chain is done in this method
result, err = r.decryptForNewChain(message, input)
if err != nil {
return nil, err
}
} else if receiverChainFromMessage.chainKey().Index > message.Counter {
// No need to advance the chain
// Chain already advanced beyond the key for this message
// Check if the message keys are in the skipped key list.
foundSkippedKey := false
for curSkippedIndex := range r.SkippedMessageKeys {
if message.Counter == r.SkippedMessageKeys[curSkippedIndex].MKey.Index {
// Found the key for this message. Check the MAC.
verified, err := message.VerifyMACInline(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, RatchetCipher, input)
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt from skipped message keys: %w", goolm.ErrBadMAC)
}
result, err = RatchetCipher.Decrypt(r.SkippedMessageKeys[curSkippedIndex].MKey.Key, message.Ciphertext)
if err != nil {
return nil, fmt.Errorf("cipher decrypt: %w", err)
}
if len(result) != 0 {
// Remove the key from the skipped keys now that we've
// decoded the message it corresponds to.
r.SkippedMessageKeys[curSkippedIndex] = r.SkippedMessageKeys[len(r.SkippedMessageKeys)-1]
r.SkippedMessageKeys = r.SkippedMessageKeys[:len(r.SkippedMessageKeys)-1]
}
foundSkippedKey = true
}
}
if !foundSkippedKey {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrMessageKeyNotFound)
}
} else {
//Advancing the chain is done in this method
result, err = r.decryptForExistingChain(receiverChainFromMessage, message, input)
if err != nil {
return nil, err
}
}
return result, nil
}
// advanceRootKey created the next root key and returns the next chainKey
func (r *Ratchet) advanceRootKey(newRatchetKey crypto.Curve25519KeyPair, oldRatchetKey crypto.Curve25519PublicKey) (crypto.Curve25519PublicKey, error) {
sharedSecret, err := newRatchetKey.SharedSecret(oldRatchetKey)
if err != nil {
return nil, err
}
derivedSecretsReader := crypto.HKDFSHA256(sharedSecret, r.RootKey, KdfInfo.Ratchet)
derivedSecrets := make([]byte, 2*sharedKeyLength)
if _, err := io.ReadFull(derivedSecretsReader, derivedSecrets); err != nil {
return nil, err
}
r.RootKey = derivedSecrets[:sharedKeyLength]
return derivedSecrets[sharedKeyLength:], nil
}
// createMessageKeys returns the messageKey derived from the chainKey
func (r Ratchet) createMessageKeys(chainKey chainKey) messageKey {
res := messageKey{}
res.Key = crypto.HMACSHA256(chainKey.Key, []byte{messageKeySeed})
res.Index = chainKey.Index
return res
}
// decryptForExistingChain returns the decrypted message by using the chain. The MAC of the rawMessage is verified.
func (r *Ratchet) decryptForExistingChain(chain *receiverChain, message *message.Message, rawMessage []byte) ([]byte, error) {
if message.Counter < chain.CKey.Index {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrChainTooHigh)
}
// Limit the number of hashes we're prepared to compute
if message.Counter-chain.CKey.Index > maxMessageGap {
return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrMsgIndexTooHigh)
}
for chain.CKey.Index < message.Counter {
messageKey := r.createMessageKeys(chain.chainKey())
skippedKey := skippedMessageKey{
MKey: messageKey,
RKey: chain.ratchetKey(),
}
r.SkippedMessageKeys = append(r.SkippedMessageKeys, skippedKey)
chain.advance()
}
messageKey := r.createMessageKeys(chain.chainKey())
chain.advance()
verified, err := message.VerifyMACInline(messageKey.Key, RatchetCipher, rawMessage)
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt from existing chain: %w", goolm.ErrBadMAC)
}
return RatchetCipher.Decrypt(messageKey.Key, message.Ciphertext)
}
// decryptForNewChain returns the decrypted message by creating a new chain and advancing the root key.
func (r *Ratchet) decryptForNewChain(message *message.Message, rawMessage []byte) ([]byte, error) {
// They shouldn't move to a new chain until we've sent them a message
// acknowledging the last one
if !r.SenderChains.IsSet {
return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrProtocolViolation)
}
// Limit the number of hashes we're prepared to compute
if message.Counter > maxMessageGap {
return nil, fmt.Errorf("decrypt for new chain: %w", goolm.ErrMsgIndexTooHigh)
}
newChainKey, err := r.advanceRootKey(r.SenderChains.ratchetKey(), message.RatchetKey)
if err != nil {
return nil, err
}
newChain := newReceiverChain(newChainKey, message.RatchetKey)
r.ReceiverChains = append([]receiverChain{*newChain}, r.ReceiverChains...)
/*
They have started using a new ephemeral ratchet key.
We needed to derive a new set of chain keys.
We can discard our previous ephemeral ratchet key.
We will generate a new key when we send the next message.
*/
r.SenderChains = senderChain{}
decrypted, err := r.decryptForExistingChain(&r.ReceiverChains[0], message, rawMessage)
if err != nil {
return nil, err
}
return decrypted, nil
}
// PickleAsJSON returns a ratchet as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (r Ratchet) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(r, olmPickleVersion, key)
}
// UnpickleAsJSON updates a ratchet by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (r *Ratchet) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(r, pickled, key, olmPickleVersion)
}
// UnpickleLibOlm decodes the unencryted value and populates the Ratchet accordingly. It returns the number of bytes read.
func (r *Ratchet) UnpickleLibOlm(value []byte, includesChainIndex bool) (int, error) {
//read ratchet data
curPos := 0
readBytes, err := r.RootKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
countSenderChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of sender chain
if err != nil {
return 0, err
}
curPos += readBytes
for i := uint32(0); i < countSenderChains; i++ {
if i == 0 {
//only first is stored
readBytes, err := r.SenderChains.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
r.SenderChains.IsSet = true
} else {
dummy := senderChain{}
readBytes, err := dummy.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
}
countReceivChains, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of recevier chain
if err != nil {
return 0, err
}
curPos += readBytes
r.ReceiverChains = make([]receiverChain, countReceivChains)
for i := uint32(0); i < countReceivChains; i++ {
readBytes, err := r.ReceiverChains[i].UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
countSkippedMessageKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:]) //Length of skippedMessageKeys
if err != nil {
return 0, err
}
curPos += readBytes
r.SkippedMessageKeys = make([]skippedMessageKey, countSkippedMessageKeys)
for i := uint32(0); i < countSkippedMessageKeys; i++ {
readBytes, err := r.SkippedMessageKeys[i].UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
// pickle v 0x80000001 includes a chain index; pickle v1 does not.
if includesChainIndex {
_, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
return curPos, nil
}
// PickleLibOlm encodes the ratchet into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r Ratchet) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle ratchet: %w", goolm.ErrValueTooShort)
}
written, err := r.RootKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
if r.SenderChains.IsSet {
written += libolmpickle.PickleUInt32(1, target[written:]) //Length of sender chain, always 1
writtenSender, err := r.SenderChains.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
written += writtenSender
} else {
written += libolmpickle.PickleUInt32(0, target[written:]) //Length of sender chain
}
written += libolmpickle.PickleUInt32(uint32(len(r.ReceiverChains)), target[written:])
for _, curChain := range r.ReceiverChains {
writtenChain, err := curChain.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
written += writtenChain
}
written += libolmpickle.PickleUInt32(uint32(len(r.SkippedMessageKeys)), target[written:])
for _, curChain := range r.SkippedMessageKeys {
writtenChain, err := curChain.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle ratchet: %w", err)
}
written += writtenChain
}
return written, nil
}
// PickleLen returns the actual number of bytes the pickled ratchet will have.
func (r Ratchet) PickleLen() int {
length := r.RootKey.PickleLen()
if r.SenderChains.IsSet {
length += libolmpickle.PickleUInt32Len(1)
length += r.SenderChains.PickleLen()
} else {
length += libolmpickle.PickleUInt32Len(0)
}
length += libolmpickle.PickleUInt32Len(uint32(len(r.ReceiverChains)))
length += len(r.ReceiverChains) * receiverChain{}.PickleLen()
length += libolmpickle.PickleUInt32Len(uint32(len(r.SkippedMessageKeys)))
length += len(r.SkippedMessageKeys) * skippedMessageKey{}.PickleLen()
return length
}
// PickleLen returns the minimum number of bytes the pickled ratchet must have.
func (r Ratchet) PickleLenMin() int {
length := r.RootKey.PickleLen()
length += libolmpickle.PickleUInt32Len(0)
length += libolmpickle.PickleUInt32Len(0)
length += libolmpickle.PickleUInt32Len(0)
return length
}

View File

@@ -0,0 +1,55 @@
package olm
import (
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// skippedMessageKey stores a skipped message key
type skippedMessageKey struct {
RKey crypto.Curve25519PublicKey `json:"ratchet_key"`
MKey messageKey `json:"message_key"`
}
// UnpickleLibOlm decodes the unencryted value and populates the chain accordingly. It returns the number of bytes read.
func (r *skippedMessageKey) UnpickleLibOlm(value []byte) (int, error) {
curPos := 0
readBytes, err := r.RKey.UnpickleLibOlm(value)
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = r.MKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// PickleLibOlm encodes the chain into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (r skippedMessageKey) PickleLibOlm(target []byte) (int, error) {
if len(target) < r.PickleLen() {
return 0, fmt.Errorf("pickle sender chain: %w", goolm.ErrValueTooShort)
}
written, err := r.RKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
writtenChain, err := r.MKey.PickleLibOlm(target)
if err != nil {
return 0, fmt.Errorf("pickle sender chain: %w", err)
}
written += writtenChain
return written, nil
}
// PickleLen returns the number of bytes the pickled chain will have.
func (r skippedMessageKey) PickleLen() int {
length := r.RKey.PickleLen()
length += r.MKey.PickleLen()
return length
}

View File

@@ -0,0 +1,165 @@
package pk
import (
"encoding/base64"
"errors"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
const (
decryptionPickleVersionJSON uint8 = 1
decryptionPickleVersionLibOlm uint32 = 1
)
// Decryption is used to decrypt pk messages
type Decryption struct {
KeyPair crypto.Curve25519KeyPair `json:"key_pair"`
}
// NewDecryption returns a new Decryption with a new generated key pair.
func NewDecryption() (*Decryption, error) {
keyPair, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
return &Decryption{
KeyPair: keyPair,
}, nil
}
// NewDescriptionFromPrivate resturns a new Decryption with the private key fixed.
func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decryption, error) {
s := &Decryption{}
keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey)
if err != nil {
return nil, err
}
s.KeyPair = keyPair
return s, nil
}
// PubKey returns the public key base 64 encoded.
func (s Decryption) PubKey() id.Curve25519 {
return s.KeyPair.B64Encoded()
}
// PrivateKey returns the private key.
func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey {
return s.KeyPair.PrivateKey
}
// Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret.
func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) {
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
if err != nil {
return nil, err
}
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
if err != nil {
return nil, err
}
decodedMAC, err := goolm.Base64Decode(mac)
if err != nil {
return nil, err
}
cipher := cipher.NewAESSHA256(nil)
verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC)
if err != nil {
return nil, err
}
if !verified {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC)
}
plaintext, err := cipher.Decrypt(sharedSecret, ciphertext)
if err != nil {
return nil, err
}
return plaintext, nil
}
// PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key)
}
// UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON)
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (a *Decryption) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = a.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read.
func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case decryptionPickleVersionLibOlm:
default:
return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion)
}
readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm().
func (a Decryption) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, a.PickleLen())
written, err := a.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (a Decryption) PickleLibOlm(target []byte) (int, error) {
if len(target) < a.PickleLen() {
return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target)
writtenKey, err := a.KeyPair.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle Decryption: %w", err)
}
written += writtenKey
return written, nil
}
// PickleLen returns the number of bytes the pickled Decryption will have.
func (a Decryption) PickleLen() int {
length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm)
length += a.KeyPair.PickleLen()
return length
}

View File

@@ -0,0 +1,49 @@
package pk
import (
"encoding/base64"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// Encryption is used to encrypt pk messages
type Encryption struct {
RecipientKey crypto.Curve25519PublicKey `json:"recipient_key"`
}
// NewEncryption returns a new Encryption with the base64 encoded public key of the recipient
func NewEncryption(pubKey id.Curve25519) (*Encryption, error) {
pubKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(pubKey))
if err != nil {
return nil, err
}
return &Encryption{
RecipientKey: pubKeyDecoded,
}, nil
}
// Encrypt encrypts the plaintext with the privateKey and returns the ciphertext and base64 encoded MAC.
func (e Encryption) Encrypt(plaintext []byte, privateKey crypto.Curve25519PrivateKey) (ciphertext, mac []byte, err error) {
keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey)
if err != nil {
return nil, nil, err
}
sharedSecret, err := keyPair.SharedSecret(e.RecipientKey)
if err != nil {
return nil, nil, err
}
cipher := cipher.NewAESSHA256(nil)
ciphertext, err = cipher.Encrypt(sharedSecret, plaintext)
if err != nil {
return nil, nil, err
}
mac, err = cipher.MAC(sharedSecret, ciphertext)
if err != nil {
return nil, nil, err
}
return ciphertext, goolm.Base64Encode(mac), nil
}

View File

@@ -0,0 +1,44 @@
package pk
import (
"crypto/rand"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id"
)
// Signing is used for signing a pk
type Signing struct {
KeyPair crypto.Ed25519KeyPair `json:"key_pair"`
Seed []byte `json:"seed"`
}
// NewSigningFromSeed constructs a new Signing based on a seed.
func NewSigningFromSeed(seed []byte) (*Signing, error) {
s := &Signing{}
s.Seed = seed
s.KeyPair = crypto.Ed25519GenerateFromSeed(seed)
return s, nil
}
// NewSigning returns a Signing based on a random seed
func NewSigning() (*Signing, error) {
seed := make([]byte, 32)
_, err := rand.Read(seed)
if err != nil {
return nil, err
}
return NewSigningFromSeed(seed)
}
// Sign returns the signature of the message base64 encoded.
func (s Signing) Sign(message []byte) []byte {
signature := s.KeyPair.Sign(message)
return goolm.Base64Encode(signature)
}
// PublicKey returns the public key of the key pair base 64 encoded.
func (s Signing) PublicKey() id.Ed25519 {
return s.KeyPair.B64Encoded()
}

76
vendor/maunium.net/go/mautrix/crypto/goolm/sas/main.go generated vendored Normal file
View File

@@ -0,0 +1,76 @@
// sas provides the means to do SAS between keys
package sas
import (
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// SAS contains the key pair and secret for SAS.
type SAS struct {
KeyPair crypto.Curve25519KeyPair
Secret []byte
}
// New creates a new SAS with a new key pair.
func New() (*SAS, error) {
kp, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
s := &SAS{
KeyPair: kp,
}
return s, nil
}
// GetPubkey returns the public key of the key pair base64 encoded
func (s SAS) GetPubkey() []byte {
return goolm.Base64Encode(s.KeyPair.PublicKey)
}
// SetTheirKey sets the key of the other party and computes the shared secret.
func (s *SAS) SetTheirKey(key []byte) error {
keyDecoded, err := goolm.Base64Decode(key)
if err != nil {
return err
}
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
if err != nil {
return err
}
s.Secret = sharedSecret
return nil
}
// GenerateBytes creates length bytes from the shared secret and info.
func (s SAS) GenerateBytes(info []byte, length uint) ([]byte, error) {
byteReader := crypto.HKDFSHA256(s.Secret, nil, info)
output := make([]byte, length)
if _, err := io.ReadFull(byteReader, output); err != nil {
return nil, err
}
return output, nil
}
// calculateMAC returns a base64 encoded MAC of input.
func (s *SAS) calculateMAC(input, info []byte, length uint) ([]byte, error) {
key, err := s.GenerateBytes(info, length)
if err != nil {
return nil, err
}
mac := crypto.HMACSHA256(key, input)
return goolm.Base64Encode(mac), nil
}
// CalculateMACFixes returns a base64 encoded, 32 byte long MAC of input.
func (s SAS) CalculateMAC(input, info []byte) ([]byte, error) {
return s.calculateMAC(input, info, 32)
}
// CalculateMACLongKDF returns a base64 encoded, 256 byte long MAC of input.
func (s SAS) CalculateMACLongKDF(input, info []byte) ([]byte, error) {
return s.calculateMAC(input, info, 256)
}

View File

@@ -0,0 +1,2 @@
// session provides the different types of sessions for en/decrypting of messages
package session

View File

@@ -0,0 +1,276 @@
package session
import (
"encoding/base64"
"errors"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/megolm"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
const (
megolmInboundSessionPickleVersionJSON byte = 1
megolmInboundSessionPickleVersionLibOlm uint32 = 2
)
// MegolmInboundSession stores information about the sessions of receive.
type MegolmInboundSession struct {
Ratchet megolm.Ratchet `json:"ratchet"`
SigningKey crypto.Ed25519PublicKey `json:"signing_key"`
InitialRatchet megolm.Ratchet `json:"initial_ratchet"`
SigningKeyVerified bool `json:"signing_key_verified"` //not used for now
}
// NewMegolmInboundSession creates a new MegolmInboundSession from a base64 encoded session sharing message.
func NewMegolmInboundSession(input []byte) (*MegolmInboundSession, error) {
var err error
input, err = goolm.Base64Decode(input)
if err != nil {
return nil, err
}
msg := message.MegolmSessionSharing{}
err = msg.VerifyAndDecode(input)
if err != nil {
return nil, err
}
o := &MegolmInboundSession{}
o.SigningKey = msg.PublicKey
o.SigningKeyVerified = true
ratchet, err := megolm.New(msg.Counter, msg.RatchetData)
if err != nil {
return nil, err
}
o.Ratchet = *ratchet
o.InitialRatchet = *ratchet
return o, nil
}
// NewMegolmInboundSessionFromExport creates a new MegolmInboundSession from a base64 encoded session export message.
func NewMegolmInboundSessionFromExport(input []byte) (*MegolmInboundSession, error) {
var err error
input, err = goolm.Base64Decode(input)
if err != nil {
return nil, err
}
msg := message.MegolmSessionExport{}
err = msg.Decode(input)
if err != nil {
return nil, err
}
o := &MegolmInboundSession{}
o.SigningKey = msg.PublicKey
ratchet, err := megolm.New(msg.Counter, msg.RatchetData)
if err != nil {
return nil, err
}
o.Ratchet = *ratchet
o.InitialRatchet = *ratchet
return o, nil
}
// MegolmInboundSessionFromPickled loads the MegolmInboundSession details from a pickled base64 string. The input is decrypted with the supplied key.
func MegolmInboundSessionFromPickled(pickled, key []byte) (*MegolmInboundSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("megolmInboundSessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &MegolmInboundSession{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// getRatchet tries to find the correct ratchet for a messageIndex.
func (o MegolmInboundSession) getRatchet(messageIndex uint32) (*megolm.Ratchet, error) {
// pick a megolm instance to use. if we are at or beyond the latest ratchet value, use that
if (messageIndex - o.Ratchet.Counter) < uint32(1<<31) {
o.Ratchet.AdvanceTo(messageIndex)
return &o.Ratchet, nil
}
if (messageIndex - o.InitialRatchet.Counter) >= uint32(1<<31) {
// the counter is before our initial ratchet - we can't decode this
return nil, fmt.Errorf("decrypt: %w", goolm.ErrRatchetNotAvailable)
}
// otherwise, start from the initial ratchet. Take a copy so that we don't overwrite the initial ratchet
copiedRatchet := o.InitialRatchet
copiedRatchet.AdvanceTo(messageIndex)
return &copiedRatchet, nil
}
// Decrypt decrypts a base64 encoded group message.
func (o *MegolmInboundSession) Decrypt(ciphertext []byte) ([]byte, uint32, error) {
if o.SigningKey == nil {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
}
decoded, err := goolm.Base64Decode(ciphertext)
if err != nil {
return nil, 0, err
}
msg := &message.GroupMessage{}
err = msg.Decode(decoded)
if err != nil {
return nil, 0, err
}
if msg.Version != protocolVersion {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrWrongProtocolVersion)
}
if msg.Ciphertext == nil || !msg.HasMessageIndex {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadMessageFormat)
}
// verify signature
verifiedSignature := msg.VerifySignatureInline(o.SigningKey, decoded)
if !verifiedSignature {
return nil, 0, fmt.Errorf("decrypt: %w", goolm.ErrBadSignature)
}
targetRatch, err := o.getRatchet(msg.MessageIndex)
if err != nil {
return nil, 0, err
}
decrypted, err := targetRatch.Decrypt(decoded, &o.SigningKey, msg)
if err != nil {
return nil, 0, err
}
o.SigningKeyVerified = true
return decrypted, msg.MessageIndex, nil
}
// SessionID returns the base64 endoded signing key
func (o MegolmInboundSession) SessionID() id.SessionID {
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey))
}
// PickleAsJSON returns an MegolmInboundSession as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (o MegolmInboundSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(o, megolmInboundSessionPickleVersionJSON, key)
}
// UnpickleAsJSON updates an MegolmInboundSession by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (o *MegolmInboundSession) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(o, pickled, key, megolmInboundSessionPickleVersionJSON)
}
// SessionExportMessage creates an base64 encoded export of the session.
func (o MegolmInboundSession) SessionExportMessage(messageIndex uint32) ([]byte, error) {
ratchet, err := o.getRatchet(messageIndex)
if err != nil {
return nil, err
}
return ratchet.SessionExportMessage(o.SigningKey)
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *MegolmInboundSession) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = o.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
func (o *MegolmInboundSession) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case megolmInboundSessionPickleVersionLibOlm, 1:
default:
return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion)
}
readBytes, err := o.InitialRatchet.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if pickledVersion == 1 {
// pickle v1 had no signing_key_verified field (all keyshares were verified at import time)
o.SigningKeyVerified = true
} else {
o.SigningKeyVerified, readBytes, err = libolmpickle.UnpickleBool(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled MegolmInboundSession using PickleLibOlm().
func (o MegolmInboundSession) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o MegolmInboundSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(megolmInboundSessionPickleVersionLibOlm, target)
writtenInitRatchet, err := o.InitialRatchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
}
written += writtenInitRatchet
writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
}
written += writtenRatchet
writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmInboundSession: %w", err)
}
written += writtenPubKey
written += libolmpickle.PickleBool(o.SigningKeyVerified, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled session will have.
func (o MegolmInboundSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(megolmInboundSessionPickleVersionLibOlm)
length += o.InitialRatchet.PickleLen()
length += o.Ratchet.PickleLen()
length += o.SigningKey.PickleLen()
length += libolmpickle.PickleBoolLen(o.SigningKeyVerified)
return length
}

View File

@@ -0,0 +1,171 @@
package session
import (
"crypto/rand"
"encoding/base64"
"errors"
"fmt"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/megolm"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
megolmOutboundSessionPickleVersion byte = 1
megolmOutboundSessionPickleVersionLibOlm uint32 = 1
)
// MegolmOutboundSession stores information about the sessions to send.
type MegolmOutboundSession struct {
Ratchet megolm.Ratchet `json:"ratchet"`
SigningKey crypto.Ed25519KeyPair `json:"signing_key"`
}
// NewMegolmOutboundSession creates a new MegolmOutboundSession.
func NewMegolmOutboundSession() (*MegolmOutboundSession, error) {
o := &MegolmOutboundSession{}
var err error
o.SigningKey, err = crypto.Ed25519GenerateKey(nil)
if err != nil {
return nil, err
}
var randomData [megolm.RatchetParts * megolm.RatchetPartLength]byte
_, err = rand.Read(randomData[:])
if err != nil {
return nil, err
}
ratchet, err := megolm.New(0, randomData)
if err != nil {
return nil, err
}
o.Ratchet = *ratchet
return o, nil
}
// MegolmOutboundSessionFromPickled loads the MegolmOutboundSession details from a pickled base64 string. The input is decrypted with the supplied key.
func MegolmOutboundSessionFromPickled(pickled, key []byte) (*MegolmOutboundSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("megolmOutboundSessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &MegolmOutboundSession{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// Encrypt encrypts the plaintext as a base64 encoded group message.
func (o *MegolmOutboundSession) Encrypt(plaintext []byte) ([]byte, error) {
encrypted, err := o.Ratchet.Encrypt(plaintext, &o.SigningKey)
if err != nil {
return nil, err
}
return goolm.Base64Encode(encrypted), nil
}
// SessionID returns the base64 endoded public signing key
func (o MegolmOutboundSession) SessionID() id.SessionID {
return id.SessionID(base64.RawStdEncoding.EncodeToString(o.SigningKey.PublicKey))
}
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (o MegolmOutboundSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(o, megolmOutboundSessionPickleVersion, key)
}
// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format.
func (o *MegolmOutboundSession) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(o, pickled, key, megolmOutboundSessionPickleVersion)
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *MegolmOutboundSession) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = o.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
func (o *MegolmOutboundSession) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case megolmOutboundSessionPickleVersionLibOlm:
default:
return 0, fmt.Errorf("unpickle MegolmInboundSession: %w", goolm.ErrBadVersion)
}
readBytes, err := o.Ratchet.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.SigningKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled MegolmOutboundSession using PickleLibOlm().
func (o MegolmOutboundSession) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o MegolmOutboundSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(megolmOutboundSessionPickleVersionLibOlm, target)
writtenRatchet, err := o.Ratchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenPubKey, err := o.SigningKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenPubKey
return written, nil
}
// PickleLen returns the number of bytes the pickled session will have.
func (o MegolmOutboundSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(megolmOutboundSessionPickleVersionLibOlm)
length += o.Ratchet.PickleLen()
length += o.SigningKey.PickleLen()
return length
}
func (o MegolmOutboundSession) SessionSharingMessage() ([]byte, error) {
return o.Ratchet.SessionSharingMessage(o.SigningKey)
}

View File

@@ -0,0 +1,476 @@
package session
import (
"bytes"
"encoding/base64"
"errors"
"fmt"
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/message"
"maunium.net/go/mautrix/crypto/goolm/olm"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
const (
olmSessionPickleVersionJSON uint8 = 1
olmSessionPickleVersionLibOlm uint32 = 1
)
const (
protocolVersion = 0x3
)
// OlmSession stores all information for an olm session
type OlmSession struct {
ReceivedMessage bool `json:"received_message"`
AliceIdentityKey crypto.Curve25519PublicKey `json:"alice_id_key"`
AliceBaseKey crypto.Curve25519PublicKey `json:"alice_base_key"`
BobOneTimeKey crypto.Curve25519PublicKey `json:"bob_one_time_key"`
Ratchet olm.Ratchet `json:"ratchet"`
}
// SearchOTKFunc is used to retrieve a crypto.OneTimeKey from a public key.
type SearchOTKFunc = func(crypto.Curve25519PublicKey) *crypto.OneTimeKey
// OlmSessionFromJSONPickled loads an OlmSession from a pickled base64 string. Decrypts
// the Session using the supplied key.
func OlmSessionFromJSONPickled(pickled, key []byte) (*OlmSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &OlmSession{}
err := a.UnpickleAsJSON(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// OlmSessionFromPickled loads the OlmSession details from a pickled base64 string. The input is decrypted with the supplied key.
func OlmSessionFromPickled(pickled, key []byte) (*OlmSession, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("sessionFromPickled: %w", goolm.ErrEmptyInput)
}
a := &OlmSession{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// NewOlmSession creates a new Session.
func NewOlmSession() *OlmSession {
s := &OlmSession{}
s.Ratchet = *olm.New()
return s
}
// NewOutboundOlmSession creates a new outbound session for sending the first message to a
// given curve25519 identityKey and oneTimeKey.
func NewOutboundOlmSession(identityKeyAlice crypto.Curve25519KeyPair, identityKeyBob crypto.Curve25519PublicKey, oneTimeKeyBob crypto.Curve25519PublicKey) (*OlmSession, error) {
s := NewOlmSession()
//generate E_A
baseKey, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
//generate T_0
ratchetKey, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
//Calculate shared secret via Triple Diffie-Hellman
var secret []byte
//ECDH(I_A,E_B)
idSecret, err := identityKeyAlice.SharedSecret(oneTimeKeyBob)
if err != nil {
return nil, err
}
//ECDH(E_A,I_B)
baseIdSecret, err := baseKey.SharedSecret(identityKeyBob)
if err != nil {
return nil, err
}
//ECDH(E_A,E_B)
baseOneTimeSecret, err := baseKey.SharedSecret(oneTimeKeyBob)
if err != nil {
return nil, err
}
secret = append(secret, idSecret...)
secret = append(secret, baseIdSecret...)
secret = append(secret, baseOneTimeSecret...)
//Init Ratchet
s.Ratchet.InitializeAsAlice(secret, ratchetKey)
s.AliceIdentityKey = identityKeyAlice.PublicKey
s.AliceBaseKey = baseKey.PublicKey
s.BobOneTimeKey = oneTimeKeyBob
return s, nil
}
// NewInboundOlmSession creates a new inbound session from receiving the first message.
func NewInboundOlmSession(identityKeyAlice *crypto.Curve25519PublicKey, receivedOTKMsg []byte, searchBobOTK SearchOTKFunc, identityKeyBob crypto.Curve25519KeyPair) (*OlmSession, error) {
decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg)
if err != nil {
return nil, err
}
s := NewOlmSession()
//decode OneTimeKeyMessage
oneTimeMsg := message.PreKeyMessage{}
err = oneTimeMsg.Decode(decodedOTKMsg)
if err != nil {
return nil, fmt.Errorf("OneTimeKeyMessage decode: %w", err)
}
if !oneTimeMsg.CheckFields(identityKeyAlice) {
return nil, fmt.Errorf("OneTimeKeyMessage check fields: %w", goolm.ErrBadMessageFormat)
}
//Either the identityKeyAlice is set and/or the oneTimeMsg.IdentityKey is set, which is checked
// by oneTimeMsg.CheckFields
if identityKeyAlice != nil && len(oneTimeMsg.IdentityKey) != 0 {
//if both are set, compare them
if !identityKeyAlice.Equal(oneTimeMsg.IdentityKey) {
return nil, fmt.Errorf("OneTimeKeyMessage identity keys: %w", goolm.ErrBadMessageKeyID)
}
}
if identityKeyAlice == nil {
//for downstream use set
identityKeyAlice = &oneTimeMsg.IdentityKey
}
oneTimeKeyBob := searchBobOTK(oneTimeMsg.OneTimeKey)
if oneTimeKeyBob == nil {
return nil, fmt.Errorf("ourOneTimeKey: %w", goolm.ErrBadMessageKeyID)
}
//Calculate shared secret via Triple Diffie-Hellman
var secret []byte
//ECDH(E_B,I_A)
idSecret, err := oneTimeKeyBob.Key.SharedSecret(*identityKeyAlice)
if err != nil {
return nil, err
}
//ECDH(I_B,E_A)
baseIdSecret, err := identityKeyBob.SharedSecret(oneTimeMsg.BaseKey)
if err != nil {
return nil, err
}
//ECDH(E_B,E_A)
baseOneTimeSecret, err := oneTimeKeyBob.Key.SharedSecret(oneTimeMsg.BaseKey)
if err != nil {
return nil, err
}
secret = append(secret, idSecret...)
secret = append(secret, baseIdSecret...)
secret = append(secret, baseOneTimeSecret...)
//decode message
msg := message.Message{}
err = msg.Decode(oneTimeMsg.Message)
if err != nil {
return nil, fmt.Errorf("Message decode: %w", err)
}
if len(msg.RatchetKey) == 0 {
return nil, fmt.Errorf("Message missing ratchet key: %w", goolm.ErrBadMessageFormat)
}
//Init Ratchet
s.Ratchet.InitializeAsBob(secret, msg.RatchetKey)
s.AliceBaseKey = oneTimeMsg.BaseKey
s.AliceIdentityKey = oneTimeMsg.IdentityKey
s.BobOneTimeKey = oneTimeKeyBob.Key.PublicKey
//https://gitlab.matrix.org/matrix-org/olm/blob/master/docs/olm.md states to remove the oneTimeKey
//this is done via the account itself
return s, nil
}
// PickleAsJSON returns an Session as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a OlmSession) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, olmSessionPickleVersionJSON, key)
}
// UnpickleAsJSON updates an Session by a base64 encrypted string with the key. The unencrypted representation has to be in JSON format.
func (a *OlmSession) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, olmSessionPickleVersionJSON)
}
// ID returns an identifier for this Session. Will be the same for both ends of the conversation.
// Generated by hashing the public keys used to create the session.
func (s OlmSession) ID() id.SessionID {
message := make([]byte, 3*crypto.Curve25519KeyLength)
copy(message, s.AliceIdentityKey)
copy(message[crypto.Curve25519KeyLength:], s.AliceBaseKey)
copy(message[2*crypto.Curve25519KeyLength:], s.BobOneTimeKey)
hash := crypto.SHA256(message)
res := id.SessionID(goolm.Base64Encode(hash))
return res
}
// HasReceivedMessage returns true if this session has received any message.
func (s OlmSession) HasReceivedMessage() bool {
return s.ReceivedMessage
}
// MatchesInboundSessionFrom checks if the oneTimeKeyMsg message is set for this inbound
// Session. This can happen if multiple messages are sent to this Account
// before this Account sends a message in reply. Returns true if the session
// matches. Returns false if the session does not match.
func (s OlmSession) MatchesInboundSessionFrom(theirIdentityKeyEncoded *id.Curve25519, receivedOTKMsg []byte) (bool, error) {
if len(receivedOTKMsg) == 0 {
return false, fmt.Errorf("inbound match: %w", goolm.ErrEmptyInput)
}
decodedOTKMsg, err := goolm.Base64Decode(receivedOTKMsg)
if err != nil {
return false, err
}
var theirIdentityKey *crypto.Curve25519PublicKey
if theirIdentityKeyEncoded != nil {
decodedKey, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKeyEncoded))
if err != nil {
return false, err
}
theirIdentityKeyByte := crypto.Curve25519PublicKey(decodedKey)
theirIdentityKey = &theirIdentityKeyByte
}
msg := message.PreKeyMessage{}
err = msg.Decode(decodedOTKMsg)
if err != nil {
return false, err
}
if !msg.CheckFields(theirIdentityKey) {
return false, nil
}
same := true
if msg.IdentityKey != nil {
same = same && msg.IdentityKey.Equal(s.AliceIdentityKey)
}
if theirIdentityKey != nil {
same = same && theirIdentityKey.Equal(s.AliceIdentityKey)
}
same = same && bytes.Equal(msg.BaseKey, s.AliceBaseKey)
same = same && bytes.Equal(msg.OneTimeKey, s.BobOneTimeKey)
return same, nil
}
// EncryptMsgType returns the type of the next message that Encrypt will
// return. Returns MsgTypePreKey if the message will be a oneTimeKeyMsg.
// Returns MsgTypeMsg if the message will be a normal message.
func (s OlmSession) EncryptMsgType() id.OlmMsgType {
if s.ReceivedMessage {
return id.OlmMsgTypeMsg
}
return id.OlmMsgTypePreKey
}
// Encrypt encrypts a message using the Session. Returns the encrypted message base64 encoded. If reader is nil, crypto/rand is used for key generations.
func (s *OlmSession) Encrypt(plaintext []byte, reader io.Reader) (id.OlmMsgType, []byte, error) {
if len(plaintext) == 0 {
return 0, nil, fmt.Errorf("encrypt: %w", goolm.ErrEmptyInput)
}
messageType := s.EncryptMsgType()
encrypted, err := s.Ratchet.Encrypt(plaintext, reader)
if err != nil {
return 0, nil, err
}
result := encrypted
if !s.ReceivedMessage {
msg := message.PreKeyMessage{}
msg.Version = protocolVersion
msg.OneTimeKey = s.BobOneTimeKey
msg.IdentityKey = s.AliceIdentityKey
msg.BaseKey = s.AliceBaseKey
msg.Message = encrypted
var err error
messageBody, err := msg.Encode()
if err != nil {
return 0, nil, err
}
result = messageBody
}
return messageType, goolm.Base64Encode(result), nil
}
// Decrypt decrypts a base64 encoded message using the Session.
func (s *OlmSession) Decrypt(crypttext []byte, msgType id.OlmMsgType) ([]byte, error) {
if len(crypttext) == 0 {
return nil, fmt.Errorf("decrypt: %w", goolm.ErrEmptyInput)
}
decodedCrypttext, err := goolm.Base64Decode(crypttext)
if err != nil {
return nil, err
}
msgBody := decodedCrypttext
if msgType != id.OlmMsgTypeMsg {
//Pre-Key Message
msg := message.PreKeyMessage{}
err := msg.Decode(decodedCrypttext)
if err != nil {
return nil, err
}
msgBody = msg.Message
}
plaintext, err := s.Ratchet.Decrypt(msgBody)
if err != nil {
return nil, err
}
s.ReceivedMessage = true
return plaintext, nil
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (o *OlmSession) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = o.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Session accordingly. It returns the number of bytes read.
func (o *OlmSession) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
includesChainIndex := true
switch pickledVersion {
case olmSessionPickleVersionLibOlm:
includesChainIndex = false
case uint32(0x80000001):
includesChainIndex = true
default:
return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion)
}
var readBytes int
o.ReceivedMessage, readBytes, err = libolmpickle.UnpickleBool(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.AliceIdentityKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.AliceBaseKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.BobOneTimeKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = o.Ratchet.UnpickleLibOlm(value[curPos:], includesChainIndex)
if err != nil {
return 0, err
}
curPos += readBytes
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled olmSession using PickleLibOlm().
func (o OlmSession) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, o.PickleLen())
written, err := o.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the session into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (o OlmSession) PickleLibOlm(target []byte) (int, error) {
if len(target) < o.PickleLen() {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(olmSessionPickleVersionLibOlm, target)
written += libolmpickle.PickleBool(o.ReceivedMessage, target[written:])
writtenRatchet, err := o.AliceIdentityKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenRatchet, err = o.AliceBaseKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenRatchet, err = o.BobOneTimeKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
writtenRatchet, err = o.Ratchet.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle MegolmOutboundSession: %w", err)
}
written += writtenRatchet
return written, nil
}
// PickleLen returns the actual number of bytes the pickled session will have.
func (o OlmSession) PickleLen() int {
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
length += o.AliceIdentityKey.PickleLen()
length += o.AliceBaseKey.PickleLen()
length += o.BobOneTimeKey.PickleLen()
length += o.Ratchet.PickleLen()
return length
}
// PickleLenMin returns the minimum number of bytes the pickled session must have.
func (o OlmSession) PickleLenMin() int {
length := libolmpickle.PickleUInt32Len(olmSessionPickleVersionLibOlm)
length += libolmpickle.PickleBoolLen(o.ReceivedMessage)
length += o.AliceIdentityKey.PickleLen()
length += o.AliceBaseKey.PickleLen()
length += o.BobOneTimeKey.PickleLen()
length += o.Ratchet.PickleLenMin()
return length
}
// Describe returns a string describing the current state of the session for debugging.
func (o OlmSession) Describe() string {
var res string
if o.Ratchet.SenderChains.IsSet {
res += fmt.Sprintf("sender chain index: %d ", o.Ratchet.SenderChains.CKey.Index)
} else {
res += "sender chain index: "
}
res += "receiver chain indicies:"
for _, curChain := range o.Ratchet.ReceiverChains {
res += fmt.Sprintf(" %d", curChain.CKey.Index)
}
res += " skipped message keys:"
for _, curSkip := range o.Ratchet.SkippedMessageKeys {
res += fmt.Sprintf(" %d", curSkip.MKey.Index)
}
return res
}

View File

@@ -0,0 +1,23 @@
package utilities
import (
"encoding/base64"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id"
)
// VerifySignature verifies an ed25519 signature.
func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) {
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
if err != nil {
return false, err
}
signatureDecoded, err := goolm.Base64Decode(signature)
if err != nil {
return false, err
}
publicKey := crypto.Ed25519PublicKey(keyDecoded)
return publicKey.Verify(message, signatureDecoded), nil
}

View File

@@ -0,0 +1,60 @@
package utilities
import (
"encoding/json"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
)
// PickleAsJSON returns an object as a base64 string encrypted using the supplied key. The unencrypted representation of the object is in JSON format.
func PickleAsJSON(object any, pickleVersion byte, key []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("pickle: %w", goolm.ErrNoKeyProvided)
}
marshaled, err := json.Marshal(object)
if err != nil {
return nil, fmt.Errorf("pickle marshal: %w", err)
}
marshaled = append([]byte{pickleVersion}, marshaled...)
toEncrypt := make([]byte, len(marshaled))
copy(toEncrypt, marshaled)
//pad marshaled to get block size
if len(marshaled)%cipher.PickleBlockSize() != 0 {
padding := cipher.PickleBlockSize() - len(marshaled)%cipher.PickleBlockSize()
toEncrypt = make([]byte, len(marshaled)+padding)
copy(toEncrypt, marshaled)
}
encrypted, err := cipher.Pickle(key, toEncrypt)
if err != nil {
return nil, fmt.Errorf("pickle encrypt: %w", err)
}
return encrypted, nil
}
// UnpickleAsJSON updates the object by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func UnpickleAsJSON(object any, pickled, key []byte, pickleVersion byte) error {
if len(key) == 0 {
return fmt.Errorf("unpickle: %w", goolm.ErrNoKeyProvided)
}
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return fmt.Errorf("unpickle decrypt: %w", err)
}
//unpad decrypted so unmarshal works
for i := len(decrypted) - 1; i >= 0; i-- {
if decrypted[i] != 0 {
decrypted = decrypted[:i+1]
break
}
}
if decrypted[0] != pickleVersion {
return fmt.Errorf("unpickle: %w", goolm.ErrWrongPickleVersion)
}
err = json.Unmarshal(decrypted[1:], object)
if err != nil {
return fmt.Errorf("unpickle unmarshal: %w", err)
}
return nil
}