BREAKING: update mautrix to 0.15.x

This commit is contained in:
Aine
2023-06-01 14:32:20 +00:00
parent a6b20a75ab
commit 2bdb8ca635
222 changed files with 7851 additions and 23986 deletions

View File

@@ -1,6 +1,6 @@
# maulogger
A logger in Go.
A logger in Go. Deprecated in favor of [zerolog](https://github.com/rs/zerolog).
Docs: [godoc.org/maunium.net/go/maulogger](https://godoc.org/maunium.net/go/maulogger)
Go get: `go get maunium.net/go/maulogger`
Utilities for migrating gracefully can be found in the maulogadapt package,
it includes both wrapping a zerolog in the maulogger interface, and wrapping a
maulogger as a zerolog output writer.

View File

@@ -0,0 +1,185 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogadapt
import (
"fmt"
"io"
"strings"
"github.com/rs/zerolog"
"maunium.net/go/maulogger/v2"
)
type MauZeroLog struct {
*zerolog.Logger
orig *zerolog.Logger
mod string
}
func ZeroAsMau(log *zerolog.Logger) maulogger.Logger {
return MauZeroLog{log, log, ""}
}
var _ maulogger.Logger = (*MauZeroLog)(nil)
func (m MauZeroLog) Sub(module string) maulogger.Logger {
return m.Subm(module, map[string]interface{}{})
}
func (m MauZeroLog) Subm(module string, metadata map[string]interface{}) maulogger.Logger {
if m.mod != "" {
module = fmt.Sprintf("%s/%s", m.mod, module)
}
var orig zerolog.Logger
if m.orig != nil {
orig = *m.orig
} else {
orig = *m.Logger
}
if len(metadata) > 0 {
with := m.orig.With()
for key, value := range metadata {
with = with.Interface(key, value)
}
orig = with.Logger()
}
log := orig.With().Str("module", module).Logger()
return MauZeroLog{&log, &orig, module}
}
func (m MauZeroLog) WithDefaultLevel(_ maulogger.Level) maulogger.Logger {
return m
}
func (m MauZeroLog) GetParent() maulogger.Logger {
return nil
}
type nopWriteCloser struct {
io.Writer
}
func (nopWriteCloser) Close() error { return nil }
func (m MauZeroLog) Writer(level maulogger.Level) io.WriteCloser {
return nopWriteCloser{m.Logger.With().Str(zerolog.LevelFieldName, zerolog.LevelFieldMarshalFunc(mauToZeroLevel(level))).Logger()}
}
func mauToZeroLevel(level maulogger.Level) zerolog.Level {
switch level {
case maulogger.LevelDebug:
return zerolog.DebugLevel
case maulogger.LevelInfo:
return zerolog.InfoLevel
case maulogger.LevelWarn:
return zerolog.WarnLevel
case maulogger.LevelError:
return zerolog.ErrorLevel
case maulogger.LevelFatal:
return zerolog.FatalLevel
default:
return zerolog.TraceLevel
}
}
func (m MauZeroLog) Log(level maulogger.Level, parts ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Logln(level maulogger.Level, parts ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Logf(level maulogger.Level, message string, args ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Logfln(level maulogger.Level, message string, args ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Debug(parts ...interface{}) {
m.Logger.Debug().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Debugln(parts ...interface{}) {
m.Logger.Debug().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Debugf(message string, args ...interface{}) {
m.Logger.Debug().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Debugfln(message string, args ...interface{}) {
m.Logger.Debug().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Info(parts ...interface{}) {
m.Logger.Info().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Infoln(parts ...interface{}) {
m.Logger.Info().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Infof(message string, args ...interface{}) {
m.Logger.Info().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Infofln(message string, args ...interface{}) {
m.Logger.Info().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Warn(parts ...interface{}) {
m.Logger.Warn().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Warnln(parts ...interface{}) {
m.Logger.Warn().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Warnf(message string, args ...interface{}) {
m.Logger.Warn().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Warnfln(message string, args ...interface{}) {
m.Logger.Warn().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Error(parts ...interface{}) {
m.Logger.Error().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Errorln(parts ...interface{}) {
m.Logger.Error().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Errorf(message string, args ...interface{}) {
m.Logger.Error().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Errorfln(message string, args ...interface{}) {
m.Logger.Error().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Fatal(parts ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Fatalln(parts ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Fatalf(message string, args ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Fatalfln(message string, args ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprintf(message, args...))
}

View File

@@ -0,0 +1,73 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogadapt
import (
"bytes"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"maunium.net/go/maulogger/v2"
)
// ZeroMauLog is a simple wrapper for a maulogger that can be set as the output writer for zerolog.
type ZeroMauLog struct {
maulogger.Logger
}
func MauAsZero(log maulogger.Logger) *zerolog.Logger {
zero := zerolog.New(&ZeroMauLog{log})
return &zero
}
var _ zerolog.LevelWriter = (*ZeroMauLog)(nil)
func (z *ZeroMauLog) Write(p []byte) (n int, err error) {
return 0, nil
}
func (z *ZeroMauLog) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
var mauLevel maulogger.Level
switch level {
case zerolog.DebugLevel:
mauLevel = maulogger.LevelDebug
case zerolog.InfoLevel, zerolog.NoLevel:
mauLevel = maulogger.LevelInfo
case zerolog.WarnLevel:
mauLevel = maulogger.LevelWarn
case zerolog.ErrorLevel:
mauLevel = maulogger.LevelError
case zerolog.FatalLevel, zerolog.PanicLevel:
mauLevel = maulogger.LevelFatal
case zerolog.Disabled, zerolog.TraceLevel:
fallthrough
default:
return 0, nil
}
p = bytes.TrimSuffix(p, []byte{'\n'})
msg := gjson.GetBytes(p, zerolog.MessageFieldName).Str
p, err = sjson.DeleteBytes(p, zerolog.MessageFieldName)
if err != nil {
return
}
p, err = sjson.DeleteBytes(p, zerolog.LevelFieldName)
if err != nil {
return
}
p, err = sjson.DeleteBytes(p, zerolog.TimestampFieldName)
if err != nil {
return
}
if len(p) > 2 {
msg += " " + string(p)
}
z.Log(mauLevel, msg)
return len(p), nil
}

View File

@@ -1,2 +1,4 @@
.idea/
.vscode/
*.db
*.log

View File

@@ -1,6 +1,6 @@
repos:
- repo: https://github.com/pre-commit/pre-commit-hooks
rev: v4.1.0
rev: v4.4.0
hooks:
- id: trailing-whitespace
exclude_types: [markdown]
@@ -9,7 +9,7 @@ repos:
- id: check-added-large-files
- repo: https://github.com/tekwizely/pre-commit-golang
rev: v1.0.0-beta.5
rev: v1.0.0-rc.1
hooks:
- id: go-imports-repo
- id: go-vet-repo-mod

View File

@@ -1,3 +1,110 @@
## v0.15.2 (2023-05-16)
* *(client)* Changed member-fetching methods to clear existing member info in
state store.
* *(client)* Added support for inserting mautrix-go commit hash into default
user agent at compile time.
* *(bridge)* Fixed bridge bot intent not having state store set.
* *(client)* Fixed `RespError` marshaling mutating the `ExtraData` map and
potentially causing panics.
* *(util/dbutil)* Added `DoTxn` method for an easier way to manage database
transactions.
* *(util)* Added a zerolog `CallerMarshalFunc` implementation that includes the
function name.
* *(bridge)* Added error reply to encrypted messages if the bridge isn't
configured to do encryption.
## v0.15.1 (2023-04-16)
* *(crypto, bridge)* Added options to automatically ratchet/delete megolm
sessions to minimize access to old messages.
* *(pushrules)* Added method to get entire push rule that matched (instead of
only the list of actions).
* *(pushrules)* Deprecated `NotifySpecified` as there's no reason to read it.
* *(crypto)* Changed `max_age` column in `crypto_megolm_inbound_session` table
to be milliseconds instead of nanoseconds.
* *(util)* Added method for iterating `RingBuffer`.
* *(crypto/cryptohelper)* Changed decryption errors to request session from all
own devices in addition to the sender, instead of only asking the sender.
* *(sqlstatestore)* Fixed `FindSharedRooms` throwing an error when using from
a non-bridge context.
* *(client)* Optimized `AccountDataSyncStore` to not resend save requests if
the sync token didn't change.
* *(types)* Added `Clone()` method for `PowerLevelEventContent`.
## v0.15.0 (2023-03-16)
### beta.3 (2023-03-15)
* **Breaking change *(appservice)*** Removed `Load()` and `AppService.Init()`
functions. The struct should just be created with `Create()` and the relevant
fields should be filled manually.
* **Breaking change *(appservice)*** Removed public `HomeserverURL` field and
replaced it with a `SetHomeserverURL` method.
* *(appservice)* Added support for unix sockets for homeserver URL and
appservice HTTP server.
* *(client)* Changed request logging to log durations as floats instead of
strings (using zerolog's `Dur()`, so the exact output can be configured).
* *(bridge)* Changed zerolog to use nanosecond precision timestamps.
* *(crypto)* Added message index to log after encrypting/decrypting megolm
events, and when failing to decrypt due to duplicate index.
* *(sqlstatestore)* Fixed warning log for rooms that don't have encryption
enabled.
### beta.2 (2023-03-02)
* *(bridge)* Fixed building with `nocrypto` tag.
* *(bridge)* Fixed legacy logging config migration not disabling file writer
when `file_name_format` was empty.
* *(bridge)* Added option to require room power level to run commands.
* *(event)* Added structs for [MSC3952]: Intentional Mentions.
* *(util/variationselector)* Added `FullyQualify` method to add necessary emoji
variation selectors without adding all possible ones.
[MSC3952]: https://github.com/matrix-org/matrix-spec-proposals/pull/3952
### beta.1 (2023-02-24)
* Bumped minimum Go version to 1.19.
* **Breaking changes**
* *(all)* Switched to zerolog for logging.
* The `Client` and `Bridge` structs still include a legacy logger for
backwards compatibility.
* *(client, appservice)* Moved `SQLStateStore` from appservice module to the
top-level (client) module.
* *(client, appservice)* Removed unused `Typing` map in `SQLStateStore`.
* *(client)* Removed unused `SaveRoom` and `LoadRoom` methods in `Storer`.
* *(client, appservice)* Removed deprecated `SendVideo` and `SendImage` methods.
* *(client)* Replaced `AppServiceUserID` field with `SetAppServiceUserID` boolean.
The `UserID` field is used as the value for the query param.
* *(crypto)* Renamed `GobStore` to `MemoryStore` and removed the file saving
features. The data can still be persisted, but the persistence part must be
implemented separately.
* *(crypto)* Removed deprecated `DeviceIdentity` alias
(renamed to `id.Device` long ago).
* *(client)* Removed `Stringifable` interface as it's the same as `fmt.Stringer`.
* *(client)* Renamed `Storer` interface to `SyncStore`. A type alias exists for
backwards-compatibility.
* *(crypto/cryptohelper)* Added package for a simplified crypto interface for clients.
* *(example)* Added e2ee support to example using crypto helper.
* *(client)* Changed default syncer to stop syncing on `M_UNKNOWN_TOKEN` errors.
## v0.14.0 (2023-02-16)
* **Breaking change *(format)*** Refactored the HTML parser `Context` to have
more data.
* *(id)* Fixed escaping path components when forming matrix.to URLs
or `matrix:` URIs.
* *(bridge)* Bumped default timeouts for decrypting incoming messages.
* *(bridge)* Added `RawArgs` to commands to allow accessing non-split input.
* *(bridge)* Added `ReplyAdvanced` to commands to allow setting markdown
settings.
* *(event)* Added `notifications` key to `PowerLevelEventContent`.
* *(event)* Changed `SetEdit` to cut off edit fallback if the message is long.
* *(util)* Added `SyncMap` as a simple generic wrapper for a map with a mutex.
* *(util)* Added `ReturnableOnce` as a wrapper for `sync.Once` with a return
value.
## v0.13.0 (2023-01-16)
* **Breaking change:** Removed `IsTyping` and `SetTyping` in `appservice.StateStore`

View File

@@ -1,410 +0,0 @@
// Copyright (c) 2020 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"errors"
"fmt"
"html/template"
"io/ioutil"
"net/http"
"net/http/cookiejar"
"os"
"path/filepath"
"strings"
"sync"
"syscall"
"time"
"github.com/gorilla/mux"
"github.com/gorilla/websocket"
"golang.org/x/net/publicsuffix"
"gopkg.in/yaml.v3"
"maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// EventChannelSize is the size for the Events channel in Appservice instances.
var EventChannelSize = 64
var OTKChannelSize = 4
// Create a blank appservice instance.
func Create() *AppService {
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
return &AppService{
LogConfig: CreateLogConfig(),
clients: make(map[id.UserID]*mautrix.Client),
intents: make(map[id.UserID]*IntentAPI),
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
StateStore: NewBasicStateStore(),
Router: mux.NewRouter(),
UserAgent: mautrix.DefaultUserAgent,
txnIDC: NewTransactionIDCache(128),
Live: true,
Ready: false,
}
}
// Load an appservice config from a file.
func Load(path string) (*AppService, error) {
data, readErr := ioutil.ReadFile(path)
if readErr != nil {
return nil, readErr
}
config := Create()
return config, yaml.Unmarshal(data, config)
}
// QueryHandler handles room alias and user ID queries from the homeserver.
type QueryHandler interface {
QueryAlias(alias string) bool
QueryUser(userID id.UserID) bool
}
type QueryHandlerStub struct{}
func (qh *QueryHandlerStub) QueryAlias(alias string) bool {
return false
}
func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool {
return false
}
type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
// AppService is the main config for all appservices.
// It also serves as the appservice instance struct.
type AppService struct {
HomeserverDomain string `yaml:"homeserver_domain"`
HomeserverURL string `yaml:"homeserver_url"`
RegistrationPath string `yaml:"registration"`
Host HostConfig `yaml:"host"`
LogConfig LogConfig `yaml:"logging"`
Registration *Registration `yaml:"-"`
Log maulogger.Logger `yaml:"-"`
txnIDC *TransactionIDCache
Events chan *event.Event `yaml:"-"`
ToDeviceEvents chan *event.Event `yaml:"-"`
DeviceLists chan *mautrix.DeviceLists `yaml:"-"`
OTKCounts chan *mautrix.OTKCount `yaml:"-"`
QueryHandler QueryHandler `yaml:"-"`
StateStore StateStore `yaml:"-"`
Router *mux.Router `yaml:"-"`
UserAgent string `yaml:"-"`
server *http.Server
HTTPClient *http.Client
botClient *mautrix.Client
botIntent *IntentAPI
DefaultHTTPRetries int
Live bool
Ready bool
clients map[id.UserID]*mautrix.Client
clientsLock sync.RWMutex
intents map[id.UserID]*IntentAPI
intentsLock sync.RWMutex
ws *websocket.Conn
wsWriteLock sync.Mutex
StopWebsocket func(error)
websocketHandlers map[string]WebsocketHandler
websocketHandlersLock sync.RWMutex
websocketRequests map[int]chan<- *WebsocketCommand
websocketRequestsLock sync.RWMutex
websocketRequestID int32
// ProcessID is an identifier sent to the websocket proxy for debugging connections
ProcessID string
DoublePuppetValue string
GetProfile func(userID id.UserID, roomID id.RoomID) *event.MemberEventContent
}
const DoublePuppetKey = "fi.mau.double_puppet_source"
func getDefaultProcessID() string {
pid := syscall.Getpid()
uid := syscall.Getuid()
hostname, _ := os.Hostname()
return fmt.Sprintf("%s-%d-%d", hostname, uid, pid)
}
func (as *AppService) PrepareWebsocket() {
as.websocketHandlersLock.Lock()
defer as.websocketHandlersLock.Unlock()
if as.websocketHandlers == nil {
as.websocketHandlers = make(map[string]WebsocketHandler, 32)
as.websocketRequests = make(map[int]chan<- *WebsocketCommand)
}
}
// HostConfig contains info about how to host the appservice.
type HostConfig struct {
Hostname string `yaml:"hostname"`
Port uint16 `yaml:"port"`
TLSKey string `yaml:"tls_key,omitempty"`
TLSCert string `yaml:"tls_cert,omitempty"`
}
// Address gets the whole address of the Appservice.
func (hc *HostConfig) Address() string {
return fmt.Sprintf("%s:%d", hc.Hostname, hc.Port)
}
// Save saves this config into a file at the given path.
func (as *AppService) Save(path string) error {
data, err := yaml.Marshal(as)
if err != nil {
return err
}
return ioutil.WriteFile(path, data, 0644)
}
// YAML returns the config in YAML format.
func (as *AppService) YAML() (string, error) {
data, err := yaml.Marshal(as)
if err != nil {
return "", err
}
return string(data), nil
}
func (as *AppService) BotMXID() id.UserID {
return id.NewUserID(as.Registration.SenderLocalpart, as.HomeserverDomain)
}
func (as *AppService) makeIntent(userID id.UserID) *IntentAPI {
as.intentsLock.Lock()
defer as.intentsLock.Unlock()
intent, ok := as.intents[userID]
if ok {
return intent
}
localpart, homeserver, err := userID.Parse()
if err != nil || len(localpart) == 0 || homeserver != as.HomeserverDomain {
if err != nil {
as.Log.Fatalfln("Failed to parse user ID %s: %v", userID, err)
} else if len(localpart) == 0 {
as.Log.Fatalfln("Failed to make intent for %s: localpart is empty", userID)
} else if homeserver != as.HomeserverDomain {
as.Log.Fatalfln("Failed to make intent for %s: homeserver isn't %s", userID, as.HomeserverDomain)
}
return nil
}
intent = as.NewIntentAPI(localpart)
as.intents[userID] = intent
return intent
}
func (as *AppService) Intent(userID id.UserID) *IntentAPI {
as.intentsLock.RLock()
intent, ok := as.intents[userID]
as.intentsLock.RUnlock()
if !ok {
return as.makeIntent(userID)
}
return intent
}
func (as *AppService) BotIntent() *IntentAPI {
if as.botIntent == nil {
as.botIntent = as.makeIntent(as.BotMXID())
}
return as.botIntent
}
func (as *AppService) makeClient(userID id.UserID) *mautrix.Client {
as.clientsLock.Lock()
defer as.clientsLock.Unlock()
client, ok := as.clients[userID]
if ok {
return client
}
client, err := mautrix.NewClient(as.HomeserverURL, userID, as.Registration.AppToken)
if err != nil {
as.Log.Fatalln("Failed to create mautrix client instance:", err)
return nil
}
client.UserAgent = as.UserAgent
client.Syncer = nil
client.Store = nil
client.AppServiceUserID = userID
client.Logger = as.Log.Sub(string(userID))
client.Client = as.HTTPClient
client.DefaultHTTPRetries = as.DefaultHTTPRetries
as.clients[userID] = client
return client
}
func (as *AppService) Client(userID id.UserID) *mautrix.Client {
as.clientsLock.RLock()
client, ok := as.clients[userID]
as.clientsLock.RUnlock()
if !ok {
return as.makeClient(userID)
}
return client
}
func (as *AppService) BotClient() *mautrix.Client {
if as.botClient == nil {
as.botClient = as.makeClient(as.BotMXID())
as.botClient.Logger = as.Log.Sub("Bot")
}
return as.botClient
}
// Init initializes the logger and loads the registration of this appservice.
func (as *AppService) Init() (bool, error) {
as.Events = make(chan *event.Event, EventChannelSize)
as.ToDeviceEvents = make(chan *event.Event, EventChannelSize)
as.OTKCounts = make(chan *mautrix.OTKCount, OTKChannelSize)
as.DeviceLists = make(chan *mautrix.DeviceLists, EventChannelSize)
as.QueryHandler = &QueryHandlerStub{}
if len(as.UserAgent) == 0 {
as.UserAgent = mautrix.DefaultUserAgent
}
if len(as.ProcessID) == 0 {
as.ProcessID = getDefaultProcessID()
}
as.Log = maulogger.Create()
as.LogConfig.Configure(as.Log)
as.Log.Debugln("Logger initialized successfully.")
if len(as.RegistrationPath) > 0 {
var err error
as.Registration, err = LoadRegistration(as.RegistrationPath)
if err != nil {
return false, err
}
}
as.Log.Debugln("Appservice initialized successfully.")
return true, nil
}
// LogConfig contains configs for the logger.
type LogConfig struct {
Directory string `yaml:"directory"`
FileNameFormat string `yaml:"file_name_format"`
FileDateFormat string `yaml:"file_date_format"`
FileMode uint32 `yaml:"file_mode"`
TimestampFormat string `yaml:"timestamp_format"`
RawPrintLevel string `yaml:"print_level"`
JSONStdout bool `yaml:"print_json"`
JSONFile bool `yaml:"file_json"`
PrintLevel int `yaml:"-"`
}
type umLogConfig LogConfig
func (lc *LogConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
err := unmarshal((*umLogConfig)(lc))
if err != nil {
return err
}
switch strings.ToUpper(lc.RawPrintLevel) {
case "TRACE":
lc.PrintLevel = -10
case "DEBUG":
lc.PrintLevel = maulogger.LevelDebug.Severity
case "INFO":
lc.PrintLevel = maulogger.LevelInfo.Severity
case "WARN", "WARNING":
lc.PrintLevel = maulogger.LevelWarn.Severity
case "ERR", "ERROR":
lc.PrintLevel = maulogger.LevelError.Severity
case "FATAL":
lc.PrintLevel = maulogger.LevelFatal.Severity
default:
return errors.New("invalid print level " + lc.RawPrintLevel)
}
return err
}
func (lc *LogConfig) MarshalYAML() (interface{}, error) {
switch {
case lc.PrintLevel >= maulogger.LevelFatal.Severity:
lc.RawPrintLevel = maulogger.LevelFatal.Name
case lc.PrintLevel >= maulogger.LevelError.Severity:
lc.RawPrintLevel = maulogger.LevelError.Name
case lc.PrintLevel >= maulogger.LevelWarn.Severity:
lc.RawPrintLevel = maulogger.LevelWarn.Name
case lc.PrintLevel >= maulogger.LevelInfo.Severity:
lc.RawPrintLevel = maulogger.LevelInfo.Name
default:
lc.RawPrintLevel = maulogger.LevelDebug.Name
}
return lc, nil
}
// CreateLogConfig creates a basic LogConfig.
func CreateLogConfig() LogConfig {
return LogConfig{
Directory: "./logs",
FileNameFormat: "%[1]s-%02[2]d.log",
TimestampFormat: "Jan _2, 2006 15:04:05",
FileMode: 0600,
FileDateFormat: "2006-01-02",
PrintLevel: 10,
}
}
type FileFormatData struct {
Date string
Index int
}
// GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct.
func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat {
if len(lc.Directory) > 0 {
_ = os.MkdirAll(lc.Directory, 0700)
}
path := filepath.Join(lc.Directory, lc.FileNameFormat)
tpl, _ := template.New("fileformat").Parse(path)
return func(now string, i int) string {
var buf strings.Builder
_ = tpl.Execute(&buf, FileFormatData{
Date: now,
Index: i,
})
return buf.String()
}
}
// Configure configures a mauLogger instance with the data in this struct.
func (lc LogConfig) Configure(log maulogger.Logger) {
basicLogger := log.(*maulogger.BasicLogger)
basicLogger.FileFormat = lc.GetFileFormat()
basicLogger.FileMode = os.FileMode(lc.FileMode)
basicLogger.FileTimeFormat = lc.FileDateFormat
basicLogger.TimeFormat = lc.TimestampFormat
basicLogger.PrintLevel = lc.PrintLevel
basicLogger.JSONFile = lc.JSONFile
if lc.JSONStdout {
basicLogger.EnableJSONStdout()
}
}

View File

@@ -1,173 +0,0 @@
// Copyright (c) 2020 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"encoding/json"
"runtime/debug"
log "maunium.net/go/maulogger/v2"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
)
type ExecMode uint8
const (
AsyncHandlers ExecMode = iota
AsyncLoop
Sync
)
type EventHandler func(evt *event.Event)
type OTKHandler func(otk *mautrix.OTKCount)
type DeviceListHandler func(otk *mautrix.DeviceLists, since string)
type EventProcessor struct {
ExecMode ExecMode
as *AppService
log log.Logger
stop chan struct{}
handlers map[event.Type][]EventHandler
otkHandlers []OTKHandler
deviceListHandlers []DeviceListHandler
}
func NewEventProcessor(as *AppService) *EventProcessor {
return &EventProcessor{
ExecMode: AsyncHandlers,
as: as,
log: as.Log.Sub("Events"),
stop: make(chan struct{}, 1),
handlers: make(map[event.Type][]EventHandler),
otkHandlers: make([]OTKHandler, 0),
deviceListHandlers: make([]DeviceListHandler, 0),
}
}
func (ep *EventProcessor) On(evtType event.Type, handler EventHandler) {
handlers, ok := ep.handlers[evtType]
if !ok {
handlers = []EventHandler{handler}
} else {
handlers = append(handlers, handler)
}
ep.handlers[evtType] = handlers
}
func (ep *EventProcessor) PrependHandler(evtType event.Type, handler EventHandler) {
handlers, ok := ep.handlers[evtType]
if !ok {
handlers = []EventHandler{handler}
} else {
handlers = append([]EventHandler{handler}, handlers...)
}
ep.handlers[evtType] = handlers
}
func (ep *EventProcessor) OnOTK(handler OTKHandler) {
ep.otkHandlers = append(ep.otkHandlers, handler)
}
func (ep *EventProcessor) OnDeviceList(handler DeviceListHandler) {
ep.deviceListHandlers = append(ep.deviceListHandlers, handler)
}
func (ep *EventProcessor) recoverFunc(data interface{}) {
if err := recover(); err != nil {
d, _ := json.Marshal(data)
ep.log.Errorfln("Panic in Matrix event handler: %v (event content: %s):\n%s", err, string(d), string(debug.Stack()))
}
}
func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) {
defer ep.recoverFunc(evt)
handler(evt)
}
func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) {
defer ep.recoverFunc(otk)
handler(otk)
}
func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) {
defer ep.recoverFunc(dl)
handler(dl, "")
}
func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) {
for _, handler := range ep.otkHandlers {
go ep.callOTKHandler(handler, otk)
}
}
func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) {
for _, handler := range ep.deviceListHandlers {
go ep.callDeviceListHandler(handler, dl)
}
}
func (ep *EventProcessor) Dispatch(evt *event.Event) {
handlers, ok := ep.handlers[evt.Type]
if !ok {
return
}
switch ep.ExecMode {
case AsyncHandlers:
for _, handler := range handlers {
go ep.callHandler(handler, evt)
}
case AsyncLoop:
go func() {
for _, handler := range handlers {
ep.callHandler(handler, evt)
}
}()
case Sync:
for _, handler := range handlers {
ep.callHandler(handler, evt)
}
}
}
func (ep *EventProcessor) startEvents() {
for {
select {
case evt := <-ep.as.Events:
ep.Dispatch(evt)
case <-ep.stop:
return
}
}
}
func (ep *EventProcessor) startEncryption() {
for {
select {
case evt := <-ep.as.ToDeviceEvents:
ep.Dispatch(evt)
case otk := <-ep.as.OTKCounts:
ep.DispatchOTK(otk)
case dl := <-ep.as.DeviceLists:
ep.DispatchDeviceList(dl)
case <-ep.stop:
return
}
}
}
func (ep *EventProcessor) Start() {
go ep.startEvents()
go ep.startEncryption()
}
func (ep *EventProcessor) Stop() {
close(ep.stop)
}

View File

@@ -1,281 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"context"
"encoding/json"
"errors"
"io/ioutil"
"net/http"
"strings"
"time"
"github.com/gorilla/mux"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// Start starts the HTTP server that listens for calls from the Matrix homeserver.
func (as *AppService) Start() {
as.Router.HandleFunc("/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
as.Router.HandleFunc("/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
as.Router.HandleFunc("/users/{userID}", as.GetUser).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
var err error
as.server = &http.Server{
Addr: as.Host.Address(),
Handler: as.Router,
}
as.Log.Infoln("Listening on", as.Host.Address())
if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 {
err = as.server.ListenAndServe()
} else {
err = as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey)
}
if err != nil && !errors.Is(err, http.ErrServerClosed) {
as.Log.Fatalln("Error while listening:", err)
} else {
as.Log.Debugln("Listener stopped.")
}
}
func (as *AppService) Stop() {
if as.server == nil {
return
}
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
defer cancel()
_ = as.server.Shutdown(ctx)
as.server = nil
}
// CheckServerToken checks if the given request originated from the Matrix homeserver.
func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) {
authHeader := r.Header.Get("Authorization")
if len(authHeader) > 0 && strings.HasPrefix(authHeader, "Bearer ") {
isValid = authHeader[len("Bearer "):] == as.Registration.ServerToken
} else {
queryToken := r.URL.Query().Get("access_token")
if len(queryToken) > 0 {
isValid = queryToken == as.Registration.ServerToken
} else {
Error{
ErrorCode: ErrUnknownToken,
HTTPStatus: http.StatusForbidden,
Message: "Missing access token",
}.Write(w)
return
}
}
if !isValid {
Error{
ErrorCode: ErrUnknownToken,
HTTPStatus: http.StatusForbidden,
Message: "Incorrect access token",
}.Write(w)
}
return
}
// PutTransaction handles a /transactions PUT call from the homeserver.
func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
if !as.CheckServerToken(w, r) {
return
}
vars := mux.Vars(r)
txnID := vars["txnID"]
if len(txnID) == 0 {
Error{
ErrorCode: ErrNoTransactionID,
HTTPStatus: http.StatusBadRequest,
Message: "Missing transaction ID",
}.Write(w)
return
}
defer r.Body.Close()
body, err := ioutil.ReadAll(r.Body)
if err != nil || len(body) == 0 {
Error{
ErrorCode: ErrNotJSON,
HTTPStatus: http.StatusBadRequest,
Message: "Missing request body",
}.Write(w)
return
}
if as.txnIDC.IsProcessed(txnID) {
// Duplicate transaction ID: no-op
WriteBlankOK(w)
as.Log.Debugfln("Ignoring duplicate transaction %s", txnID)
return
}
var txn Transaction
err = json.Unmarshal(body, &txn)
if err != nil {
as.Log.Warnfln("Failed to parse JSON of transaction %s: %v", txnID, err)
Error{
ErrorCode: ErrBadJSON,
HTTPStatus: http.StatusBadRequest,
Message: "Failed to parse body JSON",
}.Write(w)
} else {
as.handleTransaction(txnID, &txn)
WriteBlankOK(w)
}
}
func (as *AppService) handleTransaction(id string, txn *Transaction) {
as.Log.Debugfln("Starting handling of transaction %s (%s)", id, txn.ContentString())
if as.Registration.EphemeralEvents {
if txn.EphemeralEvents != nil {
as.handleEvents(txn.EphemeralEvents, event.EphemeralEventType)
} else if txn.MSC2409EphemeralEvents != nil {
as.handleEvents(txn.MSC2409EphemeralEvents, event.EphemeralEventType)
}
if txn.ToDeviceEvents != nil {
as.handleEvents(txn.ToDeviceEvents, event.ToDeviceEventType)
} else if txn.MSC2409ToDeviceEvents != nil {
as.handleEvents(txn.MSC2409ToDeviceEvents, event.ToDeviceEventType)
}
}
as.handleEvents(txn.Events, event.UnknownEventType)
if txn.DeviceLists != nil {
as.handleDeviceLists(txn.DeviceLists)
} else if txn.MSC3202DeviceLists != nil {
as.handleDeviceLists(txn.MSC3202DeviceLists)
}
if txn.DeviceOTKCount != nil {
as.handleOTKCounts(txn.DeviceOTKCount)
} else if txn.MSC3202DeviceOTKCount != nil {
as.handleOTKCounts(txn.MSC3202DeviceOTKCount)
}
as.txnIDC.MarkProcessed(id)
}
func (as *AppService) handleOTKCounts(otks OTKCountMap) {
for userID, devices := range otks {
for deviceID, otkCounts := range devices {
otkCounts.UserID = userID
otkCounts.DeviceID = deviceID
select {
case as.OTKCounts <- &otkCounts:
default:
as.Log.Warnfln("Dropped OTK count update for %s because channel is full", userID)
}
}
}
}
func (as *AppService) handleDeviceLists(dl *mautrix.DeviceLists) {
select {
case as.DeviceLists <- dl:
default:
as.Log.Warnln("Dropped device list update because channel is full")
}
}
func (as *AppService) handleEvents(evts []*event.Event, defaultTypeClass event.TypeClass) {
for _, evt := range evts {
evt.Mautrix.ReceivedAt = time.Now()
if defaultTypeClass != event.UnknownEventType {
evt.Type.Class = defaultTypeClass
} else if evt.StateKey != nil {
evt.Type.Class = event.StateEventType
} else {
evt.Type.Class = event.MessageEventType
}
err := evt.Content.ParseRaw(evt.Type)
if errors.Is(err, event.ErrUnsupportedContentType) {
as.Log.Debugfln("Not parsing content of %s: %v", evt.ID, err)
} else if err != nil {
as.Log.Debugfln("Failed to parse content of %s (type %s): %v", evt.ID, evt.Type.Type, err)
}
if evt.Type.IsState() {
// TODO remove this check after making sure the log doesn't happen
historical, ok := evt.Content.Raw["org.matrix.msc2716.historical"].(bool)
if ok && historical {
as.Log.Warnfln("Received historical state event %s (%s/%s)", evt.ID, evt.Type.Type, evt.GetStateKey())
} else {
as.UpdateState(evt)
}
}
if evt.Type.Class == event.ToDeviceEventType {
as.ToDeviceEvents <- evt
} else {
as.Events <- evt
}
}
}
// GetRoom handles a /rooms GET call from the homeserver.
func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
if !as.CheckServerToken(w, r) {
return
}
vars := mux.Vars(r)
roomAlias := vars["roomAlias"]
ok := as.QueryHandler.QueryAlias(roomAlias)
if ok {
WriteBlankOK(w)
} else {
Error{
ErrorCode: ErrUnknown,
HTTPStatus: http.StatusNotFound,
}.Write(w)
}
}
// GetUser handles a /users GET call from the homeserver.
func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
if !as.CheckServerToken(w, r) {
return
}
vars := mux.Vars(r)
userID := id.UserID(vars["userID"])
ok := as.QueryHandler.QueryUser(userID)
if ok {
WriteBlankOK(w)
} else {
Error{
ErrorCode: ErrUnknown,
HTTPStatus: http.StatusNotFound,
}.Write(w)
}
}
func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
if as.Live {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusInternalServerError)
}
w.Write([]byte("{}"))
}
func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) {
w.Header().Add("Content-Type", "application/json")
if as.Ready {
w.WriteHeader(http.StatusOK)
} else {
w.WriteHeader(http.StatusInternalServerError)
}
w.Write([]byte("{}"))
}

View File

@@ -1,537 +0,0 @@
// Copyright (c) 2020 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"encoding/json"
"errors"
"fmt"
"strings"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type IntentAPI struct {
*mautrix.Client
bot *mautrix.Client
as *AppService
Localpart string
UserID id.UserID
IsCustomPuppet bool
}
func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
userID := id.NewUserID(localpart, as.HomeserverDomain)
bot := as.BotClient()
if userID == bot.UserID {
bot = nil
}
return &IntentAPI{
Client: as.Client(userID),
bot: bot,
as: as,
Localpart: localpart,
UserID: userID,
IsCustomPuppet: false,
}
}
func (intent *IntentAPI) Register() error {
_, _, err := intent.Client.Register(&mautrix.ReqRegister{
Username: intent.Localpart,
Type: mautrix.AuthTypeAppservice,
InhibitLogin: true,
})
return err
}
func (intent *IntentAPI) EnsureRegistered() error {
if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) {
return nil
}
err := intent.Register()
if err != nil && !errors.Is(err, mautrix.MUserInUse) {
return fmt.Errorf("failed to ensure registered: %w", err)
}
intent.as.StateStore.MarkRegistered(intent.UserID)
return nil
}
type EnsureJoinedParams struct {
IgnoreCache bool
BotOverride *mautrix.Client
}
func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedParams) error {
var params EnsureJoinedParams
if len(extra) > 1 {
panic("invalid number of extra parameters")
} else if len(extra) == 1 {
params = extra[0]
}
if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache {
return nil
}
if err := intent.EnsureRegistered(); err != nil {
return fmt.Errorf("failed to ensure joined: %w", err)
}
resp, err := intent.JoinRoomByID(roomID)
if err != nil {
bot := intent.bot
if params.BotOverride != nil {
bot = params.BotOverride
}
if !errors.Is(err, mautrix.MForbidden) || bot == nil {
return fmt.Errorf("failed to ensure joined: %w", err)
}
_, inviteErr := bot.InviteUser(roomID, &mautrix.ReqInviteUser{
UserID: intent.UserID,
})
if inviteErr != nil {
return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr)
}
resp, err = intent.JoinRoomByID(roomID)
if err != nil {
return fmt.Errorf("failed to ensure joined after invite: %w", err)
}
}
intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin)
return nil
}
func (intent *IntentAPI) AddDoublePuppetValue(into interface{}) interface{} {
if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" {
return into
}
switch val := into.(type) {
case *map[string]interface{}:
if *val == nil {
valNonPtr := make(map[string]interface{})
*val = valNonPtr
}
(*val)[DoublePuppetKey] = intent.as.DoublePuppetValue
return val
case map[string]interface{}:
val[DoublePuppetKey] = intent.as.DoublePuppetValue
return val
case *event.Content:
if val.Raw == nil {
val.Raw = make(map[string]interface{})
}
val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue
return val
case event.Content:
if val.Raw == nil {
val.Raw = make(map[string]interface{})
}
val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue
return val
default:
return &event.Content{
Raw: map[string]interface{}{
DoublePuppetKey: intent.as.DoublePuppetValue,
},
Parsed: val,
}
}
}
func (intent *IntentAPI) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
return intent.Client.SendMessageEvent(roomID, eventType, contentJSON)
}
func (intent *IntentAPI) SendMassagedMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
return intent.Client.SendMessageEvent(roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
}
func (intent *IntentAPI) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, eventID id.EventID) {
fakeEvt := &event.Event{
StateKey: &stateKey,
Sender: intent.UserID,
Type: eventType,
ID: eventID,
RoomID: roomID,
Content: event.Content{},
}
var err error
fakeEvt.Content.VeryRaw, err = json.Marshal(contentJSON)
if err != nil {
intent.Logger.Debugfln("Failed to marshal state event content to update state store: %v", err)
return
}
err = json.Unmarshal(fakeEvt.Content.VeryRaw, &fakeEvt.Content.Raw)
if err != nil {
intent.Logger.Debugfln("Failed to unmarshal state event content to update state store: %v", err)
return
}
err = fakeEvt.Content.ParseRaw(fakeEvt.Type)
if err != nil {
intent.Logger.Debugfln("Failed to parse state event content to update state store: %v", err)
return
}
intent.as.UpdateState(fakeEvt)
}
func (intent *IntentAPI) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
if eventType != event.StateMember || stateKey != string(intent.UserID) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
resp, err := intent.Client.SendStateEvent(roomID, eventType, stateKey, contentJSON)
if err == nil && resp != nil {
intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON, resp.EventID)
}
return resp, err
}
func (intent *IntentAPI) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
contentJSON = intent.AddDoublePuppetValue(contentJSON)
resp, err := intent.Client.SendMassagedStateEvent(roomID, eventType, stateKey, contentJSON, ts)
if err == nil && resp != nil {
intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON, resp.EventID)
}
return resp, err
}
func (intent *IntentAPI) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error {
if err := intent.EnsureJoined(roomID); err != nil {
return err
}
err := intent.Client.StateEvent(roomID, eventType, stateKey, outContent)
if err == nil {
intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent, "")
}
return err
}
func (intent *IntentAPI) State(roomID id.RoomID) (mautrix.RoomStateMap, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
state, err := intent.Client.State(roomID)
if err == nil {
for _, events := range state {
for _, evt := range events {
intent.as.UpdateState(evt)
}
}
}
return state, err
}
func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.UserID, membership event.Membership, reason string, extraContent ...map[string]interface{}) (*mautrix.RespSendEvent, error) {
content := &event.MemberEventContent{
Membership: membership,
Reason: reason,
}
memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target)
if !ok {
if intent.as.GetProfile != nil {
memberContent = intent.as.GetProfile(target, roomID)
ok = memberContent != nil
}
if !ok {
profile, err := intent.GetProfile(target)
if err != nil {
intent.Logger.Debugfln("Failed to get profile for %s to fill new %s membership event: %v", target, membership, err)
} else {
content.Displayname = profile.DisplayName
content.AvatarURL = profile.AvatarURL.CUString()
}
}
}
if ok && memberContent != nil {
content.Displayname = memberContent.Displayname
content.AvatarURL = memberContent.AvatarURL
}
var extra map[string]interface{}
if len(extraContent) > 0 {
extra = extraContent[0]
}
return intent.SendStateEvent(roomID, event.StateMember, target.String(), &event.Content{
Parsed: content,
Raw: extra,
})
}
func (intent *IntentAPI) JoinRoomByID(roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipJoin, "", extraContent...)
return &mautrix.RespJoinRoom{}, err
}
resp, err = intent.Client.JoinRoomByID(roomID)
if err == nil {
intent.as.StateStore.SetMembership(roomID, intent.UserID, event.MembershipJoin)
}
return
}
func (intent *IntentAPI) LeaveRoom(roomID id.RoomID, extra ...interface{}) (resp *mautrix.RespLeaveRoom, err error) {
var extraContent map[string]interface{}
leaveReq := &mautrix.ReqLeave{}
for _, item := range extra {
switch val := item.(type) {
case map[string]interface{}:
extraContent = val
case *mautrix.ReqLeave:
leaveReq = val
}
}
if intent.IsCustomPuppet || extraContent != nil {
_, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipLeave, leaveReq.Reason, extraContent)
return &mautrix.RespLeaveRoom{}, err
}
resp, err = intent.Client.LeaveRoom(roomID, leaveReq)
if err == nil {
intent.as.StateStore.SetMembership(roomID, intent.UserID, event.MembershipLeave)
}
return
}
func (intent *IntentAPI) InviteUser(roomID id.RoomID, req *mautrix.ReqInviteUser, extraContent ...map[string]interface{}) (resp *mautrix.RespInviteUser, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipInvite, req.Reason, extraContent...)
return &mautrix.RespInviteUser{}, err
}
resp, err = intent.Client.InviteUser(roomID, req)
if err == nil {
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite)
}
return
}
func (intent *IntentAPI) KickUser(roomID id.RoomID, req *mautrix.ReqKickUser, extraContent ...map[string]interface{}) (resp *mautrix.RespKickUser, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...)
return &mautrix.RespKickUser{}, err
}
resp, err = intent.Client.KickUser(roomID, req)
if err == nil {
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
}
return
}
func (intent *IntentAPI) BanUser(roomID id.RoomID, req *mautrix.ReqBanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespBanUser, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipBan, req.Reason, extraContent...)
return &mautrix.RespBanUser{}, err
}
resp, err = intent.Client.BanUser(roomID, req)
if err == nil {
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan)
}
return
}
func (intent *IntentAPI) UnbanUser(roomID id.RoomID, req *mautrix.ReqUnbanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespUnbanUser, err error) {
if intent.IsCustomPuppet || len(extraContent) > 0 {
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...)
return &mautrix.RespUnbanUser{}, err
}
resp, err = intent.Client.UnbanUser(roomID, req)
if err == nil {
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
}
return
}
func (intent *IntentAPI) Member(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := intent.as.StateStore.TryGetMember(roomID, userID)
if !ok {
_ = intent.StateEvent(roomID, event.StateMember, string(userID), &member)
intent.as.StateStore.SetMember(roomID, userID, member)
}
return member
}
func (intent *IntentAPI) PowerLevels(roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
pl = intent.as.StateStore.GetPowerLevels(roomID)
if pl == nil {
pl = &event.PowerLevelsEventContent{}
err = intent.StateEvent(roomID, event.StatePowerLevels, "", pl)
if err == nil {
intent.as.StateStore.SetPowerLevels(roomID, pl)
}
}
return
}
func (intent *IntentAPI) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) {
resp, err = intent.SendStateEvent(roomID, event.StatePowerLevels, "", &levels)
if err == nil {
intent.as.StateStore.SetPowerLevels(roomID, levels)
}
return
}
func (intent *IntentAPI) SetPowerLevel(roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) {
pl, err := intent.PowerLevels(roomID)
if err != nil {
return nil, err
}
if pl.GetUserLevel(userID) != level {
pl.SetUserLevel(userID, level)
return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &pl)
}
return nil, nil
}
func (intent *IntentAPI) SendText(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
return intent.Client.SendText(roomID, text)
}
// Deprecated: This does not allow setting image metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
func (intent *IntentAPI) SendImage(roomID id.RoomID, body string, url id.ContentURI) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
return intent.Client.SendImage(roomID, body, url)
}
// Deprecated: This does not allow setting video metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
func (intent *IntentAPI) SendVideo(roomID id.RoomID, body string, url id.ContentURI) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
return intent.Client.SendVideo(roomID, body, url)
}
func (intent *IntentAPI) SendNotice(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
return intent.Client.SendNotice(roomID, text)
}
func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) {
if err := intent.EnsureJoined(roomID); err != nil {
return nil, err
}
var req mautrix.ReqRedact
if len(extra) > 0 {
req = extra[0]
}
intent.AddDoublePuppetValue(&req.Extra)
return intent.Client.RedactEvent(roomID, eventID, req)
}
func (intent *IntentAPI) SetRoomName(roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) {
return intent.SendStateEvent(roomID, event.StateRoomName, "", map[string]interface{}{
"name": roomName,
})
}
func (intent *IntentAPI) SetRoomAvatar(roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) {
return intent.SendStateEvent(roomID, event.StateRoomAvatar, "", map[string]interface{}{
"url": avatarURL.String(),
})
}
func (intent *IntentAPI) SetRoomTopic(roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) {
return intent.SendStateEvent(roomID, event.StateTopic, "", map[string]interface{}{
"topic": topic,
})
}
func (intent *IntentAPI) SetDisplayName(displayName string) error {
if err := intent.EnsureRegistered(); err != nil {
return err
}
resp, err := intent.Client.GetOwnDisplayName()
if err != nil {
return fmt.Errorf("failed to check current displayname: %w", err)
} else if resp.DisplayName == displayName {
// No need to update
return nil
}
return intent.Client.SetDisplayName(displayName)
}
func (intent *IntentAPI) SetAvatarURL(avatarURL id.ContentURI) error {
if err := intent.EnsureRegistered(); err != nil {
return err
}
resp, err := intent.Client.GetOwnAvatarURL()
if err != nil {
return fmt.Errorf("failed to check current avatar URL: %w", err)
} else if resp.FileID == avatarURL.FileID && resp.Homeserver == avatarURL.Homeserver {
// No need to update
return nil
}
return intent.Client.SetAvatarURL(avatarURL)
}
func (intent *IntentAPI) Whoami() (*mautrix.RespWhoami, error) {
if err := intent.EnsureRegistered(); err != nil {
return nil, err
}
return intent.Client.Whoami()
}
func (intent *IntentAPI) JoinedMembers(roomID id.RoomID) (resp *mautrix.RespJoinedMembers, err error) {
resp, err = intent.Client.JoinedMembers(roomID)
if err != nil {
return
}
for userID, member := range resp.Joined {
intent.as.StateStore.SetMember(roomID, userID, &event.MemberEventContent{
Membership: event.MembershipJoin,
AvatarURL: id.ContentURIString(member.AvatarURL),
Displayname: member.DisplayName,
})
}
return
}
func (intent *IntentAPI) Members(roomID id.RoomID, req ...mautrix.ReqMembers) (resp *mautrix.RespMembers, err error) {
resp, err = intent.Client.Members(roomID, req...)
if err != nil {
return
}
for _, evt := range resp.Chunk {
intent.as.UpdateState(evt)
}
return
}
func (intent *IntentAPI) EnsureInvited(roomID id.RoomID, userID id.UserID) error {
if !intent.as.StateStore.IsInvited(roomID, userID) {
_, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{
UserID: userID,
})
if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && strings.Contains(httpErr.RespError.Err, "is already in the room") {
return nil
}
return err
}
return nil
}

View File

@@ -1,138 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"encoding/json"
"fmt"
"net/http"
"strings"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type OTKCountMap = map[id.UserID]map[id.DeviceID]mautrix.OTKCount
type FallbackKeyMap = map[id.UserID]map[id.DeviceID][]id.KeyAlgorithm
// Transaction contains a list of events.
type Transaction struct {
Events []*event.Event `json:"events"`
EphemeralEvents []*event.Event `json:"ephemeral,omitempty"`
ToDeviceEvents []*event.Event `json:"to_device,omitempty"`
DeviceLists *mautrix.DeviceLists `json:"device_lists,omitempty"`
DeviceOTKCount OTKCountMap `json:"device_one_time_keys_count,omitempty"`
FallbackKeys FallbackKeyMap `json:"device_unused_fallback_key_types,omitempty"`
MSC2409EphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral,omitempty"`
MSC2409ToDeviceEvents []*event.Event `json:"de.sorunome.msc2409.to_device,omitempty"`
MSC3202DeviceLists *mautrix.DeviceLists `json:"org.matrix.msc3202.device_lists,omitempty"`
MSC3202DeviceOTKCount OTKCountMap `json:"org.matrix.msc3202.device_one_time_keys_count,omitempty"`
MSC3202FallbackKeys FallbackKeyMap `json:"org.matrix.msc3202.device_unused_fallback_key_types,omitempty"`
}
func (txn *Transaction) MarshalZerologObject(ctx *zerolog.Event) {
ctx.Int("pdu", len(txn.Events))
ctx.Int("edu", len(txn.EphemeralEvents))
ctx.Int("to_device", len(txn.ToDeviceEvents))
if len(txn.DeviceOTKCount) > 0 {
ctx.Int("otk_count_users", len(txn.DeviceOTKCount))
}
if txn.DeviceLists != nil {
ctx.Int("device_changes", len(txn.DeviceLists.Changed))
}
if txn.FallbackKeys != nil {
ctx.Int("fallback_key_users", len(txn.FallbackKeys))
}
}
func (txn *Transaction) ContentString() string {
var parts []string
if len(txn.Events) > 0 {
parts = append(parts, fmt.Sprintf("%d PDUs", len(txn.Events)))
}
if len(txn.EphemeralEvents) > 0 {
parts = append(parts, fmt.Sprintf("%d EDUs", len(txn.EphemeralEvents)))
} else if len(txn.MSC2409EphemeralEvents) > 0 {
parts = append(parts, fmt.Sprintf("%d EDUs (unstable)", len(txn.MSC2409EphemeralEvents)))
}
if len(txn.ToDeviceEvents) > 0 {
parts = append(parts, fmt.Sprintf("%d to-device events", len(txn.ToDeviceEvents)))
} else if len(txn.MSC2409ToDeviceEvents) > 0 {
parts = append(parts, fmt.Sprintf("%d to-device events (unstable)", len(txn.MSC2409ToDeviceEvents)))
}
if len(txn.DeviceOTKCount) > 0 {
parts = append(parts, fmt.Sprintf("OTK counts for %d users", len(txn.DeviceOTKCount)))
} else if len(txn.MSC3202DeviceOTKCount) > 0 {
parts = append(parts, fmt.Sprintf("OTK counts for %d users (unstable)", len(txn.MSC3202DeviceOTKCount)))
}
if txn.DeviceLists != nil {
parts = append(parts, fmt.Sprintf("%d device list changes", len(txn.DeviceLists.Changed)))
} else if txn.MSC3202DeviceLists != nil {
parts = append(parts, fmt.Sprintf("%d device list changes (unstable)", len(txn.MSC3202DeviceLists.Changed)))
}
if txn.FallbackKeys != nil {
parts = append(parts, fmt.Sprintf("unused fallback key counts for %d users", len(txn.FallbackKeys)))
} else if txn.MSC3202FallbackKeys != nil {
parts = append(parts, fmt.Sprintf("unused fallback key counts for %d users (unstable)", len(txn.MSC3202FallbackKeys)))
}
return strings.Join(parts, ", ")
}
// EventListener is a function that receives events.
type EventListener func(evt *event.Event)
// WriteBlankOK writes a blank OK message as a reply to a HTTP request.
func WriteBlankOK(w http.ResponseWriter) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(http.StatusOK)
_, _ = w.Write([]byte("{}"))
}
// Respond responds to a HTTP request with a JSON object.
func Respond(w http.ResponseWriter, data interface{}) error {
w.Header().Add("Content-Type", "application/json")
dataStr, err := json.Marshal(data)
if err != nil {
return err
}
_, err = w.Write(dataStr)
return err
}
// Error represents a Matrix protocol error.
type Error struct {
HTTPStatus int `json:"-"`
ErrorCode ErrorCode `json:"errcode"`
Message string `json:"error"`
}
func (err Error) Write(w http.ResponseWriter) {
w.Header().Add("Content-Type", "application/json")
w.WriteHeader(err.HTTPStatus)
_ = Respond(w, &err)
}
// ErrorCode is the machine-readable code in an Error.
type ErrorCode string
// Native ErrorCodes
const (
ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN"
ErrBadJSON ErrorCode = "M_BAD_JSON"
ErrNotJSON ErrorCode = "M_NOT_JSON"
ErrUnknown ErrorCode = "M_UNKNOWN"
)
// Custom ErrorCodes
const (
ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID"
)

View File

@@ -1,100 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"io/ioutil"
"regexp"
"gopkg.in/yaml.v3"
"maunium.net/go/mautrix/util"
)
// Registration contains the data in a Matrix appservice registration.
// See https://spec.matrix.org/v1.2/application-service-api/#registration
type Registration struct {
ID string `yaml:"id" json:"id"`
URL string `yaml:"url" json:"url"`
AppToken string `yaml:"as_token" json:"as_token"`
ServerToken string `yaml:"hs_token" json:"hs_token"`
SenderLocalpart string `yaml:"sender_localpart" json:"sender_localpart"`
RateLimited *bool `yaml:"rate_limited,omitempty" json:"rate_limited,omitempty"`
Namespaces Namespaces `yaml:"namespaces" json:"namespaces"`
Protocols []string `yaml:"protocols,omitempty" json:"protocols,omitempty"`
SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"`
EphemeralEvents bool `yaml:"push_ephemeral,omitempty" json:"push_ephemeral,omitempty"`
}
// CreateRegistration creates a Registration with random appservice and homeserver tokens.
func CreateRegistration() *Registration {
return &Registration{
AppToken: util.RandomString(64),
ServerToken: util.RandomString(64),
}
}
// LoadRegistration loads a YAML file and turns it into a Registration.
func LoadRegistration(path string) (*Registration, error) {
data, err := ioutil.ReadFile(path)
if err != nil {
return nil, err
}
reg := &Registration{}
err = yaml.Unmarshal(data, reg)
if err != nil {
return nil, err
}
return reg, nil
}
// Save saves this Registration into a file at the given path.
func (reg *Registration) Save(path string) error {
data, err := yaml.Marshal(reg)
if err != nil {
return err
}
return ioutil.WriteFile(path, data, 0600)
}
// YAML returns the registration in YAML format.
func (reg *Registration) YAML() (string, error) {
data, err := yaml.Marshal(reg)
if err != nil {
return "", err
}
return string(data), nil
}
// Namespaces contains the three areas that appservices can reserve parts of.
type Namespaces struct {
UserIDs NamespaceList `yaml:"users,omitempty" json:"users,omitempty"`
RoomAliases NamespaceList `yaml:"aliases,omitempty" json:"aliases,omitempty"`
RoomIDs NamespaceList `yaml:"rooms,omitempty" json:"rooms,omitempty"`
}
// Namespace is a reserved namespace in any area.
type Namespace struct {
Regex string `yaml:"regex" json:"regex"`
Exclusive bool `yaml:"exclusive" json:"exclusive"`
}
type NamespaceList []Namespace
func (nsl *NamespaceList) Register(regex *regexp.Regexp, exclusive bool) {
ns := Namespace{
Regex: regex.String(),
Exclusive: exclusive,
}
if nsl == nil {
*nsl = []Namespace{ns}
} else {
*nsl = append(*nsl, ns)
}
}

View File

@@ -1,186 +0,0 @@
// Copyright (c) 2020 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"sync"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
type StateStore interface {
IsRegistered(userID id.UserID) bool
MarkRegistered(userID id.UserID)
IsInRoom(roomID id.RoomID, userID id.UserID) bool
IsInvited(roomID id.RoomID, userID id.UserID) bool
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
GetPowerLevel(roomID id.RoomID, userID id.UserID) int
GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int
HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool
}
func (as *AppService) UpdateState(evt *event.Event) {
switch content := evt.Content.Parsed.(type) {
case *event.MemberEventContent:
as.StateStore.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
case *event.PowerLevelsEventContent:
as.StateStore.SetPowerLevels(evt.RoomID, content)
}
}
type BasicStateStore struct {
Registrations map[id.UserID]bool `json:"registrations"`
Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"`
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
registrationsLock sync.RWMutex
membersLock sync.RWMutex
powerLevelsLock sync.RWMutex
}
func NewBasicStateStore() StateStore {
return &BasicStateStore{
Registrations: make(map[id.UserID]bool),
Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent),
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
}
}
func (store *BasicStateStore) IsRegistered(userID id.UserID) bool {
store.registrationsLock.RLock()
defer store.registrationsLock.RUnlock()
registered, ok := store.Registrations[userID]
return ok && registered
}
func (store *BasicStateStore) MarkRegistered(userID id.UserID) {
store.registrationsLock.Lock()
defer store.registrationsLock.Unlock()
store.Registrations[userID] = true
}
func (store *BasicStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
store.membersLock.RLock()
members, ok := store.Members[roomID]
store.membersLock.RUnlock()
if !ok {
members = make(map[id.UserID]*event.MemberEventContent)
store.membersLock.Lock()
store.Members[roomID] = members
store.membersLock.Unlock()
}
return members
}
func (store *BasicStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
return store.GetMember(roomID, userID).Membership
}
func (store *BasicStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return member
}
func (store *BasicStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
store.membersLock.RLock()
defer store.membersLock.RUnlock()
members, membersOk := store.Members[roomID]
if !membersOk {
return
}
member, ok = members[userID]
return
}
func (store *BasicStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *BasicStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *BasicStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
}
}
return false
}
func (store *BasicStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
members = map[id.UserID]*event.MemberEventContent{
userID: {Membership: membership},
}
} else {
member, ok := members[userID]
if !ok {
members[userID] = &event.MemberEventContent{Membership: membership}
} else {
member.Membership = membership
members[userID] = member
}
}
store.Members[roomID] = members
store.membersLock.Unlock()
}
func (store *BasicStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
members = map[id.UserID]*event.MemberEventContent{
userID: member,
}
} else {
members[userID] = member
}
store.Members[roomID] = members
store.membersLock.Unlock()
}
func (store *BasicStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
store.powerLevelsLock.Lock()
store.PowerLevels[roomID] = levels
store.powerLevelsLock.Unlock()
}
func (store *BasicStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
store.powerLevelsLock.RUnlock()
return
}
func (store *BasicStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *BasicStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *BasicStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}

View File

@@ -1,43 +0,0 @@
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import "sync"
type TransactionIDCache struct {
array []string
arrayPtr int
hash map[string]struct{}
lock sync.RWMutex
}
func NewTransactionIDCache(size int) *TransactionIDCache {
return &TransactionIDCache{
array: make([]string, size),
hash: make(map[string]struct{}),
}
}
func (txnIDC *TransactionIDCache) IsProcessed(txnID string) bool {
txnIDC.lock.RLock()
_, exists := txnIDC.hash[txnID]
txnIDC.lock.RUnlock()
return exists
}
func (txnIDC *TransactionIDCache) MarkProcessed(txnID string) {
txnIDC.lock.Lock()
txnIDC.hash[txnID] = struct{}{}
if txnIDC.array[txnIDC.arrayPtr] != "" {
for i := 0; i < len(txnIDC.array)/8; i++ {
delete(txnIDC.hash, txnIDC.array[txnIDC.arrayPtr+i])
txnIDC.array[txnIDC.arrayPtr+i] = ""
}
}
txnIDC.array[txnIDC.arrayPtr] = txnID
txnIDC.lock.Unlock()
}

View File

@@ -1,392 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package appservice
import (
"context"
"encoding/json"
"errors"
"fmt"
"net/http"
"net/url"
"path/filepath"
"strings"
"sync"
"sync/atomic"
"time"
"github.com/gorilla/websocket"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
type WebsocketRequest struct {
ReqID int `json:"id,omitempty"`
Command string `json:"command"`
Data interface{} `json:"data"`
Deadline time.Duration `json:"-"`
}
type WebsocketCommand struct {
ReqID int `json:"id,omitempty"`
Command string `json:"command"`
Data json.RawMessage `json:"data"`
}
func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest {
if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" {
return nil
}
cmd := "response"
if !ok {
cmd = "error"
}
if err, isError := data.(error); isError {
var errorData json.RawMessage
var jsonErr error
unwrappedErr := err
var prefixMessage string
for unwrappedErr != nil {
errorData, jsonErr = json.Marshal(unwrappedErr)
if errorData != nil && len(errorData) > 2 && jsonErr == nil {
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
prefixMessage = strings.TrimRight(prefixMessage, ": ")
break
}
unwrappedErr = errors.Unwrap(unwrappedErr)
}
if errorData != nil {
if !gjson.GetBytes(errorData, "message").Exists() {
errorData, _ = sjson.SetBytes(errorData, "message", err.Error())
} // else: marshaled error contains a message already
} else {
errorData, _ = sjson.SetBytes(nil, "message", err.Error())
}
if len(prefixMessage) > 0 {
errorData, _ = sjson.SetBytes(errorData, "prefix_message", prefixMessage)
}
data = errorData
}
return &WebsocketRequest{
ReqID: wsc.ReqID,
Command: cmd,
Data: data,
}
}
type WebsocketTransaction struct {
Status string `json:"status"`
TxnID string `json:"txn_id"`
Transaction
}
type WebsocketTransactionResponse struct {
TxnID string `json:"txn_id"`
}
type WebsocketMessage struct {
WebsocketTransaction
WebsocketCommand
}
const (
WebsocketCloseConnReplaced = 4001
WebsocketCloseTxnNotAcknowledged = 4002
)
type MeowWebsocketCloseCode string
const (
MeowServerShuttingDown MeowWebsocketCloseCode = "server_shutting_down"
MeowConnectionReplaced MeowWebsocketCloseCode = "conn_replaced"
MeowTxnNotAcknowledged MeowWebsocketCloseCode = "transactions_not_acknowledged"
)
var (
ErrWebsocketManualStop = errors.New("the websocket was disconnected manually")
ErrWebsocketOverridden = errors.New("a new call to StartWebsocket overrode the previous connection")
ErrWebsocketUnknownError = errors.New("an unknown error occurred")
ErrWebsocketNotConnected = errors.New("websocket not connected")
ErrWebsocketClosed = errors.New("websocket closed before response received")
)
func (mwcc MeowWebsocketCloseCode) String() string {
switch mwcc {
case MeowServerShuttingDown:
return "the server is shutting down"
case MeowConnectionReplaced:
return "the connection was replaced by another client"
case MeowTxnNotAcknowledged:
return "transactions were not acknowledged"
default:
return string(mwcc)
}
}
type CloseCommand struct {
Code int `json:"-"`
Command string `json:"command"`
Status MeowWebsocketCloseCode `json:"status"`
}
func (cc CloseCommand) Error() string {
return fmt.Sprintf("websocket: close %d: %s", cc.Code, cc.Status.String())
}
func parseCloseError(err error) error {
closeError := &websocket.CloseError{}
if !errors.As(err, &closeError) {
return err
}
var closeCommand CloseCommand
closeCommand.Code = closeError.Code
closeCommand.Command = "disconnect"
if len(closeError.Text) > 0 {
jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
if jsonErr != nil {
return err
}
}
if len(closeCommand.Status) == 0 {
if closeCommand.Code == WebsocketCloseConnReplaced {
closeCommand.Status = MeowConnectionReplaced
} else if closeCommand.Code == websocket.CloseServiceRestart {
closeCommand.Status = MeowServerShuttingDown
}
}
return &closeCommand
}
func (as *AppService) HasWebsocket() bool {
return as.ws != nil
}
func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error {
ws := as.ws
if cmd == nil {
return nil
} else if ws == nil {
return ErrWebsocketNotConnected
}
as.wsWriteLock.Lock()
defer as.wsWriteLock.Unlock()
if cmd.Deadline == 0 {
cmd.Deadline = 3 * time.Minute
}
_ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline))
return ws.WriteJSON(cmd)
}
func (as *AppService) clearWebsocketResponseWaiters() {
as.websocketRequestsLock.Lock()
for _, waiter := range as.websocketRequests {
waiter <- &WebsocketCommand{Command: "__websocket_closed"}
}
as.websocketRequests = make(map[int]chan<- *WebsocketCommand)
as.websocketRequestsLock.Unlock()
}
func (as *AppService) addWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) {
as.websocketRequestsLock.Lock()
as.websocketRequests[reqID] = waiter
as.websocketRequestsLock.Unlock()
}
func (as *AppService) removeWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) {
as.websocketRequestsLock.Lock()
existingWaiter, ok := as.websocketRequests[reqID]
if ok && existingWaiter == waiter {
delete(as.websocketRequests, reqID)
}
close(waiter)
as.websocketRequestsLock.Unlock()
}
type ErrorResponse struct {
Code string `json:"code"`
Message string `json:"message"`
}
func (er *ErrorResponse) Error() string {
return fmt.Sprintf("%s: %s", er.Code, er.Message)
}
func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error {
cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1))
respChan := make(chan *WebsocketCommand, 1)
as.addWebsocketResponseWaiter(cmd.ReqID, respChan)
defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan)
err := as.SendWebsocket(cmd)
if err != nil {
return err
}
select {
case resp := <-respChan:
if resp.Command == "__websocket_closed" {
return ErrWebsocketClosed
} else if resp.Command == "error" {
var respErr ErrorResponse
err = json.Unmarshal(resp.Data, &respErr)
if err != nil {
return fmt.Errorf("failed to parse error JSON: %w", err)
}
return &respErr
} else if response != nil {
err = json.Unmarshal(resp.Data, &response)
if err != nil {
return fmt.Errorf("failed to parse response JSON: %w", err)
}
return nil
} else {
return nil
}
case <-ctx.Done():
return ctx.Err()
}
}
func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) {
as.Log.Warnfln("No handler for websocket command %s (%d)", cmd.Command, cmd.ReqID)
return false, fmt.Errorf("unknown request type")
}
func (as *AppService) SetWebsocketCommandHandler(cmd string, handler WebsocketHandler) {
as.websocketHandlersLock.Lock()
as.websocketHandlers[cmd] = handler
as.websocketHandlersLock.Unlock()
}
func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
defer stopFunc(ErrWebsocketUnknownError)
for {
var msg WebsocketMessage
err := ws.ReadJSON(&msg)
if err != nil {
as.Log.Debugln("Error reading from websocket:", err)
stopFunc(parseCloseError(err))
return
}
if msg.Command == "" || msg.Command == "transaction" {
if msg.TxnID == "" || !as.txnIDC.IsProcessed(msg.TxnID) {
as.handleTransaction(msg.TxnID, &msg.Transaction)
} else {
as.Log.Debugfln("Ignoring duplicate transaction %s (%s)", msg.TxnID, msg.Transaction.ContentString())
}
go func() {
err = as.SendWebsocket(msg.MakeResponse(true, &WebsocketTransactionResponse{TxnID: msg.TxnID}))
if err != nil {
as.Log.Warnfln("Failed to send response to %s %d: %v", msg.Command, msg.ReqID, err)
}
}()
} else if msg.Command == "connect" {
as.Log.Debugln("Websocket connect confirmation received")
} else if msg.Command == "response" || msg.Command == "error" {
as.websocketRequestsLock.RLock()
respChan, ok := as.websocketRequests[msg.ReqID]
if ok {
select {
case respChan <- &msg.WebsocketCommand:
default:
as.Log.Warnfln("Failed to handle response to %d: channel didn't accept response", msg.ReqID)
}
} else {
as.Log.Warnfln("Dropping response to %d: unknown request ID", msg.ReqID)
}
as.websocketRequestsLock.RUnlock()
} else {
as.Log.Debugfln("Received command request %s %d", msg.Command, msg.ReqID)
as.websocketHandlersLock.RLock()
handler, ok := as.websocketHandlers[msg.Command]
as.websocketHandlersLock.RUnlock()
if !ok {
handler = as.unknownCommandHandler
}
go func() {
okResp, data := handler(msg.WebsocketCommand)
err = as.SendWebsocket(msg.MakeResponse(okResp, data))
if err != nil {
as.Log.Warnfln("Failed to send response to %s %d: %v", msg.Command, msg.ReqID, err)
} else if okResp {
as.Log.Debugfln("Sent success response to %s %d", msg.Command, msg.ReqID)
} else {
as.Log.Debugfln("Sent error response to %s %d", msg.Command, msg.ReqID)
}
}()
}
}
}
func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
parsed, err := url.Parse(baseURL)
if err != nil {
return fmt.Errorf("failed to parse URL: %w", err)
}
parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
if parsed.Scheme == "http" {
parsed.Scheme = "ws"
} else if parsed.Scheme == "https" {
parsed.Scheme = "wss"
}
ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
"Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
"User-Agent": []string{as.BotClient().UserAgent},
"X-Mautrix-Process-ID": []string{as.ProcessID},
"X-Mautrix-Websocket-Version": []string{"3"},
})
if resp != nil && resp.StatusCode >= 400 {
var errResp Error
err = json.NewDecoder(resp.Body).Decode(&errResp)
if err != nil {
return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode)
} else {
return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message)
}
} else if err != nil {
return fmt.Errorf("failed to open websocket: %w", err)
}
if as.StopWebsocket != nil {
as.StopWebsocket(ErrWebsocketOverridden)
}
closeChan := make(chan error)
closeChanOnce := sync.Once{}
stopFunc := func(err error) {
closeChanOnce.Do(func() {
closeChan <- err
})
}
as.ws = ws
as.StopWebsocket = stopFunc
as.PrepareWebsocket()
as.Log.Debugln("Appservice transaction websocket connected")
go as.consumeWebsocket(stopFunc, ws)
if onConnect != nil {
onConnect()
}
closeErr := <-closeChan
if as.ws == ws {
as.clearWebsocketResponseWaiters()
as.ws = nil
}
_ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second))
err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
as.Log.Warnln("Error writing close message to websocket:", err)
}
err = ws.Close()
if err != nil {
as.Log.Warnln("Error closing websocket:", err)
}
return closeErr
}

View File

@@ -1,253 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package bridgeconfig
import (
"fmt"
"regexp"
"strings"
"maunium.net/go/mautrix/appservice"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util"
up "maunium.net/go/mautrix/util/configupgrade"
)
type HomeserverSoftware string
const (
SoftwareStandard HomeserverSoftware = "standard"
SoftwareAsmux HomeserverSoftware = "asmux"
SoftwareHungry HomeserverSoftware = "hungry"
)
var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{
SoftwareStandard: true,
SoftwareAsmux: true,
SoftwareHungry: true,
}
type HomeserverConfig struct {
Address string `yaml:"address"`
Domain string `yaml:"domain"`
AsyncMedia bool `yaml:"async_media"`
Software HomeserverSoftware `yaml:"software"`
StatusEndpoint string `yaml:"status_endpoint"`
MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
WSProxy string `yaml:"websocket_proxy"`
WSPingInterval int `yaml:"ping_interval_seconds"`
}
type AppserviceConfig struct {
Address string `yaml:"address"`
Hostname string `yaml:"hostname"`
Port uint16 `yaml:"port"`
Database DatabaseConfig `yaml:"database"`
ID string `yaml:"id"`
Bot BotUserConfig `yaml:"bot"`
ASToken string `yaml:"as_token"`
HSToken string `yaml:"hs_token"`
EphemeralEvents bool `yaml:"ephemeral_events"`
AsyncTransactions bool `yaml:"async_transactions"`
}
func (config *BaseConfig) MakeUserIDRegex(matcher string) *regexp.Regexp {
usernamePlaceholder := strings.ToLower(util.RandomString(16))
usernameTemplate := fmt.Sprintf("@%s:%s",
config.Bridge.FormatUsername(usernamePlaceholder),
config.Homeserver.Domain)
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1)
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
return regexp.MustCompile(usernameTemplate)
}
// GenerateRegistration generates a registration file for the homeserver.
func (config *BaseConfig) GenerateRegistration() *appservice.Registration {
registration := appservice.CreateRegistration()
config.AppService.HSToken = registration.ServerToken
config.AppService.ASToken = registration.AppToken
config.AppService.copyToRegistration(registration)
registration.SenderLocalpart = util.RandomString(32)
botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
regexp.QuoteMeta(config.AppService.Bot.Username),
regexp.QuoteMeta(config.Homeserver.Domain)))
registration.Namespaces.UserIDs.Register(botRegex, true)
registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true)
return registration
}
func (config *BaseConfig) MakeAppService() *appservice.AppService {
as := appservice.Create()
as.HomeserverDomain = config.Homeserver.Domain
as.HomeserverURL = config.Homeserver.Address
as.Host.Hostname = config.AppService.Hostname
as.Host.Port = config.AppService.Port
as.DefaultHTTPRetries = 4
as.Registration = config.AppService.GetRegistration()
return as
}
// GetRegistration copies the data from the bridge config into an *appservice.Registration struct.
// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver.
func (asc *AppserviceConfig) GetRegistration() *appservice.Registration {
reg := &appservice.Registration{}
asc.copyToRegistration(reg)
reg.SenderLocalpart = asc.Bot.Username
reg.ServerToken = asc.HSToken
reg.AppToken = asc.ASToken
return reg
}
func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) {
registration.ID = asc.ID
registration.URL = asc.Address
falseVal := false
registration.RateLimited = &falseVal
registration.EphemeralEvents = asc.EphemeralEvents
registration.SoruEphemeralEvents = asc.EphemeralEvents
}
type BotUserConfig struct {
Username string `yaml:"username"`
Displayname string `yaml:"displayname"`
Avatar string `yaml:"avatar"`
ParsedAvatar id.ContentURI `yaml:"-"`
}
type serializableBUC BotUserConfig
func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
var sbuc serializableBUC
err := unmarshal(&sbuc)
if err != nil {
return err
}
*buc = (BotUserConfig)(sbuc)
if buc.Avatar != "" && buc.Avatar != "remove" {
buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar)
if err != nil {
return fmt.Errorf("%w in bot avatar", err)
}
}
return nil
}
type DatabaseConfig struct {
Type string `yaml:"type"`
URI string `yaml:"uri"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
}
type BridgeConfig interface {
FormatUsername(username string) string
GetEncryptionConfig() EncryptionConfig
GetCommandPrefix() string
GetManagementRoomTexts() ManagementRoomTexts
GetResendBridgeInfo() bool
EnableMessageStatusEvents() bool
EnableMessageErrorNotices() bool
Validate() error
}
type EncryptionConfig struct {
Allow bool `yaml:"allow"`
Default bool `yaml:"default"`
Require bool `yaml:"require"`
Appservice bool `yaml:"appservice"`
VerificationLevels struct {
Receive id.TrustState `yaml:"receive"`
Send id.TrustState `yaml:"send"`
Share id.TrustState `yaml:"share"`
} `yaml:"verification_levels"`
AllowKeySharing bool `yaml:"allow_key_sharing"`
Rotation struct {
EnableCustom bool `yaml:"enable_custom"`
Milliseconds int64 `yaml:"milliseconds"`
Messages int `yaml:"messages"`
} `yaml:"rotation"`
}
type ManagementRoomTexts struct {
Welcome string `yaml:"welcome"`
WelcomeConnected string `yaml:"welcome_connected"`
WelcomeUnconnected string `yaml:"welcome_unconnected"`
AdditionalHelp string `yaml:"additional_help"`
}
type BaseConfig struct {
Homeserver HomeserverConfig `yaml:"homeserver"`
AppService AppserviceConfig `yaml:"appservice"`
Bridge BridgeConfig `yaml:"-"`
Logging appservice.LogConfig `yaml:"logging"`
}
func doUpgrade(helper *up.Helper) {
helper.Copy(up.Str, "homeserver", "address")
helper.Copy(up.Str, "homeserver", "domain")
if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" {
helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software")
} else {
helper.Copy(up.Str, "homeserver", "software")
}
helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
helper.Copy(up.Bool, "homeserver", "async_media")
helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy")
helper.Copy(up.Int, "homeserver", "ping_interval_seconds")
helper.Copy(up.Str, "appservice", "address")
helper.Copy(up.Str, "appservice", "hostname")
helper.Copy(up.Int, "appservice", "port")
if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" {
helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type")
} else {
helper.Copy(up.Str, "appservice", "database", "type")
}
helper.Copy(up.Str, "appservice", "database", "uri")
helper.Copy(up.Int, "appservice", "database", "max_open_conns")
helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
helper.Copy(up.Str, "appservice", "id")
helper.Copy(up.Str, "appservice", "bot", "username")
helper.Copy(up.Str, "appservice", "bot", "displayname")
helper.Copy(up.Str, "appservice", "bot", "avatar")
helper.Copy(up.Bool, "appservice", "ephemeral_events")
helper.Copy(up.Bool, "appservice", "async_transactions")
helper.Copy(up.Str, "appservice", "as_token")
helper.Copy(up.Str, "appservice", "hs_token")
helper.Copy(up.Str, "logging", "directory")
helper.Copy(up.Str|up.Null, "logging", "file_name_format")
helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
helper.Copy(up.Int, "logging", "file_mode")
helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
helper.Copy(up.Str, "logging", "print_level")
helper.Copy(up.Bool, "logging", "print_json")
helper.Copy(up.Bool, "logging", "file_json")
}
// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks.
var Upgrader = up.SimpleUpgrader(doUpgrade)

View File

@@ -1,71 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package bridgeconfig
import (
"strconv"
"strings"
"maunium.net/go/mautrix/id"
)
type PermissionConfig map[string]PermissionLevel
type PermissionLevel int
const (
PermissionLevelBlock PermissionLevel = 0
PermissionLevelRelay PermissionLevel = 5
PermissionLevelUser PermissionLevel = 10
PermissionLevelAdmin PermissionLevel = 100
)
var namesToLevels = map[string]PermissionLevel{
"block": PermissionLevelBlock,
"relay": PermissionLevelRelay,
"user": PermissionLevelUser,
"admin": PermissionLevelAdmin,
}
func RegisterPermissionLevel(name string, level PermissionLevel) {
namesToLevels[name] = level
}
func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
rawPC := make(map[string]string)
err := unmarshal(&rawPC)
if err != nil {
return err
}
if *pc == nil {
*pc = make(map[string]PermissionLevel)
}
for key, value := range rawPC {
level, ok := namesToLevels[strings.ToLower(value)]
if ok {
(*pc)[key] = level
} else if val, err := strconv.Atoi(value); err == nil {
(*pc)[key] = PermissionLevel(val)
} else {
(*pc)[key] = PermissionLevelBlock
}
}
return nil
}
func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel {
if level, ok := pc[string(userID)]; ok {
return level
} else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok {
return level
} else if level, ok = pc["*"]; ok {
return level
} else {
return PermissionLevelBlock
}
}

View File

@@ -18,32 +18,33 @@ import (
"sync/atomic"
"time"
"github.com/rs/zerolog"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules"
)
type CryptoHelper interface {
Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error)
Decrypt(*event.Event) (*event.Event, error)
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
Init() error
}
// Deprecated: switch to zerolog
type Logger interface {
Debugfln(message string, args ...interface{})
}
// StubLogger is an implementation of Logger that does nothing
type StubLogger struct{}
func (sl *StubLogger) Debugfln(message string, args ...interface{}) {}
func (sl *StubLogger) Warnfln(message string, args ...interface{}) {}
var stubLogger = &StubLogger{}
// Deprecated: switch to zerolog
type WarnLogger interface {
Logger
Warnfln(message string, args ...interface{})
}
type Stringifiable interface {
String() string
}
// Client represents a Matrix client.
type Client struct {
HomeserverURL *url.URL // The base homeserver URL
@@ -53,9 +54,18 @@ type Client struct {
UserAgent string // The value for the User-Agent header
Client *http.Client // The underlying HTTP client which will be used to make HTTP requests.
Syncer Syncer // The thing which can process /sync responses
Store Storer // The thing which can store rooms/tokens/ids
Logger Logger
SyncPresence event.Presence
Store SyncStore // The thing which can store tokens/ids
StateStore StateStore
Crypto CryptoHelper
Log zerolog.Logger
// Deprecated: switch to the zerolog instance in Log
Logger Logger
RequestHook func(req *http.Request)
ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration)
SyncPresence event.Presence
StreamSyncMinAge time.Duration
@@ -67,10 +77,9 @@ type Client struct {
txnID int32
// The ?user_id= query parameter for application services. This must be set *prior* to calling a method.
// If this is empty, no user_id parameter will be sent.
// See https://spec.matrix.org/v1.2/application-service-api/#identity-assertion
AppServiceUserID id.UserID
// Should the ?user_id= query parameter be set in requests?
// See https://spec.matrix.org/v1.6/application-service-api/#identity-assertion
SetAppServiceUserID bool
syncingID uint32 // Identifies the current Sync. Only one Sync can be active at any given time.
}
@@ -180,7 +189,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error {
for {
streamResp := false
if cli.StreamSyncMinAge > 0 && time.Since(lastSuccessfulSync) > cli.StreamSyncMinAge {
cli.Logger.Debugfln("Last sync is old, will stream next response")
cli.Log.Debug().Msg("Last sync is old, will stream next response")
streamResp = true
}
resSync, err := cli.FullSyncRequest(ReqSync{
@@ -242,40 +251,52 @@ func (cli *Client) StopSync() {
cli.incrementSyncingID()
}
const logBodyContextKey = "fi.mau.mautrix.log_body"
const logRequestIDContextKey = "fi.mau.mautrix.request_id"
type contextKey int
const (
LogBodyContextKey contextKey = iota
LogRequestIDContextKey
)
func (cli *Client) LogRequest(req *http.Request) {
if cli.Logger == stubLogger {
return
if cli.RequestHook != nil {
cli.RequestHook(req)
}
body, ok := req.Context().Value(logBodyContextKey).(string)
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
if ok && len(body) > 0 {
cli.Logger.Debugfln("req #%d: %s %s %s", reqID, req.Method, req.URL.String(), body)
} else {
cli.Logger.Debugfln("req #%d: %s %s", reqID, req.Method, req.URL.String())
evt := zerolog.Ctx(req.Context()).Debug().
Str("method", req.Method).
Str("url", req.URL.String())
body := req.Context().Value(LogBodyContextKey)
if body != nil {
evt.Interface("body", body)
}
evt.Msg("Sending request")
}
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, handlerErr error, contentLength int, duration time.Duration) {
if cli.Logger == stubLogger {
return
if cli.ResponseHook != nil {
cli.ResponseHook(req, resp, duration)
}
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
mime := resp.Header.Get("Content-Type")
var suffix string
if handlerErr != nil {
suffix = fmt.Sprintf(" (but parsing the body failed)")
}
length := resp.ContentLength
if length == -1 && contentLength > 0 {
length = int64(contentLength)
}
cli.Logger.Debugfln(
"req #%d (%s) completed in %s with status %d and %d bytes of %s body%s",
reqID, strings.TrimPrefix(req.URL.Path, "/_matrix/client"), duration, resp.StatusCode, length, mime, suffix,
)
path := strings.TrimPrefix(req.URL.Path, cli.HomeserverURL.Path)
path = strings.TrimPrefix(path, "/_matrix/client")
evt := zerolog.Ctx(req.Context()).Debug().
Str("method", req.Method).
Str("path", path).
Int("status_code", resp.StatusCode).
Int64("response_length", length).
Str("response_mime", mime).
Dur("duration", duration)
if handlerErr != nil {
evt.AnErr("body_parse_err", handlerErr)
}
if serverRequestID := resp.Header.Get("X-Beeper-Request-ID"); serverRequestID != "" {
evt.Str("beeper_request_id", serverRequestID)
}
evt.Msg("Request completed")
}
func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) {
@@ -297,13 +318,14 @@ type FullRequest struct {
MaxAttempts int
SensitiveContent bool
Handler ClientResponseHandler
Logger *zerolog.Logger
}
var requestID int32
var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes"
func (params *FullRequest) compileRequest() (*http.Request, error) {
var logBody string
var logBody any
reqBody := params.RequestBody
if params.Context == nil {
params.Context = context.Background()
@@ -319,7 +341,7 @@ func (params *FullRequest) compileRequest() (*http.Request, error) {
if params.SensitiveContent && !logSensitiveContent {
logBody = "<sensitive content omitted>"
} else {
logBody = string(jsonStr)
logBody = params.RequestJSON
}
reqBody = bytes.NewReader(jsonStr)
} else if params.RequestBytes != nil {
@@ -330,12 +352,20 @@ func (params *FullRequest) compileRequest() (*http.Request, error) {
logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
} else if params.Method != http.MethodGet && params.Method != http.MethodHead {
params.RequestJSON = struct{}{}
logBody = "<default empty object>"
logBody = params.RequestJSON
reqBody = bytes.NewReader([]byte("{}"))
}
ctx := context.WithValue(params.Context, logBodyContextKey, logBody)
reqID := atomic.AddInt32(&requestID, 1)
ctx = context.WithValue(ctx, logRequestIDContextKey, int(reqID))
ctx := params.Context
logger := zerolog.Ctx(ctx)
if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
logger = params.Logger
}
ctx = logger.With().
Int32("req_id", reqID).
Logger().WithContext(ctx)
ctx = context.WithValue(ctx, LogBodyContextKey, logBody)
ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID))
req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody)
if err != nil {
return nil, HTTPError{
@@ -365,6 +395,9 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
if params.MaxAttempts == 0 {
params.MaxAttempts = 1 + cli.DefaultHTTPRetries
}
if params.Logger == nil {
params.Logger = &cli.Log
}
req, err := params.compileRequest()
if err != nil {
return nil, err
@@ -379,30 +412,31 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler)
}
func (cli *Client) logWarning(format string, args ...interface{}) {
warnLogger, ok := cli.Logger.(WarnLogger)
if ok {
warnLogger.Warnfln(format, args...)
} else {
cli.Logger.Debugfln(format, args...)
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
return &cli.Log
}
return log
}
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
log := zerolog.Ctx(req.Context())
if req.Body != nil {
if req.GetBody == nil {
cli.logWarning("Failed to get new body to retry request #%d: GetBody is nil", reqID)
log.Warn().Msg("Failed to get new body to retry request: GetBody is nil")
return nil, cause
}
var err error
req.Body, err = req.GetBody()
if err != nil {
cli.logWarning("Failed to get new body to retry request #%d: %v", reqID, err)
log.Warn().Err(err).Msg("Failed to get new body to retry request")
return nil, cause
}
}
cli.logWarning("Request #%d failed: %v, retrying in %d seconds", reqID, cause, int(backoff.Seconds()))
log.Warn().Err(cause).
Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Request failed, retrying")
time.Sleep(backoff)
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler)
}
@@ -421,22 +455,23 @@ func (cli *Client) readRequestBody(req *http.Request, res *http.Response) ([]byt
return contents, nil
}
func (cli *Client) closeTemp(file *os.File) {
func closeTemp(log *zerolog.Logger, file *os.File) {
_ = file.Close()
err := os.Remove(file.Name())
if err != nil {
cli.logWarning("Failed to remove temp file %s: %v", file.Name(), err)
log.Warn().Err(err).Str("file_name", file.Name()).Msg("Failed to remove response temp file")
}
}
func (cli *Client) streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
log := zerolog.Ctx(req.Context())
file, err := os.CreateTemp("", "mautrix-response-")
if err != nil {
cli.logWarning("Failed to create temporary file: %v", err)
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
_, err = cli.handleNormalResponse(req, res, responseJSON)
return nil, err
}
defer cli.closeTemp(file)
defer closeTemp(log, file)
if _, err = io.Copy(file, res.Body); err != nil {
return nil, fmt.Errorf("failed to copy response to file: %w", err)
} else if _, err = file.Seek(0, 0); err != nil {
@@ -487,7 +522,7 @@ func (cli *Client) handleResponseError(req *http.Request, res *http.Response) ([
// parseBackoffFromResponse extracts the backoff time specified in the Retry-After header if present. See
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After.
func (cli *Client) parseBackoffFromResponse(res *http.Response, now time.Time, fallback time.Duration) time.Duration {
func (cli *Client) parseBackoffFromResponse(req *http.Request, res *http.Response, now time.Time, fallback time.Duration) time.Duration {
retryAfterHeaderValue := res.Header.Get("Retry-After")
if retryAfterHeaderValue == "" {
return fallback
@@ -501,7 +536,9 @@ func (cli *Client) parseBackoffFromResponse(res *http.Response, now time.Time, f
return time.Duration(seconds) * time.Second
}
cli.logWarning(`Failed to parse Retry-After header value "%s"`, retryAfterHeaderValue)
zerolog.Ctx(req.Context()).Warn().
Str("retry_after", retryAfterHeaderValue).
Msg("Failed to parse Retry-After header value")
return fallback
}
@@ -536,7 +573,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
if retries > 0 && cli.shouldRetry(res) {
if res.StatusCode == http.StatusTooManyRequests {
backoff = cli.parseBackoffFromResponse(res, time.Now(), backoff)
backoff = cli.parseBackoffFromResponse(req, res, time.Now(), backoff)
}
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler)
}
@@ -631,7 +668,11 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) {
buffer = 1 * time.Minute
}
if err == nil && duration > timeout+buffer {
cli.logWarning("Sync request (%s) took %s with timeout %s", req.Since, duration, timeout)
cli.cliOrContextLog(fullReq.Context).Warn().
Str("since", req.Since).
Dur("duration", duration).
Dur("timeout", timeout).
Msg("Sync request took unusually long")
}
return
}
@@ -760,15 +801,24 @@ func (cli *Client) Login(req *ReqLogin) (resp *RespLogin, err error) {
cli.DeviceID = resp.DeviceID
cli.AccessToken = resp.AccessToken
cli.UserID = resp.UserID
cli.Logger.Debugfln("Stored credentials for %s/%s after login", cli.UserID, cli.DeviceID)
cli.Log.Debug().
Str("user_id", cli.UserID.String()).
Str("device_id", cli.DeviceID.String()).
Msg("Stored credentials after login")
}
if req.StoreHomeserverURL && err == nil && resp.WellKnown != nil && len(resp.WellKnown.Homeserver.BaseURL) > 0 {
var urlErr error
cli.HomeserverURL, urlErr = url.Parse(resp.WellKnown.Homeserver.BaseURL)
if urlErr != nil {
cli.logWarning("Failed to parse homeserver URL '%s' in login response: %v", resp.WellKnown.Homeserver.BaseURL, urlErr)
cli.Log.Warn().
Err(urlErr).
Str("homeserver_url", resp.WellKnown.Homeserver.BaseURL).
Msg("Failed to parse homeserver URL in login response")
} else {
cli.Logger.Debugfln("Updated homeserver URL to %s after login", cli.HomeserverURL.String())
cli.Log.Debug().
Str("homeserver_url", cli.HomeserverURL.String()).
Msg("Updated homeserver URL after login")
}
}
return
@@ -818,6 +868,9 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{
urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias)
}
_, err = cli.MakeRequest("POST", urlPath, content, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
}
return
}
@@ -827,6 +880,9 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{
// It's mostly intended for bridges and other things where it's already certain that the server is in the room.
func (cli *Client) JoinRoomByID(roomID id.RoomID) (resp *RespJoinRoom, err error) {
_, err = cli.MakeRequest("POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
}
return
}
@@ -892,6 +948,13 @@ func (cli *Client) SetAvatarURL(url id.ContentURI) (err error) {
return nil
}
// BeeperUpdateProfile sets custom fields in the user's profile.
func (cli *Client) BeeperUpdateProfile(data map[string]any) (err error) {
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID)
_, err = cli.MakeRequest("PATCH", urlPath, &data, nil)
return
}
// GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype
func (cli *Client) GetAccountData(name string, output interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name)
@@ -932,6 +995,8 @@ type ReqSendEvent struct {
Timestamp int64
TransactionID string
DontEncrypt bool
MeowEventID id.EventID
}
@@ -958,8 +1023,16 @@ func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, cont
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
}
urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID}
if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted && cli.StateStore.IsEncrypted(roomID) {
contentJSON, err = cli.Crypto.Encrypt(roomID, eventType, contentJSON)
if err != nil {
err = fmt.Errorf("failed to encrypt event: %w", err)
return
}
eventType = event.EventEncrypted
}
urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID}
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
return
@@ -970,6 +1043,9 @@ func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, cont
func (cli *Client) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) {
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
}
return
}
@@ -980,6 +1056,9 @@ func (cli *Client) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type
"ts": strconv.FormatInt(ts, 10),
})
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
}
return
}
@@ -992,30 +1071,6 @@ func (cli *Client) SendText(roomID id.RoomID, text string) (*RespSendEvent, erro
})
}
// SendImage sends an m.room.message event into the given room with a msgtype of m.image
// See https://spec.matrix.org/v1.2/client-server-api/#mimage
//
// Deprecated: This does not allow setting image metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
func (cli *Client) SendImage(roomID id.RoomID, body string, url id.ContentURI) (*RespSendEvent, error) {
return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgImage,
Body: body,
URL: url.CUString(),
})
}
// SendVideo sends an m.room.message event into the given room with a msgtype of m.video
// See https://spec.matrix.org/v1.2/client-server-api/#mvideo
//
// Deprecated: This does not allow setting video metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
func (cli *Client) SendVideo(roomID id.RoomID, body string, url id.ContentURI) (*RespSendEvent, error) {
return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{
MsgType: event.MsgVideo,
Body: body,
URL: url.CUString(),
})
}
// SendNotice sends an m.room.message event into the given room with a msgtype of m.notice
// See https://spec.matrix.org/v1.2/client-server-api/#mnotice
func (cli *Client) SendNotice(roomID id.RoomID, text string) (*RespSendEvent, error) {
@@ -1067,6 +1122,22 @@ func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...Re
func (cli *Client) CreateRoom(req *ReqCreateRoom) (resp *RespCreateRoom, err error) {
urlPath := cli.BuildClientURL("v3", "createRoom")
_, err = cli.MakeRequest("POST", urlPath, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
for _, evt := range req.InitialState {
UpdateStateStore(cli.StateStore, evt)
}
inviteMembership := event.MembershipInvite
if req.BeeperAutoJoinInvites {
inviteMembership = event.MembershipJoin
}
for _, invitee := range req.Invite {
cli.StateStore.SetMembership(resp.RoomID, invitee, inviteMembership)
}
for _, evt := range req.InitialState {
cli.updateStoreWithOutgoingEvent(resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
}
}
return
}
@@ -1080,6 +1151,9 @@ func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp *
}
u := cli.BuildClientURL("v3", "rooms", roomID, "leave")
_, err = cli.MakeRequest("POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave)
}
return
}
@@ -1094,6 +1168,9 @@ func (cli *Client) ForgetRoom(roomID id.RoomID) (resp *RespForgetRoom, err error
func (cli *Client) InviteUser(roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "invite")
_, err = cli.MakeRequest("POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite)
}
return
}
@@ -1108,6 +1185,9 @@ func (cli *Client) InviteUserByThirdParty(roomID id.RoomID, req *ReqInvite3PID)
func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "kick")
_, err = cli.MakeRequest("POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
}
return
}
@@ -1115,6 +1195,9 @@ func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickU
func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "ban")
_, err = cli.MakeRequest("POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan)
}
return
}
@@ -1122,6 +1205,9 @@ func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser
func (cli *Client) UnbanUser(roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "unban")
_, err = cli.MakeRequest("POST", u, req, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
}
return
}
@@ -1153,12 +1239,48 @@ func (cli *Client) SetPresence(status event.Presence) (err error) {
return
}
func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) {
if cli.StateStore == nil {
return
}
fakeEvt := &event.Event{
StateKey: &stateKey,
Type: eventType,
RoomID: roomID,
}
var err error
fakeEvt.Content.VeryRaw, err = json.Marshal(contentJSON)
if err != nil {
cli.Log.Warn().Err(err).Msg("Failed to marshal state event content to update state store")
return
}
err = json.Unmarshal(fakeEvt.Content.VeryRaw, &fakeEvt.Content.Raw)
if err != nil {
cli.Log.Warn().Err(err).Msg("Failed to unmarshal state event content to update state store")
return
}
err = fakeEvt.Content.ParseRaw(fakeEvt.Type)
if err != nil {
switch fakeEvt.Type {
case event.StateMember, event.StatePowerLevels, event.StateEncryption:
cli.Log.Warn().Err(err).Msg("Failed to parse state event content to update state store")
default:
cli.Log.Debug().Err(err).Msg("Failed to parse state event content to update state store")
}
return
}
UpdateStateStore(cli.StateStore, fakeEvt)
}
// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with
// the HTTP response body, or return an error.
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey
func (cli *Client) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest("GET", u, nil, outContent)
if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent)
}
return
}
@@ -1182,6 +1304,7 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter
if err != nil {
return nil, fmt.Errorf("failed to parse state array item #%d: %v", i, err)
}
evt.Type.Class = event.StateEventType
_ = evt.Content.ParseRaw(evt.Type)
subMap, ok := response[evt.Type]
if !ok {
@@ -1209,6 +1332,14 @@ func (cli *Client) State(roomID id.RoomID) (stateMap RoomStateMap, err error) {
ResponseJSON: &stateMap,
Handler: parseRoomStateArray,
})
if err == nil && cli.StateStore != nil {
cli.StateStore.ClearCachedMembers(roomID)
for _, evts := range stateMap {
for _, evt := range evts {
UpdateStateStore(cli.StateStore, evt)
}
}
}
return
}
@@ -1245,6 +1376,10 @@ func (cli *Client) DownloadContext(ctx context.Context, mxcURL id.ContentURI) (i
}
func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Request, *http.Response, error) {
ctxLog := zerolog.Ctx(ctx)
if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger {
ctx = cli.Log.WithContext(ctx)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil)
if err != nil {
return req, nil, err
@@ -1303,7 +1438,7 @@ func (cli *Client) UnstableUploadAsync(req ReqUploadMedia) (*RespCreateMXC, erro
go func() {
_, err = cli.UploadMedia(req)
if err != nil {
cli.logWarning("Failed to upload %s: %v", req.UnstableMXC, err)
cli.Log.Error().Str("mxc", req.UnstableMXC.String()).Err(err).Msg("Async upload of media failed")
}
}()
return resp, nil
@@ -1349,7 +1484,7 @@ type ReqUploadMedia struct {
}
func (cli *Client) tryUploadMediaToURL(url, contentType string, content io.Reader) (*http.Response, error) {
cli.Logger.Debugfln("Uploading media to external URL %s", url)
cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL")
req, err := http.NewRequest(http.MethodPut, url, content)
if err != nil {
return nil, err
@@ -1382,10 +1517,10 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro
err = fmt.Errorf("HTTP %d", resp.StatusCode)
}
if retries <= 0 {
cli.logWarning("Error uploading media to %s: %v, not retrying", data.UploadURL, err)
cli.Log.Warn().Str("url", data.UploadURL).Err(err).Msg("Error uploading media to external URL, not retrying")
return nil, err
}
cli.Logger.Debugfln("Error uploading media to %s: %v, retrying", data.UploadURL, err)
cli.Log.Warn().Str("url", data.UploadURL).Err(err).Msg("Error uploading media to external URL, retrying")
retries--
}
@@ -1464,6 +1599,16 @@ func (cli *Client) GetURLPreview(url string) (*RespPreviewURL, error) {
func (cli *Client) JoinedMembers(roomID id.RoomID) (resp *RespJoinedMembers, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members")
_, err = cli.MakeRequest("GET", u, nil, &resp)
if err == nil && cli.StateStore != nil {
cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin)
for userID, member := range resp.Joined {
cli.StateStore.SetMember(roomID, userID, &event.MemberEventContent{
Membership: event.MembershipJoin,
AvatarURL: id.ContentURIString(member.AvatarURL),
Displayname: member.DisplayName,
})
}
}
return
}
@@ -1484,6 +1629,18 @@ func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembe
}
u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query)
_, err = cli.MakeRequest("GET", u, nil, &resp)
if err == nil && cli.StateStore != nil {
var clearMemberships []event.Membership
if extra.Membership != "" {
clearMemberships = append(clearMemberships, extra.Membership)
}
if extra.NotMembership == "" {
cli.StateStore.ClearCachedMembers(roomID, clearMemberships...)
}
for _, evt := range resp.Chunk {
UpdateStateStore(cli.StateStore, evt)
}
}
return
}
@@ -1833,6 +1990,18 @@ func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBat
return
}
func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, err error) {
_, err = cli.MakeFullRequest(FullRequest{
Method: http.MethodPost,
URL: cli.BuildClientURL("v1", "appservice", id, "ping"),
RequestJSON: &ReqAppservicePing{TxnID: txnID},
ResponseJSON: &resp,
// This endpoint intentionally returns 50x, so don't retry
MaxAttempts: 1,
})
return
}
func (cli *Client) BeeperMergeRooms(req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) {
urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "merge")
_, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp)
@@ -1859,21 +2028,23 @@ func (cli *Client) TxnID() string {
// NewClient creates a new Matrix Client ready for syncing
func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Client, error) {
hsURL, err := parseAndNormalizeBaseURL(homeserverURL)
hsURL, err := ParseAndNormalizeBaseURL(homeserverURL)
if err != nil {
return nil, err
}
return &Client{
cli := &Client{
AccessToken: accessToken,
UserAgent: DefaultUserAgent,
HomeserverURL: hsURL,
UserID: userID,
Client: &http.Client{Timeout: 180 * time.Second},
Syncer: NewDefaultSyncer(),
Logger: stubLogger,
Log: zerolog.Nop(),
// By default, use an in-memory store which will never save filter ids / next batch tokens to disk.
// The client will work with this storer: it just won't remember across restarts.
// In practice, a database backend should be used.
Store: NewInMemoryStore(),
}, nil
Store: NewMemorySyncStore(),
}
cli.Logger = maulogadapt.ZeroAsMau(&cli.Log)
return cli, nil
}

View File

@@ -1,4 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -55,8 +56,11 @@ func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err erro
return
}
mach.Log.Trace("Got cross-signing keys: Master `%v` Self-signing `%v` User-signing `%v`",
keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey)
mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
Msg("Imported own cross-signing keys")
mach.CrossSigningKeys = &keysCache
mach.crossSigningPubkeys = keysCache.PublicKeys()
@@ -76,8 +80,11 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil {
return nil, fmt.Errorf("failed to generate user-signing key: %w", err)
}
mach.Log.Debug("Generated cross-signing keys: Master: `%v` Self-signing: `%v` User-signing: `%v`",
keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey)
mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
Msg("Generated cross-signing keys")
return &keysCache, nil
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -32,7 +32,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa
}
cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID)
if err != nil {
mach.Log.Error("Failed to get own cross-signing public keys: %v", err)
mach.Log.Error().Err(err).Msg("Failed to get own cross-signing public keys")
return nil
}
mach.crossSigningPubkeys = cspk

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -79,7 +79,10 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
return err
}
mach.Log.Trace("Signed master key of %s with user-signing key: `%v`", userID, signature)
mach.Log.Debug().
Str("user_id", userID.String()).
Str("signature", signature).
Msg("Signed master key of user with our user-signing key")
if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)
@@ -116,7 +119,10 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature,
},
}
mach.Log.Trace("Signed own master key with device %v: `%v`", deviceID, signature)
mach.Log.Debug().
Str("device_id", deviceID.String()).
Str("signature", signature).
Msg("Signed own master key with own device key")
resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{
userID: map[string]mautrix.ReqKeysSignatures{
@@ -165,7 +171,11 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
return err
}
mach.Log.Trace("Signed own device %s with self-signing key: `%v`", device.UserID, device.DeviceID, signature)
mach.Log.Debug().
Str("user_id", device.UserID.String()).
Str("device_id", device.DeviceID.String()).
Str("signature", signature).
Msg("Signed own device key with self-signing key")
if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err)

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,28 +8,36 @@
package crypto
import (
"context"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
log := mach.machOrContextLog(ctx)
for userID, userKeys := range crossSigningKeys {
log := log.With().Str("user_id", userID.String()).Logger()
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
if err != nil {
mach.Log.Error("Error fetching current cross-signing keys of user %v: %v", userID, err)
log.Error().Err(err).
Msg("Error fetching current cross-signing keys of user")
}
if currentKeys != nil {
for curKeyUsage, curKey := range currentKeys {
log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger()
// got a new key with the same usage as an existing key
for _, newKeyUsage := range userKeys.Usage {
if newKeyUsage == curKeyUsage {
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
// old key is not in the new key map, so we drop signatures made by it
if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil {
mach.Log.Error("Error deleting old signatures made by %s (%s): %v", curKey, curKeyUsage, err)
log.Error().Err(err).Msg("Error deleting old signatures made by user")
} else {
mach.Log.Debug("Dropped %d signatures made by key %s (%s) as it has been replaced", count, curKey, curKeyUsage)
log.Debug().
Int64("signature_count", count).
Msg("Dropped signatures made by old key as it has been replaced")
}
}
break
@@ -39,10 +47,11 @@ func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mau
}
for _, key := range userKeys.Keys {
log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger()
for _, usage := range userKeys.Usage {
mach.Log.Debug("Storing cross-signing key for %s: %s (type %s)", userID, key, usage)
log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key")
if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil {
mach.Log.Error("Error storing cross-signing key: %v", err)
log.Error().Err(err).Msg("Error storing cross-signing key")
}
}
@@ -50,31 +59,38 @@ func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mau
for signKeyID, signature := range keySigs {
_, signKeyName := signKeyID.Parse()
signingKey := id.Ed25519(signKeyName)
log := log.With().
Str("sign_key_id", signKeyID.String()).
Str("signer_user_id", signUserID.String()).
Str("signing_key", signingKey.String()).
Logger()
// if the signer is one of this user's own devices, find the key from the key ID
if signUserID == userID {
ownDeviceID := id.DeviceID(signKeyName)
if ownDeviceKeys, ok := deviceKeys[userID][ownDeviceID]; ok {
signingKey = ownDeviceKeys.Keys.GetEd25519(ownDeviceID)
mach.Log.Trace("Treating %s as the device ID -> signing key %s", signKeyName, signingKey)
log.Trace().
Str("device_id", signKeyName).
Msg("Treating key name as device ID")
}
}
if len(signingKey) != 43 {
mach.Log.Trace("Cross-signing key %s/%s/%v has a signature from an unknown key %s", userID, key, userKeys.Usage, signKeyID)
log.Debug().Msg("Cross-signing key has a signature from an unknown key")
continue
}
mach.Log.Debug("Verifying cross-signing key %s/%s/%v with key %s/%s", userID, key, userKeys.Usage, signUserID, signingKey)
log.Debug().Msg("Verifying cross-signing key signature")
if verified, err := olm.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil {
mach.Log.Warn("Error while verifying signature from %s for %s: %v", signingKey, key, err)
log.Warn().Err(err).Msg("Error verifying cross-signing key signature")
} else {
if verified {
mach.Log.Debug("Signature from %s for %s verified", signingKey, key)
log.Debug().Err(err).Msg("Cross-signing key signature verified")
err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature)
if err != nil {
mach.Log.Warn("Failed to store signature from %s for %s: %v", signingKey, key, err)
log.Error().Err(err).Msg("Error storing cross-signing key signature")
}
} else {
mach.Log.Error("Invalid signature from %s for %s", signingKey, key)
log.Warn().Err(err).Msg("Cross-siging key signature is invalid")
}
}
}

View File

@@ -1,5 +1,5 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -8,52 +8,72 @@
package crypto
import (
"context"
"maunium.net/go/mautrix/id"
)
// ResolveTrust resolves the trust state of the device from cross-signing.
func (mach *OlmMachine) ResolveTrust(device *id.Device) id.TrustState {
state, _ := mach.ResolveTrustContext(context.Background(), device)
return state
}
// ResolveTrustContext resolves the trust state of the device from cross-signing.
func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Device) (id.TrustState, error) {
if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted {
return device.Trust
return device.Trust, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
if err != nil {
mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", device.UserID, err)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
Msg("Error retrieving cross-signing key of user from database")
return id.TrustStateUnset, err
}
theirMSK, ok := theirKeys[id.XSUsageMaster]
if !ok {
mach.Log.Error("Master key of user %v not found", device.UserID)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().
Str("user_id", device.UserID.String()).
Msg("Master key of user not found")
return id.TrustStateUnset, nil
}
theirSSK, ok := theirKeys[id.XSUsageSelfSigning]
if !ok {
mach.Log.Error("Self-signing key of user %v not found", device.UserID)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().
Str("user_id", device.UserID.String()).
Msg("Self-signing key of user not found")
return id.TrustStateUnset, nil
}
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
if err != nil {
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", device.UserID, err)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
Msg("Error retrieving cross-signing signatures for master key of user from database")
return id.TrustStateUnset, err
}
if !sskSigExists {
mach.Log.Warn("Self-signing key of user %v is not signed by their master key", device.UserID)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().
Str("user_id", device.UserID.String()).
Msg("Self-signing key of user is not signed by their master key")
return id.TrustStateUnset, nil
}
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
if err != nil {
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", device.UserID, err)
return id.TrustStateUnset
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", device.UserID.String()).
Str("device_key", device.SigningKey.String()).
Msg("Error retrieving cross-signing signatures for device from database")
return id.TrustStateUnset, err
}
if deviceSigExists {
if mach.IsUserTrusted(device.UserID) {
return id.TrustStateCrossSignedVerified
if trusted, err := mach.IsUserTrusted(ctx, device.UserID); !trusted {
return id.TrustStateCrossSignedVerified, err
} else if theirMSK.Key == theirMSK.First {
return id.TrustStateCrossSignedTOFU
return id.TrustStateCrossSignedTOFU, nil
}
return id.TrustStateCrossSignedUntrusted
return id.TrustStateCrossSignedUntrusted, nil
}
return id.TrustStateUnset
return id.TrustStateUnset, nil
}
// IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing.
@@ -68,36 +88,42 @@ func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool {
// IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key.
// In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user.
func (mach *OlmMachine) IsUserTrusted(userID id.UserID) bool {
func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bool, error) {
csPubkeys := mach.GetOwnCrossSigningPublicKeys()
if csPubkeys == nil {
return false
return false, nil
}
if userID == mach.Client.UserID {
return true
return true, nil
}
// first we verify our user-signing key
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
if err != nil {
mach.Log.Error("Error retrieving our self-singing key signatures: %v", err)
return false
mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database")
return false, err
} else if !ourUserSigningKeyTrusted {
return false
return false, nil
}
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
if err != nil {
mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", userID, err)
return false
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).
Msg("Error retrieving cross-signing key of user from database")
return false, err
}
theirMskKey, ok := theirKeys[id.XSUsageMaster]
if !ok {
mach.Log.Error("Master key of user %v not found", userID)
return false
mach.machOrContextLog(ctx).Error().
Str("user_id", userID.String()).
Msg("Master key of user not found")
return false, nil
}
sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
if err != nil {
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", userID, err)
return false
mach.machOrContextLog(ctx).Error().Err(err).
Str("user_id", userID.String()).
Msg("Error retrieving cross-signing signatures for master key of user from database")
return false, err
}
return sigExists
return sigExists, nil
}

View File

@@ -0,0 +1,374 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package cryptohelper
import (
"context"
"errors"
"fmt"
"sync"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/sqlstatestore"
"maunium.net/go/mautrix/util/dbutil"
)
type CryptoHelper struct {
client *mautrix.Client
mach *crypto.OlmMachine
log zerolog.Logger
lock sync.RWMutex
pickleKey []byte
managedStateStore *sqlstatestore.SQLStateStore
unmanagedCryptoStore crypto.Store
dbForManagedStores *dbutil.Database
DecryptErrorCallback func(*event.Event, error)
LoginAs *mautrix.ReqLogin
DBAccountID string
}
var _ mautrix.CryptoHelper = (*CryptoHelper)(nil)
// NewCryptoHelper creates a struct that helps a mautrix client struct with Matrix e2ee operations.
//
// The client and pickle key are always required. Additionally, you must either:
// - Provide a crypto.Store here and set a StateStore in the client, or
// - Provide a dbutil.Database here to automatically create missing stores.
// - Provide a string here to use it as a path to a SQLite database, and then automatically create missing stores.
//
// The same database may be shared across multiple clients, but note that doing that will allow all clients access to
// decryption keys received by any one of the clients. For that reason, the pickle key must also be same for all clients
// using the same database.
func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoHelper, error) {
if len(pickleKey) == 0 {
return nil, fmt.Errorf("pickle key must be provided")
}
_, isExtensible := cli.Syncer.(mautrix.ExtensibleSyncer)
if !isExtensible {
return nil, fmt.Errorf("the client syncer must implement ExtensibleSyncer")
}
var managedStateStore *sqlstatestore.SQLStateStore
var dbForManagedStores *dbutil.Database
var unmanagedCryptoStore crypto.Store
switch typedStore := store.(type) {
case crypto.Store:
if cli.StateStore == nil {
return nil, fmt.Errorf("when passing a crypto.Store to NewCryptoHelper, the client must have a state store set beforehand")
} else if _, isCryptoCompatible := cli.StateStore.(crypto.StateStore); !isCryptoCompatible {
return nil, fmt.Errorf("the client state store must implement crypto.StateStore")
}
unmanagedCryptoStore = typedStore
case string:
db, err := dbutil.NewWithDialect(typedStore, "sqlite3")
if err != nil {
return nil, err
}
dbForManagedStores = db
case *dbutil.Database:
dbForManagedStores = typedStore
default:
return nil, fmt.Errorf("you must pass a *dbutil.Database or *crypto.StateStore to NewCryptoHelper")
}
log := cli.Log.With().Str("component", "crypto").Logger()
if cli.StateStore == nil && dbForManagedStores != nil {
managedStateStore = sqlstatestore.NewSQLStateStore(dbForManagedStores, dbutil.ZeroLogger(log.With().Str("db_section", "matrix_state").Logger()), false)
cli.StateStore = managedStateStore
} else if _, isCryptoCompatible := cli.StateStore.(crypto.StateStore); !isCryptoCompatible {
return nil, fmt.Errorf("the client state store must implement crypto.StateStore")
}
return &CryptoHelper{
client: cli,
log: log,
pickleKey: pickleKey,
unmanagedCryptoStore: unmanagedCryptoStore,
managedStateStore: managedStateStore,
dbForManagedStores: dbForManagedStores,
DecryptErrorCallback: func(_ *event.Event, _ error) {},
}, nil
}
func (helper *CryptoHelper) Init() error {
if helper == nil {
return fmt.Errorf("crypto helper is nil")
}
syncer, ok := helper.client.Syncer.(mautrix.ExtensibleSyncer)
if !ok {
return fmt.Errorf("the client syncer must implement ExtensibleSyncer")
}
var stateStore crypto.StateStore
if helper.managedStateStore != nil {
err := helper.managedStateStore.Upgrade()
if err != nil {
return fmt.Errorf("failed to upgrade client state store: %w", err)
}
stateStore = helper.managedStateStore
} else {
stateStore = helper.client.StateStore.(crypto.StateStore)
}
var cryptoStore crypto.Store
if helper.unmanagedCryptoStore == nil {
managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey)
if helper.client.Store == nil {
helper.client.Store = managedCryptoStore
} else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory {
helper.client.Store = managedCryptoStore
}
err := managedCryptoStore.DB.Upgrade()
if err != nil {
return fmt.Errorf("failed to upgrade crypto state store: %w", err)
}
storedDeviceID := managedCryptoStore.FindDeviceID()
if helper.LoginAs != nil {
if storedDeviceID != "" {
helper.LoginAs.DeviceID = storedDeviceID
}
helper.LoginAs.StoreCredentials = true
helper.log.Debug().
Str("username", helper.LoginAs.Identifier.User).
Str("device_id", helper.LoginAs.DeviceID.String()).
Msg("Logging in")
_, err = helper.client.Login(helper.LoginAs)
if err != nil {
return err
}
if storedDeviceID == "" {
managedCryptoStore.DeviceID = helper.client.DeviceID
}
} else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID {
return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID)
}
cryptoStore = managedCryptoStore
} else {
if helper.LoginAs != nil {
return fmt.Errorf("LoginAs can only be used with a managed crypto store")
}
cryptoStore = helper.unmanagedCryptoStore
}
if helper.client.DeviceID == "" || helper.client.UserID == "" {
return fmt.Errorf("the client must be logged in")
}
helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore)
err := helper.mach.Load()
if err != nil {
return fmt.Errorf("failed to load olm account: %w", err)
} else if err = helper.verifyDeviceKeysOnServer(); err != nil {
return err
}
syncer.OnSync(helper.mach.ProcessSyncResponse)
syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent)
if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok {
syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted)
} else {
helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.")
}
if helper.managedStateStore != nil {
syncer.OnEvent(helper.client.StateStoreSyncHandler)
}
return nil
}
func (helper *CryptoHelper) Close() error {
if helper != nil && helper.dbForManagedStores != nil {
err := helper.dbForManagedStores.RawDB.Close()
if err != nil {
return err
}
}
return nil
}
func (helper *CryptoHelper) Machine() *crypto.OlmMachine {
if helper == nil || helper.mach == nil {
panic("Machine() called before initing CryptoHelper")
}
return helper.mach
}
func (helper *CryptoHelper) verifyDeviceKeysOnServer() error {
helper.log.Debug().Msg("Making sure our device has the expected keys on the server")
resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
helper.client.UserID: {helper.client.DeviceID},
},
})
if err != nil {
return fmt.Errorf("failed to query own keys to make sure device is properly configured: %w", err)
}
ownID := helper.mach.OwnIdentity()
isShared := helper.mach.GetAccount().Shared
device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID]
if !ok || len(device.Keys) == 0 {
if isShared {
return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server")
} else {
helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine")
return nil
}
} else if !isShared {
return fmt.Errorf("olm account is not marked as shared, but there are keys on the server")
} else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed {
return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed)
}
if !isShared {
helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?")
} else {
helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine")
}
return nil
}
var NoSessionFound = crypto.NoSessionFound
const initialSessionWaitTimeout = 3 * time.Second
const extendedSessionWaitTimeout = 22 * time.Second
func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) {
if helper == nil {
return
}
content := evt.Content.AsEncrypted()
log := helper.log.With().
Str("event_id", evt.ID.String()).
Str("session_id", content.SessionID.String()).
Logger()
log.Debug().Msg("Decrypting received event")
decrypted, err := helper.Decrypt(evt)
if errors.Is(err, NoSessionFound) {
log.Debug().
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
Msg("Couldn't find session, waiting for keys to arrive...")
if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
decrypted, err = helper.Decrypt(evt)
} else {
go helper.waitLongerForSession(log, src, evt)
return
}
}
if err != nil {
log.Warn().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
return
}
helper.postDecrypt(src, decrypted)
}
func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) {
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted)
}
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
if helper == nil {
return
}
helper.lock.RLock()
defer helper.lock.RUnlock()
if deviceID == "" {
deviceID = "*"
}
// TODO get log from context
log := helper.log.With().
Str("session_id", sessionID.String()).
Str("user_id", userID.String()).
Str("device_id", deviceID.String()).
Str("room_id", roomID.String()).
Logger()
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
userID: {deviceID},
helper.client.UserID: {"*"},
})
if err != nil {
log.Warn().Err(err).Msg("Failed to send key request")
} else {
log.Debug().Msg("Sent key request")
}
}
func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
content := evt.Content.AsEncrypted()
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
go helper.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
log.Debug().Msg("Didn't get session, giving up")
helper.DecryptErrorCallback(evt, NoSessionFound)
return
}
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
decrypted, err := helper.Decrypt(evt)
if err != nil {
log.Error().Err(err).Msg("Failed to decrypt event")
helper.DecryptErrorCallback(evt, err)
return
}
helper.postDecrypt(src, decrypted)
}
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
if helper == nil {
return false
}
helper.lock.RLock()
defer helper.lock.RUnlock()
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
}
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
}
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
if helper == nil {
return nil, fmt.Errorf("crypto helper is nil")
}
helper.lock.RLock()
defer helper.lock.RUnlock()
ctx := context.TODO()
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
if err != nil {
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
return
}
helper.log.Debug().
Err(err).
Str("room_id", roomID.String()).
Msg("Got session error while encrypting event, sharing group session and trying again")
var users []id.UserID
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID)
if err != nil {
err = fmt.Errorf("failed to get room member list: %w", err)
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {
err = fmt.Errorf("failed to share group session: %w", err)
} else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil {
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
}
}
return
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,11 +7,14 @@
package crypto
import (
"context"
"encoding/json"
"errors"
"fmt"
"strings"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -23,6 +26,7 @@ var (
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
RatchetError = errors.New("failed to ratchet session after use")
)
type megolmEvent struct {
@@ -32,13 +36,21 @@ type megolmEvent struct {
}
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, error) {
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, IncorrectEncryptedContentType
} else if content.Algorithm != id.AlgorithmMegolmV1 {
return nil, UnsupportedAlgorithm
}
log := mach.machOrContextLog(ctx).With().
Str("action", "decrypt megolm event").
Str("event_id", evt.ID.String()).
Str("sender", evt.Sender.String()).
Str("sender_key", content.SenderKey.String()).
Str("session_id", content.SessionID.String()).
Logger()
ctx = log.WithContext(ctx)
encryptionRoomID := evt.RoomID
// Allow the server to move encrypted events between rooms if both the real room and target room are on a non-federatable .local domain.
// The message index checks to prevent replay attacks still apply and aren't based on the room ID,
@@ -46,22 +58,11 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if origRoomID, ok := evt.Content.Raw["com.beeper.original_room_id"].(string); ok && strings.HasSuffix(origRoomID, ".local") && strings.HasSuffix(evt.RoomID.String(), ".local") {
encryptionRoomID = id.RoomID(origRoomID)
}
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
sess, plaintext, messageIndex, err := mach.actuallyDecryptMegolmEvent(ctx, evt, encryptionRoomID, content)
if err != nil {
return nil, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return nil, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
return nil, fmt.Errorf("failed to decrypt megolm event: %w", err)
} else if ok, err = mach.CryptoStore.ValidateMessageIndex(sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return nil, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return nil, DuplicateMessageIndex
return nil, err
}
log = log.With().Uint("message_index", messageIndex).Logger()
var trustLevel id.TrustState
var forwardedKeys bool
@@ -70,14 +71,16 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if sess.SigningKey == ownSigningKey && sess.SenderKey == ownIdentityKey && len(sess.ForwardingChains) == 0 {
trustLevel = id.TrustStateVerified
} else {
device, err = mach.GetOrFetchDeviceByKey(evt.Sender, sess.SenderKey)
device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey)
if err != nil {
// We don't want to throw these errors as the message can still be decrypted.
mach.Log.Debug("Failed to get device %s/%s to verify session %s: %v", evt.Sender, sess.SenderKey, sess.ID(), err)
log.Debug().Err(err).Msg("Failed to get device to verify session")
trustLevel = id.TrustStateUnknownDevice
} else if len(sess.ForwardingChains) == 0 || (len(sess.ForwardingChains) == 1 && sess.ForwardingChains[0] == sess.SenderKey.String()) {
if device == nil {
mach.Log.Debug("Couldn't resolve trust level of session %s: sent by unknown device %s/%s", sess.ID(), evt.Sender, sess.SenderKey)
log.Debug().Err(err).
Str("session_sender_key", sess.SenderKey.String()).
Msg("Couldn't resolve trust level of session: sent by unknown device")
trustLevel = id.TrustStateUnknownDevice
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
return nil, DeviceKeyMismatch
@@ -91,7 +94,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if device != nil {
trustLevel = mach.ResolveTrust(device)
} else {
mach.Log.Debug("Couldn't resolve trust level of session %s: forwarding chain ends with unknown device %s", sess.ID(), lastChainItem)
log.Debug().
Str("forward_last_sender_key", lastChainItem).
Msg("Couldn't resolve trust level of session: forwarding chain ends with unknown device")
trustLevel = id.TrustStateForwarded
}
}
@@ -105,10 +110,11 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
return nil, WrongRoom
}
megolmEvt.Type.Class = evt.Type.Class
log = log.With().Str("decrypted_event_type", megolmEvt.Type.Repr()).Logger()
err = megolmEvt.Content.ParseRaw(megolmEvt.Type)
if err != nil {
if errors.Is(err, event.ErrUnsupportedContentType) {
mach.Log.Warn("Unsupported event type %s in encrypted event %s", megolmEvt.Type.Repr(), evt.ID)
log.Warn().Msg("Unsupported event type in encrypted event")
} else {
return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err)
}
@@ -119,13 +125,14 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
if relatable.OptionalGetRelatesTo() == nil {
relatable.SetRelatesTo(content.RelatesTo)
} else {
mach.Log.Trace("Not overriding relation data in %s, as encrypted payload already has it", evt.ID)
log.Trace().Msg("Not overriding relation data as encrypted payload already has it")
}
}
if _, hasRelation := megolmEvt.Content.Raw["m.relates_to"]; !hasRelation {
megolmEvt.Content.Raw["m.relates_to"] = evt.Content.Raw["m.relates_to"]
}
}
log.Debug().Msg("Event decrypted successfully")
megolmEvt.Type.Class = evt.Type.Class
return &event.Event{
Sender: evt.Sender,
@@ -144,3 +151,106 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
},
}, nil
}
func removeItem(slice []uint, item uint) ([]uint, bool) {
for i, s := range slice {
if s == item {
return append(slice[:i], slice[i+1:]...), true
}
}
return slice, false
}
const missedIndexCutoff = 10
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
mach.megolmDecryptLock.Lock()
defer mach.megolmDecryptLock.Unlock()
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
if err != nil {
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
} else if sess == nil {
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
return sess, nil, 0, SenderKeyMismatch
}
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
if err != nil {
return sess, nil, 0, fmt.Errorf("failed to decrypt megolm event: %w", err)
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
} else if !ok {
return sess, nil, messageIndex, DuplicateMessageIndex
}
expectedMessageIndex := sess.RatchetSafety.NextIndex
didModify := false
switch {
case messageIndex > expectedMessageIndex:
// When the index jumps, add indices in between to the missed indices list.
for i := expectedMessageIndex; i < messageIndex; i++ {
sess.RatchetSafety.MissedIndices = append(sess.RatchetSafety.MissedIndices, i)
}
fallthrough
case messageIndex == expectedMessageIndex:
// When the index moves forward (to the next one or jumping ahead), update the last received index.
sess.RatchetSafety.NextIndex = messageIndex + 1
didModify = true
default:
sess.RatchetSafety.MissedIndices, didModify = removeItem(sess.RatchetSafety.MissedIndices, messageIndex)
}
// Use presence of ReceivedAt as a sign that this is a recent megolm session,
// and therefore it's safe to drop missed indices entirely.
if !sess.ReceivedAt.IsZero() && len(sess.RatchetSafety.MissedIndices) > 0 && int(sess.RatchetSafety.MissedIndices[0]) < int(sess.RatchetSafety.NextIndex)-missedIndexCutoff {
limit := sess.RatchetSafety.NextIndex - missedIndexCutoff
var cutoff int
for ; cutoff < len(sess.RatchetSafety.MissedIndices) && sess.RatchetSafety.MissedIndices[cutoff] < limit; cutoff++ {
}
sess.RatchetSafety.LostIndices = append(sess.RatchetSafety.LostIndices, sess.RatchetSafety.MissedIndices[:cutoff]...)
sess.RatchetSafety.MissedIndices = sess.RatchetSafety.MissedIndices[cutoff:]
didModify = true
}
ratchetTargetIndex := uint32(sess.RatchetSafety.NextIndex)
if len(sess.RatchetSafety.MissedIndices) > 0 {
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
}
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
log := zerolog.Ctx(ctx).With().
Uint32("prev_ratchet_index", ratchetCurrentIndex).
Uint32("new_ratchet_index", ratchetTargetIndex).
Uint("next_new_index", sess.RatchetSafety.NextIndex).
Uints("missed_indices", sess.RatchetSafety.MissedIndices).
Uints("lost_indices", sess.RatchetSafety.LostIndices).
Int("max_messages", sess.MaxMessages).
Logger()
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
if err != nil {
log.Err(err).Msg("Failed to delete fully used session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Deleted fully used session")
}
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
log.Err(err).Msg("Failed to ratchet session")
return sess, plaintext, messageIndex, RatchetError
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store ratcheted session")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Info().Msg("Ratcheted session forward")
}
} else if didModify {
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
log.Err(err).Msg("Failed to store updated ratchet safety data")
return sess, plaintext, messageIndex, RatchetError
} else {
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
}
} else {
log.Debug().Msg("Ratchet safety data didn't change")
}
return sess, plaintext, messageIndex, nil
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2021 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,11 +7,14 @@
package crypto
import (
"context"
"encoding/json"
"errors"
"fmt"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
@@ -43,7 +46,7 @@ type DecryptedOlmEvent struct {
Content event.Content `json:"content"`
}
func (mach *OlmMachine) decryptOlmEvent(evt *event.Event, traceID string) (*DecryptedOlmEvent, error) {
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
if !ok {
return nil, IncorrectEncryptedContentType
@@ -54,7 +57,7 @@ func (mach *OlmMachine) decryptOlmEvent(evt *event.Event, traceID string) (*Decr
if !ok {
return nil, NotEncryptedForMe
}
decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body, traceID)
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil {
return nil, err
}
@@ -66,19 +69,19 @@ type OlmEventKeys struct {
Ed25519 id.Ed25519 `json:"ed25519"`
}
func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) (*DecryptedOlmEvent, error) {
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
return nil, UnsupportedOlmMessageType
}
endTimeTrace := mach.timeTrace("decrypting olm ciphertext", traceID, 5*time.Second)
plaintext, err := mach.tryDecryptOlmCiphertext(sender, senderKey, olmType, ciphertext, traceID)
endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second)
plaintext, err := mach.tryDecryptOlmCiphertext(ctx, sender, senderKey, olmType, ciphertext)
endTimeTrace()
if err != nil {
return nil, err
}
defer mach.timeTrace("parsing decrypted olm event", traceID, time.Second)()
defer mach.timeTrace(ctx, "parsing decrypted olm event", time.Second)()
var olmEvt DecryptedOlmEvent
err = json.Unmarshal(plaintext, &olmEvt)
@@ -103,17 +106,18 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey
return &olmEvt, nil
}
func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) {
endTimeTrace := mach.timeTrace("waiting for olm lock", traceID, 5*time.Second)
func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second)
mach.olmLock.Lock()
endTimeTrace()
defer mach.olmLock.Unlock()
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(senderKey, olmType, ciphertext, traceID)
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext)
if err != nil {
if err == DecryptionFailedWithMatchingSession {
mach.Log.Warn("Found matching session yet decryption failed for sender %s with key %s", sender, senderKey)
go mach.unwedgeDevice(sender, senderKey)
log.Warn().Msg("Found matching session, but decryption failed")
go mach.unwedgeDevice(log, sender, senderKey)
}
return nil, fmt.Errorf("failed to decrypt olm event: %w", err)
}
@@ -128,39 +132,44 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.S
// New sessions can only be created if it's a prekey message, we can't decrypt the message
// if it isn't one at this point in time anymore, so return early.
if olmType != id.OlmMsgTypePreKey {
go mach.unwedgeDevice(sender, senderKey)
go mach.unwedgeDevice(log, sender, senderKey)
return nil, DecryptionFailedForNormalMessage
}
mach.Log.Trace("Trying to create inbound session for %s/%s", sender, senderKey)
endTimeTrace = mach.timeTrace("creating inbound olm session", traceID, time.Second)
session, err := mach.createInboundSession(senderKey, ciphertext)
log.Trace().Msg("Trying to create inbound session")
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
endTimeTrace()
if err != nil {
go mach.unwedgeDevice(sender, senderKey)
go mach.unwedgeDevice(log, sender, senderKey)
return nil, fmt.Errorf("failed to create new session from prekey message: %w", err)
}
mach.Log.Debug("Created inbound olm session %s for %s/%s: %s", session.ID(), sender, senderKey, session.Describe())
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
log.Debug().
Str("olm_session_description", session.Describe()).
Msg("Created inbound olm session")
ctx = log.WithContext(ctx)
endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting prekey olm message with %s/%s", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "decrypting prekey olm message", time.Second)
plaintext, err = session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
go mach.unwedgeDevice(sender, senderKey)
go mach.unwedgeDevice(log, sender, senderKey)
return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err)
}
endTimeTrace = mach.timeTrace(fmt.Sprintf("updating new session %s/%s in database", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
endTimeTrace()
if err != nil {
mach.Log.Warn("Failed to update new olm session in crypto store after decrypting: %v", err)
log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting")
}
return plaintext, nil
}
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) {
endTimeTrace := mach.timeTrace(fmt.Sprintf("getting sessions with %s", senderKey), traceID, time.Second)
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
log := *zerolog.Ctx(ctx)
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
sessions, err := mach.CryptoStore.GetSessions(senderKey)
endTimeTrace()
if err != nil {
@@ -168,8 +177,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
}
for _, session := range sessions {
log := log.With().Str("olm_session_id", session.ID().String()).Logger()
ctx := log.WithContext(ctx)
if olmType == id.OlmMsgTypePreKey {
endTimeTrace = mach.timeTrace(fmt.Sprintf("checking if prekey olm message matches session %s/%s", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "checking if prekey olm message matches session", time.Second)
matches, err := session.Internal.MatchesInboundSession(ciphertext)
endTimeTrace()
if err != nil {
@@ -178,8 +189,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
continue
}
}
mach.Log.Trace("Trying to decrypt olm message from %s with session %s: %s", senderKey, session.ID(), session.Describe())
endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting olm message with %s/%s", senderKey, session.ID()), traceID, time.Second)
log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message")
endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second)
plaintext, err := session.Decrypt(ciphertext, olmType)
endTimeTrace()
if err != nil {
@@ -187,20 +198,20 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
return nil, DecryptionFailedWithMatchingSession
}
} else {
endTimeTrace = mach.timeTrace(fmt.Sprintf("updating session %s/%s in database", senderKey, session.ID()), traceID, time.Second)
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
err = mach.CryptoStore.UpdateSession(senderKey, session)
endTimeTrace()
if err != nil {
mach.Log.Warn("Failed to update olm session in crypto store after decrypting: %v", err)
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
}
mach.Log.Trace("Decrypted olm message from %s with session %s", senderKey, session.ID())
log.Debug().Msg("Decrypted olm message")
return plaintext, nil
}
}
return nil, nil
}
func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext string) (*OlmSession, error) {
func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.SenderKey, ciphertext string) (*OlmSession, error) {
session, err := mach.account.NewInboundSessionFrom(senderKey, ciphertext)
if err != nil {
return nil, err
@@ -208,40 +219,44 @@ func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext
mach.saveAccount()
err = mach.CryptoStore.AddSession(senderKey, session)
if err != nil {
mach.Log.Error("Failed to store created inbound session: %v", err)
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session")
}
return session, nil
}
const MinUnwedgeInterval = 1 * time.Hour
func (mach *OlmMachine) unwedgeDevice(sender id.UserID, senderKey id.SenderKey) {
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
log = log.With().Str("action", "unwedge olm session").Logger()
ctx := log.WithContext(context.Background())
mach.recentlyUnwedgedLock.Lock()
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
delta := time.Now().Sub(prevUnwedge)
if ok && delta < MinUnwedgeInterval {
mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, senderKey, delta)
log.Debug().
Str("previous_recreation", delta.String()).
Msg("Not creating new Olm session as it was already recreated recently")
mach.recentlyUnwedgedLock.Unlock()
return
}
mach.recentlyUnwedged[senderKey] = time.Now()
mach.recentlyUnwedgedLock.Unlock()
deviceIdentity, err := mach.GetOrFetchDeviceByKey(sender, senderKey)
deviceIdentity, err := mach.GetOrFetchDeviceByKey(ctx, sender, senderKey)
if err != nil {
mach.Log.Error("Failed to find device info by identity key: %v", err)
log.Error().Err(err).Msg("Failed to find device info by identity key")
return
} else if deviceIdentity == nil {
mach.Log.Warn("Didn't find identity of %s/%s, can't unwedge session", sender, senderKey)
log.Warn().Msg("Didn't find identity for device")
return
}
mach.Log.Debug("Creating new Olm session with %s/%s (key: %s)", sender, deviceIdentity.DeviceID, senderKey)
log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session")
mach.devicesToUnwedgeLock.Lock()
mach.devicesToUnwedge[senderKey] = true
mach.devicesToUnwedgeLock.Unlock()
err = mach.SendEncryptedToDevice(deviceIdentity, event.ToDeviceDummy, event.Content{})
err = mach.SendEncryptedToDevice(ctx, deviceIdentity, event.ToDeviceDummy, event.Content{})
if err != nil {
mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, senderKey, err)
log.Error().Err(err).Msg("Failed to send dummy event to unwedge session")
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,9 +7,12 @@
package crypto
import (
"context"
"errors"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
@@ -25,10 +28,12 @@ var (
)
func (mach *OlmMachine) LoadDevices(user id.UserID) map[id.DeviceID]*id.Device {
return mach.fetchKeys([]id.UserID{user}, "", true)[user]
// TODO proper context?
return mach.fetchKeys(context.TODO(), []id.UserID{user}, "", true)[user]
}
func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
log := zerolog.Ctx(ctx)
deviceKeys := resp.DeviceKeys[userID][deviceID]
for signerUserID, signerKeys := range deviceKeys.Signatures {
for signerKey, signature := range signerKeys {
@@ -43,17 +48,27 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.
if verified, err := olm.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified {
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
signature := deviceKeys.Signatures[signerUserID][id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]
mach.Log.Trace("Verified self-signing signature for device %s/%s: %s", signerUserID, deviceID, signature)
log.Trace().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signed_device_id", deviceID.String()).
Str("signature", signature).
Msg("Verified self-signing signature")
err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
if err != nil {
mach.Log.Warn("Failed to store self-signing signature for device %s/%s: %v", signerUserID, deviceID, err)
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signed_device_id", deviceID.String()).
Msg("Failed to store self-signing signature")
}
}
} else {
if err == nil {
err = errors.New("invalid signature")
}
mach.Log.Warn("Could not verify device self-signing signature for %s/%s: %v", signerUserID, deviceID, err)
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signed_device_id", deviceID.String()).
Msg("Failed to verify self-signing signature")
}
}
}
@@ -61,25 +76,29 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
if err != nil {
mach.Log.Warn("Failed to store self-signing signature for %s/%s: %v", signerUserID, signKey, err)
log.Warn().Err(err).
Str("signer_user_id", signerUserID.String()).
Str("signer_key", signKey).
Msg("Failed to store self-signing signature")
}
}
}
}
}
func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
// TODO this function should probably return errors
req := &mautrix.ReqQueryKeys{
DeviceKeys: mautrix.DeviceKeysRequest{},
Timeout: 10 * 1000,
Token: sinceToken,
}
log := mach.machOrContextLog(ctx)
if !includeUntracked {
var err error
users, err = mach.CryptoStore.FilterTrackedUsers(users)
if err != nil {
mach.Log.Warn("Failed to filter tracked user list: %v", err)
log.Warn().Err(err).Msg("Failed to filter tracked user list")
}
}
if len(users) == 0 {
@@ -88,62 +107,88 @@ func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeU
for _, userID := range users {
req.DeviceKeys[userID] = mautrix.DeviceIDList{}
}
mach.Log.Trace("Querying keys for %v", users)
log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users")
resp, err := mach.Client.QueryKeys(req)
if err != nil {
mach.Log.Warn("Failed to query keys: %v", err)
log.Error().Err(err).Msg("Failed to query keys")
return
}
for server, err := range resp.Failures {
mach.Log.Warn("Query keys failure for %s: %v", server, err)
log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server")
}
mach.Log.Trace("Query key result received with %d users", len(resp.DeviceKeys))
log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received")
data = make(map[id.UserID]map[id.DeviceID]*id.Device)
for userID, devices := range resp.DeviceKeys {
log := log.With().Str("user_id", userID.String()).Logger()
delete(req.DeviceKeys, userID)
newDevices := make(map[id.DeviceID]*id.Device)
existingDevices, err := mach.CryptoStore.GetDevices(userID)
if err != nil {
mach.Log.Warn("Failed to get existing devices for %s: %v", userID, err)
log.Warn().Err(err).Msg("Failed to get existing devices for user")
existingDevices = make(map[id.DeviceID]*id.Device)
}
mach.Log.Trace("Updating devices for %s, got %d devices, have %d in store", userID, len(devices), len(existingDevices))
log.Debug().
Int("new_device_count", len(devices)).
Int("old_device_count", len(existingDevices)).
Msg("Updating devices in store")
changed := false
for deviceID, deviceKeys := range devices {
log := log.With().Str("device_id", deviceID.String()).Logger()
existing, ok := existingDevices[deviceID]
if !ok {
// New device
changed = true
}
mach.Log.Trace("Validating device %s of %s", deviceID, userID)
log.Trace().Msg("Validating device")
newDevice, err := mach.validateDevice(userID, deviceID, deviceKeys, existing)
if err != nil {
mach.Log.Error("Failed to validate device %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to validate device")
} else if newDevice != nil {
newDevices[deviceID] = newDevice
mach.storeDeviceSelfSignatures(userID, deviceID, resp)
mach.storeDeviceSelfSignatures(ctx, userID, deviceID, resp)
}
}
mach.Log.Trace("Storing new device list for %s containing %d devices", userID, len(newDevices))
log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list")
err = mach.CryptoStore.PutDevices(userID, newDevices)
if err != nil {
mach.Log.Warn("Failed to update device list for %s: %v", userID, err)
log.Warn().Err(err).Msg("Failed to update device list")
}
data[userID] = newDevices
changed = changed || len(newDevices) != len(existingDevices)
if changed {
if mach.DeleteKeysOnDeviceDelete {
for deviceID := range newDevices {
delete(existingDevices, deviceID)
}
for _, device := range existingDevices {
log := log.With().
Str("device_id", device.DeviceID.String()).
Str("identity_key", device.IdentityKey.String()).
Str("signing_key", device.SigningKey.String()).
Logger()
sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed")
if err != nil {
log.Err(err).Msg("Failed to redact megolm sessions from deleted device")
} else {
log.Info().
Strs("session_ids", stringifyArray(sessionIDs)).
Msg("Redacted megolm sessions from deleted device")
}
}
}
mach.OnDevicesChanged(userID)
}
}
for userID := range req.DeviceKeys {
mach.Log.Warn("Didn't get any keys for user %s", userID)
log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user")
}
mach.storeCrossSigningKeys(resp.MasterKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(resp.SelfSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(resp.UserSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys)
mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys)
return data
}
@@ -154,10 +199,16 @@ func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeU
// not need to be called manually.
func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) {
for _, roomID := range mach.StateStore.FindSharedRooms(userID) {
mach.Log.Debug("Devices of %s changed, invalidating group session for %s", userID, roomID)
mach.Log.Debug().
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Invalidating group session in room due to device change notification")
err := mach.CryptoStore.RemoveOutboundGroupSession(roomID)
if err != nil {
mach.Log.Warn("Failed to invalidate outbound group session of %s on device change for %s: %v", roomID, userID, err)
mach.Log.Warn().Err(err).
Str("user_id", userID.String()).
Str("room_id", roomID.String()).
Msg("Failed to invalidate outbound group session")
}
}
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,10 +7,15 @@
package crypto
import (
"context"
"encoding/base64"
"encoding/binary"
"encoding/json"
"errors"
"fmt"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
@@ -33,6 +38,18 @@ func getRelatesTo(content interface{}) *event.RelatesTo {
return nil
}
func getMentions(content interface{}) *event.Mentions {
contentStruct, ok := content.(*event.Content)
if ok {
content = contentStruct.Parsed
}
message, ok := content.(*event.MessageEventContent)
if ok {
return message.Mentions
}
return nil
}
type rawMegolmEvent struct {
RoomID id.RoomID `json:"room_id"`
Type event.Type `json:"type"`
@@ -44,12 +61,29 @@ func IsShareError(err error) bool {
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
}
func parseMessageIndex(ciphertext []byte) (uint64, error) {
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
var err error
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
if err != nil {
return 0, err
} else if decoded[0] != 3 || decoded[1] != 8 {
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
}
index, read := binary.Uvarint(decoded[2 : 2+binary.MaxVarintLen64])
if read <= 0 {
return 0, fmt.Errorf("failed to decode varint, read value %d", read)
}
return index, nil
}
// EncryptMegolmEvent encrypts data with the m.megolm.v1.aes-sha2 algorithm.
//
// If you use the event.Content struct, make sure you pass a pointer to the struct,
// as JSON serialization will not work correctly otherwise.
func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.Log.Trace("Encrypting event of type %s for %s", evtType.Type, roomID)
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
@@ -64,15 +98,28 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
if err != nil {
return nil, err
}
log := mach.machOrContextLog(ctx).With().
Str("event_type", evtType.Type).
Str("room_id", roomID.String()).
Str("session_id", session.ID().String()).
Logger()
log.Trace().Msg("Encrypting event...")
ciphertext, err := session.Encrypt(plaintext)
if err != nil {
return nil, err
}
idx, err := parseMessageIndex(ciphertext)
if err != nil {
log.Warn().Err(err).Msg("Failed to get megolm message index of encrypted event")
} else {
log = log.With().Uint64("message_index", idx).Logger()
}
log.Debug().Msg("Encrypted event successfully")
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
if err != nil {
mach.Log.Warn("Failed to update megolm session in crypto store after encrypting: %v", err)
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
}
return &event.EncryptedEventContent{
encrypted := &event.EncryptedEventContent{
Algorithm: id.AlgorithmMegolmV1,
SessionID: session.ID(),
MegolmCiphertext: ciphertext,
@@ -81,13 +128,19 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
// These are deprecated
SenderKey: mach.account.IdentityKey(),
DeviceID: mach.Client.DeviceID,
}, nil
}
if mach.PlaintextMentions {
encrypted.Mentions = getMentions(content)
}
return encrypted, nil
}
func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroupSession {
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(idKey, signingKey, roomID, session.ID(), session.Internal.Key(), "create")
if !mach.DontStoreOutboundKeys {
signingKey, idKey := mach.account.Keys()
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
}
return session
}
@@ -96,21 +149,38 @@ type deviceSessionWrapper struct {
identity *id.Device
}
func strishArray[T ~string](arr []T) []string {
out := make([]string, len(arr))
for i, item := range arr {
out[i] = string(item)
}
return out
}
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
//
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) error {
mach.Log.Debug("Sharing group session for room %s to %v", roomID, users)
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
mach.megolmEncryptLock.Lock()
defer mach.megolmEncryptLock.Unlock()
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
if err != nil {
return fmt.Errorf("failed to get previous outbound group session: %w", err)
} else if session != nil && session.Shared && !session.Expired() {
return AlreadyShared
}
log := mach.machOrContextLog(ctx).With().
Str("room_id", roomID.String()).
Str("action", "share megolm session").
Logger()
ctx = log.WithContext(ctx)
if session == nil || session.Expired() {
session = mach.newOutboundGroupSession(roomID)
session = mach.newOutboundGroupSession(ctx, roomID)
}
log = log.With().Str("session_id", session.ID().String()).Logger()
ctx = log.WithContext(ctx)
log.Debug().Strs("users", strishArray(users)).Msg("Sharing group session for room")
withheldCount := 0
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
@@ -120,20 +190,25 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
var fetchKeys []id.UserID
for _, userID := range users {
log := log.With().Str("target_user_id", userID.String()).Logger()
devices, err := mach.CryptoStore.GetDevices(userID)
if err != nil {
mach.Log.Error("Failed to get devices of %s", userID)
log.Error().Err(err).Msg("Failed to get devices of user")
} else if devices == nil {
mach.Log.Trace("GetDevices returned nil for %s, will fetch keys and retry", userID)
log.Debug().Msg("GetDevices returned nil, will fetch keys and retry")
fetchKeys = append(fetchKeys, userID)
} else if len(devices) == 0 {
mach.Log.Trace("%s has no devices, skipping", userID)
log.Trace().Msg("User has no devices, skipping")
} else {
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s", session.ID(), userID)
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user")
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
mach.Log.Trace("Found %d sessions, withholding from %d sessions and missing %d sessions to encrypt %s for for %s", len(olmSessions[userID]), len(toDeviceWithheld.Messages[userID]), len(missingUserSessions), session.ID(), userID)
mach.findOlmSessionsForUser(ctx, session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
log.Debug().
Int("olm_session_count", len(olmSessions[userID])).
Int("withheld_count", len(toDeviceWithheld.Messages[userID])).
Int("missing_count", len(missingUserSessions)).
Msg("Completed first pass of finding olm sessions")
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(missingUserSessions) > 0 {
missingSessions[userID] = missingUserSessions
@@ -146,18 +221,21 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
}
if len(fetchKeys) > 0 {
mach.Log.Trace("Fetching missing keys for %v", fetchKeys)
for userID, devices := range mach.fetchKeys(fetchKeys, "", true) {
mach.Log.Trace("Got %d device keys for %s", len(devices), userID)
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
log.Debug().
Int("device_count", len(devices)).
Str("target_user_id", userID.String()).
Msg("Got device keys for user")
missingSessions[userID] = devices
}
}
if len(missingSessions) > 0 {
mach.Log.Trace("Creating missing outbound sessions")
err = mach.createOutboundSessions(missingSessions)
log.Debug().Msg("Creating missing olm sessions")
err = mach.createOutboundSessions(ctx, missingSessions)
if err != nil {
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
log.Error().Err(err).Msg("Failed to create missing olm sessions")
}
}
@@ -176,42 +254,51 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
withheld = make(map[id.DeviceID]*event.Content)
toDeviceWithheld.Messages[userID] = withheld
}
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s (post-fetch retry)", session.ID(), userID)
mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil)
mach.Log.Trace("Found %d sessions and withholding from %d sessions to encrypt %s for for %s (post-fetch retry)", len(output), len(withheld), session.ID(), userID)
log := log.With().Str("target_user_id", userID.String()).Logger()
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
log.Debug().
Int("olm_session_count", len(output)).
Int("withheld_count", len(withheld)).
Msg("Completed post-fetch retry of finding olm sessions")
withheldCount += len(toDeviceWithheld.Messages[userID])
if len(toDeviceWithheld.Messages[userID]) == 0 {
delete(toDeviceWithheld.Messages, userID)
}
}
err = mach.encryptAndSendGroupSession(session, olmSessions)
err = mach.encryptAndSendGroupSession(ctx, session, olmSessions)
if err != nil {
return fmt.Errorf("failed to share group session: %w", err)
}
if len(toDeviceWithheld.Messages) > 0 {
mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID)
log.Debug().
Int("device_count", withheldCount).
Int("user_count", len(toDeviceWithheld.Messages)).
Msg("Sending to-device messages to report withheld key")
// TODO remove the next 4 lines once clients support m.room_key.withheld
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err)
log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)")
}
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
if err != nil {
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
log.Warn().Err(err).Msg("Failed to report withheld keys")
}
}
mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID)
log.Debug().Msg("Group session successfully shared")
session.Shared = true
return mach.CryptoStore.AddOutboundGroupSession(session)
}
func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
mach.olmLock.Lock()
defer mach.olmLock.Unlock()
mach.Log.Trace("Encrypting group session %s for all found devices", session.ID())
log := zerolog.Ctx(ctx)
log.Trace().Msg("Encrypting group session for all found devices")
deviceCount := 0
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
for userID, sessions := range olmSessions {
@@ -221,31 +308,41 @@ func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession
output := make(map[id.DeviceID]*event.Content)
toDevice.Messages[userID] = output
for deviceID, device := range sessions {
mach.Log.Trace("Encrypting group session %s for %s of %s", session.ID(), deviceID, userID)
content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
log.Trace().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Msg("Encrypting group session for device")
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
output[deviceID] = &event.Content{Parsed: content}
deviceCount++
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
log.Debug().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Msg("Encrypted group session for device")
}
}
mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID())
log.Debug().
Int("device_count", deviceCount).
Int("user_count", len(toDevice.Messages)).
Msg("Sending to-device messages to share group session")
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
return err
}
func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
for deviceID, device := range devices {
log := zerolog.Ctx(ctx).With().
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Logger()
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
if state := session.Users[userKey]; state != OGSNotShared {
continue
} else if userID == mach.Client.UserID && deviceID == mach.Client.DeviceID {
session.Users[userKey] = OGSIgnored
} else if device.Trust == id.TrustStateBlacklisted {
mach.Log.Debug(
"Not encrypting group session %s for %s of %s: device is blacklisted",
session.ID(), deviceID, userID,
)
log.Debug().Msg("Not encrypting group session for device: device is blacklisted")
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
@@ -256,10 +353,10 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
}}
session.Users[userKey] = OGSIgnored
} else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust {
mach.Log.Debug(
"Not encrypting group session %s for %s of %s: device is not verified (minimum: %s, device: %s)",
session.ID(), deviceID, userID, mach.SendKeysMinTrust, trustState,
)
log.Debug().
Str("min_trust", mach.SendKeysMinTrust.String()).
Str("device_trust", trustState.String()).
Msg("Not encrypting group session for device: device is not trusted")
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
RoomID: session.RoomID,
Algorithm: id.AlgorithmMegolmV1,
@@ -270,9 +367,9 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
}}
session.Users[userKey] = OGSIgnored
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
mach.Log.Error("Failed to get session for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
} else if deviceSession == nil {
mach.Log.Warn("Didn't find a session for %s of %s", deviceID, userID)
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")
if missingOutput != nil {
missingOutput[deviceID] = device
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,6 +7,7 @@
package crypto
import (
"context"
"encoding/json"
"fmt"
@@ -16,7 +17,7 @@ import (
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) encryptOlmEvent(session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent {
func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent {
evt := &DecryptedOlmEvent{
Sender: mach.Client.UserID,
SenderDevice: mach.Client.DeviceID,
@@ -30,11 +31,16 @@ func (mach *OlmMachine) encryptOlmEvent(session *OlmSession, recipient *id.Devic
if err != nil {
panic(err)
}
mach.Log.Trace("Encrypting olm message for %s with session %s: %s", recipient.IdentityKey, session.ID(), session.Describe())
log := mach.machOrContextLog(ctx)
log.Debug().
Str("recipient_identity_key", recipient.IdentityKey.String()).
Str("olm_session_id", session.ID().String()).
Str("olm_session_description", session.Describe()).
Msg("Encrypting olm message")
msgType, ciphertext := session.Encrypt(plaintext)
err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session)
if err != nil {
mach.Log.Warn("Failed to update olm session in crypto store after encrypting: %v", err)
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
}
return &event.EncryptedEventContent{
Algorithm: id.AlgorithmOlmV1,
@@ -61,7 +67,7 @@ func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool
return shouldUnwedge
}
func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.DeviceID]*id.Device) error {
func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id.UserID]map[id.DeviceID]*id.Device) error {
request := make(mautrix.OneTimeKeysRequest)
for userID, devices := range input {
request[userID] = make(map[id.DeviceID]id.KeyAlgorithm)
@@ -84,6 +90,7 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device
if err != nil {
return fmt.Errorf("failed to claim keys: %w", err)
}
log := mach.machOrContextLog(ctx)
for userID, user := range resp.OneTimeKeys {
for deviceID, oneTimeKeys := range user {
var oneTimeKey mautrix.OneTimeKey
@@ -91,25 +98,30 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device
for keyID, oneTimeKey = range oneTimeKeys {
break
}
keyAlg, keyIndex := keyID.Parse()
log := log.With().
Str("peer_user_id", userID.String()).
Str("peer_device_id", deviceID.String()).
Str("peer_otk_id", keyID.String()).
Logger()
keyAlg, _ := keyID.Parse()
if keyAlg != id.KeyAlgorithmSignedCurve25519 {
mach.Log.Warn("Unexpected key ID algorithm in one-time key response for %s of %s: %s", deviceID, userID, keyID)
log.Warn().Msg("Unexpected key ID algorithm in one-time key response")
continue
}
identity := input[userID][deviceID]
if ok, err := olm.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil {
mach.Log.Error("Failed to verify signature for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to verify signature of one-time key")
} else if !ok {
mach.Log.Warn("Invalid signature for %s of %s", deviceID, userID)
log.Warn().Msg("One-time key has invalid signature from device")
} else if sess, err := mach.account.Internal.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil {
mach.Log.Error("Failed to create outbound session for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
} else {
wrapped := wrapSession(sess)
err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped)
if err != nil {
mach.Log.Error("Failed to store created session for %s of %s: %v", deviceID, userID, err)
log.Error().Err(err).Msg("Failed to store created outbound session")
} else {
mach.Log.Debug("Created new Olm session with %s/%s (OTK ID: %s)", userID, deviceID, keyIndex)
log.Debug().Msg("Created new Olm session")
}
}
}

View File

@@ -103,7 +103,7 @@ func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error)
SenderKey: session.SenderKey,
SenderClaimedKeys: SenderClaimedKeys{},
SessionID: session.ID(),
SessionKey: key,
SessionKey: string(key),
}
}
return export, nil

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -17,6 +17,7 @@ import (
"encoding/json"
"errors"
"fmt"
"time"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
@@ -108,6 +109,8 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
RoomID: session.RoomID,
// TODO should we add something here to mark the signing key as unverified like key requests do?
ForwardingChains: session.ForwardingChains,
ReceivedAt: time.Now().UTC(),
}
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
@@ -136,14 +139,18 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er
count := 0
for _, session := range sessions {
log := mach.Log.With().
Str("room_id", session.RoomID.String()).
Str("session_id", session.SessionID.String()).
Logger()
imported, err := mach.importExportedRoomKey(session)
if err != nil {
mach.Log.Warn("Failed to import Megolm session %s/%s from file: %v", session.RoomID, session.SessionID, err)
log.Error().Err(err).Msg("Failed to import Megolm session from file")
} else if imported {
mach.Log.Debug("Imported Megolm session %s/%s from file", session.RoomID, session.SessionID)
log.Debug().Msg("Imported Megolm session from file")
count++
} else {
mach.Log.Debug("Skipped Megolm session %s/%s: already in store", session.RoomID, session.SessionID)
log.Debug().Msg("Skipped Megolm session which is already in the store")
}
}
return count, len(sessions), nil

View File

@@ -1,16 +1,18 @@
// Copyright (c) 2020 Nikos Filippakis
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !nosas
// +build !nosas
package crypto
import (
"context"
"errors"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/id"
@@ -56,11 +58,11 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
select {
case <-keyResponseReceived:
// key request successful
mach.Log.Debug("Key for session %v was received, cancelling other key requests", sessionID)
mach.Log.Debug().Msgf("Key for session %v was received, cancelling other key requests", sessionID)
resChan <- true
case <-ctx.Done():
// if the context is done, key request was unsuccessful
mach.Log.Debug("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID)
mach.Log.Debug().Msgf("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID)
resChan <- false
}
@@ -128,20 +130,41 @@ func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.Sender
return err
}
func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content *event.ForwardedRoomKeyEventContent) bool {
func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.ForwardedRoomKeyEventContent) bool {
log := zerolog.Ctx(ctx).With().
Str("session_id", content.SessionID.String()).
Str("room_id", content.RoomID.String()).
Logger()
if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" {
mach.Log.Debug("Ignoring weird forwarded room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID)
log.Debug().
Str("algorithm", string(content.Algorithm)).
Msg("Ignoring weird forwarded room key")
return false
}
igsInternal, err := olm.InboundGroupSessionImport([]byte(content.SessionKey))
if err != nil {
mach.Log.Error("Failed to import inbound group session: %v", err)
log.Error().Err(err).Msg("Failed to import inbound group session")
return false
} else if igsInternal.ID() != content.SessionID {
mach.Log.Warn("Mismatched session ID while creating inbound group session")
log.Warn().
Str("actual_session_id", igsInternal.ID().String()).
Msg("Mismatched session ID while creating inbound group session from forward")
return false
}
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
var maxAge time.Duration
var maxMessages int
if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
maxMessages = config.RotationPeriodMessages
}
if content.MaxAge != 0 {
maxAge = time.Duration(content.MaxAge) * time.Millisecond
}
if content.MaxMessages != 0 {
maxMessages = content.MaxMessages
}
igs := &InboundGroupSession{
Internal: *igsInternal,
SigningKey: evt.Keys.Ed25519,
@@ -149,14 +172,19 @@ func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content *
RoomID: content.RoomID,
ForwardingChains: append(content.ForwardingKeyChain, evt.SenderKey.String()),
id: content.SessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: content.IsScheduled,
}
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
if err != nil {
mach.Log.Error("Failed to store new inbound group session: %v", err)
log.Error().Err(err).Msg("Failed to store new inbound group session")
return false
}
mach.markSessionReceived(content.SessionID)
mach.Log.Trace("Received forwarded inbound group session %s/%s/%s", content.RoomID, content.SenderKey, content.SessionID)
log.Debug().Msg("Received forwarded inbound group session")
return true
}
@@ -175,50 +203,72 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id
}
err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content)
if err != nil {
mach.Log.Warn("Failed to send key share rejection %s to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err)
mach.Log.Warn().Err(err).
Str("code", string(rejection.Code)).
Str("user_id", device.UserID.String()).
Str("device_id", device.DeviceID.String()).
Msg("Failed to send key share rejection")
}
err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content)
if err != nil {
mach.Log.Warn("Failed to send key share rejection %s (org.matrix.) to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err)
mach.Log.Warn().Err(err).
Str("code", string(rejection.Code)).
Str("user_id", device.UserID.String()).
Str("device_id", device.DeviceID.String()).
Msg("Failed to send key share rejection (legacy event type)")
}
}
func (mach *OlmMachine) defaultAllowKeyShare(device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection {
func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection {
log := mach.machOrContextLog(ctx)
if mach.Client.UserID != device.UserID {
mach.Log.Debug("Ignoring key request from a different user (%s)", device.UserID)
log.Debug().Msg("Rejecting key request from a different user")
return &KeyShareRejectOtherUser
} else if mach.Client.DeviceID == device.DeviceID {
mach.Log.Debug("Ignoring key request from ourselves")
log.Debug().Msg("Ignoring key request from ourselves")
return &KeyShareRejectNoResponse
} else if device.Trust == id.TrustStateBlacklisted {
mach.Log.Debug("Ignoring key request from blacklisted device %s", device.DeviceID)
log.Debug().Msg("Rejecting key request from blacklisted device")
return &KeyShareRejectBlacklisted
} else if trustState := mach.ResolveTrust(device); trustState >= mach.ShareKeysMinTrust {
mach.Log.Debug("Accepting key request from device %s (trust state: %s)", device.DeviceID, trustState)
log.Debug().
Str("min_trust", mach.SendKeysMinTrust.String()).
Str("device_trust", trustState.String()).
Msg("Accepting key request from trusted device")
return nil
} else {
mach.Log.Debug("Ignoring key request from unverified device %s (trust state: %s)", device.DeviceID, trustState)
log.Debug().
Str("min_trust", mach.SendKeysMinTrust.String()).
Str("device_trust", trustState.String()).
Msg("Rejecting key request from untrusted device")
return &KeyShareRejectUnverified
}
}
func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.RoomKeyRequestEventContent) {
func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.UserID, content *event.RoomKeyRequestEventContent) {
log := zerolog.Ctx(ctx).With().
Str("request_id", content.RequestID).
Str("device_id", content.RequestingDeviceID.String()).
Str("room_id", content.Body.RoomID.String()).
Str("session_id", content.Body.SessionID.String()).
Logger()
ctx = log.WithContext(ctx)
if content.Action != event.KeyRequestActionRequest {
return
} else if content.RequestingDeviceID == mach.Client.DeviceID && sender == mach.Client.UserID {
mach.Log.Debug("Ignoring key request %s from ourselves", content.RequestID)
log.Debug().Msg("Ignoring key request from ourselves")
return
}
mach.Log.Debug("Received key request %s for %s from %s/%s", content.RequestID, content.Body.SessionID, sender, content.RequestingDeviceID)
log.Debug().Msg("Received key request")
device, err := mach.GetOrFetchDevice(sender, content.RequestingDeviceID)
device, err := mach.GetOrFetchDevice(ctx, sender, content.RequestingDeviceID)
if err != nil {
mach.Log.Error("Failed to fetch device %s/%s that requested keys: %v", sender, content.RequestingDeviceID, err)
log.Error().Err(err).Msg("Failed to fetch device that requested keys")
return
}
rejection := mach.AllowKeyShare(device, content.Body)
rejection := mach.AllowKeyShare(ctx, device, content.Body)
if rejection != nil {
mach.rejectKeyRequest(*rejection, device, content.Body)
return
@@ -226,18 +276,29 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
if err != nil {
mach.Log.Error("Failed to fetch group session to forward to %s/%s: %v", device.UserID, device.DeviceID, err)
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Requested group session not available")
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
} else {
log.Error().Err(err).Msg("Failed to get group session to forward")
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
}
return
} else if igs == nil {
mach.Log.Warn("Didn't find group session %s to forward to %s/%s", content.Body.SessionID, device.UserID, device.DeviceID)
log.Error().Msg("Didn't find group session to forward")
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
return
}
if internalID := igs.ID(); internalID != content.Body.SessionID {
// Should this be an error?
log = log.With().Str("unexpected_session_id", internalID.String()).Logger()
}
exportedKey, err := igs.Internal.Export(igs.Internal.FirstKnownIndex())
firstKnownIndex := igs.Internal.FirstKnownIndex()
log = log.With().Uint32("first_known_index", firstKnownIndex).Logger()
exportedKey, err := igs.Internal.Export(firstKnownIndex)
if err != nil {
mach.Log.Error("Failed to export session %s to forward to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err)
log.Error().Err(err).Msg("Failed to export group session to forward")
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
return
}
@@ -248,7 +309,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
Algorithm: id.AlgorithmMegolmV1,
RoomID: igs.RoomID,
SessionID: igs.ID(),
SessionKey: exportedKey,
SessionKey: string(exportedKey),
},
SenderKey: content.Body.SenderKey,
ForwardingKeyChain: igs.ForwardingChains,
@@ -256,9 +317,42 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
},
}
if err := mach.SendEncryptedToDevice(device, event.ToDeviceForwardedRoomKey, forwardedRoomKey); err != nil {
mach.Log.Error("Failed to send encrypted forwarded key %s to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err)
if err = mach.SendEncryptedToDevice(ctx, device, event.ToDeviceForwardedRoomKey, forwardedRoomKey); err != nil {
log.Error().Err(err).Msg("Failed to encrypt and send group session")
} else {
log.Debug().Msg("Successfully sent forwarded group session")
}
}
func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) {
log := mach.machOrContextLog(ctx).With().
Str("room_id", content.RoomID.String()).
Str("session_id", content.SessionID.String()).
Int("first_message_index", content.FirstMessageIndex).
Logger()
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
if err != nil {
if errors.Is(err, ErrGroupSessionWithheld) {
log.Debug().Err(err).Msg("Acked group session was already redacted")
} else {
log.Err(err).Msg("Failed to get group session to check if it should be redacted")
}
return
}
log = log.With().
Str("sender_key", sess.SenderKey.String()).
Str("own_identity", mach.OwnIdentity().IdentityKey.String()).
Logger()
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
if err != nil {
log.Err(err).Msg("Failed to redact group session")
}
} else {
log.Debug().Bool("inbound", isInbound).Msg("Received room key ack")
}
mach.Log.Debug("Sent encrypted forwarded key to device %s/%s for session %s", device.UserID, device.DeviceID, igs.ID())
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -7,12 +7,14 @@
package crypto
import (
"context"
"errors"
"fmt"
"sync"
"time"
"maunium.net/go/mautrix/appservice"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/id"
@@ -20,28 +22,21 @@ import (
"maunium.net/go/mautrix/event"
)
// Logger is a simple logging struct for OlmMachine.
// Implementations are recommended to use fmt.Sprintf and manually add a newline after the message.
type Logger interface {
Error(message string, args ...interface{})
Warn(message string, args ...interface{})
Debug(message string, args ...interface{})
Trace(message string, args ...interface{})
}
// OlmMachine is the main struct for handling Matrix end-to-end encryption.
type OlmMachine struct {
Client *mautrix.Client
SSSS *ssss.Machine
Log Logger
Log *zerolog.Logger
CryptoStore Store
StateStore StateStore
PlaintextMentions bool
SendKeysMinTrust id.TrustState
ShareKeysMinTrust id.TrustState
AllowKeyShare func(*id.Device, event.RequestedKeyInfo) *KeyShareRejection
AllowKeyShare func(context.Context, *id.Device, event.RequestedKeyInfo) *KeyShareRejection
DefaultSASTimeout time.Duration
// AcceptVerificationFrom determines whether the machine will accept verification requests from this device.
@@ -60,7 +55,9 @@ type OlmMachine struct {
recentlyUnwedged map[id.IdentityKey]time.Time
recentlyUnwedgedLock sync.Mutex
olmLock sync.Mutex
olmLock sync.Mutex
megolmEncryptLock sync.Mutex
megolmDecryptLock sync.Mutex
otkUploadLock sync.Mutex
lastOTKUpload time.Time
@@ -69,6 +66,13 @@ type OlmMachine struct {
crossSigningPubkeys *CrossSigningPublicKeysCache
crossSigningPubkeysFetched bool
DeleteOutboundKeysOnAck bool
DontStoreOutboundKeys bool
DeletePreviousKeysOnReceive bool
RatchetKeysOnDecrypt bool
DeleteFullyUsedKeysOnDecrypt bool
DeleteKeysOnDeviceDelete bool
}
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
@@ -82,7 +86,11 @@ type StateStore interface {
}
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateStore StateStore) *OlmMachine {
func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Store, stateStore StateStore) *OlmMachine {
if log == nil {
logPtr := zerolog.Nop()
log = &logPtr
}
mach := &OlmMachine{
Client: client,
SSSS: ssss.NewSSSSMachine(client),
@@ -111,6 +119,14 @@ func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateS
return mach
}
func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
return mach.Log
}
return log
}
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
// This must be called before using the machine.
func (mach *OlmMachine) Load() (err error) {
@@ -127,7 +143,7 @@ func (mach *OlmMachine) Load() (err error) {
func (mach *OlmMachine) saveAccount() {
err := mach.CryptoStore.PutAccount(mach.account)
if err != nil {
mach.Log.Error("Failed to save account: %v", err)
mach.Log.Error().Err(err).Msg("Failed to save account")
}
}
@@ -136,12 +152,15 @@ func (mach *OlmMachine) FlushStore() error {
return mach.CryptoStore.Flush()
}
func (mach *OlmMachine) timeTrace(thing, trace string, expectedDuration time.Duration) func() {
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
start := time.Now()
return func() {
duration := time.Now().Sub(start)
if duration > expectedDuration {
mach.Log.Warn("%s took %s (trace: %s)", thing, duration, trace)
zerolog.Ctx(ctx).Warn().
Str("action", thing).
Dur("duration", duration).
Msg("Executing encryption function took longer than expected")
}
}
}
@@ -156,7 +175,11 @@ func (mach *OlmMachine) Fingerprint() string {
return mach.account.SigningKey().Fingerprint()
}
// OwnIdentity returns this device's DeviceIdentity struct
func (mach *OlmMachine) GetAccount() *OlmAccount {
return mach.account
}
// OwnIdentity returns this device's id.Device struct
func (mach *OlmMachine) OwnIdentity() *id.Device {
return &id.Device{
UserID: mach.Client.UserID,
@@ -168,11 +191,18 @@ func (mach *OlmMachine) OwnIdentity() *id.Device {
}
}
func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az *appservice.AppService) {
type asEventProcessor interface {
On(evtType event.Type, handler func(evt *event.Event))
OnOTK(func(otk *mautrix.OTKCount))
OnDeviceList(func(lists *mautrix.DeviceLists, since string))
}
func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
// ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events
ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceRoomKeyWithheld, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceBeeperRoomKeyAck, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceOrgMatrixRoomKeyWithheld, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationRequest, mach.HandleToDeviceEvent)
ep.On(event.ToDeviceVerificationStart, mach.HandleToDeviceEvent)
@@ -182,34 +212,44 @@ func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az
ep.On(event.ToDeviceVerificationCancel, mach.HandleToDeviceEvent)
ep.OnOTK(mach.HandleOTKCounts)
ep.OnDeviceList(mach.HandleDeviceLists)
mach.Log.Trace("Added listeners for encryption data coming from appservice transactions")
mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions")
}
func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
if len(dl.Changed) > 0 {
traceID := time.Now().Format("15:04:05.000000")
mach.Log.Trace("Device list changes in /sync: %v (trace: %s)", dl.Changed, traceID)
mach.fetchKeys(dl.Changed, since, false)
mach.Log.Trace("Finished handling device list changes (trace: %s)", traceID)
mach.Log.Debug().
Str("trace_id", traceID).
Interface("changes", dl.Changed).
Msg("Device list changes in /sync")
mach.fetchKeys(context.TODO(), dl.Changed, since, false)
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
}
}
func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Debug("Dropping OTK counts targeted to %s/%s (not us)", otkCount.UserID, otkCount.DeviceID)
mach.Log.Warn().
Str("target_user_id", otkCount.UserID.String()).
Str("target_device_id", otkCount.DeviceID.String()).
Msg("Dropping OTK counts targeted to someone else")
return
}
minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
if otkCount.SignedCurve25519 < int(minCount) {
traceID := time.Now().Format("15:04:05.000000")
mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones... (trace: %s)", otkCount.SignedCurve25519, traceID)
err := mach.ShareKeys(otkCount.SignedCurve25519)
log := mach.Log.With().Str("trace_id", traceID).Logger()
ctx := log.WithContext(context.Background())
log.Debug().
Int("keys_left", otkCount.Curve25519).
Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...")
err := mach.ShareKeys(ctx, otkCount.SignedCurve25519)
if err != nil {
mach.Log.Error("Failed to share keys: %v (trace: %s)", err, traceID)
log.Error().Err(err).Msg("Failed to share keys")
} else {
mach.Log.Debug("Successfully shared keys (trace: %s)", traceID)
log.Debug().Msg("Successfully shared keys")
}
}
}
@@ -218,7 +258,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
//
// This can be easily registered into a mautrix client using .OnSync():
//
// client.Syncer.(*mautrix.DefaultSyncer).OnSync(c.crypto.ProcessSyncResponse)
// client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse)
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
mach.HandleDeviceLists(&resp.DeviceLists, since)
@@ -226,7 +266,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
evt.Type.Class = event.ToDeviceEventType
err := evt.Content.ParseRaw(evt.Type)
if err != nil {
mach.Log.Warn("Failed to parse to-device event of type %s: %v", evt.Type.Type, err)
mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event")
continue
}
mach.HandleToDeviceEvent(evt)
@@ -240,8 +280,8 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
//
// Currently this is not automatically called, so you must add a listener yourself:
//
// client.Syncer.(*mautrix.DefaultSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
if !mach.StateStore.IsEncrypted(evt.RoomID) {
return
}
@@ -263,10 +303,15 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
(prevContent.Membership == event.MembershipLeave && content.Membership == event.MembershipBan) {
return
}
mach.Log.Trace("Got membership state event in %s changing %s from %s to %s, invalidating group session", evt.RoomID, evt.GetStateKey(), prevContent.Membership, content.Membership)
mach.Log.Trace().
Str("room_id", evt.RoomID.String()).
Str("user_id", evt.GetStateKey()).
Str("prev_membership", string(prevContent.Membership)).
Str("new_membership", string(content.Membership)).
Msg("Got membership state change, invalidating group session in room")
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
if err != nil {
mach.Log.Warn("Failed to invalidate outbound group session of %s: %v", evt.RoomID, err)
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
}
}
@@ -275,43 +320,63 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
mach.Log.Debug("Dropping to-device event targeted to %s/%s (not us)", evt.ToUserID, evt.ToDeviceID)
mach.Log.Debug().
Str("target_user_id", evt.ToUserID.String()).
Str("target_device_id", evt.ToDeviceID.String()).
Msg("Dropping to-device event targeted to someone else")
return
}
traceID := time.Now().Format("15:04:05.000000")
log := mach.Log.With().
Str("trace_id", traceID).
Str("sender", evt.Sender.String()).
Str("type", evt.Type.Type).
Logger()
ctx := log.WithContext(context.Background())
if evt.Type != event.ToDeviceEncrypted {
mach.Log.Trace("Starting handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID)
log.Debug().Msg("Starting handling to-device event")
}
switch content := evt.Content.Parsed.(type) {
case *event.EncryptedEventContent:
mach.Log.Debug("Handling encrypted to-device event from %s/%s (trace: %s)", evt.Sender, content.SenderKey, traceID)
decryptedEvt, err := mach.decryptOlmEvent(evt, traceID)
log = log.With().
Str("sender_key", content.SenderKey.String()).
Logger()
log.Debug().Msg("Handling encrypted to-device event")
ctx = log.WithContext(context.Background())
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
if err != nil {
mach.Log.Error("Failed to decrypt to-device event: %v (trace: %s)", err, traceID)
log.Error().Err(err).Msg("Failed to decrypt to-device event")
return
}
mach.Log.Trace("Successfully decrypted to-device from %s/%s into type %s (sender key: %s, trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, decryptedEvt.Type.String(), decryptedEvt.SenderKey, traceID)
log = log.With().
Str("decrypted_type", decryptedEvt.Type.Type).
Str("sender_device", decryptedEvt.SenderDevice.String()).
Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()).
Logger()
log.Trace().Msg("Successfully decrypted to-device event")
switch decryptedContent := decryptedEvt.Content.Parsed.(type) {
case *event.RoomKeyEventContent:
mach.receiveRoomKey(decryptedEvt, decryptedContent, traceID)
mach.Log.Trace("Handled room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent)
log.Trace().Msg("Handled room key event")
case *event.ForwardedRoomKeyEventContent:
if mach.importForwardedRoomKey(decryptedEvt, decryptedContent) {
if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) {
if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok {
// close channel to notify listener that the key was received
close(ch.(chan struct{}))
}
}
mach.Log.Trace("Handled forwarded room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
log.Trace().Msg("Handled forwarded room key event")
case *event.DummyEventContent:
mach.Log.Debug("Received encrypted dummy event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
log.Debug().Msg("Received encrypted dummy event")
default:
mach.Log.Debug("Unhandled encrypted to-device event of type %s from %s/%s (trace: %s)", decryptedEvt.Type.String(), decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
log.Debug().Msg("Unhandled encrypted to-device event")
}
return
case *event.RoomKeyRequestEventContent:
go mach.handleRoomKeyRequest(evt.Sender, content)
go mach.handleRoomKeyRequest(ctx, evt.Sender, content)
case *event.BeeperRoomKeyAckEventContent:
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content)
// verification cases
case *event.VerificationStartEventContent:
mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "")
@@ -326,27 +391,25 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
case *event.VerificationRequestEventContent:
mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "")
case *event.RoomKeyWithheldEventContent:
mach.handleRoomKeyWithheld(content)
mach.handleRoomKeyWithheld(ctx, content)
default:
deviceID, _ := evt.Content.Raw["device_id"].(string)
mach.Log.Trace("Unhandled to-device event of type %s from %s/%s (trace: %s)", evt.Type.Type, evt.Sender, deviceID, traceID)
log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event")
return
}
mach.Log.Trace("Finished handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID)
log.Debug().Msg("Finished handling to-device event")
}
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
// and if it's not found it asks the server for it.
func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
// get device identity
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
if err != nil {
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
} else if device != nil {
return device, nil
}
// try to fetch if not found
usersToDevices := mach.fetchKeys([]id.UserID{userID}, "", true)
usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true)
if devices, ok := usersToDevices[userID]; ok {
if device, ok = devices[deviceID]; ok {
return device, nil
@@ -359,12 +422,15 @@ func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID)
// GetOrFetchDeviceByKey attempts to retrieve the device identity for the device with the given identity key from the
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
// the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
if err != nil || deviceIdentity != nil {
return deviceIdentity, err
}
mach.Log.Debug("Didn't find identity of %s/%s in crypto store, fetching from server", userID, identityKey)
mach.machOrContextLog(ctx).Debug().
Str("user_id", userID.String()).
Str("identity_key", identityKey.String()).
Msg("Didn't find identity in crypto store, fetching from server")
devices := mach.LoadDevices(userID)
for _, device := range devices {
if device.IdentityKey == identityKey {
@@ -375,8 +441,8 @@ func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.I
}
// SendEncryptedToDevice sends an Olm-encrypted event to the given user device.
func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.Type, content event.Content) error {
if err := mach.createOutboundSessions(map[id.UserID]map[id.DeviceID]*id.Device{
func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.Device, evtType event.Type, content event.Content) error {
if err := mach.createOutboundSessions(ctx, map[id.UserID]map[id.DeviceID]*id.Device{
device.UserID: {
device.DeviceID: device,
},
@@ -395,10 +461,16 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.T
return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID)
}
encrypted := mach.encryptOlmEvent(olmSess, device, evtType, content)
encrypted := mach.encryptOlmEvent(ctx, olmSess, device, evtType, content)
encryptedContent := &event.Content{Parsed: &encrypted}
mach.Log.Debug("Sending encrypted to-device event of type %s to %s/%s (identity key: %s, olm session ID: %s)", evtType.Type, device.UserID, device.DeviceID, device.IdentityKey, olmSess.ID())
mach.machOrContextLog(ctx).Debug().
Str("decrypted_type", evtType.Type).
Str("to_user_id", device.UserID.String()).
Str("to_device_id", device.DeviceID.String()).
Str("to_identity_key", device.IdentityKey.String()).
Str("olm_session_id", olmSess.ID().String()).
Msg("Sending encrypted to-device event")
_, err = mach.Client.SendToDevice(event.ToDeviceEncrypted,
&mautrix.ReqSendToDevice{
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
@@ -412,22 +484,32 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.T
return err
}
func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, traceID string) {
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey)
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) {
log := zerolog.Ctx(ctx)
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages, isScheduled)
if err != nil {
mach.Log.Error("Failed to create inbound group session: %v", err)
log.Error().Err(err).Msg("Failed to create inbound group session")
return
} else if igs.ID() != sessionID {
mach.Log.Warn("Mismatched session ID while creating inbound group session")
log.Warn().
Str("expected_session_id", sessionID.String()).
Str("actual_session_id", igs.ID().String()).
Msg("Mismatched session ID while creating inbound group session")
return
}
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
if err != nil {
mach.Log.Error("Failed to store new inbound group session: %v", err)
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return
}
mach.markSessionReceived(sessionID)
mach.Log.Debug("Received inbound group session %s / %s / %s", roomID, senderKey, sessionID)
log.Debug().
Str("session_id", sessionID.String()).
Str("sender_key", senderKey.String()).
Str("max_age", maxAge.String()).
Int("max_messages", maxMessages).
Bool("is_scheduled", isScheduled).
Msg("Received inbound group session")
}
func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
@@ -465,24 +547,66 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
}
}
func (mach *OlmMachine) receiveRoomKey(evt *DecryptedOlmEvent, content *event.RoomKeyEventContent, traceID string) {
// TODO nio had a comment saying "handle this better" for the case where evt.Keys.Ed25519 is none?
func stringifyArray[T ~string](arr []T) []string {
strs := make([]string, len(arr))
for i, v := range arr {
strs[i] = string(v)
}
return strs
}
func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.RoomKeyEventContent) {
log := zerolog.Ctx(ctx).With().
Str("algorithm", string(content.Algorithm)).
Str("session_id", content.SessionID.String()).
Str("room_id", content.RoomID.String()).
Logger()
if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" {
mach.Log.Debug("Ignoring weird room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID)
log.Debug().Msg("Ignoring weird room key")
return
}
mach.createGroupSession(evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, traceID)
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
var maxAge time.Duration
var maxMessages int
if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
if maxAge == 0 {
maxAge = 7 * 24 * time.Hour
}
maxMessages = config.RotationPeriodMessages
if maxMessages == 0 {
maxMessages = 100
}
}
if content.MaxAge != 0 {
maxAge = time.Duration(content.MaxAge) * time.Millisecond
}
if content.MaxMessages != 0 {
maxMessages = content.MaxMessages
}
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
if err != nil {
log.Err(err).Msg("Failed to redact previous megolm sessions")
} else {
log.Info().
Strs("session_ids", stringifyArray(sessionIDs)).
Msg("Redacted previous megolm sessions")
}
}
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled)
}
func (mach *OlmMachine) handleRoomKeyWithheld(content *event.RoomKeyWithheldEventContent) {
func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) {
if content.Algorithm != id.AlgorithmMegolmV1 {
mach.Log.Debug("Non-megolm room key withheld event: %+v", content)
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
return
}
err := mach.CryptoStore.PutWithheldGroupSession(*content)
if err != nil {
mach.Log.Error("Failed to save room key withheld event: %v", err)
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
}
}
@@ -491,34 +615,38 @@ func (mach *OlmMachine) handleRoomKeyWithheld(content *event.RoomKeyWithheldEven
// If the Olm account hasn't been shared, the account keys will be uploaded.
// If currentOTKCount is less than half of the limit (100 / 2 = 50), enough one-time keys will be uploaded so exactly
// half of the limit is filled.
func (mach *OlmMachine) ShareKeys(currentOTKCount int) error {
func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) error {
log := mach.machOrContextLog(ctx)
start := time.Now()
mach.otkUploadLock.Lock()
defer mach.otkUploadLock.Unlock()
if mach.lastOTKUpload.Add(1 * time.Minute).After(start) {
mach.Log.Trace("Checking OTK count from server due to suspiciously close share keys requests")
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests")
resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{})
if err != nil {
return fmt.Errorf("failed to check current OTK counts: %w", err)
}
mach.Log.Trace("Fetched current OTK count (%d) from server (input count was %d)", resp.OneTimeKeyCounts.SignedCurve25519, currentOTKCount)
log.Debug().
Int("input_count", currentOTKCount).
Int("server_count", resp.OneTimeKeyCounts.SignedCurve25519).
Msg("Fetched current OTK count from server")
currentOTKCount = resp.OneTimeKeyCounts.SignedCurve25519
}
var deviceKeys *mautrix.DeviceKeys
if !mach.account.Shared {
deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID)
mach.Log.Trace("Going to upload initial account keys")
log.Debug().Msg("Going to upload initial account keys")
}
oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount)
if len(oneTimeKeys) == 0 && deviceKeys == nil {
mach.Log.Trace("No one-time keys nor device keys got when trying to share keys")
log.Debug().Msg("No one-time keys nor device keys got when trying to share keys")
return nil
}
req := &mautrix.ReqUploadKeys{
DeviceKeys: deviceKeys,
OneTimeKeys: oneTimeKeys,
}
mach.Log.Trace("Uploading %d one-time keys", len(oneTimeKeys))
log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys")
_, err := mach.Client.UploadKeys(req)
if err != nil {
return err
@@ -528,3 +656,23 @@ func (mach *OlmMachine) ShareKeys(currentOTKCount int) error {
mach.saveAccount()
return nil
}
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
for {
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
if err != nil {
log.Err(err).Msg("Failed to redact expired megolm sessions")
} else if len(sessionIDs) > 0 {
log.Info().Strs("session_ids", stringifyArray(sessionIDs)).Msg("Redacted expired megolm sessions")
} else {
log.Debug().Msg("Didn't find any expired megolm sessions")
}
select {
case <-ctx.Done():
log.Debug().Msg("Loop stopped")
return
case <-time.After(24 * time.Hour):
}
}
}

View File

@@ -288,7 +288,7 @@ func (s *InboundGroupSession) exportLen() uint {
// if we do not have a session key corresponding to the given index (ie, it was
// sent before the session key was shared with us) the error will be
// "OLM_UNKNOWN_MESSAGE_INDEX".
func (s *InboundGroupSession) Export(messageIndex uint32) (string, error) {
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
key := make([]byte, s.exportLen())
r := C.olm_export_inbound_group_session(
(*C.OlmInboundGroupSession)(s.int),
@@ -296,7 +296,7 @@ func (s *InboundGroupSession) Export(messageIndex uint32) (string, error) {
C.size_t(len(key)),
C.uint32_t(messageIndex))
if r == errorVal() {
return "", s.lastError()
return nil, s.lastError()
}
return string(key[:r]), nil
return key[:r], nil
}

View File

@@ -1,6 +1,3 @@
//go:build !nosas
// +build !nosas
package olm
// #cgo LDFLAGS: -lolm -lstdc++

View File

@@ -89,6 +89,12 @@ func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([]
return msg, err
}
type RatchetSafety struct {
NextIndex uint `json:"next_index"`
MissedIndices []uint `json:"missed_indices,omitempty"`
LostIndices []uint `json:"lost_indices,omitempty"`
}
type InboundGroupSession struct {
Internal olm.InboundGroupSession
@@ -97,11 +103,17 @@ type InboundGroupSession struct {
RoomID id.RoomID
ForwardingChains []string
RatchetSafety RatchetSafety
ReceivedAt time.Time
MaxAge int64
MaxMessages int
IsScheduled bool
id id.SessionID
}
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) {
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) (*InboundGroupSession, error) {
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
if err != nil {
return nil, err
@@ -112,6 +124,10 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: nil,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
IsScheduled: isScheduled,
}, nil
}
@@ -122,6 +138,19 @@ func (igs *InboundGroupSession) ID() id.SessionID {
return igs.id
}
func (igs *InboundGroupSession) RatchetTo(index uint32) error {
exported, err := igs.Internal.Export(index)
if err != nil {
return err
}
imported, err := olm.InboundGroupSessionImport(exported)
if err != nil {
return err
}
igs.Internal = *imported
return nil
}
type OGSState int
const (

View File

@@ -7,13 +7,19 @@
package crypto
import (
"context"
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"strings"
"sync"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
"maunium.net/go/mautrix/event"
@@ -80,6 +86,35 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) {
return store.SyncToken, nil
}
var _ mautrix.SyncStore = (*SQLCryptoStore)(nil)
func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {}
func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" }
func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
err := store.PutNextBatch(nextBatchToken)
if err != nil {
// TODO handle error
}
}
func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string {
nb, err := store.GetNextBatch()
if err != nil {
// TODO handle error
}
return nb
}
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
// TODO return error
store.DB.Log.Warn("Failed to scan device ID: %v", err)
}
return
}
// PutAccount stores an OlmAccount in the database.
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
store.Account = account
@@ -220,37 +255,72 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession)
return err
}
func intishPtr[T int | int64](i T) *T {
if i == 0 {
return nil
}
return &i
}
func datePtr(t time.Time) *time.Time {
if t.IsZero() {
return nil
}
return &t
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
sessionBytes := session.Internal.Pickle(store.PickleKey)
forwardingChains := strings.Join(session.ForwardingChains, ",")
_, err := store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
VALUES ($1, $2, $3, $4, $5, $6, $7)
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
if err != nil {
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
}
_, err = store.DB.Exec(`
INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains
`, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID)
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
`,
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
session.IsScheduled, store.AccountID,
)
return err
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
var signingKey, forwardingChains, withheldCode sql.NullString
var sessionBytes []byte
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err := store.DB.QueryRow(`
SELECT signing_key, session, forwarding_chains, withheld_code
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID,
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode)
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
return nil, err
} else if withheldCode.Valid {
return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheldCode.String)
return nil, &event.RoomKeyWithheldEventContent{
RoomID: roomID,
Algorithm: id.AlgorithmMegolmV1,
SessionID: sessionID,
SenderKey: senderKey,
Code: event.RoomKeyWithheldCode(withheldCode.String),
Reason: withheldReason.String,
}
}
igs := olm.NewBlankInboundGroupSession()
err = igs.Unpickle(sessionBytes, store.PickleKey)
@@ -261,18 +331,96 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
if senderKey == "" {
senderKey = id.Curve25519(senderKeyDB.String)
}
return &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: senderKey,
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
}, nil
}
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
_, err := store.DB.Exec(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, sessionID, store.AccountID)
return err
}
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
if roomID == "" && senderKey == "" {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
res, err := store.DB.Query(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
}
return sessionIDs, err
}
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
var query string
switch store.DB.Dialect {
case dbutil.Postgres:
query = `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND received_at + 2 * (max_age * interval '1 millisecond') < now()
RETURNING session_id
`
case dbutil.SQLite:
query = `
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
AND received_at IS NOT NULL and max_age IS NOT NULL
AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
RETURNING session_id
`
default:
return nil, fmt.Errorf("unsupported dialect")
}
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
}
return sessionIDs, err
}
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID)
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
return err
}
@@ -302,8 +450,11 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
for rows.Next() {
var roomID id.RoomID
var signingKey, senderKey, forwardingChains sql.NullString
var sessionBytes []byte
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains)
var sessionBytes, ratchetSafetyBytes []byte
var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64
var isScheduled bool
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
if err != nil {
return
}
@@ -316,12 +467,24 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
if forwardingChains.String != "" {
chains = strings.Split(forwardingChains.String, ",")
}
var rs RatchetSafety
if len(ratchetSafetyBytes) > 0 {
err = json.Unmarshal(ratchetSafetyBytes, &rs)
if err != nil {
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
}
}
result = append(result, &InboundGroupSession{
Internal: *igs,
SigningKey: id.Ed25519(signingKey.String),
SenderKey: id.Curve25519(senderKey.String),
RoomID: roomID,
ForwardingChains: chains,
RatchetSafety: rs,
ReceivedAt: receivedAt.Time,
MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled,
})
}
return
@@ -329,7 +492,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID,
)
@@ -343,7 +506,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
rows, err := store.DB.Query(`
SELECT room_id, signing_key, sender_key, session, forwarding_chains
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
store.AccountID,
)
@@ -367,7 +530,7 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
max_messages=excluded.max_messages, message_count=excluded.message_count, max_age=excluded.max_age,
created_at=excluded.created_at, last_used=excluded.last_used, account_id=excluded.account_id
`, session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount,
session.MaxAge, session.CreationTime, session.LastEncryptedTime, store.AccountID)
session.MaxAge.Milliseconds(), session.CreationTime, session.LastEncryptedTime, store.AccountID)
return err
}
@@ -383,11 +546,12 @@ func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSe
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
var ogs OutboundGroupSession
var sessionBytes []byte
var maxAgeMS int64
err := store.DB.QueryRow(`
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
roomID, store.AccountID,
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.LastEncryptedTime)
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
if err == sql.ErrNoRows {
return nil, nil
} else if err != nil {
@@ -400,6 +564,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
}
ogs.Internal = *intOGS
ogs.RoomID = roomID
ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond
return &ogs, nil
}
@@ -412,7 +577,7 @@ func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
const validateQuery = `
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
VALUES ($1, $2, $3, $4, $5)
@@ -422,11 +587,19 @@ func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessio
`
var expectedEventID id.EventID
var expectedTimestamp int64
err := store.DB.QueryRow(validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
if err != nil {
return false, err
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
zerolog.Ctx(ctx).Debug().
Uint("message_index", index).
Str("expected_event_id", expectedEventID.String()).
Int64("expected_timestamp", expectedTimestamp).
Int64("actual_timestamp", timestamp).
Msg("Failed to validate that message index wasn't duplicated")
return false, nil
}
return expectedEventID == eventID && expectedTimestamp == timestamp, nil
return true, nil
}
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.

View File

@@ -1,4 +1,4 @@
-- v0 -> v8: Latest revision
-- v0 -> v10: Latest revision
CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL,
@@ -52,6 +52,11 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
forwarding_chains bytea,
withheld_code TEXT,
withheld_reason TEXT,
ratchet_safety jsonb,
received_at timestamp,
max_age BIGINT,
max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false,
PRIMARY KEY (account_id, session_id)
);

View File

@@ -0,0 +1,2 @@
-- v9: Change outbound megolm session max_age column to milliseconds
UPDATE crypto_megolm_outbound_session SET max_age=max_age/1000000;

View File

@@ -0,0 +1,6 @@
-- v10: Add metadata for detecting when megolm sessions are safe to delete
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER;
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false;

View File

@@ -21,7 +21,7 @@ const VersionTableName = "crypto_version"
var fs embed.FS
func init() {
Table.Register(-1, 3, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
})
Table.RegisterFS(fs)

View File

@@ -7,10 +7,8 @@
package crypto
import (
"encoding/gob"
"errors"
"context"
"fmt"
"os"
"sort"
"sync"
@@ -18,10 +16,7 @@ import (
"maunium.net/go/mautrix/id"
)
// Deprecated: moved to id.Device
type DeviceIdentity = id.Device
var ErrGroupSessionWithheld = errors.New("group session has been withheld")
var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{}
// Store is used by OlmMachine to store Olm and Megolm sessions, user device lists and message indices.
//
@@ -58,6 +53,12 @@ type Store interface {
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error
// RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room.
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
RedactExpiredGroupSessions() ([]id.SessionID, error)
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
@@ -90,9 +91,9 @@ type Store interface {
// * If the map key doesn't exist, the given values should be stored and this should return true.
// * If the map key exists and the stored values match the given values, this should return true.
// * If the map key exists, but the stored values do not match the given values, this should return false.
ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
// GetDevices returns a map from device ID to DeviceIdentity containing all devices of a given user.
// GetDevices returns a map from device ID to id.Device struct containing all devices of a given user.
GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error)
// GetDevice returns a specific device of a given user.
GetDevice(id.UserID, id.DeviceID) (*id.Device, error)
@@ -129,12 +130,12 @@ type messageIndexValue struct {
Timestamp int64
}
// GobStore is a simple Store implementation that dumps everything into a .gob file.
//
// Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended.
type GobStore struct {
// MemoryStore is a simple in-memory Store implementation. It can optionally have a callback function for saving data,
// but the actual storage must be implemented manually.
type MemoryStore struct {
lock sync.RWMutex
path string
save func() error
Account *OlmAccount
Sessions map[id.SenderKey]OlmSessionList
@@ -147,14 +148,15 @@ type GobStore struct {
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
}
var _ Store = (*GobStore)(nil)
var _ Store = (*MemoryStore)(nil)
func NewMemoryStore(saveCallback func() error) *MemoryStore {
if saveCallback == nil {
saveCallback = func() error { return nil }
}
return &MemoryStore{
save: saveCallback,
// NewGobStore creates a new GobStore that saves everything to the given file.
//
// Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended.
func NewGobStore(path string) (*GobStore, error) {
gs := &GobStore{
path: path,
Sessions: make(map[id.SenderKey]OlmSessionList),
GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession),
WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent),
@@ -164,44 +166,20 @@ func NewGobStore(path string) (*GobStore, error) {
CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey),
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
}
return gs, gs.load()
}
func (gs *GobStore) save() error {
file, err := os.OpenFile(gs.path, os.O_CREATE|os.O_WRONLY, 0600)
if err != nil {
return err
}
err = gob.NewEncoder(file).Encode(gs)
_ = file.Close()
return err
}
func (gs *GobStore) load() error {
file, err := os.OpenFile(gs.path, os.O_RDONLY, 0600)
if err != nil {
if os.IsNotExist(err) {
return nil
}
return err
}
err = gob.NewDecoder(file).Decode(gs)
_ = file.Close()
return err
}
func (gs *GobStore) Flush() error {
func (gs *MemoryStore) Flush() error {
gs.lock.Lock()
err := gs.save()
gs.lock.Unlock()
return err
}
func (gs *GobStore) GetAccount() (*OlmAccount, error) {
func (gs *MemoryStore) GetAccount() (*OlmAccount, error) {
return gs.Account, nil
}
func (gs *GobStore) PutAccount(account *OlmAccount) error {
func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
gs.lock.Lock()
gs.Account = account
err := gs.save()
@@ -209,7 +187,7 @@ func (gs *GobStore) PutAccount(account *OlmAccount) error {
return err
}
func (gs *GobStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
gs.lock.Lock()
sessions, ok := gs.Sessions[senderKey]
if !ok {
@@ -220,7 +198,7 @@ func (gs *GobStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error)
return sessions, nil
}
func (gs *GobStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
gs.lock.Lock()
sessions, _ := gs.Sessions[senderKey]
gs.Sessions[senderKey] = append(sessions, session)
@@ -230,19 +208,19 @@ func (gs *GobStore) AddSession(senderKey id.SenderKey, session *OlmSession) erro
return err
}
func (gs *GobStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *GobStore) HasSession(senderKey id.SenderKey) bool {
func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
return ok && len(sessions) > 0 && !sessions[0].Expired()
}
func (gs *GobStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
gs.lock.RLock()
sessions, ok := gs.Sessions[senderKey]
gs.lock.RUnlock()
@@ -252,7 +230,7 @@ func (gs *GobStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error
return sessions[0], nil
}
func (gs *GobStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession {
func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession {
room, ok := gs.GroupSessions[roomID]
if !ok {
room = make(map[id.SenderKey]map[id.SessionID]*InboundGroupSession)
@@ -266,7 +244,7 @@ func (gs *GobStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) m
return sender
}
func (gs *GobStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
gs.lock.Lock()
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
err := gs.save()
@@ -274,7 +252,7 @@ func (gs *GobStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, se
return err
}
func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
gs.lock.Lock()
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
if !ok {
@@ -289,7 +267,57 @@ func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, se
return session, nil
}
func (gs *GobStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
gs.lock.Lock()
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
err := gs.save()
gs.lock.Unlock()
return err
}
func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
gs.lock.Lock()
var sessionIDs []id.SessionID
if roomID != "" && senderKey != "" {
sessions := gs.getGroupSessions(roomID, senderKey)
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
delete(sessions, sessionID)
}
} else if senderKey != "" {
for _, room := range gs.GroupSessions {
sessions, ok := room[senderKey]
if ok {
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
}
delete(room, senderKey)
}
}
} else if roomID != "" {
room, ok := gs.GroupSessions[roomID]
if ok {
for senderKey := range room {
sessions := room[senderKey]
for sessionID := range sessions {
sessionIDs = append(sessionIDs, sessionID)
}
}
delete(gs.GroupSessions, roomID)
}
} else {
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
}
err := gs.save()
gs.lock.Unlock()
return sessionIDs, err
}
func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
room, ok := gs.WithheldGroupSessions[roomID]
if !ok {
room = make(map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent)
@@ -303,7 +331,7 @@ func (gs *GobStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.Send
return sender
}
func (gs *GobStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
gs.lock.Lock()
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
err := gs.save()
@@ -311,7 +339,7 @@ func (gs *GobStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventCo
return err
}
func (gs *GobStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
gs.lock.Lock()
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
gs.lock.Unlock()
@@ -321,7 +349,7 @@ func (gs *GobStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Sende
return session, nil
}
func (gs *GobStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
room, ok := gs.GroupSessions[roomID]
@@ -337,7 +365,7 @@ func (gs *GobStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSe
return result, nil
}
func (gs *GobStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
gs.lock.Lock()
var result []*InboundGroupSession
for _, room := range gs.GroupSessions {
@@ -351,7 +379,7 @@ func (gs *GobStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
return result, nil
}
func (gs *GobStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
gs.lock.Lock()
gs.OutGroupSessions[session.RoomID] = session
err := gs.save()
@@ -359,12 +387,12 @@ func (gs *GobStore) AddOutboundGroupSession(session *OutboundGroupSession) error
return err
}
func (gs *GobStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
// we don't need to do anything here because the session is a pointer and already stored in our map
return gs.save()
}
func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
gs.lock.RLock()
session, ok := gs.OutGroupSessions[roomID]
gs.lock.RUnlock()
@@ -374,7 +402,7 @@ func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSes
return session, nil
}
func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
gs.lock.Lock()
session, ok := gs.OutGroupSessions[roomID]
if !ok || session == nil {
@@ -386,7 +414,7 @@ func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
return nil
}
func (gs *GobStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
gs.lock.Lock()
defer gs.lock.Unlock()
key := messageIndexKey{
@@ -409,7 +437,7 @@ func (gs *GobStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.Se
return true, nil
}
func (gs *GobStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
gs.lock.RLock()
devices, ok := gs.Devices[userID]
if !ok {
@@ -419,7 +447,7 @@ func (gs *GobStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, er
return devices, nil
}
func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@@ -433,7 +461,7 @@ func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Devic
return device, nil
}
func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
devices, ok := gs.Devices[userID]
@@ -448,7 +476,7 @@ func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey
return nil, nil
}
func (gs *GobStore) PutDevice(userID id.UserID, device *id.Device) error {
func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
gs.lock.Lock()
devices, ok := gs.Devices[userID]
if !ok {
@@ -461,7 +489,7 @@ func (gs *GobStore) PutDevice(userID id.UserID, device *id.Device) error {
return err
}
func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
gs.lock.Lock()
gs.Devices[userID] = devices
err := gs.save()
@@ -469,7 +497,7 @@ func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Dev
return err
}
func (gs *GobStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
gs.lock.RLock()
var ptr int
for _, userID := range users {
@@ -483,7 +511,7 @@ func (gs *GobStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
return users[:ptr], nil
}
func (gs *GobStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
gs.lock.RLock()
userKeys, ok := gs.CrossSigningKeys[userID]
if !ok {
@@ -505,7 +533,7 @@ func (gs *GobStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUs
return err
}
func (gs *GobStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
keys, ok := gs.CrossSigningKeys[userID]
@@ -515,7 +543,7 @@ func (gs *GobStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUs
return keys, nil
}
func (gs *GobStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
gs.lock.RLock()
signedUserSigs, ok := gs.KeySignatures[signedUserID]
if !ok {
@@ -538,7 +566,7 @@ func (gs *GobStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, s
return err
}
func (gs *GobStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
gs.lock.RLock()
defer gs.lock.RUnlock()
userKeys, ok := gs.KeySignatures[userID]
@@ -556,7 +584,7 @@ func (gs *GobStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, sign
return sigsBySigner, nil
}
func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID)
if err != nil {
return false, err
@@ -565,7 +593,7 @@ func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.
return ok, nil
}
func (gs *GobStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
var count int64
gs.lock.RLock()
for _, userSigs := range gs.KeySignatures {

View File

@@ -4,9 +4,6 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !nosas
// +build !nosas
package crypto
import (
@@ -92,13 +89,13 @@ func (mach *OlmMachine) getPKAndKeysMAC(sas *olm.SAS, sendingUser id.UserID, sen
if err != nil {
return "", "", err
}
mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac))
mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac))
keysMac, err := sas.CalculateMAC([]byte(keyIDString), []byte(sasInfo+"KEY_IDS"))
if err != nil {
return "", "", err
}
mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac))
mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac))
return string(pubKeyMac), string(keysMac), nil
}
@@ -144,14 +141,14 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User
// handleVerificationStart handles an incoming m.key.verification.start message.
// It initializes the state for this SAS verification process and stores it.
func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
mach.Log.Debug("Received verification start from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice)
mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
if err != nil {
mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID)
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
return
}
warnAndCancel := func(logReason, cancelReason string) {
mach.Log.Warn("Canceling verification transaction %v as it %s", transactionID, logReason)
mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
} else {
@@ -179,7 +176,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
if inRoomID != "" && transactionID != "" {
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error")
return
}
@@ -187,7 +184,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString)
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
if err != nil {
mach.Log.Error("Error accepting in-room SAS verification: %v", err)
mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err)
}
verState.chosenSASMethod = sasMethods[0]
verState.verificationStarted = true
@@ -197,7 +194,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
if resp == AcceptRequest {
sasMethods := commonSASMethods(hooks, content.ShortAuthenticationString)
if len(sasMethods) == 0 {
mach.Log.Error("No common SAS methods: %v", content.ShortAuthenticationString)
mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
} else {
@@ -222,7 +219,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
_, loaded := mach.keyVerificationTransactionState.LoadOrStore(userID.String()+":"+transactionID, verState)
if loaded {
// transaction already exists
mach.Log.Error("Transaction %v already exists, canceling", transactionID)
mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
} else {
@@ -240,10 +237,10 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
}
if err != nil {
mach.Log.Error("Error accepting SAS verification: %v", err)
mach.Log.Error().Msgf("Error accepting SAS verification: %v", err)
}
} else if resp == RejectRequest {
mach.Log.Debug("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
var err error
if inRoomID == "" {
err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
@@ -251,10 +248,10 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
}
if err != nil {
mach.Log.Error("Error canceling SAS verification: %v", err)
mach.Log.Error().Msgf("Error canceling SAS verification: %v", err)
}
} else {
mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
}
}
@@ -276,12 +273,12 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
// if deadline exceeded cancel due to timeout
mach.keyVerificationTransactionState.Delete(mapKey)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout)
mach.Log.Warn("Verification transaction %v is canceled due to timing out", transactionID)
mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID)
verState.lock.Unlock()
return
}
// otherwise the cancel func was called, so the timeout is reset
mach.Log.Debug("Extending timeout for transaction %v", transactionID)
mach.Log.Debug().Msgf("Extending timeout for transaction %v", transactionID)
timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), timeout)
verState.extendTimeout = timeoutCancel
verState.lock.Unlock()
@@ -292,10 +289,10 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
// handleVerificationAccept handles an incoming m.key.verification.accept message.
// It continues the SAS verification process by sending the SAS key message to the other device.
func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) {
mach.Log.Debug("Received verification accept for transaction %v", transactionID)
mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID)
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
verState.lock.Lock()
@@ -304,7 +301,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
if !verState.initiatedByUs || verState.verificationStarted {
// unexpected accept at this point
mach.Log.Warn("Unexpected verification accept message for transaction %v", transactionID)
mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage)
return
@@ -316,7 +313,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
content.MessageAuthenticationCode != event.HKDFHMACSHA256 ||
len(sasMethods) == 0 {
mach.Log.Warn("Canceling verification transaction %v due to unknown parameter", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod)
return
@@ -333,7 +330,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
}
if err != nil {
mach.Log.Error("Error sending SAS key to other device: %v", err)
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
return
}
}
@@ -341,10 +338,10 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
// handleVerificationKey handles an incoming m.key.verification.key message.
// It stores the other device's public key in order to acquire the SAS shared secret.
func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) {
mach.Log.Debug("Got verification key for transaction %v: %v", transactionID, content.Key)
mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key)
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
verState.lock.Lock()
@@ -355,14 +352,14 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
if !verState.verificationStarted || verState.keyReceived {
// unexpected key at this point
mach.Log.Warn("Unexpected verification key message for transaction %v", transactionID)
mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage)
return
}
if err := verState.sas.SetTheirKey([]byte(content.Key)); err != nil {
mach.Log.Error("Error setting other device's key: %v", err)
mach.Log.Error().Msgf("Error setting other device's key: %v", err)
return
}
@@ -371,9 +368,9 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
if verState.initiatedByUs {
// verify commitment string from accept message now
expectedCommitment := olm.NewUtility().Sha256(content.Key + verState.startEventCanonical)
mach.Log.Debug("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment)
mach.Log.Debug().Msgf("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment)
if expectedCommitment != verState.commitment {
mach.Log.Warn("Canceling verification transaction %v due to commitment mismatch", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID)
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch)
return
@@ -388,7 +385,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
}
if err != nil {
mach.Log.Error("Error sending SAS key to other device: %v", err)
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
return
}
}
@@ -416,10 +413,10 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
sasMethod := verState.chosenSASMethod
sas, err := sasMethod.GetVerificationSAS(initUserID, initDeviceID, initKey, acceptUserID, acceptDeviceID, acceptKey, transactionID, verState.sas)
if err != nil {
mach.Log.Error("Error generating SAS (method %v): %v", sasMethod.Type(), err)
mach.Log.Error().Msgf("Error generating SAS (method %v): %v", sasMethod.Type(), err)
return
}
mach.Log.Debug("Generated SAS (%v): %v", sasMethod.Type(), sas)
mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas)
go func() {
result := verState.hooks.VerifySASMatch(device, sas)
mach.sasCompared(result, transactionID, verState)
@@ -441,7 +438,7 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
}
if err != nil {
mach.Log.Error("Error sending verification MAC to other device: %v", err)
mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err)
}
} else {
verState.sasMatched <- false
@@ -451,10 +448,10 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
// handleVerificationMAC handles an incoming m.key.verification.mac message.
// It verifies the other device's MAC and if the MAC is valid it marks the device as trusted.
func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) {
mach.Log.Debug("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
verState.lock.Lock()
@@ -468,7 +465,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
if !verState.verificationStarted || !verState.keyReceived {
// unexpected MAC at this point
mach.Log.Warn("Unexpected MAC message for transaction %v", transactionID)
mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage)
return
}
@@ -480,7 +477,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
defer verState.lock.Unlock()
if !matched {
mach.Log.Warn("SAS do not match! Canceling transaction %v", transactionID)
mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch)
return
}
@@ -490,20 +487,20 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
expectedPKMAC, expectedKeysMAC, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID,
mach.Client.UserID, mach.Client.DeviceID, transactionID, device.SigningKey, keyID, content.Mac)
if err != nil {
mach.Log.Error("Error generating MAC to match with received MAC: %v", err)
mach.Log.Error().Msgf("Error generating MAC to match with received MAC: %v", err)
return
}
mach.Log.Debug("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
if content.Keys != expectedKeysMAC {
mach.Log.Warn("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch)
return
}
mach.Log.Debug("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
if content.Mac[keyID] != expectedPKMAC {
mach.Log.Warn("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch)
return
}
@@ -512,35 +509,35 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
device.Trust = id.TrustStateVerified
err = mach.CryptoStore.PutDevice(device.UserID, device)
if err != nil {
mach.Log.Warn("Failed to put device after verifying: %v", err)
mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err)
}
if mach.CrossSigningKeys != nil {
if device.UserID == mach.Client.UserID {
err := mach.SignOwnDevice(device)
if err != nil {
mach.Log.Error("Failed to cross-sign own device %s: %v", device.DeviceID, err)
mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err)
} else {
mach.Log.Debug("Cross-signed own device %v after SAS verification", device.DeviceID)
mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID)
}
} else {
masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID)
if err != nil {
mach.Log.Warn("Failed to fetch %s's master key: %v", device.UserID, err)
mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err)
} else {
if err := mach.SignUser(device.UserID, masterKey); err != nil {
mach.Log.Error("Failed to cross-sign master key of %s: %v", device.UserID, err)
mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err)
} else {
mach.Log.Debug("Cross-signed master key of %v after SAS verification", device.UserID)
mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID)
}
}
}
} else {
// TODO ask user to unlock cross-signing keys?
mach.Log.Debug("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID)
mach.Log.Debug().Msgf("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID)
}
mach.Log.Debug("Device %v of user %v verified successfully!", device.DeviceID, device.UserID)
mach.Log.Debug().Msgf("Device %v of user %v verified successfully!", device.DeviceID, device.UserID)
verState.hooks.OnSuccess()
}()
@@ -557,20 +554,20 @@ func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *even
}
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
mach.Log.Warn("SAS verification %v was canceled by %v with reason: %v (%v)",
mach.Log.Warn().Msgf("SAS verification %v was canceled by %v with reason: %v (%v)",
transactionID, userID, content.Reason, content.Code)
}
// handleVerificationRequest handles an incoming m.key.verification.request message.
func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) {
mach.Log.Debug("Received verification request from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice)
mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice)
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
if err != nil {
mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID)
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
return
}
if !content.SupportsVerificationMethod(event.VerificationMethodSAS) {
mach.Log.Warn("Canceling verification transaction %v as SAS is not supported", transactionID)
mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
} else {
@@ -580,12 +577,12 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
}
resp, hooks := mach.AcceptVerificationFrom(transactionID, otherDevice, inRoomID)
if resp == AcceptRequest {
mach.Log.Debug("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
if inRoomID == "" {
_, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
} else {
if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil {
mach.Log.Error("Error sending in-room SAS verification ready: %v", err)
mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err)
}
if mach.Client.UserID < otherDevice.UserID {
// up to us to send the start message
@@ -593,17 +590,17 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
}
}
if err != nil {
mach.Log.Error("Error accepting SAS verification request: %v", err)
mach.Log.Error().Msgf("Error accepting SAS verification request: %v", err)
}
} else if resp == RejectRequest {
mach.Log.Debug("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
if inRoomID == "" {
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
} else {
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
}
} else {
mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
}
}
@@ -620,7 +617,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica
if transactionID == "" {
transactionID = strconv.Itoa(rand.Int())
}
mach.Log.Debug("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID)
mach.Log.Debug().Msgf("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID)
verState := &verificationState{
sas: olm.NewSAS(),
@@ -669,7 +666,7 @@ func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, r
verState := verStateInterface.(*verificationState)
verState.lock.Lock()
defer verState.lock.Unlock()
mach.Log.Trace("User canceled verification transaction %v with reason: %v", transactionID, reason)
mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason)
mach.keyVerificationTransactionState.Delete(mapKey)
return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser)
}
@@ -766,9 +763,9 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID,
userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap)
if err != nil {
mach.Log.Error("Error generating master key MAC: %v", err)
mach.Log.Error().Msgf("Error generating master key MAC: %v", err)
} else {
mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
macMap[masterKeyID] = masterKeyMAC
}
}
@@ -777,8 +774,8 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
if err != nil {
return err
}
mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac)
mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac)
macMap[keyID] = pubKeyMac
content := &event.VerificationMacEventContent{

View File

@@ -7,6 +7,7 @@
package crypto
import (
"context"
"encoding/json"
"errors"
"time"
@@ -80,7 +81,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationCancel, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationCancel, content)
if err != nil {
return err
}
@@ -97,7 +98,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse
To: toUserID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.EventMessage, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.EventMessage, content)
if err != nil {
return "", err
}
@@ -116,7 +117,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transac
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationReady, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationReady, content)
if err != nil {
return err
}
@@ -141,7 +142,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserI
To: toUserID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationStart, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationStart, content)
if err != nil {
return nil, err
}
@@ -182,7 +183,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUs
To: fromUser,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationAccept, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationAccept, content)
if err != nil {
return err
}
@@ -198,7 +199,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationKey, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationKey, content)
if err != nil {
return err
}
@@ -222,9 +223,9 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID,
userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap)
if err != nil {
mach.Log.Error("Error generating master key MAC: %v", err)
mach.Log.Error().Msgf("Error generating master key MAC: %v", err)
} else {
mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
macMap[masterKeyID] = masterKeyMAC
}
}
@@ -233,8 +234,8 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
if err != nil {
return err
}
mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac)
mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac)
mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac)
macMap[keyID] = pubKeyMac
content := &event.VerificationMacEventContent{
@@ -244,7 +245,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
To: userID,
}
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationMAC, content)
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationMAC, content)
if err != nil {
return err
}
@@ -259,7 +260,7 @@ func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID
}
func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
mach.Log.Debug("Starting new in-room verification transaction user %v", device.UserID)
mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID)
request := transactionID == ""
if request {
@@ -310,15 +311,15 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de
}
func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) {
device, err := mach.GetOrFetchDevice(userID, content.FromDevice)
device, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
if err != nil {
mach.Log.Error("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
return
}
verState, err := mach.getTransactionState(transactionID, userID)
if err != nil {
mach.Log.Error("Error getting transaction state: %v", err)
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
return
}
//mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)

View File

@@ -4,9 +4,6 @@
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !nosas
// +build !nosas
package crypto
import (

View File

@@ -11,6 +11,8 @@ import (
"errors"
"fmt"
"net/http"
"golang.org/x/exp/maps"
)
// Common error codes from https://matrix.org/docs/spec/client_server/latest#api-standards
@@ -60,6 +62,11 @@ var (
MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"}
// The client specified a parameter that has the wrong value.
MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM"}
MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"}
MBadStatus = RespError{ErrCode: "M_BAD_STATUS"}
MConnectionTimeout = RespError{ErrCode: "M_CONNECTION_TIMEOUT"}
MConnectionFailed = RespError{ErrCode: "M_CONNECTION_FAILED"}
)
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.
@@ -124,12 +131,13 @@ func (e *RespError) UnmarshalJSON(data []byte) error {
}
func (e *RespError) MarshalJSON() ([]byte, error) {
if e.ExtraData == nil {
e.ExtraData = make(map[string]interface{})
data := maps.Clone(e.ExtraData)
if data == nil {
data = make(map[string]any)
}
e.ExtraData["errcode"] = e.ErrCode
e.ExtraData["error"] = e.Err
return json.Marshal(&e.ExtraData)
data["errcode"] = e.ErrCode
data["error"] = e.Err
return json.Marshal(data)
}
// Error returns the errcode and error message.

View File

@@ -49,3 +49,9 @@ type BeeperRetryMetadata struct {
RetryCount int `json:"retry_count"`
// last_retry is also present, but not used by bridges
}
type BeeperRoomKeyAckEventContent struct {
RoomID id.RoomID `json:"room_id"`
SessionID id.SessionID `json:"session_id"`
FirstMessageIndex int `json:"first_message_index"`
}

View File

@@ -80,6 +80,8 @@ var TypeMap = map[Type]reflect.Type{
ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}),
ToDeviceBeeperRoomKeyAck: reflect.TypeOf(BeeperRoomKeyAckEventContent{}),
CallInvite: reflect.TypeOf(CallInviteEventContent{}),
CallCandidates: reflect.TypeOf(CallCandidatesEventContent{}),
CallAnswer: reflect.TypeOf(CallAnswerEventContent{}),

View File

@@ -8,6 +8,7 @@ package event
import (
"encoding/json"
"fmt"
"maunium.net/go/mautrix/id"
)
@@ -41,6 +42,7 @@ type EncryptedEventContent struct {
OlmCiphertext OlmCiphertexts `json:"-"`
RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
Mentions *Mentions `json:"m.mentions,omitempty"`
}
type OlmCiphertexts map[id.Curve25519]struct {
@@ -92,6 +94,10 @@ type RoomKeyEventContent struct {
RoomID id.RoomID `json:"room_id"`
SessionID id.SessionID `json:"session_id"`
SessionKey string `json:"session_key"`
MaxAge int64 `json:"com.beeper.max_age_ms"`
MaxMessages int `json:"com.beeper.max_messages"`
IsScheduled bool `json:"com.beeper.is_scheduled"`
}
// ForwardedRoomKeyEventContent represents the content of a m.forwarded_room_key to_device event.
@@ -101,6 +107,10 @@ type ForwardedRoomKeyEventContent struct {
SenderKey id.SenderKey `json:"sender_key"`
SenderClaimedKey id.Ed25519 `json:"sender_claimed_ed25519_key"`
ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"`
MaxAge int64 `json:"com.beeper.max_age_ms"`
MaxMessages int `json:"com.beeper.max_messages"`
IsScheduled bool `json:"com.beeper.is_scheduled"`
}
type KeyRequestAction string
@@ -134,6 +144,8 @@ const (
RoomKeyWithheldUnauthorized RoomKeyWithheldCode = "m.unauthorised"
RoomKeyWithheldUnavailable RoomKeyWithheldCode = "m.unavailable"
RoomKeyWithheldNoOlmSession RoomKeyWithheldCode = "m.no_olm"
RoomKeyWithheldBeeperRedacted RoomKeyWithheldCode = "com.beeper.redacted"
)
type RoomKeyWithheldEventContent struct {
@@ -145,4 +157,23 @@ type RoomKeyWithheldEventContent struct {
Reason string `json:"reason,omitempty"`
}
const groupSessionWithheldMsg = "group session has been withheld: %s"
func (withheld *RoomKeyWithheldEventContent) Error() string {
switch withheld.Code {
case RoomKeyWithheldBlacklisted, RoomKeyWithheldUnverified, RoomKeyWithheldUnauthorized, RoomKeyWithheldUnavailable, RoomKeyWithheldNoOlmSession:
return fmt.Sprintf(groupSessionWithheldMsg, withheld.Code)
default:
return fmt.Sprintf(groupSessionWithheldMsg+" (%s)", withheld.Code, withheld.Reason)
}
}
func (withheld *RoomKeyWithheldEventContent) Is(other error) bool {
otherWithheld, ok := other.(*RoomKeyWithheldEventContent)
if !ok {
return false
}
return withheld.Code == "" || otherWithheld.Code == "" || withheld.Code == otherWithheld.Code
}
type DummyEventContent struct{}

View File

@@ -111,6 +111,8 @@ type MautrixInfo struct {
TrustSource *id.Device
ReceivedAt time.Time
EditedAt time.Time
LastEditID id.EventID
DecryptionDuration time.Duration
CheckpointSent bool

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2020 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -97,6 +97,9 @@ type MessageEventContent struct {
FileName string `json:"filename,omitempty"`
Mentions *Mentions `json:"m.mentions,omitempty"`
UnstableMentions *Mentions `json:"org.matrix.msc3952.mentions,omitempty"`
// Edits and relations
NewContent *MessageEventContent `json:"m.new_content,omitempty"`
RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
@@ -135,6 +138,12 @@ func (content *MessageEventContent) SetEdit(original id.EventID) {
if content.Format == FormatHTML && len(content.FormattedBody) > 0 {
content.FormattedBody = "* " + content.FormattedBody
}
// If the message is long, remove most of the useless edit fallback to avoid event size issues.
if len(content.Body) > 10000 {
content.FormattedBody = ""
content.Format = ""
content.Body = content.Body[:50] + "[edit fallback cut…]"
}
}
}
@@ -163,6 +172,11 @@ func (content *MessageEventContent) GetInfo() *FileInfo {
return content.Info
}
type Mentions struct {
UserIDs []id.UserID `json:"user_ids,omitempty"`
Room bool `json:"room,omitempty"`
}
type EncryptedFileInfo struct {
attachment.EncryptedFile
URL id.ContentURIString `json:"url"`

View File

@@ -13,7 +13,7 @@ import (
)
// PowerLevelsEventContent represents the content of a m.room.power_levels state event content.
// https://spec.matrix.org/v1.2/client-server-api/#mroompower_levels
// https://spec.matrix.org/v1.5/client-server-api/#mroompower_levels
type PowerLevelsEventContent struct {
usersLock sync.RWMutex
Users map[id.UserID]int `json:"users,omitempty"`
@@ -23,6 +23,8 @@ type PowerLevelsEventContent struct {
Events map[string]int `json:"events,omitempty"`
EventsDefault int `json:"events_default,omitempty"`
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
StateDefaultPtr *int `json:"state_default,omitempty"`
InvitePtr *int `json:"invite,omitempty"`
@@ -32,6 +34,66 @@ type PowerLevelsEventContent struct {
HistoricalPtr *int `json:"historical,omitempty"`
}
func copyPtr(ptr *int) *int {
if ptr == nil {
return nil
}
val := *ptr
return &val
}
func copyMap[Key comparable](m map[Key]int) map[Key]int {
if m == nil {
return nil
}
copied := make(map[Key]int, len(m))
for k, v := range m {
copied[k] = v
}
return copied
}
func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
if pl == nil {
return nil
}
return &PowerLevelsEventContent{
Users: copyMap(pl.Users),
UsersDefault: pl.UsersDefault,
Events: copyMap(pl.Events),
EventsDefault: pl.EventsDefault,
StateDefaultPtr: copyPtr(pl.StateDefaultPtr),
Notifications: pl.Notifications.Clone(),
InvitePtr: copyPtr(pl.InvitePtr),
KickPtr: copyPtr(pl.KickPtr),
BanPtr: copyPtr(pl.BanPtr),
RedactPtr: copyPtr(pl.RedactPtr),
HistoricalPtr: copyPtr(pl.HistoricalPtr),
}
}
type NotificationPowerLevels struct {
RoomPtr *int `json:"room,omitempty"`
}
func (npl *NotificationPowerLevels) Clone() *NotificationPowerLevels {
if npl == nil {
return nil
}
return &NotificationPowerLevels{
RoomPtr: copyPtr(npl.RoomPtr),
}
}
func (npl *NotificationPowerLevels) Room() int {
if npl != nil && npl.RoomPtr != nil {
return *npl.RoomPtr
}
return 50
}
func (pl *PowerLevelsEventContent) Invite() int {
if pl.InvitePtr != nil {
return *pl.InvitePtr

View File

@@ -32,6 +32,8 @@ type RelatesTo struct {
type InReplyTo struct {
EventID id.EventID `json:"event_id,omitempty"`
UnstableRoomID id.RoomID `json:"room_id,omitempty"`
}
func (rel *RelatesTo) Copy() *RelatesTo {

View File

@@ -124,7 +124,8 @@ func (et *Type) GuessClass() TypeClass {
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type:
return MessageEventType
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type:
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
ToDeviceBeeperRoomKeyAck.Type:
return ToDeviceEventType
default:
return UnknownEventType
@@ -253,4 +254,6 @@ var (
ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType}
ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType}
ToDeviceBeeperRoomKeyAck = Type{"com.beeper.room_key.ack", ToDeviceEventType}
)

View File

@@ -17,7 +17,45 @@ import (
"maunium.net/go/mautrix/id"
)
type Context map[string]interface{}
type TagStack []string
func (ts TagStack) Index(tag string) int {
for i := len(ts) - 1; i >= 0; i-- {
if ts[i] == tag {
return i
}
}
return -1
}
func (ts TagStack) Has(tag string) bool {
return ts.Index(tag) >= 0
}
type Context struct {
ReturnData map[string]any
TagStack TagStack
PreserveWhitespace bool
}
func NewContext() Context {
return Context{
ReturnData: map[string]any{},
TagStack: make(TagStack, 0, 4),
}
}
func (ctx Context) WithTag(tag string) Context {
ctx.TagStack = append(ctx.TagStack, tag)
return ctx
}
func (ctx Context) WithWhitespace() Context {
ctx.PreserveWhitespace = true
return ctx
}
type TextConverter func(string, Context) string
type SpoilerConverter func(text, reason string, ctx Context) string
type LinkConverter func(text, href string, ctx Context) string
@@ -93,9 +131,9 @@ func Digits(num int) int {
return int(math.Floor(math.Log10(float64(num))) + 1)
}
func (parser *HTMLParser) listToString(node *html.Node, stripLinebreak bool, ctx Context) string {
func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string {
ordered := node.Data == "ol"
taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, stripLinebreak, ctx)
taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, ctx)
counter := 1
indentLength := 0
if ordered {
@@ -137,8 +175,27 @@ func (parser *HTMLParser) listToString(node *html.Node, stripLinebreak bool, ctx
return strings.Join(children, "\n")
}
func (parser *HTMLParser) basicFormatToString(node *html.Node, stripLinebreak bool, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
func LongestSequence(in string, of rune) int {
currentSeq := 0
maxSeq := 0
for _, chr := range in {
if chr == of {
currentSeq++
} else {
if currentSeq > maxSeq {
maxSeq = currentSeq
}
currentSeq = 0
}
}
if currentSeq > maxSeq {
maxSeq = currentSeq
}
return maxSeq
}
func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
switch node.Data {
case "b", "strong":
if parser.BoldConverter != nil {
@@ -163,13 +220,14 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, stripLinebreak bo
if parser.MonospaceConverter != nil {
return parser.MonospaceConverter(str, ctx)
}
return fmt.Sprintf("`%s`", str)
surround := strings.Repeat("`", LongestSequence(str, '`')+1)
return fmt.Sprintf("%s%s%s", surround, str, surround)
}
return str
}
func (parser *HTMLParser) spanToString(node *html.Node, stripLinebreak bool, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
if node.Data == "span" {
reason, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler")
if isSpoiler {
@@ -195,15 +253,15 @@ func (parser *HTMLParser) spanToString(node *html.Node, stripLinebreak bool, ctx
return str
}
func (parser *HTMLParser) headerToString(node *html.Node, stripLinebreak bool, ctx Context) string {
children := parser.nodeToStrings(node.FirstChild, stripLinebreak, ctx)
func (parser *HTMLParser) headerToString(node *html.Node, ctx Context) string {
children := parser.nodeToStrings(node.FirstChild, ctx)
length := int(node.Data[1] - '0')
prefix := strings.Repeat("#", length) + " "
return prefix + strings.Join(children, "")
}
func (parser *HTMLParser) blockquoteToString(node *html.Node, stripLinebreak bool, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
func (parser *HTMLParser) blockquoteToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
childrenArr := strings.Split(strings.TrimSpace(str), "\n")
// TODO make blockquote prefix configurable
for index, child := range childrenArr {
@@ -212,8 +270,8 @@ func (parser *HTMLParser) blockquoteToString(node *html.Node, stripLinebreak boo
return strings.Join(childrenArr, "\n")
}
func (parser *HTMLParser) linkToString(node *html.Node, stripLinebreak bool, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string {
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
href := parser.getAttribute(node, "href")
if len(href) == 0 {
return str
@@ -232,24 +290,25 @@ func (parser *HTMLParser) linkToString(node *html.Node, stripLinebreak bool, ctx
return fmt.Sprintf("%s (%s)", str, href)
}
func (parser *HTMLParser) tagToString(node *html.Node, stripLinebreak bool, ctx Context) string {
func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
ctx = ctx.WithTag(node.Data)
switch node.Data {
case "blockquote":
return parser.blockquoteToString(node, stripLinebreak, ctx)
return parser.blockquoteToString(node, ctx)
case "ol", "ul":
return parser.listToString(node, stripLinebreak, ctx)
return parser.listToString(node, ctx)
case "h1", "h2", "h3", "h4", "h5", "h6":
return parser.headerToString(node, stripLinebreak, ctx)
return parser.headerToString(node, ctx)
case "br":
return parser.Newline
case "b", "strong", "i", "em", "s", "strike", "del", "u", "ins", "tt", "code":
return parser.basicFormatToString(node, stripLinebreak, ctx)
return parser.basicFormatToString(node, ctx)
case "span", "font":
return parser.spanToString(node, stripLinebreak, ctx)
return parser.spanToString(node, ctx)
case "a":
return parser.linkToString(node, stripLinebreak, ctx)
return parser.linkToString(node, ctx)
case "p":
return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
return parser.nodeToTagAwareString(node.FirstChild, ctx)
case "hr":
return parser.HorizontalLine
case "pre":
@@ -259,9 +318,9 @@ func (parser *HTMLParser) tagToString(node *html.Node, stripLinebreak bool, ctx
if strings.HasPrefix(class, "language-") {
language = class[len("language-"):]
}
preStr = parser.nodeToString(node.FirstChild.FirstChild, false, ctx)
preStr = parser.nodeToString(node.FirstChild.FirstChild, ctx.WithWhitespace())
} else {
preStr = parser.nodeToString(node.FirstChild, false, ctx)
preStr = parser.nodeToString(node.FirstChild, ctx.WithWhitespace())
}
if parser.MonospaceBlockConverter != nil {
return parser.MonospaceBlockConverter(preStr, language, ctx)
@@ -271,14 +330,14 @@ func (parser *HTMLParser) tagToString(node *html.Node, stripLinebreak bool, ctx
}
return fmt.Sprintf("```%s\n%s```", language, preStr)
default:
return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
return parser.nodeToTagAwareString(node.FirstChild, ctx)
}
}
func (parser *HTMLParser) singleNodeToString(node *html.Node, stripLinebreak bool, ctx Context) TaggedString {
func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString {
switch node.Type {
case html.TextNode:
if stripLinebreak {
if !ctx.PreserveWhitespace {
node.Data = strings.Replace(node.Data, "\n", "", -1)
}
if parser.TextConverter != nil {
@@ -286,17 +345,17 @@ func (parser *HTMLParser) singleNodeToString(node *html.Node, stripLinebreak boo
}
return TaggedString{node.Data, "text"}
case html.ElementNode:
return TaggedString{parser.tagToString(node, stripLinebreak, ctx), node.Data}
return TaggedString{parser.tagToString(node, ctx), node.Data}
case html.DocumentNode:
return TaggedString{parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx), "html"}
return TaggedString{parser.nodeToTagAwareString(node.FirstChild, ctx), "html"}
default:
return TaggedString{"", "unknown"}
}
}
func (parser *HTMLParser) nodeToTaggedStrings(node *html.Node, stripLinebreak bool, ctx Context) (strs []TaggedString) {
func (parser *HTMLParser) nodeToTaggedStrings(node *html.Node, ctx Context) (strs []TaggedString) {
for ; node != nil; node = node.NextSibling {
strs = append(strs, parser.singleNodeToString(node, stripLinebreak, ctx))
strs = append(strs, parser.singleNodeToString(node, ctx))
}
return
}
@@ -312,8 +371,8 @@ func (parser *HTMLParser) isBlockTag(tag string) bool {
return false
}
func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, stripLinebreak bool, ctx Context) string {
strs := parser.nodeToTaggedStrings(node, stripLinebreak, ctx)
func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, ctx Context) string {
strs := parser.nodeToTaggedStrings(node, ctx)
var output strings.Builder
for _, str := range strs {
tstr := str.string
@@ -325,15 +384,15 @@ func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, stripLinebreak b
return strings.TrimSpace(output.String())
}
func (parser *HTMLParser) nodeToStrings(node *html.Node, stripLinebreak bool, ctx Context) (strs []string) {
func (parser *HTMLParser) nodeToStrings(node *html.Node, ctx Context) (strs []string) {
for ; node != nil; node = node.NextSibling {
strs = append(strs, parser.singleNodeToString(node, stripLinebreak, ctx).string)
strs = append(strs, parser.singleNodeToString(node, ctx).string)
}
return
}
func (parser *HTMLParser) nodeToString(node *html.Node, stripLinebreak bool, ctx Context) string {
return strings.Join(parser.nodeToStrings(node, stripLinebreak, ctx), "")
func (parser *HTMLParser) nodeToString(node *html.Node, ctx Context) string {
return strings.Join(parser.nodeToStrings(node, ctx), "")
}
// Parse converts Matrix HTML into text using the settings in this parser.
@@ -342,7 +401,7 @@ func (parser *HTMLParser) Parse(htmlData string, ctx Context) string {
htmlData = strings.Replace(htmlData, "\t", strings.Repeat(" ", parser.TabsToSpaces), -1)
}
node, _ := html.Parse(strings.NewReader(htmlData))
return parser.nodeToTagAwareString(node, true, ctx)
return parser.nodeToTagAwareString(node, ctx)
}
// HTMLToText converts Matrix HTML into text with the default settings.
@@ -352,7 +411,7 @@ func HTMLToText(html string) string {
Newline: "\n",
HorizontalLine: "\n---\n",
PillConverter: DefaultPillConverter,
}).Parse(html, make(Context))
}).Parse(html, NewContext())
}
// HTMLToMarkdown converts Matrix HTML into markdown with the default settings.
@@ -370,5 +429,5 @@ func HTMLToMarkdown(html string) string {
}
return fmt.Sprintf("[%s](%s)", text, href)
},
}).Parse(html, make(Context))
}).Parse(html, NewContext())
}

View File

@@ -0,0 +1,48 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package mdext
import (
"reflect"
"github.com/yuin/goldmark/parser"
"github.com/yuin/goldmark/util"
)
func filterParsers(list []util.PrioritizedValue, forbidden map[reflect.Type]struct{}) []util.PrioritizedValue {
n := 0
for _, item := range list {
if _, isForbidden := forbidden[reflect.TypeOf(item.Value)]; isForbidden {
continue
}
list[n] = item
n++
}
return list[:n]
}
// ParserWithoutFeatures returns a Goldmark parser with the provided default features removed.
//
// e.g. to disable lists, use
//
// markdown := goldmark.New(goldmark.WithParser(
// mdext.ParserWithoutFeatures(goldmark.NewListParser(), goldmark.NewListItemParser())
// ))
func ParserWithoutFeatures(features ...any) parser.Parser {
forbiddenTypes := make(map[reflect.Type]struct{}, len(features))
for _, feature := range features {
forbiddenTypes[reflect.TypeOf(feature)] = struct{}{}
}
filteredBlockParsers := filterParsers(parser.DefaultBlockParsers(), forbiddenTypes)
filteredInlineParsers := filterParsers(parser.DefaultInlineParsers(), forbiddenTypes)
filteredParagraphTransformers := filterParsers(parser.DefaultParagraphTransformers(), forbiddenTypes)
return parser.NewParser(
parser.WithBlockParsers(filteredBlockParsers...),
parser.WithInlineParsers(filteredInlineParsers...),
parser.WithParagraphTransformers(filteredParagraphTransformers...),
)
}

View File

@@ -67,10 +67,10 @@ func (uri *MatrixURI) getQuery() url.Values {
func (uri *MatrixURI) String() string {
parts := []string{
SigilToPathSegment[uri.Sigil1],
uri.MXID1,
url.PathEscape(uri.MXID1),
}
if uri.Sigil2 != 0 {
parts = append(parts, SigilToPathSegment[uri.Sigil2], uri.MXID2)
parts = append(parts, SigilToPathSegment[uri.Sigil2], url.PathEscape(uri.MXID2))
}
return (&url.URL{
Scheme: "matrix",
@@ -81,9 +81,9 @@ func (uri *MatrixURI) String() string {
// MatrixToURL converts to parsed matrix: URI into a matrix.to URL
func (uri *MatrixURI) MatrixToURL() string {
fragment := fmt.Sprintf("#/%s", url.QueryEscape(uri.PrimaryIdentifier()))
fragment := fmt.Sprintf("#/%s", url.PathEscape(uri.PrimaryIdentifier()))
if uri.Sigil2 != 0 {
fragment = fmt.Sprintf("%s/%s", fragment, url.QueryEscape(uri.SecondaryIdentifier()))
fragment = fmt.Sprintf("%s/%s", fragment, url.PathEscape(uri.SecondaryIdentifier()))
}
query := uri.getQuery().Encode()
if len(query) > 0 {

View File

@@ -33,14 +33,15 @@ type PushActionArray []*PushAction
// PushActionArrayShould contains the important information parsed from a PushActionArray.
type PushActionArrayShould struct {
// Whether or not the array contained a Notify, DontNotify or Coalesce action type.
// Whether the array contained a Notify, DontNotify or Coalesce action type.
// Deprecated: an empty array should be treated as no notification, so there's no reason to check this field.
NotifySpecified bool
// Whether or not the event in question should trigger a notification.
// Whether the event in question should trigger a notification.
Notify bool
// Whether or not the event in question should be highlighted.
// Whether the event in question should be highlighted.
Highlight bool
// Whether or not the event in question should trigger a sound alert.
// Whether the event in question should trigger a sound alert.
PlaySound bool
// The name of the sound to play if PlaySound is true.
SoundName string

View File

@@ -20,6 +20,7 @@ func init() {
}
type PushRuleCollection interface {
GetMatchingRule(room Room, evt *event.Event) *PushRule
GetActions(room Room, evt *event.Event) PushActionArray
}
@@ -32,16 +33,20 @@ func (rules PushRuleArray) SetType(typ PushRuleType) PushRuleArray {
return rules
}
func (rules PushRuleArray) GetActions(room Room, evt *event.Event) PushActionArray {
func (rules PushRuleArray) GetMatchingRule(room Room, evt *event.Event) *PushRule {
for _, rule := range rules {
if !rule.Match(room, evt) {
continue
}
return rule.Actions
return rule
}
return nil
}
func (rules PushRuleArray) GetActions(room Room, evt *event.Event) PushActionArray {
return rules.GetMatchingRule(room, evt).GetActions()
}
type PushRuleMap struct {
Map map[string]*PushRule
Type PushRuleType
@@ -59,7 +64,7 @@ func (rules PushRuleArray) SetTypeAndMap(typ PushRuleType) PushRuleMap {
return data
}
func (ruleMap PushRuleMap) GetActions(room Room, evt *event.Event) PushActionArray {
func (ruleMap PushRuleMap) GetMatchingRule(room Room, evt *event.Event) *PushRule {
var rule *PushRule
var found bool
switch ruleMap.Type {
@@ -69,11 +74,15 @@ func (ruleMap PushRuleMap) GetActions(room Room, evt *event.Event) PushActionArr
rule, found = ruleMap.Map[string(evt.Sender)]
}
if found && rule.Match(room, evt) {
return rule.Actions
return rule
}
return nil
}
func (ruleMap PushRuleMap) GetActions(room Room, evt *event.Event) PushActionArray {
return ruleMap.GetMatchingRule(room, evt).GetActions()
}
func (ruleMap PushRuleMap) Unmap() PushRuleArray {
array := make(PushRuleArray, len(ruleMap.Map))
index := 0
@@ -114,8 +123,15 @@ type PushRule struct {
Pattern string `json:"pattern,omitempty"`
}
func (rule *PushRule) GetActions() PushActionArray {
if rule == nil {
return nil
}
return rule.Actions
}
func (rule *PushRule) Match(room Room, evt *event.Event) bool {
if !rule.Enabled {
if rule == nil || !rule.Enabled {
return false
}
switch rule.Type {

View File

@@ -67,10 +67,7 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) {
// collections in a Ruleset match the event given to GetActions()
var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}}
// GetActions matches the given event against all of the push rule
// collections in this push ruleset in the order of priority as
// specified in spec section 11.12.1.4.
func (rs *PushRuleset) GetActions(room Room, evt *event.Event) (match PushActionArray) {
func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) {
// Add push rule collections to array in priority order
arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride}
// Loop until one of the push rule collections matches the room/event combo.
@@ -78,11 +75,23 @@ func (rs *PushRuleset) GetActions(room Room, evt *event.Event) (match PushAction
if pra == nil {
continue
}
if match = pra.GetActions(room, evt); match != nil {
if rule = pra.GetMatchingRule(room, evt); rule != nil {
// Match found, return it.
return
}
}
// No match found, return default actions.
return DefaultPushActions
// No match found
return nil
}
// GetActions matches the given event against all of the push rule
// collections in this push ruleset in the order of priority as
// specified in spec section 11.12.1.4.
func (rs *PushRuleset) GetActions(room Room, evt *event.Event) (match PushActionArray) {
actions := rs.GetMatchingRule(room, evt).GetActions()
if actions == nil {
// No match found, return default actions.
return DefaultPushActions
}
return actions
}

View File

@@ -392,6 +392,10 @@ func (req *ReqHierarchy) Query() map[string]string {
return query
}
type ReqAppservicePing struct {
TxnID string `json:"transaction_id,omitempty"`
}
type ReqBeeperMergeRoom struct {
NewRoom ReqCreateRoom `json:"create"`
Key string `json:"key"`
@@ -411,3 +415,23 @@ type ReqBeeperSplitRoom struct {
Key string `json:"key"`
Parts []BeeperSplitRoomPart `json:"parts"`
}
type ReqRoomKeysVersionCreate struct {
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
}
type ReqRoomKeysUpdate struct {
Rooms map[id.RoomID]ReqRoomKeysRoomUpdate `json:"rooms"`
}
type ReqRoomKeysRoomUpdate struct {
Sessions map[id.SessionID]ReqRoomKeysSessionUpdate `json:"sessions"`
}
type ReqRoomKeysSessionUpdate struct {
FirstMessageIndex int `json:"first_message_index"`
ForwardedCount int `json:"forwarded_count"`
IsVerified bool `json:"is_verified"`
SessionData json.RawMessage `json:"session_data"`
}

View File

@@ -552,6 +552,10 @@ type StrippedStateWithTime struct {
Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
}
type RespAppservicePing struct {
DurationMS int64 `json:"duration_ms"`
}
type RespBeeperMergeRoom RespCreateRoom
type RespBeeperSplitRoom struct {
@@ -562,3 +566,35 @@ type RespTimestampToEvent struct {
EventID id.EventID `json:"event_id"`
Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
}
type RespRoomKeysVersionCreate struct {
Version string `json:"version"`
}
type RespRoomKeysVersion struct {
Algorithm string `json:"algorithm"`
AuthData json.RawMessage `json:"auth_data"`
Count int `json:"count"`
ETag string `json:"etag"`
Version string `json:"version"`
}
type RespRoomKeys struct {
Rooms map[id.RoomID]RespRoomKeysRoom `json:"rooms"`
}
type RespRoomKeysRoom struct {
Sessions map[id.SessionID]RespRoomKeysSession `json:"sessions"`
}
type RespRoomKeysSession struct {
FirstMessageIndex int `json:"first_message_index"`
ForwardedCount int `json:"forwarded_count"`
IsVerified bool `json:"is_verified"`
SessionData json.RawMessage `json:"session_data"`
}
type RespRoomKeysUpdate struct {
Count int `json:"count"`
ETag string `json:"etag"`
}

View File

@@ -0,0 +1,368 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package sqlstatestore
import (
"database/sql"
"embed"
"encoding/json"
"errors"
"fmt"
"strconv"
"strings"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
//go:embed *.sql
var rawUpgrades embed.FS
var UpgradeTable dbutil.UpgradeTable
func init() {
UpgradeTable.RegisterFS(rawUpgrades)
}
const VersionTableName = "mx_version"
type SQLStateStore struct {
*dbutil.Database
IsBridge bool
}
func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore {
return &SQLStateStore{
Database: db.Child(VersionTableName, UpgradeTable, log),
IsBridge: isBridge,
}
}
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
var isRegistered bool
err := store.
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
Scan(&isRegistered)
if err != nil {
store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err)
}
return isRegistered
}
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
if err != nil {
store.Log.Warn("Failed to mark %s as registered: %v", userID, err)
}
}
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent {
members := make(map[id.UserID]*event.MemberEventContent)
args := make([]any, len(memberships)+1)
args[0] = roomID
query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1"
if len(memberships) > 0 {
placeholders := make([]string, len(memberships))
for i, membership := range memberships {
args[i+1] = string(membership)
placeholders[i] = fmt.Sprintf("$%d", i+2)
}
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
}
rows, err := store.Query(query, args...)
if err != nil {
return members
}
var userID id.UserID
var member event.MemberEventContent
for rows.Next() {
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil {
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
} else {
members[userID] = &member
}
}
return members
}
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite)
members = make([]id.UserID, len(memberMap))
i := 0
for userID := range memberMap {
members[i] = userID
i++
}
return
}
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
membership := event.MembershipLeave
err := store.
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&membership)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err)
}
return membership
}
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member.Membership = event.MembershipLeave
}
return member
}
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
var member event.MemberEventContent
err := store.
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err)
}
return &member, err == nil
}
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
query := `
SELECT room_id FROM mx_user_profile
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
WHERE mx_user_profile.user_id=$1 AND portal.encrypted=true
`
if !store.IsBridge {
query = `
SELECT mx_user_profile.room_id FROM mx_user_profile
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
`
}
rows, err := store.Query(query, userID)
if err != nil {
store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err)
return
}
for rows.Next() {
var roomID id.RoomID
err = rows.Scan(&roomID)
if err != nil {
store.Log.Warn("Failed to scan room ID: %v", err)
} else {
rooms = append(rooms, roomID)
}
}
return
}
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
}
}
return false
}
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
_, err := store.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
`, roomID, userID, membership)
if err != nil {
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
}
}
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
_, err := store.Exec(`
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
if err != nil {
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
}
}
func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
params := make([]any, len(memberships)+1)
params[0] = roomID
if len(memberships) > 0 {
placeholders := make([]string, len(memberships))
for i, membership := range memberships {
placeholders[i] = "$" + strconv.Itoa(i+2)
params[i+1] = string(membership)
}
query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ","))
}
_, err := store.Exec(query, params...)
if err != nil {
store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err)
}
}
func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
contentBytes, err := json.Marshal(content)
if err != nil {
store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err)
return
}
_, err = store.Exec(`
INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption
`, roomID, contentBytes)
if err != nil {
store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err)
}
}
func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
var data []byte
err := store.
QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err)
}
return nil
} else if data == nil {
return nil
}
content := &event.EncryptionEventContent{}
err = json.Unmarshal(data, content)
if err != nil {
store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err)
return nil
}
return content
}
func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool {
cfg := store.GetEncryptionEvent(roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
}
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
levelsBytes, err := json.Marshal(levels)
if err != nil {
store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err)
return
}
_, err = store.Exec(`
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
`, roomID, levelsBytes)
if err != nil {
store.Log.Warn("Failed to store power levels of %s: %v", roomID, err)
}
}
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
var data []byte
err := store.
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
Scan(&data)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err)
}
return
} else if data == nil {
return
}
levels = &event.PowerLevelsEventContent{}
err = json.Unmarshal(data, levels)
if err != nil {
store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err)
return nil
}
return
}
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
if store.Dialect == dbutil.Postgres {
var powerLevel int
err := store.
QueryRow(`
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
FROM mx_room_state WHERE room_id=$1
`, roomID, userID).
Scan(&powerLevel)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err)
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var powerLevel int
err := store.
QueryRow(`
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
FROM mx_room_state WHERE room_id=$1
`, roomID, eventType.Type, defaultType, defaultValue).
Scan(&powerLevel)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue
}
return powerLevel
}
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
if store.Dialect == dbutil.Postgres {
defaultType := "events_default"
defaultValue := 0
if eventType.IsState() {
defaultType = "state_default"
defaultValue = 50
}
var hasPower bool
err := store.
QueryRow(`SELECT
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
>=
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
Scan(&hasPower)
if err != nil {
if !errors.Is(err, sql.ErrNoRows) {
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
}
return defaultValue == 0
}
return hasPower
}
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}

View File

@@ -0,0 +1,23 @@
-- v0 -> v5: Latest revision
CREATE TABLE mx_registrations (
user_id TEXT PRIMARY KEY
);
-- only: postgres
CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock');
CREATE TABLE mx_user_profile (
room_id TEXT,
user_id TEXT,
membership membership NOT NULL,
displayname TEXT NOT NULL DEFAULT '',
avatar_url TEXT NOT NULL DEFAULT '',
PRIMARY KEY (room_id, user_id)
);
CREATE TABLE mx_room_state (
room_id TEXT PRIMARY KEY,
power_levels jsonb,
encryption jsonb
);

View File

@@ -0,0 +1,6 @@
-- v2: Use enum for membership field on Postgres
-- only: postgres
CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock');
UPDATE mx_user_profile SET membership='leave' WHERE LOWER(membership) NOT IN ('join', 'leave', 'invite', 'ban', 'knock');
ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership USING LOWER(membership)::membership;

View File

@@ -0,0 +1,10 @@
-- v3: Disable nulls in mx_user_profile
UPDATE mx_user_profile SET displayname='' WHERE displayname IS NULL;
UPDATE mx_user_profile SET avatar_url='' WHERE avatar_url IS NULL;
-- only: postgres for next 4 lines
ALTER TABLE mx_user_profile ALTER COLUMN displayname SET DEFAULT '';
ALTER TABLE mx_user_profile ALTER COLUMN displayname SET NOT NULL;
ALTER TABLE mx_user_profile ALTER COLUMN avatar_url SET DEFAULT '';
ALTER TABLE mx_user_profile ALTER COLUMN avatar_url SET NOT NULL;

View File

@@ -0,0 +1,2 @@
-- v4: Store room encryption configuration
ALTER TABLE mx_room_state ADD COLUMN encryption jsonb;

View File

@@ -0,0 +1,29 @@
package sqlstatestore
import (
"fmt"
"maunium.net/go/mautrix/util/dbutil"
)
func init() {
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error {
portalExists, err := db.TableExists(tx, "portal")
if err != nil {
return fmt.Errorf("failed to check if portal table exists")
}
if portalExists {
_, err = tx.Exec(`
INSERT INTO mx_room_state (room_id, encryption)
SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
ON CONFLICT (room_id) DO UPDATE
SET encryption=excluded.encryption
WHERE mx_room_state.encryption IS NULL
`)
if err != nil {
return err
}
}
return nil
})
}

251
vendor/maunium.net/go/mautrix/statestore.go generated vendored Normal file
View File

@@ -0,0 +1,251 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package mautrix
import (
"sync"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// StateStore is an interface for storing basic room state information.
type StateStore interface {
IsInRoom(roomID id.RoomID, userID id.UserID) bool
IsInvited(roomID id.RoomID, userID id.UserID) bool
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership)
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent)
IsEncrypted(roomID id.RoomID) bool
GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error)
}
func UpdateStateStore(store StateStore, evt *event.Event) {
if store == nil || evt == nil || evt.StateKey == nil {
return
}
// We only care about events without a state key (power levels, encryption) or member events with state key
if evt.Type != event.StateMember && evt.GetStateKey() != "" {
return
}
switch content := evt.Content.Parsed.(type) {
case *event.MemberEventContent:
store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
case *event.PowerLevelsEventContent:
store.SetPowerLevels(evt.RoomID, content)
case *event.EncryptionEventContent:
store.SetEncryptionEvent(evt.RoomID, content)
}
}
// StateStoreSyncHandler can be added as an event handler in the syncer to update the state store automatically.
//
// client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler)
//
// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) {
UpdateStateStore(cli.StateStore, evt)
}
type MemoryStateStore struct {
Registrations map[id.UserID]bool `json:"registrations"`
Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"`
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
registrationsLock sync.RWMutex
membersLock sync.RWMutex
powerLevelsLock sync.RWMutex
encryptionLock sync.RWMutex
}
func NewMemoryStateStore() StateStore {
return &MemoryStateStore{
Registrations: make(map[id.UserID]bool),
Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent),
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
}
}
func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool {
store.registrationsLock.RLock()
defer store.registrationsLock.RUnlock()
registered, ok := store.Registrations[userID]
return ok && registered
}
func (store *MemoryStateStore) MarkRegistered(userID id.UserID) {
store.registrationsLock.Lock()
defer store.registrationsLock.Unlock()
store.Registrations[userID] = true
}
func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
store.membersLock.RLock()
members, ok := store.Members[roomID]
store.membersLock.RUnlock()
if !ok {
members = make(map[id.UserID]*event.MemberEventContent)
store.membersLock.Lock()
store.Members[roomID] = members
store.membersLock.Unlock()
}
return members
}
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) {
members := store.GetRoomMembers(roomID)
ids := make([]id.UserID, 0, len(members))
for id := range members {
ids = append(ids, id)
}
return ids, nil
}
func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
return store.GetMember(roomID, userID).Membership
}
func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
member, ok := store.TryGetMember(roomID, userID)
if !ok {
member = &event.MemberEventContent{Membership: event.MembershipLeave}
}
return member
}
func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
store.membersLock.RLock()
defer store.membersLock.RUnlock()
members, membersOk := store.Members[roomID]
if !membersOk {
return
}
member, ok = members[userID]
return
}
func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join")
}
func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
return store.IsMembership(roomID, userID, "join", "invite")
}
func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
membership := store.GetMembership(roomID, userID)
for _, allowedMembership := range allowedMemberships {
if allowedMembership == membership {
return true
}
}
return false
}
func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
members = map[id.UserID]*event.MemberEventContent{
userID: {Membership: membership},
}
} else {
member, ok := members[userID]
if !ok {
members[userID] = &event.MemberEventContent{Membership: membership}
} else {
member.Membership = membership
members[userID] = member
}
}
store.Members[roomID] = members
store.membersLock.Unlock()
}
func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
store.membersLock.Lock()
members, ok := store.Members[roomID]
if !ok {
members = map[id.UserID]*event.MemberEventContent{
userID: member,
}
} else {
members[userID] = member
}
store.Members[roomID] = members
store.membersLock.Unlock()
}
func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
store.membersLock.Lock()
defer store.membersLock.Unlock()
members, ok := store.Members[roomID]
if !ok {
return
}
for userID, member := range members {
for _, membership := range memberships {
if membership == member.Membership {
delete(members, userID)
break
}
}
}
}
func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
store.powerLevelsLock.Lock()
store.PowerLevels[roomID] = levels
store.powerLevelsLock.Unlock()
}
func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
store.powerLevelsLock.RLock()
levels = store.PowerLevels[roomID]
store.powerLevelsLock.RUnlock()
return
}
func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
return store.GetPowerLevels(roomID).GetUserLevel(userID)
}
func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
}
func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
}
func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
store.encryptionLock.Lock()
store.Encryption[roomID] = content
store.encryptionLock.Unlock()
}
func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
store.encryptionLock.RLock()
defer store.encryptionLock.RUnlock()
return store.Encryption[roomID]
}
func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool {
cfg := store.GetEncryptionEvent(roomID)
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
}

View File

@@ -1,161 +0,0 @@
package mautrix
import (
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// Storer is an interface which must be satisfied to store client data.
//
// You can either write a struct which persists this data to disk, or you can use the
// provided "InMemoryStore" which just keeps data around in-memory which is lost on
// restarts.
type Storer interface {
SaveFilterID(userID id.UserID, filterID string)
LoadFilterID(userID id.UserID) string
SaveNextBatch(userID id.UserID, nextBatchToken string)
LoadNextBatch(userID id.UserID) string
SaveRoom(room *Room)
LoadRoom(roomID id.RoomID) *Room
}
// InMemoryStore implements the Storer interface.
//
// Everything is persisted in-memory as maps. It is not safe to load/save filter IDs
// or next batch tokens on any goroutine other than the syncing goroutine: the one
// which called Client.Sync().
type InMemoryStore struct {
Filters map[id.UserID]string
NextBatch map[id.UserID]string
Rooms map[id.RoomID]*Room
}
// SaveFilterID to memory.
func (s *InMemoryStore) SaveFilterID(userID id.UserID, filterID string) {
s.Filters[userID] = filterID
}
// LoadFilterID from memory.
func (s *InMemoryStore) LoadFilterID(userID id.UserID) string {
return s.Filters[userID]
}
// SaveNextBatch to memory.
func (s *InMemoryStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
s.NextBatch[userID] = nextBatchToken
}
// LoadNextBatch from memory.
func (s *InMemoryStore) LoadNextBatch(userID id.UserID) string {
return s.NextBatch[userID]
}
// SaveRoom to memory.
func (s *InMemoryStore) SaveRoom(room *Room) {
s.Rooms[room.ID] = room
}
// LoadRoom from memory.
func (s *InMemoryStore) LoadRoom(roomID id.RoomID) *Room {
return s.Rooms[roomID]
}
// UpdateState stores a state event. This can be passed to DefaultSyncer.OnEvent to keep all room state cached.
func (s *InMemoryStore) UpdateState(_ EventSource, evt *event.Event) {
if !evt.Type.IsState() {
return
}
room := s.LoadRoom(evt.RoomID)
if room == nil {
room = NewRoom(evt.RoomID)
s.SaveRoom(room)
}
room.UpdateState(evt)
}
// NewInMemoryStore constructs a new InMemoryStore.
func NewInMemoryStore() *InMemoryStore {
return &InMemoryStore{
Filters: make(map[id.UserID]string),
NextBatch: make(map[id.UserID]string),
Rooms: make(map[id.RoomID]*Room),
}
}
// AccountDataStore uses account data to store the next batch token, and
// reuses the InMemoryStore for all other operations.
type AccountDataStore struct {
*InMemoryStore
eventType string
client *Client
}
type accountData struct {
NextBatch string `json:"next_batch"`
}
// SaveNextBatch to account data.
func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with bots")
}
data := accountData{
NextBatch: nextBatchToken,
}
err := s.client.SetAccountData(s.eventType, data)
if err != nil {
if s.client.Logger != nil {
s.client.Logger.Debugfln("failed to save next batch token to account data: %s", err.Error())
}
}
}
// LoadNextBatch from account data.
func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with bots")
}
data := &accountData{}
err := s.client.GetAccountData(s.eventType, data)
if err != nil {
if s.client.Logger != nil {
s.client.Logger.Debugfln("failed to load next batch token to account data: %s", err.Error())
}
return ""
}
return data.NextBatch
}
// NewAccountDataStore returns a new AccountDataStore, which stores
// the next_batch token as a custom event in account data in the
// homeserver.
//
// AccountDataStore is only appropriate for bots, not appservices.
//
// eventType should be a reversed DNS name like tld.domain.sub.internal and
// must be unique for a client. The data stored in it is considered internal
// and must not be modified through outside means. You should also add a filter
// for account data changes of this event type, to avoid ending up in a sync
// loop:
//
// mautrix.Filter{
// AccountData: mautrix.FilterPart{
// Limit: 20,
// NotTypes: []event.Type{
// event.NewEventType(eventType),
// },
// },
// }
// mautrix.Client.CreateFilter(...)
func NewAccountDataStore(eventType string, client *Client) *AccountDataStore {
return &AccountDataStore{
InMemoryStore: NewInMemoryStore(),
eventType: eventType,
client: client,
}
}

118
vendor/maunium.net/go/mautrix/sync.go generated vendored
View File

@@ -7,6 +7,7 @@
package mautrix
import (
"errors"
"fmt"
"runtime/debug"
"time"
@@ -28,44 +29,52 @@ const (
EventSourceState
EventSourceEphemeral
EventSourceToDevice
EventSourceDecrypted
)
const primaryTypes = EventSourcePresence | EventSourceAccountData | EventSourceToDevice | EventSourceTimeline | EventSourceState
const roomSections = EventSourceJoin | EventSourceInvite | EventSourceLeave
const roomableTypes = EventSourceAccountData | EventSourceTimeline | EventSourceState
const encryptableTypes = roomableTypes | EventSourceToDevice
func (es EventSource) String() string {
switch {
case es == EventSourcePresence:
return "presence"
case es == EventSourceAccountData:
return "user account data"
case es == EventSourceToDevice:
return "to-device"
case es&EventSourceJoin != 0:
es -= EventSourceJoin
switch es {
case EventSourceState:
return "joined state"
case EventSourceTimeline:
return "joined timeline"
case EventSourceEphemeral:
return "room ephemeral (joined)"
case EventSourceAccountData:
return "room account data (joined)"
}
case es&EventSourceInvite != 0:
es -= EventSourceInvite
switch es {
case EventSourceState:
return "invited state"
}
case es&EventSourceLeave != 0:
es -= EventSourceLeave
switch es {
case EventSourceState:
return "left state"
case EventSourceTimeline:
return "left timeline"
}
var typeName string
switch es & primaryTypes {
case EventSourcePresence:
typeName = "presence"
case EventSourceAccountData:
typeName = "account data"
case EventSourceToDevice:
typeName = "to-device"
case EventSourceTimeline:
typeName = "timeline"
case EventSourceState:
typeName = "state"
default:
return fmt.Sprintf("unknown (%d)", es)
}
return fmt.Sprintf("unknown (%d)", es)
if es&roomableTypes != 0 {
switch es & roomSections {
case EventSourceJoin:
typeName = "joined room " + typeName
case EventSourceInvite:
typeName = "invited room " + typeName
case EventSourceLeave:
typeName = "left room " + typeName
default:
return fmt.Sprintf("unknown (%d)", es)
}
es &^= roomableTypes
}
if es&encryptableTypes != 0 && es&EventSourceDecrypted != 0 {
typeName += " (decrypted)"
es &^= EventSourceDecrypted
}
es &^= primaryTypes
if es != 0 {
return fmt.Sprintf("unknown (%d)", es)
}
return typeName
}
// EventHandler handles a single event from a sync response.
@@ -76,9 +85,8 @@ type SyncHandler func(resp *RespSync, since string) bool
// Syncer is an interface that must be satisfied in order to do /sync requests on a client.
type Syncer interface {
// Process the /sync response. The since parameter is the since= value that was used to produce the response.
// This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped
// permanently.
// ProcessResponse processes the /sync response. The since parameter is the since= value that was used to produce the response.
// This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped permanently.
ProcessResponse(resp *RespSync, since string) error
// OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently.
OnFailedSync(res *RespSync, err error) (time.Duration, error)
@@ -92,6 +100,10 @@ type ExtensibleSyncer interface {
OnEventType(eventType event.Type, callback EventHandler)
}
type DispatchableSyncer interface {
Dispatch(source EventSource, evt *event.Event)
}
// DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively
// replace parts of this default syncer (e.g. the ProcessResponse method). The default syncer uses the observer
// pattern to notify callers about incoming events. See DefaultSyncer.OnEventType for more information.
@@ -107,6 +119,8 @@ type DefaultSyncer struct {
// ParseErrorHandler is called when event.Content.ParseRaw returns an error.
// If it returns false, the event will not be forwarded to listeners.
ParseErrorHandler func(evt *event.Event, err error) bool
// FilterJSON is used when the client starts syncing and doesn't get an existing filter ID from SyncStore's LoadFilterID.
FilterJSON *Filter
}
var _ Syncer = (*DefaultSyncer)(nil)
@@ -191,10 +205,10 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
}
}
s.notifyListeners(source, evt)
s.Dispatch(source, evt)
}
func (s *DefaultSyncer) notifyListeners(source EventSource, evt *event.Event) {
func (s *DefaultSyncer) Dispatch(source EventSource, evt *event.Event) {
for _, fn := range s.globalListeners {
fn(source, evt)
}
@@ -226,22 +240,34 @@ func (s *DefaultSyncer) OnEvent(callback EventHandler) {
// OnFailedSync always returns a 10 second wait period between failed /syncs, never a fatal error.
func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, error) {
if errors.Is(err, MUnknownToken) {
return 0, err
}
return 10 * time.Second, nil
}
var defaultFilter = Filter{
Room: RoomFilter{
Timeline: FilterPart{
Limit: 50,
},
},
}
// GetFilterJSON returns a filter with a timeline limit of 50.
func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter {
return &Filter{
Room: RoomFilter{
Timeline: FilterPart{
Limit: 50,
},
},
if s.FilterJSON == nil {
defaultFilterCopy := defaultFilter
s.FilterJSON = &defaultFilterCopy
}
return s.FilterJSON
}
// OldEventIgnorer is an utility struct for bots to ignore events from before the bot joined the room.
// Create a struct and call Register with your DefaultSyncer to register the sync handler.
//
// Create a struct and call Register with your DefaultSyncer to register the sync handler, e.g.:
//
// (&OldEventIgnorer{UserID: cli.UserID}).Register(cli.Syncer.(mautrix.ExtensibleSyncer))
type OldEventIgnorer struct {
UserID id.UserID
}

166
vendor/maunium.net/go/mautrix/syncstore.go generated vendored Normal file
View File

@@ -0,0 +1,166 @@
package mautrix
import (
"errors"
"maunium.net/go/mautrix/id"
)
// SyncStore is an interface which must be satisfied to store client data.
//
// You can either write a struct which persists this data to disk, or you can use the
// provided "MemorySyncStore" which just keeps data around in-memory which is lost on
// restarts.
type SyncStore interface {
SaveFilterID(userID id.UserID, filterID string)
LoadFilterID(userID id.UserID) string
SaveNextBatch(userID id.UserID, nextBatchToken string)
LoadNextBatch(userID id.UserID) string
}
// Deprecated: renamed to SyncStore
type Storer = SyncStore
// MemorySyncStore implements the Storer interface.
//
// Everything is persisted in-memory as maps. It is not safe to load/save filter IDs
// or next batch tokens on any goroutine other than the syncing goroutine: the one
// which called Client.Sync().
type MemorySyncStore struct {
Filters map[id.UserID]string
NextBatch map[id.UserID]string
}
// SaveFilterID to memory.
func (s *MemorySyncStore) SaveFilterID(userID id.UserID, filterID string) {
s.Filters[userID] = filterID
}
// LoadFilterID from memory.
func (s *MemorySyncStore) LoadFilterID(userID id.UserID) string {
return s.Filters[userID]
}
// SaveNextBatch to memory.
func (s *MemorySyncStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
s.NextBatch[userID] = nextBatchToken
}
// LoadNextBatch from memory.
func (s *MemorySyncStore) LoadNextBatch(userID id.UserID) string {
return s.NextBatch[userID]
}
// NewMemorySyncStore constructs a new MemorySyncStore.
func NewMemorySyncStore() *MemorySyncStore {
return &MemorySyncStore{
Filters: make(map[id.UserID]string),
NextBatch: make(map[id.UserID]string),
}
}
// AccountDataStore uses account data to store the next batch token, and stores the filter ID in memory
// (as filters can be safely recreated every startup).
type AccountDataStore struct {
FilterID string
EventType string
client *Client
nextBatch string
}
type accountData struct {
NextBatch string `json:"next_batch"`
}
func (s *AccountDataStore) SaveFilterID(userID id.UserID, filterID string) {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
}
s.FilterID = filterID
}
func (s *AccountDataStore) LoadFilterID(userID id.UserID) string {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
}
return s.FilterID
}
func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
} else if nextBatchToken == s.nextBatch {
return
}
data := accountData{
NextBatch: nextBatchToken,
}
err := s.client.SetAccountData(s.EventType, data)
if err != nil {
s.client.Log.Warn().Err(err).Msg("Failed to save next batch token to account data")
} else {
s.client.Log.Debug().
Str("old_token", s.nextBatch).
Str("new_token", nextBatchToken).
Msg("Saved next batch token")
s.nextBatch = nextBatchToken
}
}
func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string {
if userID.String() != s.client.UserID.String() {
panic("AccountDataStore must only be used with a single account")
}
data := &accountData{}
err := s.client.GetAccountData(s.EventType, data)
if err != nil {
if errors.Is(err, MNotFound) {
s.client.Log.Debug().Msg("No next batch token found in account data")
} else {
s.client.Log.Warn().Err(err).Msg("Failed to load next batch token from account data")
}
return ""
}
s.nextBatch = data.NextBatch
s.client.Log.Debug().Str("next_batch", data.NextBatch).Msg("Loaded next batch token from account data")
return s.nextBatch
}
// NewAccountDataStore returns a new AccountDataStore, which stores
// the next_batch token as a custom event in account data in the
// homeserver.
//
// AccountDataStore is only appropriate for bots, not appservices.
//
// The event type should be a reversed DNS name like tld.domain.sub.internal and
// must be unique for a client. The data stored in it is considered internal
// and must not be modified through outside means. You should also add a filter
// for account data changes of this event type, to avoid ending up in a sync
// loop:
//
// filter := mautrix.Filter{
// AccountData: mautrix.FilterPart{
// Limit: 20,
// NotTypes: []event.Type{
// event.NewEventType(eventType),
// },
// },
// }
// // If you use a custom Syncer, set the filter there, not like this
// client.Syncer.(*mautrix.DefaultSyncer).FilterJSON = &filter
// client.Store = mautrix.NewAccountDataStore("com.example.mybot.store", client)
// go func() {
// err := client.Sync()
// // don't forget to check err
// }()
func NewAccountDataStore(eventType string, client *Client) *AccountDataStore {
return &AccountDataStore{
EventType: eventType,
client: client,
}
}

View File

@@ -13,7 +13,7 @@ import (
"strings"
)
func parseAndNormalizeBaseURL(homeserverURL string) (*url.URL, error) {
func ParseAndNormalizeBaseURL(homeserverURL string) (*url.URL, error) {
hsURL, err := url.Parse(homeserverURL)
if err != nil {
return nil, err
@@ -43,7 +43,7 @@ func BuildURL(baseURL *url.URL, path ...interface{}) *url.URL {
parts[i+1] = casted
case int:
parts[i+1] = strconv.Itoa(casted)
case Stringifiable:
case fmt.Stringer:
parts[i+1] = casted.String()
default:
parts[i+1] = fmt.Sprint(casted)
@@ -93,8 +93,8 @@ func (mup MediaURLPath) FullPath() []interface{} {
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...)
query := hsURL.Query()
if cli.AppServiceUserID != "" {
query.Set("user_id", string(cli.AppServiceUserID))
if cli.SetAppServiceUserID {
query.Set("user_id", string(cli.UserID))
}
if urlQuery != nil {
for k, v := range urlQuery {

28
vendor/maunium.net/go/mautrix/util/callermarshal.go generated vendored Normal file
View File

@@ -0,0 +1,28 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"fmt"
"runtime"
"strings"
)
// CallerWithFunctionName is an implementation for zerolog.CallerMarshalFunc that includes the caller function name
// in addition to the file and line number.
//
// Use as
//
// zerolog.CallerMarshalFunc = util.CallerWithFunctionName
func CallerWithFunctionName(pc uintptr, file string, line int) string {
files := strings.Split(file, "/")
file = files[len(files)-1]
name := runtime.FuncForPC(pc).Name()
fns := strings.Split(name, ".")
name = fns[len(fns)-1]
return fmt.Sprintf("%s:%d:%s()", file, line, name)
}

View File

@@ -1,288 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package configupgrade
import (
"fmt"
"os"
"strings"
"gopkg.in/yaml.v3"
)
type YAMLMap map[string]YAMLNode
type YAMLList []YAMLNode
type YAMLNode struct {
*yaml.Node
Map YAMLMap
List YAMLList
Key *yaml.Node
}
type YAMLType uint32
const (
Null YAMLType = 1 << iota
Bool
Str
Int
Float
Timestamp
List
Map
Binary
)
func (t YAMLType) String() string {
switch t {
case Null:
return NullTag
case Bool:
return BoolTag
case Str:
return StrTag
case Int:
return IntTag
case Float:
return FloatTag
case Timestamp:
return TimestampTag
case List:
return SeqTag
case Map:
return MapTag
case Binary:
return BinaryTag
default:
panic(fmt.Errorf("can't convert type %d to string", t))
}
}
func tagToType(tag string) YAMLType {
switch tag {
case NullTag:
return Null
case BoolTag:
return Bool
case StrTag:
return Str
case IntTag:
return Int
case FloatTag:
return Float
case TimestampTag:
return Timestamp
case SeqTag:
return List
case MapTag:
return Map
case BinaryTag:
return Binary
default:
return 0
}
}
const (
NullTag = "!!null"
BoolTag = "!!bool"
StrTag = "!!str"
IntTag = "!!int"
FloatTag = "!!float"
TimestampTag = "!!timestamp"
SeqTag = "!!seq"
MapTag = "!!map"
BinaryTag = "!!binary"
)
func fromNode(node, key *yaml.Node) YAMLNode {
switch node.Kind {
case yaml.DocumentNode:
return fromNode(node.Content[0], nil)
case yaml.AliasNode:
return fromNode(node.Alias, nil)
case yaml.MappingNode:
return YAMLNode{
Node: node,
Map: parseYAMLMap(node),
Key: key,
}
case yaml.SequenceNode:
return YAMLNode{
Node: node,
List: parseYAMLList(node),
}
default:
return YAMLNode{Node: node, Key: key}
}
}
func (yn *YAMLNode) toNode() *yaml.Node {
yn.UpdateContent()
return yn.Node
}
func (yn *YAMLNode) UpdateContent() {
switch {
case yn.Map != nil && yn.Node.Kind == yaml.MappingNode:
yn.Content = yn.Map.toNodes()
case yn.List != nil && yn.Node.Kind == yaml.SequenceNode:
yn.Content = yn.List.toNodes()
}
}
func parseYAMLList(node *yaml.Node) YAMLList {
data := make(YAMLList, len(node.Content))
for i, item := range node.Content {
data[i] = fromNode(item, nil)
}
return data
}
func (yl YAMLList) toNodes() []*yaml.Node {
nodes := make([]*yaml.Node, len(yl))
for i, item := range yl {
nodes[i] = item.toNode()
}
return nodes
}
func parseYAMLMap(node *yaml.Node) YAMLMap {
if len(node.Content)%2 != 0 {
panic(fmt.Errorf("uneven number of items in YAML map (%d)", len(node.Content)))
}
data := make(YAMLMap, len(node.Content)/2)
for i := 0; i < len(node.Content); i += 2 {
key := node.Content[i]
value := node.Content[i+1]
if key.Kind == yaml.ScalarNode {
data[key.Value] = fromNode(value, key)
}
}
return data
}
func (ym YAMLMap) toNodes() []*yaml.Node {
nodes := make([]*yaml.Node, len(ym)*2)
i := 0
for key, value := range ym {
nodes[i] = makeStringNode(key)
nodes[i+1] = value.toNode()
i += 2
}
return nodes
}
func makeStringNode(val string) *yaml.Node {
var node yaml.Node
node.SetString(val)
return &node
}
func StringNode(val string) YAMLNode {
return YAMLNode{Node: makeStringNode(val)}
}
type Helper struct {
Base YAMLNode
Config YAMLNode
}
func NewHelper(base, cfg *yaml.Node) *Helper {
return &Helper{
Base: fromNode(base, nil),
Config: fromNode(cfg, nil),
}
}
func (helper *Helper) AddSpaceBeforeComment(path ...string) {
node := helper.GetBaseNode(path...)
if node == nil || node.Key == nil {
panic(fmt.Errorf("didn't find key at %+v", path))
}
node.Key.HeadComment = "\n" + node.Key.HeadComment
}
func (helper *Helper) Copy(allowedTypes YAMLType, path ...string) {
base, cfg := helper.Base, helper.Config
var ok bool
for _, item := range path {
cfg, ok = cfg.Map[item]
if !ok {
return
}
base, ok = base.Map[item]
if !ok {
_, _ = fmt.Fprintf(os.Stderr, "Ignoring config field %s which is missing in base config\n", strings.Join(path, "->"))
return
}
}
if allowedTypes&tagToType(cfg.Tag) == 0 {
_, _ = fmt.Fprintf(os.Stderr, "Ignoring incorrect config field type %s at %s\n", cfg.Tag, strings.Join(path, "->"))
return
}
base.Tag = cfg.Tag
base.Style = cfg.Style
switch base.Kind {
case yaml.ScalarNode:
base.Value = cfg.Value
case yaml.SequenceNode, yaml.MappingNode:
base.Content = cfg.Content
}
}
func getNode(cfg YAMLNode, path []string) *YAMLNode {
var ok bool
for _, item := range path {
cfg, ok = cfg.Map[item]
if !ok {
return nil
}
}
return &cfg
}
func (helper *Helper) GetNode(path ...string) *YAMLNode {
return getNode(helper.Config, path)
}
func (helper *Helper) GetBaseNode(path ...string) *YAMLNode {
return getNode(helper.Base, path)
}
func (helper *Helper) Get(tag YAMLType, path ...string) (string, bool) {
node := helper.GetNode(path...)
if node == nil || node.Kind != yaml.ScalarNode || tag&tagToType(node.Tag) == 0 {
return "", false
}
return node.Value, true
}
func (helper *Helper) GetBase(path ...string) string {
return helper.GetBaseNode(path...).Value
}
func (helper *Helper) Set(tag YAMLType, value string, path ...string) {
base := helper.Base
for _, item := range path {
base = base.Map[item]
}
base.Tag = tag.String()
base.Value = value
}
func (helper *Helper) SetMap(value YAMLMap, path ...string) {
base := helper.Base
for _, item := range path {
base = base.Map[item]
}
if base.Tag != MapTag || base.Kind != yaml.MappingNode {
panic(fmt.Errorf("invalid target for SetMap(%+v): tag:%s, kind:%d", path, base.Tag, base.Kind))
}
base.Content = value.toNodes()
}

View File

@@ -1,108 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package configupgrade
import (
"fmt"
"os"
"path"
"gopkg.in/yaml.v3"
)
type Upgrader interface {
DoUpgrade(helper *Helper)
}
type SpacedUpgrader interface {
Upgrader
SpacedBlocks() [][]string
}
type BaseUpgrader interface {
Upgrader
GetBase() string
}
type StructUpgrader struct {
SimpleUpgrader
Blocks [][]string
Base string
}
func (su *StructUpgrader) SpacedBlocks() [][]string {
return su.Blocks
}
func (su *StructUpgrader) GetBase() string {
return su.Base
}
type SimpleUpgrader func(helper *Helper)
func (su SimpleUpgrader) DoUpgrade(helper *Helper) {
su(helper)
}
func (helper *Helper) apply(upgrader Upgrader) {
upgrader.DoUpgrade(helper)
helper.addSpaces(upgrader)
}
func (helper *Helper) addSpaces(upgrader Upgrader) {
spaced, ok := upgrader.(SpacedUpgrader)
if ok {
for _, spacePath := range spaced.SpacedBlocks() {
helper.AddSpaceBeforeComment(spacePath...)
}
}
}
func Do(configPath string, save bool, upgrader BaseUpgrader, additional ...Upgrader) ([]byte, bool, error) {
sourceData, err := os.ReadFile(configPath)
if err != nil {
return nil, false, fmt.Errorf("failed to read config: %w", err)
}
var base, cfg yaml.Node
err = yaml.Unmarshal([]byte(upgrader.GetBase()), &base)
if err != nil {
return sourceData, false, fmt.Errorf("failed to unmarshal example config: %w", err)
}
err = yaml.Unmarshal(sourceData, &cfg)
if err != nil {
return sourceData, false, fmt.Errorf("failed to unmarshal config: %w", err)
}
helper := NewHelper(&base, &cfg)
helper.apply(upgrader)
for _, add := range additional {
helper.apply(add)
}
output, err := yaml.Marshal(&base)
if err != nil {
return sourceData, false, fmt.Errorf("failed to marshal updated config: %w", err)
}
if save {
var tempFile *os.File
tempFile, err = os.CreateTemp(path.Dir(configPath), "mautrix-config-*.yaml")
if err != nil {
return output, true, fmt.Errorf("failed to create temp file for writing config: %w", err)
}
_, err = tempFile.Write(output)
if err != nil {
_ = os.Remove(tempFile.Name())
return output, true, fmt.Errorf("failed to write updated config to temp file: %w", err)
}
err = os.Rename(tempFile.Name(), configPath)
if err != nil {
_ = os.Remove(tempFile.Name())
return output, true, fmt.Errorf("failed to override current config with temp file: %w", err)
}
}
return output, true, nil
}

View File

@@ -23,7 +23,7 @@ func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args .
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start))
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
return res, err
}
@@ -31,7 +31,7 @@ func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start))
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
return &LoggingRows{
ctx: ctx,
db: le.db,
@@ -46,7 +46,7 @@ func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, ar
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start))
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start), nil)
return row
}
@@ -73,7 +73,7 @@ type loggingDB struct {
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
start := time.Now()
tx, err := ld.db.RawDB.BeginTx(ctx, opts)
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start))
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start), err)
if err != nil {
return nil, err
}
@@ -81,6 +81,7 @@ func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Logging
LoggingExecable: LoggingExecable{UnderlyingExecable: tx, db: ld.db},
UnderlyingTx: tx,
ctx: ctx,
StartTime: start,
}, nil
}
@@ -92,22 +93,35 @@ type LoggingTxn struct {
LoggingExecable
UnderlyingTx *sql.Tx
ctx context.Context
StartTime time.Time
EndTime time.Time
noTotalLog bool
}
func (lt *LoggingTxn) Commit() error {
start := time.Now()
err := lt.UnderlyingTx.Commit()
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start))
lt.endLog()
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start), err)
return err
}
func (lt *LoggingTxn) Rollback() error {
start := time.Now()
err := lt.UnderlyingTx.Rollback()
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start))
lt.endLog()
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start), err)
return err
}
func (lt *LoggingTxn) endLog() {
lt.EndTime = time.Now()
if !lt.noTotalLog {
lt.db.Log.QueryTiming(lt.ctx, "<Transaction>", "", nil, -1, lt.EndTime.Sub(lt.StartTime), nil)
}
}
type LoggingRows struct {
ctx context.Context
db *Database
@@ -120,7 +134,7 @@ type LoggingRows struct {
func (lrs *LoggingRows) stopTiming() {
if !lrs.start.IsZero() {
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start))
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start), lrs.rs.Err())
lrs.start = time.Time{}
}
}

View File

@@ -13,8 +13,6 @@ import (
"regexp"
"strings"
"time"
"maunium.net/go/mautrix/bridge/bridgeconfig"
)
type Dialect int
@@ -38,7 +36,7 @@ func (dialect Dialect) String() string {
func ParseDialect(engine string) (Dialect, error) {
switch strings.ToLower(engine) {
case "postgres", "postgresql":
case "postgres", "postgresql", "pgx":
return Postgres, nil
case "sqlite3", "sqlite", "litestream", "sqlite3-fk-wal":
return SQLite, nil
@@ -176,7 +174,18 @@ func NewWithDialect(uri, rawDialect string) (*Database, error) {
return NewWithDB(db, rawDialect)
}
func (db *Database) Configure(cfg bridgeconfig.DatabaseConfig) error {
type Config struct {
Type string `yaml:"type"`
URI string `yaml:"uri"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
}
func (db *Database) Configure(cfg Config) error {
db.RawDB.SetMaxOpenConns(cfg.MaxOpenConns)
db.RawDB.SetMaxIdleConns(cfg.MaxIdleConns)
if len(cfg.ConnMaxIdleTime) > 0 {
@@ -196,7 +205,7 @@ func (db *Database) Configure(cfg bridgeconfig.DatabaseConfig) error {
return nil
}
func NewFromConfig(owner string, cfg bridgeconfig.DatabaseConfig, logger DatabaseLogger) (*Database, error) {
func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database, error) {
dialect, err := ParseDialect(cfg.Type)
if err != nil {
return nil, err

View File

@@ -11,9 +11,9 @@ import (
)
type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration)
WarnUnsupportedVersion(current, latest int)
PrepareUpgrade(current, latest int)
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error)
WarnUnsupportedVersion(current, compat, latest int)
PrepareUpgrade(current, compat, latest int)
DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{})
@@ -23,35 +23,36 @@ type noopLogger struct{}
var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _ int) {}
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration) {
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) {
}
type mauLogger struct {
l maulogger.Logger
}
// Deprecated: Use zerolog instead
func MauLogger(log maulogger.Logger) DatabaseLogger {
return &mauLogger{l: log}
}
func (m mauLogger) WarnUnsupportedVersion(current, latest int) {
m.l.Warnfln("Unsupported database schema version: currently on v%d, latest known: v%d - continuing anyway", current, latest)
func (m mauLogger) WarnUnsupportedVersion(current, compat, latest int) {
m.l.Warnfln("Unsupported database schema version: currently on v%d (compatible down to v%d), latest known: v%d - continuing anyway", current, compat, latest)
}
func (m mauLogger) PrepareUpgrade(current, latest int) {
m.l.Infofln("Database currently on v%d, latest: v%d", current, latest)
func (m mauLogger) PrepareUpgrade(current, compat, latest int) {
m.l.Infofln("Database currently on v%d (compat: v%d), latest known: v%d", current, compat, latest)
}
func (m mauLogger) DoUpgrade(from, to int, message string, _ bool) {
m.l.Infofln("Upgrading database from v%d to v%d: %s", from, to, message)
}
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration) {
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration, _ error) {
if duration > 1*time.Second {
m.l.Warnfln("%s(%s) took %.3f seconds", method, query, duration.Seconds())
}
@@ -63,28 +64,40 @@ func (m mauLogger) Warn(msg string, args ...interface{}) {
type zeroLogger struct {
l *zerolog.Logger
ZeroLogSettings
}
func ZeroLogger(log zerolog.Logger) DatabaseLogger {
return ZeroLoggerPtr(&log)
type ZeroLogSettings struct {
CallerSkipFrame int
Caller bool
}
func ZeroLoggerPtr(log *zerolog.Logger) DatabaseLogger {
return &zeroLogger{l: log}
func ZeroLogger(log zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
return ZeroLoggerPtr(&log, cfg...)
}
func (z zeroLogger) WarnUnsupportedVersion(current, latest int) {
func ZeroLoggerPtr(log *zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
wrapped := &zeroLogger{l: log}
if len(cfg) > 0 {
wrapped.ZeroLogSettings = cfg[0]
}
return wrapped
}
func (z zeroLogger) WarnUnsupportedVersion(current, compat, latest int) {
z.l.Warn().
Int("current_db_version", current).
Int("current_version", current).
Int("oldest_compatible_version", compat).
Int("latest_known_version", latest).
Msg("Unsupported database schema version, continuing anyway")
}
func (z zeroLogger) PrepareUpgrade(current, latest int) {
func (z zeroLogger) PrepareUpgrade(current, compat, latest int) {
evt := z.l.Info().
Int("current_db_version", current).
Int("current_version", current).
Int("oldest_compatible_version", compat).
Int("latest_known_version", latest)
if current == latest {
if current >= latest {
evt.Msg("Database is up to date")
} else {
evt.Msg("Preparing to update database schema")
@@ -102,9 +115,9 @@ func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration) {
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled {
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
log = z.l
}
if log.GetLevel() != zerolog.TraceLevel && duration < 1*time.Second {
@@ -116,17 +129,21 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
}
query = strings.TrimSpace(whitespaceRegex.ReplaceAllLiteralString(query, " "))
log.Trace().
Err(err).
Int64("duration_µs", duration.Microseconds()).
Str("method", method).
Str("query", query).
Interface("query_args", args).
Msg("Query")
if duration >= 1*time.Second {
log.Warn().
evt := log.Warn().
Float64("duration_seconds", duration.Seconds()).
Str("method", method).
Str("query", query).
Msg("Query took long")
Str("query", query)
if z.Caller {
evt = evt.Caller(z.CallerSkipFrame)
}
evt.Msg("Query took long")
}
}

View File

@@ -0,0 +1,3 @@
-- v5 (compatible with v3+): Sample backwards-compatible upgrade
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -0,0 +1 @@
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -0,0 +1,82 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/util"
)
var (
ErrTxn = errors.New("transaction")
ErrTxnBegin = fmt.Errorf("%w: begin", ErrTxn)
ErrTxnCommit = fmt.Errorf("%w: commit", ErrTxn)
)
type contextKey int
const (
ContextKeyDatabaseTransaction contextKey = iota
ContextKeyDoTxnCallerSkip
)
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
zerolog.Ctx(ctx).Debug().Msg("Already in a transaction, not creating a new one")
return fn(ctx)
}
log := zerolog.Ctx(ctx).With().Str("db_txn_id", util.RandomString(12)).Logger()
start := time.Now()
defer func() {
dur := time.Since(start)
if dur > time.Second {
val := ctx.Value(ContextKeyDoTxnCallerSkip)
callerSkip := 2
if val != nil {
callerSkip += val.(int)
}
log.Warn().
Float64("duration_seconds", dur.Seconds()).
Caller(callerSkip).
Msg("Transaction took a long time")
}
}()
tx, err := db.BeginTx(ctx, opts)
if err != nil {
log.Trace().Err(err).Msg("Failed to begin transaction")
return util.NewDualError(ErrTxnBegin, err)
}
log.Trace().Msg("Transaction started")
tx.noTotalLog = true
ctx = log.WithContext(ctx)
ctx = context.WithValue(ctx, ContextKeyDatabaseTransaction, tx)
err = fn(ctx)
if err != nil {
log.Trace().Err(err).Msg("Database transaction failed, rolling back")
rollbackErr := tx.Rollback()
if rollbackErr != nil {
log.Warn().Err(rollbackErr).Msg("Rollback after transaction error failed")
} else {
log.Trace().Msg("Rollback successful")
}
return err
}
err = tx.Commit()
if err != nil {
log.Trace().Err(err).Msg("Commit failed")
return util.NewDualError(ErrTxnCommit, err)
}
log.Trace().Msg("Commit successful")
return nil
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -18,42 +18,97 @@ type upgrade struct {
message string
fn upgradeFunc
upgradesTo int
transaction bool
upgradesTo int
compatVersion int
transaction bool
}
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
var ErrNotOwned = fmt.Errorf("the database is owned by")
var ErrUnsupportedDatabaseVersion = errors.New("unsupported database schema version")
var ErrForeignTables = errors.New("the database contains foreign tables")
var ErrNotOwned = errors.New("the database is owned by")
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
func (db *Database) getVersion() (int, error) {
_, err := db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version INTEGER)", db.VersionTable))
if err != nil {
return -1, err
func (db *Database) upgradeVersionTable() error {
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
return fmt.Errorf("failed to check if version table is up to date: %w", err)
} else if !compatColumnExists {
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
} else if !tableExists {
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to create version table: %w", err)
}
} else {
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to add compat column to version table: %w", err)
}
}
}
version := 0
err = db.QueryRow(fmt.Sprintf("SELECT version FROM %s LIMIT 1", db.VersionTable)).Scan(&version)
if err != nil && !errors.Is(err, sql.ErrNoRows) {
return -1, err
}
return version, nil
return nil
}
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=$1)"
func (db *Database) getVersion() (version, compat int, err error) {
if err = db.upgradeVersionTable(); err != nil {
return
}
func (db *Database) tableExists(table string) (exists bool, err error) {
if db.Dialect == SQLite {
var compatNull sql.NullInt32
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if compatNull.Valid && compatNull.Int32 != 0 {
compat = int(compatNull.Int32)
} else {
compat = version
}
return
}
const (
tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
)
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect {
case SQLite:
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
} else if db.Dialect == Postgres {
case Postgres:
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
const (
columnExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name=$2)"
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
)
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect {
case SQLite:
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
case Postgres:
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
func (db *Database) tableExistsNoError(table string) bool {
exists, err := db.tableExists(table)
exists, err := db.TableExists(nil, table)
if err != nil {
panic(fmt.Errorf("failed to check if table exists: %w", err))
}
@@ -94,12 +149,12 @@ func (db *Database) checkDatabaseOwner() error {
return nil
}
func (db *Database) setVersion(tx Execable, version int) error {
func (db *Database) setVersion(tx Execable, version, compat int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", db.VersionTable), version)
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
return err
}
@@ -109,20 +164,20 @@ func (db *Database) Upgrade() error {
return err
}
version, err := db.getVersion()
version, compat, err := db.getVersion()
if err != nil {
return err
}
if version > len(db.UpgradeTable) {
if compat > len(db.UpgradeTable) {
if db.IgnoreUnsupportedDatabase {
db.Log.WarnUnsupportedVersion(version, len(db.UpgradeTable))
db.Log.WarnUnsupportedVersion(version, compat, len(db.UpgradeTable))
return nil
}
return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, len(db.UpgradeTable))
return fmt.Errorf("%w: currently on v%d (compatible down to v%d), latest known: v%d", ErrUnsupportedDatabaseVersion, version, compat, len(db.UpgradeTable))
}
db.Log.PrepareUpgrade(version, len(db.UpgradeTable))
db.Log.PrepareUpgrade(version, compat, len(db.UpgradeTable))
logVersion := version
for version < len(db.UpgradeTable) {
upgradeItem := db.UpgradeTable[version]
@@ -148,7 +203,7 @@ func (db *Database) Upgrade() error {
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(upgradeConn, version)
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
if err != nil {
return err
}

View File

@@ -1,4 +1,4 @@
// Copyright (c) 2022 Tulir Asokan
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
@@ -29,14 +29,17 @@ func (ut *UpgradeTable) extend(toSize int) {
}
}
func (ut *UpgradeTable) Register(from, to int, message string, txn bool, fn upgradeFunc) {
func (ut *UpgradeTable) Register(from, to, compat int, message string, txn bool, fn upgradeFunc) {
if from < 0 {
from += to
}
if from < 0 {
panic("invalid from value in UpgradeTable.Register() call")
}
upg := upgrade{message: message, fn: fn, upgradesTo: to, transaction: txn}
if compat <= 0 {
compat = to
}
upg := upgrade{message: message, fn: fn, upgradesTo: to, compatVersion: compat, transaction: txn}
if len(*ut) == from {
*ut = append(*ut, upg)
return
@@ -55,7 +58,11 @@ func (ut *UpgradeTable) Register(from, to int, message string, txn bool, fn upgr
// or
//
// -- v1: Message
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
//
// Both syntaxes may also have a compatibility notice before the colon:
//
// -- v5 (compatible with v3+): Upgrade with backwards compatibility
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+)(?: \(compatible with v(\d+)\+\))?: (.+)$`)
// To disable wrapping the upgrade in a single transaction, put `--transaction: off` on the second line.
//
@@ -64,7 +71,7 @@ var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
// // do dangerous stuff
var transactionDisableRegex = regexp.MustCompile(`^-- transaction: (\w*)`)
func parseFileHeader(file []byte) (from, to int, message string, txn bool, lines [][]byte, err error) {
func parseFileHeader(file []byte) (from, to, compat int, message string, txn bool, lines [][]byte, err error) {
lines = bytes.Split(file, []byte("\n"))
if len(lines) < 2 {
err = errors.New("upgrade file too short")
@@ -75,19 +82,22 @@ func parseFileHeader(file []byte) (from, to int, message string, txn bool, lines
lines = lines[1:]
if match == nil {
err = errors.New("header not found")
} else if len(match) != 4 {
} else if len(match) != 5 {
err = errors.New("unexpected number of items in regex match")
} else if maybeFrom, err = strconv.Atoi(string(match[1])); len(match[1]) > 0 && err != nil {
err = fmt.Errorf("invalid source version: %w", err)
} else if to, err = strconv.Atoi(string(match[2])); err != nil {
err = fmt.Errorf("invalid target version: %w", err)
} else if compat, err = strconv.Atoi(string(match[3])); len(match[3]) > 0 && err != nil {
err = fmt.Errorf("invalid compatible version: %w", err)
} else {
err = nil
if len(match[1]) > 0 {
from = maybeFrom
} else {
from = -1
}
message = string(match[3])
message = string(match[4])
txn = true
match = transactionDisableRegex.FindSubmatch(lines[0])
if match != nil {
@@ -205,7 +215,7 @@ func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
}
}
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, txn bool, fn upgradeFunc) {
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to, compat int, message string, txn bool, fn upgradeFunc) {
postgresName := fmt.Sprintf("%s.postgres.sql", name)
sqliteName := fmt.Sprintf("%s.sqlite.sql", name)
skipNames[postgresName] = struct{}{}
@@ -218,15 +228,15 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
if err != nil {
panic(err)
}
from, to, message, txn, _, err = parseFileHeader(postgresData)
from, to, compat, message, txn, _, err = parseFileHeader(postgresData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", postgresName, err))
}
sqliteFrom, sqliteTo, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
sqliteFrom, sqliteTo, sqliteCompat, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", sqliteName, err))
}
if from != sqliteFrom || to != sqliteTo {
if from != sqliteFrom || to != sqliteTo || compat != sqliteCompat {
panic(fmt.Errorf("mismatching versions in postgres and sqlite versions of %s: %d/%d -> %d/%d", name, from, sqliteFrom, to, sqliteTo))
} else if message != sqliteMessage {
panic(fmt.Errorf("mismatching message in postgres and sqlite versions of %s: %q != %q", name, message, sqliteMessage))
@@ -260,14 +270,14 @@ func (ut *UpgradeTable) RegisterFSPath(fs fullFS, dir string) {
} else if _, skip := skipNames[file.Name()]; skip {
// also do nothing
} else if splitName := splitFileNameRegex.FindStringSubmatch(file.Name()); splitName != nil {
from, to, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, message, txn, fn)
from, to, compat, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, compat, message, txn, fn)
} else if data, err := fs.ReadFile(filepath.Join(dir, file.Name())); err != nil {
panic(err)
} else if from, to, message, txn, lines, err := parseFileHeader(data); err != nil {
} else if from, to, compat, message, txn, lines, err := parseFileHeader(data); err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
} else {
ut.Register(from, to, message, txn, sqlUpgradeFunc(file.Name(), lines))
ut.Register(from, to, compat, message, txn, sqlUpgradeFunc(file.Name(), lines))
}
}
}

23
vendor/maunium.net/go/mautrix/util/returnonce.go generated vendored Normal file
View File

@@ -0,0 +1,23 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import "sync"
// ReturnableOnce is a wrapper for sync.Once that can return a value
type ReturnableOnce[Value any] struct {
once sync.Once
output Value
err error
}
func (ronce *ReturnableOnce[Value]) Do(fn func() (Value, error)) (Value, error) {
ronce.once.Do(func() {
ronce.output, ronce.err = fn()
})
return ronce.output, ronce.err
}

139
vendor/maunium.net/go/mautrix/util/ringbuffer.go generated vendored Normal file
View File

@@ -0,0 +1,139 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"errors"
"sync"
)
type pair[Key comparable, Value any] struct {
Set bool
Key Key
Value Value
}
type RingBuffer[Key comparable, Value any] struct {
ptr int
data []pair[Key, Value]
lock sync.RWMutex
size int
}
func NewRingBuffer[Key comparable, Value any](size int) *RingBuffer[Key, Value] {
return &RingBuffer[Key, Value]{
data: make([]pair[Key, Value], size),
}
}
var (
// StopIteration can be returned by the RingBuffer.Iter or MapRingBuffer callbacks to stop iteration immediately.
StopIteration = errors.New("stop iteration")
// SkipItem can be returned by the MapRingBuffer callback to skip adding a specific item.
SkipItem = errors.New("skip item")
)
func (rb *RingBuffer[Key, Value]) unlockedIter(callback func(key Key, val Value) error) error {
end := rb.ptr
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
entry := rb.data[i]
if !entry.Set {
break
}
err := callback(entry.Key, entry.Value)
if err != nil {
if errors.Is(err, StopIteration) {
return nil
}
return err
}
}
return nil
}
func (rb *RingBuffer[Key, Value]) Iter(callback func(key Key, val Value) error) error {
rb.lock.RLock()
defer rb.lock.RUnlock()
return rb.unlockedIter(callback)
}
func MapRingBuffer[Key comparable, Value, Output any](rb *RingBuffer[Key, Value], callback func(key Key, val Value) (Output, error)) ([]Output, error) {
rb.lock.RLock()
defer rb.lock.RUnlock()
output := make([]Output, 0, rb.size)
err := rb.unlockedIter(func(key Key, val Value) error {
item, err := callback(key, val)
if err != nil {
if errors.Is(err, SkipItem) {
return nil
}
return err
}
output = append(output, item)
return nil
})
return output, err
}
func (rb *RingBuffer[Key, Value]) Size() int {
rb.lock.RLock()
defer rb.lock.RUnlock()
return rb.size
}
func (rb *RingBuffer[Key, Value]) Contains(val Key) bool {
_, ok := rb.Get(val)
return ok
}
func (rb *RingBuffer[Key, Value]) Get(key Key) (val Value, found bool) {
rb.lock.RLock()
end := rb.ptr
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
if rb.data[i].Key == key {
val = rb.data[i].Value
found = true
break
}
}
rb.lock.RUnlock()
return
}
func (rb *RingBuffer[Key, Value]) Replace(key Key, val Value) bool {
rb.lock.Lock()
defer rb.lock.Unlock()
end := rb.ptr
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
if rb.data[i].Key == key {
rb.data[i].Value = val
return true
}
}
return false
}
func (rb *RingBuffer[Key, Value]) Push(key Key, val Value) {
rb.lock.Lock()
rb.data[rb.ptr] = pair[Key, Value]{Key: key, Value: val, Set: true}
rb.ptr = (rb.ptr + 1) % len(rb.data)
if rb.size < len(rb.data) {
rb.size++
}
rb.lock.Unlock()
}
func clamp(index, len int) int {
if index < 0 {
return len + index
} else if index >= len {
return len - index
} else {
return index
}
}

94
vendor/maunium.net/go/mautrix/util/syncmap.go generated vendored Normal file
View File

@@ -0,0 +1,94 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import "sync"
// SyncMap is a simple map with a built-in mutex.
type SyncMap[Key comparable, Value any] struct {
data map[Key]Value
lock sync.RWMutex
}
func NewSyncMap[Key comparable, Value any]() *SyncMap[Key, Value] {
return &SyncMap[Key, Value]{
data: make(map[Key]Value),
}
}
// Set stores a value in the map.
func (sm *SyncMap[Key, Value]) Set(key Key, value Value) {
sm.Swap(key, value)
}
// Swap sets a value in the map and returns the old value.
//
// The boolean return parameter is true if the value already existed, false if not.
func (sm *SyncMap[Key, Value]) Swap(key Key, value Value) (oldValue Value, wasReplaced bool) {
sm.lock.Lock()
oldValue, wasReplaced = sm.data[key]
sm.data[key] = value
sm.lock.Unlock()
return
}
// Delete removes a key from the map.
func (sm *SyncMap[Key, Value]) Delete(key Key) {
sm.Pop(key)
}
// Pop removes a key from the map and returns the old value.
//
// The boolean return parameter is the same as with normal Go map access (true if the key exists, false if not).
func (sm *SyncMap[Key, Value]) Pop(key Key) (value Value, ok bool) {
sm.lock.Lock()
value, ok = sm.data[key]
delete(sm.data, key)
sm.lock.Unlock()
return
}
// Get gets a value in the map.
//
// The boolean return parameter is the same as with normal Go map access (true if the key exists, false if not).
func (sm *SyncMap[Key, Value]) Get(key Key) (value Value, ok bool) {
sm.lock.RLock()
value, ok = sm.data[key]
sm.lock.RUnlock()
return
}
// GetOrSet gets a value in the map if the key already exists, otherwise inserts the given value and returns it.
//
// The boolean return parameter is true if the key already exists, and false if the given value was inserted.
func (sm *SyncMap[Key, Value]) GetOrSet(key Key, value Value) (actual Value, wasGet bool) {
sm.lock.Lock()
defer sm.lock.Unlock()
actual, wasGet = sm.data[key]
if wasGet {
return
}
sm.data[key] = value
actual = value
return
}
// Clone returns a copy of the map.
func (sm *SyncMap[Key, Value]) Clone() *SyncMap[Key, Value] {
return &SyncMap[Key, Value]{data: sm.CopyData()}
}
// CopyData returns a copy of the data in the map as a normal (non-atomic) map.
func (sm *SyncMap[Key, Value]) CopyData() map[Key]Value {
sm.lock.RLock()
copied := make(map[Key]Value, len(sm.data))
for key, value := range sm.data {
copied[key] = value
}
sm.lock.RUnlock()
return copied
}

View File

@@ -1,5 +1,31 @@
package mautrix
const Version = "v0.13.0"
import (
"fmt"
"regexp"
"runtime"
"strings"
)
var DefaultUserAgent = "mautrix-go/" + Version
const Version = "v0.15.2"
var GoModVersion = ""
var Commit = ""
var VersionWithCommit = Version
var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go")
var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`)
func init() {
if GoModVersion != "" {
match := goModVersionRegex.FindStringSubmatch(GoModVersion)
if match != nil {
Commit = match[1]
}
}
if Commit != "" {
VersionWithCommit = fmt.Sprintf("%s+dev.%s", Version, Commit[:8])
DefaultUserAgent = strings.Replace(DefaultUserAgent, "mautrix-go/"+Version, "mautrix-go/"+VersionWithCommit, 1)
}
}

View File

@@ -71,6 +71,8 @@ var (
SpecV13 = MustParseSpecVersion("v1.3")
SpecV14 = MustParseSpecVersion("v1.4")
SpecV15 = MustParseSpecVersion("v1.5")
SpecV16 = MustParseSpecVersion("v1.6")
SpecV17 = MustParseSpecVersion("v1.7")
)
func (svf SpecVersionFormat) String() string {