BREAKING: update mautrix to 0.15.x
This commit is contained in:
219
vendor/maunium.net/go/mautrix/crypto/sql_store.go
generated
vendored
219
vendor/maunium.net/go/mautrix/crypto/sql_store.go
generated
vendored
@@ -7,13 +7,19 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
"maunium.net/go/mautrix/event"
|
||||
@@ -80,6 +86,35 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) {
|
||||
return store.SyncToken, nil
|
||||
}
|
||||
|
||||
var _ mautrix.SyncStore = (*SQLCryptoStore)(nil)
|
||||
|
||||
func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {}
|
||||
func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" }
|
||||
|
||||
func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
|
||||
err := store.PutNextBatch(nextBatchToken)
|
||||
if err != nil {
|
||||
// TODO handle error
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string {
|
||||
nb, err := store.GetNextBatch()
|
||||
if err != nil {
|
||||
// TODO handle error
|
||||
}
|
||||
return nb
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
|
||||
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
// TODO return error
|
||||
store.DB.Log.Warn("Failed to scan device ID: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PutAccount stores an OlmAccount in the database.
|
||||
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
|
||||
store.Account = account
|
||||
@@ -220,37 +255,72 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession)
|
||||
return err
|
||||
}
|
||||
|
||||
func intishPtr[T int | int64](i T) *T {
|
||||
if i == 0 {
|
||||
return nil
|
||||
}
|
||||
return &i
|
||||
}
|
||||
|
||||
func datePtr(t time.Time) *time.Time {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
forwardingChains := strings.Join(session.ForwardingChains, ",")
|
||||
_, err := store.DB.Exec(`
|
||||
INSERT INTO crypto_megolm_inbound_session
|
||||
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
|
||||
}
|
||||
_, err = store.DB.Exec(`
|
||||
INSERT INTO crypto_megolm_inbound_session (
|
||||
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
||||
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT (session_id, account_id) DO UPDATE
|
||||
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
|
||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains
|
||||
`, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID)
|
||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
|
||||
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
|
||||
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
|
||||
`,
|
||||
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
|
||||
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
|
||||
session.IsScheduled, store.AccountID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
var signingKey, forwardingChains, withheldCode sql.NullString
|
||||
var sessionBytes []byte
|
||||
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT signing_key, session, forwarding_chains, withheld_code
|
||||
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
|
||||
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode)
|
||||
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else if withheldCode.Valid {
|
||||
return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheldCode.String)
|
||||
return nil, &event.RoomKeyWithheldEventContent{
|
||||
RoomID: roomID,
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
SessionID: sessionID,
|
||||
SenderKey: senderKey,
|
||||
Code: event.RoomKeyWithheldCode(withheldCode.String),
|
||||
Reason: withheldReason.String,
|
||||
}
|
||||
}
|
||||
igs := olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
@@ -261,18 +331,96 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
if senderKey == "" {
|
||||
senderKey = id.Curve25519(senderKeyDB.String)
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: senderKey,
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
_, err := store.DB.Exec(`
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, sessionID, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
if roomID == "" && senderKey == "" {
|
||||
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
|
||||
}
|
||||
res, err := store.DB.Query(`
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
|
||||
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
|
||||
RETURNING session_id
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
return sessionIDs, err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
|
||||
var query string
|
||||
switch store.DB.Dialect {
|
||||
case dbutil.Postgres:
|
||||
query = `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
|
||||
AND received_at IS NOT NULL and max_age IS NOT NULL
|
||||
AND received_at + 2 * (max_age * interval '1 millisecond') < now()
|
||||
RETURNING session_id
|
||||
`
|
||||
case dbutil.SQLite:
|
||||
query = `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
|
||||
AND received_at IS NOT NULL and max_age IS NOT NULL
|
||||
AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
|
||||
RETURNING session_id
|
||||
`
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported dialect")
|
||||
}
|
||||
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
return sessionIDs, err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID)
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -302,8 +450,11 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
var signingKey, senderKey, forwardingChains sql.NullString
|
||||
var sessionBytes []byte
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains)
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -316,12 +467,24 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
result = append(result, &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
})
|
||||
}
|
||||
return
|
||||
@@ -329,7 +492,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
||||
roomID, store.AccountID,
|
||||
)
|
||||
@@ -343,7 +506,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou
|
||||
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
|
||||
store.AccountID,
|
||||
)
|
||||
@@ -367,7 +530,7 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
|
||||
max_messages=excluded.max_messages, message_count=excluded.message_count, max_age=excluded.max_age,
|
||||
created_at=excluded.created_at, last_used=excluded.last_used, account_id=excluded.account_id
|
||||
`, session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount,
|
||||
session.MaxAge, session.CreationTime, session.LastEncryptedTime, store.AccountID)
|
||||
session.MaxAge.Milliseconds(), session.CreationTime, session.LastEncryptedTime, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -383,11 +546,12 @@ func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSe
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
var ogs OutboundGroupSession
|
||||
var sessionBytes []byte
|
||||
var maxAgeMS int64
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
|
||||
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
|
||||
roomID, store.AccountID,
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.LastEncryptedTime)
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
@@ -400,6 +564,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
|
||||
}
|
||||
ogs.Internal = *intOGS
|
||||
ogs.RoomID = roomID
|
||||
ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond
|
||||
return &ogs, nil
|
||||
}
|
||||
|
||||
@@ -412,7 +577,7 @@ func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error
|
||||
|
||||
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
|
||||
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
|
||||
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
|
||||
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
|
||||
const validateQuery = `
|
||||
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -422,11 +587,19 @@ func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessio
|
||||
`
|
||||
var expectedEventID id.EventID
|
||||
var expectedTimestamp int64
|
||||
err := store.DB.QueryRow(validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Uint("message_index", index).
|
||||
Str("expected_event_id", expectedEventID.String()).
|
||||
Int64("expected_timestamp", expectedTimestamp).
|
||||
Int64("actual_timestamp", timestamp).
|
||||
Msg("Failed to validate that message index wasn't duplicated")
|
||||
return false, nil
|
||||
}
|
||||
return expectedEventID == eventID && expectedTimestamp == timestamp, nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
|
||||
|
||||
Reference in New Issue
Block a user