Files
postmoogle/vendor/maunium.net/go/mautrix/crypto/goolm/account/account.go
2024-02-11 20:47:04 +02:00

523 lines
17 KiB
Go

// account packages an account which stores the identity, one time keys and fallback keys.
package account
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"io"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
"maunium.net/go/mautrix/crypto/goolm/session"
"maunium.net/go/mautrix/crypto/goolm/utilities"
)
const (
accountPickleVersionJSON byte = 1
accountPickleVersionLibOLM uint32 = 4
MaxOneTimeKeys int = 100 //maximum number of stored one time keys per Account
)
// Account stores an account for end to end encrypted messaging via the olm protocol.
// An Account can not be used to en/decrypt messages. However it can be used to contruct new olm sessions, which in turn do the en/decryption.
// There is no tracking of sessions in an account.
type Account struct {
IdKeys struct {
Ed25519 crypto.Ed25519KeyPair `json:"ed25519,omitempty"`
Curve25519 crypto.Curve25519KeyPair `json:"curve25519,omitempty"`
} `json:"identity_keys"`
OTKeys []crypto.OneTimeKey `json:"one_time_keys"`
CurrentFallbackKey crypto.OneTimeKey `json:"current_fallback_key,omitempty"`
PrevFallbackKey crypto.OneTimeKey `json:"prev_fallback_key,omitempty"`
NextOneTimeKeyID uint32 `json:"next_one_time_key_id,omitempty"`
NumFallbackKeys uint8 `json:"number_fallback_keys"`
}
// AccountFromJSONPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
func AccountFromJSONPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
}
a := &Account{}
err := a.UnpickleAsJSON(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// AccountFromPickled loads the Account details from a pickled base64 string. The input is decrypted with the supplied key.
func AccountFromPickled(pickled, key []byte) (*Account, error) {
if len(pickled) == 0 {
return nil, fmt.Errorf("accountFromPickled: %w", goolm.ErrEmptyInput)
}
a := &Account{}
err := a.Unpickle(pickled, key)
if err != nil {
return nil, err
}
return a, nil
}
// NewAccount creates a new Account. If reader is nil, crypto/rand is used for the key creation.
func NewAccount(reader io.Reader) (*Account, error) {
a := &Account{}
kPEd25519, err := crypto.Ed25519GenerateKey(reader)
if err != nil {
return nil, err
}
a.IdKeys.Ed25519 = kPEd25519
kPCurve25519, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return nil, err
}
a.IdKeys.Curve25519 = kPCurve25519
return a, nil
}
// PickleAsJSON returns an Account as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
func (a Account) PickleAsJSON(key []byte) ([]byte, error) {
return utilities.PickleAsJSON(a, accountPickleVersionJSON, key)
}
// UnpickleAsJSON updates an Account by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
func (a *Account) UnpickleAsJSON(pickled, key []byte) error {
return utilities.UnpickleAsJSON(a, pickled, key, accountPickleVersionJSON)
}
// IdentityKeysJSON returns the public parts of the identity keys for the Account in a JSON string.
func (a Account) IdentityKeysJSON() ([]byte, error) {
res := struct {
Ed25519 string `json:"ed25519"`
Curve25519 string `json:"curve25519"`
}{}
ed25519, curve25519 := a.IdentityKeys()
res.Ed25519 = string(ed25519)
res.Curve25519 = string(curve25519)
return json.Marshal(res)
}
// IdentityKeys returns the public parts of the Ed25519 and Curve25519 identity keys for the Account.
func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
ed25519 := id.Ed25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.PublicKey))
curve25519 := id.Curve25519(base64.RawStdEncoding.EncodeToString(a.IdKeys.Curve25519.PublicKey))
return ed25519, curve25519
}
// Sign returns the signature of a message using the Ed25519 key for this Account.
func (a Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 {
return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput)
}
return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil
}
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account.
//
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) OneTimeKeys() map[string]id.Curve25519 {
oneTimeKeys := make(map[string]id.Curve25519)
for _, curKey := range a.OTKeys {
if !curKey.Published {
oneTimeKeys[curKey.KeyIDEncoded()] = id.Curve25519(curKey.PublicKeyEncoded())
}
}
return oneTimeKeys
}
//OneTimeKeysJSON returns the public parts of the unpublished one time keys of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
Curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo",
"AAAAAB": "LRvjo46L1X2vx69sS9QNFD29HWulxrmW11Up5AfAjgU"
}
}
*/
func (a Account) OneTimeKeysJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
otKeys := a.OneTimeKeys()
res["Curve25519"] = otKeys
return json.Marshal(res)
}
// MarkKeysAsPublished marks the current set of one time keys and the fallback key as being
// published.
func (a *Account) MarkKeysAsPublished() {
for keyIndex := range a.OTKeys {
if !a.OTKeys[keyIndex].Published {
a.OTKeys[keyIndex].Published = true
}
}
a.CurrentFallbackKey.Published = true
}
// GenOneTimeKeys generates a number of new one time keys. If the total number
// of keys stored by this Account exceeds MaxOneTimeKeys then the older
// keys are discarded. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenOneTimeKeys(reader io.Reader, num uint) error {
for i := uint(0); i < num; i++ {
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return err
}
key.Key = newKP
a.NextOneTimeKeyID++
a.OTKeys = append([]crypto.OneTimeKey{key}, a.OTKeys...)
}
if len(a.OTKeys) > MaxOneTimeKeys {
a.OTKeys = a.OTKeys[:MaxOneTimeKeys]
}
return nil
}
// NewOutboundSession creates a new outbound session to a
// given curve25519 identity Key and one time key.
func (a Account) NewOutboundSession(theirIdentityKey, theirOneTimeKey id.Curve25519) (*session.OlmSession, error) {
if len(theirIdentityKey) == 0 || len(theirOneTimeKey) == 0 {
return nil, fmt.Errorf("outbound session: %w", goolm.ErrEmptyInput)
}
theirIdentityKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirIdentityKey))
if err != nil {
return nil, err
}
theirOneTimeKeyDecoded, err := base64.RawStdEncoding.DecodeString(string(theirOneTimeKey))
if err != nil {
return nil, err
}
s, err := session.NewOutboundOlmSession(a.IdKeys.Curve25519, theirIdentityKeyDecoded, theirOneTimeKeyDecoded)
if err != nil {
return nil, err
}
return s, nil
}
// NewInboundSession creates a new inbound session from an incoming PRE_KEY message.
func (a Account) NewInboundSession(theirIdentityKey *id.Curve25519, oneTimeKeyMsg []byte) (*session.OlmSession, error) {
if len(oneTimeKeyMsg) == 0 {
return nil, fmt.Errorf("inbound session: %w", goolm.ErrEmptyInput)
}
var theirIdentityKeyDecoded *crypto.Curve25519PublicKey
var err error
if theirIdentityKey != nil {
theirIdentityKeyDecodedByte, err := base64.RawStdEncoding.DecodeString(string(*theirIdentityKey))
if err != nil {
return nil, err
}
theirIdentityKeyCurve := crypto.Curve25519PublicKey(theirIdentityKeyDecodedByte)
theirIdentityKeyDecoded = &theirIdentityKeyCurve
}
s, err := session.NewInboundOlmSession(theirIdentityKeyDecoded, oneTimeKeyMsg, a.searchOTKForOur, a.IdKeys.Curve25519)
if err != nil {
return nil, err
}
return s, nil
}
func (a Account) searchOTKForOur(toFind crypto.Curve25519PublicKey) *crypto.OneTimeKey {
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
return &a.OTKeys[curIndex]
}
}
if a.NumFallbackKeys >= 1 && a.CurrentFallbackKey.Key.PublicKey.Equal(toFind) {
return &a.CurrentFallbackKey
}
if a.NumFallbackKeys >= 2 && a.PrevFallbackKey.Key.PublicKey.Equal(toFind) {
return &a.PrevFallbackKey
}
return nil
}
// RemoveOneTimeKeys removes the one time key in this Account which matches the one time key in the session s.
func (a *Account) RemoveOneTimeKeys(s *session.OlmSession) {
toFind := s.BobOneTimeKey
for curIndex := range a.OTKeys {
if a.OTKeys[curIndex].Key.PublicKey.Equal(toFind) {
//Remove and return
a.OTKeys[curIndex] = a.OTKeys[len(a.OTKeys)-1]
a.OTKeys = a.OTKeys[:len(a.OTKeys)-1]
return
}
}
//if the key is a fallback or prevFallback, don't remove it
}
// GenFallbackKey generates a new fallback key. The old fallback key is stored in a.PrevFallbackKey overwriting any previous PrevFallbackKey. If reader is nil, crypto/rand is used for the key creation.
func (a *Account) GenFallbackKey(reader io.Reader) error {
a.PrevFallbackKey = a.CurrentFallbackKey
key := crypto.OneTimeKey{
Published: false,
ID: a.NextOneTimeKeyID,
}
newKP, err := crypto.Curve25519GenerateKey(reader)
if err != nil {
return err
}
key.Key = newKP
a.NextOneTimeKeyID++
if a.NumFallbackKeys < 2 {
a.NumFallbackKeys++
}
a.CurrentFallbackKey = key
return nil
}
// FallbackKey returns the public part of the current fallback key of the Account.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKey() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
}
return keys
}
//FallbackKeyJSON returns the public part of the current fallback key of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
}
}
*/
func (a Account) FallbackKeyJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKey()
res["curve25519"] = fbk
return json.Marshal(res)
}
// FallbackKeyUnpublished returns the public part of the current fallback key of the Account only if it is unpublished.
// The returned data is a map with the mapping of key id to base64-encoded Curve25519 key.
func (a Account) FallbackKeyUnpublished() map[string]id.Curve25519 {
keys := make(map[string]id.Curve25519)
if a.NumFallbackKeys >= 1 && !a.CurrentFallbackKey.Published {
keys[a.CurrentFallbackKey.KeyIDEncoded()] = id.Curve25519(a.CurrentFallbackKey.PublicKeyEncoded())
}
return keys
}
//FallbackKeyUnpublishedJSON returns the public part of the current fallback key, only if it is unpublished, of the Account as a JSON string.
//
//The returned JSON is of format:
/*
{
curve25519: {
"AAAAAA": "wo76WcYtb0Vk/pBOdmduiGJ0wIEjW4IBMbbQn7aSnTo"
}
}
*/
func (a Account) FallbackKeyUnpublishedJSON() ([]byte, error) {
res := make(map[string]map[string]id.Curve25519)
fbk := a.FallbackKeyUnpublished()
res["curve25519"] = fbk
return json.Marshal(res)
}
// ForgetOldFallbackKey resets the previous fallback key in the account.
func (a *Account) ForgetOldFallbackKey() {
if a.NumFallbackKeys >= 2 {
a.NumFallbackKeys = 1
a.PrevFallbackKey = crypto.OneTimeKey{}
}
}
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
// The decrypted value is then passed to UnpickleLibOlm.
func (a *Account) Unpickle(pickled, key []byte) error {
decrypted, err := cipher.Unpickle(key, pickled)
if err != nil {
return err
}
_, err = a.UnpickleLibOlm(decrypted)
return err
}
// UnpickleLibOlm decodes the unencryted value and populates the Account accordingly. It returns the number of bytes read.
func (a *Account) UnpickleLibOlm(value []byte) (int, error) {
//First 4 bytes are the accountPickleVersion
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
if err != nil {
return 0, err
}
switch pickledVersion {
case accountPickleVersionLibOLM, 3, 2:
default:
return 0, fmt.Errorf("unpickle account: %w", goolm.ErrBadVersion)
}
//read ed25519 key pair
readBytes, err := a.IdKeys.Ed25519.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//read curve25519 key pair
readBytes, err = a.IdKeys.Curve25519.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//Read number of onetimeKeys
numberOTKeys, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
//Read i one time keys
a.OTKeys = make([]crypto.OneTimeKey, numberOTKeys)
for i := uint32(0); i < numberOTKeys; i++ {
readBytes, err := a.OTKeys[i].UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
if pickledVersion <= 2 {
// version 2 did not have fallback keys
a.NumFallbackKeys = 0
} else if pickledVersion == 3 {
// version 3 used the published flag to indicate how many fallback keys
// were present (we'll have to assume that the keys were published)
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
readBytes, err = a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if a.CurrentFallbackKey.Published {
if a.PrevFallbackKey.Published {
a.NumFallbackKeys = 2
} else {
a.NumFallbackKeys = 1
}
} else {
a.NumFallbackKeys = 0
}
} else {
//Read number of fallback keys
numFallbackKeys, readBytes, err := libolmpickle.UnpickleUInt8(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
a.NumFallbackKeys = numFallbackKeys
if a.NumFallbackKeys >= 1 {
readBytes, err := a.CurrentFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
if a.NumFallbackKeys >= 2 {
readBytes, err := a.PrevFallbackKey.UnpickleLibOlm(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
}
}
}
//Read next onetime key id
nextOTKeyID, readBytes, err := libolmpickle.UnpickleUInt32(value[curPos:])
if err != nil {
return 0, err
}
curPos += readBytes
a.NextOneTimeKeyID = nextOTKeyID
return curPos, nil
}
// Pickle returns a base64 encoded and with key encrypted pickled account using PickleLibOlm().
func (a Account) Pickle(key []byte) ([]byte, error) {
pickeledBytes := make([]byte, a.PickleLen())
written, err := a.PickleLibOlm(pickeledBytes)
if err != nil {
return nil, err
}
if written != len(pickeledBytes) {
return nil, errors.New("number of written bytes not correct")
}
encrypted, err := cipher.Pickle(key, pickeledBytes)
if err != nil {
return nil, err
}
return encrypted, nil
}
// PickleLibOlm encodes the Account into target. target has to have a size of at least PickleLen() and is written to from index 0.
// It returns the number of bytes written.
func (a Account) PickleLibOlm(target []byte) (int, error) {
if len(target) < a.PickleLen() {
return 0, fmt.Errorf("pickle account: %w", goolm.ErrValueTooShort)
}
written := libolmpickle.PickleUInt32(accountPickleVersionLibOLM, target)
writtenEdKey, err := a.IdKeys.Ed25519.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenEdKey
writtenCurveKey, err := a.IdKeys.Curve25519.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenCurveKey
written += libolmpickle.PickleUInt32(uint32(len(a.OTKeys)), target[written:])
for _, curOTKey := range a.OTKeys {
writtenOT, err := curOTKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
}
written += libolmpickle.PickleUInt8(a.NumFallbackKeys, target[written:])
if a.NumFallbackKeys >= 1 {
writtenOT, err := a.CurrentFallbackKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
if a.NumFallbackKeys >= 2 {
writtenOT, err := a.PrevFallbackKey.PickleLibOlm(target[written:])
if err != nil {
return 0, fmt.Errorf("pickle account: %w", err)
}
written += writtenOT
}
}
written += libolmpickle.PickleUInt32(a.NextOneTimeKeyID, target[written:])
return written, nil
}
// PickleLen returns the number of bytes the pickled Account will have.
func (a Account) PickleLen() int {
length := libolmpickle.PickleUInt32Len(accountPickleVersionLibOLM)
length += a.IdKeys.Ed25519.PickleLen()
length += a.IdKeys.Curve25519.PickleLen()
length += libolmpickle.PickleUInt32Len(uint32(len(a.OTKeys)))
length += (len(a.OTKeys) * (&crypto.OneTimeKey{}).PickleLen())
length += libolmpickle.PickleUInt8Len(a.NumFallbackKeys)
length += (int(a.NumFallbackKeys) * (&crypto.OneTimeKey{}).PickleLen())
length += libolmpickle.PickleUInt32Len(a.NextOneTimeKeyID)
return length
}