BREAKING: update mautrix to 0.15.x

This commit is contained in:
Aine
2023-06-01 14:32:20 +00:00
parent a6b20a75ab
commit 2bdb8ca635
222 changed files with 7851 additions and 23986 deletions

View File

@@ -7,13 +7,19 @@
package crypto
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
"maunium.net/go/mautrix/event"
@@ -80,6 +86,35 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) {
return store.SyncToken, nil
}
var _ mautrix.SyncStore = (*SQLCryptoStore)(nil)
func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {}
func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" }
func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
err := store.PutNextBatch(nextBatchToken)
if err != nil {
// TODO handle error
}
}
func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string {
nb, err := store.GetNextBatch()
if err != nil {
// TODO handle error
}
return nb
}
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
// TODO return error
store.DB.Log.Warn("Failed to scan device ID: %v", err)
}
return
}
// PutAccount stores an OlmAccount in the database.
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
store.Account = account
@@ -220,37 +255,72 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession)
return err
}
func intishPtr[T int | int64](i T) *T {
if i == 0 {
return nil
}
return &i
}
func datePtr(t time.Time) *time.Time {
if t.IsZero() {
return nil
}
return &t
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
forwardingChains := strings.Join(session.ForwardingChains, ",")
_, err := store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
if err != nil {
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
}
_, err = store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains
`, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID)
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
`,
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
session.IsScheduled, store.AccountID,
)
return err
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
var signingKey, forwardingChains, withheldCode sql.NullString
var sessionBytes []byte
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := store.DB.QueryRow(`
SELECT signing_key, session, forwarding_chains, withheld_code
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode)
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
} else if withheldCode.Valid {
return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheldCode.String)
return nil, &event.RoomKeyWithheldEventContent{
RoomID: roomID,
Algorithm: id.AlgorithmMegolmV1,
SessionID: sessionID,
SenderKey: senderKey,
Code: event.RoomKeyWithheldCode(withheldCode.String),
Reason: withheldReason.String,
}
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
@@ -261,18 +331,96 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
if senderKey == "" {
senderKey = id.Curve25519(senderKeyDB.String)
}
return &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
}, nil
}
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
_, err := store.DB.Exec(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, sessionID, store.AccountID)
return err
}
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
if roomID == "" && senderKey == "" {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
res, err := store.DB.Query(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
}
return sessionIDs, err
}
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
var query string
switch store.DB.Dialect {
case dbutil.Postgres:
query = `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND received_at + 2 * (max_age * interval '1 millisecond') < now()
RETURNING session_id
`
case dbutil.SQLite:
query = `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
RETURNING session_id
`
default:
return nil, fmt.Errorf("unsupported dialect")
}
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
}
return sessionIDs, err
}
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID)
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
return err
}
@@ -302,8 +450,11 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
for rows.Next() {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes []byte
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains)
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err != nil {
return
}
@@ -316,12 +467,24 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
result = append(result, &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
})
}
return
@@ -329,7 +492,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
@@ -343,7 +506,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
store.AccountID,
)
@@ -367,7 +530,7 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
max_messages=excluded.max_messages, message_count=excluded.message_count, max_age=excluded.max_age,
created_at=excluded.created_at, last_used=excluded.last_used, account_id=excluded.account_id
`, session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount,
session.MaxAge, session.CreationTime, session.LastEncryptedTime, store.AccountID)
session.MaxAge.Milliseconds(), session.CreationTime, session.LastEncryptedTime, store.AccountID)
return err
}
@@ -383,11 +546,12 @@ func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSe
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
var ogs OutboundGroupSession
var sessionBytes []byte
var maxAgeMS int64
err := store.DB.QueryRow(`
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
roomID, store.AccountID,
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.LastEncryptedTime)
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
@@ -400,6 +564,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
}
ogs.Internal = *intOGS
ogs.RoomID = roomID
ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond
return &ogs, nil
}
@@ -412,7 +577,7 @@ func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
const validateQuery = `
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
VALUES ($1, $2, $3, $4, $5)
@@ -422,11 +587,19 @@ func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessio
`
var expectedEventID id.EventID
var expectedTimestamp int64
err := store.DB.QueryRow(validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
if err != nil {
return false, err
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
zerolog.Ctx(ctx).Debug().
Uint("message_index", index).
Str("expected_event_id", expectedEventID.String()).
Int64("expected_timestamp", expectedTimestamp).
Int64("actual_timestamp", timestamp).
Msg("Failed to validate that message index wasn't duplicated")
return false, nil
}
return expectedEventID == eventID && expectedTimestamp == timestamp, nil
return true, nil
}
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.