166 lines
5.0 KiB
Go
166 lines
5.0 KiB
Go
package pk
|
|
|
|
import (
|
|
"encoding/base64"
|
|
"errors"
|
|
"fmt"
|
|
|
|
"maunium.net/go/mautrix/crypto/goolm"
|
|
"maunium.net/go/mautrix/crypto/goolm/cipher"
|
|
"maunium.net/go/mautrix/crypto/goolm/crypto"
|
|
"maunium.net/go/mautrix/crypto/goolm/libolmpickle"
|
|
"maunium.net/go/mautrix/crypto/goolm/utilities"
|
|
"maunium.net/go/mautrix/id"
|
|
)
|
|
|
|
const (
|
|
decryptionPickleVersionJSON uint8 = 1
|
|
decryptionPickleVersionLibOlm uint32 = 1
|
|
)
|
|
|
|
// Decryption is used to decrypt pk messages
|
|
type Decryption struct {
|
|
KeyPair crypto.Curve25519KeyPair `json:"key_pair"`
|
|
}
|
|
|
|
// NewDecryption returns a new Decryption with a new generated key pair.
|
|
func NewDecryption() (*Decryption, error) {
|
|
keyPair, err := crypto.Curve25519GenerateKey(nil)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return &Decryption{
|
|
KeyPair: keyPair,
|
|
}, nil
|
|
}
|
|
|
|
// NewDescriptionFromPrivate resturns a new Decryption with the private key fixed.
|
|
func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decryption, error) {
|
|
s := &Decryption{}
|
|
keyPair, err := crypto.Curve25519GenerateFromPrivate(privateKey)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
s.KeyPair = keyPair
|
|
return s, nil
|
|
}
|
|
|
|
// PublicKey returns the public key base 64 encoded.
|
|
func (s Decryption) PublicKey() id.Curve25519 {
|
|
return s.KeyPair.B64Encoded()
|
|
}
|
|
|
|
// PrivateKey returns the private key.
|
|
func (s Decryption) PrivateKey() crypto.Curve25519PrivateKey {
|
|
return s.KeyPair.PrivateKey
|
|
}
|
|
|
|
// Decrypt decrypts the ciphertext and verifies the MAC. The base64 encoded key is used to construct the shared secret.
|
|
func (s Decryption) Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error) {
|
|
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
decodedMAC, err := goolm.Base64Decode(mac)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
cipher := cipher.NewAESSHA256(nil)
|
|
verified, err := cipher.Verify(sharedSecret, ciphertext, decodedMAC)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if !verified {
|
|
return nil, fmt.Errorf("decrypt: %w", goolm.ErrBadMAC)
|
|
}
|
|
plaintext, err := cipher.Decrypt(sharedSecret, ciphertext)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return plaintext, nil
|
|
}
|
|
|
|
// PickleAsJSON returns an Decryption as a base64 string encrypted using the supplied key. The unencrypted representation of the Account is in JSON format.
|
|
func (a Decryption) PickleAsJSON(key []byte) ([]byte, error) {
|
|
return utilities.PickleAsJSON(a, decryptionPickleVersionJSON, key)
|
|
}
|
|
|
|
// UnpickleAsJSON updates an Decryption by a base64 encrypted string using the supplied key. The unencrypted representation has to be in JSON format.
|
|
func (a *Decryption) UnpickleAsJSON(pickled, key []byte) error {
|
|
return utilities.UnpickleAsJSON(a, pickled, key, decryptionPickleVersionJSON)
|
|
}
|
|
|
|
// Unpickle decodes the base64 encoded string and decrypts the result with the key.
|
|
// The decrypted value is then passed to UnpickleLibOlm.
|
|
func (a *Decryption) Unpickle(pickled, key []byte) error {
|
|
decrypted, err := cipher.Unpickle(key, pickled)
|
|
if err != nil {
|
|
return err
|
|
}
|
|
_, err = a.UnpickleLibOlm(decrypted)
|
|
return err
|
|
}
|
|
|
|
// UnpickleLibOlm decodes the unencryted value and populates the Decryption accordingly. It returns the number of bytes read.
|
|
func (a *Decryption) UnpickleLibOlm(value []byte) (int, error) {
|
|
//First 4 bytes are the accountPickleVersion
|
|
pickledVersion, curPos, err := libolmpickle.UnpickleUInt32(value)
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
switch pickledVersion {
|
|
case decryptionPickleVersionLibOlm:
|
|
default:
|
|
return 0, fmt.Errorf("unpickle olmSession: %w", goolm.ErrBadVersion)
|
|
}
|
|
readBytes, err := a.KeyPair.UnpickleLibOlm(value[curPos:])
|
|
if err != nil {
|
|
return 0, err
|
|
}
|
|
curPos += readBytes
|
|
return curPos, nil
|
|
}
|
|
|
|
// Pickle returns a base64 encoded and with key encrypted pickled Decryption using PickleLibOlm().
|
|
func (a Decryption) Pickle(key []byte) ([]byte, error) {
|
|
pickeledBytes := make([]byte, a.PickleLen())
|
|
written, err := a.PickleLibOlm(pickeledBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
if written != len(pickeledBytes) {
|
|
return nil, errors.New("number of written bytes not correct")
|
|
}
|
|
encrypted, err := cipher.Pickle(key, pickeledBytes)
|
|
if err != nil {
|
|
return nil, err
|
|
}
|
|
return encrypted, nil
|
|
}
|
|
|
|
// PickleLibOlm encodes the Decryption into target. target has to have a size of at least PickleLen() and is written to from index 0.
|
|
// It returns the number of bytes written.
|
|
func (a Decryption) PickleLibOlm(target []byte) (int, error) {
|
|
if len(target) < a.PickleLen() {
|
|
return 0, fmt.Errorf("pickle Decryption: %w", goolm.ErrValueTooShort)
|
|
}
|
|
written := libolmpickle.PickleUInt32(decryptionPickleVersionLibOlm, target)
|
|
writtenKey, err := a.KeyPair.PickleLibOlm(target[written:])
|
|
if err != nil {
|
|
return 0, fmt.Errorf("pickle Decryption: %w", err)
|
|
}
|
|
written += writtenKey
|
|
return written, nil
|
|
}
|
|
|
|
// PickleLen returns the number of bytes the pickled Decryption will have.
|
|
func (a Decryption) PickleLen() int {
|
|
length := libolmpickle.PickleUInt32Len(decryptionPickleVersionLibOlm)
|
|
length += a.KeyPair.PickleLen()
|
|
return length
|
|
}
|