BREAKING: update mautrix to 0.15.x
This commit is contained in:
8
vendor/maunium.net/go/maulogger/v2/README.md
generated
vendored
8
vendor/maunium.net/go/maulogger/v2/README.md
generated
vendored
@@ -1,6 +1,6 @@
|
||||
# maulogger
|
||||
A logger in Go.
|
||||
A logger in Go. Deprecated in favor of [zerolog](https://github.com/rs/zerolog).
|
||||
|
||||
Docs: [godoc.org/maunium.net/go/maulogger](https://godoc.org/maunium.net/go/maulogger)
|
||||
|
||||
Go get: `go get maunium.net/go/maulogger`
|
||||
Utilities for migrating gracefully can be found in the maulogadapt package,
|
||||
it includes both wrapping a zerolog in the maulogger interface, and wrapping a
|
||||
maulogger as a zerolog output writer.
|
||||
|
||||
185
vendor/maunium.net/go/maulogger/v2/maulogadapt/mauzerolog.go
generated
vendored
Normal file
185
vendor/maunium.net/go/maulogger/v2/maulogadapt/mauzerolog.go
generated
vendored
Normal file
@@ -0,0 +1,185 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package maulogadapt
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"io"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/maulogger/v2"
|
||||
)
|
||||
|
||||
type MauZeroLog struct {
|
||||
*zerolog.Logger
|
||||
orig *zerolog.Logger
|
||||
mod string
|
||||
}
|
||||
|
||||
func ZeroAsMau(log *zerolog.Logger) maulogger.Logger {
|
||||
return MauZeroLog{log, log, ""}
|
||||
}
|
||||
|
||||
var _ maulogger.Logger = (*MauZeroLog)(nil)
|
||||
|
||||
func (m MauZeroLog) Sub(module string) maulogger.Logger {
|
||||
return m.Subm(module, map[string]interface{}{})
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Subm(module string, metadata map[string]interface{}) maulogger.Logger {
|
||||
if m.mod != "" {
|
||||
module = fmt.Sprintf("%s/%s", m.mod, module)
|
||||
}
|
||||
var orig zerolog.Logger
|
||||
if m.orig != nil {
|
||||
orig = *m.orig
|
||||
} else {
|
||||
orig = *m.Logger
|
||||
}
|
||||
if len(metadata) > 0 {
|
||||
with := m.orig.With()
|
||||
for key, value := range metadata {
|
||||
with = with.Interface(key, value)
|
||||
}
|
||||
orig = with.Logger()
|
||||
}
|
||||
log := orig.With().Str("module", module).Logger()
|
||||
return MauZeroLog{&log, &orig, module}
|
||||
}
|
||||
|
||||
func (m MauZeroLog) WithDefaultLevel(_ maulogger.Level) maulogger.Logger {
|
||||
return m
|
||||
}
|
||||
|
||||
func (m MauZeroLog) GetParent() maulogger.Logger {
|
||||
return nil
|
||||
}
|
||||
|
||||
type nopWriteCloser struct {
|
||||
io.Writer
|
||||
}
|
||||
|
||||
func (nopWriteCloser) Close() error { return nil }
|
||||
|
||||
func (m MauZeroLog) Writer(level maulogger.Level) io.WriteCloser {
|
||||
return nopWriteCloser{m.Logger.With().Str(zerolog.LevelFieldName, zerolog.LevelFieldMarshalFunc(mauToZeroLevel(level))).Logger()}
|
||||
}
|
||||
|
||||
func mauToZeroLevel(level maulogger.Level) zerolog.Level {
|
||||
switch level {
|
||||
case maulogger.LevelDebug:
|
||||
return zerolog.DebugLevel
|
||||
case maulogger.LevelInfo:
|
||||
return zerolog.InfoLevel
|
||||
case maulogger.LevelWarn:
|
||||
return zerolog.WarnLevel
|
||||
case maulogger.LevelError:
|
||||
return zerolog.ErrorLevel
|
||||
case maulogger.LevelFatal:
|
||||
return zerolog.FatalLevel
|
||||
default:
|
||||
return zerolog.TraceLevel
|
||||
}
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Log(level maulogger.Level, parts ...interface{}) {
|
||||
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprint(parts...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Logln(level maulogger.Level, parts ...interface{}) {
|
||||
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Logf(level maulogger.Level, message string, args ...interface{}) {
|
||||
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Logfln(level maulogger.Level, message string, args ...interface{}) {
|
||||
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Debug(parts ...interface{}) {
|
||||
m.Logger.Debug().Msg(fmt.Sprint(parts...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Debugln(parts ...interface{}) {
|
||||
m.Logger.Debug().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Debugf(message string, args ...interface{}) {
|
||||
m.Logger.Debug().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Debugfln(message string, args ...interface{}) {
|
||||
m.Logger.Debug().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Info(parts ...interface{}) {
|
||||
m.Logger.Info().Msg(fmt.Sprint(parts...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Infoln(parts ...interface{}) {
|
||||
m.Logger.Info().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Infof(message string, args ...interface{}) {
|
||||
m.Logger.Info().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Infofln(message string, args ...interface{}) {
|
||||
m.Logger.Info().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Warn(parts ...interface{}) {
|
||||
m.Logger.Warn().Msg(fmt.Sprint(parts...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Warnln(parts ...interface{}) {
|
||||
m.Logger.Warn().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Warnf(message string, args ...interface{}) {
|
||||
m.Logger.Warn().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Warnfln(message string, args ...interface{}) {
|
||||
m.Logger.Warn().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Error(parts ...interface{}) {
|
||||
m.Logger.Error().Msg(fmt.Sprint(parts...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Errorln(parts ...interface{}) {
|
||||
m.Logger.Error().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Errorf(message string, args ...interface{}) {
|
||||
m.Logger.Error().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Errorfln(message string, args ...interface{}) {
|
||||
m.Logger.Error().Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Fatal(parts ...interface{}) {
|
||||
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprint(parts...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Fatalln(parts ...interface{}) {
|
||||
m.Logger.WithLevel(zerolog.FatalLevel).Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Fatalf(message string, args ...interface{}) {
|
||||
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
|
||||
func (m MauZeroLog) Fatalfln(message string, args ...interface{}) {
|
||||
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprintf(message, args...))
|
||||
}
|
||||
73
vendor/maunium.net/go/maulogger/v2/maulogadapt/zeromaulog.go
generated
vendored
Normal file
73
vendor/maunium.net/go/maulogger/v2/maulogadapt/zeromaulog.go
generated
vendored
Normal file
@@ -0,0 +1,73 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package maulogadapt
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
|
||||
"maunium.net/go/maulogger/v2"
|
||||
)
|
||||
|
||||
// ZeroMauLog is a simple wrapper for a maulogger that can be set as the output writer for zerolog.
|
||||
type ZeroMauLog struct {
|
||||
maulogger.Logger
|
||||
}
|
||||
|
||||
func MauAsZero(log maulogger.Logger) *zerolog.Logger {
|
||||
zero := zerolog.New(&ZeroMauLog{log})
|
||||
return &zero
|
||||
}
|
||||
|
||||
var _ zerolog.LevelWriter = (*ZeroMauLog)(nil)
|
||||
|
||||
func (z *ZeroMauLog) Write(p []byte) (n int, err error) {
|
||||
return 0, nil
|
||||
}
|
||||
|
||||
func (z *ZeroMauLog) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
|
||||
var mauLevel maulogger.Level
|
||||
switch level {
|
||||
case zerolog.DebugLevel:
|
||||
mauLevel = maulogger.LevelDebug
|
||||
case zerolog.InfoLevel, zerolog.NoLevel:
|
||||
mauLevel = maulogger.LevelInfo
|
||||
case zerolog.WarnLevel:
|
||||
mauLevel = maulogger.LevelWarn
|
||||
case zerolog.ErrorLevel:
|
||||
mauLevel = maulogger.LevelError
|
||||
case zerolog.FatalLevel, zerolog.PanicLevel:
|
||||
mauLevel = maulogger.LevelFatal
|
||||
case zerolog.Disabled, zerolog.TraceLevel:
|
||||
fallthrough
|
||||
default:
|
||||
return 0, nil
|
||||
}
|
||||
p = bytes.TrimSuffix(p, []byte{'\n'})
|
||||
msg := gjson.GetBytes(p, zerolog.MessageFieldName).Str
|
||||
|
||||
p, err = sjson.DeleteBytes(p, zerolog.MessageFieldName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p, err = sjson.DeleteBytes(p, zerolog.LevelFieldName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
p, err = sjson.DeleteBytes(p, zerolog.TimestampFieldName)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
if len(p) > 2 {
|
||||
msg += " " + string(p)
|
||||
}
|
||||
z.Log(mauLevel, msg)
|
||||
return len(p), nil
|
||||
}
|
||||
2
vendor/maunium.net/go/mautrix/.gitignore
generated
vendored
2
vendor/maunium.net/go/mautrix/.gitignore
generated
vendored
@@ -1,2 +1,4 @@
|
||||
.idea/
|
||||
.vscode/
|
||||
*.db
|
||||
*.log
|
||||
|
||||
4
vendor/maunium.net/go/mautrix/.pre-commit-config.yaml
generated
vendored
4
vendor/maunium.net/go/mautrix/.pre-commit-config.yaml
generated
vendored
@@ -1,6 +1,6 @@
|
||||
repos:
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: v4.1.0
|
||||
rev: v4.4.0
|
||||
hooks:
|
||||
- id: trailing-whitespace
|
||||
exclude_types: [markdown]
|
||||
@@ -9,7 +9,7 @@ repos:
|
||||
- id: check-added-large-files
|
||||
|
||||
- repo: https://github.com/tekwizely/pre-commit-golang
|
||||
rev: v1.0.0-beta.5
|
||||
rev: v1.0.0-rc.1
|
||||
hooks:
|
||||
- id: go-imports-repo
|
||||
- id: go-vet-repo-mod
|
||||
|
||||
107
vendor/maunium.net/go/mautrix/CHANGELOG.md
generated
vendored
107
vendor/maunium.net/go/mautrix/CHANGELOG.md
generated
vendored
@@ -1,3 +1,110 @@
|
||||
## v0.15.2 (2023-05-16)
|
||||
|
||||
* *(client)* Changed member-fetching methods to clear existing member info in
|
||||
state store.
|
||||
* *(client)* Added support for inserting mautrix-go commit hash into default
|
||||
user agent at compile time.
|
||||
* *(bridge)* Fixed bridge bot intent not having state store set.
|
||||
* *(client)* Fixed `RespError` marshaling mutating the `ExtraData` map and
|
||||
potentially causing panics.
|
||||
* *(util/dbutil)* Added `DoTxn` method for an easier way to manage database
|
||||
transactions.
|
||||
* *(util)* Added a zerolog `CallerMarshalFunc` implementation that includes the
|
||||
function name.
|
||||
* *(bridge)* Added error reply to encrypted messages if the bridge isn't
|
||||
configured to do encryption.
|
||||
|
||||
## v0.15.1 (2023-04-16)
|
||||
|
||||
* *(crypto, bridge)* Added options to automatically ratchet/delete megolm
|
||||
sessions to minimize access to old messages.
|
||||
* *(pushrules)* Added method to get entire push rule that matched (instead of
|
||||
only the list of actions).
|
||||
* *(pushrules)* Deprecated `NotifySpecified` as there's no reason to read it.
|
||||
* *(crypto)* Changed `max_age` column in `crypto_megolm_inbound_session` table
|
||||
to be milliseconds instead of nanoseconds.
|
||||
* *(util)* Added method for iterating `RingBuffer`.
|
||||
* *(crypto/cryptohelper)* Changed decryption errors to request session from all
|
||||
own devices in addition to the sender, instead of only asking the sender.
|
||||
* *(sqlstatestore)* Fixed `FindSharedRooms` throwing an error when using from
|
||||
a non-bridge context.
|
||||
* *(client)* Optimized `AccountDataSyncStore` to not resend save requests if
|
||||
the sync token didn't change.
|
||||
* *(types)* Added `Clone()` method for `PowerLevelEventContent`.
|
||||
|
||||
## v0.15.0 (2023-03-16)
|
||||
|
||||
### beta.3 (2023-03-15)
|
||||
|
||||
* **Breaking change *(appservice)*** Removed `Load()` and `AppService.Init()`
|
||||
functions. The struct should just be created with `Create()` and the relevant
|
||||
fields should be filled manually.
|
||||
* **Breaking change *(appservice)*** Removed public `HomeserverURL` field and
|
||||
replaced it with a `SetHomeserverURL` method.
|
||||
* *(appservice)* Added support for unix sockets for homeserver URL and
|
||||
appservice HTTP server.
|
||||
* *(client)* Changed request logging to log durations as floats instead of
|
||||
strings (using zerolog's `Dur()`, so the exact output can be configured).
|
||||
* *(bridge)* Changed zerolog to use nanosecond precision timestamps.
|
||||
* *(crypto)* Added message index to log after encrypting/decrypting megolm
|
||||
events, and when failing to decrypt due to duplicate index.
|
||||
* *(sqlstatestore)* Fixed warning log for rooms that don't have encryption
|
||||
enabled.
|
||||
|
||||
### beta.2 (2023-03-02)
|
||||
|
||||
* *(bridge)* Fixed building with `nocrypto` tag.
|
||||
* *(bridge)* Fixed legacy logging config migration not disabling file writer
|
||||
when `file_name_format` was empty.
|
||||
* *(bridge)* Added option to require room power level to run commands.
|
||||
* *(event)* Added structs for [MSC3952]: Intentional Mentions.
|
||||
* *(util/variationselector)* Added `FullyQualify` method to add necessary emoji
|
||||
variation selectors without adding all possible ones.
|
||||
|
||||
[MSC3952]: https://github.com/matrix-org/matrix-spec-proposals/pull/3952
|
||||
|
||||
### beta.1 (2023-02-24)
|
||||
|
||||
* Bumped minimum Go version to 1.19.
|
||||
* **Breaking changes**
|
||||
* *(all)* Switched to zerolog for logging.
|
||||
* The `Client` and `Bridge` structs still include a legacy logger for
|
||||
backwards compatibility.
|
||||
* *(client, appservice)* Moved `SQLStateStore` from appservice module to the
|
||||
top-level (client) module.
|
||||
* *(client, appservice)* Removed unused `Typing` map in `SQLStateStore`.
|
||||
* *(client)* Removed unused `SaveRoom` and `LoadRoom` methods in `Storer`.
|
||||
* *(client, appservice)* Removed deprecated `SendVideo` and `SendImage` methods.
|
||||
* *(client)* Replaced `AppServiceUserID` field with `SetAppServiceUserID` boolean.
|
||||
The `UserID` field is used as the value for the query param.
|
||||
* *(crypto)* Renamed `GobStore` to `MemoryStore` and removed the file saving
|
||||
features. The data can still be persisted, but the persistence part must be
|
||||
implemented separately.
|
||||
* *(crypto)* Removed deprecated `DeviceIdentity` alias
|
||||
(renamed to `id.Device` long ago).
|
||||
* *(client)* Removed `Stringifable` interface as it's the same as `fmt.Stringer`.
|
||||
* *(client)* Renamed `Storer` interface to `SyncStore`. A type alias exists for
|
||||
backwards-compatibility.
|
||||
* *(crypto/cryptohelper)* Added package for a simplified crypto interface for clients.
|
||||
* *(example)* Added e2ee support to example using crypto helper.
|
||||
* *(client)* Changed default syncer to stop syncing on `M_UNKNOWN_TOKEN` errors.
|
||||
|
||||
## v0.14.0 (2023-02-16)
|
||||
|
||||
* **Breaking change *(format)*** Refactored the HTML parser `Context` to have
|
||||
more data.
|
||||
* *(id)* Fixed escaping path components when forming matrix.to URLs
|
||||
or `matrix:` URIs.
|
||||
* *(bridge)* Bumped default timeouts for decrypting incoming messages.
|
||||
* *(bridge)* Added `RawArgs` to commands to allow accessing non-split input.
|
||||
* *(bridge)* Added `ReplyAdvanced` to commands to allow setting markdown
|
||||
settings.
|
||||
* *(event)* Added `notifications` key to `PowerLevelEventContent`.
|
||||
* *(event)* Changed `SetEdit` to cut off edit fallback if the message is long.
|
||||
* *(util)* Added `SyncMap` as a simple generic wrapper for a map with a mutex.
|
||||
* *(util)* Added `ReturnableOnce` as a wrapper for `sync.Once` with a return
|
||||
value.
|
||||
|
||||
## v0.13.0 (2023-01-16)
|
||||
|
||||
* **Breaking change:** Removed `IsTyping` and `SetTyping` in `appservice.StateStore`
|
||||
|
||||
410
vendor/maunium.net/go/mautrix/appservice/appservice.go
generated
vendored
410
vendor/maunium.net/go/mautrix/appservice/appservice.go
generated
vendored
@@ -1,410 +0,0 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"html/template"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"net/http/cookiejar"
|
||||
"os"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"syscall"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/gorilla/websocket"
|
||||
"golang.org/x/net/publicsuffix"
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// EventChannelSize is the size for the Events channel in Appservice instances.
|
||||
var EventChannelSize = 64
|
||||
var OTKChannelSize = 4
|
||||
|
||||
// Create a blank appservice instance.
|
||||
func Create() *AppService {
|
||||
jar, _ := cookiejar.New(&cookiejar.Options{PublicSuffixList: publicsuffix.List})
|
||||
return &AppService{
|
||||
LogConfig: CreateLogConfig(),
|
||||
clients: make(map[id.UserID]*mautrix.Client),
|
||||
intents: make(map[id.UserID]*IntentAPI),
|
||||
HTTPClient: &http.Client{Timeout: 180 * time.Second, Jar: jar},
|
||||
StateStore: NewBasicStateStore(),
|
||||
Router: mux.NewRouter(),
|
||||
UserAgent: mautrix.DefaultUserAgent,
|
||||
txnIDC: NewTransactionIDCache(128),
|
||||
Live: true,
|
||||
Ready: false,
|
||||
}
|
||||
}
|
||||
|
||||
// Load an appservice config from a file.
|
||||
func Load(path string) (*AppService, error) {
|
||||
data, readErr := ioutil.ReadFile(path)
|
||||
if readErr != nil {
|
||||
return nil, readErr
|
||||
}
|
||||
|
||||
config := Create()
|
||||
return config, yaml.Unmarshal(data, config)
|
||||
}
|
||||
|
||||
// QueryHandler handles room alias and user ID queries from the homeserver.
|
||||
type QueryHandler interface {
|
||||
QueryAlias(alias string) bool
|
||||
QueryUser(userID id.UserID) bool
|
||||
}
|
||||
|
||||
type QueryHandlerStub struct{}
|
||||
|
||||
func (qh *QueryHandlerStub) QueryAlias(alias string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (qh *QueryHandlerStub) QueryUser(userID id.UserID) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
type WebsocketHandler func(WebsocketCommand) (ok bool, data interface{})
|
||||
|
||||
// AppService is the main config for all appservices.
|
||||
// It also serves as the appservice instance struct.
|
||||
type AppService struct {
|
||||
HomeserverDomain string `yaml:"homeserver_domain"`
|
||||
HomeserverURL string `yaml:"homeserver_url"`
|
||||
RegistrationPath string `yaml:"registration"`
|
||||
Host HostConfig `yaml:"host"`
|
||||
LogConfig LogConfig `yaml:"logging"`
|
||||
|
||||
Registration *Registration `yaml:"-"`
|
||||
Log maulogger.Logger `yaml:"-"`
|
||||
|
||||
txnIDC *TransactionIDCache
|
||||
|
||||
Events chan *event.Event `yaml:"-"`
|
||||
ToDeviceEvents chan *event.Event `yaml:"-"`
|
||||
DeviceLists chan *mautrix.DeviceLists `yaml:"-"`
|
||||
OTKCounts chan *mautrix.OTKCount `yaml:"-"`
|
||||
QueryHandler QueryHandler `yaml:"-"`
|
||||
StateStore StateStore `yaml:"-"`
|
||||
|
||||
Router *mux.Router `yaml:"-"`
|
||||
UserAgent string `yaml:"-"`
|
||||
server *http.Server
|
||||
HTTPClient *http.Client
|
||||
botClient *mautrix.Client
|
||||
botIntent *IntentAPI
|
||||
|
||||
DefaultHTTPRetries int
|
||||
|
||||
Live bool
|
||||
Ready bool
|
||||
|
||||
clients map[id.UserID]*mautrix.Client
|
||||
clientsLock sync.RWMutex
|
||||
intents map[id.UserID]*IntentAPI
|
||||
intentsLock sync.RWMutex
|
||||
|
||||
ws *websocket.Conn
|
||||
wsWriteLock sync.Mutex
|
||||
StopWebsocket func(error)
|
||||
websocketHandlers map[string]WebsocketHandler
|
||||
websocketHandlersLock sync.RWMutex
|
||||
websocketRequests map[int]chan<- *WebsocketCommand
|
||||
websocketRequestsLock sync.RWMutex
|
||||
websocketRequestID int32
|
||||
// ProcessID is an identifier sent to the websocket proxy for debugging connections
|
||||
ProcessID string
|
||||
|
||||
DoublePuppetValue string
|
||||
GetProfile func(userID id.UserID, roomID id.RoomID) *event.MemberEventContent
|
||||
}
|
||||
|
||||
const DoublePuppetKey = "fi.mau.double_puppet_source"
|
||||
|
||||
func getDefaultProcessID() string {
|
||||
pid := syscall.Getpid()
|
||||
uid := syscall.Getuid()
|
||||
hostname, _ := os.Hostname()
|
||||
return fmt.Sprintf("%s-%d-%d", hostname, uid, pid)
|
||||
}
|
||||
|
||||
func (as *AppService) PrepareWebsocket() {
|
||||
as.websocketHandlersLock.Lock()
|
||||
defer as.websocketHandlersLock.Unlock()
|
||||
if as.websocketHandlers == nil {
|
||||
as.websocketHandlers = make(map[string]WebsocketHandler, 32)
|
||||
as.websocketRequests = make(map[int]chan<- *WebsocketCommand)
|
||||
}
|
||||
}
|
||||
|
||||
// HostConfig contains info about how to host the appservice.
|
||||
type HostConfig struct {
|
||||
Hostname string `yaml:"hostname"`
|
||||
Port uint16 `yaml:"port"`
|
||||
TLSKey string `yaml:"tls_key,omitempty"`
|
||||
TLSCert string `yaml:"tls_cert,omitempty"`
|
||||
}
|
||||
|
||||
// Address gets the whole address of the Appservice.
|
||||
func (hc *HostConfig) Address() string {
|
||||
return fmt.Sprintf("%s:%d", hc.Hostname, hc.Port)
|
||||
}
|
||||
|
||||
// Save saves this config into a file at the given path.
|
||||
func (as *AppService) Save(path string) error {
|
||||
data, err := yaml.Marshal(as)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil.WriteFile(path, data, 0644)
|
||||
}
|
||||
|
||||
// YAML returns the config in YAML format.
|
||||
func (as *AppService) YAML() (string, error) {
|
||||
data, err := yaml.Marshal(as)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
func (as *AppService) BotMXID() id.UserID {
|
||||
return id.NewUserID(as.Registration.SenderLocalpart, as.HomeserverDomain)
|
||||
}
|
||||
|
||||
func (as *AppService) makeIntent(userID id.UserID) *IntentAPI {
|
||||
as.intentsLock.Lock()
|
||||
defer as.intentsLock.Unlock()
|
||||
|
||||
intent, ok := as.intents[userID]
|
||||
if ok {
|
||||
return intent
|
||||
}
|
||||
|
||||
localpart, homeserver, err := userID.Parse()
|
||||
if err != nil || len(localpart) == 0 || homeserver != as.HomeserverDomain {
|
||||
if err != nil {
|
||||
as.Log.Fatalfln("Failed to parse user ID %s: %v", userID, err)
|
||||
} else if len(localpart) == 0 {
|
||||
as.Log.Fatalfln("Failed to make intent for %s: localpart is empty", userID)
|
||||
} else if homeserver != as.HomeserverDomain {
|
||||
as.Log.Fatalfln("Failed to make intent for %s: homeserver isn't %s", userID, as.HomeserverDomain)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
intent = as.NewIntentAPI(localpart)
|
||||
as.intents[userID] = intent
|
||||
return intent
|
||||
}
|
||||
|
||||
func (as *AppService) Intent(userID id.UserID) *IntentAPI {
|
||||
as.intentsLock.RLock()
|
||||
intent, ok := as.intents[userID]
|
||||
as.intentsLock.RUnlock()
|
||||
if !ok {
|
||||
return as.makeIntent(userID)
|
||||
}
|
||||
return intent
|
||||
}
|
||||
|
||||
func (as *AppService) BotIntent() *IntentAPI {
|
||||
if as.botIntent == nil {
|
||||
as.botIntent = as.makeIntent(as.BotMXID())
|
||||
}
|
||||
return as.botIntent
|
||||
}
|
||||
|
||||
func (as *AppService) makeClient(userID id.UserID) *mautrix.Client {
|
||||
as.clientsLock.Lock()
|
||||
defer as.clientsLock.Unlock()
|
||||
|
||||
client, ok := as.clients[userID]
|
||||
if ok {
|
||||
return client
|
||||
}
|
||||
|
||||
client, err := mautrix.NewClient(as.HomeserverURL, userID, as.Registration.AppToken)
|
||||
if err != nil {
|
||||
as.Log.Fatalln("Failed to create mautrix client instance:", err)
|
||||
return nil
|
||||
}
|
||||
client.UserAgent = as.UserAgent
|
||||
client.Syncer = nil
|
||||
client.Store = nil
|
||||
client.AppServiceUserID = userID
|
||||
client.Logger = as.Log.Sub(string(userID))
|
||||
client.Client = as.HTTPClient
|
||||
client.DefaultHTTPRetries = as.DefaultHTTPRetries
|
||||
as.clients[userID] = client
|
||||
return client
|
||||
}
|
||||
|
||||
func (as *AppService) Client(userID id.UserID) *mautrix.Client {
|
||||
as.clientsLock.RLock()
|
||||
client, ok := as.clients[userID]
|
||||
as.clientsLock.RUnlock()
|
||||
if !ok {
|
||||
return as.makeClient(userID)
|
||||
}
|
||||
return client
|
||||
}
|
||||
|
||||
func (as *AppService) BotClient() *mautrix.Client {
|
||||
if as.botClient == nil {
|
||||
as.botClient = as.makeClient(as.BotMXID())
|
||||
as.botClient.Logger = as.Log.Sub("Bot")
|
||||
}
|
||||
return as.botClient
|
||||
}
|
||||
|
||||
// Init initializes the logger and loads the registration of this appservice.
|
||||
func (as *AppService) Init() (bool, error) {
|
||||
as.Events = make(chan *event.Event, EventChannelSize)
|
||||
as.ToDeviceEvents = make(chan *event.Event, EventChannelSize)
|
||||
as.OTKCounts = make(chan *mautrix.OTKCount, OTKChannelSize)
|
||||
as.DeviceLists = make(chan *mautrix.DeviceLists, EventChannelSize)
|
||||
as.QueryHandler = &QueryHandlerStub{}
|
||||
|
||||
if len(as.UserAgent) == 0 {
|
||||
as.UserAgent = mautrix.DefaultUserAgent
|
||||
}
|
||||
if len(as.ProcessID) == 0 {
|
||||
as.ProcessID = getDefaultProcessID()
|
||||
}
|
||||
|
||||
as.Log = maulogger.Create()
|
||||
as.LogConfig.Configure(as.Log)
|
||||
as.Log.Debugln("Logger initialized successfully.")
|
||||
|
||||
if len(as.RegistrationPath) > 0 {
|
||||
var err error
|
||||
as.Registration, err = LoadRegistration(as.RegistrationPath)
|
||||
if err != nil {
|
||||
return false, err
|
||||
}
|
||||
}
|
||||
|
||||
as.Log.Debugln("Appservice initialized successfully.")
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// LogConfig contains configs for the logger.
|
||||
type LogConfig struct {
|
||||
Directory string `yaml:"directory"`
|
||||
FileNameFormat string `yaml:"file_name_format"`
|
||||
FileDateFormat string `yaml:"file_date_format"`
|
||||
FileMode uint32 `yaml:"file_mode"`
|
||||
TimestampFormat string `yaml:"timestamp_format"`
|
||||
RawPrintLevel string `yaml:"print_level"`
|
||||
JSONStdout bool `yaml:"print_json"`
|
||||
JSONFile bool `yaml:"file_json"`
|
||||
PrintLevel int `yaml:"-"`
|
||||
}
|
||||
|
||||
type umLogConfig LogConfig
|
||||
|
||||
func (lc *LogConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
err := unmarshal((*umLogConfig)(lc))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
switch strings.ToUpper(lc.RawPrintLevel) {
|
||||
case "TRACE":
|
||||
lc.PrintLevel = -10
|
||||
case "DEBUG":
|
||||
lc.PrintLevel = maulogger.LevelDebug.Severity
|
||||
case "INFO":
|
||||
lc.PrintLevel = maulogger.LevelInfo.Severity
|
||||
case "WARN", "WARNING":
|
||||
lc.PrintLevel = maulogger.LevelWarn.Severity
|
||||
case "ERR", "ERROR":
|
||||
lc.PrintLevel = maulogger.LevelError.Severity
|
||||
case "FATAL":
|
||||
lc.PrintLevel = maulogger.LevelFatal.Severity
|
||||
default:
|
||||
return errors.New("invalid print level " + lc.RawPrintLevel)
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (lc *LogConfig) MarshalYAML() (interface{}, error) {
|
||||
switch {
|
||||
case lc.PrintLevel >= maulogger.LevelFatal.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelFatal.Name
|
||||
case lc.PrintLevel >= maulogger.LevelError.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelError.Name
|
||||
case lc.PrintLevel >= maulogger.LevelWarn.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelWarn.Name
|
||||
case lc.PrintLevel >= maulogger.LevelInfo.Severity:
|
||||
lc.RawPrintLevel = maulogger.LevelInfo.Name
|
||||
default:
|
||||
lc.RawPrintLevel = maulogger.LevelDebug.Name
|
||||
}
|
||||
return lc, nil
|
||||
}
|
||||
|
||||
// CreateLogConfig creates a basic LogConfig.
|
||||
func CreateLogConfig() LogConfig {
|
||||
return LogConfig{
|
||||
Directory: "./logs",
|
||||
FileNameFormat: "%[1]s-%02[2]d.log",
|
||||
TimestampFormat: "Jan _2, 2006 15:04:05",
|
||||
FileMode: 0600,
|
||||
FileDateFormat: "2006-01-02",
|
||||
PrintLevel: 10,
|
||||
}
|
||||
}
|
||||
|
||||
type FileFormatData struct {
|
||||
Date string
|
||||
Index int
|
||||
}
|
||||
|
||||
// GetFileFormat returns a mauLogger-compatible logger file format based on the data in the struct.
|
||||
func (lc LogConfig) GetFileFormat() maulogger.LoggerFileFormat {
|
||||
if len(lc.Directory) > 0 {
|
||||
_ = os.MkdirAll(lc.Directory, 0700)
|
||||
}
|
||||
path := filepath.Join(lc.Directory, lc.FileNameFormat)
|
||||
tpl, _ := template.New("fileformat").Parse(path)
|
||||
|
||||
return func(now string, i int) string {
|
||||
var buf strings.Builder
|
||||
_ = tpl.Execute(&buf, FileFormatData{
|
||||
Date: now,
|
||||
Index: i,
|
||||
})
|
||||
return buf.String()
|
||||
}
|
||||
}
|
||||
|
||||
// Configure configures a mauLogger instance with the data in this struct.
|
||||
func (lc LogConfig) Configure(log maulogger.Logger) {
|
||||
basicLogger := log.(*maulogger.BasicLogger)
|
||||
basicLogger.FileFormat = lc.GetFileFormat()
|
||||
basicLogger.FileMode = os.FileMode(lc.FileMode)
|
||||
basicLogger.FileTimeFormat = lc.FileDateFormat
|
||||
basicLogger.TimeFormat = lc.TimestampFormat
|
||||
basicLogger.PrintLevel = lc.PrintLevel
|
||||
basicLogger.JSONFile = lc.JSONFile
|
||||
if lc.JSONStdout {
|
||||
basicLogger.EnableJSONStdout()
|
||||
}
|
||||
}
|
||||
173
vendor/maunium.net/go/mautrix/appservice/eventprocessor.go
generated
vendored
173
vendor/maunium.net/go/mautrix/appservice/eventprocessor.go
generated
vendored
@@ -1,173 +0,0 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"runtime/debug"
|
||||
|
||||
log "maunium.net/go/maulogger/v2"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
)
|
||||
|
||||
type ExecMode uint8
|
||||
|
||||
const (
|
||||
AsyncHandlers ExecMode = iota
|
||||
AsyncLoop
|
||||
Sync
|
||||
)
|
||||
|
||||
type EventHandler func(evt *event.Event)
|
||||
type OTKHandler func(otk *mautrix.OTKCount)
|
||||
type DeviceListHandler func(otk *mautrix.DeviceLists, since string)
|
||||
|
||||
type EventProcessor struct {
|
||||
ExecMode ExecMode
|
||||
|
||||
as *AppService
|
||||
log log.Logger
|
||||
stop chan struct{}
|
||||
handlers map[event.Type][]EventHandler
|
||||
|
||||
otkHandlers []OTKHandler
|
||||
deviceListHandlers []DeviceListHandler
|
||||
}
|
||||
|
||||
func NewEventProcessor(as *AppService) *EventProcessor {
|
||||
return &EventProcessor{
|
||||
ExecMode: AsyncHandlers,
|
||||
as: as,
|
||||
log: as.Log.Sub("Events"),
|
||||
stop: make(chan struct{}, 1),
|
||||
handlers: make(map[event.Type][]EventHandler),
|
||||
|
||||
otkHandlers: make([]OTKHandler, 0),
|
||||
deviceListHandlers: make([]DeviceListHandler, 0),
|
||||
}
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) On(evtType event.Type, handler EventHandler) {
|
||||
handlers, ok := ep.handlers[evtType]
|
||||
if !ok {
|
||||
handlers = []EventHandler{handler}
|
||||
} else {
|
||||
handlers = append(handlers, handler)
|
||||
}
|
||||
ep.handlers[evtType] = handlers
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) PrependHandler(evtType event.Type, handler EventHandler) {
|
||||
handlers, ok := ep.handlers[evtType]
|
||||
if !ok {
|
||||
handlers = []EventHandler{handler}
|
||||
} else {
|
||||
handlers = append([]EventHandler{handler}, handlers...)
|
||||
}
|
||||
ep.handlers[evtType] = handlers
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) OnOTK(handler OTKHandler) {
|
||||
ep.otkHandlers = append(ep.otkHandlers, handler)
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) OnDeviceList(handler DeviceListHandler) {
|
||||
ep.deviceListHandlers = append(ep.deviceListHandlers, handler)
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) recoverFunc(data interface{}) {
|
||||
if err := recover(); err != nil {
|
||||
d, _ := json.Marshal(data)
|
||||
ep.log.Errorfln("Panic in Matrix event handler: %v (event content: %s):\n%s", err, string(d), string(debug.Stack()))
|
||||
}
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) callHandler(handler EventHandler, evt *event.Event) {
|
||||
defer ep.recoverFunc(evt)
|
||||
handler(evt)
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) callOTKHandler(handler OTKHandler, otk *mautrix.OTKCount) {
|
||||
defer ep.recoverFunc(otk)
|
||||
handler(otk)
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) callDeviceListHandler(handler DeviceListHandler, dl *mautrix.DeviceLists) {
|
||||
defer ep.recoverFunc(dl)
|
||||
handler(dl, "")
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) DispatchOTK(otk *mautrix.OTKCount) {
|
||||
for _, handler := range ep.otkHandlers {
|
||||
go ep.callOTKHandler(handler, otk)
|
||||
}
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) DispatchDeviceList(dl *mautrix.DeviceLists) {
|
||||
for _, handler := range ep.deviceListHandlers {
|
||||
go ep.callDeviceListHandler(handler, dl)
|
||||
}
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) Dispatch(evt *event.Event) {
|
||||
handlers, ok := ep.handlers[evt.Type]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
switch ep.ExecMode {
|
||||
case AsyncHandlers:
|
||||
for _, handler := range handlers {
|
||||
go ep.callHandler(handler, evt)
|
||||
}
|
||||
case AsyncLoop:
|
||||
go func() {
|
||||
for _, handler := range handlers {
|
||||
ep.callHandler(handler, evt)
|
||||
}
|
||||
}()
|
||||
case Sync:
|
||||
for _, handler := range handlers {
|
||||
ep.callHandler(handler, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
func (ep *EventProcessor) startEvents() {
|
||||
for {
|
||||
select {
|
||||
case evt := <-ep.as.Events:
|
||||
ep.Dispatch(evt)
|
||||
case <-ep.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) startEncryption() {
|
||||
for {
|
||||
select {
|
||||
case evt := <-ep.as.ToDeviceEvents:
|
||||
ep.Dispatch(evt)
|
||||
case otk := <-ep.as.OTKCounts:
|
||||
ep.DispatchOTK(otk)
|
||||
case dl := <-ep.as.DeviceLists:
|
||||
ep.DispatchDeviceList(dl)
|
||||
case <-ep.stop:
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) Start() {
|
||||
go ep.startEvents()
|
||||
go ep.startEncryption()
|
||||
}
|
||||
|
||||
func (ep *EventProcessor) Stop() {
|
||||
close(ep.stop)
|
||||
}
|
||||
281
vendor/maunium.net/go/mautrix/appservice/http.go
generated
vendored
281
vendor/maunium.net/go/mautrix/appservice/http.go
generated
vendored
@@ -1,281 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"io/ioutil"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Start starts the HTTP server that listens for calls from the Matrix homeserver.
|
||||
func (as *AppService) Start() {
|
||||
as.Router.HandleFunc("/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
|
||||
as.Router.HandleFunc("/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/users/{userID}", as.GetUser).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/_matrix/app/v1/transactions/{txnID}", as.PutTransaction).Methods(http.MethodPut)
|
||||
as.Router.HandleFunc("/_matrix/app/v1/rooms/{roomAlias}", as.GetRoom).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/_matrix/app/v1/users/{userID}", as.GetUser).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/_matrix/mau/live", as.GetLive).Methods(http.MethodGet)
|
||||
as.Router.HandleFunc("/_matrix/mau/ready", as.GetReady).Methods(http.MethodGet)
|
||||
|
||||
var err error
|
||||
as.server = &http.Server{
|
||||
Addr: as.Host.Address(),
|
||||
Handler: as.Router,
|
||||
}
|
||||
as.Log.Infoln("Listening on", as.Host.Address())
|
||||
if len(as.Host.TLSCert) == 0 || len(as.Host.TLSKey) == 0 {
|
||||
err = as.server.ListenAndServe()
|
||||
} else {
|
||||
err = as.server.ListenAndServeTLS(as.Host.TLSCert, as.Host.TLSKey)
|
||||
}
|
||||
if err != nil && !errors.Is(err, http.ErrServerClosed) {
|
||||
as.Log.Fatalln("Error while listening:", err)
|
||||
} else {
|
||||
as.Log.Debugln("Listener stopped.")
|
||||
}
|
||||
}
|
||||
|
||||
func (as *AppService) Stop() {
|
||||
if as.server == nil {
|
||||
return
|
||||
}
|
||||
|
||||
ctx, cancel := context.WithTimeout(context.Background(), 5*time.Second)
|
||||
defer cancel()
|
||||
_ = as.server.Shutdown(ctx)
|
||||
as.server = nil
|
||||
}
|
||||
|
||||
// CheckServerToken checks if the given request originated from the Matrix homeserver.
|
||||
func (as *AppService) CheckServerToken(w http.ResponseWriter, r *http.Request) (isValid bool) {
|
||||
authHeader := r.Header.Get("Authorization")
|
||||
if len(authHeader) > 0 && strings.HasPrefix(authHeader, "Bearer ") {
|
||||
isValid = authHeader[len("Bearer "):] == as.Registration.ServerToken
|
||||
} else {
|
||||
queryToken := r.URL.Query().Get("access_token")
|
||||
if len(queryToken) > 0 {
|
||||
isValid = queryToken == as.Registration.ServerToken
|
||||
} else {
|
||||
Error{
|
||||
ErrorCode: ErrUnknownToken,
|
||||
HTTPStatus: http.StatusForbidden,
|
||||
Message: "Missing access token",
|
||||
}.Write(w)
|
||||
return
|
||||
}
|
||||
}
|
||||
if !isValid {
|
||||
Error{
|
||||
ErrorCode: ErrUnknownToken,
|
||||
HTTPStatus: http.StatusForbidden,
|
||||
Message: "Incorrect access token",
|
||||
}.Write(w)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PutTransaction handles a /transactions PUT call from the homeserver.
|
||||
func (as *AppService) PutTransaction(w http.ResponseWriter, r *http.Request) {
|
||||
if !as.CheckServerToken(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
txnID := vars["txnID"]
|
||||
if len(txnID) == 0 {
|
||||
Error{
|
||||
ErrorCode: ErrNoTransactionID,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "Missing transaction ID",
|
||||
}.Write(w)
|
||||
return
|
||||
}
|
||||
defer r.Body.Close()
|
||||
body, err := ioutil.ReadAll(r.Body)
|
||||
if err != nil || len(body) == 0 {
|
||||
Error{
|
||||
ErrorCode: ErrNotJSON,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "Missing request body",
|
||||
}.Write(w)
|
||||
return
|
||||
}
|
||||
if as.txnIDC.IsProcessed(txnID) {
|
||||
// Duplicate transaction ID: no-op
|
||||
WriteBlankOK(w)
|
||||
as.Log.Debugfln("Ignoring duplicate transaction %s", txnID)
|
||||
return
|
||||
}
|
||||
|
||||
var txn Transaction
|
||||
err = json.Unmarshal(body, &txn)
|
||||
if err != nil {
|
||||
as.Log.Warnfln("Failed to parse JSON of transaction %s: %v", txnID, err)
|
||||
Error{
|
||||
ErrorCode: ErrBadJSON,
|
||||
HTTPStatus: http.StatusBadRequest,
|
||||
Message: "Failed to parse body JSON",
|
||||
}.Write(w)
|
||||
} else {
|
||||
as.handleTransaction(txnID, &txn)
|
||||
WriteBlankOK(w)
|
||||
}
|
||||
}
|
||||
|
||||
func (as *AppService) handleTransaction(id string, txn *Transaction) {
|
||||
as.Log.Debugfln("Starting handling of transaction %s (%s)", id, txn.ContentString())
|
||||
if as.Registration.EphemeralEvents {
|
||||
if txn.EphemeralEvents != nil {
|
||||
as.handleEvents(txn.EphemeralEvents, event.EphemeralEventType)
|
||||
} else if txn.MSC2409EphemeralEvents != nil {
|
||||
as.handleEvents(txn.MSC2409EphemeralEvents, event.EphemeralEventType)
|
||||
}
|
||||
if txn.ToDeviceEvents != nil {
|
||||
as.handleEvents(txn.ToDeviceEvents, event.ToDeviceEventType)
|
||||
} else if txn.MSC2409ToDeviceEvents != nil {
|
||||
as.handleEvents(txn.MSC2409ToDeviceEvents, event.ToDeviceEventType)
|
||||
}
|
||||
}
|
||||
as.handleEvents(txn.Events, event.UnknownEventType)
|
||||
if txn.DeviceLists != nil {
|
||||
as.handleDeviceLists(txn.DeviceLists)
|
||||
} else if txn.MSC3202DeviceLists != nil {
|
||||
as.handleDeviceLists(txn.MSC3202DeviceLists)
|
||||
}
|
||||
if txn.DeviceOTKCount != nil {
|
||||
as.handleOTKCounts(txn.DeviceOTKCount)
|
||||
} else if txn.MSC3202DeviceOTKCount != nil {
|
||||
as.handleOTKCounts(txn.MSC3202DeviceOTKCount)
|
||||
}
|
||||
as.txnIDC.MarkProcessed(id)
|
||||
}
|
||||
|
||||
func (as *AppService) handleOTKCounts(otks OTKCountMap) {
|
||||
for userID, devices := range otks {
|
||||
for deviceID, otkCounts := range devices {
|
||||
otkCounts.UserID = userID
|
||||
otkCounts.DeviceID = deviceID
|
||||
select {
|
||||
case as.OTKCounts <- &otkCounts:
|
||||
default:
|
||||
as.Log.Warnfln("Dropped OTK count update for %s because channel is full", userID)
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (as *AppService) handleDeviceLists(dl *mautrix.DeviceLists) {
|
||||
select {
|
||||
case as.DeviceLists <- dl:
|
||||
default:
|
||||
as.Log.Warnln("Dropped device list update because channel is full")
|
||||
}
|
||||
}
|
||||
|
||||
func (as *AppService) handleEvents(evts []*event.Event, defaultTypeClass event.TypeClass) {
|
||||
for _, evt := range evts {
|
||||
evt.Mautrix.ReceivedAt = time.Now()
|
||||
if defaultTypeClass != event.UnknownEventType {
|
||||
evt.Type.Class = defaultTypeClass
|
||||
} else if evt.StateKey != nil {
|
||||
evt.Type.Class = event.StateEventType
|
||||
} else {
|
||||
evt.Type.Class = event.MessageEventType
|
||||
}
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if errors.Is(err, event.ErrUnsupportedContentType) {
|
||||
as.Log.Debugfln("Not parsing content of %s: %v", evt.ID, err)
|
||||
} else if err != nil {
|
||||
as.Log.Debugfln("Failed to parse content of %s (type %s): %v", evt.ID, evt.Type.Type, err)
|
||||
}
|
||||
|
||||
if evt.Type.IsState() {
|
||||
// TODO remove this check after making sure the log doesn't happen
|
||||
historical, ok := evt.Content.Raw["org.matrix.msc2716.historical"].(bool)
|
||||
if ok && historical {
|
||||
as.Log.Warnfln("Received historical state event %s (%s/%s)", evt.ID, evt.Type.Type, evt.GetStateKey())
|
||||
} else {
|
||||
as.UpdateState(evt)
|
||||
}
|
||||
}
|
||||
if evt.Type.Class == event.ToDeviceEventType {
|
||||
as.ToDeviceEvents <- evt
|
||||
} else {
|
||||
as.Events <- evt
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// GetRoom handles a /rooms GET call from the homeserver.
|
||||
func (as *AppService) GetRoom(w http.ResponseWriter, r *http.Request) {
|
||||
if !as.CheckServerToken(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
roomAlias := vars["roomAlias"]
|
||||
ok := as.QueryHandler.QueryAlias(roomAlias)
|
||||
if ok {
|
||||
WriteBlankOK(w)
|
||||
} else {
|
||||
Error{
|
||||
ErrorCode: ErrUnknown,
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
}.Write(w)
|
||||
}
|
||||
}
|
||||
|
||||
// GetUser handles a /users GET call from the homeserver.
|
||||
func (as *AppService) GetUser(w http.ResponseWriter, r *http.Request) {
|
||||
if !as.CheckServerToken(w, r) {
|
||||
return
|
||||
}
|
||||
|
||||
vars := mux.Vars(r)
|
||||
userID := id.UserID(vars["userID"])
|
||||
ok := as.QueryHandler.QueryUser(userID)
|
||||
if ok {
|
||||
WriteBlankOK(w)
|
||||
} else {
|
||||
Error{
|
||||
ErrorCode: ErrUnknown,
|
||||
HTTPStatus: http.StatusNotFound,
|
||||
}.Write(w)
|
||||
}
|
||||
}
|
||||
|
||||
func (as *AppService) GetLive(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
if as.Live {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
w.Write([]byte("{}"))
|
||||
}
|
||||
|
||||
func (as *AppService) GetReady(w http.ResponseWriter, r *http.Request) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
if as.Ready {
|
||||
w.WriteHeader(http.StatusOK)
|
||||
} else {
|
||||
w.WriteHeader(http.StatusInternalServerError)
|
||||
}
|
||||
w.Write([]byte("{}"))
|
||||
}
|
||||
537
vendor/maunium.net/go/mautrix/appservice/intent.go
generated
vendored
537
vendor/maunium.net/go/mautrix/appservice/intent.go
generated
vendored
@@ -1,537 +0,0 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type IntentAPI struct {
|
||||
*mautrix.Client
|
||||
bot *mautrix.Client
|
||||
as *AppService
|
||||
Localpart string
|
||||
UserID id.UserID
|
||||
|
||||
IsCustomPuppet bool
|
||||
}
|
||||
|
||||
func (as *AppService) NewIntentAPI(localpart string) *IntentAPI {
|
||||
userID := id.NewUserID(localpart, as.HomeserverDomain)
|
||||
bot := as.BotClient()
|
||||
if userID == bot.UserID {
|
||||
bot = nil
|
||||
}
|
||||
return &IntentAPI{
|
||||
Client: as.Client(userID),
|
||||
bot: bot,
|
||||
as: as,
|
||||
Localpart: localpart,
|
||||
UserID: userID,
|
||||
|
||||
IsCustomPuppet: false,
|
||||
}
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) Register() error {
|
||||
_, _, err := intent.Client.Register(&mautrix.ReqRegister{
|
||||
Username: intent.Localpart,
|
||||
Type: mautrix.AuthTypeAppservice,
|
||||
InhibitLogin: true,
|
||||
})
|
||||
return err
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) EnsureRegistered() error {
|
||||
if intent.IsCustomPuppet || intent.as.StateStore.IsRegistered(intent.UserID) {
|
||||
return nil
|
||||
}
|
||||
|
||||
err := intent.Register()
|
||||
if err != nil && !errors.Is(err, mautrix.MUserInUse) {
|
||||
return fmt.Errorf("failed to ensure registered: %w", err)
|
||||
}
|
||||
intent.as.StateStore.MarkRegistered(intent.UserID)
|
||||
return nil
|
||||
}
|
||||
|
||||
type EnsureJoinedParams struct {
|
||||
IgnoreCache bool
|
||||
BotOverride *mautrix.Client
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) EnsureJoined(roomID id.RoomID, extra ...EnsureJoinedParams) error {
|
||||
var params EnsureJoinedParams
|
||||
if len(extra) > 1 {
|
||||
panic("invalid number of extra parameters")
|
||||
} else if len(extra) == 1 {
|
||||
params = extra[0]
|
||||
}
|
||||
if intent.as.StateStore.IsInRoom(roomID, intent.UserID) && !params.IgnoreCache {
|
||||
return nil
|
||||
}
|
||||
|
||||
if err := intent.EnsureRegistered(); err != nil {
|
||||
return fmt.Errorf("failed to ensure joined: %w", err)
|
||||
}
|
||||
|
||||
resp, err := intent.JoinRoomByID(roomID)
|
||||
if err != nil {
|
||||
bot := intent.bot
|
||||
if params.BotOverride != nil {
|
||||
bot = params.BotOverride
|
||||
}
|
||||
if !errors.Is(err, mautrix.MForbidden) || bot == nil {
|
||||
return fmt.Errorf("failed to ensure joined: %w", err)
|
||||
}
|
||||
_, inviteErr := bot.InviteUser(roomID, &mautrix.ReqInviteUser{
|
||||
UserID: intent.UserID,
|
||||
})
|
||||
if inviteErr != nil {
|
||||
return fmt.Errorf("failed to invite in ensure joined: %w", inviteErr)
|
||||
}
|
||||
resp, err = intent.JoinRoomByID(roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to ensure joined after invite: %w", err)
|
||||
}
|
||||
}
|
||||
intent.as.StateStore.SetMembership(resp.RoomID, intent.UserID, event.MembershipJoin)
|
||||
return nil
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) AddDoublePuppetValue(into interface{}) interface{} {
|
||||
if !intent.IsCustomPuppet || intent.as.DoublePuppetValue == "" {
|
||||
return into
|
||||
}
|
||||
switch val := into.(type) {
|
||||
case *map[string]interface{}:
|
||||
if *val == nil {
|
||||
valNonPtr := make(map[string]interface{})
|
||||
*val = valNonPtr
|
||||
}
|
||||
(*val)[DoublePuppetKey] = intent.as.DoublePuppetValue
|
||||
return val
|
||||
case map[string]interface{}:
|
||||
val[DoublePuppetKey] = intent.as.DoublePuppetValue
|
||||
return val
|
||||
case *event.Content:
|
||||
if val.Raw == nil {
|
||||
val.Raw = make(map[string]interface{})
|
||||
}
|
||||
val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue
|
||||
return val
|
||||
case event.Content:
|
||||
if val.Raw == nil {
|
||||
val.Raw = make(map[string]interface{})
|
||||
}
|
||||
val.Raw[DoublePuppetKey] = intent.as.DoublePuppetValue
|
||||
return val
|
||||
default:
|
||||
return &event.Content{
|
||||
Raw: map[string]interface{}{
|
||||
DoublePuppetKey: intent.as.DoublePuppetValue,
|
||||
},
|
||||
Parsed: val,
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValue(contentJSON)
|
||||
return intent.Client.SendMessageEvent(roomID, eventType, contentJSON)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendMassagedMessageEvent(roomID id.RoomID, eventType event.Type, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValue(contentJSON)
|
||||
return intent.Client.SendMessageEvent(roomID, eventType, contentJSON, mautrix.ReqSendEvent{Timestamp: ts})
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, eventID id.EventID) {
|
||||
fakeEvt := &event.Event{
|
||||
StateKey: &stateKey,
|
||||
Sender: intent.UserID,
|
||||
Type: eventType,
|
||||
ID: eventID,
|
||||
RoomID: roomID,
|
||||
Content: event.Content{},
|
||||
}
|
||||
var err error
|
||||
fakeEvt.Content.VeryRaw, err = json.Marshal(contentJSON)
|
||||
if err != nil {
|
||||
intent.Logger.Debugfln("Failed to marshal state event content to update state store: %v", err)
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(fakeEvt.Content.VeryRaw, &fakeEvt.Content.Raw)
|
||||
if err != nil {
|
||||
intent.Logger.Debugfln("Failed to unmarshal state event content to update state store: %v", err)
|
||||
return
|
||||
}
|
||||
err = fakeEvt.Content.ParseRaw(fakeEvt.Type)
|
||||
if err != nil {
|
||||
intent.Logger.Debugfln("Failed to parse state event content to update state store: %v", err)
|
||||
return
|
||||
}
|
||||
intent.as.UpdateState(fakeEvt)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (*mautrix.RespSendEvent, error) {
|
||||
if eventType != event.StateMember || stateKey != string(intent.UserID) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValue(contentJSON)
|
||||
resp, err := intent.Client.SendStateEvent(roomID, eventType, stateKey, contentJSON)
|
||||
if err == nil && resp != nil {
|
||||
intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON, resp.EventID)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}, ts int64) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
contentJSON = intent.AddDoublePuppetValue(contentJSON)
|
||||
resp, err := intent.Client.SendMassagedStateEvent(roomID, eventType, stateKey, contentJSON, ts)
|
||||
if err == nil && resp != nil {
|
||||
intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON, resp.EventID)
|
||||
}
|
||||
return resp, err
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) error {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return err
|
||||
}
|
||||
err := intent.Client.StateEvent(roomID, eventType, stateKey, outContent)
|
||||
if err == nil {
|
||||
intent.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent, "")
|
||||
}
|
||||
return err
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) State(roomID id.RoomID) (mautrix.RoomStateMap, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
state, err := intent.Client.State(roomID)
|
||||
if err == nil {
|
||||
for _, events := range state {
|
||||
for _, evt := range events {
|
||||
intent.as.UpdateState(evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return state, err
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendCustomMembershipEvent(roomID id.RoomID, target id.UserID, membership event.Membership, reason string, extraContent ...map[string]interface{}) (*mautrix.RespSendEvent, error) {
|
||||
content := &event.MemberEventContent{
|
||||
Membership: membership,
|
||||
Reason: reason,
|
||||
}
|
||||
memberContent, ok := intent.as.StateStore.TryGetMember(roomID, target)
|
||||
if !ok {
|
||||
if intent.as.GetProfile != nil {
|
||||
memberContent = intent.as.GetProfile(target, roomID)
|
||||
ok = memberContent != nil
|
||||
}
|
||||
if !ok {
|
||||
profile, err := intent.GetProfile(target)
|
||||
if err != nil {
|
||||
intent.Logger.Debugfln("Failed to get profile for %s to fill new %s membership event: %v", target, membership, err)
|
||||
} else {
|
||||
content.Displayname = profile.DisplayName
|
||||
content.AvatarURL = profile.AvatarURL.CUString()
|
||||
}
|
||||
}
|
||||
}
|
||||
if ok && memberContent != nil {
|
||||
content.Displayname = memberContent.Displayname
|
||||
content.AvatarURL = memberContent.AvatarURL
|
||||
}
|
||||
var extra map[string]interface{}
|
||||
if len(extraContent) > 0 {
|
||||
extra = extraContent[0]
|
||||
}
|
||||
return intent.SendStateEvent(roomID, event.StateMember, target.String(), &event.Content{
|
||||
Parsed: content,
|
||||
Raw: extra,
|
||||
})
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) JoinRoomByID(roomID id.RoomID, extraContent ...map[string]interface{}) (resp *mautrix.RespJoinRoom, err error) {
|
||||
if intent.IsCustomPuppet || len(extraContent) > 0 {
|
||||
_, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipJoin, "", extraContent...)
|
||||
return &mautrix.RespJoinRoom{}, err
|
||||
}
|
||||
resp, err = intent.Client.JoinRoomByID(roomID)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetMembership(roomID, intent.UserID, event.MembershipJoin)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) LeaveRoom(roomID id.RoomID, extra ...interface{}) (resp *mautrix.RespLeaveRoom, err error) {
|
||||
var extraContent map[string]interface{}
|
||||
leaveReq := &mautrix.ReqLeave{}
|
||||
for _, item := range extra {
|
||||
switch val := item.(type) {
|
||||
case map[string]interface{}:
|
||||
extraContent = val
|
||||
case *mautrix.ReqLeave:
|
||||
leaveReq = val
|
||||
}
|
||||
}
|
||||
if intent.IsCustomPuppet || extraContent != nil {
|
||||
_, err = intent.SendCustomMembershipEvent(roomID, intent.UserID, event.MembershipLeave, leaveReq.Reason, extraContent)
|
||||
return &mautrix.RespLeaveRoom{}, err
|
||||
}
|
||||
resp, err = intent.Client.LeaveRoom(roomID, leaveReq)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetMembership(roomID, intent.UserID, event.MembershipLeave)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) InviteUser(roomID id.RoomID, req *mautrix.ReqInviteUser, extraContent ...map[string]interface{}) (resp *mautrix.RespInviteUser, err error) {
|
||||
if intent.IsCustomPuppet || len(extraContent) > 0 {
|
||||
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipInvite, req.Reason, extraContent...)
|
||||
return &mautrix.RespInviteUser{}, err
|
||||
}
|
||||
resp, err = intent.Client.InviteUser(roomID, req)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) KickUser(roomID id.RoomID, req *mautrix.ReqKickUser, extraContent ...map[string]interface{}) (resp *mautrix.RespKickUser, err error) {
|
||||
if intent.IsCustomPuppet || len(extraContent) > 0 {
|
||||
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...)
|
||||
return &mautrix.RespKickUser{}, err
|
||||
}
|
||||
resp, err = intent.Client.KickUser(roomID, req)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) BanUser(roomID id.RoomID, req *mautrix.ReqBanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespBanUser, err error) {
|
||||
if intent.IsCustomPuppet || len(extraContent) > 0 {
|
||||
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipBan, req.Reason, extraContent...)
|
||||
return &mautrix.RespBanUser{}, err
|
||||
}
|
||||
resp, err = intent.Client.BanUser(roomID, req)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) UnbanUser(roomID id.RoomID, req *mautrix.ReqUnbanUser, extraContent ...map[string]interface{}) (resp *mautrix.RespUnbanUser, err error) {
|
||||
if intent.IsCustomPuppet || len(extraContent) > 0 {
|
||||
_, err = intent.SendCustomMembershipEvent(roomID, req.UserID, event.MembershipLeave, req.Reason, extraContent...)
|
||||
return &mautrix.RespUnbanUser{}, err
|
||||
}
|
||||
resp, err = intent.Client.UnbanUser(roomID, req)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) Member(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := intent.as.StateStore.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
_ = intent.StateEvent(roomID, event.StateMember, string(userID), &member)
|
||||
intent.as.StateStore.SetMember(roomID, userID, member)
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) PowerLevels(roomID id.RoomID) (pl *event.PowerLevelsEventContent, err error) {
|
||||
pl = intent.as.StateStore.GetPowerLevels(roomID)
|
||||
if pl == nil {
|
||||
pl = &event.PowerLevelsEventContent{}
|
||||
err = intent.StateEvent(roomID, event.StatePowerLevels, "", pl)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetPowerLevels(roomID, pl)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) (resp *mautrix.RespSendEvent, err error) {
|
||||
resp, err = intent.SendStateEvent(roomID, event.StatePowerLevels, "", &levels)
|
||||
if err == nil {
|
||||
intent.as.StateStore.SetPowerLevels(roomID, levels)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SetPowerLevel(roomID id.RoomID, userID id.UserID, level int) (*mautrix.RespSendEvent, error) {
|
||||
pl, err := intent.PowerLevels(roomID)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
if pl.GetUserLevel(userID) != level {
|
||||
pl.SetUserLevel(userID, level)
|
||||
return intent.SendStateEvent(roomID, event.StatePowerLevels, "", &pl)
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendText(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return intent.Client.SendText(roomID, text)
|
||||
}
|
||||
|
||||
// Deprecated: This does not allow setting image metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
|
||||
func (intent *IntentAPI) SendImage(roomID id.RoomID, body string, url id.ContentURI) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return intent.Client.SendImage(roomID, body, url)
|
||||
}
|
||||
|
||||
// Deprecated: This does not allow setting video metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
|
||||
func (intent *IntentAPI) SendVideo(roomID id.RoomID, body string, url id.ContentURI) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return intent.Client.SendVideo(roomID, body, url)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SendNotice(roomID id.RoomID, text string) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return intent.Client.SendNotice(roomID, text)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...mautrix.ReqRedact) (*mautrix.RespSendEvent, error) {
|
||||
if err := intent.EnsureJoined(roomID); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
var req mautrix.ReqRedact
|
||||
if len(extra) > 0 {
|
||||
req = extra[0]
|
||||
}
|
||||
intent.AddDoublePuppetValue(&req.Extra)
|
||||
return intent.Client.RedactEvent(roomID, eventID, req)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SetRoomName(roomID id.RoomID, roomName string) (*mautrix.RespSendEvent, error) {
|
||||
return intent.SendStateEvent(roomID, event.StateRoomName, "", map[string]interface{}{
|
||||
"name": roomName,
|
||||
})
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SetRoomAvatar(roomID id.RoomID, avatarURL id.ContentURI) (*mautrix.RespSendEvent, error) {
|
||||
return intent.SendStateEvent(roomID, event.StateRoomAvatar, "", map[string]interface{}{
|
||||
"url": avatarURL.String(),
|
||||
})
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SetRoomTopic(roomID id.RoomID, topic string) (*mautrix.RespSendEvent, error) {
|
||||
return intent.SendStateEvent(roomID, event.StateTopic, "", map[string]interface{}{
|
||||
"topic": topic,
|
||||
})
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SetDisplayName(displayName string) error {
|
||||
if err := intent.EnsureRegistered(); err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := intent.Client.GetOwnDisplayName()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check current displayname: %w", err)
|
||||
} else if resp.DisplayName == displayName {
|
||||
// No need to update
|
||||
return nil
|
||||
}
|
||||
return intent.Client.SetDisplayName(displayName)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) SetAvatarURL(avatarURL id.ContentURI) error {
|
||||
if err := intent.EnsureRegistered(); err != nil {
|
||||
return err
|
||||
}
|
||||
resp, err := intent.Client.GetOwnAvatarURL()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check current avatar URL: %w", err)
|
||||
} else if resp.FileID == avatarURL.FileID && resp.Homeserver == avatarURL.Homeserver {
|
||||
// No need to update
|
||||
return nil
|
||||
}
|
||||
return intent.Client.SetAvatarURL(avatarURL)
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) Whoami() (*mautrix.RespWhoami, error) {
|
||||
if err := intent.EnsureRegistered(); err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return intent.Client.Whoami()
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) JoinedMembers(roomID id.RoomID) (resp *mautrix.RespJoinedMembers, err error) {
|
||||
resp, err = intent.Client.JoinedMembers(roomID)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for userID, member := range resp.Joined {
|
||||
intent.as.StateStore.SetMember(roomID, userID, &event.MemberEventContent{
|
||||
Membership: event.MembershipJoin,
|
||||
AvatarURL: id.ContentURIString(member.AvatarURL),
|
||||
Displayname: member.DisplayName,
|
||||
})
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) Members(roomID id.RoomID, req ...mautrix.ReqMembers) (resp *mautrix.RespMembers, err error) {
|
||||
resp, err = intent.Client.Members(roomID, req...)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
for _, evt := range resp.Chunk {
|
||||
intent.as.UpdateState(evt)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (intent *IntentAPI) EnsureInvited(roomID id.RoomID, userID id.UserID) error {
|
||||
if !intent.as.StateStore.IsInvited(roomID, userID) {
|
||||
_, err := intent.InviteUser(roomID, &mautrix.ReqInviteUser{
|
||||
UserID: userID,
|
||||
})
|
||||
if httpErr, ok := err.(mautrix.HTTPError); ok && httpErr.RespError != nil && strings.Contains(httpErr.RespError.Err, "is already in the room") {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
return nil
|
||||
}
|
||||
138
vendor/maunium.net/go/mautrix/appservice/protocol.go
generated
vendored
138
vendor/maunium.net/go/mautrix/appservice/protocol.go
generated
vendored
@@ -1,138 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type OTKCountMap = map[id.UserID]map[id.DeviceID]mautrix.OTKCount
|
||||
type FallbackKeyMap = map[id.UserID]map[id.DeviceID][]id.KeyAlgorithm
|
||||
|
||||
// Transaction contains a list of events.
|
||||
type Transaction struct {
|
||||
Events []*event.Event `json:"events"`
|
||||
EphemeralEvents []*event.Event `json:"ephemeral,omitempty"`
|
||||
ToDeviceEvents []*event.Event `json:"to_device,omitempty"`
|
||||
|
||||
DeviceLists *mautrix.DeviceLists `json:"device_lists,omitempty"`
|
||||
DeviceOTKCount OTKCountMap `json:"device_one_time_keys_count,omitempty"`
|
||||
FallbackKeys FallbackKeyMap `json:"device_unused_fallback_key_types,omitempty"`
|
||||
|
||||
MSC2409EphemeralEvents []*event.Event `json:"de.sorunome.msc2409.ephemeral,omitempty"`
|
||||
MSC2409ToDeviceEvents []*event.Event `json:"de.sorunome.msc2409.to_device,omitempty"`
|
||||
MSC3202DeviceLists *mautrix.DeviceLists `json:"org.matrix.msc3202.device_lists,omitempty"`
|
||||
MSC3202DeviceOTKCount OTKCountMap `json:"org.matrix.msc3202.device_one_time_keys_count,omitempty"`
|
||||
MSC3202FallbackKeys FallbackKeyMap `json:"org.matrix.msc3202.device_unused_fallback_key_types,omitempty"`
|
||||
}
|
||||
|
||||
func (txn *Transaction) MarshalZerologObject(ctx *zerolog.Event) {
|
||||
ctx.Int("pdu", len(txn.Events))
|
||||
ctx.Int("edu", len(txn.EphemeralEvents))
|
||||
ctx.Int("to_device", len(txn.ToDeviceEvents))
|
||||
if len(txn.DeviceOTKCount) > 0 {
|
||||
ctx.Int("otk_count_users", len(txn.DeviceOTKCount))
|
||||
}
|
||||
if txn.DeviceLists != nil {
|
||||
ctx.Int("device_changes", len(txn.DeviceLists.Changed))
|
||||
}
|
||||
if txn.FallbackKeys != nil {
|
||||
ctx.Int("fallback_key_users", len(txn.FallbackKeys))
|
||||
}
|
||||
}
|
||||
|
||||
func (txn *Transaction) ContentString() string {
|
||||
var parts []string
|
||||
if len(txn.Events) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%d PDUs", len(txn.Events)))
|
||||
}
|
||||
if len(txn.EphemeralEvents) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%d EDUs", len(txn.EphemeralEvents)))
|
||||
} else if len(txn.MSC2409EphemeralEvents) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%d EDUs (unstable)", len(txn.MSC2409EphemeralEvents)))
|
||||
}
|
||||
if len(txn.ToDeviceEvents) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%d to-device events", len(txn.ToDeviceEvents)))
|
||||
} else if len(txn.MSC2409ToDeviceEvents) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("%d to-device events (unstable)", len(txn.MSC2409ToDeviceEvents)))
|
||||
}
|
||||
if len(txn.DeviceOTKCount) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("OTK counts for %d users", len(txn.DeviceOTKCount)))
|
||||
} else if len(txn.MSC3202DeviceOTKCount) > 0 {
|
||||
parts = append(parts, fmt.Sprintf("OTK counts for %d users (unstable)", len(txn.MSC3202DeviceOTKCount)))
|
||||
}
|
||||
if txn.DeviceLists != nil {
|
||||
parts = append(parts, fmt.Sprintf("%d device list changes", len(txn.DeviceLists.Changed)))
|
||||
} else if txn.MSC3202DeviceLists != nil {
|
||||
parts = append(parts, fmt.Sprintf("%d device list changes (unstable)", len(txn.MSC3202DeviceLists.Changed)))
|
||||
}
|
||||
if txn.FallbackKeys != nil {
|
||||
parts = append(parts, fmt.Sprintf("unused fallback key counts for %d users", len(txn.FallbackKeys)))
|
||||
} else if txn.MSC3202FallbackKeys != nil {
|
||||
parts = append(parts, fmt.Sprintf("unused fallback key counts for %d users (unstable)", len(txn.MSC3202FallbackKeys)))
|
||||
}
|
||||
return strings.Join(parts, ", ")
|
||||
}
|
||||
|
||||
// EventListener is a function that receives events.
|
||||
type EventListener func(evt *event.Event)
|
||||
|
||||
// WriteBlankOK writes a blank OK message as a reply to a HTTP request.
|
||||
func WriteBlankOK(w http.ResponseWriter) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(http.StatusOK)
|
||||
_, _ = w.Write([]byte("{}"))
|
||||
}
|
||||
|
||||
// Respond responds to a HTTP request with a JSON object.
|
||||
func Respond(w http.ResponseWriter, data interface{}) error {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
dataStr, err := json.Marshal(data)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = w.Write(dataStr)
|
||||
return err
|
||||
}
|
||||
|
||||
// Error represents a Matrix protocol error.
|
||||
type Error struct {
|
||||
HTTPStatus int `json:"-"`
|
||||
ErrorCode ErrorCode `json:"errcode"`
|
||||
Message string `json:"error"`
|
||||
}
|
||||
|
||||
func (err Error) Write(w http.ResponseWriter) {
|
||||
w.Header().Add("Content-Type", "application/json")
|
||||
w.WriteHeader(err.HTTPStatus)
|
||||
_ = Respond(w, &err)
|
||||
}
|
||||
|
||||
// ErrorCode is the machine-readable code in an Error.
|
||||
type ErrorCode string
|
||||
|
||||
// Native ErrorCodes
|
||||
const (
|
||||
ErrUnknownToken ErrorCode = "M_UNKNOWN_TOKEN"
|
||||
ErrBadJSON ErrorCode = "M_BAD_JSON"
|
||||
ErrNotJSON ErrorCode = "M_NOT_JSON"
|
||||
ErrUnknown ErrorCode = "M_UNKNOWN"
|
||||
)
|
||||
|
||||
// Custom ErrorCodes
|
||||
const (
|
||||
ErrNoTransactionID ErrorCode = "NET.MAUNIUM.NO_TRANSACTION_ID"
|
||||
)
|
||||
100
vendor/maunium.net/go/mautrix/appservice/registration.go
generated
vendored
100
vendor/maunium.net/go/mautrix/appservice/registration.go
generated
vendored
@@ -1,100 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"io/ioutil"
|
||||
"regexp"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
|
||||
"maunium.net/go/mautrix/util"
|
||||
)
|
||||
|
||||
// Registration contains the data in a Matrix appservice registration.
|
||||
// See https://spec.matrix.org/v1.2/application-service-api/#registration
|
||||
type Registration struct {
|
||||
ID string `yaml:"id" json:"id"`
|
||||
URL string `yaml:"url" json:"url"`
|
||||
AppToken string `yaml:"as_token" json:"as_token"`
|
||||
ServerToken string `yaml:"hs_token" json:"hs_token"`
|
||||
SenderLocalpart string `yaml:"sender_localpart" json:"sender_localpart"`
|
||||
RateLimited *bool `yaml:"rate_limited,omitempty" json:"rate_limited,omitempty"`
|
||||
Namespaces Namespaces `yaml:"namespaces" json:"namespaces"`
|
||||
Protocols []string `yaml:"protocols,omitempty" json:"protocols,omitempty"`
|
||||
|
||||
SoruEphemeralEvents bool `yaml:"de.sorunome.msc2409.push_ephemeral,omitempty" json:"de.sorunome.msc2409.push_ephemeral,omitempty"`
|
||||
EphemeralEvents bool `yaml:"push_ephemeral,omitempty" json:"push_ephemeral,omitempty"`
|
||||
}
|
||||
|
||||
// CreateRegistration creates a Registration with random appservice and homeserver tokens.
|
||||
func CreateRegistration() *Registration {
|
||||
return &Registration{
|
||||
AppToken: util.RandomString(64),
|
||||
ServerToken: util.RandomString(64),
|
||||
}
|
||||
}
|
||||
|
||||
// LoadRegistration loads a YAML file and turns it into a Registration.
|
||||
func LoadRegistration(path string) (*Registration, error) {
|
||||
data, err := ioutil.ReadFile(path)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
reg := &Registration{}
|
||||
err = yaml.Unmarshal(data, reg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return reg, nil
|
||||
}
|
||||
|
||||
// Save saves this Registration into a file at the given path.
|
||||
func (reg *Registration) Save(path string) error {
|
||||
data, err := yaml.Marshal(reg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
return ioutil.WriteFile(path, data, 0600)
|
||||
}
|
||||
|
||||
// YAML returns the registration in YAML format.
|
||||
func (reg *Registration) YAML() (string, error) {
|
||||
data, err := yaml.Marshal(reg)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
return string(data), nil
|
||||
}
|
||||
|
||||
// Namespaces contains the three areas that appservices can reserve parts of.
|
||||
type Namespaces struct {
|
||||
UserIDs NamespaceList `yaml:"users,omitempty" json:"users,omitempty"`
|
||||
RoomAliases NamespaceList `yaml:"aliases,omitempty" json:"aliases,omitempty"`
|
||||
RoomIDs NamespaceList `yaml:"rooms,omitempty" json:"rooms,omitempty"`
|
||||
}
|
||||
|
||||
// Namespace is a reserved namespace in any area.
|
||||
type Namespace struct {
|
||||
Regex string `yaml:"regex" json:"regex"`
|
||||
Exclusive bool `yaml:"exclusive" json:"exclusive"`
|
||||
}
|
||||
|
||||
type NamespaceList []Namespace
|
||||
|
||||
func (nsl *NamespaceList) Register(regex *regexp.Regexp, exclusive bool) {
|
||||
ns := Namespace{
|
||||
Regex: regex.String(),
|
||||
Exclusive: exclusive,
|
||||
}
|
||||
if nsl == nil {
|
||||
*nsl = []Namespace{ns}
|
||||
} else {
|
||||
*nsl = append(*nsl, ns)
|
||||
}
|
||||
}
|
||||
186
vendor/maunium.net/go/mautrix/appservice/statestore.go
generated
vendored
186
vendor/maunium.net/go/mautrix/appservice/statestore.go
generated
vendored
@@ -1,186 +0,0 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type StateStore interface {
|
||||
IsRegistered(userID id.UserID) bool
|
||||
MarkRegistered(userID id.UserID)
|
||||
|
||||
IsInRoom(roomID id.RoomID, userID id.UserID) bool
|
||||
IsInvited(roomID id.RoomID, userID id.UserID) bool
|
||||
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
|
||||
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
|
||||
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
|
||||
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
|
||||
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
|
||||
|
||||
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
|
||||
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
|
||||
GetPowerLevel(roomID id.RoomID, userID id.UserID) int
|
||||
GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int
|
||||
HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool
|
||||
}
|
||||
|
||||
func (as *AppService) UpdateState(evt *event.Event) {
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.MemberEventContent:
|
||||
as.StateStore.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
|
||||
case *event.PowerLevelsEventContent:
|
||||
as.StateStore.SetPowerLevels(evt.RoomID, content)
|
||||
}
|
||||
}
|
||||
|
||||
type BasicStateStore struct {
|
||||
Registrations map[id.UserID]bool `json:"registrations"`
|
||||
Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"`
|
||||
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
|
||||
|
||||
registrationsLock sync.RWMutex
|
||||
membersLock sync.RWMutex
|
||||
powerLevelsLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewBasicStateStore() StateStore {
|
||||
return &BasicStateStore{
|
||||
Registrations: make(map[id.UserID]bool),
|
||||
Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent),
|
||||
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
|
||||
}
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) IsRegistered(userID id.UserID) bool {
|
||||
store.registrationsLock.RLock()
|
||||
defer store.registrationsLock.RUnlock()
|
||||
registered, ok := store.Registrations[userID]
|
||||
return ok && registered
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) MarkRegistered(userID id.UserID) {
|
||||
store.registrationsLock.Lock()
|
||||
defer store.registrationsLock.Unlock()
|
||||
store.Registrations[userID] = true
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
|
||||
store.membersLock.RLock()
|
||||
members, ok := store.Members[roomID]
|
||||
store.membersLock.RUnlock()
|
||||
if !ok {
|
||||
members = make(map[id.UserID]*event.MemberEventContent)
|
||||
store.membersLock.Lock()
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
return members
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
return store.GetMember(roomID, userID).Membership
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
|
||||
store.membersLock.RLock()
|
||||
defer store.membersLock.RUnlock()
|
||||
members, membersOk := store.Members[roomID]
|
||||
if !membersOk {
|
||||
return
|
||||
}
|
||||
member, ok = members[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
members = map[id.UserID]*event.MemberEventContent{
|
||||
userID: {Membership: membership},
|
||||
}
|
||||
} else {
|
||||
member, ok := members[userID]
|
||||
if !ok {
|
||||
members[userID] = &event.MemberEventContent{Membership: membership}
|
||||
} else {
|
||||
member.Membership = membership
|
||||
members[userID] = member
|
||||
}
|
||||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
members = map[id.UserID]*event.MemberEventContent{
|
||||
userID: member,
|
||||
}
|
||||
} else {
|
||||
members[userID] = member
|
||||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
store.powerLevelsLock.Lock()
|
||||
store.PowerLevels[roomID] = levels
|
||||
store.powerLevelsLock.Unlock()
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
store.powerLevelsLock.RLock()
|
||||
levels = store.PowerLevels[roomID]
|
||||
store.powerLevelsLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
}
|
||||
|
||||
func (store *BasicStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
43
vendor/maunium.net/go/mautrix/appservice/txnid.go
generated
vendored
43
vendor/maunium.net/go/mautrix/appservice/txnid.go
generated
vendored
@@ -1,43 +0,0 @@
|
||||
// Copyright (c) 2021 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import "sync"
|
||||
|
||||
type TransactionIDCache struct {
|
||||
array []string
|
||||
arrayPtr int
|
||||
hash map[string]struct{}
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewTransactionIDCache(size int) *TransactionIDCache {
|
||||
return &TransactionIDCache{
|
||||
array: make([]string, size),
|
||||
hash: make(map[string]struct{}),
|
||||
}
|
||||
}
|
||||
|
||||
func (txnIDC *TransactionIDCache) IsProcessed(txnID string) bool {
|
||||
txnIDC.lock.RLock()
|
||||
_, exists := txnIDC.hash[txnID]
|
||||
txnIDC.lock.RUnlock()
|
||||
return exists
|
||||
}
|
||||
|
||||
func (txnIDC *TransactionIDCache) MarkProcessed(txnID string) {
|
||||
txnIDC.lock.Lock()
|
||||
txnIDC.hash[txnID] = struct{}{}
|
||||
if txnIDC.array[txnIDC.arrayPtr] != "" {
|
||||
for i := 0; i < len(txnIDC.array)/8; i++ {
|
||||
delete(txnIDC.hash, txnIDC.array[txnIDC.arrayPtr+i])
|
||||
txnIDC.array[txnIDC.arrayPtr+i] = ""
|
||||
}
|
||||
}
|
||||
txnIDC.array[txnIDC.arrayPtr] = txnID
|
||||
txnIDC.lock.Unlock()
|
||||
}
|
||||
392
vendor/maunium.net/go/mautrix/appservice/websocket.go
generated
vendored
392
vendor/maunium.net/go/mautrix/appservice/websocket.go
generated
vendored
@@ -1,392 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package appservice
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
"net/url"
|
||||
"path/filepath"
|
||||
"strings"
|
||||
"sync"
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/tidwall/gjson"
|
||||
"github.com/tidwall/sjson"
|
||||
)
|
||||
|
||||
type WebsocketRequest struct {
|
||||
ReqID int `json:"id,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Data interface{} `json:"data"`
|
||||
|
||||
Deadline time.Duration `json:"-"`
|
||||
}
|
||||
|
||||
type WebsocketCommand struct {
|
||||
ReqID int `json:"id,omitempty"`
|
||||
Command string `json:"command"`
|
||||
Data json.RawMessage `json:"data"`
|
||||
}
|
||||
|
||||
func (wsc *WebsocketCommand) MakeResponse(ok bool, data interface{}) *WebsocketRequest {
|
||||
if wsc.ReqID == 0 || wsc.Command == "response" || wsc.Command == "error" {
|
||||
return nil
|
||||
}
|
||||
cmd := "response"
|
||||
if !ok {
|
||||
cmd = "error"
|
||||
}
|
||||
if err, isError := data.(error); isError {
|
||||
var errorData json.RawMessage
|
||||
var jsonErr error
|
||||
unwrappedErr := err
|
||||
var prefixMessage string
|
||||
for unwrappedErr != nil {
|
||||
errorData, jsonErr = json.Marshal(unwrappedErr)
|
||||
if errorData != nil && len(errorData) > 2 && jsonErr == nil {
|
||||
prefixMessage = strings.Replace(err.Error(), unwrappedErr.Error(), "", 1)
|
||||
prefixMessage = strings.TrimRight(prefixMessage, ": ")
|
||||
break
|
||||
}
|
||||
unwrappedErr = errors.Unwrap(unwrappedErr)
|
||||
}
|
||||
if errorData != nil {
|
||||
if !gjson.GetBytes(errorData, "message").Exists() {
|
||||
errorData, _ = sjson.SetBytes(errorData, "message", err.Error())
|
||||
} // else: marshaled error contains a message already
|
||||
} else {
|
||||
errorData, _ = sjson.SetBytes(nil, "message", err.Error())
|
||||
}
|
||||
if len(prefixMessage) > 0 {
|
||||
errorData, _ = sjson.SetBytes(errorData, "prefix_message", prefixMessage)
|
||||
}
|
||||
data = errorData
|
||||
}
|
||||
return &WebsocketRequest{
|
||||
ReqID: wsc.ReqID,
|
||||
Command: cmd,
|
||||
Data: data,
|
||||
}
|
||||
}
|
||||
|
||||
type WebsocketTransaction struct {
|
||||
Status string `json:"status"`
|
||||
TxnID string `json:"txn_id"`
|
||||
Transaction
|
||||
}
|
||||
|
||||
type WebsocketTransactionResponse struct {
|
||||
TxnID string `json:"txn_id"`
|
||||
}
|
||||
|
||||
type WebsocketMessage struct {
|
||||
WebsocketTransaction
|
||||
WebsocketCommand
|
||||
}
|
||||
|
||||
const (
|
||||
WebsocketCloseConnReplaced = 4001
|
||||
WebsocketCloseTxnNotAcknowledged = 4002
|
||||
)
|
||||
|
||||
type MeowWebsocketCloseCode string
|
||||
|
||||
const (
|
||||
MeowServerShuttingDown MeowWebsocketCloseCode = "server_shutting_down"
|
||||
MeowConnectionReplaced MeowWebsocketCloseCode = "conn_replaced"
|
||||
MeowTxnNotAcknowledged MeowWebsocketCloseCode = "transactions_not_acknowledged"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrWebsocketManualStop = errors.New("the websocket was disconnected manually")
|
||||
ErrWebsocketOverridden = errors.New("a new call to StartWebsocket overrode the previous connection")
|
||||
ErrWebsocketUnknownError = errors.New("an unknown error occurred")
|
||||
|
||||
ErrWebsocketNotConnected = errors.New("websocket not connected")
|
||||
ErrWebsocketClosed = errors.New("websocket closed before response received")
|
||||
)
|
||||
|
||||
func (mwcc MeowWebsocketCloseCode) String() string {
|
||||
switch mwcc {
|
||||
case MeowServerShuttingDown:
|
||||
return "the server is shutting down"
|
||||
case MeowConnectionReplaced:
|
||||
return "the connection was replaced by another client"
|
||||
case MeowTxnNotAcknowledged:
|
||||
return "transactions were not acknowledged"
|
||||
default:
|
||||
return string(mwcc)
|
||||
}
|
||||
}
|
||||
|
||||
type CloseCommand struct {
|
||||
Code int `json:"-"`
|
||||
Command string `json:"command"`
|
||||
Status MeowWebsocketCloseCode `json:"status"`
|
||||
}
|
||||
|
||||
func (cc CloseCommand) Error() string {
|
||||
return fmt.Sprintf("websocket: close %d: %s", cc.Code, cc.Status.String())
|
||||
}
|
||||
|
||||
func parseCloseError(err error) error {
|
||||
closeError := &websocket.CloseError{}
|
||||
if !errors.As(err, &closeError) {
|
||||
return err
|
||||
}
|
||||
var closeCommand CloseCommand
|
||||
closeCommand.Code = closeError.Code
|
||||
closeCommand.Command = "disconnect"
|
||||
if len(closeError.Text) > 0 {
|
||||
jsonErr := json.Unmarshal([]byte(closeError.Text), &closeCommand)
|
||||
if jsonErr != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
if len(closeCommand.Status) == 0 {
|
||||
if closeCommand.Code == WebsocketCloseConnReplaced {
|
||||
closeCommand.Status = MeowConnectionReplaced
|
||||
} else if closeCommand.Code == websocket.CloseServiceRestart {
|
||||
closeCommand.Status = MeowServerShuttingDown
|
||||
}
|
||||
}
|
||||
return &closeCommand
|
||||
}
|
||||
|
||||
func (as *AppService) HasWebsocket() bool {
|
||||
return as.ws != nil
|
||||
}
|
||||
|
||||
func (as *AppService) SendWebsocket(cmd *WebsocketRequest) error {
|
||||
ws := as.ws
|
||||
if cmd == nil {
|
||||
return nil
|
||||
} else if ws == nil {
|
||||
return ErrWebsocketNotConnected
|
||||
}
|
||||
as.wsWriteLock.Lock()
|
||||
defer as.wsWriteLock.Unlock()
|
||||
if cmd.Deadline == 0 {
|
||||
cmd.Deadline = 3 * time.Minute
|
||||
}
|
||||
_ = ws.SetWriteDeadline(time.Now().Add(cmd.Deadline))
|
||||
return ws.WriteJSON(cmd)
|
||||
}
|
||||
|
||||
func (as *AppService) clearWebsocketResponseWaiters() {
|
||||
as.websocketRequestsLock.Lock()
|
||||
for _, waiter := range as.websocketRequests {
|
||||
waiter <- &WebsocketCommand{Command: "__websocket_closed"}
|
||||
}
|
||||
as.websocketRequests = make(map[int]chan<- *WebsocketCommand)
|
||||
as.websocketRequestsLock.Unlock()
|
||||
}
|
||||
|
||||
func (as *AppService) addWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) {
|
||||
as.websocketRequestsLock.Lock()
|
||||
as.websocketRequests[reqID] = waiter
|
||||
as.websocketRequestsLock.Unlock()
|
||||
}
|
||||
|
||||
func (as *AppService) removeWebsocketResponseWaiter(reqID int, waiter chan<- *WebsocketCommand) {
|
||||
as.websocketRequestsLock.Lock()
|
||||
existingWaiter, ok := as.websocketRequests[reqID]
|
||||
if ok && existingWaiter == waiter {
|
||||
delete(as.websocketRequests, reqID)
|
||||
}
|
||||
close(waiter)
|
||||
as.websocketRequestsLock.Unlock()
|
||||
}
|
||||
|
||||
type ErrorResponse struct {
|
||||
Code string `json:"code"`
|
||||
Message string `json:"message"`
|
||||
}
|
||||
|
||||
func (er *ErrorResponse) Error() string {
|
||||
return fmt.Sprintf("%s: %s", er.Code, er.Message)
|
||||
}
|
||||
|
||||
func (as *AppService) RequestWebsocket(ctx context.Context, cmd *WebsocketRequest, response interface{}) error {
|
||||
cmd.ReqID = int(atomic.AddInt32(&as.websocketRequestID, 1))
|
||||
respChan := make(chan *WebsocketCommand, 1)
|
||||
as.addWebsocketResponseWaiter(cmd.ReqID, respChan)
|
||||
defer as.removeWebsocketResponseWaiter(cmd.ReqID, respChan)
|
||||
err := as.SendWebsocket(cmd)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
select {
|
||||
case resp := <-respChan:
|
||||
if resp.Command == "__websocket_closed" {
|
||||
return ErrWebsocketClosed
|
||||
} else if resp.Command == "error" {
|
||||
var respErr ErrorResponse
|
||||
err = json.Unmarshal(resp.Data, &respErr)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse error JSON: %w", err)
|
||||
}
|
||||
return &respErr
|
||||
} else if response != nil {
|
||||
err = json.Unmarshal(resp.Data, &response)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse response JSON: %w", err)
|
||||
}
|
||||
return nil
|
||||
} else {
|
||||
return nil
|
||||
}
|
||||
case <-ctx.Done():
|
||||
return ctx.Err()
|
||||
}
|
||||
}
|
||||
|
||||
func (as *AppService) unknownCommandHandler(cmd WebsocketCommand) (bool, interface{}) {
|
||||
as.Log.Warnfln("No handler for websocket command %s (%d)", cmd.Command, cmd.ReqID)
|
||||
return false, fmt.Errorf("unknown request type")
|
||||
}
|
||||
|
||||
func (as *AppService) SetWebsocketCommandHandler(cmd string, handler WebsocketHandler) {
|
||||
as.websocketHandlersLock.Lock()
|
||||
as.websocketHandlers[cmd] = handler
|
||||
as.websocketHandlersLock.Unlock()
|
||||
}
|
||||
|
||||
func (as *AppService) consumeWebsocket(stopFunc func(error), ws *websocket.Conn) {
|
||||
defer stopFunc(ErrWebsocketUnknownError)
|
||||
for {
|
||||
var msg WebsocketMessage
|
||||
err := ws.ReadJSON(&msg)
|
||||
if err != nil {
|
||||
as.Log.Debugln("Error reading from websocket:", err)
|
||||
stopFunc(parseCloseError(err))
|
||||
return
|
||||
}
|
||||
if msg.Command == "" || msg.Command == "transaction" {
|
||||
if msg.TxnID == "" || !as.txnIDC.IsProcessed(msg.TxnID) {
|
||||
as.handleTransaction(msg.TxnID, &msg.Transaction)
|
||||
} else {
|
||||
as.Log.Debugfln("Ignoring duplicate transaction %s (%s)", msg.TxnID, msg.Transaction.ContentString())
|
||||
}
|
||||
go func() {
|
||||
err = as.SendWebsocket(msg.MakeResponse(true, &WebsocketTransactionResponse{TxnID: msg.TxnID}))
|
||||
if err != nil {
|
||||
as.Log.Warnfln("Failed to send response to %s %d: %v", msg.Command, msg.ReqID, err)
|
||||
}
|
||||
}()
|
||||
} else if msg.Command == "connect" {
|
||||
as.Log.Debugln("Websocket connect confirmation received")
|
||||
} else if msg.Command == "response" || msg.Command == "error" {
|
||||
as.websocketRequestsLock.RLock()
|
||||
respChan, ok := as.websocketRequests[msg.ReqID]
|
||||
if ok {
|
||||
select {
|
||||
case respChan <- &msg.WebsocketCommand:
|
||||
default:
|
||||
as.Log.Warnfln("Failed to handle response to %d: channel didn't accept response", msg.ReqID)
|
||||
}
|
||||
} else {
|
||||
as.Log.Warnfln("Dropping response to %d: unknown request ID", msg.ReqID)
|
||||
}
|
||||
as.websocketRequestsLock.RUnlock()
|
||||
} else {
|
||||
as.Log.Debugfln("Received command request %s %d", msg.Command, msg.ReqID)
|
||||
as.websocketHandlersLock.RLock()
|
||||
handler, ok := as.websocketHandlers[msg.Command]
|
||||
as.websocketHandlersLock.RUnlock()
|
||||
if !ok {
|
||||
handler = as.unknownCommandHandler
|
||||
}
|
||||
go func() {
|
||||
okResp, data := handler(msg.WebsocketCommand)
|
||||
err = as.SendWebsocket(msg.MakeResponse(okResp, data))
|
||||
if err != nil {
|
||||
as.Log.Warnfln("Failed to send response to %s %d: %v", msg.Command, msg.ReqID, err)
|
||||
} else if okResp {
|
||||
as.Log.Debugfln("Sent success response to %s %d", msg.Command, msg.ReqID)
|
||||
} else {
|
||||
as.Log.Debugfln("Sent error response to %s %d", msg.Command, msg.ReqID)
|
||||
}
|
||||
}()
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (as *AppService) StartWebsocket(baseURL string, onConnect func()) error {
|
||||
parsed, err := url.Parse(baseURL)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to parse URL: %w", err)
|
||||
}
|
||||
parsed.Path = filepath.Join(parsed.Path, "_matrix/client/unstable/fi.mau.as_sync")
|
||||
if parsed.Scheme == "http" {
|
||||
parsed.Scheme = "ws"
|
||||
} else if parsed.Scheme == "https" {
|
||||
parsed.Scheme = "wss"
|
||||
}
|
||||
ws, resp, err := websocket.DefaultDialer.Dial(parsed.String(), http.Header{
|
||||
"Authorization": []string{fmt.Sprintf("Bearer %s", as.Registration.AppToken)},
|
||||
"User-Agent": []string{as.BotClient().UserAgent},
|
||||
|
||||
"X-Mautrix-Process-ID": []string{as.ProcessID},
|
||||
"X-Mautrix-Websocket-Version": []string{"3"},
|
||||
})
|
||||
if resp != nil && resp.StatusCode >= 400 {
|
||||
var errResp Error
|
||||
err = json.NewDecoder(resp.Body).Decode(&errResp)
|
||||
if err != nil {
|
||||
return fmt.Errorf("websocket request returned HTTP %d with non-JSON body", resp.StatusCode)
|
||||
} else {
|
||||
return fmt.Errorf("websocket request returned %s (HTTP %d): %s", errResp.ErrorCode, resp.StatusCode, errResp.Message)
|
||||
}
|
||||
} else if err != nil {
|
||||
return fmt.Errorf("failed to open websocket: %w", err)
|
||||
}
|
||||
if as.StopWebsocket != nil {
|
||||
as.StopWebsocket(ErrWebsocketOverridden)
|
||||
}
|
||||
closeChan := make(chan error)
|
||||
closeChanOnce := sync.Once{}
|
||||
stopFunc := func(err error) {
|
||||
closeChanOnce.Do(func() {
|
||||
closeChan <- err
|
||||
})
|
||||
}
|
||||
as.ws = ws
|
||||
as.StopWebsocket = stopFunc
|
||||
as.PrepareWebsocket()
|
||||
as.Log.Debugln("Appservice transaction websocket connected")
|
||||
|
||||
go as.consumeWebsocket(stopFunc, ws)
|
||||
|
||||
if onConnect != nil {
|
||||
onConnect()
|
||||
}
|
||||
|
||||
closeErr := <-closeChan
|
||||
|
||||
if as.ws == ws {
|
||||
as.clearWebsocketResponseWaiters()
|
||||
as.ws = nil
|
||||
}
|
||||
|
||||
_ = ws.SetWriteDeadline(time.Now().Add(3 * time.Second))
|
||||
err = ws.WriteMessage(websocket.CloseMessage, websocket.FormatCloseMessage(websocket.CloseGoingAway, ""))
|
||||
if err != nil && !errors.Is(err, websocket.ErrCloseSent) {
|
||||
as.Log.Warnln("Error writing close message to websocket:", err)
|
||||
}
|
||||
err = ws.Close()
|
||||
if err != nil {
|
||||
as.Log.Warnln("Error closing websocket:", err)
|
||||
}
|
||||
return closeErr
|
||||
}
|
||||
253
vendor/maunium.net/go/mautrix/bridge/bridgeconfig/config.go
generated
vendored
253
vendor/maunium.net/go/mautrix/bridge/bridgeconfig/config.go
generated
vendored
@@ -1,253 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package bridgeconfig
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util"
|
||||
up "maunium.net/go/mautrix/util/configupgrade"
|
||||
)
|
||||
|
||||
type HomeserverSoftware string
|
||||
|
||||
const (
|
||||
SoftwareStandard HomeserverSoftware = "standard"
|
||||
SoftwareAsmux HomeserverSoftware = "asmux"
|
||||
SoftwareHungry HomeserverSoftware = "hungry"
|
||||
)
|
||||
|
||||
var AllowedHomeserverSoftware = map[HomeserverSoftware]bool{
|
||||
SoftwareStandard: true,
|
||||
SoftwareAsmux: true,
|
||||
SoftwareHungry: true,
|
||||
}
|
||||
|
||||
type HomeserverConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
Domain string `yaml:"domain"`
|
||||
AsyncMedia bool `yaml:"async_media"`
|
||||
|
||||
Software HomeserverSoftware `yaml:"software"`
|
||||
|
||||
StatusEndpoint string `yaml:"status_endpoint"`
|
||||
MessageSendCheckpointEndpoint string `yaml:"message_send_checkpoint_endpoint"`
|
||||
|
||||
WSProxy string `yaml:"websocket_proxy"`
|
||||
WSPingInterval int `yaml:"ping_interval_seconds"`
|
||||
}
|
||||
|
||||
type AppserviceConfig struct {
|
||||
Address string `yaml:"address"`
|
||||
Hostname string `yaml:"hostname"`
|
||||
Port uint16 `yaml:"port"`
|
||||
|
||||
Database DatabaseConfig `yaml:"database"`
|
||||
|
||||
ID string `yaml:"id"`
|
||||
Bot BotUserConfig `yaml:"bot"`
|
||||
|
||||
ASToken string `yaml:"as_token"`
|
||||
HSToken string `yaml:"hs_token"`
|
||||
|
||||
EphemeralEvents bool `yaml:"ephemeral_events"`
|
||||
AsyncTransactions bool `yaml:"async_transactions"`
|
||||
}
|
||||
|
||||
func (config *BaseConfig) MakeUserIDRegex(matcher string) *regexp.Regexp {
|
||||
usernamePlaceholder := strings.ToLower(util.RandomString(16))
|
||||
usernameTemplate := fmt.Sprintf("@%s:%s",
|
||||
config.Bridge.FormatUsername(usernamePlaceholder),
|
||||
config.Homeserver.Domain)
|
||||
usernameTemplate = regexp.QuoteMeta(usernameTemplate)
|
||||
usernameTemplate = strings.Replace(usernameTemplate, usernamePlaceholder, matcher, 1)
|
||||
usernameTemplate = fmt.Sprintf("^%s$", usernameTemplate)
|
||||
return regexp.MustCompile(usernameTemplate)
|
||||
}
|
||||
|
||||
// GenerateRegistration generates a registration file for the homeserver.
|
||||
func (config *BaseConfig) GenerateRegistration() *appservice.Registration {
|
||||
registration := appservice.CreateRegistration()
|
||||
config.AppService.HSToken = registration.ServerToken
|
||||
config.AppService.ASToken = registration.AppToken
|
||||
config.AppService.copyToRegistration(registration)
|
||||
|
||||
registration.SenderLocalpart = util.RandomString(32)
|
||||
botRegex := regexp.MustCompile(fmt.Sprintf("^@%s:%s$",
|
||||
regexp.QuoteMeta(config.AppService.Bot.Username),
|
||||
regexp.QuoteMeta(config.Homeserver.Domain)))
|
||||
registration.Namespaces.UserIDs.Register(botRegex, true)
|
||||
registration.Namespaces.UserIDs.Register(config.MakeUserIDRegex(".*"), true)
|
||||
|
||||
return registration
|
||||
}
|
||||
|
||||
func (config *BaseConfig) MakeAppService() *appservice.AppService {
|
||||
as := appservice.Create()
|
||||
as.HomeserverDomain = config.Homeserver.Domain
|
||||
as.HomeserverURL = config.Homeserver.Address
|
||||
as.Host.Hostname = config.AppService.Hostname
|
||||
as.Host.Port = config.AppService.Port
|
||||
as.DefaultHTTPRetries = 4
|
||||
as.Registration = config.AppService.GetRegistration()
|
||||
return as
|
||||
}
|
||||
|
||||
// GetRegistration copies the data from the bridge config into an *appservice.Registration struct.
|
||||
// This can't be used with the homeserver, see GenerateRegistration for generating files for the homeserver.
|
||||
func (asc *AppserviceConfig) GetRegistration() *appservice.Registration {
|
||||
reg := &appservice.Registration{}
|
||||
asc.copyToRegistration(reg)
|
||||
reg.SenderLocalpart = asc.Bot.Username
|
||||
reg.ServerToken = asc.HSToken
|
||||
reg.AppToken = asc.ASToken
|
||||
return reg
|
||||
}
|
||||
|
||||
func (asc *AppserviceConfig) copyToRegistration(registration *appservice.Registration) {
|
||||
registration.ID = asc.ID
|
||||
registration.URL = asc.Address
|
||||
falseVal := false
|
||||
registration.RateLimited = &falseVal
|
||||
registration.EphemeralEvents = asc.EphemeralEvents
|
||||
registration.SoruEphemeralEvents = asc.EphemeralEvents
|
||||
}
|
||||
|
||||
type BotUserConfig struct {
|
||||
Username string `yaml:"username"`
|
||||
Displayname string `yaml:"displayname"`
|
||||
Avatar string `yaml:"avatar"`
|
||||
|
||||
ParsedAvatar id.ContentURI `yaml:"-"`
|
||||
}
|
||||
|
||||
type serializableBUC BotUserConfig
|
||||
|
||||
func (buc *BotUserConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
var sbuc serializableBUC
|
||||
err := unmarshal(&sbuc)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
*buc = (BotUserConfig)(sbuc)
|
||||
if buc.Avatar != "" && buc.Avatar != "remove" {
|
||||
buc.ParsedAvatar, err = id.ParseContentURI(buc.Avatar)
|
||||
if err != nil {
|
||||
return fmt.Errorf("%w in bot avatar", err)
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type DatabaseConfig struct {
|
||||
Type string `yaml:"type"`
|
||||
URI string `yaml:"uri"`
|
||||
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
|
||||
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
|
||||
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
type BridgeConfig interface {
|
||||
FormatUsername(username string) string
|
||||
GetEncryptionConfig() EncryptionConfig
|
||||
GetCommandPrefix() string
|
||||
GetManagementRoomTexts() ManagementRoomTexts
|
||||
GetResendBridgeInfo() bool
|
||||
EnableMessageStatusEvents() bool
|
||||
EnableMessageErrorNotices() bool
|
||||
Validate() error
|
||||
}
|
||||
|
||||
type EncryptionConfig struct {
|
||||
Allow bool `yaml:"allow"`
|
||||
Default bool `yaml:"default"`
|
||||
Require bool `yaml:"require"`
|
||||
Appservice bool `yaml:"appservice"`
|
||||
|
||||
VerificationLevels struct {
|
||||
Receive id.TrustState `yaml:"receive"`
|
||||
Send id.TrustState `yaml:"send"`
|
||||
Share id.TrustState `yaml:"share"`
|
||||
} `yaml:"verification_levels"`
|
||||
AllowKeySharing bool `yaml:"allow_key_sharing"`
|
||||
|
||||
Rotation struct {
|
||||
EnableCustom bool `yaml:"enable_custom"`
|
||||
Milliseconds int64 `yaml:"milliseconds"`
|
||||
Messages int `yaml:"messages"`
|
||||
} `yaml:"rotation"`
|
||||
}
|
||||
|
||||
type ManagementRoomTexts struct {
|
||||
Welcome string `yaml:"welcome"`
|
||||
WelcomeConnected string `yaml:"welcome_connected"`
|
||||
WelcomeUnconnected string `yaml:"welcome_unconnected"`
|
||||
AdditionalHelp string `yaml:"additional_help"`
|
||||
}
|
||||
|
||||
type BaseConfig struct {
|
||||
Homeserver HomeserverConfig `yaml:"homeserver"`
|
||||
AppService AppserviceConfig `yaml:"appservice"`
|
||||
Bridge BridgeConfig `yaml:"-"`
|
||||
Logging appservice.LogConfig `yaml:"logging"`
|
||||
}
|
||||
|
||||
func doUpgrade(helper *up.Helper) {
|
||||
helper.Copy(up.Str, "homeserver", "address")
|
||||
helper.Copy(up.Str, "homeserver", "domain")
|
||||
if legacyAsmuxFlag, ok := helper.Get(up.Bool, "homeserver", "asmux"); ok && legacyAsmuxFlag == "true" {
|
||||
helper.Set(up.Str, string(SoftwareAsmux), "homeserver", "software")
|
||||
} else {
|
||||
helper.Copy(up.Str, "homeserver", "software")
|
||||
}
|
||||
helper.Copy(up.Str|up.Null, "homeserver", "status_endpoint")
|
||||
helper.Copy(up.Str|up.Null, "homeserver", "message_send_checkpoint_endpoint")
|
||||
helper.Copy(up.Bool, "homeserver", "async_media")
|
||||
helper.Copy(up.Str|up.Null, "homeserver", "websocket_proxy")
|
||||
helper.Copy(up.Int, "homeserver", "ping_interval_seconds")
|
||||
|
||||
helper.Copy(up.Str, "appservice", "address")
|
||||
helper.Copy(up.Str, "appservice", "hostname")
|
||||
helper.Copy(up.Int, "appservice", "port")
|
||||
if dbType, ok := helper.Get(up.Str, "appservice", "database", "type"); ok && dbType == "sqlite3" {
|
||||
helper.Set(up.Str, "sqlite3-fk-wal", "appservice", "database", "type")
|
||||
} else {
|
||||
helper.Copy(up.Str, "appservice", "database", "type")
|
||||
}
|
||||
helper.Copy(up.Str, "appservice", "database", "uri")
|
||||
helper.Copy(up.Int, "appservice", "database", "max_open_conns")
|
||||
helper.Copy(up.Int, "appservice", "database", "max_idle_conns")
|
||||
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_idle_time")
|
||||
helper.Copy(up.Str|up.Null, "appservice", "database", "max_conn_lifetime")
|
||||
helper.Copy(up.Str, "appservice", "id")
|
||||
helper.Copy(up.Str, "appservice", "bot", "username")
|
||||
helper.Copy(up.Str, "appservice", "bot", "displayname")
|
||||
helper.Copy(up.Str, "appservice", "bot", "avatar")
|
||||
helper.Copy(up.Bool, "appservice", "ephemeral_events")
|
||||
helper.Copy(up.Bool, "appservice", "async_transactions")
|
||||
helper.Copy(up.Str, "appservice", "as_token")
|
||||
helper.Copy(up.Str, "appservice", "hs_token")
|
||||
|
||||
helper.Copy(up.Str, "logging", "directory")
|
||||
helper.Copy(up.Str|up.Null, "logging", "file_name_format")
|
||||
helper.Copy(up.Str|up.Timestamp, "logging", "file_date_format")
|
||||
helper.Copy(up.Int, "logging", "file_mode")
|
||||
helper.Copy(up.Str|up.Timestamp, "logging", "timestamp_format")
|
||||
helper.Copy(up.Str, "logging", "print_level")
|
||||
helper.Copy(up.Bool, "logging", "print_json")
|
||||
helper.Copy(up.Bool, "logging", "file_json")
|
||||
}
|
||||
|
||||
// Upgrader is a config upgrader that copies the default fields in the homeserver, appservice and logging blocks.
|
||||
var Upgrader = up.SimpleUpgrader(doUpgrade)
|
||||
71
vendor/maunium.net/go/mautrix/bridge/bridgeconfig/permissions.go
generated
vendored
71
vendor/maunium.net/go/mautrix/bridge/bridgeconfig/permissions.go
generated
vendored
@@ -1,71 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package bridgeconfig
|
||||
|
||||
import (
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type PermissionConfig map[string]PermissionLevel
|
||||
|
||||
type PermissionLevel int
|
||||
|
||||
const (
|
||||
PermissionLevelBlock PermissionLevel = 0
|
||||
PermissionLevelRelay PermissionLevel = 5
|
||||
PermissionLevelUser PermissionLevel = 10
|
||||
PermissionLevelAdmin PermissionLevel = 100
|
||||
)
|
||||
|
||||
var namesToLevels = map[string]PermissionLevel{
|
||||
"block": PermissionLevelBlock,
|
||||
"relay": PermissionLevelRelay,
|
||||
"user": PermissionLevelUser,
|
||||
"admin": PermissionLevelAdmin,
|
||||
}
|
||||
|
||||
func RegisterPermissionLevel(name string, level PermissionLevel) {
|
||||
namesToLevels[name] = level
|
||||
}
|
||||
|
||||
func (pc *PermissionConfig) UnmarshalYAML(unmarshal func(interface{}) error) error {
|
||||
rawPC := make(map[string]string)
|
||||
err := unmarshal(&rawPC)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if *pc == nil {
|
||||
*pc = make(map[string]PermissionLevel)
|
||||
}
|
||||
for key, value := range rawPC {
|
||||
level, ok := namesToLevels[strings.ToLower(value)]
|
||||
if ok {
|
||||
(*pc)[key] = level
|
||||
} else if val, err := strconv.Atoi(value); err == nil {
|
||||
(*pc)[key] = PermissionLevel(val)
|
||||
} else {
|
||||
(*pc)[key] = PermissionLevelBlock
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (pc PermissionConfig) Get(userID id.UserID) PermissionLevel {
|
||||
if level, ok := pc[string(userID)]; ok {
|
||||
return level
|
||||
} else if level, ok = pc[userID.Homeserver()]; len(userID.Homeserver()) > 0 && ok {
|
||||
return level
|
||||
} else if level, ok = pc["*"]; ok {
|
||||
return level
|
||||
} else {
|
||||
return PermissionLevelBlock
|
||||
}
|
||||
}
|
||||
373
vendor/maunium.net/go/mautrix/client.go
generated
vendored
373
vendor/maunium.net/go/mautrix/client.go
generated
vendored
@@ -18,32 +18,33 @@ import (
|
||||
"sync/atomic"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
"maunium.net/go/maulogger/v2/maulogadapt"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/pushrules"
|
||||
)
|
||||
|
||||
type CryptoHelper interface {
|
||||
Encrypt(id.RoomID, event.Type, any) (*event.EncryptedEventContent, error)
|
||||
Decrypt(*event.Event) (*event.Event, error)
|
||||
WaitForSession(id.RoomID, id.SenderKey, id.SessionID, time.Duration) bool
|
||||
RequestSession(id.RoomID, id.SenderKey, id.SessionID, id.UserID, id.DeviceID)
|
||||
Init() error
|
||||
}
|
||||
|
||||
// Deprecated: switch to zerolog
|
||||
type Logger interface {
|
||||
Debugfln(message string, args ...interface{})
|
||||
}
|
||||
|
||||
// StubLogger is an implementation of Logger that does nothing
|
||||
type StubLogger struct{}
|
||||
|
||||
func (sl *StubLogger) Debugfln(message string, args ...interface{}) {}
|
||||
func (sl *StubLogger) Warnfln(message string, args ...interface{}) {}
|
||||
|
||||
var stubLogger = &StubLogger{}
|
||||
|
||||
// Deprecated: switch to zerolog
|
||||
type WarnLogger interface {
|
||||
Logger
|
||||
Warnfln(message string, args ...interface{})
|
||||
}
|
||||
|
||||
type Stringifiable interface {
|
||||
String() string
|
||||
}
|
||||
|
||||
// Client represents a Matrix client.
|
||||
type Client struct {
|
||||
HomeserverURL *url.URL // The base homeserver URL
|
||||
@@ -53,9 +54,18 @@ type Client struct {
|
||||
UserAgent string // The value for the User-Agent header
|
||||
Client *http.Client // The underlying HTTP client which will be used to make HTTP requests.
|
||||
Syncer Syncer // The thing which can process /sync responses
|
||||
Store Storer // The thing which can store rooms/tokens/ids
|
||||
Logger Logger
|
||||
SyncPresence event.Presence
|
||||
Store SyncStore // The thing which can store tokens/ids
|
||||
StateStore StateStore
|
||||
Crypto CryptoHelper
|
||||
|
||||
Log zerolog.Logger
|
||||
// Deprecated: switch to the zerolog instance in Log
|
||||
Logger Logger
|
||||
|
||||
RequestHook func(req *http.Request)
|
||||
ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration)
|
||||
|
||||
SyncPresence event.Presence
|
||||
|
||||
StreamSyncMinAge time.Duration
|
||||
|
||||
@@ -67,10 +77,9 @@ type Client struct {
|
||||
|
||||
txnID int32
|
||||
|
||||
// The ?user_id= query parameter for application services. This must be set *prior* to calling a method.
|
||||
// If this is empty, no user_id parameter will be sent.
|
||||
// See https://spec.matrix.org/v1.2/application-service-api/#identity-assertion
|
||||
AppServiceUserID id.UserID
|
||||
// Should the ?user_id= query parameter be set in requests?
|
||||
// See https://spec.matrix.org/v1.6/application-service-api/#identity-assertion
|
||||
SetAppServiceUserID bool
|
||||
|
||||
syncingID uint32 // Identifies the current Sync. Only one Sync can be active at any given time.
|
||||
}
|
||||
@@ -180,7 +189,7 @@ func (cli *Client) SyncWithContext(ctx context.Context) error {
|
||||
for {
|
||||
streamResp := false
|
||||
if cli.StreamSyncMinAge > 0 && time.Since(lastSuccessfulSync) > cli.StreamSyncMinAge {
|
||||
cli.Logger.Debugfln("Last sync is old, will stream next response")
|
||||
cli.Log.Debug().Msg("Last sync is old, will stream next response")
|
||||
streamResp = true
|
||||
}
|
||||
resSync, err := cli.FullSyncRequest(ReqSync{
|
||||
@@ -242,40 +251,52 @@ func (cli *Client) StopSync() {
|
||||
cli.incrementSyncingID()
|
||||
}
|
||||
|
||||
const logBodyContextKey = "fi.mau.mautrix.log_body"
|
||||
const logRequestIDContextKey = "fi.mau.mautrix.request_id"
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
LogBodyContextKey contextKey = iota
|
||||
LogRequestIDContextKey
|
||||
)
|
||||
|
||||
func (cli *Client) LogRequest(req *http.Request) {
|
||||
if cli.Logger == stubLogger {
|
||||
return
|
||||
if cli.RequestHook != nil {
|
||||
cli.RequestHook(req)
|
||||
}
|
||||
body, ok := req.Context().Value(logBodyContextKey).(string)
|
||||
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
|
||||
if ok && len(body) > 0 {
|
||||
cli.Logger.Debugfln("req #%d: %s %s %s", reqID, req.Method, req.URL.String(), body)
|
||||
} else {
|
||||
cli.Logger.Debugfln("req #%d: %s %s", reqID, req.Method, req.URL.String())
|
||||
evt := zerolog.Ctx(req.Context()).Debug().
|
||||
Str("method", req.Method).
|
||||
Str("url", req.URL.String())
|
||||
body := req.Context().Value(LogBodyContextKey)
|
||||
if body != nil {
|
||||
evt.Interface("body", body)
|
||||
}
|
||||
evt.Msg("Sending request")
|
||||
}
|
||||
|
||||
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, handlerErr error, contentLength int, duration time.Duration) {
|
||||
if cli.Logger == stubLogger {
|
||||
return
|
||||
if cli.ResponseHook != nil {
|
||||
cli.ResponseHook(req, resp, duration)
|
||||
}
|
||||
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
|
||||
mime := resp.Header.Get("Content-Type")
|
||||
var suffix string
|
||||
if handlerErr != nil {
|
||||
suffix = fmt.Sprintf(" (but parsing the body failed)")
|
||||
}
|
||||
length := resp.ContentLength
|
||||
if length == -1 && contentLength > 0 {
|
||||
length = int64(contentLength)
|
||||
}
|
||||
cli.Logger.Debugfln(
|
||||
"req #%d (%s) completed in %s with status %d and %d bytes of %s body%s",
|
||||
reqID, strings.TrimPrefix(req.URL.Path, "/_matrix/client"), duration, resp.StatusCode, length, mime, suffix,
|
||||
)
|
||||
path := strings.TrimPrefix(req.URL.Path, cli.HomeserverURL.Path)
|
||||
path = strings.TrimPrefix(path, "/_matrix/client")
|
||||
evt := zerolog.Ctx(req.Context()).Debug().
|
||||
Str("method", req.Method).
|
||||
Str("path", path).
|
||||
Int("status_code", resp.StatusCode).
|
||||
Int64("response_length", length).
|
||||
Str("response_mime", mime).
|
||||
Dur("duration", duration)
|
||||
if handlerErr != nil {
|
||||
evt.AnErr("body_parse_err", handlerErr)
|
||||
}
|
||||
if serverRequestID := resp.Header.Get("X-Beeper-Request-ID"); serverRequestID != "" {
|
||||
evt.Str("beeper_request_id", serverRequestID)
|
||||
}
|
||||
evt.Msg("Request completed")
|
||||
}
|
||||
|
||||
func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) {
|
||||
@@ -297,13 +318,14 @@ type FullRequest struct {
|
||||
MaxAttempts int
|
||||
SensitiveContent bool
|
||||
Handler ClientResponseHandler
|
||||
Logger *zerolog.Logger
|
||||
}
|
||||
|
||||
var requestID int32
|
||||
var logSensitiveContent = os.Getenv("MAUTRIX_LOG_SENSITIVE_CONTENT") == "yes"
|
||||
|
||||
func (params *FullRequest) compileRequest() (*http.Request, error) {
|
||||
var logBody string
|
||||
var logBody any
|
||||
reqBody := params.RequestBody
|
||||
if params.Context == nil {
|
||||
params.Context = context.Background()
|
||||
@@ -319,7 +341,7 @@ func (params *FullRequest) compileRequest() (*http.Request, error) {
|
||||
if params.SensitiveContent && !logSensitiveContent {
|
||||
logBody = "<sensitive content omitted>"
|
||||
} else {
|
||||
logBody = string(jsonStr)
|
||||
logBody = params.RequestJSON
|
||||
}
|
||||
reqBody = bytes.NewReader(jsonStr)
|
||||
} else if params.RequestBytes != nil {
|
||||
@@ -330,12 +352,20 @@ func (params *FullRequest) compileRequest() (*http.Request, error) {
|
||||
logBody = fmt.Sprintf("<%d bytes>", params.RequestLength)
|
||||
} else if params.Method != http.MethodGet && params.Method != http.MethodHead {
|
||||
params.RequestJSON = struct{}{}
|
||||
logBody = "<default empty object>"
|
||||
logBody = params.RequestJSON
|
||||
reqBody = bytes.NewReader([]byte("{}"))
|
||||
}
|
||||
ctx := context.WithValue(params.Context, logBodyContextKey, logBody)
|
||||
reqID := atomic.AddInt32(&requestID, 1)
|
||||
ctx = context.WithValue(ctx, logRequestIDContextKey, int(reqID))
|
||||
ctx := params.Context
|
||||
logger := zerolog.Ctx(ctx)
|
||||
if logger.GetLevel() == zerolog.Disabled || logger == zerolog.DefaultContextLogger {
|
||||
logger = params.Logger
|
||||
}
|
||||
ctx = logger.With().
|
||||
Int32("req_id", reqID).
|
||||
Logger().WithContext(ctx)
|
||||
ctx = context.WithValue(ctx, LogBodyContextKey, logBody)
|
||||
ctx = context.WithValue(ctx, LogRequestIDContextKey, int(reqID))
|
||||
req, err := http.NewRequestWithContext(ctx, params.Method, params.URL, reqBody)
|
||||
if err != nil {
|
||||
return nil, HTTPError{
|
||||
@@ -365,6 +395,9 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
|
||||
if params.MaxAttempts == 0 {
|
||||
params.MaxAttempts = 1 + cli.DefaultHTTPRetries
|
||||
}
|
||||
if params.Logger == nil {
|
||||
params.Logger = &cli.Log
|
||||
}
|
||||
req, err := params.compileRequest()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -379,30 +412,31 @@ func (cli *Client) MakeFullRequest(params FullRequest) ([]byte, error) {
|
||||
return cli.executeCompiledRequest(req, params.MaxAttempts-1, 4*time.Second, params.ResponseJSON, params.Handler)
|
||||
}
|
||||
|
||||
func (cli *Client) logWarning(format string, args ...interface{}) {
|
||||
warnLogger, ok := cli.Logger.(WarnLogger)
|
||||
if ok {
|
||||
warnLogger.Warnfln(format, args...)
|
||||
} else {
|
||||
cli.Logger.Debugfln(format, args...)
|
||||
func (cli *Client) cliOrContextLog(ctx context.Context) *zerolog.Logger {
|
||||
log := zerolog.Ctx(ctx)
|
||||
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
|
||||
return &cli.Log
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
||||
func (cli *Client) doRetry(req *http.Request, cause error, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
|
||||
reqID, _ := req.Context().Value(logRequestIDContextKey).(int)
|
||||
log := zerolog.Ctx(req.Context())
|
||||
if req.Body != nil {
|
||||
if req.GetBody == nil {
|
||||
cli.logWarning("Failed to get new body to retry request #%d: GetBody is nil", reqID)
|
||||
log.Warn().Msg("Failed to get new body to retry request: GetBody is nil")
|
||||
return nil, cause
|
||||
}
|
||||
var err error
|
||||
req.Body, err = req.GetBody()
|
||||
if err != nil {
|
||||
cli.logWarning("Failed to get new body to retry request #%d: %v", reqID, err)
|
||||
log.Warn().Err(err).Msg("Failed to get new body to retry request")
|
||||
return nil, cause
|
||||
}
|
||||
}
|
||||
cli.logWarning("Request #%d failed: %v, retrying in %d seconds", reqID, cause, int(backoff.Seconds()))
|
||||
log.Warn().Err(cause).
|
||||
Int("retry_in_seconds", int(backoff.Seconds())).
|
||||
Msg("Request failed, retrying")
|
||||
time.Sleep(backoff)
|
||||
return cli.executeCompiledRequest(req, retries-1, backoff*2, responseJSON, handler)
|
||||
}
|
||||
@@ -421,22 +455,23 @@ func (cli *Client) readRequestBody(req *http.Request, res *http.Response) ([]byt
|
||||
return contents, nil
|
||||
}
|
||||
|
||||
func (cli *Client) closeTemp(file *os.File) {
|
||||
func closeTemp(log *zerolog.Logger, file *os.File) {
|
||||
_ = file.Close()
|
||||
err := os.Remove(file.Name())
|
||||
if err != nil {
|
||||
cli.logWarning("Failed to remove temp file %s: %v", file.Name(), err)
|
||||
log.Warn().Err(err).Str("file_name", file.Name()).Msg("Failed to remove response temp file")
|
||||
}
|
||||
}
|
||||
|
||||
func (cli *Client) streamResponse(req *http.Request, res *http.Response, responseJSON interface{}) ([]byte, error) {
|
||||
log := zerolog.Ctx(req.Context())
|
||||
file, err := os.CreateTemp("", "mautrix-response-")
|
||||
if err != nil {
|
||||
cli.logWarning("Failed to create temporary file: %v", err)
|
||||
log.Warn().Err(err).Msg("Failed to create temporary file for streaming response")
|
||||
_, err = cli.handleNormalResponse(req, res, responseJSON)
|
||||
return nil, err
|
||||
}
|
||||
defer cli.closeTemp(file)
|
||||
defer closeTemp(log, file)
|
||||
if _, err = io.Copy(file, res.Body); err != nil {
|
||||
return nil, fmt.Errorf("failed to copy response to file: %w", err)
|
||||
} else if _, err = file.Seek(0, 0); err != nil {
|
||||
@@ -487,7 +522,7 @@ func (cli *Client) handleResponseError(req *http.Request, res *http.Response) ([
|
||||
|
||||
// parseBackoffFromResponse extracts the backoff time specified in the Retry-After header if present. See
|
||||
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After.
|
||||
func (cli *Client) parseBackoffFromResponse(res *http.Response, now time.Time, fallback time.Duration) time.Duration {
|
||||
func (cli *Client) parseBackoffFromResponse(req *http.Request, res *http.Response, now time.Time, fallback time.Duration) time.Duration {
|
||||
retryAfterHeaderValue := res.Header.Get("Retry-After")
|
||||
if retryAfterHeaderValue == "" {
|
||||
return fallback
|
||||
@@ -501,7 +536,9 @@ func (cli *Client) parseBackoffFromResponse(res *http.Response, now time.Time, f
|
||||
return time.Duration(seconds) * time.Second
|
||||
}
|
||||
|
||||
cli.logWarning(`Failed to parse Retry-After header value "%s"`, retryAfterHeaderValue)
|
||||
zerolog.Ctx(req.Context()).Warn().
|
||||
Str("retry_after", retryAfterHeaderValue).
|
||||
Msg("Failed to parse Retry-After header value")
|
||||
|
||||
return fallback
|
||||
}
|
||||
@@ -536,7 +573,7 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
|
||||
|
||||
if retries > 0 && cli.shouldRetry(res) {
|
||||
if res.StatusCode == http.StatusTooManyRequests {
|
||||
backoff = cli.parseBackoffFromResponse(res, time.Now(), backoff)
|
||||
backoff = cli.parseBackoffFromResponse(req, res, time.Now(), backoff)
|
||||
}
|
||||
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler)
|
||||
}
|
||||
@@ -631,7 +668,11 @@ func (cli *Client) FullSyncRequest(req ReqSync) (resp *RespSync, err error) {
|
||||
buffer = 1 * time.Minute
|
||||
}
|
||||
if err == nil && duration > timeout+buffer {
|
||||
cli.logWarning("Sync request (%s) took %s with timeout %s", req.Since, duration, timeout)
|
||||
cli.cliOrContextLog(fullReq.Context).Warn().
|
||||
Str("since", req.Since).
|
||||
Dur("duration", duration).
|
||||
Dur("timeout", timeout).
|
||||
Msg("Sync request took unusually long")
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -760,15 +801,24 @@ func (cli *Client) Login(req *ReqLogin) (resp *RespLogin, err error) {
|
||||
cli.DeviceID = resp.DeviceID
|
||||
cli.AccessToken = resp.AccessToken
|
||||
cli.UserID = resp.UserID
|
||||
cli.Logger.Debugfln("Stored credentials for %s/%s after login", cli.UserID, cli.DeviceID)
|
||||
|
||||
cli.Log.Debug().
|
||||
Str("user_id", cli.UserID.String()).
|
||||
Str("device_id", cli.DeviceID.String()).
|
||||
Msg("Stored credentials after login")
|
||||
}
|
||||
if req.StoreHomeserverURL && err == nil && resp.WellKnown != nil && len(resp.WellKnown.Homeserver.BaseURL) > 0 {
|
||||
var urlErr error
|
||||
cli.HomeserverURL, urlErr = url.Parse(resp.WellKnown.Homeserver.BaseURL)
|
||||
if urlErr != nil {
|
||||
cli.logWarning("Failed to parse homeserver URL '%s' in login response: %v", resp.WellKnown.Homeserver.BaseURL, urlErr)
|
||||
cli.Log.Warn().
|
||||
Err(urlErr).
|
||||
Str("homeserver_url", resp.WellKnown.Homeserver.BaseURL).
|
||||
Msg("Failed to parse homeserver URL in login response")
|
||||
} else {
|
||||
cli.Logger.Debugfln("Updated homeserver URL to %s after login", cli.HomeserverURL.String())
|
||||
cli.Log.Debug().
|
||||
Str("homeserver_url", cli.HomeserverURL.String()).
|
||||
Msg("Updated homeserver URL after login")
|
||||
}
|
||||
}
|
||||
return
|
||||
@@ -818,6 +868,9 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{
|
||||
urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias)
|
||||
}
|
||||
_, err = cli.MakeRequest("POST", urlPath, content, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -827,6 +880,9 @@ func (cli *Client) JoinRoom(roomIDorAlias, serverName string, content interface{
|
||||
// It's mostly intended for bridges and other things where it's already certain that the server is in the room.
|
||||
func (cli *Client) JoinRoomByID(roomID id.RoomID) (resp *RespJoinRoom, err error) {
|
||||
_, err = cli.MakeRequest("POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -892,6 +948,13 @@ func (cli *Client) SetAvatarURL(url id.ContentURI) (err error) {
|
||||
return nil
|
||||
}
|
||||
|
||||
// BeeperUpdateProfile sets custom fields in the user's profile.
|
||||
func (cli *Client) BeeperUpdateProfile(data map[string]any) (err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID)
|
||||
_, err = cli.MakeRequest("PATCH", urlPath, &data, nil)
|
||||
return
|
||||
}
|
||||
|
||||
// GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype
|
||||
func (cli *Client) GetAccountData(name string, output interface{}) (err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name)
|
||||
@@ -932,6 +995,8 @@ type ReqSendEvent struct {
|
||||
Timestamp int64
|
||||
TransactionID string
|
||||
|
||||
DontEncrypt bool
|
||||
|
||||
MeowEventID id.EventID
|
||||
}
|
||||
|
||||
@@ -958,8 +1023,16 @@ func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, cont
|
||||
queryParams["fi.mau.event_id"] = req.MeowEventID.String()
|
||||
}
|
||||
|
||||
urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID}
|
||||
if !req.DontEncrypt && cli.Crypto != nil && eventType != event.EventReaction && eventType != event.EventEncrypted && cli.StateStore.IsEncrypted(roomID) {
|
||||
contentJSON, err = cli.Crypto.Encrypt(roomID, eventType, contentJSON)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to encrypt event: %w", err)
|
||||
return
|
||||
}
|
||||
eventType = event.EventEncrypted
|
||||
}
|
||||
|
||||
urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID}
|
||||
urlPath := cli.BuildURLWithQuery(urlData, queryParams)
|
||||
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
|
||||
return
|
||||
@@ -970,6 +1043,9 @@ func (cli *Client) SendMessageEvent(roomID id.RoomID, eventType event.Type, cont
|
||||
func (cli *Client) SendStateEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
|
||||
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -980,6 +1056,9 @@ func (cli *Client) SendMassagedStateEvent(roomID id.RoomID, eventType event.Type
|
||||
"ts": strconv.FormatInt(ts, 10),
|
||||
})
|
||||
_, err = cli.MakeRequest("PUT", urlPath, contentJSON, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, contentJSON)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -992,30 +1071,6 @@ func (cli *Client) SendText(roomID id.RoomID, text string) (*RespSendEvent, erro
|
||||
})
|
||||
}
|
||||
|
||||
// SendImage sends an m.room.message event into the given room with a msgtype of m.image
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#mimage
|
||||
//
|
||||
// Deprecated: This does not allow setting image metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
|
||||
func (cli *Client) SendImage(roomID id.RoomID, body string, url id.ContentURI) (*RespSendEvent, error) {
|
||||
return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{
|
||||
MsgType: event.MsgImage,
|
||||
Body: body,
|
||||
URL: url.CUString(),
|
||||
})
|
||||
}
|
||||
|
||||
// SendVideo sends an m.room.message event into the given room with a msgtype of m.video
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#mvideo
|
||||
//
|
||||
// Deprecated: This does not allow setting video metadata, you should prefer SendMessageEvent with a properly filled &event.MessageEventContent
|
||||
func (cli *Client) SendVideo(roomID id.RoomID, body string, url id.ContentURI) (*RespSendEvent, error) {
|
||||
return cli.SendMessageEvent(roomID, event.EventMessage, &event.MessageEventContent{
|
||||
MsgType: event.MsgVideo,
|
||||
Body: body,
|
||||
URL: url.CUString(),
|
||||
})
|
||||
}
|
||||
|
||||
// SendNotice sends an m.room.message event into the given room with a msgtype of m.notice
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#mnotice
|
||||
func (cli *Client) SendNotice(roomID id.RoomID, text string) (*RespSendEvent, error) {
|
||||
@@ -1067,6 +1122,22 @@ func (cli *Client) RedactEvent(roomID id.RoomID, eventID id.EventID, extra ...Re
|
||||
func (cli *Client) CreateRoom(req *ReqCreateRoom) (resp *RespCreateRoom, err error) {
|
||||
urlPath := cli.BuildClientURL("v3", "createRoom")
|
||||
_, err = cli.MakeRequest("POST", urlPath, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(resp.RoomID, cli.UserID, event.MembershipJoin)
|
||||
for _, evt := range req.InitialState {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
}
|
||||
inviteMembership := event.MembershipInvite
|
||||
if req.BeeperAutoJoinInvites {
|
||||
inviteMembership = event.MembershipJoin
|
||||
}
|
||||
for _, invitee := range req.Invite {
|
||||
cli.StateStore.SetMembership(resp.RoomID, invitee, inviteMembership)
|
||||
}
|
||||
for _, evt := range req.InitialState {
|
||||
cli.updateStoreWithOutgoingEvent(resp.RoomID, evt.Type, evt.GetStateKey(), &evt.Content)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1080,6 +1151,9 @@ func (cli *Client) LeaveRoom(roomID id.RoomID, optionalReq ...*ReqLeave) (resp *
|
||||
}
|
||||
u := cli.BuildClientURL("v3", "rooms", roomID, "leave")
|
||||
_, err = cli.MakeRequest("POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, cli.UserID, event.MembershipLeave)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1094,6 +1168,9 @@ func (cli *Client) ForgetRoom(roomID id.RoomID) (resp *RespForgetRoom, err error
|
||||
func (cli *Client) InviteUser(roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) {
|
||||
u := cli.BuildClientURL("v3", "rooms", roomID, "invite")
|
||||
_, err = cli.MakeRequest("POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipInvite)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1108,6 +1185,9 @@ func (cli *Client) InviteUserByThirdParty(roomID id.RoomID, req *ReqInvite3PID)
|
||||
func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) {
|
||||
u := cli.BuildClientURL("v3", "rooms", roomID, "kick")
|
||||
_, err = cli.MakeRequest("POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1115,6 +1195,9 @@ func (cli *Client) KickUser(roomID id.RoomID, req *ReqKickUser) (resp *RespKickU
|
||||
func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) {
|
||||
u := cli.BuildClientURL("v3", "rooms", roomID, "ban")
|
||||
_, err = cli.MakeRequest("POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipBan)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1122,6 +1205,9 @@ func (cli *Client) BanUser(roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser
|
||||
func (cli *Client) UnbanUser(roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) {
|
||||
u := cli.BuildClientURL("v3", "rooms", roomID, "unban")
|
||||
_, err = cli.MakeRequest("POST", u, req, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.SetMembership(roomID, req.UserID, event.MembershipLeave)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1153,12 +1239,48 @@ func (cli *Client) SetPresence(status event.Presence) (err error) {
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *Client) updateStoreWithOutgoingEvent(roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) {
|
||||
if cli.StateStore == nil {
|
||||
return
|
||||
}
|
||||
fakeEvt := &event.Event{
|
||||
StateKey: &stateKey,
|
||||
Type: eventType,
|
||||
RoomID: roomID,
|
||||
}
|
||||
var err error
|
||||
fakeEvt.Content.VeryRaw, err = json.Marshal(contentJSON)
|
||||
if err != nil {
|
||||
cli.Log.Warn().Err(err).Msg("Failed to marshal state event content to update state store")
|
||||
return
|
||||
}
|
||||
err = json.Unmarshal(fakeEvt.Content.VeryRaw, &fakeEvt.Content.Raw)
|
||||
if err != nil {
|
||||
cli.Log.Warn().Err(err).Msg("Failed to unmarshal state event content to update state store")
|
||||
return
|
||||
}
|
||||
err = fakeEvt.Content.ParseRaw(fakeEvt.Type)
|
||||
if err != nil {
|
||||
switch fakeEvt.Type {
|
||||
case event.StateMember, event.StatePowerLevels, event.StateEncryption:
|
||||
cli.Log.Warn().Err(err).Msg("Failed to parse state event content to update state store")
|
||||
default:
|
||||
cli.Log.Debug().Err(err).Msg("Failed to parse state event content to update state store")
|
||||
}
|
||||
return
|
||||
}
|
||||
UpdateStateStore(cli.StateStore, fakeEvt)
|
||||
}
|
||||
|
||||
// StateEvent gets a single state event in a room. It will attempt to JSON unmarshal into the given "outContent" struct with
|
||||
// the HTTP response body, or return an error.
|
||||
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey
|
||||
func (cli *Client) StateEvent(roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) {
|
||||
u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
|
||||
_, err = cli.MakeRequest("GET", u, nil, outContent)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.updateStoreWithOutgoingEvent(roomID, eventType, stateKey, outContent)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1182,6 +1304,7 @@ func parseRoomStateArray(_ *http.Request, res *http.Response, responseJSON inter
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to parse state array item #%d: %v", i, err)
|
||||
}
|
||||
evt.Type.Class = event.StateEventType
|
||||
_ = evt.Content.ParseRaw(evt.Type)
|
||||
subMap, ok := response[evt.Type]
|
||||
if !ok {
|
||||
@@ -1209,6 +1332,14 @@ func (cli *Client) State(roomID id.RoomID) (stateMap RoomStateMap, err error) {
|
||||
ResponseJSON: &stateMap,
|
||||
Handler: parseRoomStateArray,
|
||||
})
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.ClearCachedMembers(roomID)
|
||||
for _, evts := range stateMap {
|
||||
for _, evt := range evts {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
}
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1245,6 +1376,10 @@ func (cli *Client) DownloadContext(ctx context.Context, mxcURL id.ContentURI) (i
|
||||
}
|
||||
|
||||
func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Request, *http.Response, error) {
|
||||
ctxLog := zerolog.Ctx(ctx)
|
||||
if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger {
|
||||
ctx = cli.Log.WithContext(ctx)
|
||||
}
|
||||
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil)
|
||||
if err != nil {
|
||||
return req, nil, err
|
||||
@@ -1303,7 +1438,7 @@ func (cli *Client) UnstableUploadAsync(req ReqUploadMedia) (*RespCreateMXC, erro
|
||||
go func() {
|
||||
_, err = cli.UploadMedia(req)
|
||||
if err != nil {
|
||||
cli.logWarning("Failed to upload %s: %v", req.UnstableMXC, err)
|
||||
cli.Log.Error().Str("mxc", req.UnstableMXC.String()).Err(err).Msg("Async upload of media failed")
|
||||
}
|
||||
}()
|
||||
return resp, nil
|
||||
@@ -1349,7 +1484,7 @@ type ReqUploadMedia struct {
|
||||
}
|
||||
|
||||
func (cli *Client) tryUploadMediaToURL(url, contentType string, content io.Reader) (*http.Response, error) {
|
||||
cli.Logger.Debugfln("Uploading media to external URL %s", url)
|
||||
cli.Log.Debug().Str("url", url).Msg("Uploading media to external URL")
|
||||
req, err := http.NewRequest(http.MethodPut, url, content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -1382,10 +1517,10 @@ func (cli *Client) uploadMediaToURL(data ReqUploadMedia) (*RespMediaUpload, erro
|
||||
err = fmt.Errorf("HTTP %d", resp.StatusCode)
|
||||
}
|
||||
if retries <= 0 {
|
||||
cli.logWarning("Error uploading media to %s: %v, not retrying", data.UploadURL, err)
|
||||
cli.Log.Warn().Str("url", data.UploadURL).Err(err).Msg("Error uploading media to external URL, not retrying")
|
||||
return nil, err
|
||||
}
|
||||
cli.Logger.Debugfln("Error uploading media to %s: %v, retrying", data.UploadURL, err)
|
||||
cli.Log.Warn().Str("url", data.UploadURL).Err(err).Msg("Error uploading media to external URL, retrying")
|
||||
retries--
|
||||
}
|
||||
|
||||
@@ -1464,6 +1599,16 @@ func (cli *Client) GetURLPreview(url string) (*RespPreviewURL, error) {
|
||||
func (cli *Client) JoinedMembers(roomID id.RoomID) (resp *RespJoinedMembers, err error) {
|
||||
u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members")
|
||||
_, err = cli.MakeRequest("GET", u, nil, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
cli.StateStore.ClearCachedMembers(roomID, event.MembershipJoin)
|
||||
for userID, member := range resp.Joined {
|
||||
cli.StateStore.SetMember(roomID, userID, &event.MemberEventContent{
|
||||
Membership: event.MembershipJoin,
|
||||
AvatarURL: id.ContentURIString(member.AvatarURL),
|
||||
Displayname: member.DisplayName,
|
||||
})
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1484,6 +1629,18 @@ func (cli *Client) Members(roomID id.RoomID, req ...ReqMembers) (resp *RespMembe
|
||||
}
|
||||
u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query)
|
||||
_, err = cli.MakeRequest("GET", u, nil, &resp)
|
||||
if err == nil && cli.StateStore != nil {
|
||||
var clearMemberships []event.Membership
|
||||
if extra.Membership != "" {
|
||||
clearMemberships = append(clearMemberships, extra.Membership)
|
||||
}
|
||||
if extra.NotMembership == "" {
|
||||
cli.StateStore.ClearCachedMembers(roomID, clearMemberships...)
|
||||
}
|
||||
for _, evt := range resp.Chunk {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
@@ -1833,6 +1990,18 @@ func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBat
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, err error) {
|
||||
_, err = cli.MakeFullRequest(FullRequest{
|
||||
Method: http.MethodPost,
|
||||
URL: cli.BuildClientURL("v1", "appservice", id, "ping"),
|
||||
RequestJSON: &ReqAppservicePing{TxnID: txnID},
|
||||
ResponseJSON: &resp,
|
||||
// This endpoint intentionally returns 50x, so don't retry
|
||||
MaxAttempts: 1,
|
||||
})
|
||||
return
|
||||
}
|
||||
|
||||
func (cli *Client) BeeperMergeRooms(req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) {
|
||||
urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "merge")
|
||||
_, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp)
|
||||
@@ -1859,21 +2028,23 @@ func (cli *Client) TxnID() string {
|
||||
|
||||
// NewClient creates a new Matrix Client ready for syncing
|
||||
func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Client, error) {
|
||||
hsURL, err := parseAndNormalizeBaseURL(homeserverURL)
|
||||
hsURL, err := ParseAndNormalizeBaseURL(homeserverURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
return &Client{
|
||||
cli := &Client{
|
||||
AccessToken: accessToken,
|
||||
UserAgent: DefaultUserAgent,
|
||||
HomeserverURL: hsURL,
|
||||
UserID: userID,
|
||||
Client: &http.Client{Timeout: 180 * time.Second},
|
||||
Syncer: NewDefaultSyncer(),
|
||||
Logger: stubLogger,
|
||||
Log: zerolog.Nop(),
|
||||
// By default, use an in-memory store which will never save filter ids / next batch tokens to disk.
|
||||
// The client will work with this storer: it just won't remember across restarts.
|
||||
// In practice, a database backend should be used.
|
||||
Store: NewInMemoryStore(),
|
||||
}, nil
|
||||
Store: NewMemorySyncStore(),
|
||||
}
|
||||
cli.Logger = maulogadapt.ZeroAsMau(&cli.Log)
|
||||
return cli, nil
|
||||
}
|
||||
|
||||
15
vendor/maunium.net/go/mautrix/crypto/cross_sign_key.go
generated
vendored
15
vendor/maunium.net/go/mautrix/crypto/cross_sign_key.go
generated
vendored
@@ -1,4 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -55,8 +56,11 @@ func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err erro
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Trace("Got cross-signing keys: Master `%v` Self-signing `%v` User-signing `%v`",
|
||||
keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey)
|
||||
mach.Log.Debug().
|
||||
Str("master", keysCache.MasterKey.PublicKey.String()).
|
||||
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
|
||||
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
|
||||
Msg("Imported own cross-signing keys")
|
||||
|
||||
mach.CrossSigningKeys = &keysCache
|
||||
mach.crossSigningPubkeys = keysCache.PublicKeys()
|
||||
@@ -76,8 +80,11 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
|
||||
if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil {
|
||||
return nil, fmt.Errorf("failed to generate user-signing key: %w", err)
|
||||
}
|
||||
mach.Log.Debug("Generated cross-signing keys: Master: `%v` Self-signing: `%v` User-signing: `%v`",
|
||||
keysCache.MasterKey.PublicKey, keysCache.SelfSigningKey.PublicKey, keysCache.UserSigningKey.PublicKey)
|
||||
mach.Log.Debug().
|
||||
Str("master", keysCache.MasterKey.PublicKey.String()).
|
||||
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()).
|
||||
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()).
|
||||
Msg("Generated cross-signing keys")
|
||||
return &keysCache, nil
|
||||
}
|
||||
|
||||
|
||||
4
vendor/maunium.net/go/mautrix/crypto/cross_sign_pubkey.go
generated
vendored
4
vendor/maunium.net/go/mautrix/crypto/cross_sign_pubkey.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -32,7 +32,7 @@ func (mach *OlmMachine) GetOwnCrossSigningPublicKeys() *CrossSigningPublicKeysCa
|
||||
}
|
||||
cspk, err := mach.GetCrossSigningPublicKeys(mach.Client.UserID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to get own cross-signing public keys: %v", err)
|
||||
mach.Log.Error().Err(err).Msg("Failed to get own cross-signing public keys")
|
||||
return nil
|
||||
}
|
||||
mach.crossSigningPubkeys = cspk
|
||||
|
||||
18
vendor/maunium.net/go/mautrix/crypto/cross_sign_signing.go
generated
vendored
18
vendor/maunium.net/go/mautrix/crypto/cross_sign_signing.go
generated
vendored
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -79,7 +79,10 @@ func (mach *OlmMachine) SignUser(userID id.UserID, masterKey id.Ed25519) error {
|
||||
return err
|
||||
}
|
||||
|
||||
mach.Log.Trace("Signed master key of %s with user-signing key: `%v`", userID, signature)
|
||||
mach.Log.Debug().
|
||||
Str("user_id", userID.String()).
|
||||
Str("signature", signature).
|
||||
Msg("Signed master key of user with our user-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
@@ -116,7 +119,10 @@ func (mach *OlmMachine) SignOwnMasterKey() error {
|
||||
id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature,
|
||||
},
|
||||
}
|
||||
mach.Log.Trace("Signed own master key with device %v: `%v`", deviceID, signature)
|
||||
mach.Log.Debug().
|
||||
Str("device_id", deviceID.String()).
|
||||
Str("signature", signature).
|
||||
Msg("Signed own master key with own device key")
|
||||
|
||||
resp, err := mach.Client.UploadSignatures(&mautrix.ReqUploadSignatures{
|
||||
userID: map[string]mautrix.ReqKeysSignatures{
|
||||
@@ -165,7 +171,11 @@ func (mach *OlmMachine) SignOwnDevice(device *id.Device) error {
|
||||
return err
|
||||
}
|
||||
|
||||
mach.Log.Trace("Signed own device %s with self-signing key: `%v`", device.UserID, device.DeviceID, signature)
|
||||
mach.Log.Debug().
|
||||
Str("user_id", device.UserID.String()).
|
||||
Str("device_id", device.DeviceID.String()).
|
||||
Str("signature", signature).
|
||||
Msg("Signed own device key with self-signing key")
|
||||
|
||||
if err := mach.CryptoStore.PutSignature(device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil {
|
||||
return fmt.Errorf("error storing signature in crypto store: %w", err)
|
||||
|
||||
44
vendor/maunium.net/go/mautrix/crypto/cross_sign_store.go
generated
vendored
44
vendor/maunium.net/go/mautrix/crypto/cross_sign_store.go
generated
vendored
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -8,28 +8,36 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
|
||||
func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningKeys map[id.UserID]mautrix.CrossSigningKeys, deviceKeys map[id.UserID]map[id.DeviceID]mautrix.DeviceKeys) {
|
||||
log := mach.machOrContextLog(ctx)
|
||||
for userID, userKeys := range crossSigningKeys {
|
||||
log := log.With().Str("user_id", userID.String()).Logger()
|
||||
currentKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error fetching current cross-signing keys of user %v: %v", userID, err)
|
||||
log.Error().Err(err).
|
||||
Msg("Error fetching current cross-signing keys of user")
|
||||
}
|
||||
if currentKeys != nil {
|
||||
for curKeyUsage, curKey := range currentKeys {
|
||||
log := log.With().Str("old_key", curKey.Key.String()).Str("old_key_usage", string(curKeyUsage)).Logger()
|
||||
// got a new key with the same usage as an existing key
|
||||
for _, newKeyUsage := range userKeys.Usage {
|
||||
if newKeyUsage == curKeyUsage {
|
||||
if _, ok := userKeys.Keys[id.NewKeyID(id.KeyAlgorithmEd25519, curKey.Key.String())]; !ok {
|
||||
// old key is not in the new key map, so we drop signatures made by it
|
||||
if count, err := mach.CryptoStore.DropSignaturesByKey(userID, curKey.Key); err != nil {
|
||||
mach.Log.Error("Error deleting old signatures made by %s (%s): %v", curKey, curKeyUsage, err)
|
||||
log.Error().Err(err).Msg("Error deleting old signatures made by user")
|
||||
} else {
|
||||
mach.Log.Debug("Dropped %d signatures made by key %s (%s) as it has been replaced", count, curKey, curKeyUsage)
|
||||
log.Debug().
|
||||
Int64("signature_count", count).
|
||||
Msg("Dropped signatures made by old key as it has been replaced")
|
||||
}
|
||||
}
|
||||
break
|
||||
@@ -39,10 +47,11 @@ func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mau
|
||||
}
|
||||
|
||||
for _, key := range userKeys.Keys {
|
||||
log := log.With().Str("key", key.String()).Strs("usages", strishArray(userKeys.Usage)).Logger()
|
||||
for _, usage := range userKeys.Usage {
|
||||
mach.Log.Debug("Storing cross-signing key for %s: %s (type %s)", userID, key, usage)
|
||||
log.Debug().Str("usage", string(usage)).Msg("Storing cross-signing key")
|
||||
if err = mach.CryptoStore.PutCrossSigningKey(userID, usage, key); err != nil {
|
||||
mach.Log.Error("Error storing cross-signing key: %v", err)
|
||||
log.Error().Err(err).Msg("Error storing cross-signing key")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -50,31 +59,38 @@ func (mach *OlmMachine) storeCrossSigningKeys(crossSigningKeys map[id.UserID]mau
|
||||
for signKeyID, signature := range keySigs {
|
||||
_, signKeyName := signKeyID.Parse()
|
||||
signingKey := id.Ed25519(signKeyName)
|
||||
log := log.With().
|
||||
Str("sign_key_id", signKeyID.String()).
|
||||
Str("signer_user_id", signUserID.String()).
|
||||
Str("signing_key", signingKey.String()).
|
||||
Logger()
|
||||
// if the signer is one of this user's own devices, find the key from the key ID
|
||||
if signUserID == userID {
|
||||
ownDeviceID := id.DeviceID(signKeyName)
|
||||
if ownDeviceKeys, ok := deviceKeys[userID][ownDeviceID]; ok {
|
||||
signingKey = ownDeviceKeys.Keys.GetEd25519(ownDeviceID)
|
||||
mach.Log.Trace("Treating %s as the device ID -> signing key %s", signKeyName, signingKey)
|
||||
log.Trace().
|
||||
Str("device_id", signKeyName).
|
||||
Msg("Treating key name as device ID")
|
||||
}
|
||||
}
|
||||
if len(signingKey) != 43 {
|
||||
mach.Log.Trace("Cross-signing key %s/%s/%v has a signature from an unknown key %s", userID, key, userKeys.Usage, signKeyID)
|
||||
log.Debug().Msg("Cross-signing key has a signature from an unknown key")
|
||||
continue
|
||||
}
|
||||
|
||||
mach.Log.Debug("Verifying cross-signing key %s/%s/%v with key %s/%s", userID, key, userKeys.Usage, signUserID, signingKey)
|
||||
log.Debug().Msg("Verifying cross-signing key signature")
|
||||
if verified, err := olm.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil {
|
||||
mach.Log.Warn("Error while verifying signature from %s for %s: %v", signingKey, key, err)
|
||||
log.Warn().Err(err).Msg("Error verifying cross-signing key signature")
|
||||
} else {
|
||||
if verified {
|
||||
mach.Log.Debug("Signature from %s for %s verified", signingKey, key)
|
||||
log.Debug().Err(err).Msg("Cross-signing key signature verified")
|
||||
err = mach.CryptoStore.PutSignature(userID, key, signUserID, signingKey, signature)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to store signature from %s for %s: %v", signingKey, key, err)
|
||||
log.Error().Err(err).Msg("Error storing cross-signing key signature")
|
||||
}
|
||||
} else {
|
||||
mach.Log.Error("Invalid signature from %s for %s", signingKey, key)
|
||||
log.Warn().Err(err).Msg("Cross-siging key signature is invalid")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
92
vendor/maunium.net/go/mautrix/crypto/cross_sign_validation.go
generated
vendored
92
vendor/maunium.net/go/mautrix/crypto/cross_sign_validation.go
generated
vendored
@@ -1,5 +1,5 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -8,52 +8,72 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// ResolveTrust resolves the trust state of the device from cross-signing.
|
||||
func (mach *OlmMachine) ResolveTrust(device *id.Device) id.TrustState {
|
||||
state, _ := mach.ResolveTrustContext(context.Background(), device)
|
||||
return state
|
||||
}
|
||||
|
||||
// ResolveTrustContext resolves the trust state of the device from cross-signing.
|
||||
func (mach *OlmMachine) ResolveTrustContext(ctx context.Context, device *id.Device) (id.TrustState, error) {
|
||||
if device.Trust == id.TrustStateVerified || device.Trust == id.TrustStateBlacklisted {
|
||||
return device.Trust
|
||||
return device.Trust, nil
|
||||
}
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(device.UserID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", device.UserID, err)
|
||||
return id.TrustStateUnset
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
Msg("Error retrieving cross-signing key of user from database")
|
||||
return id.TrustStateUnset, err
|
||||
}
|
||||
theirMSK, ok := theirKeys[id.XSUsageMaster]
|
||||
if !ok {
|
||||
mach.Log.Error("Master key of user %v not found", device.UserID)
|
||||
return id.TrustStateUnset
|
||||
mach.machOrContextLog(ctx).Error().
|
||||
Str("user_id", device.UserID.String()).
|
||||
Msg("Master key of user not found")
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
theirSSK, ok := theirKeys[id.XSUsageSelfSigning]
|
||||
if !ok {
|
||||
mach.Log.Error("Self-signing key of user %v not found", device.UserID)
|
||||
return id.TrustStateUnset
|
||||
mach.machOrContextLog(ctx).Error().
|
||||
Str("user_id", device.UserID.String()).
|
||||
Msg("Self-signing key of user not found")
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
sskSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, theirSSK.Key, device.UserID, theirMSK.Key)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", device.UserID, err)
|
||||
return id.TrustStateUnset
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
Msg("Error retrieving cross-signing signatures for master key of user from database")
|
||||
return id.TrustStateUnset, err
|
||||
}
|
||||
if !sskSigExists {
|
||||
mach.Log.Warn("Self-signing key of user %v is not signed by their master key", device.UserID)
|
||||
return id.TrustStateUnset
|
||||
mach.machOrContextLog(ctx).Error().
|
||||
Str("user_id", device.UserID.String()).
|
||||
Msg("Self-signing key of user is not signed by their master key")
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
deviceSigExists, err := mach.CryptoStore.IsKeySignedBy(device.UserID, device.SigningKey, device.UserID, theirSSK.Key)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", device.UserID, err)
|
||||
return id.TrustStateUnset
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", device.UserID.String()).
|
||||
Str("device_key", device.SigningKey.String()).
|
||||
Msg("Error retrieving cross-signing signatures for device from database")
|
||||
return id.TrustStateUnset, err
|
||||
}
|
||||
if deviceSigExists {
|
||||
if mach.IsUserTrusted(device.UserID) {
|
||||
return id.TrustStateCrossSignedVerified
|
||||
if trusted, err := mach.IsUserTrusted(ctx, device.UserID); !trusted {
|
||||
return id.TrustStateCrossSignedVerified, err
|
||||
} else if theirMSK.Key == theirMSK.First {
|
||||
return id.TrustStateCrossSignedTOFU
|
||||
return id.TrustStateCrossSignedTOFU, nil
|
||||
}
|
||||
return id.TrustStateCrossSignedUntrusted
|
||||
return id.TrustStateCrossSignedUntrusted, nil
|
||||
}
|
||||
return id.TrustStateUnset
|
||||
return id.TrustStateUnset, nil
|
||||
}
|
||||
|
||||
// IsDeviceTrusted returns whether a device has been determined to be trusted either through verification or cross-signing.
|
||||
@@ -68,36 +88,42 @@ func (mach *OlmMachine) IsDeviceTrusted(device *id.Device) bool {
|
||||
|
||||
// IsUserTrusted returns whether a user has been determined to be trusted by our user-signing key having signed their master key.
|
||||
// In the case the user ID is our own and we have successfully retrieved our cross-signing keys, we trust our own user.
|
||||
func (mach *OlmMachine) IsUserTrusted(userID id.UserID) bool {
|
||||
func (mach *OlmMachine) IsUserTrusted(ctx context.Context, userID id.UserID) (bool, error) {
|
||||
csPubkeys := mach.GetOwnCrossSigningPublicKeys()
|
||||
if csPubkeys == nil {
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
if userID == mach.Client.UserID {
|
||||
return true
|
||||
return true, nil
|
||||
}
|
||||
// first we verify our user-signing key
|
||||
ourUserSigningKeyTrusted, err := mach.CryptoStore.IsKeySignedBy(mach.Client.UserID, csPubkeys.UserSigningKey, mach.Client.UserID, csPubkeys.MasterKey)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error retrieving our self-singing key signatures: %v", err)
|
||||
return false
|
||||
mach.machOrContextLog(ctx).Error().Err(err).Msg("Error retrieving our self-signing key signatures from database")
|
||||
return false, err
|
||||
} else if !ourUserSigningKeyTrusted {
|
||||
return false
|
||||
return false, nil
|
||||
}
|
||||
theirKeys, err := mach.CryptoStore.GetCrossSigningKeys(userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error retrieving cross-singing key of user %v from database: %v", userID, err)
|
||||
return false
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
Msg("Error retrieving cross-signing key of user from database")
|
||||
return false, err
|
||||
}
|
||||
theirMskKey, ok := theirKeys[id.XSUsageMaster]
|
||||
if !ok {
|
||||
mach.Log.Error("Master key of user %v not found", userID)
|
||||
return false
|
||||
mach.machOrContextLog(ctx).Error().
|
||||
Str("user_id", userID.String()).
|
||||
Msg("Master key of user not found")
|
||||
return false, nil
|
||||
}
|
||||
sigExists, err := mach.CryptoStore.IsKeySignedBy(userID, theirMskKey.Key, mach.Client.UserID, csPubkeys.UserSigningKey)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error retrieving cross-singing signatures for master key of user %v from database: %v", userID, err)
|
||||
return false
|
||||
mach.machOrContextLog(ctx).Error().Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
Msg("Error retrieving cross-signing signatures for master key of user from database")
|
||||
return false, err
|
||||
}
|
||||
return sigExists
|
||||
return sigExists, nil
|
||||
}
|
||||
|
||||
374
vendor/maunium.net/go/mautrix/crypto/cryptohelper/cryptohelper.go
generated
vendored
Normal file
374
vendor/maunium.net/go/mautrix/crypto/cryptohelper/cryptohelper.go
generated
vendored
Normal file
@@ -0,0 +1,374 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package cryptohelper
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/sqlstatestore"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
type CryptoHelper struct {
|
||||
client *mautrix.Client
|
||||
mach *crypto.OlmMachine
|
||||
log zerolog.Logger
|
||||
lock sync.RWMutex
|
||||
pickleKey []byte
|
||||
|
||||
managedStateStore *sqlstatestore.SQLStateStore
|
||||
unmanagedCryptoStore crypto.Store
|
||||
dbForManagedStores *dbutil.Database
|
||||
|
||||
DecryptErrorCallback func(*event.Event, error)
|
||||
|
||||
LoginAs *mautrix.ReqLogin
|
||||
|
||||
DBAccountID string
|
||||
}
|
||||
|
||||
var _ mautrix.CryptoHelper = (*CryptoHelper)(nil)
|
||||
|
||||
// NewCryptoHelper creates a struct that helps a mautrix client struct with Matrix e2ee operations.
|
||||
//
|
||||
// The client and pickle key are always required. Additionally, you must either:
|
||||
// - Provide a crypto.Store here and set a StateStore in the client, or
|
||||
// - Provide a dbutil.Database here to automatically create missing stores.
|
||||
// - Provide a string here to use it as a path to a SQLite database, and then automatically create missing stores.
|
||||
//
|
||||
// The same database may be shared across multiple clients, but note that doing that will allow all clients access to
|
||||
// decryption keys received by any one of the clients. For that reason, the pickle key must also be same for all clients
|
||||
// using the same database.
|
||||
func NewCryptoHelper(cli *mautrix.Client, pickleKey []byte, store any) (*CryptoHelper, error) {
|
||||
if len(pickleKey) == 0 {
|
||||
return nil, fmt.Errorf("pickle key must be provided")
|
||||
}
|
||||
_, isExtensible := cli.Syncer.(mautrix.ExtensibleSyncer)
|
||||
if !isExtensible {
|
||||
return nil, fmt.Errorf("the client syncer must implement ExtensibleSyncer")
|
||||
}
|
||||
|
||||
var managedStateStore *sqlstatestore.SQLStateStore
|
||||
var dbForManagedStores *dbutil.Database
|
||||
var unmanagedCryptoStore crypto.Store
|
||||
switch typedStore := store.(type) {
|
||||
case crypto.Store:
|
||||
if cli.StateStore == nil {
|
||||
return nil, fmt.Errorf("when passing a crypto.Store to NewCryptoHelper, the client must have a state store set beforehand")
|
||||
} else if _, isCryptoCompatible := cli.StateStore.(crypto.StateStore); !isCryptoCompatible {
|
||||
return nil, fmt.Errorf("the client state store must implement crypto.StateStore")
|
||||
}
|
||||
unmanagedCryptoStore = typedStore
|
||||
case string:
|
||||
db, err := dbutil.NewWithDialect(typedStore, "sqlite3")
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
dbForManagedStores = db
|
||||
case *dbutil.Database:
|
||||
dbForManagedStores = typedStore
|
||||
default:
|
||||
return nil, fmt.Errorf("you must pass a *dbutil.Database or *crypto.StateStore to NewCryptoHelper")
|
||||
}
|
||||
log := cli.Log.With().Str("component", "crypto").Logger()
|
||||
if cli.StateStore == nil && dbForManagedStores != nil {
|
||||
managedStateStore = sqlstatestore.NewSQLStateStore(dbForManagedStores, dbutil.ZeroLogger(log.With().Str("db_section", "matrix_state").Logger()), false)
|
||||
cli.StateStore = managedStateStore
|
||||
} else if _, isCryptoCompatible := cli.StateStore.(crypto.StateStore); !isCryptoCompatible {
|
||||
return nil, fmt.Errorf("the client state store must implement crypto.StateStore")
|
||||
}
|
||||
|
||||
return &CryptoHelper{
|
||||
client: cli,
|
||||
log: log,
|
||||
pickleKey: pickleKey,
|
||||
|
||||
unmanagedCryptoStore: unmanagedCryptoStore,
|
||||
managedStateStore: managedStateStore,
|
||||
dbForManagedStores: dbForManagedStores,
|
||||
|
||||
DecryptErrorCallback: func(_ *event.Event, _ error) {},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Init() error {
|
||||
if helper == nil {
|
||||
return fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
syncer, ok := helper.client.Syncer.(mautrix.ExtensibleSyncer)
|
||||
if !ok {
|
||||
return fmt.Errorf("the client syncer must implement ExtensibleSyncer")
|
||||
}
|
||||
|
||||
var stateStore crypto.StateStore
|
||||
if helper.managedStateStore != nil {
|
||||
err := helper.managedStateStore.Upgrade()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade client state store: %w", err)
|
||||
}
|
||||
stateStore = helper.managedStateStore
|
||||
} else {
|
||||
stateStore = helper.client.StateStore.(crypto.StateStore)
|
||||
}
|
||||
var cryptoStore crypto.Store
|
||||
if helper.unmanagedCryptoStore == nil {
|
||||
managedCryptoStore := crypto.NewSQLCryptoStore(helper.dbForManagedStores, dbutil.ZeroLogger(helper.log.With().Str("db_section", "crypto").Logger()), helper.DBAccountID, helper.client.DeviceID, helper.pickleKey)
|
||||
if helper.client.Store == nil {
|
||||
helper.client.Store = managedCryptoStore
|
||||
} else if _, isMemory := helper.client.Store.(*mautrix.MemorySyncStore); isMemory {
|
||||
helper.client.Store = managedCryptoStore
|
||||
}
|
||||
err := managedCryptoStore.DB.Upgrade()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to upgrade crypto state store: %w", err)
|
||||
}
|
||||
storedDeviceID := managedCryptoStore.FindDeviceID()
|
||||
if helper.LoginAs != nil {
|
||||
if storedDeviceID != "" {
|
||||
helper.LoginAs.DeviceID = storedDeviceID
|
||||
}
|
||||
helper.LoginAs.StoreCredentials = true
|
||||
helper.log.Debug().
|
||||
Str("username", helper.LoginAs.Identifier.User).
|
||||
Str("device_id", helper.LoginAs.DeviceID.String()).
|
||||
Msg("Logging in")
|
||||
_, err = helper.client.Login(helper.LoginAs)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
if storedDeviceID == "" {
|
||||
managedCryptoStore.DeviceID = helper.client.DeviceID
|
||||
}
|
||||
} else if storedDeviceID != "" && storedDeviceID != helper.client.DeviceID {
|
||||
return fmt.Errorf("mismatching device ID in client and crypto store (%q != %q)", storedDeviceID, helper.client.DeviceID)
|
||||
}
|
||||
cryptoStore = managedCryptoStore
|
||||
} else {
|
||||
if helper.LoginAs != nil {
|
||||
return fmt.Errorf("LoginAs can only be used with a managed crypto store")
|
||||
}
|
||||
cryptoStore = helper.unmanagedCryptoStore
|
||||
}
|
||||
if helper.client.DeviceID == "" || helper.client.UserID == "" {
|
||||
return fmt.Errorf("the client must be logged in")
|
||||
}
|
||||
helper.mach = crypto.NewOlmMachine(helper.client, &helper.log, cryptoStore, stateStore)
|
||||
err := helper.mach.Load()
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to load olm account: %w", err)
|
||||
} else if err = helper.verifyDeviceKeysOnServer(); err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
syncer.OnSync(helper.mach.ProcessSyncResponse)
|
||||
syncer.OnEventType(event.StateMember, helper.mach.HandleMemberEvent)
|
||||
if _, ok = helper.client.Syncer.(mautrix.DispatchableSyncer); ok {
|
||||
syncer.OnEventType(event.EventEncrypted, helper.HandleEncrypted)
|
||||
} else {
|
||||
helper.log.Warn().Msg("Client syncer does not implement DispatchableSyncer. Events will not be decrypted automatically.")
|
||||
}
|
||||
if helper.managedStateStore != nil {
|
||||
syncer.OnEvent(helper.client.StateStoreSyncHandler)
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Close() error {
|
||||
if helper != nil && helper.dbForManagedStores != nil {
|
||||
err := helper.dbForManagedStores.RawDB.Close()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Machine() *crypto.OlmMachine {
|
||||
if helper == nil || helper.mach == nil {
|
||||
panic("Machine() called before initing CryptoHelper")
|
||||
}
|
||||
return helper.mach
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) verifyDeviceKeysOnServer() error {
|
||||
helper.log.Debug().Msg("Making sure our device has the expected keys on the server")
|
||||
resp, err := helper.client.QueryKeys(&mautrix.ReqQueryKeys{
|
||||
DeviceKeys: map[id.UserID]mautrix.DeviceIDList{
|
||||
helper.client.UserID: {helper.client.DeviceID},
|
||||
},
|
||||
})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to query own keys to make sure device is properly configured: %w", err)
|
||||
}
|
||||
ownID := helper.mach.OwnIdentity()
|
||||
isShared := helper.mach.GetAccount().Shared
|
||||
device, ok := resp.DeviceKeys[helper.client.UserID][helper.client.DeviceID]
|
||||
if !ok || len(device.Keys) == 0 {
|
||||
if isShared {
|
||||
return fmt.Errorf("olm account is marked as shared, keys seem to have disappeared from the server")
|
||||
} else {
|
||||
helper.log.Debug().Msg("Olm account not shared and keys not on server, so device is probably fine")
|
||||
return nil
|
||||
}
|
||||
} else if !isShared {
|
||||
return fmt.Errorf("olm account is not marked as shared, but there are keys on the server")
|
||||
} else if ed := device.Keys.GetEd25519(helper.client.DeviceID); ownID.SigningKey != ed {
|
||||
return fmt.Errorf("mismatching identity key on server (%q != %q)", ownID.SigningKey, ed)
|
||||
}
|
||||
if !isShared {
|
||||
helper.log.Debug().Msg("Olm account not marked as shared, but keys on server match?")
|
||||
} else {
|
||||
helper.log.Debug().Msg("Olm account marked as shared and keys on server match, device is fine")
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
var NoSessionFound = crypto.NoSessionFound
|
||||
|
||||
const initialSessionWaitTimeout = 3 * time.Second
|
||||
const extendedSessionWaitTimeout = 22 * time.Second
|
||||
|
||||
func (helper *CryptoHelper) HandleEncrypted(src mautrix.EventSource, evt *event.Event) {
|
||||
if helper == nil {
|
||||
return
|
||||
}
|
||||
content := evt.Content.AsEncrypted()
|
||||
log := helper.log.With().
|
||||
Str("event_id", evt.ID.String()).
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Logger()
|
||||
log.Debug().Msg("Decrypting received event")
|
||||
|
||||
decrypted, err := helper.Decrypt(evt)
|
||||
if errors.Is(err, NoSessionFound) {
|
||||
log.Debug().
|
||||
Int("wait_seconds", int(initialSessionWaitTimeout.Seconds())).
|
||||
Msg("Couldn't find session, waiting for keys to arrive...")
|
||||
if helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, initialSessionWaitTimeout) {
|
||||
log.Debug().Msg("Got keys after waiting, trying to decrypt event again")
|
||||
decrypted, err = helper.Decrypt(evt)
|
||||
} else {
|
||||
go helper.waitLongerForSession(log, src, evt)
|
||||
return
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to decrypt event")
|
||||
helper.DecryptErrorCallback(evt, err)
|
||||
return
|
||||
}
|
||||
helper.postDecrypt(src, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) postDecrypt(src mautrix.EventSource, decrypted *event.Event) {
|
||||
helper.client.Syncer.(mautrix.DispatchableSyncer).Dispatch(src|mautrix.EventSourceDecrypted, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) RequestSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, userID id.UserID, deviceID id.DeviceID) {
|
||||
if helper == nil {
|
||||
return
|
||||
}
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
if deviceID == "" {
|
||||
deviceID = "*"
|
||||
}
|
||||
// TODO get log from context
|
||||
log := helper.log.With().
|
||||
Str("session_id", sessionID.String()).
|
||||
Str("user_id", userID.String()).
|
||||
Str("device_id", deviceID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Logger()
|
||||
err := helper.mach.SendRoomKeyRequest(roomID, senderKey, sessionID, "", map[id.UserID][]id.DeviceID{
|
||||
userID: {deviceID},
|
||||
helper.client.UserID: {"*"},
|
||||
})
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to send key request")
|
||||
} else {
|
||||
log.Debug().Msg("Sent key request")
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) waitLongerForSession(log zerolog.Logger, src mautrix.EventSource, evt *event.Event) {
|
||||
content := evt.Content.AsEncrypted()
|
||||
log.Debug().Int("wait_seconds", int(extendedSessionWaitTimeout.Seconds())).Msg("Couldn't find session, requesting keys and waiting longer...")
|
||||
|
||||
go helper.RequestSession(evt.RoomID, content.SenderKey, content.SessionID, evt.Sender, content.DeviceID)
|
||||
|
||||
if !helper.mach.WaitForSession(evt.RoomID, content.SenderKey, content.SessionID, extendedSessionWaitTimeout) {
|
||||
log.Debug().Msg("Didn't get session, giving up")
|
||||
helper.DecryptErrorCallback(evt, NoSessionFound)
|
||||
return
|
||||
}
|
||||
|
||||
log.Debug().Msg("Got keys after waiting longer, trying to decrypt event again")
|
||||
decrypted, err := helper.Decrypt(evt)
|
||||
if err != nil {
|
||||
log.Error().Err(err).Msg("Failed to decrypt event")
|
||||
helper.DecryptErrorCallback(evt, err)
|
||||
return
|
||||
}
|
||||
|
||||
helper.postDecrypt(src, decrypted)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) WaitForSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, timeout time.Duration) bool {
|
||||
if helper == nil {
|
||||
return false
|
||||
}
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
return helper.mach.WaitForSession(roomID, senderKey, sessionID, timeout)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Decrypt(evt *event.Event) (*event.Event, error) {
|
||||
if helper == nil {
|
||||
return nil, fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
return helper.mach.DecryptMegolmEvent(context.TODO(), evt)
|
||||
}
|
||||
|
||||
func (helper *CryptoHelper) Encrypt(roomID id.RoomID, evtType event.Type, content any) (encrypted *event.EncryptedEventContent, err error) {
|
||||
if helper == nil {
|
||||
return nil, fmt.Errorf("crypto helper is nil")
|
||||
}
|
||||
helper.lock.RLock()
|
||||
defer helper.lock.RUnlock()
|
||||
ctx := context.TODO()
|
||||
encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content)
|
||||
if err != nil {
|
||||
if err != crypto.SessionExpired && err != crypto.SessionNotShared && err != crypto.NoGroupSession {
|
||||
return
|
||||
}
|
||||
helper.log.Debug().
|
||||
Err(err).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Got session error while encrypting event, sharing group session and trying again")
|
||||
var users []id.UserID
|
||||
users, err = helper.client.StateStore.GetRoomJoinedOrInvitedMembers(roomID)
|
||||
if err != nil {
|
||||
err = fmt.Errorf("failed to get room member list: %w", err)
|
||||
} else if err = helper.mach.ShareGroupSession(ctx, roomID, users); err != nil {
|
||||
err = fmt.Errorf("failed to share group session: %w", err)
|
||||
} else if encrypted, err = helper.mach.EncryptMegolmEvent(ctx, roomID, evtType, content); err != nil {
|
||||
err = fmt.Errorf("failed to encrypt event after re-sharing group session: %w", err)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
154
vendor/maunium.net/go/mautrix/crypto/decryptmegolm.go
generated
vendored
154
vendor/maunium.net/go/mautrix/crypto/decryptmegolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,11 +7,14 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
@@ -23,6 +26,7 @@ var (
|
||||
WrongRoom = errors.New("encrypted megolm event is not intended for this room")
|
||||
DeviceKeyMismatch = errors.New("device keys in event and verified device info do not match")
|
||||
SenderKeyMismatch = errors.New("sender keys in content and megolm session do not match")
|
||||
RatchetError = errors.New("failed to ratchet session after use")
|
||||
)
|
||||
|
||||
type megolmEvent struct {
|
||||
@@ -32,13 +36,21 @@ type megolmEvent struct {
|
||||
}
|
||||
|
||||
// DecryptMegolmEvent decrypts an m.room.encrypted event where the algorithm is m.megolm.v1.aes-sha2
|
||||
func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, error) {
|
||||
func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event) (*event.Event, error) {
|
||||
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
|
||||
if !ok {
|
||||
return nil, IncorrectEncryptedContentType
|
||||
} else if content.Algorithm != id.AlgorithmMegolmV1 {
|
||||
return nil, UnsupportedAlgorithm
|
||||
}
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("action", "decrypt megolm event").
|
||||
Str("event_id", evt.ID.String()).
|
||||
Str("sender", evt.Sender.String()).
|
||||
Str("sender_key", content.SenderKey.String()).
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
encryptionRoomID := evt.RoomID
|
||||
// Allow the server to move encrypted events between rooms if both the real room and target room are on a non-federatable .local domain.
|
||||
// The message index checks to prevent replay attacks still apply and aren't based on the room ID,
|
||||
@@ -46,22 +58,11 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
|
||||
if origRoomID, ok := evt.Content.Raw["com.beeper.original_room_id"].(string); ok && strings.HasSuffix(origRoomID, ".local") && strings.HasSuffix(evt.RoomID.String(), ".local") {
|
||||
encryptionRoomID = id.RoomID(origRoomID)
|
||||
}
|
||||
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
sess, plaintext, messageIndex, err := mach.actuallyDecryptMegolmEvent(ctx, evt, encryptionRoomID, content)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
return nil, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
|
||||
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
|
||||
return nil, SenderKeyMismatch
|
||||
}
|
||||
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
||||
} else if ok, err = mach.CryptoStore.ValidateMessageIndex(sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
return nil, fmt.Errorf("failed to check if message index is duplicate: %w", err)
|
||||
} else if !ok {
|
||||
return nil, DuplicateMessageIndex
|
||||
return nil, err
|
||||
}
|
||||
log = log.With().Uint("message_index", messageIndex).Logger()
|
||||
|
||||
var trustLevel id.TrustState
|
||||
var forwardedKeys bool
|
||||
@@ -70,14 +71,16 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
|
||||
if sess.SigningKey == ownSigningKey && sess.SenderKey == ownIdentityKey && len(sess.ForwardingChains) == 0 {
|
||||
trustLevel = id.TrustStateVerified
|
||||
} else {
|
||||
device, err = mach.GetOrFetchDeviceByKey(evt.Sender, sess.SenderKey)
|
||||
device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey)
|
||||
if err != nil {
|
||||
// We don't want to throw these errors as the message can still be decrypted.
|
||||
mach.Log.Debug("Failed to get device %s/%s to verify session %s: %v", evt.Sender, sess.SenderKey, sess.ID(), err)
|
||||
log.Debug().Err(err).Msg("Failed to get device to verify session")
|
||||
trustLevel = id.TrustStateUnknownDevice
|
||||
} else if len(sess.ForwardingChains) == 0 || (len(sess.ForwardingChains) == 1 && sess.ForwardingChains[0] == sess.SenderKey.String()) {
|
||||
if device == nil {
|
||||
mach.Log.Debug("Couldn't resolve trust level of session %s: sent by unknown device %s/%s", sess.ID(), evt.Sender, sess.SenderKey)
|
||||
log.Debug().Err(err).
|
||||
Str("session_sender_key", sess.SenderKey.String()).
|
||||
Msg("Couldn't resolve trust level of session: sent by unknown device")
|
||||
trustLevel = id.TrustStateUnknownDevice
|
||||
} else if device.SigningKey != sess.SigningKey || device.IdentityKey != sess.SenderKey {
|
||||
return nil, DeviceKeyMismatch
|
||||
@@ -91,7 +94,9 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
|
||||
if device != nil {
|
||||
trustLevel = mach.ResolveTrust(device)
|
||||
} else {
|
||||
mach.Log.Debug("Couldn't resolve trust level of session %s: forwarding chain ends with unknown device %s", sess.ID(), lastChainItem)
|
||||
log.Debug().
|
||||
Str("forward_last_sender_key", lastChainItem).
|
||||
Msg("Couldn't resolve trust level of session: forwarding chain ends with unknown device")
|
||||
trustLevel = id.TrustStateForwarded
|
||||
}
|
||||
}
|
||||
@@ -105,10 +110,11 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
|
||||
return nil, WrongRoom
|
||||
}
|
||||
megolmEvt.Type.Class = evt.Type.Class
|
||||
log = log.With().Str("decrypted_event_type", megolmEvt.Type.Repr()).Logger()
|
||||
err = megolmEvt.Content.ParseRaw(megolmEvt.Type)
|
||||
if err != nil {
|
||||
if errors.Is(err, event.ErrUnsupportedContentType) {
|
||||
mach.Log.Warn("Unsupported event type %s in encrypted event %s", megolmEvt.Type.Repr(), evt.ID)
|
||||
log.Warn().Msg("Unsupported event type in encrypted event")
|
||||
} else {
|
||||
return nil, fmt.Errorf("failed to parse content of megolm payload event: %w", err)
|
||||
}
|
||||
@@ -119,13 +125,14 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
|
||||
if relatable.OptionalGetRelatesTo() == nil {
|
||||
relatable.SetRelatesTo(content.RelatesTo)
|
||||
} else {
|
||||
mach.Log.Trace("Not overriding relation data in %s, as encrypted payload already has it", evt.ID)
|
||||
log.Trace().Msg("Not overriding relation data as encrypted payload already has it")
|
||||
}
|
||||
}
|
||||
if _, hasRelation := megolmEvt.Content.Raw["m.relates_to"]; !hasRelation {
|
||||
megolmEvt.Content.Raw["m.relates_to"] = evt.Content.Raw["m.relates_to"]
|
||||
}
|
||||
}
|
||||
log.Debug().Msg("Event decrypted successfully")
|
||||
megolmEvt.Type.Class = evt.Type.Class
|
||||
return &event.Event{
|
||||
Sender: evt.Sender,
|
||||
@@ -144,3 +151,106 @@ func (mach *OlmMachine) DecryptMegolmEvent(evt *event.Event) (*event.Event, erro
|
||||
},
|
||||
}, nil
|
||||
}
|
||||
|
||||
func removeItem(slice []uint, item uint) ([]uint, bool) {
|
||||
for i, s := range slice {
|
||||
if s == item {
|
||||
return append(slice[:i], slice[i+1:]...), true
|
||||
}
|
||||
}
|
||||
return slice, false
|
||||
}
|
||||
|
||||
const missedIndexCutoff = 10
|
||||
|
||||
func (mach *OlmMachine) actuallyDecryptMegolmEvent(ctx context.Context, evt *event.Event, encryptionRoomID id.RoomID, content *event.EncryptedEventContent) (*InboundGroupSession, []byte, uint, error) {
|
||||
mach.megolmDecryptLock.Lock()
|
||||
defer mach.megolmDecryptLock.Unlock()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(encryptionRoomID, content.SenderKey, content.SessionID)
|
||||
if err != nil {
|
||||
return nil, nil, 0, fmt.Errorf("failed to get group session: %w", err)
|
||||
} else if sess == nil {
|
||||
return nil, nil, 0, fmt.Errorf("%w (ID %s)", NoSessionFound, content.SessionID)
|
||||
} else if content.SenderKey != "" && content.SenderKey != sess.SenderKey {
|
||||
return sess, nil, 0, SenderKeyMismatch
|
||||
}
|
||||
plaintext, messageIndex, err := sess.Internal.Decrypt(content.MegolmCiphertext)
|
||||
if err != nil {
|
||||
return sess, nil, 0, fmt.Errorf("failed to decrypt megolm event: %w", err)
|
||||
} else if ok, err := mach.CryptoStore.ValidateMessageIndex(ctx, sess.SenderKey, content.SessionID, evt.ID, messageIndex, evt.Timestamp); err != nil {
|
||||
return sess, nil, messageIndex, fmt.Errorf("failed to check if message index is duplicate: %w", err)
|
||||
} else if !ok {
|
||||
return sess, nil, messageIndex, DuplicateMessageIndex
|
||||
}
|
||||
|
||||
expectedMessageIndex := sess.RatchetSafety.NextIndex
|
||||
didModify := false
|
||||
switch {
|
||||
case messageIndex > expectedMessageIndex:
|
||||
// When the index jumps, add indices in between to the missed indices list.
|
||||
for i := expectedMessageIndex; i < messageIndex; i++ {
|
||||
sess.RatchetSafety.MissedIndices = append(sess.RatchetSafety.MissedIndices, i)
|
||||
}
|
||||
fallthrough
|
||||
case messageIndex == expectedMessageIndex:
|
||||
// When the index moves forward (to the next one or jumping ahead), update the last received index.
|
||||
sess.RatchetSafety.NextIndex = messageIndex + 1
|
||||
didModify = true
|
||||
default:
|
||||
sess.RatchetSafety.MissedIndices, didModify = removeItem(sess.RatchetSafety.MissedIndices, messageIndex)
|
||||
}
|
||||
// Use presence of ReceivedAt as a sign that this is a recent megolm session,
|
||||
// and therefore it's safe to drop missed indices entirely.
|
||||
if !sess.ReceivedAt.IsZero() && len(sess.RatchetSafety.MissedIndices) > 0 && int(sess.RatchetSafety.MissedIndices[0]) < int(sess.RatchetSafety.NextIndex)-missedIndexCutoff {
|
||||
limit := sess.RatchetSafety.NextIndex - missedIndexCutoff
|
||||
var cutoff int
|
||||
for ; cutoff < len(sess.RatchetSafety.MissedIndices) && sess.RatchetSafety.MissedIndices[cutoff] < limit; cutoff++ {
|
||||
}
|
||||
sess.RatchetSafety.LostIndices = append(sess.RatchetSafety.LostIndices, sess.RatchetSafety.MissedIndices[:cutoff]...)
|
||||
sess.RatchetSafety.MissedIndices = sess.RatchetSafety.MissedIndices[cutoff:]
|
||||
didModify = true
|
||||
}
|
||||
ratchetTargetIndex := uint32(sess.RatchetSafety.NextIndex)
|
||||
if len(sess.RatchetSafety.MissedIndices) > 0 {
|
||||
ratchetTargetIndex = uint32(sess.RatchetSafety.MissedIndices[0])
|
||||
}
|
||||
ratchetCurrentIndex := sess.Internal.FirstKnownIndex()
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Uint32("prev_ratchet_index", ratchetCurrentIndex).
|
||||
Uint32("new_ratchet_index", ratchetTargetIndex).
|
||||
Uint("next_new_index", sess.RatchetSafety.NextIndex).
|
||||
Uints("missed_indices", sess.RatchetSafety.MissedIndices).
|
||||
Uints("lost_indices", sess.RatchetSafety.LostIndices).
|
||||
Int("max_messages", sess.MaxMessages).
|
||||
Logger()
|
||||
if sess.MaxMessages > 0 && int(ratchetTargetIndex) >= sess.MaxMessages && len(sess.RatchetSafety.MissedIndices) == 0 && mach.DeleteFullyUsedKeysOnDecrypt {
|
||||
err = mach.CryptoStore.RedactGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), "maximum messages reached")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to delete fully used session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Info().Msg("Deleted fully used session")
|
||||
}
|
||||
} else if ratchetCurrentIndex < ratchetTargetIndex && mach.RatchetKeysOnDecrypt {
|
||||
if err = sess.RatchetTo(ratchetTargetIndex); err != nil {
|
||||
log.Err(err).Msg("Failed to ratchet session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store ratcheted session")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Info().Msg("Ratcheted session forward")
|
||||
}
|
||||
} else if didModify {
|
||||
if err = mach.CryptoStore.PutGroupSession(sess.RoomID, sess.SenderKey, sess.ID(), sess); err != nil {
|
||||
log.Err(err).Msg("Failed to store updated ratchet safety data")
|
||||
return sess, plaintext, messageIndex, RatchetError
|
||||
} else {
|
||||
log.Debug().Msg("Ratchet safety data changed (ratchet state didn't change)")
|
||||
}
|
||||
} else {
|
||||
log.Debug().Msg("Ratchet safety data didn't change")
|
||||
}
|
||||
return sess, plaintext, messageIndex, nil
|
||||
}
|
||||
|
||||
95
vendor/maunium.net/go/mautrix/crypto/decryptolm.go
generated
vendored
95
vendor/maunium.net/go/mautrix/crypto/decryptolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2021 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,11 +7,14 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
@@ -43,7 +46,7 @@ type DecryptedOlmEvent struct {
|
||||
Content event.Content `json:"content"`
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) decryptOlmEvent(evt *event.Event, traceID string) (*DecryptedOlmEvent, error) {
|
||||
func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (*DecryptedOlmEvent, error) {
|
||||
content, ok := evt.Content.Parsed.(*event.EncryptedEventContent)
|
||||
if !ok {
|
||||
return nil, IncorrectEncryptedContentType
|
||||
@@ -54,7 +57,7 @@ func (mach *OlmMachine) decryptOlmEvent(evt *event.Event, traceID string) (*Decr
|
||||
if !ok {
|
||||
return nil, NotEncryptedForMe
|
||||
}
|
||||
decrypted, err := mach.decryptAndParseOlmCiphertext(evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body, traceID)
|
||||
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -66,19 +69,19 @@ type OlmEventKeys struct {
|
||||
Ed25519 id.Ed25519 `json:"ed25519"`
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) (*DecryptedOlmEvent, error) {
|
||||
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
|
||||
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
|
||||
return nil, UnsupportedOlmMessageType
|
||||
}
|
||||
|
||||
endTimeTrace := mach.timeTrace("decrypting olm ciphertext", traceID, 5*time.Second)
|
||||
plaintext, err := mach.tryDecryptOlmCiphertext(sender, senderKey, olmType, ciphertext, traceID)
|
||||
endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second)
|
||||
plaintext, err := mach.tryDecryptOlmCiphertext(ctx, sender, senderKey, olmType, ciphertext)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
||||
defer mach.timeTrace("parsing decrypted olm event", traceID, time.Second)()
|
||||
defer mach.timeTrace(ctx, "parsing decrypted olm event", time.Second)()
|
||||
|
||||
var olmEvt DecryptedOlmEvent
|
||||
err = json.Unmarshal(plaintext, &olmEvt)
|
||||
@@ -103,17 +106,18 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(sender id.UserID, senderKey
|
||||
return &olmEvt, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) {
|
||||
endTimeTrace := mach.timeTrace("waiting for olm lock", traceID, 5*time.Second)
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
log := *zerolog.Ctx(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "waiting for olm lock", 5*time.Second)
|
||||
mach.olmLock.Lock()
|
||||
endTimeTrace()
|
||||
defer mach.olmLock.Unlock()
|
||||
|
||||
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(senderKey, olmType, ciphertext, traceID)
|
||||
plaintext, err := mach.tryDecryptOlmCiphertextWithExistingSession(ctx, senderKey, olmType, ciphertext)
|
||||
if err != nil {
|
||||
if err == DecryptionFailedWithMatchingSession {
|
||||
mach.Log.Warn("Found matching session yet decryption failed for sender %s with key %s", sender, senderKey)
|
||||
go mach.unwedgeDevice(sender, senderKey)
|
||||
log.Warn().Msg("Found matching session, but decryption failed")
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
}
|
||||
return nil, fmt.Errorf("failed to decrypt olm event: %w", err)
|
||||
}
|
||||
@@ -128,39 +132,44 @@ func (mach *OlmMachine) tryDecryptOlmCiphertext(sender id.UserID, senderKey id.S
|
||||
// New sessions can only be created if it's a prekey message, we can't decrypt the message
|
||||
// if it isn't one at this point in time anymore, so return early.
|
||||
if olmType != id.OlmMsgTypePreKey {
|
||||
go mach.unwedgeDevice(sender, senderKey)
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
return nil, DecryptionFailedForNormalMessage
|
||||
}
|
||||
|
||||
mach.Log.Trace("Trying to create inbound session for %s/%s", sender, senderKey)
|
||||
endTimeTrace = mach.timeTrace("creating inbound olm session", traceID, time.Second)
|
||||
session, err := mach.createInboundSession(senderKey, ciphertext)
|
||||
log.Trace().Msg("Trying to create inbound session")
|
||||
endTimeTrace = mach.timeTrace(ctx, "creating inbound olm session", time.Second)
|
||||
session, err := mach.createInboundSession(ctx, senderKey, ciphertext)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
go mach.unwedgeDevice(sender, senderKey)
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
return nil, fmt.Errorf("failed to create new session from prekey message: %w", err)
|
||||
}
|
||||
mach.Log.Debug("Created inbound olm session %s for %s/%s: %s", session.ID(), sender, senderKey, session.Describe())
|
||||
log = log.With().Str("new_olm_session_id", session.ID().String()).Logger()
|
||||
log.Debug().
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Created inbound olm session")
|
||||
ctx = log.WithContext(ctx)
|
||||
|
||||
endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting prekey olm message with %s/%s", senderKey, session.ID()), traceID, time.Second)
|
||||
endTimeTrace = mach.timeTrace(ctx, "decrypting prekey olm message", time.Second)
|
||||
plaintext, err = session.Decrypt(ciphertext, olmType)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
go mach.unwedgeDevice(sender, senderKey)
|
||||
go mach.unwedgeDevice(log, sender, senderKey)
|
||||
return nil, fmt.Errorf("failed to decrypt olm event with session created from prekey message: %w", err)
|
||||
}
|
||||
|
||||
endTimeTrace = mach.timeTrace(fmt.Sprintf("updating new session %s/%s in database", senderKey, session.ID()), traceID, time.Second)
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating new session in database", time.Second)
|
||||
err = mach.CryptoStore.UpdateSession(senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to update new olm session in crypto store after decrypting: %v", err)
|
||||
log.Warn().Err(err).Msg("Failed to update new olm session in crypto store after decrypting")
|
||||
}
|
||||
return plaintext, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string, traceID string) ([]byte, error) {
|
||||
endTimeTrace := mach.timeTrace(fmt.Sprintf("getting sessions with %s", senderKey), traceID, time.Second)
|
||||
func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(ctx context.Context, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) ([]byte, error) {
|
||||
log := *zerolog.Ctx(ctx)
|
||||
endTimeTrace := mach.timeTrace(ctx, "getting sessions with sender key", time.Second)
|
||||
sessions, err := mach.CryptoStore.GetSessions(senderKey)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
@@ -168,8 +177,10 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
|
||||
}
|
||||
|
||||
for _, session := range sessions {
|
||||
log := log.With().Str("olm_session_id", session.ID().String()).Logger()
|
||||
ctx := log.WithContext(ctx)
|
||||
if olmType == id.OlmMsgTypePreKey {
|
||||
endTimeTrace = mach.timeTrace(fmt.Sprintf("checking if prekey olm message matches session %s/%s", senderKey, session.ID()), traceID, time.Second)
|
||||
endTimeTrace = mach.timeTrace(ctx, "checking if prekey olm message matches session", time.Second)
|
||||
matches, err := session.Internal.MatchesInboundSession(ciphertext)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
@@ -178,8 +189,8 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
|
||||
continue
|
||||
}
|
||||
}
|
||||
mach.Log.Trace("Trying to decrypt olm message from %s with session %s: %s", senderKey, session.ID(), session.Describe())
|
||||
endTimeTrace = mach.timeTrace(fmt.Sprintf("decrypting olm message with %s/%s", senderKey, session.ID()), traceID, time.Second)
|
||||
log.Debug().Str("session_description", session.Describe()).Msg("Trying to decrypt olm message")
|
||||
endTimeTrace = mach.timeTrace(ctx, "decrypting olm message", time.Second)
|
||||
plaintext, err := session.Decrypt(ciphertext, olmType)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
@@ -187,20 +198,20 @@ func (mach *OlmMachine) tryDecryptOlmCiphertextWithExistingSession(senderKey id.
|
||||
return nil, DecryptionFailedWithMatchingSession
|
||||
}
|
||||
} else {
|
||||
endTimeTrace = mach.timeTrace(fmt.Sprintf("updating session %s/%s in database", senderKey, session.ID()), traceID, time.Second)
|
||||
endTimeTrace = mach.timeTrace(ctx, "updating session in database", time.Second)
|
||||
err = mach.CryptoStore.UpdateSession(senderKey, session)
|
||||
endTimeTrace()
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to update olm session in crypto store after decrypting: %v", err)
|
||||
log.Warn().Err(err).Msg("Failed to update olm session in crypto store after decrypting")
|
||||
}
|
||||
mach.Log.Trace("Decrypted olm message from %s with session %s", senderKey, session.ID())
|
||||
log.Debug().Msg("Decrypted olm message")
|
||||
return plaintext, nil
|
||||
}
|
||||
}
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext string) (*OlmSession, error) {
|
||||
func (mach *OlmMachine) createInboundSession(ctx context.Context, senderKey id.SenderKey, ciphertext string) (*OlmSession, error) {
|
||||
session, err := mach.account.NewInboundSessionFrom(senderKey, ciphertext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -208,40 +219,44 @@ func (mach *OlmMachine) createInboundSession(senderKey id.SenderKey, ciphertext
|
||||
mach.saveAccount()
|
||||
err = mach.CryptoStore.AddSession(senderKey, session)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to store created inbound session: %v", err)
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to store created inbound session")
|
||||
}
|
||||
return session, nil
|
||||
}
|
||||
|
||||
const MinUnwedgeInterval = 1 * time.Hour
|
||||
|
||||
func (mach *OlmMachine) unwedgeDevice(sender id.UserID, senderKey id.SenderKey) {
|
||||
func (mach *OlmMachine) unwedgeDevice(log zerolog.Logger, sender id.UserID, senderKey id.SenderKey) {
|
||||
log = log.With().Str("action", "unwedge olm session").Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
mach.recentlyUnwedgedLock.Lock()
|
||||
prevUnwedge, ok := mach.recentlyUnwedged[senderKey]
|
||||
delta := time.Now().Sub(prevUnwedge)
|
||||
if ok && delta < MinUnwedgeInterval {
|
||||
mach.Log.Debug("Not creating new Olm session with %s/%s, previous recreation was %s ago", sender, senderKey, delta)
|
||||
log.Debug().
|
||||
Str("previous_recreation", delta.String()).
|
||||
Msg("Not creating new Olm session as it was already recreated recently")
|
||||
mach.recentlyUnwedgedLock.Unlock()
|
||||
return
|
||||
}
|
||||
mach.recentlyUnwedged[senderKey] = time.Now()
|
||||
mach.recentlyUnwedgedLock.Unlock()
|
||||
|
||||
deviceIdentity, err := mach.GetOrFetchDeviceByKey(sender, senderKey)
|
||||
deviceIdentity, err := mach.GetOrFetchDeviceByKey(ctx, sender, senderKey)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to find device info by identity key: %v", err)
|
||||
log.Error().Err(err).Msg("Failed to find device info by identity key")
|
||||
return
|
||||
} else if deviceIdentity == nil {
|
||||
mach.Log.Warn("Didn't find identity of %s/%s, can't unwedge session", sender, senderKey)
|
||||
log.Warn().Msg("Didn't find identity for device")
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Debug("Creating new Olm session with %s/%s (key: %s)", sender, deviceIdentity.DeviceID, senderKey)
|
||||
log.Debug().Str("device_id", deviceIdentity.DeviceID.String()).Msg("Creating new Olm session")
|
||||
mach.devicesToUnwedgeLock.Lock()
|
||||
mach.devicesToUnwedge[senderKey] = true
|
||||
mach.devicesToUnwedgeLock.Unlock()
|
||||
err = mach.SendEncryptedToDevice(deviceIdentity, event.ToDeviceDummy, event.Content{})
|
||||
err = mach.SendEncryptedToDevice(ctx, deviceIdentity, event.ToDeviceDummy, event.Content{})
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to send dummy event to unwedge session with %s/%s: %v", sender, senderKey, err)
|
||||
log.Error().Err(err).Msg("Failed to send dummy event to unwedge session")
|
||||
}
|
||||
}
|
||||
|
||||
103
vendor/maunium.net/go/mautrix/crypto/devicelist.go
generated
vendored
103
vendor/maunium.net/go/mautrix/crypto/devicelist.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,9 +7,12 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/id"
|
||||
@@ -25,10 +28,12 @@ var (
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) LoadDevices(user id.UserID) map[id.DeviceID]*id.Device {
|
||||
return mach.fetchKeys([]id.UserID{user}, "", true)[user]
|
||||
// TODO proper context?
|
||||
return mach.fetchKeys(context.TODO(), []id.UserID{user}, "", true)[user]
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
|
||||
func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id.UserID, deviceID id.DeviceID, resp *mautrix.RespQueryKeys) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
deviceKeys := resp.DeviceKeys[userID][deviceID]
|
||||
for signerUserID, signerKeys := range deviceKeys.Signatures {
|
||||
for signerKey, signature := range signerKeys {
|
||||
@@ -43,17 +48,27 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.
|
||||
if verified, err := olm.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified {
|
||||
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
|
||||
signature := deviceKeys.Signatures[signerUserID][id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]
|
||||
mach.Log.Trace("Verified self-signing signature for device %s/%s: %s", signerUserID, deviceID, signature)
|
||||
log.Trace().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
Str("signed_device_id", deviceID.String()).
|
||||
Str("signature", signature).
|
||||
Msg("Verified self-signing signature")
|
||||
err = mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, pubKey, signature)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to store self-signing signature for device %s/%s: %v", signerUserID, deviceID, err)
|
||||
log.Warn().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
Str("signed_device_id", deviceID.String()).
|
||||
Msg("Failed to store self-signing signature")
|
||||
}
|
||||
}
|
||||
} else {
|
||||
if err == nil {
|
||||
err = errors.New("invalid signature")
|
||||
}
|
||||
mach.Log.Warn("Could not verify device self-signing signature for %s/%s: %v", signerUserID, deviceID, err)
|
||||
log.Warn().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
Str("signed_device_id", deviceID.String()).
|
||||
Msg("Failed to verify self-signing signature")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -61,25 +76,29 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(userID id.UserID, deviceID id.
|
||||
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
|
||||
err := mach.CryptoStore.PutSignature(userID, id.Ed25519(signKey), signerUserID, id.Ed25519(signKey), signature)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to store self-signing signature for %s/%s: %v", signerUserID, signKey, err)
|
||||
log.Warn().Err(err).
|
||||
Str("signer_user_id", signerUserID.String()).
|
||||
Str("signer_key", signKey).
|
||||
Msg("Failed to store self-signing signature")
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
|
||||
func (mach *OlmMachine) fetchKeys(ctx context.Context, users []id.UserID, sinceToken string, includeUntracked bool) (data map[id.UserID]map[id.DeviceID]*id.Device) {
|
||||
// TODO this function should probably return errors
|
||||
req := &mautrix.ReqQueryKeys{
|
||||
DeviceKeys: mautrix.DeviceKeysRequest{},
|
||||
Timeout: 10 * 1000,
|
||||
Token: sinceToken,
|
||||
}
|
||||
log := mach.machOrContextLog(ctx)
|
||||
if !includeUntracked {
|
||||
var err error
|
||||
users, err = mach.CryptoStore.FilterTrackedUsers(users)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to filter tracked user list: %v", err)
|
||||
log.Warn().Err(err).Msg("Failed to filter tracked user list")
|
||||
}
|
||||
}
|
||||
if len(users) == 0 {
|
||||
@@ -88,62 +107,88 @@ func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeU
|
||||
for _, userID := range users {
|
||||
req.DeviceKeys[userID] = mautrix.DeviceIDList{}
|
||||
}
|
||||
mach.Log.Trace("Querying keys for %v", users)
|
||||
log.Debug().Strs("users", strishArray(users)).Msg("Querying keys for users")
|
||||
resp, err := mach.Client.QueryKeys(req)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to query keys: %v", err)
|
||||
log.Error().Err(err).Msg("Failed to query keys")
|
||||
return
|
||||
}
|
||||
for server, err := range resp.Failures {
|
||||
mach.Log.Warn("Query keys failure for %s: %v", server, err)
|
||||
log.Warn().Interface("query_error", err).Str("server", server).Msg("Query keys failure for server")
|
||||
}
|
||||
mach.Log.Trace("Query key result received with %d users", len(resp.DeviceKeys))
|
||||
log.Trace().Int("user_count", len(resp.DeviceKeys)).Msg("Query key result received")
|
||||
data = make(map[id.UserID]map[id.DeviceID]*id.Device)
|
||||
for userID, devices := range resp.DeviceKeys {
|
||||
log := log.With().Str("user_id", userID.String()).Logger()
|
||||
delete(req.DeviceKeys, userID)
|
||||
|
||||
newDevices := make(map[id.DeviceID]*id.Device)
|
||||
existingDevices, err := mach.CryptoStore.GetDevices(userID)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to get existing devices for %s: %v", userID, err)
|
||||
log.Warn().Err(err).Msg("Failed to get existing devices for user")
|
||||
existingDevices = make(map[id.DeviceID]*id.Device)
|
||||
}
|
||||
mach.Log.Trace("Updating devices for %s, got %d devices, have %d in store", userID, len(devices), len(existingDevices))
|
||||
|
||||
log.Debug().
|
||||
Int("new_device_count", len(devices)).
|
||||
Int("old_device_count", len(existingDevices)).
|
||||
Msg("Updating devices in store")
|
||||
changed := false
|
||||
for deviceID, deviceKeys := range devices {
|
||||
log := log.With().Str("device_id", deviceID.String()).Logger()
|
||||
existing, ok := existingDevices[deviceID]
|
||||
if !ok {
|
||||
// New device
|
||||
changed = true
|
||||
}
|
||||
mach.Log.Trace("Validating device %s of %s", deviceID, userID)
|
||||
log.Trace().Msg("Validating device")
|
||||
newDevice, err := mach.validateDevice(userID, deviceID, deviceKeys, existing)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to validate device %s of %s: %v", deviceID, userID, err)
|
||||
log.Error().Err(err).Msg("Failed to validate device")
|
||||
} else if newDevice != nil {
|
||||
newDevices[deviceID] = newDevice
|
||||
mach.storeDeviceSelfSignatures(userID, deviceID, resp)
|
||||
mach.storeDeviceSelfSignatures(ctx, userID, deviceID, resp)
|
||||
}
|
||||
}
|
||||
mach.Log.Trace("Storing new device list for %s containing %d devices", userID, len(newDevices))
|
||||
log.Trace().Int("new_device_count", len(newDevices)).Msg("Storing new device list")
|
||||
err = mach.CryptoStore.PutDevices(userID, newDevices)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to update device list for %s: %v", userID, err)
|
||||
log.Warn().Err(err).Msg("Failed to update device list")
|
||||
}
|
||||
data[userID] = newDevices
|
||||
|
||||
changed = changed || len(newDevices) != len(existingDevices)
|
||||
if changed {
|
||||
if mach.DeleteKeysOnDeviceDelete {
|
||||
for deviceID := range newDevices {
|
||||
delete(existingDevices, deviceID)
|
||||
}
|
||||
for _, device := range existingDevices {
|
||||
log := log.With().
|
||||
Str("device_id", device.DeviceID.String()).
|
||||
Str("identity_key", device.IdentityKey.String()).
|
||||
Str("signing_key", device.SigningKey.String()).
|
||||
Logger()
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions("", device.IdentityKey, "device removed")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact megolm sessions from deleted device")
|
||||
} else {
|
||||
log.Info().
|
||||
Strs("session_ids", stringifyArray(sessionIDs)).
|
||||
Msg("Redacted megolm sessions from deleted device")
|
||||
}
|
||||
}
|
||||
}
|
||||
mach.OnDevicesChanged(userID)
|
||||
}
|
||||
}
|
||||
for userID := range req.DeviceKeys {
|
||||
mach.Log.Warn("Didn't get any keys for user %s", userID)
|
||||
log.Warn().Str("user_id", userID.String()).Msg("Didn't get any keys for user")
|
||||
}
|
||||
|
||||
mach.storeCrossSigningKeys(resp.MasterKeys, resp.DeviceKeys)
|
||||
mach.storeCrossSigningKeys(resp.SelfSigningKeys, resp.DeviceKeys)
|
||||
mach.storeCrossSigningKeys(resp.UserSigningKeys, resp.DeviceKeys)
|
||||
mach.storeCrossSigningKeys(ctx, resp.MasterKeys, resp.DeviceKeys)
|
||||
mach.storeCrossSigningKeys(ctx, resp.SelfSigningKeys, resp.DeviceKeys)
|
||||
mach.storeCrossSigningKeys(ctx, resp.UserSigningKeys, resp.DeviceKeys)
|
||||
|
||||
return data
|
||||
}
|
||||
@@ -154,10 +199,16 @@ func (mach *OlmMachine) fetchKeys(users []id.UserID, sinceToken string, includeU
|
||||
// not need to be called manually.
|
||||
func (mach *OlmMachine) OnDevicesChanged(userID id.UserID) {
|
||||
for _, roomID := range mach.StateStore.FindSharedRooms(userID) {
|
||||
mach.Log.Debug("Devices of %s changed, invalidating group session for %s", userID, roomID)
|
||||
mach.Log.Debug().
|
||||
Str("user_id", userID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Invalidating group session in room due to device change notification")
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(roomID)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to invalidate outbound group session of %s on device change for %s: %v", roomID, userID, err)
|
||||
mach.Log.Warn().Err(err).
|
||||
Str("user_id", userID.String()).
|
||||
Str("room_id", roomID.String()).
|
||||
Msg("Failed to invalidate outbound group session")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
195
vendor/maunium.net/go/mautrix/crypto/encryptmegolm.go
generated
vendored
195
vendor/maunium.net/go/mautrix/crypto/encryptmegolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,10 +7,15 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/base64"
|
||||
"encoding/binary"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
@@ -33,6 +38,18 @@ func getRelatesTo(content interface{}) *event.RelatesTo {
|
||||
return nil
|
||||
}
|
||||
|
||||
func getMentions(content interface{}) *event.Mentions {
|
||||
contentStruct, ok := content.(*event.Content)
|
||||
if ok {
|
||||
content = contentStruct.Parsed
|
||||
}
|
||||
message, ok := content.(*event.MessageEventContent)
|
||||
if ok {
|
||||
return message.Mentions
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
type rawMegolmEvent struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
Type event.Type `json:"type"`
|
||||
@@ -44,12 +61,29 @@ func IsShareError(err error) bool {
|
||||
return err == SessionExpired || err == SessionNotShared || err == NoGroupSession
|
||||
}
|
||||
|
||||
func parseMessageIndex(ciphertext []byte) (uint64, error) {
|
||||
decoded := make([]byte, base64.RawStdEncoding.DecodedLen(len(ciphertext)))
|
||||
var err error
|
||||
_, err = base64.RawStdEncoding.Decode(decoded, ciphertext)
|
||||
if err != nil {
|
||||
return 0, err
|
||||
} else if decoded[0] != 3 || decoded[1] != 8 {
|
||||
return 0, fmt.Errorf("unexpected initial bytes %d and %d", decoded[0], decoded[1])
|
||||
}
|
||||
index, read := binary.Uvarint(decoded[2 : 2+binary.MaxVarintLen64])
|
||||
if read <= 0 {
|
||||
return 0, fmt.Errorf("failed to decode varint, read value %d", read)
|
||||
}
|
||||
return index, nil
|
||||
}
|
||||
|
||||
// EncryptMegolmEvent encrypts data with the m.megolm.v1.aes-sha2 algorithm.
|
||||
//
|
||||
// If you use the event.Content struct, make sure you pass a pointer to the struct,
|
||||
// as JSON serialization will not work correctly otherwise.
|
||||
func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
|
||||
mach.Log.Trace("Encrypting event of type %s for %s", evtType.Type, roomID)
|
||||
func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID, evtType event.Type, content interface{}) (*event.EncryptedEventContent, error) {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get outbound group session: %w", err)
|
||||
@@ -64,15 +98,28 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("event_type", evtType.Type).
|
||||
Str("room_id", roomID.String()).
|
||||
Str("session_id", session.ID().String()).
|
||||
Logger()
|
||||
log.Trace().Msg("Encrypting event...")
|
||||
ciphertext, err := session.Encrypt(plaintext)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
idx, err := parseMessageIndex(ciphertext)
|
||||
if err != nil {
|
||||
log.Warn().Err(err).Msg("Failed to get megolm message index of encrypted event")
|
||||
} else {
|
||||
log = log.With().Uint64("message_index", idx).Logger()
|
||||
}
|
||||
log.Debug().Msg("Encrypted event successfully")
|
||||
err = mach.CryptoStore.UpdateOutboundGroupSession(session)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to update megolm session in crypto store after encrypting: %v", err)
|
||||
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting")
|
||||
}
|
||||
return &event.EncryptedEventContent{
|
||||
encrypted := &event.EncryptedEventContent{
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
SessionID: session.ID(),
|
||||
MegolmCiphertext: ciphertext,
|
||||
@@ -81,13 +128,19 @@ func (mach *OlmMachine) EncryptMegolmEvent(roomID id.RoomID, evtType event.Type,
|
||||
// These are deprecated
|
||||
SenderKey: mach.account.IdentityKey(),
|
||||
DeviceID: mach.Client.DeviceID,
|
||||
}, nil
|
||||
}
|
||||
if mach.PlaintextMentions {
|
||||
encrypted.Mentions = getMentions(content)
|
||||
}
|
||||
return encrypted, nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) newOutboundGroupSession(roomID id.RoomID) *OutboundGroupSession {
|
||||
func (mach *OlmMachine) newOutboundGroupSession(ctx context.Context, roomID id.RoomID) *OutboundGroupSession {
|
||||
session := NewOutboundGroupSession(roomID, mach.StateStore.GetEncryptionEvent(roomID))
|
||||
signingKey, idKey := mach.account.Keys()
|
||||
mach.createGroupSession(idKey, signingKey, roomID, session.ID(), session.Internal.Key(), "create")
|
||||
if !mach.DontStoreOutboundKeys {
|
||||
signingKey, idKey := mach.account.Keys()
|
||||
mach.createGroupSession(ctx, idKey, signingKey, roomID, session.ID(), session.Internal.Key(), session.MaxAge, session.MaxMessages, false)
|
||||
}
|
||||
return session
|
||||
}
|
||||
|
||||
@@ -96,21 +149,38 @@ type deviceSessionWrapper struct {
|
||||
identity *id.Device
|
||||
}
|
||||
|
||||
func strishArray[T ~string](arr []T) []string {
|
||||
out := make([]string, len(arr))
|
||||
for i, item := range arr {
|
||||
out[i] = string(item)
|
||||
}
|
||||
return out
|
||||
}
|
||||
|
||||
// ShareGroupSession shares a group session for a specific room with all the devices of the given user list.
|
||||
//
|
||||
// For devices with TrustStateBlacklisted, a m.room_key.withheld event with code=m.blacklisted is sent.
|
||||
// If AllowUnverifiedDevices is false, a similar event with code=m.unverified is sent to devices with TrustStateUnset
|
||||
func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) error {
|
||||
mach.Log.Debug("Sharing group session for room %s to %v", roomID, users)
|
||||
func (mach *OlmMachine) ShareGroupSession(ctx context.Context, roomID id.RoomID, users []id.UserID) error {
|
||||
mach.megolmEncryptLock.Lock()
|
||||
defer mach.megolmEncryptLock.Unlock()
|
||||
session, err := mach.CryptoStore.GetOutboundGroupSession(roomID)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to get previous outbound group session: %w", err)
|
||||
} else if session != nil && session.Shared && !session.Expired() {
|
||||
return AlreadyShared
|
||||
}
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("room_id", roomID.String()).
|
||||
Str("action", "share megolm session").
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
if session == nil || session.Expired() {
|
||||
session = mach.newOutboundGroupSession(roomID)
|
||||
session = mach.newOutboundGroupSession(ctx, roomID)
|
||||
}
|
||||
log = log.With().Str("session_id", session.ID().String()).Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
log.Debug().Strs("users", strishArray(users)).Msg("Sharing group session for room")
|
||||
|
||||
withheldCount := 0
|
||||
toDeviceWithheld := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
||||
@@ -120,20 +190,25 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
||||
var fetchKeys []id.UserID
|
||||
|
||||
for _, userID := range users {
|
||||
log := log.With().Str("target_user_id", userID.String()).Logger()
|
||||
devices, err := mach.CryptoStore.GetDevices(userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to get devices of %s", userID)
|
||||
log.Error().Err(err).Msg("Failed to get devices of user")
|
||||
} else if devices == nil {
|
||||
mach.Log.Trace("GetDevices returned nil for %s, will fetch keys and retry", userID)
|
||||
log.Debug().Msg("GetDevices returned nil, will fetch keys and retry")
|
||||
fetchKeys = append(fetchKeys, userID)
|
||||
} else if len(devices) == 0 {
|
||||
mach.Log.Trace("%s has no devices, skipping", userID)
|
||||
log.Trace().Msg("User has no devices, skipping")
|
||||
} else {
|
||||
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s", session.ID(), userID)
|
||||
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user")
|
||||
toDeviceWithheld.Messages[userID] = make(map[id.DeviceID]*event.Content)
|
||||
olmSessions[userID] = make(map[id.DeviceID]deviceSessionWrapper)
|
||||
mach.findOlmSessionsForUser(session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
|
||||
mach.Log.Trace("Found %d sessions, withholding from %d sessions and missing %d sessions to encrypt %s for for %s", len(olmSessions[userID]), len(toDeviceWithheld.Messages[userID]), len(missingUserSessions), session.ID(), userID)
|
||||
mach.findOlmSessionsForUser(ctx, session, userID, devices, olmSessions[userID], toDeviceWithheld.Messages[userID], missingUserSessions)
|
||||
log.Debug().
|
||||
Int("olm_session_count", len(olmSessions[userID])).
|
||||
Int("withheld_count", len(toDeviceWithheld.Messages[userID])).
|
||||
Int("missing_count", len(missingUserSessions)).
|
||||
Msg("Completed first pass of finding olm sessions")
|
||||
withheldCount += len(toDeviceWithheld.Messages[userID])
|
||||
if len(missingUserSessions) > 0 {
|
||||
missingSessions[userID] = missingUserSessions
|
||||
@@ -146,18 +221,21 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
||||
}
|
||||
|
||||
if len(fetchKeys) > 0 {
|
||||
mach.Log.Trace("Fetching missing keys for %v", fetchKeys)
|
||||
for userID, devices := range mach.fetchKeys(fetchKeys, "", true) {
|
||||
mach.Log.Trace("Got %d device keys for %s", len(devices), userID)
|
||||
log.Debug().Strs("users", strishArray(fetchKeys)).Msg("Fetching missing keys")
|
||||
for userID, devices := range mach.fetchKeys(ctx, fetchKeys, "", true) {
|
||||
log.Debug().
|
||||
Int("device_count", len(devices)).
|
||||
Str("target_user_id", userID.String()).
|
||||
Msg("Got device keys for user")
|
||||
missingSessions[userID] = devices
|
||||
}
|
||||
}
|
||||
|
||||
if len(missingSessions) > 0 {
|
||||
mach.Log.Trace("Creating missing outbound sessions")
|
||||
err = mach.createOutboundSessions(missingSessions)
|
||||
log.Debug().Msg("Creating missing olm sessions")
|
||||
err = mach.createOutboundSessions(ctx, missingSessions)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to create missing outbound sessions: %v", err)
|
||||
log.Error().Err(err).Msg("Failed to create missing olm sessions")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -176,42 +254,51 @@ func (mach *OlmMachine) ShareGroupSession(roomID id.RoomID, users []id.UserID) e
|
||||
withheld = make(map[id.DeviceID]*event.Content)
|
||||
toDeviceWithheld.Messages[userID] = withheld
|
||||
}
|
||||
mach.Log.Trace("Trying to find olm sessions to encrypt %s for %s (post-fetch retry)", session.ID(), userID)
|
||||
mach.findOlmSessionsForUser(session, userID, devices, output, withheld, nil)
|
||||
mach.Log.Trace("Found %d sessions and withholding from %d sessions to encrypt %s for for %s (post-fetch retry)", len(output), len(withheld), session.ID(), userID)
|
||||
|
||||
log := log.With().Str("target_user_id", userID.String()).Logger()
|
||||
log.Trace().Msg("Trying to find olm session to encrypt megolm session for user (post-fetch retry)")
|
||||
mach.findOlmSessionsForUser(ctx, session, userID, devices, output, withheld, nil)
|
||||
log.Debug().
|
||||
Int("olm_session_count", len(output)).
|
||||
Int("withheld_count", len(withheld)).
|
||||
Msg("Completed post-fetch retry of finding olm sessions")
|
||||
withheldCount += len(toDeviceWithheld.Messages[userID])
|
||||
if len(toDeviceWithheld.Messages[userID]) == 0 {
|
||||
delete(toDeviceWithheld.Messages, userID)
|
||||
}
|
||||
}
|
||||
|
||||
err = mach.encryptAndSendGroupSession(session, olmSessions)
|
||||
err = mach.encryptAndSendGroupSession(ctx, session, olmSessions)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to share group session: %w", err)
|
||||
}
|
||||
|
||||
if len(toDeviceWithheld.Messages) > 0 {
|
||||
mach.Log.Trace("Sending to-device messages to %d devices of %d users to report withheld keys in %s", withheldCount, len(toDeviceWithheld.Messages), roomID)
|
||||
log.Debug().
|
||||
Int("device_count", withheldCount).
|
||||
Int("user_count", len(toDeviceWithheld.Messages)).
|
||||
Msg("Sending to-device messages to report withheld key")
|
||||
// TODO remove the next 4 lines once clients support m.room_key.withheld
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceOrgMatrixRoomKeyWithheld, toDeviceWithheld)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to report withheld keys in %s (legacy event type): %v", roomID, err)
|
||||
log.Warn().Err(err).Msg("Failed to report withheld keys (legacy event type)")
|
||||
}
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceRoomKeyWithheld, toDeviceWithheld)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to report withheld keys in %s: %v", roomID, err)
|
||||
log.Warn().Err(err).Msg("Failed to report withheld keys")
|
||||
}
|
||||
}
|
||||
|
||||
mach.Log.Debug("Group session %s for %s successfully shared", session.ID(), roomID)
|
||||
log.Debug().Msg("Group session successfully shared")
|
||||
session.Shared = true
|
||||
return mach.CryptoStore.AddOutboundGroupSession(session)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
|
||||
func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session *OutboundGroupSession, olmSessions map[id.UserID]map[id.DeviceID]deviceSessionWrapper) error {
|
||||
mach.olmLock.Lock()
|
||||
defer mach.olmLock.Unlock()
|
||||
mach.Log.Trace("Encrypting group session %s for all found devices", session.ID())
|
||||
log := zerolog.Ctx(ctx)
|
||||
log.Trace().Msg("Encrypting group session for all found devices")
|
||||
deviceCount := 0
|
||||
toDevice := &mautrix.ReqSendToDevice{Messages: make(map[id.UserID]map[id.DeviceID]*event.Content)}
|
||||
for userID, sessions := range olmSessions {
|
||||
@@ -221,31 +308,41 @@ func (mach *OlmMachine) encryptAndSendGroupSession(session *OutboundGroupSession
|
||||
output := make(map[id.DeviceID]*event.Content)
|
||||
toDevice.Messages[userID] = output
|
||||
for deviceID, device := range sessions {
|
||||
mach.Log.Trace("Encrypting group session %s for %s of %s", session.ID(), deviceID, userID)
|
||||
content := mach.encryptOlmEvent(device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
|
||||
log.Trace().
|
||||
Str("target_user_id", userID.String()).
|
||||
Str("target_device_id", deviceID.String()).
|
||||
Msg("Encrypting group session for device")
|
||||
content := mach.encryptOlmEvent(ctx, device.session, device.identity, event.ToDeviceRoomKey, session.ShareContent())
|
||||
output[deviceID] = &event.Content{Parsed: content}
|
||||
deviceCount++
|
||||
mach.Log.Trace("Encrypted group session %s for %s of %s", session.ID(), deviceID, userID)
|
||||
log.Debug().
|
||||
Str("target_user_id", userID.String()).
|
||||
Str("target_device_id", deviceID.String()).
|
||||
Msg("Encrypted group session for device")
|
||||
}
|
||||
}
|
||||
|
||||
mach.Log.Trace("Sending to-device to %d devices of %d users to share group session %s", deviceCount, len(toDevice.Messages), session.ID())
|
||||
log.Debug().
|
||||
Int("device_count", deviceCount).
|
||||
Int("user_count", len(toDevice.Messages)).
|
||||
Msg("Sending to-device messages to share group session")
|
||||
_, err := mach.Client.SendToDevice(event.ToDeviceEncrypted, toDevice)
|
||||
return err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
|
||||
func (mach *OlmMachine) findOlmSessionsForUser(ctx context.Context, session *OutboundGroupSession, userID id.UserID, devices map[id.DeviceID]*id.Device, output map[id.DeviceID]deviceSessionWrapper, withheld map[id.DeviceID]*event.Content, missingOutput map[id.DeviceID]*id.Device) {
|
||||
for deviceID, device := range devices {
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("target_user_id", userID.String()).
|
||||
Str("target_device_id", deviceID.String()).
|
||||
Logger()
|
||||
userKey := UserDevice{UserID: userID, DeviceID: deviceID}
|
||||
if state := session.Users[userKey]; state != OGSNotShared {
|
||||
continue
|
||||
} else if userID == mach.Client.UserID && deviceID == mach.Client.DeviceID {
|
||||
session.Users[userKey] = OGSIgnored
|
||||
} else if device.Trust == id.TrustStateBlacklisted {
|
||||
mach.Log.Debug(
|
||||
"Not encrypting group session %s for %s of %s: device is blacklisted",
|
||||
session.ID(), deviceID, userID,
|
||||
)
|
||||
log.Debug().Msg("Not encrypting group session for device: device is blacklisted")
|
||||
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
|
||||
RoomID: session.RoomID,
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
@@ -256,10 +353,10 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
|
||||
}}
|
||||
session.Users[userKey] = OGSIgnored
|
||||
} else if trustState := mach.ResolveTrust(device); trustState < mach.SendKeysMinTrust {
|
||||
mach.Log.Debug(
|
||||
"Not encrypting group session %s for %s of %s: device is not verified (minimum: %s, device: %s)",
|
||||
session.ID(), deviceID, userID, mach.SendKeysMinTrust, trustState,
|
||||
)
|
||||
log.Debug().
|
||||
Str("min_trust", mach.SendKeysMinTrust.String()).
|
||||
Str("device_trust", trustState.String()).
|
||||
Msg("Not encrypting group session for device: device is not trusted")
|
||||
withheld[deviceID] = &event.Content{Parsed: &event.RoomKeyWithheldEventContent{
|
||||
RoomID: session.RoomID,
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
@@ -270,9 +367,9 @@ func (mach *OlmMachine) findOlmSessionsForUser(session *OutboundGroupSession, us
|
||||
}}
|
||||
session.Users[userKey] = OGSIgnored
|
||||
} else if deviceSession, err := mach.CryptoStore.GetLatestSession(device.IdentityKey); err != nil {
|
||||
mach.Log.Error("Failed to get session for %s of %s: %v", deviceID, userID, err)
|
||||
log.Error().Err(err).Msg("Failed to get olm session to encrypt group session")
|
||||
} else if deviceSession == nil {
|
||||
mach.Log.Warn("Didn't find a session for %s of %s", deviceID, userID)
|
||||
log.Warn().Err(err).Msg("Didn't find olm session to encrypt group session")
|
||||
if missingOutput != nil {
|
||||
missingOutput[deviceID] = device
|
||||
}
|
||||
|
||||
36
vendor/maunium.net/go/mautrix/crypto/encryptolm.go
generated
vendored
36
vendor/maunium.net/go/mautrix/crypto/encryptolm.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,6 +7,7 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
@@ -16,7 +17,7 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
func (mach *OlmMachine) encryptOlmEvent(session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent {
|
||||
func (mach *OlmMachine) encryptOlmEvent(ctx context.Context, session *OlmSession, recipient *id.Device, evtType event.Type, content event.Content) *event.EncryptedEventContent {
|
||||
evt := &DecryptedOlmEvent{
|
||||
Sender: mach.Client.UserID,
|
||||
SenderDevice: mach.Client.DeviceID,
|
||||
@@ -30,11 +31,16 @@ func (mach *OlmMachine) encryptOlmEvent(session *OlmSession, recipient *id.Devic
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
mach.Log.Trace("Encrypting olm message for %s with session %s: %s", recipient.IdentityKey, session.ID(), session.Describe())
|
||||
log := mach.machOrContextLog(ctx)
|
||||
log.Debug().
|
||||
Str("recipient_identity_key", recipient.IdentityKey.String()).
|
||||
Str("olm_session_id", session.ID().String()).
|
||||
Str("olm_session_description", session.Describe()).
|
||||
Msg("Encrypting olm message")
|
||||
msgType, ciphertext := session.Encrypt(plaintext)
|
||||
err = mach.CryptoStore.UpdateSession(recipient.IdentityKey, session)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to update olm session in crypto store after encrypting: %v", err)
|
||||
log.Error().Err(err).Msg("Failed to update olm session in crypto store after encrypting")
|
||||
}
|
||||
return &event.EncryptedEventContent{
|
||||
Algorithm: id.AlgorithmOlmV1,
|
||||
@@ -61,7 +67,7 @@ func (mach *OlmMachine) shouldCreateNewSession(identityKey id.IdentityKey) bool
|
||||
return shouldUnwedge
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.DeviceID]*id.Device) error {
|
||||
func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id.UserID]map[id.DeviceID]*id.Device) error {
|
||||
request := make(mautrix.OneTimeKeysRequest)
|
||||
for userID, devices := range input {
|
||||
request[userID] = make(map[id.DeviceID]id.KeyAlgorithm)
|
||||
@@ -84,6 +90,7 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to claim keys: %w", err)
|
||||
}
|
||||
log := mach.machOrContextLog(ctx)
|
||||
for userID, user := range resp.OneTimeKeys {
|
||||
for deviceID, oneTimeKeys := range user {
|
||||
var oneTimeKey mautrix.OneTimeKey
|
||||
@@ -91,25 +98,30 @@ func (mach *OlmMachine) createOutboundSessions(input map[id.UserID]map[id.Device
|
||||
for keyID, oneTimeKey = range oneTimeKeys {
|
||||
break
|
||||
}
|
||||
keyAlg, keyIndex := keyID.Parse()
|
||||
log := log.With().
|
||||
Str("peer_user_id", userID.String()).
|
||||
Str("peer_device_id", deviceID.String()).
|
||||
Str("peer_otk_id", keyID.String()).
|
||||
Logger()
|
||||
keyAlg, _ := keyID.Parse()
|
||||
if keyAlg != id.KeyAlgorithmSignedCurve25519 {
|
||||
mach.Log.Warn("Unexpected key ID algorithm in one-time key response for %s of %s: %s", deviceID, userID, keyID)
|
||||
log.Warn().Msg("Unexpected key ID algorithm in one-time key response")
|
||||
continue
|
||||
}
|
||||
identity := input[userID][deviceID]
|
||||
if ok, err := olm.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil {
|
||||
mach.Log.Error("Failed to verify signature for %s of %s: %v", deviceID, userID, err)
|
||||
log.Error().Err(err).Msg("Failed to verify signature of one-time key")
|
||||
} else if !ok {
|
||||
mach.Log.Warn("Invalid signature for %s of %s", deviceID, userID)
|
||||
log.Warn().Msg("One-time key has invalid signature from device")
|
||||
} else if sess, err := mach.account.Internal.NewOutboundSession(identity.IdentityKey, oneTimeKey.Key); err != nil {
|
||||
mach.Log.Error("Failed to create outbound session for %s of %s: %v", deviceID, userID, err)
|
||||
log.Error().Err(err).Msg("Failed to create outbound session with claimed one-time key")
|
||||
} else {
|
||||
wrapped := wrapSession(sess)
|
||||
err = mach.CryptoStore.AddSession(identity.IdentityKey, wrapped)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to store created session for %s of %s: %v", deviceID, userID, err)
|
||||
log.Error().Err(err).Msg("Failed to store created outbound session")
|
||||
} else {
|
||||
mach.Log.Debug("Created new Olm session with %s/%s (OTK ID: %s)", userID, deviceID, keyIndex)
|
||||
log.Debug().Msg("Created new Olm session")
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/crypto/keyexport.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/keyexport.go
generated
vendored
@@ -103,7 +103,7 @@ func exportSessions(sessions []*InboundGroupSession) ([]ExportedSession, error)
|
||||
SenderKey: session.SenderKey,
|
||||
SenderClaimedKeys: SenderClaimedKeys{},
|
||||
SessionID: session.ID(),
|
||||
SessionKey: key,
|
||||
SessionKey: string(key),
|
||||
}
|
||||
}
|
||||
return export, nil
|
||||
|
||||
15
vendor/maunium.net/go/mautrix/crypto/keyimport.go
generated
vendored
15
vendor/maunium.net/go/mautrix/crypto/keyimport.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -17,6 +17,7 @@ import (
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/id"
|
||||
@@ -108,6 +109,8 @@ func (mach *OlmMachine) importExportedRoomKey(session ExportedSession) (bool, er
|
||||
RoomID: session.RoomID,
|
||||
// TODO should we add something here to mark the signing key as unverified like key requests do?
|
||||
ForwardingChains: session.ForwardingChains,
|
||||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
}
|
||||
existingIGS, _ := mach.CryptoStore.GetGroupSession(igs.RoomID, igs.SenderKey, igs.ID())
|
||||
if existingIGS != nil && existingIGS.Internal.FirstKnownIndex() <= igs.Internal.FirstKnownIndex() {
|
||||
@@ -136,14 +139,18 @@ func (mach *OlmMachine) ImportKeys(passphrase string, data []byte) (int, int, er
|
||||
|
||||
count := 0
|
||||
for _, session := range sessions {
|
||||
log := mach.Log.With().
|
||||
Str("room_id", session.RoomID.String()).
|
||||
Str("session_id", session.SessionID.String()).
|
||||
Logger()
|
||||
imported, err := mach.importExportedRoomKey(session)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to import Megolm session %s/%s from file: %v", session.RoomID, session.SessionID, err)
|
||||
log.Error().Err(err).Msg("Failed to import Megolm session from file")
|
||||
} else if imported {
|
||||
mach.Log.Debug("Imported Megolm session %s/%s from file", session.RoomID, session.SessionID)
|
||||
log.Debug().Msg("Imported Megolm session from file")
|
||||
count++
|
||||
} else {
|
||||
mach.Log.Debug("Skipped Megolm session %s/%s: already in store", session.RoomID, session.SessionID)
|
||||
log.Debug().Msg("Skipped Megolm session which is already in the store")
|
||||
}
|
||||
}
|
||||
return count, len(sessions), nil
|
||||
|
||||
164
vendor/maunium.net/go/mautrix/crypto/keysharing.go
generated
vendored
164
vendor/maunium.net/go/mautrix/crypto/keysharing.go
generated
vendored
@@ -1,16 +1,18 @@
|
||||
// Copyright (c) 2020 Nikos Filippakis
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build !nosas
|
||||
// +build !nosas
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/id"
|
||||
@@ -56,11 +58,11 @@ func (mach *OlmMachine) RequestRoomKey(ctx context.Context, toUser id.UserID, to
|
||||
select {
|
||||
case <-keyResponseReceived:
|
||||
// key request successful
|
||||
mach.Log.Debug("Key for session %v was received, cancelling other key requests", sessionID)
|
||||
mach.Log.Debug().Msgf("Key for session %v was received, cancelling other key requests", sessionID)
|
||||
resChan <- true
|
||||
case <-ctx.Done():
|
||||
// if the context is done, key request was unsuccessful
|
||||
mach.Log.Debug("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID)
|
||||
mach.Log.Debug().Msgf("Context closed (%v) before forwared key for session %v received, sending key request cancellation", ctx.Err(), sessionID)
|
||||
resChan <- false
|
||||
}
|
||||
|
||||
@@ -128,20 +130,41 @@ func (mach *OlmMachine) SendRoomKeyRequest(roomID id.RoomID, senderKey id.Sender
|
||||
return err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content *event.ForwardedRoomKeyEventContent) bool {
|
||||
func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.ForwardedRoomKeyEventContent) bool {
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Str("room_id", content.RoomID.String()).
|
||||
Logger()
|
||||
if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" {
|
||||
mach.Log.Debug("Ignoring weird forwarded room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID)
|
||||
log.Debug().
|
||||
Str("algorithm", string(content.Algorithm)).
|
||||
Msg("Ignoring weird forwarded room key")
|
||||
return false
|
||||
}
|
||||
|
||||
igsInternal, err := olm.InboundGroupSessionImport([]byte(content.SessionKey))
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to import inbound group session: %v", err)
|
||||
log.Error().Err(err).Msg("Failed to import inbound group session")
|
||||
return false
|
||||
} else if igsInternal.ID() != content.SessionID {
|
||||
mach.Log.Warn("Mismatched session ID while creating inbound group session")
|
||||
log.Warn().
|
||||
Str("actual_session_id", igsInternal.ID().String()).
|
||||
Msg("Mismatched session ID while creating inbound group session from forward")
|
||||
return false
|
||||
}
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
|
||||
maxMessages = config.RotationPeriodMessages
|
||||
}
|
||||
if content.MaxAge != 0 {
|
||||
maxAge = time.Duration(content.MaxAge) * time.Millisecond
|
||||
}
|
||||
if content.MaxMessages != 0 {
|
||||
maxMessages = content.MaxMessages
|
||||
}
|
||||
igs := &InboundGroupSession{
|
||||
Internal: *igsInternal,
|
||||
SigningKey: evt.Keys.Ed25519,
|
||||
@@ -149,14 +172,19 @@ func (mach *OlmMachine) importForwardedRoomKey(evt *DecryptedOlmEvent, content *
|
||||
RoomID: content.RoomID,
|
||||
ForwardingChains: append(content.ForwardingKeyChain, evt.SenderKey.String()),
|
||||
id: content.SessionID,
|
||||
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
MaxAge: maxAge.Milliseconds(),
|
||||
MaxMessages: maxMessages,
|
||||
IsScheduled: content.IsScheduled,
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(content.RoomID, content.SenderKey, content.SessionID, igs)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to store new inbound group session: %v", err)
|
||||
log.Error().Err(err).Msg("Failed to store new inbound group session")
|
||||
return false
|
||||
}
|
||||
mach.markSessionReceived(content.SessionID)
|
||||
mach.Log.Trace("Received forwarded inbound group session %s/%s/%s", content.RoomID, content.SenderKey, content.SessionID)
|
||||
log.Debug().Msg("Received forwarded inbound group session")
|
||||
return true
|
||||
}
|
||||
|
||||
@@ -175,50 +203,72 @@ func (mach *OlmMachine) rejectKeyRequest(rejection KeyShareRejection, device *id
|
||||
}
|
||||
err := mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceRoomKeyWithheld, &content)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to send key share rejection %s to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err)
|
||||
mach.Log.Warn().Err(err).
|
||||
Str("code", string(rejection.Code)).
|
||||
Str("user_id", device.UserID.String()).
|
||||
Str("device_id", device.DeviceID.String()).
|
||||
Msg("Failed to send key share rejection")
|
||||
}
|
||||
err = mach.sendToOneDevice(device.UserID, device.DeviceID, event.ToDeviceOrgMatrixRoomKeyWithheld, &content)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to send key share rejection %s (org.matrix.) to %s/%s: %v", rejection.Code, device.UserID, device.DeviceID, err)
|
||||
mach.Log.Warn().Err(err).
|
||||
Str("code", string(rejection.Code)).
|
||||
Str("user_id", device.UserID.String()).
|
||||
Str("device_id", device.DeviceID.String()).
|
||||
Msg("Failed to send key share rejection (legacy event type)")
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) defaultAllowKeyShare(device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection {
|
||||
func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection {
|
||||
log := mach.machOrContextLog(ctx)
|
||||
if mach.Client.UserID != device.UserID {
|
||||
mach.Log.Debug("Ignoring key request from a different user (%s)", device.UserID)
|
||||
log.Debug().Msg("Rejecting key request from a different user")
|
||||
return &KeyShareRejectOtherUser
|
||||
} else if mach.Client.DeviceID == device.DeviceID {
|
||||
mach.Log.Debug("Ignoring key request from ourselves")
|
||||
log.Debug().Msg("Ignoring key request from ourselves")
|
||||
return &KeyShareRejectNoResponse
|
||||
} else if device.Trust == id.TrustStateBlacklisted {
|
||||
mach.Log.Debug("Ignoring key request from blacklisted device %s", device.DeviceID)
|
||||
log.Debug().Msg("Rejecting key request from blacklisted device")
|
||||
return &KeyShareRejectBlacklisted
|
||||
} else if trustState := mach.ResolveTrust(device); trustState >= mach.ShareKeysMinTrust {
|
||||
mach.Log.Debug("Accepting key request from device %s (trust state: %s)", device.DeviceID, trustState)
|
||||
log.Debug().
|
||||
Str("min_trust", mach.SendKeysMinTrust.String()).
|
||||
Str("device_trust", trustState.String()).
|
||||
Msg("Accepting key request from trusted device")
|
||||
return nil
|
||||
} else {
|
||||
mach.Log.Debug("Ignoring key request from unverified device %s (trust state: %s)", device.DeviceID, trustState)
|
||||
log.Debug().
|
||||
Str("min_trust", mach.SendKeysMinTrust.String()).
|
||||
Str("device_trust", trustState.String()).
|
||||
Msg("Rejecting key request from untrusted device")
|
||||
return &KeyShareRejectUnverified
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.RoomKeyRequestEventContent) {
|
||||
func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.UserID, content *event.RoomKeyRequestEventContent) {
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("request_id", content.RequestID).
|
||||
Str("device_id", content.RequestingDeviceID.String()).
|
||||
Str("room_id", content.Body.RoomID.String()).
|
||||
Str("session_id", content.Body.SessionID.String()).
|
||||
Logger()
|
||||
ctx = log.WithContext(ctx)
|
||||
if content.Action != event.KeyRequestActionRequest {
|
||||
return
|
||||
} else if content.RequestingDeviceID == mach.Client.DeviceID && sender == mach.Client.UserID {
|
||||
mach.Log.Debug("Ignoring key request %s from ourselves", content.RequestID)
|
||||
log.Debug().Msg("Ignoring key request from ourselves")
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Debug("Received key request %s for %s from %s/%s", content.RequestID, content.Body.SessionID, sender, content.RequestingDeviceID)
|
||||
log.Debug().Msg("Received key request")
|
||||
|
||||
device, err := mach.GetOrFetchDevice(sender, content.RequestingDeviceID)
|
||||
device, err := mach.GetOrFetchDevice(ctx, sender, content.RequestingDeviceID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to fetch device %s/%s that requested keys: %v", sender, content.RequestingDeviceID, err)
|
||||
log.Error().Err(err).Msg("Failed to fetch device that requested keys")
|
||||
return
|
||||
}
|
||||
|
||||
rejection := mach.AllowKeyShare(device, content.Body)
|
||||
rejection := mach.AllowKeyShare(ctx, device, content.Body)
|
||||
if rejection != nil {
|
||||
mach.rejectKeyRequest(*rejection, device, content.Body)
|
||||
return
|
||||
@@ -226,18 +276,29 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
|
||||
|
||||
igs, err := mach.CryptoStore.GetGroupSession(content.Body.RoomID, content.Body.SenderKey, content.Body.SessionID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to fetch group session to forward to %s/%s: %v", device.UserID, device.DeviceID, err)
|
||||
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Requested group session not available")
|
||||
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
|
||||
} else {
|
||||
log.Error().Err(err).Msg("Failed to get group session to forward")
|
||||
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
|
||||
}
|
||||
return
|
||||
} else if igs == nil {
|
||||
mach.Log.Warn("Didn't find group session %s to forward to %s/%s", content.Body.SessionID, device.UserID, device.DeviceID)
|
||||
log.Error().Msg("Didn't find group session to forward")
|
||||
mach.rejectKeyRequest(KeyShareRejectUnavailable, device, content.Body)
|
||||
return
|
||||
}
|
||||
if internalID := igs.ID(); internalID != content.Body.SessionID {
|
||||
// Should this be an error?
|
||||
log = log.With().Str("unexpected_session_id", internalID.String()).Logger()
|
||||
}
|
||||
|
||||
exportedKey, err := igs.Internal.Export(igs.Internal.FirstKnownIndex())
|
||||
firstKnownIndex := igs.Internal.FirstKnownIndex()
|
||||
log = log.With().Uint32("first_known_index", firstKnownIndex).Logger()
|
||||
exportedKey, err := igs.Internal.Export(firstKnownIndex)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to export session %s to forward to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err)
|
||||
log.Error().Err(err).Msg("Failed to export group session to forward")
|
||||
mach.rejectKeyRequest(KeyShareRejectInternalError, device, content.Body)
|
||||
return
|
||||
}
|
||||
@@ -248,7 +309,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
RoomID: igs.RoomID,
|
||||
SessionID: igs.ID(),
|
||||
SessionKey: exportedKey,
|
||||
SessionKey: string(exportedKey),
|
||||
},
|
||||
SenderKey: content.Body.SenderKey,
|
||||
ForwardingKeyChain: igs.ForwardingChains,
|
||||
@@ -256,9 +317,42 @@ func (mach *OlmMachine) handleRoomKeyRequest(sender id.UserID, content *event.Ro
|
||||
},
|
||||
}
|
||||
|
||||
if err := mach.SendEncryptedToDevice(device, event.ToDeviceForwardedRoomKey, forwardedRoomKey); err != nil {
|
||||
mach.Log.Error("Failed to send encrypted forwarded key %s to %s/%s: %v", igs.ID(), device.UserID, device.DeviceID, err)
|
||||
if err = mach.SendEncryptedToDevice(ctx, device, event.ToDeviceForwardedRoomKey, forwardedRoomKey); err != nil {
|
||||
log.Error().Err(err).Msg("Failed to encrypt and send group session")
|
||||
} else {
|
||||
log.Debug().Msg("Successfully sent forwarded group session")
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) {
|
||||
log := mach.machOrContextLog(ctx).With().
|
||||
Str("room_id", content.RoomID.String()).
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Int("first_message_index", content.FirstMessageIndex).
|
||||
Logger()
|
||||
|
||||
sess, err := mach.CryptoStore.GetGroupSession(content.RoomID, "", content.SessionID)
|
||||
if err != nil {
|
||||
if errors.Is(err, ErrGroupSessionWithheld) {
|
||||
log.Debug().Err(err).Msg("Acked group session was already redacted")
|
||||
} else {
|
||||
log.Err(err).Msg("Failed to get group session to check if it should be redacted")
|
||||
}
|
||||
return
|
||||
}
|
||||
log = log.With().
|
||||
Str("sender_key", sess.SenderKey.String()).
|
||||
Str("own_identity", mach.OwnIdentity().IdentityKey.String()).
|
||||
Logger()
|
||||
|
||||
isInbound := sess.SenderKey == mach.OwnIdentity().IdentityKey
|
||||
if isInbound && mach.DeleteOutboundKeysOnAck && content.FirstMessageIndex == 0 {
|
||||
log.Debug().Msg("Redacting inbound copy of outbound group session after ack")
|
||||
err = mach.CryptoStore.RedactGroupSession(content.RoomID, sess.SenderKey, content.SessionID, "outbound session acked")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact group session")
|
||||
}
|
||||
} else {
|
||||
log.Debug().Bool("inbound", isInbound).Msg("Received room key ack")
|
||||
}
|
||||
|
||||
mach.Log.Debug("Sent encrypted forwarded key to device %s/%s for session %s", device.UserID, device.DeviceID, igs.ID())
|
||||
}
|
||||
|
||||
308
vendor/maunium.net/go/mautrix/crypto/machine.go
generated
vendored
308
vendor/maunium.net/go/mautrix/crypto/machine.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -7,12 +7,14 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"errors"
|
||||
"fmt"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix/appservice"
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/crypto/ssss"
|
||||
"maunium.net/go/mautrix/id"
|
||||
|
||||
@@ -20,28 +22,21 @@ import (
|
||||
"maunium.net/go/mautrix/event"
|
||||
)
|
||||
|
||||
// Logger is a simple logging struct for OlmMachine.
|
||||
// Implementations are recommended to use fmt.Sprintf and manually add a newline after the message.
|
||||
type Logger interface {
|
||||
Error(message string, args ...interface{})
|
||||
Warn(message string, args ...interface{})
|
||||
Debug(message string, args ...interface{})
|
||||
Trace(message string, args ...interface{})
|
||||
}
|
||||
|
||||
// OlmMachine is the main struct for handling Matrix end-to-end encryption.
|
||||
type OlmMachine struct {
|
||||
Client *mautrix.Client
|
||||
SSSS *ssss.Machine
|
||||
Log Logger
|
||||
Log *zerolog.Logger
|
||||
|
||||
CryptoStore Store
|
||||
StateStore StateStore
|
||||
|
||||
PlaintextMentions bool
|
||||
|
||||
SendKeysMinTrust id.TrustState
|
||||
ShareKeysMinTrust id.TrustState
|
||||
|
||||
AllowKeyShare func(*id.Device, event.RequestedKeyInfo) *KeyShareRejection
|
||||
AllowKeyShare func(context.Context, *id.Device, event.RequestedKeyInfo) *KeyShareRejection
|
||||
|
||||
DefaultSASTimeout time.Duration
|
||||
// AcceptVerificationFrom determines whether the machine will accept verification requests from this device.
|
||||
@@ -60,7 +55,9 @@ type OlmMachine struct {
|
||||
recentlyUnwedged map[id.IdentityKey]time.Time
|
||||
recentlyUnwedgedLock sync.Mutex
|
||||
|
||||
olmLock sync.Mutex
|
||||
olmLock sync.Mutex
|
||||
megolmEncryptLock sync.Mutex
|
||||
megolmDecryptLock sync.Mutex
|
||||
|
||||
otkUploadLock sync.Mutex
|
||||
lastOTKUpload time.Time
|
||||
@@ -69,6 +66,13 @@ type OlmMachine struct {
|
||||
crossSigningPubkeys *CrossSigningPublicKeysCache
|
||||
|
||||
crossSigningPubkeysFetched bool
|
||||
|
||||
DeleteOutboundKeysOnAck bool
|
||||
DontStoreOutboundKeys bool
|
||||
DeletePreviousKeysOnReceive bool
|
||||
RatchetKeysOnDecrypt bool
|
||||
DeleteFullyUsedKeysOnDecrypt bool
|
||||
DeleteKeysOnDeviceDelete bool
|
||||
}
|
||||
|
||||
// StateStore is used by OlmMachine to get room state information that's needed for encryption.
|
||||
@@ -82,7 +86,11 @@ type StateStore interface {
|
||||
}
|
||||
|
||||
// NewOlmMachine creates an OlmMachine with the given client, logger and stores.
|
||||
func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateStore StateStore) *OlmMachine {
|
||||
func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Store, stateStore StateStore) *OlmMachine {
|
||||
if log == nil {
|
||||
logPtr := zerolog.Nop()
|
||||
log = &logPtr
|
||||
}
|
||||
mach := &OlmMachine{
|
||||
Client: client,
|
||||
SSSS: ssss.NewSSSSMachine(client),
|
||||
@@ -111,6 +119,14 @@ func NewOlmMachine(client *mautrix.Client, log Logger, cryptoStore Store, stateS
|
||||
return mach
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) machOrContextLog(ctx context.Context) *zerolog.Logger {
|
||||
log := zerolog.Ctx(ctx)
|
||||
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
|
||||
return mach.Log
|
||||
}
|
||||
return log
|
||||
}
|
||||
|
||||
// Load loads the Olm account information from the crypto store. If there's no olm account, a new one is created.
|
||||
// This must be called before using the machine.
|
||||
func (mach *OlmMachine) Load() (err error) {
|
||||
@@ -127,7 +143,7 @@ func (mach *OlmMachine) Load() (err error) {
|
||||
func (mach *OlmMachine) saveAccount() {
|
||||
err := mach.CryptoStore.PutAccount(mach.account)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to save account: %v", err)
|
||||
mach.Log.Error().Err(err).Msg("Failed to save account")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -136,12 +152,15 @@ func (mach *OlmMachine) FlushStore() error {
|
||||
return mach.CryptoStore.Flush()
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) timeTrace(thing, trace string, expectedDuration time.Duration) func() {
|
||||
func (mach *OlmMachine) timeTrace(ctx context.Context, thing string, expectedDuration time.Duration) func() {
|
||||
start := time.Now()
|
||||
return func() {
|
||||
duration := time.Now().Sub(start)
|
||||
if duration > expectedDuration {
|
||||
mach.Log.Warn("%s took %s (trace: %s)", thing, duration, trace)
|
||||
zerolog.Ctx(ctx).Warn().
|
||||
Str("action", thing).
|
||||
Dur("duration", duration).
|
||||
Msg("Executing encryption function took longer than expected")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -156,7 +175,11 @@ func (mach *OlmMachine) Fingerprint() string {
|
||||
return mach.account.SigningKey().Fingerprint()
|
||||
}
|
||||
|
||||
// OwnIdentity returns this device's DeviceIdentity struct
|
||||
func (mach *OlmMachine) GetAccount() *OlmAccount {
|
||||
return mach.account
|
||||
}
|
||||
|
||||
// OwnIdentity returns this device's id.Device struct
|
||||
func (mach *OlmMachine) OwnIdentity() *id.Device {
|
||||
return &id.Device{
|
||||
UserID: mach.Client.UserID,
|
||||
@@ -168,11 +191,18 @@ func (mach *OlmMachine) OwnIdentity() *id.Device {
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az *appservice.AppService) {
|
||||
type asEventProcessor interface {
|
||||
On(evtType event.Type, handler func(evt *event.Event))
|
||||
OnOTK(func(otk *mautrix.OTKCount))
|
||||
OnDeviceList(func(lists *mautrix.DeviceLists, since string))
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) AddAppserviceListener(ep asEventProcessor) {
|
||||
// ToDeviceForwardedRoomKey and ToDeviceRoomKey should only be present inside encrypted to-device events
|
||||
ep.On(event.ToDeviceEncrypted, mach.HandleToDeviceEvent)
|
||||
ep.On(event.ToDeviceRoomKeyRequest, mach.HandleToDeviceEvent)
|
||||
ep.On(event.ToDeviceRoomKeyWithheld, mach.HandleToDeviceEvent)
|
||||
ep.On(event.ToDeviceBeeperRoomKeyAck, mach.HandleToDeviceEvent)
|
||||
ep.On(event.ToDeviceOrgMatrixRoomKeyWithheld, mach.HandleToDeviceEvent)
|
||||
ep.On(event.ToDeviceVerificationRequest, mach.HandleToDeviceEvent)
|
||||
ep.On(event.ToDeviceVerificationStart, mach.HandleToDeviceEvent)
|
||||
@@ -182,34 +212,44 @@ func (mach *OlmMachine) AddAppserviceListener(ep *appservice.EventProcessor, az
|
||||
ep.On(event.ToDeviceVerificationCancel, mach.HandleToDeviceEvent)
|
||||
ep.OnOTK(mach.HandleOTKCounts)
|
||||
ep.OnDeviceList(mach.HandleDeviceLists)
|
||||
mach.Log.Trace("Added listeners for encryption data coming from appservice transactions")
|
||||
mach.Log.Debug().Msg("Added listeners for encryption data coming from appservice transactions")
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) HandleDeviceLists(dl *mautrix.DeviceLists, since string) {
|
||||
if len(dl.Changed) > 0 {
|
||||
traceID := time.Now().Format("15:04:05.000000")
|
||||
mach.Log.Trace("Device list changes in /sync: %v (trace: %s)", dl.Changed, traceID)
|
||||
mach.fetchKeys(dl.Changed, since, false)
|
||||
mach.Log.Trace("Finished handling device list changes (trace: %s)", traceID)
|
||||
mach.Log.Debug().
|
||||
Str("trace_id", traceID).
|
||||
Interface("changes", dl.Changed).
|
||||
Msg("Device list changes in /sync")
|
||||
mach.fetchKeys(context.TODO(), dl.Changed, since, false)
|
||||
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
|
||||
if (len(otkCount.UserID) > 0 && otkCount.UserID != mach.Client.UserID) || (len(otkCount.DeviceID) > 0 && otkCount.DeviceID != mach.Client.DeviceID) {
|
||||
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
|
||||
mach.Log.Debug("Dropping OTK counts targeted to %s/%s (not us)", otkCount.UserID, otkCount.DeviceID)
|
||||
mach.Log.Warn().
|
||||
Str("target_user_id", otkCount.UserID.String()).
|
||||
Str("target_device_id", otkCount.DeviceID.String()).
|
||||
Msg("Dropping OTK counts targeted to someone else")
|
||||
return
|
||||
}
|
||||
|
||||
minCount := mach.account.Internal.MaxNumberOfOneTimeKeys() / 2
|
||||
if otkCount.SignedCurve25519 < int(minCount) {
|
||||
traceID := time.Now().Format("15:04:05.000000")
|
||||
mach.Log.Debug("Sync response said we have %d signed curve25519 keys left, sharing new ones... (trace: %s)", otkCount.SignedCurve25519, traceID)
|
||||
err := mach.ShareKeys(otkCount.SignedCurve25519)
|
||||
log := mach.Log.With().Str("trace_id", traceID).Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
log.Debug().
|
||||
Int("keys_left", otkCount.Curve25519).
|
||||
Msg("Sync response said we have less than 50 signed curve25519 keys left, sharing new ones...")
|
||||
err := mach.ShareKeys(ctx, otkCount.SignedCurve25519)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to share keys: %v (trace: %s)", err, traceID)
|
||||
log.Error().Err(err).Msg("Failed to share keys")
|
||||
} else {
|
||||
mach.Log.Debug("Successfully shared keys (trace: %s)", traceID)
|
||||
log.Debug().Msg("Successfully shared keys")
|
||||
}
|
||||
}
|
||||
}
|
||||
@@ -218,7 +258,7 @@ func (mach *OlmMachine) HandleOTKCounts(otkCount *mautrix.OTKCount) {
|
||||
//
|
||||
// This can be easily registered into a mautrix client using .OnSync():
|
||||
//
|
||||
// client.Syncer.(*mautrix.DefaultSyncer).OnSync(c.crypto.ProcessSyncResponse)
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnSync(c.crypto.ProcessSyncResponse)
|
||||
func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string) bool {
|
||||
mach.HandleDeviceLists(&resp.DeviceLists, since)
|
||||
|
||||
@@ -226,7 +266,7 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
|
||||
evt.Type.Class = event.ToDeviceEventType
|
||||
err := evt.Content.ParseRaw(evt.Type)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to parse to-device event of type %s: %v", evt.Type.Type, err)
|
||||
mach.Log.Warn().Str("event_type", evt.Type.Type).Err(err).Msg("Failed to parse to-device event")
|
||||
continue
|
||||
}
|
||||
mach.HandleToDeviceEvent(evt)
|
||||
@@ -240,8 +280,8 @@ func (mach *OlmMachine) ProcessSyncResponse(resp *mautrix.RespSync, since string
|
||||
//
|
||||
// Currently this is not automatically called, so you must add a listener yourself:
|
||||
//
|
||||
// client.Syncer.(*mautrix.DefaultSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
|
||||
func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnEventType(event.StateMember, c.crypto.HandleMemberEvent)
|
||||
func (mach *OlmMachine) HandleMemberEvent(_ mautrix.EventSource, evt *event.Event) {
|
||||
if !mach.StateStore.IsEncrypted(evt.RoomID) {
|
||||
return
|
||||
}
|
||||
@@ -263,10 +303,15 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
|
||||
(prevContent.Membership == event.MembershipLeave && content.Membership == event.MembershipBan) {
|
||||
return
|
||||
}
|
||||
mach.Log.Trace("Got membership state event in %s changing %s from %s to %s, invalidating group session", evt.RoomID, evt.GetStateKey(), prevContent.Membership, content.Membership)
|
||||
mach.Log.Trace().
|
||||
Str("room_id", evt.RoomID.String()).
|
||||
Str("user_id", evt.GetStateKey()).
|
||||
Str("prev_membership", string(prevContent.Membership)).
|
||||
Str("new_membership", string(content.Membership)).
|
||||
Msg("Got membership state change, invalidating group session in room")
|
||||
err := mach.CryptoStore.RemoveOutboundGroupSession(evt.RoomID)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to invalidate outbound group session of %s: %v", evt.RoomID, err)
|
||||
mach.Log.Warn().Str("room_id", evt.RoomID.String()).Msg("Failed to invalidate outbound group session")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -275,43 +320,63 @@ func (mach *OlmMachine) HandleMemberEvent(evt *event.Event) {
|
||||
func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
||||
if len(evt.ToUserID) > 0 && (evt.ToUserID != mach.Client.UserID || evt.ToDeviceID != mach.Client.DeviceID) {
|
||||
// TODO This log probably needs to be silence-able if someone wants to use encrypted appservices with multiple e2ee sessions
|
||||
mach.Log.Debug("Dropping to-device event targeted to %s/%s (not us)", evt.ToUserID, evt.ToDeviceID)
|
||||
mach.Log.Debug().
|
||||
Str("target_user_id", evt.ToUserID.String()).
|
||||
Str("target_device_id", evt.ToDeviceID.String()).
|
||||
Msg("Dropping to-device event targeted to someone else")
|
||||
return
|
||||
}
|
||||
traceID := time.Now().Format("15:04:05.000000")
|
||||
log := mach.Log.With().
|
||||
Str("trace_id", traceID).
|
||||
Str("sender", evt.Sender.String()).
|
||||
Str("type", evt.Type.Type).
|
||||
Logger()
|
||||
ctx := log.WithContext(context.Background())
|
||||
if evt.Type != event.ToDeviceEncrypted {
|
||||
mach.Log.Trace("Starting handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID)
|
||||
log.Debug().Msg("Starting handling to-device event")
|
||||
}
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.EncryptedEventContent:
|
||||
mach.Log.Debug("Handling encrypted to-device event from %s/%s (trace: %s)", evt.Sender, content.SenderKey, traceID)
|
||||
decryptedEvt, err := mach.decryptOlmEvent(evt, traceID)
|
||||
log = log.With().
|
||||
Str("sender_key", content.SenderKey.String()).
|
||||
Logger()
|
||||
log.Debug().Msg("Handling encrypted to-device event")
|
||||
ctx = log.WithContext(context.Background())
|
||||
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to decrypt to-device event: %v (trace: %s)", err, traceID)
|
||||
log.Error().Err(err).Msg("Failed to decrypt to-device event")
|
||||
return
|
||||
}
|
||||
mach.Log.Trace("Successfully decrypted to-device from %s/%s into type %s (sender key: %s, trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, decryptedEvt.Type.String(), decryptedEvt.SenderKey, traceID)
|
||||
log = log.With().
|
||||
Str("decrypted_type", decryptedEvt.Type.Type).
|
||||
Str("sender_device", decryptedEvt.SenderDevice.String()).
|
||||
Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()).
|
||||
Logger()
|
||||
log.Trace().Msg("Successfully decrypted to-device event")
|
||||
|
||||
switch decryptedContent := decryptedEvt.Content.Parsed.(type) {
|
||||
case *event.RoomKeyEventContent:
|
||||
mach.receiveRoomKey(decryptedEvt, decryptedContent, traceID)
|
||||
mach.Log.Trace("Handled room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
|
||||
mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent)
|
||||
log.Trace().Msg("Handled room key event")
|
||||
case *event.ForwardedRoomKeyEventContent:
|
||||
if mach.importForwardedRoomKey(decryptedEvt, decryptedContent) {
|
||||
if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) {
|
||||
if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok {
|
||||
// close channel to notify listener that the key was received
|
||||
close(ch.(chan struct{}))
|
||||
}
|
||||
}
|
||||
mach.Log.Trace("Handled forwarded room key event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
|
||||
log.Trace().Msg("Handled forwarded room key event")
|
||||
case *event.DummyEventContent:
|
||||
mach.Log.Debug("Received encrypted dummy event from %s/%s (trace: %s)", decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
|
||||
log.Debug().Msg("Received encrypted dummy event")
|
||||
default:
|
||||
mach.Log.Debug("Unhandled encrypted to-device event of type %s from %s/%s (trace: %s)", decryptedEvt.Type.String(), decryptedEvt.Sender, decryptedEvt.SenderDevice, traceID)
|
||||
log.Debug().Msg("Unhandled encrypted to-device event")
|
||||
}
|
||||
return
|
||||
case *event.RoomKeyRequestEventContent:
|
||||
go mach.handleRoomKeyRequest(evt.Sender, content)
|
||||
go mach.handleRoomKeyRequest(ctx, evt.Sender, content)
|
||||
case *event.BeeperRoomKeyAckEventContent:
|
||||
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content)
|
||||
// verification cases
|
||||
case *event.VerificationStartEventContent:
|
||||
mach.handleVerificationStart(evt.Sender, content, content.TransactionID, 10*time.Minute, "")
|
||||
@@ -326,27 +391,25 @@ func (mach *OlmMachine) HandleToDeviceEvent(evt *event.Event) {
|
||||
case *event.VerificationRequestEventContent:
|
||||
mach.handleVerificationRequest(evt.Sender, content, content.TransactionID, "")
|
||||
case *event.RoomKeyWithheldEventContent:
|
||||
mach.handleRoomKeyWithheld(content)
|
||||
mach.handleRoomKeyWithheld(ctx, content)
|
||||
default:
|
||||
deviceID, _ := evt.Content.Raw["device_id"].(string)
|
||||
mach.Log.Trace("Unhandled to-device event of type %s from %s/%s (trace: %s)", evt.Type.Type, evt.Sender, deviceID, traceID)
|
||||
log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event")
|
||||
return
|
||||
}
|
||||
mach.Log.Trace("Finished handling to-device event of type %s from %s (trace: %s)", evt.Type.Type, evt.Sender, traceID)
|
||||
log.Debug().Msg("Finished handling to-device event")
|
||||
}
|
||||
|
||||
// GetOrFetchDevice attempts to retrieve the device identity for the given device from the store
|
||||
// and if it's not found it asks the server for it.
|
||||
func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
// get device identity
|
||||
func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
device, err := mach.CryptoStore.GetDevice(userID, deviceID)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to get sender device from store: %w", err)
|
||||
} else if device != nil {
|
||||
return device, nil
|
||||
}
|
||||
// try to fetch if not found
|
||||
usersToDevices := mach.fetchKeys([]id.UserID{userID}, "", true)
|
||||
usersToDevices := mach.fetchKeys(ctx, []id.UserID{userID}, "", true)
|
||||
if devices, ok := usersToDevices[userID]; ok {
|
||||
if device, ok = devices[deviceID]; ok {
|
||||
return device, nil
|
||||
@@ -359,12 +422,15 @@ func (mach *OlmMachine) GetOrFetchDevice(userID id.UserID, deviceID id.DeviceID)
|
||||
// GetOrFetchDeviceByKey attempts to retrieve the device identity for the device with the given identity key from the
|
||||
// store and if it's not found it asks the server for it. This returns nil if the server doesn't return a device with
|
||||
// the given identity key.
|
||||
func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(userID, identityKey)
|
||||
if err != nil || deviceIdentity != nil {
|
||||
return deviceIdentity, err
|
||||
}
|
||||
mach.Log.Debug("Didn't find identity of %s/%s in crypto store, fetching from server", userID, identityKey)
|
||||
mach.machOrContextLog(ctx).Debug().
|
||||
Str("user_id", userID.String()).
|
||||
Str("identity_key", identityKey.String()).
|
||||
Msg("Didn't find identity in crypto store, fetching from server")
|
||||
devices := mach.LoadDevices(userID)
|
||||
for _, device := range devices {
|
||||
if device.IdentityKey == identityKey {
|
||||
@@ -375,8 +441,8 @@ func (mach *OlmMachine) GetOrFetchDeviceByKey(userID id.UserID, identityKey id.I
|
||||
}
|
||||
|
||||
// SendEncryptedToDevice sends an Olm-encrypted event to the given user device.
|
||||
func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.Type, content event.Content) error {
|
||||
if err := mach.createOutboundSessions(map[id.UserID]map[id.DeviceID]*id.Device{
|
||||
func (mach *OlmMachine) SendEncryptedToDevice(ctx context.Context, device *id.Device, evtType event.Type, content event.Content) error {
|
||||
if err := mach.createOutboundSessions(ctx, map[id.UserID]map[id.DeviceID]*id.Device{
|
||||
device.UserID: {
|
||||
device.DeviceID: device,
|
||||
},
|
||||
@@ -395,10 +461,16 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.T
|
||||
return fmt.Errorf("didn't find created outbound session for device %s of %s", device.DeviceID, device.UserID)
|
||||
}
|
||||
|
||||
encrypted := mach.encryptOlmEvent(olmSess, device, evtType, content)
|
||||
encrypted := mach.encryptOlmEvent(ctx, olmSess, device, evtType, content)
|
||||
encryptedContent := &event.Content{Parsed: &encrypted}
|
||||
|
||||
mach.Log.Debug("Sending encrypted to-device event of type %s to %s/%s (identity key: %s, olm session ID: %s)", evtType.Type, device.UserID, device.DeviceID, device.IdentityKey, olmSess.ID())
|
||||
mach.machOrContextLog(ctx).Debug().
|
||||
Str("decrypted_type", evtType.Type).
|
||||
Str("to_user_id", device.UserID.String()).
|
||||
Str("to_device_id", device.DeviceID.String()).
|
||||
Str("to_identity_key", device.IdentityKey.String()).
|
||||
Str("olm_session_id", olmSess.ID().String()).
|
||||
Msg("Sending encrypted to-device event")
|
||||
_, err = mach.Client.SendToDevice(event.ToDeviceEncrypted,
|
||||
&mautrix.ReqSendToDevice{
|
||||
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
|
||||
@@ -412,22 +484,32 @@ func (mach *OlmMachine) SendEncryptedToDevice(device *id.Device, evtType event.T
|
||||
return err
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) createGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, traceID string) {
|
||||
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey)
|
||||
func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionID id.SessionID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
igs, err := NewInboundGroupSession(senderKey, signingKey, roomID, sessionKey, maxAge, maxMessages, isScheduled)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to create inbound group session: %v", err)
|
||||
log.Error().Err(err).Msg("Failed to create inbound group session")
|
||||
return
|
||||
} else if igs.ID() != sessionID {
|
||||
mach.Log.Warn("Mismatched session ID while creating inbound group session")
|
||||
log.Warn().
|
||||
Str("expected_session_id", sessionID.String()).
|
||||
Str("actual_session_id", igs.ID().String()).
|
||||
Msg("Mismatched session ID while creating inbound group session")
|
||||
return
|
||||
}
|
||||
err = mach.CryptoStore.PutGroupSession(roomID, senderKey, sessionID, igs)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to store new inbound group session: %v", err)
|
||||
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
|
||||
return
|
||||
}
|
||||
mach.markSessionReceived(sessionID)
|
||||
mach.Log.Debug("Received inbound group session %s / %s / %s", roomID, senderKey, sessionID)
|
||||
log.Debug().
|
||||
Str("session_id", sessionID.String()).
|
||||
Str("sender_key", senderKey.String()).
|
||||
Str("max_age", maxAge.String()).
|
||||
Int("max_messages", maxMessages).
|
||||
Bool("is_scheduled", isScheduled).
|
||||
Msg("Received inbound group session")
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) markSessionReceived(id id.SessionID) {
|
||||
@@ -465,24 +547,66 @@ func (mach *OlmMachine) WaitForSession(roomID id.RoomID, senderKey id.SenderKey,
|
||||
}
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) receiveRoomKey(evt *DecryptedOlmEvent, content *event.RoomKeyEventContent, traceID string) {
|
||||
// TODO nio had a comment saying "handle this better" for the case where evt.Keys.Ed25519 is none?
|
||||
func stringifyArray[T ~string](arr []T) []string {
|
||||
strs := make([]string, len(arr))
|
||||
for i, v := range arr {
|
||||
strs[i] = string(v)
|
||||
}
|
||||
return strs
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEvent, content *event.RoomKeyEventContent) {
|
||||
log := zerolog.Ctx(ctx).With().
|
||||
Str("algorithm", string(content.Algorithm)).
|
||||
Str("session_id", content.SessionID.String()).
|
||||
Str("room_id", content.RoomID.String()).
|
||||
Logger()
|
||||
if content.Algorithm != id.AlgorithmMegolmV1 || evt.Keys.Ed25519 == "" {
|
||||
mach.Log.Debug("Ignoring weird room key from %s/%s: alg=%s, ed25519=%s, sessionid=%s, roomid=%s", evt.Sender, evt.SenderDevice, content.Algorithm, evt.Keys.Ed25519, content.SessionID, content.RoomID)
|
||||
log.Debug().Msg("Ignoring weird room key")
|
||||
return
|
||||
}
|
||||
|
||||
mach.createGroupSession(evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, traceID)
|
||||
config := mach.StateStore.GetEncryptionEvent(content.RoomID)
|
||||
var maxAge time.Duration
|
||||
var maxMessages int
|
||||
if config != nil {
|
||||
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
|
||||
if maxAge == 0 {
|
||||
maxAge = 7 * 24 * time.Hour
|
||||
}
|
||||
maxMessages = config.RotationPeriodMessages
|
||||
if maxMessages == 0 {
|
||||
maxMessages = 100
|
||||
}
|
||||
}
|
||||
if content.MaxAge != 0 {
|
||||
maxAge = time.Duration(content.MaxAge) * time.Millisecond
|
||||
}
|
||||
if content.MaxMessages != 0 {
|
||||
maxMessages = content.MaxMessages
|
||||
}
|
||||
if mach.DeletePreviousKeysOnReceive && !content.IsScheduled {
|
||||
log.Debug().Msg("Redacting previous megolm sessions from sender in room")
|
||||
sessionIDs, err := mach.CryptoStore.RedactGroupSessions(content.RoomID, evt.SenderKey, "received new key from device")
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact previous megolm sessions")
|
||||
} else {
|
||||
log.Info().
|
||||
Strs("session_ids", stringifyArray(sessionIDs)).
|
||||
Msg("Redacted previous megolm sessions")
|
||||
}
|
||||
}
|
||||
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled)
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) handleRoomKeyWithheld(content *event.RoomKeyWithheldEventContent) {
|
||||
func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) {
|
||||
if content.Algorithm != id.AlgorithmMegolmV1 {
|
||||
mach.Log.Debug("Non-megolm room key withheld event: %+v", content)
|
||||
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
|
||||
return
|
||||
}
|
||||
err := mach.CryptoStore.PutWithheldGroupSession(*content)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to save room key withheld event: %v", err)
|
||||
zerolog.Ctx(ctx).Error().Err(err).Msg("Failed to save room key withheld event")
|
||||
}
|
||||
}
|
||||
|
||||
@@ -491,34 +615,38 @@ func (mach *OlmMachine) handleRoomKeyWithheld(content *event.RoomKeyWithheldEven
|
||||
// If the Olm account hasn't been shared, the account keys will be uploaded.
|
||||
// If currentOTKCount is less than half of the limit (100 / 2 = 50), enough one-time keys will be uploaded so exactly
|
||||
// half of the limit is filled.
|
||||
func (mach *OlmMachine) ShareKeys(currentOTKCount int) error {
|
||||
func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) error {
|
||||
log := mach.machOrContextLog(ctx)
|
||||
start := time.Now()
|
||||
mach.otkUploadLock.Lock()
|
||||
defer mach.otkUploadLock.Unlock()
|
||||
if mach.lastOTKUpload.Add(1 * time.Minute).After(start) {
|
||||
mach.Log.Trace("Checking OTK count from server due to suspiciously close share keys requests")
|
||||
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests")
|
||||
resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{})
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check current OTK counts: %w", err)
|
||||
}
|
||||
mach.Log.Trace("Fetched current OTK count (%d) from server (input count was %d)", resp.OneTimeKeyCounts.SignedCurve25519, currentOTKCount)
|
||||
log.Debug().
|
||||
Int("input_count", currentOTKCount).
|
||||
Int("server_count", resp.OneTimeKeyCounts.SignedCurve25519).
|
||||
Msg("Fetched current OTK count from server")
|
||||
currentOTKCount = resp.OneTimeKeyCounts.SignedCurve25519
|
||||
}
|
||||
var deviceKeys *mautrix.DeviceKeys
|
||||
if !mach.account.Shared {
|
||||
deviceKeys = mach.account.getInitialKeys(mach.Client.UserID, mach.Client.DeviceID)
|
||||
mach.Log.Trace("Going to upload initial account keys")
|
||||
log.Debug().Msg("Going to upload initial account keys")
|
||||
}
|
||||
oneTimeKeys := mach.account.getOneTimeKeys(mach.Client.UserID, mach.Client.DeviceID, currentOTKCount)
|
||||
if len(oneTimeKeys) == 0 && deviceKeys == nil {
|
||||
mach.Log.Trace("No one-time keys nor device keys got when trying to share keys")
|
||||
log.Debug().Msg("No one-time keys nor device keys got when trying to share keys")
|
||||
return nil
|
||||
}
|
||||
req := &mautrix.ReqUploadKeys{
|
||||
DeviceKeys: deviceKeys,
|
||||
OneTimeKeys: oneTimeKeys,
|
||||
}
|
||||
mach.Log.Trace("Uploading %d one-time keys", len(oneTimeKeys))
|
||||
log.Debug().Int("count", len(oneTimeKeys)).Msg("Uploading one-time keys")
|
||||
_, err := mach.Client.UploadKeys(req)
|
||||
if err != nil {
|
||||
return err
|
||||
@@ -528,3 +656,23 @@ func (mach *OlmMachine) ShareKeys(currentOTKCount int) error {
|
||||
mach.saveAccount()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {
|
||||
log := mach.Log.With().Str("action", "redact expired sessions").Logger()
|
||||
for {
|
||||
sessionIDs, err := mach.CryptoStore.RedactExpiredGroupSessions()
|
||||
if err != nil {
|
||||
log.Err(err).Msg("Failed to redact expired megolm sessions")
|
||||
} else if len(sessionIDs) > 0 {
|
||||
log.Info().Strs("session_ids", stringifyArray(sessionIDs)).Msg("Redacted expired megolm sessions")
|
||||
} else {
|
||||
log.Debug().Msg("Didn't find any expired megolm sessions")
|
||||
}
|
||||
select {
|
||||
case <-ctx.Done():
|
||||
log.Debug().Msg("Loop stopped")
|
||||
return
|
||||
case <-time.After(24 * time.Hour):
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
6
vendor/maunium.net/go/mautrix/crypto/olm/inboundgroupsession.go
generated
vendored
6
vendor/maunium.net/go/mautrix/crypto/olm/inboundgroupsession.go
generated
vendored
@@ -288,7 +288,7 @@ func (s *InboundGroupSession) exportLen() uint {
|
||||
// if we do not have a session key corresponding to the given index (ie, it was
|
||||
// sent before the session key was shared with us) the error will be
|
||||
// "OLM_UNKNOWN_MESSAGE_INDEX".
|
||||
func (s *InboundGroupSession) Export(messageIndex uint32) (string, error) {
|
||||
func (s *InboundGroupSession) Export(messageIndex uint32) ([]byte, error) {
|
||||
key := make([]byte, s.exportLen())
|
||||
r := C.olm_export_inbound_group_session(
|
||||
(*C.OlmInboundGroupSession)(s.int),
|
||||
@@ -296,7 +296,7 @@ func (s *InboundGroupSession) Export(messageIndex uint32) (string, error) {
|
||||
C.size_t(len(key)),
|
||||
C.uint32_t(messageIndex))
|
||||
if r == errorVal() {
|
||||
return "", s.lastError()
|
||||
return nil, s.lastError()
|
||||
}
|
||||
return string(key[:r]), nil
|
||||
return key[:r], nil
|
||||
}
|
||||
|
||||
3
vendor/maunium.net/go/mautrix/crypto/olm/verification.go
generated
vendored
3
vendor/maunium.net/go/mautrix/crypto/olm/verification.go
generated
vendored
@@ -1,6 +1,3 @@
|
||||
//go:build !nosas
|
||||
// +build !nosas
|
||||
|
||||
package olm
|
||||
|
||||
// #cgo LDFLAGS: -lolm -lstdc++
|
||||
|
||||
31
vendor/maunium.net/go/mautrix/crypto/sessions.go
generated
vendored
31
vendor/maunium.net/go/mautrix/crypto/sessions.go
generated
vendored
@@ -89,6 +89,12 @@ func (session *OlmSession) Decrypt(ciphertext string, msgType id.OlmMsgType) ([]
|
||||
return msg, err
|
||||
}
|
||||
|
||||
type RatchetSafety struct {
|
||||
NextIndex uint `json:"next_index"`
|
||||
MissedIndices []uint `json:"missed_indices,omitempty"`
|
||||
LostIndices []uint `json:"lost_indices,omitempty"`
|
||||
}
|
||||
|
||||
type InboundGroupSession struct {
|
||||
Internal olm.InboundGroupSession
|
||||
|
||||
@@ -97,11 +103,17 @@ type InboundGroupSession struct {
|
||||
RoomID id.RoomID
|
||||
|
||||
ForwardingChains []string
|
||||
RatchetSafety RatchetSafety
|
||||
|
||||
ReceivedAt time.Time
|
||||
MaxAge int64
|
||||
MaxMessages int
|
||||
IsScheduled bool
|
||||
|
||||
id id.SessionID
|
||||
}
|
||||
|
||||
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string) (*InboundGroupSession, error) {
|
||||
func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomID id.RoomID, sessionKey string, maxAge time.Duration, maxMessages int, isScheduled bool) (*InboundGroupSession, error) {
|
||||
igs, err := olm.NewInboundGroupSession([]byte(sessionKey))
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -112,6 +124,10 @@ func NewInboundGroupSession(senderKey id.SenderKey, signingKey id.Ed25519, roomI
|
||||
SenderKey: senderKey,
|
||||
RoomID: roomID,
|
||||
ForwardingChains: nil,
|
||||
ReceivedAt: time.Now().UTC(),
|
||||
MaxAge: maxAge.Milliseconds(),
|
||||
MaxMessages: maxMessages,
|
||||
IsScheduled: isScheduled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -122,6 +138,19 @@ func (igs *InboundGroupSession) ID() id.SessionID {
|
||||
return igs.id
|
||||
}
|
||||
|
||||
func (igs *InboundGroupSession) RatchetTo(index uint32) error {
|
||||
exported, err := igs.Internal.Export(index)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
imported, err := olm.InboundGroupSessionImport(exported)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
igs.Internal = *imported
|
||||
return nil
|
||||
}
|
||||
|
||||
type OGSState int
|
||||
|
||||
const (
|
||||
|
||||
219
vendor/maunium.net/go/mautrix/crypto/sql_store.go
generated
vendored
219
vendor/maunium.net/go/mautrix/crypto/sql_store.go
generated
vendored
@@ -7,13 +7,19 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"database/sql/driver"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strings"
|
||||
"sync"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix"
|
||||
"maunium.net/go/mautrix/crypto/olm"
|
||||
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
|
||||
"maunium.net/go/mautrix/event"
|
||||
@@ -80,6 +86,35 @@ func (store *SQLCryptoStore) GetNextBatch() (string, error) {
|
||||
return store.SyncToken, nil
|
||||
}
|
||||
|
||||
var _ mautrix.SyncStore = (*SQLCryptoStore)(nil)
|
||||
|
||||
func (store *SQLCryptoStore) SaveFilterID(_ id.UserID, _ string) {}
|
||||
func (store *SQLCryptoStore) LoadFilterID(_ id.UserID) string { return "" }
|
||||
|
||||
func (store *SQLCryptoStore) SaveNextBatch(_ id.UserID, nextBatchToken string) {
|
||||
err := store.PutNextBatch(nextBatchToken)
|
||||
if err != nil {
|
||||
// TODO handle error
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) LoadNextBatch(_ id.UserID) string {
|
||||
nb, err := store.GetNextBatch()
|
||||
if err != nil {
|
||||
// TODO handle error
|
||||
}
|
||||
return nb
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) FindDeviceID() (deviceID id.DeviceID) {
|
||||
err := store.DB.QueryRow("SELECT device_id FROM crypto_account WHERE account_id=$1", store.AccountID).Scan(&deviceID)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
// TODO return error
|
||||
store.DB.Log.Warn("Failed to scan device ID: %v", err)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
// PutAccount stores an OlmAccount in the database.
|
||||
func (store *SQLCryptoStore) PutAccount(account *OlmAccount) error {
|
||||
store.Account = account
|
||||
@@ -220,37 +255,72 @@ func (store *SQLCryptoStore) UpdateSession(_ id.SenderKey, session *OlmSession)
|
||||
return err
|
||||
}
|
||||
|
||||
func intishPtr[T int | int64](i T) *T {
|
||||
if i == 0 {
|
||||
return nil
|
||||
}
|
||||
return &i
|
||||
}
|
||||
|
||||
func datePtr(t time.Time) *time.Time {
|
||||
if t.IsZero() {
|
||||
return nil
|
||||
}
|
||||
return &t
|
||||
}
|
||||
|
||||
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *InboundGroupSession) error {
|
||||
sessionBytes := session.Internal.Pickle(store.PickleKey)
|
||||
forwardingChains := strings.Join(session.ForwardingChains, ",")
|
||||
_, err := store.DB.Exec(`
|
||||
INSERT INTO crypto_megolm_inbound_session
|
||||
(session_id, sender_key, signing_key, room_id, session, forwarding_chains, account_id)
|
||||
VALUES ($1, $2, $3, $4, $5, $6, $7)
|
||||
ratchetSafety, err := json.Marshal(&session.RatchetSafety)
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to marshal ratchet safety info: %w", err)
|
||||
}
|
||||
_, err = store.DB.Exec(`
|
||||
INSERT INTO crypto_megolm_inbound_session (
|
||||
session_id, sender_key, signing_key, room_id, session, forwarding_chains,
|
||||
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id
|
||||
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12)
|
||||
ON CONFLICT (session_id, account_id) DO UPDATE
|
||||
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
|
||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains
|
||||
`, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, store.AccountID)
|
||||
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
|
||||
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
|
||||
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled
|
||||
`,
|
||||
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
|
||||
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
|
||||
session.IsScheduled, store.AccountID,
|
||||
)
|
||||
return err
|
||||
}
|
||||
|
||||
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
|
||||
func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
var signingKey, forwardingChains, withheldCode sql.NullString
|
||||
var sessionBytes []byte
|
||||
var senderKeyDB, signingKey, forwardingChains, withheldCode, withheldReason sql.NullString
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT signing_key, session, forwarding_chains, withheld_code
|
||||
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session
|
||||
WHERE room_id=$1 AND sender_key=$2 AND session_id=$3 AND account_id=$4`,
|
||||
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
|
||||
roomID, senderKey, sessionID, store.AccountID,
|
||||
).Scan(&signingKey, &sessionBytes, &forwardingChains, &withheldCode)
|
||||
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
return nil, err
|
||||
} else if withheldCode.Valid {
|
||||
return nil, fmt.Errorf("%w (%s)", ErrGroupSessionWithheld, withheldCode.String)
|
||||
return nil, &event.RoomKeyWithheldEventContent{
|
||||
RoomID: roomID,
|
||||
Algorithm: id.AlgorithmMegolmV1,
|
||||
SessionID: sessionID,
|
||||
SenderKey: senderKey,
|
||||
Code: event.RoomKeyWithheldCode(withheldCode.String),
|
||||
Reason: withheldReason.String,
|
||||
}
|
||||
}
|
||||
igs := olm.NewBlankInboundGroupSession()
|
||||
err = igs.Unpickle(sessionBytes, store.PickleKey)
|
||||
@@ -261,18 +331,96 @@ func (store *SQLCryptoStore) GetGroupSession(roomID id.RoomID, senderKey id.Send
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
if senderKey == "" {
|
||||
senderKey = id.Curve25519(senderKeyDB.String)
|
||||
}
|
||||
return &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: senderKey,
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
}, nil
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSession(_ id.RoomID, _ id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
_, err := store.DB.Exec(`
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE session_id=$3 AND account_id=$4 AND session IS NOT NULL
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, sessionID, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
if roomID == "" && senderKey == "" {
|
||||
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
|
||||
}
|
||||
res, err := store.DB.Query(`
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE (room_id=$3 OR $3='') AND (sender_key=$4 OR $4='') AND account_id=$5
|
||||
AND session IS NOT NULL AND is_scheduled=false AND received_at IS NOT NULL
|
||||
RETURNING session_id
|
||||
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: "+reason, roomID, senderKey, store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
return sessionIDs, err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
|
||||
var query string
|
||||
switch store.DB.Dialect {
|
||||
case dbutil.Postgres:
|
||||
query = `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
|
||||
AND received_at IS NOT NULL and max_age IS NOT NULL
|
||||
AND received_at + 2 * (max_age * interval '1 millisecond') < now()
|
||||
RETURNING session_id
|
||||
`
|
||||
case dbutil.SQLite:
|
||||
query = `
|
||||
UPDATE crypto_megolm_inbound_session
|
||||
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
|
||||
WHERE account_id=$3 AND session IS NOT NULL AND is_scheduled=false
|
||||
AND received_at IS NOT NULL and max_age IS NOT NULL
|
||||
AND unixepoch(received_at) + (2 * max_age / 1000) < unixepoch(date('now'))
|
||||
RETURNING session_id
|
||||
`
|
||||
default:
|
||||
return nil, fmt.Errorf("unsupported dialect")
|
||||
}
|
||||
res, err := store.DB.Query(query, event.RoomKeyWithheldBeeperRedacted, "Session redacted: expired", store.AccountID)
|
||||
var sessionIDs []id.SessionID
|
||||
for res.Next() {
|
||||
var sessionID id.SessionID
|
||||
_ = res.Scan(&sessionID)
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
return sessionIDs, err
|
||||
}
|
||||
|
||||
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, account_id) VALUES ($1, $2, $3, $4, $5, $6)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, store.AccountID)
|
||||
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
|
||||
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -302,8 +450,11 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
var signingKey, senderKey, forwardingChains sql.NullString
|
||||
var sessionBytes []byte
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains)
|
||||
var sessionBytes, ratchetSafetyBytes []byte
|
||||
var receivedAt sql.NullTime
|
||||
var maxAge, maxMessages sql.NullInt64
|
||||
var isScheduled bool
|
||||
err = rows.Scan(&roomID, &signingKey, &senderKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled)
|
||||
if err != nil {
|
||||
return
|
||||
}
|
||||
@@ -316,12 +467,24 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
||||
if forwardingChains.String != "" {
|
||||
chains = strings.Split(forwardingChains.String, ",")
|
||||
}
|
||||
var rs RatchetSafety
|
||||
if len(ratchetSafetyBytes) > 0 {
|
||||
err = json.Unmarshal(ratchetSafetyBytes, &rs)
|
||||
if err != nil {
|
||||
return nil, fmt.Errorf("failed to unmarshal ratchet safety info: %w", err)
|
||||
}
|
||||
}
|
||||
result = append(result, &InboundGroupSession{
|
||||
Internal: *igs,
|
||||
SigningKey: id.Ed25519(signingKey.String),
|
||||
SenderKey: id.Curve25519(senderKey.String),
|
||||
RoomID: roomID,
|
||||
ForwardingChains: chains,
|
||||
RatchetSafety: rs,
|
||||
ReceivedAt: receivedAt.Time,
|
||||
MaxAge: maxAge.Int64,
|
||||
MaxMessages: int(maxMessages.Int64),
|
||||
IsScheduled: isScheduled,
|
||||
})
|
||||
}
|
||||
return
|
||||
@@ -329,7 +492,7 @@ func (store *SQLCryptoStore) scanGroupSessionList(rows dbutil.Rows) (result []*I
|
||||
|
||||
func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
|
||||
roomID, store.AccountID,
|
||||
)
|
||||
@@ -343,7 +506,7 @@ func (store *SQLCryptoStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*Inbou
|
||||
|
||||
func (store *SQLCryptoStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
rows, err := store.DB.Query(`
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains
|
||||
SELECT room_id, signing_key, sender_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled
|
||||
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`,
|
||||
store.AccountID,
|
||||
)
|
||||
@@ -367,7 +530,7 @@ func (store *SQLCryptoStore) AddOutboundGroupSession(session *OutboundGroupSessi
|
||||
max_messages=excluded.max_messages, message_count=excluded.message_count, max_age=excluded.max_age,
|
||||
created_at=excluded.created_at, last_used=excluded.last_used, account_id=excluded.account_id
|
||||
`, session.RoomID, session.ID(), sessionBytes, session.Shared, session.MaxMessages, session.MessageCount,
|
||||
session.MaxAge, session.CreationTime, session.LastEncryptedTime, store.AccountID)
|
||||
session.MaxAge.Milliseconds(), session.CreationTime, session.LastEncryptedTime, store.AccountID)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -383,11 +546,12 @@ func (store *SQLCryptoStore) UpdateOutboundGroupSession(session *OutboundGroupSe
|
||||
func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
var ogs OutboundGroupSession
|
||||
var sessionBytes []byte
|
||||
var maxAgeMS int64
|
||||
err := store.DB.QueryRow(`
|
||||
SELECT session, shared, max_messages, message_count, max_age, created_at, last_used
|
||||
FROM crypto_megolm_outbound_session WHERE room_id=$1 AND account_id=$2`,
|
||||
roomID, store.AccountID,
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &ogs.MaxAge, &ogs.CreationTime, &ogs.LastEncryptedTime)
|
||||
).Scan(&sessionBytes, &ogs.Shared, &ogs.MaxMessages, &ogs.MessageCount, &maxAgeMS, &ogs.CreationTime, &ogs.LastEncryptedTime)
|
||||
if err == sql.ErrNoRows {
|
||||
return nil, nil
|
||||
} else if err != nil {
|
||||
@@ -400,6 +564,7 @@ func (store *SQLCryptoStore) GetOutboundGroupSession(roomID id.RoomID) (*Outboun
|
||||
}
|
||||
ogs.Internal = *intOGS
|
||||
ogs.RoomID = roomID
|
||||
ogs.MaxAge = time.Duration(maxAgeMS) * time.Millisecond
|
||||
return &ogs, nil
|
||||
}
|
||||
|
||||
@@ -412,7 +577,7 @@ func (store *SQLCryptoStore) RemoveOutboundGroupSession(roomID id.RoomID) error
|
||||
|
||||
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
|
||||
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
|
||||
func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
|
||||
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
|
||||
const validateQuery = `
|
||||
INSERT INTO crypto_message_index (sender_key, session_id, "index", event_id, timestamp)
|
||||
VALUES ($1, $2, $3, $4, $5)
|
||||
@@ -422,11 +587,19 @@ func (store *SQLCryptoStore) ValidateMessageIndex(senderKey id.SenderKey, sessio
|
||||
`
|
||||
var expectedEventID id.EventID
|
||||
var expectedTimestamp int64
|
||||
err := store.DB.QueryRow(validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
err := store.DB.QueryRowContext(ctx, validateQuery, senderKey, sessionID, index, eventID, timestamp).Scan(&expectedEventID, &expectedTimestamp)
|
||||
if err != nil {
|
||||
return false, err
|
||||
} else if expectedEventID != eventID || expectedTimestamp != timestamp {
|
||||
zerolog.Ctx(ctx).Debug().
|
||||
Uint("message_index", index).
|
||||
Str("expected_event_id", expectedEventID.String()).
|
||||
Int64("expected_timestamp", expectedTimestamp).
|
||||
Int64("actual_timestamp", timestamp).
|
||||
Msg("Failed to validate that message index wasn't duplicated")
|
||||
return false, nil
|
||||
}
|
||||
return expectedEventID == eventID && expectedTimestamp == timestamp, nil
|
||||
return true, nil
|
||||
}
|
||||
|
||||
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
|
||||
|
||||
7
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/00-latest-revision.sql
generated
vendored
7
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/00-latest-revision.sql
generated
vendored
@@ -1,4 +1,4 @@
|
||||
-- v0 -> v8: Latest revision
|
||||
-- v0 -> v10: Latest revision
|
||||
CREATE TABLE IF NOT EXISTS crypto_account (
|
||||
account_id TEXT PRIMARY KEY,
|
||||
device_id TEXT NOT NULL,
|
||||
@@ -52,6 +52,11 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
|
||||
forwarding_chains bytea,
|
||||
withheld_code TEXT,
|
||||
withheld_reason TEXT,
|
||||
ratchet_safety jsonb,
|
||||
received_at timestamp,
|
||||
max_age BIGINT,
|
||||
max_messages INTEGER,
|
||||
is_scheduled BOOLEAN NOT NULL DEFAULT false,
|
||||
PRIMARY KEY (account_id, session_id)
|
||||
);
|
||||
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/09-max-age-ms.sql
generated
vendored
Normal file
2
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/09-max-age-ms.sql
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
-- v9: Change outbound megolm session max_age column to milliseconds
|
||||
UPDATE crypto_megolm_outbound_session SET max_age=max_age/1000000;
|
||||
6
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/10-mark-ratchetable-keys.sql
generated
vendored
Normal file
6
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/10-mark-ratchetable-keys.sql
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
-- v10: Add metadata for detecting when megolm sessions are safe to delete
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN ratchet_safety jsonb;
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN received_at timestamp;
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_age BIGINT;
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN max_messages INTEGER;
|
||||
ALTER TABLE crypto_megolm_inbound_session ADD COLUMN is_scheduled BOOLEAN NOT NULL DEFAULT false;
|
||||
2
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/upgrade.go
generated
vendored
2
vendor/maunium.net/go/mautrix/crypto/sql_store_upgrade/upgrade.go
generated
vendored
@@ -21,7 +21,7 @@ const VersionTableName = "crypto_version"
|
||||
var fs embed.FS
|
||||
|
||||
func init() {
|
||||
Table.Register(-1, 3, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
|
||||
Table.Register(-1, 3, 0, "Unsupported version", false, func(tx dbutil.Execable, database *dbutil.Database) error {
|
||||
return fmt.Errorf("upgrading from versions 1 and 2 of the crypto store is no longer supported in mautrix-go v0.12+")
|
||||
})
|
||||
Table.RegisterFS(fs)
|
||||
|
||||
184
vendor/maunium.net/go/mautrix/crypto/store.go
generated
vendored
184
vendor/maunium.net/go/mautrix/crypto/store.go
generated
vendored
@@ -7,10 +7,8 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"encoding/gob"
|
||||
"errors"
|
||||
"context"
|
||||
"fmt"
|
||||
"os"
|
||||
"sort"
|
||||
"sync"
|
||||
|
||||
@@ -18,10 +16,7 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Deprecated: moved to id.Device
|
||||
type DeviceIdentity = id.Device
|
||||
|
||||
var ErrGroupSessionWithheld = errors.New("group session has been withheld")
|
||||
var ErrGroupSessionWithheld error = &event.RoomKeyWithheldEventContent{}
|
||||
|
||||
// Store is used by OlmMachine to store Olm and Megolm sessions, user device lists and message indices.
|
||||
//
|
||||
@@ -58,6 +53,12 @@ type Store interface {
|
||||
// (i.e. a room key withheld event has been saved with PutWithheldGroupSession), this should return the
|
||||
// ErrGroupSessionWithheld error. The caller may use GetWithheldGroupSession to find more details.
|
||||
GetGroupSession(id.RoomID, id.SenderKey, id.SessionID) (*InboundGroupSession, error)
|
||||
// RedactGroupSession removes the session data for the given inbound Megolm session from the store.
|
||||
RedactGroupSession(id.RoomID, id.SenderKey, id.SessionID, string) error
|
||||
// RedactGroupSessions removes the session data for all inbound Megolm sessions from a specific device and/or in a specific room.
|
||||
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
|
||||
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
|
||||
RedactExpiredGroupSessions() ([]id.SessionID, error)
|
||||
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
|
||||
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
|
||||
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
|
||||
@@ -90,9 +91,9 @@ type Store interface {
|
||||
// * If the map key doesn't exist, the given values should be stored and this should return true.
|
||||
// * If the map key exists and the stored values match the given values, this should return true.
|
||||
// * If the map key exists, but the stored values do not match the given values, this should return false.
|
||||
ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
|
||||
ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error)
|
||||
|
||||
// GetDevices returns a map from device ID to DeviceIdentity containing all devices of a given user.
|
||||
// GetDevices returns a map from device ID to id.Device struct containing all devices of a given user.
|
||||
GetDevices(id.UserID) (map[id.DeviceID]*id.Device, error)
|
||||
// GetDevice returns a specific device of a given user.
|
||||
GetDevice(id.UserID, id.DeviceID) (*id.Device, error)
|
||||
@@ -129,12 +130,12 @@ type messageIndexValue struct {
|
||||
Timestamp int64
|
||||
}
|
||||
|
||||
// GobStore is a simple Store implementation that dumps everything into a .gob file.
|
||||
//
|
||||
// Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended.
|
||||
type GobStore struct {
|
||||
// MemoryStore is a simple in-memory Store implementation. It can optionally have a callback function for saving data,
|
||||
// but the actual storage must be implemented manually.
|
||||
type MemoryStore struct {
|
||||
lock sync.RWMutex
|
||||
path string
|
||||
|
||||
save func() error
|
||||
|
||||
Account *OlmAccount
|
||||
Sessions map[id.SenderKey]OlmSessionList
|
||||
@@ -147,14 +148,15 @@ type GobStore struct {
|
||||
KeySignatures map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string
|
||||
}
|
||||
|
||||
var _ Store = (*GobStore)(nil)
|
||||
var _ Store = (*MemoryStore)(nil)
|
||||
|
||||
func NewMemoryStore(saveCallback func() error) *MemoryStore {
|
||||
if saveCallback == nil {
|
||||
saveCallback = func() error { return nil }
|
||||
}
|
||||
return &MemoryStore{
|
||||
save: saveCallback,
|
||||
|
||||
// NewGobStore creates a new GobStore that saves everything to the given file.
|
||||
//
|
||||
// Deprecated: this is not atomic and can lose data. Using SQLCryptoStore or a custom implementation is recommended.
|
||||
func NewGobStore(path string) (*GobStore, error) {
|
||||
gs := &GobStore{
|
||||
path: path,
|
||||
Sessions: make(map[id.SenderKey]OlmSessionList),
|
||||
GroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*InboundGroupSession),
|
||||
WithheldGroupSessions: make(map[id.RoomID]map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent),
|
||||
@@ -164,44 +166,20 @@ func NewGobStore(path string) (*GobStore, error) {
|
||||
CrossSigningKeys: make(map[id.UserID]map[id.CrossSigningUsage]id.CrossSigningKey),
|
||||
KeySignatures: make(map[id.UserID]map[id.Ed25519]map[id.UserID]map[id.Ed25519]string),
|
||||
}
|
||||
return gs, gs.load()
|
||||
}
|
||||
|
||||
func (gs *GobStore) save() error {
|
||||
file, err := os.OpenFile(gs.path, os.O_CREATE|os.O_WRONLY, 0600)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
err = gob.NewEncoder(file).Encode(gs)
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) load() error {
|
||||
file, err := os.OpenFile(gs.path, os.O_RDONLY, 0600)
|
||||
if err != nil {
|
||||
if os.IsNotExist(err) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
err = gob.NewDecoder(file).Decode(gs)
|
||||
_ = file.Close()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) Flush() error {
|
||||
func (gs *MemoryStore) Flush() error {
|
||||
gs.lock.Lock()
|
||||
err := gs.save()
|
||||
gs.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetAccount() (*OlmAccount, error) {
|
||||
func (gs *MemoryStore) GetAccount() (*OlmAccount, error) {
|
||||
return gs.Account, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutAccount(account *OlmAccount) error {
|
||||
func (gs *MemoryStore) PutAccount(account *OlmAccount) error {
|
||||
gs.lock.Lock()
|
||||
gs.Account = account
|
||||
err := gs.save()
|
||||
@@ -209,7 +187,7 @@ func (gs *GobStore) PutAccount(account *OlmAccount) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
|
||||
func (gs *MemoryStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error) {
|
||||
gs.lock.Lock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
if !ok {
|
||||
@@ -220,7 +198,7 @@ func (gs *GobStore) GetSessions(senderKey id.SenderKey) (OlmSessionList, error)
|
||||
return sessions, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
|
||||
func (gs *MemoryStore) AddSession(senderKey id.SenderKey, session *OlmSession) error {
|
||||
gs.lock.Lock()
|
||||
sessions, _ := gs.Sessions[senderKey]
|
||||
gs.Sessions[senderKey] = append(sessions, session)
|
||||
@@ -230,19 +208,19 @@ func (gs *GobStore) AddSession(senderKey id.SenderKey, session *OlmSession) erro
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
|
||||
func (gs *MemoryStore) UpdateSession(_ id.SenderKey, _ *OlmSession) error {
|
||||
// we don't need to do anything here because the session is a pointer and already stored in our map
|
||||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *GobStore) HasSession(senderKey id.SenderKey) bool {
|
||||
func (gs *MemoryStore) HasSession(senderKey id.SenderKey) bool {
|
||||
gs.lock.RLock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
gs.lock.RUnlock()
|
||||
return ok && len(sessions) > 0 && !sessions[0].Expired()
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
|
||||
func (gs *MemoryStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error) {
|
||||
gs.lock.RLock()
|
||||
sessions, ok := gs.Sessions[senderKey]
|
||||
gs.lock.RUnlock()
|
||||
@@ -252,7 +230,7 @@ func (gs *GobStore) GetLatestSession(senderKey id.SenderKey) (*OlmSession, error
|
||||
return sessions[0], nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession {
|
||||
func (gs *MemoryStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*InboundGroupSession {
|
||||
room, ok := gs.GroupSessions[roomID]
|
||||
if !ok {
|
||||
room = make(map[id.SenderKey]map[id.SessionID]*InboundGroupSession)
|
||||
@@ -266,7 +244,7 @@ func (gs *GobStore) getGroupSessions(roomID id.RoomID, senderKey id.SenderKey) m
|
||||
return sender
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
|
||||
func (gs *MemoryStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, igs *InboundGroupSession) error {
|
||||
gs.lock.Lock()
|
||||
gs.getGroupSessions(roomID, senderKey)[sessionID] = igs
|
||||
err := gs.save()
|
||||
@@ -274,7 +252,7 @@ func (gs *GobStore) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, se
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.getGroupSessions(roomID, senderKey)[sessionID]
|
||||
if !ok {
|
||||
@@ -289,7 +267,57 @@ func (gs *GobStore) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, se
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
|
||||
func (gs *MemoryStore) RedactGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, reason string) error {
|
||||
gs.lock.Lock()
|
||||
delete(gs.getGroupSessions(roomID, senderKey), sessionID)
|
||||
err := gs.save()
|
||||
gs.lock.Unlock()
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactGroupSessions(roomID id.RoomID, senderKey id.SenderKey, reason string) ([]id.SessionID, error) {
|
||||
gs.lock.Lock()
|
||||
var sessionIDs []id.SessionID
|
||||
if roomID != "" && senderKey != "" {
|
||||
sessions := gs.getGroupSessions(roomID, senderKey)
|
||||
for sessionID := range sessions {
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
delete(sessions, sessionID)
|
||||
}
|
||||
} else if senderKey != "" {
|
||||
for _, room := range gs.GroupSessions {
|
||||
sessions, ok := room[senderKey]
|
||||
if ok {
|
||||
for sessionID := range sessions {
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
delete(room, senderKey)
|
||||
}
|
||||
}
|
||||
} else if roomID != "" {
|
||||
room, ok := gs.GroupSessions[roomID]
|
||||
if ok {
|
||||
for senderKey := range room {
|
||||
sessions := room[senderKey]
|
||||
for sessionID := range sessions {
|
||||
sessionIDs = append(sessionIDs, sessionID)
|
||||
}
|
||||
}
|
||||
delete(gs.GroupSessions, roomID)
|
||||
}
|
||||
} else {
|
||||
return nil, fmt.Errorf("room ID or sender key must be provided for redacting sessions")
|
||||
}
|
||||
err := gs.save()
|
||||
gs.lock.Unlock()
|
||||
return sessionIDs, err
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
|
||||
return nil, fmt.Errorf("not implemented")
|
||||
}
|
||||
|
||||
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
|
||||
room, ok := gs.WithheldGroupSessions[roomID]
|
||||
if !ok {
|
||||
room = make(map[id.SenderKey]map[id.SessionID]*event.RoomKeyWithheldEventContent)
|
||||
@@ -303,7 +331,7 @@ func (gs *GobStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.Send
|
||||
return sender
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
func (gs *MemoryStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
|
||||
gs.lock.Lock()
|
||||
gs.getWithheldGroupSessions(content.RoomID, content.SenderKey)[content.SessionID] = &content
|
||||
err := gs.save()
|
||||
@@ -311,7 +339,7 @@ func (gs *GobStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventCo
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
func (gs *MemoryStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.getWithheldGroupSessions(roomID, senderKey)[sessionID]
|
||||
gs.lock.Unlock()
|
||||
@@ -321,7 +349,7 @@ func (gs *GobStore) GetWithheldGroupSession(roomID id.RoomID, senderKey id.Sende
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
defer gs.lock.Unlock()
|
||||
room, ok := gs.GroupSessions[roomID]
|
||||
@@ -337,7 +365,7 @@ func (gs *GobStore) GetGroupSessionsForRoom(roomID id.RoomID) ([]*InboundGroupSe
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
gs.lock.Lock()
|
||||
var result []*InboundGroupSession
|
||||
for _, room := range gs.GroupSessions {
|
||||
@@ -351,7 +379,7 @@ func (gs *GobStore) GetAllGroupSessions() ([]*InboundGroupSession, error) {
|
||||
return result, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
func (gs *MemoryStore) AddOutboundGroupSession(session *OutboundGroupSession) error {
|
||||
gs.lock.Lock()
|
||||
gs.OutGroupSessions[session.RoomID] = session
|
||||
err := gs.save()
|
||||
@@ -359,12 +387,12 @@ func (gs *GobStore) AddOutboundGroupSession(session *OutboundGroupSession) error
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
|
||||
func (gs *MemoryStore) UpdateOutboundGroupSession(_ *OutboundGroupSession) error {
|
||||
// we don't need to do anything here because the session is a pointer and already stored in our map
|
||||
return gs.save()
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
func (gs *MemoryStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSession, error) {
|
||||
gs.lock.RLock()
|
||||
session, ok := gs.OutGroupSessions[roomID]
|
||||
gs.lock.RUnlock()
|
||||
@@ -374,7 +402,7 @@ func (gs *GobStore) GetOutboundGroupSession(roomID id.RoomID) (*OutboundGroupSes
|
||||
return session, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
func (gs *MemoryStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
gs.lock.Lock()
|
||||
session, ok := gs.OutGroupSessions[roomID]
|
||||
if !ok || session == nil {
|
||||
@@ -386,7 +414,7 @@ func (gs *GobStore) RemoveOutboundGroupSession(roomID id.RoomID) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
|
||||
func (gs *MemoryStore) ValidateMessageIndex(_ context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
|
||||
gs.lock.Lock()
|
||||
defer gs.lock.Unlock()
|
||||
key := messageIndexKey{
|
||||
@@ -409,7 +437,7 @@ func (gs *GobStore) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.Se
|
||||
return true, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
func (gs *MemoryStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
if !ok {
|
||||
@@ -419,7 +447,7 @@ func (gs *GobStore) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, er
|
||||
return devices, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
func (gs *MemoryStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
@@ -433,7 +461,7 @@ func (gs *GobStore) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Devic
|
||||
return device, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
func (gs *MemoryStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
@@ -448,7 +476,7 @@ func (gs *GobStore) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey
|
||||
return nil, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
func (gs *MemoryStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
gs.lock.Lock()
|
||||
devices, ok := gs.Devices[userID]
|
||||
if !ok {
|
||||
@@ -461,7 +489,7 @@ func (gs *GobStore) PutDevice(userID id.UserID, device *id.Device) error {
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
func (gs *MemoryStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
|
||||
gs.lock.Lock()
|
||||
gs.Devices[userID] = devices
|
||||
err := gs.save()
|
||||
@@ -469,7 +497,7 @@ func (gs *GobStore) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Dev
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
|
||||
func (gs *MemoryStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
|
||||
gs.lock.RLock()
|
||||
var ptr int
|
||||
for _, userID := range users {
|
||||
@@ -483,7 +511,7 @@ func (gs *GobStore) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
|
||||
return users[:ptr], nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
func (gs *MemoryStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
|
||||
gs.lock.RLock()
|
||||
userKeys, ok := gs.CrossSigningKeys[userID]
|
||||
if !ok {
|
||||
@@ -505,7 +533,7 @@ func (gs *GobStore) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUs
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
func (gs *MemoryStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
keys, ok := gs.CrossSigningKeys[userID]
|
||||
@@ -515,7 +543,7 @@ func (gs *GobStore) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUs
|
||||
return keys, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
func (gs *MemoryStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
|
||||
gs.lock.RLock()
|
||||
signedUserSigs, ok := gs.KeySignatures[signedUserID]
|
||||
if !ok {
|
||||
@@ -538,7 +566,7 @@ func (gs *GobStore) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, s
|
||||
return err
|
||||
}
|
||||
|
||||
func (gs *GobStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
func (gs *MemoryStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
|
||||
gs.lock.RLock()
|
||||
defer gs.lock.RUnlock()
|
||||
userKeys, ok := gs.KeySignatures[userID]
|
||||
@@ -556,7 +584,7 @@ func (gs *GobStore) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, sign
|
||||
return sigsBySigner, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
|
||||
func (gs *MemoryStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
|
||||
sigs, err := gs.GetSignaturesForKeyBy(userID, key, signerID)
|
||||
if err != nil {
|
||||
return false, err
|
||||
@@ -565,7 +593,7 @@ func (gs *GobStore) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.
|
||||
return ok, nil
|
||||
}
|
||||
|
||||
func (gs *GobStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
func (gs *MemoryStore) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
|
||||
var count int64
|
||||
gs.lock.RLock()
|
||||
for _, userSigs := range gs.KeySignatures {
|
||||
|
||||
131
vendor/maunium.net/go/mautrix/crypto/verification.go
generated
vendored
131
vendor/maunium.net/go/mautrix/crypto/verification.go
generated
vendored
@@ -4,9 +4,6 @@
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build !nosas
|
||||
// +build !nosas
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
@@ -92,13 +89,13 @@ func (mach *OlmMachine) getPKAndKeysMAC(sas *olm.SAS, sendingUser id.UserID, sen
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac))
|
||||
mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", signingKey, sasInfo+mainKeyID.String(), string(pubKeyMac))
|
||||
|
||||
keysMac, err := sas.CalculateMAC([]byte(keyIDString), []byte(sasInfo+"KEY_IDS"))
|
||||
if err != nil {
|
||||
return "", "", err
|
||||
}
|
||||
mach.Log.Trace("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac))
|
||||
mach.Log.Trace().Msgf("sas.CalculateMAC(\"%s\", \"%s\") -> \"%s\"", keyIDString, sasInfo+"KEY_IDS", string(keysMac))
|
||||
|
||||
return string(pubKeyMac), string(keysMac), nil
|
||||
}
|
||||
@@ -144,14 +141,14 @@ func (mach *OlmMachine) getTransactionState(transactionID string, userID id.User
|
||||
// handleVerificationStart handles an incoming m.key.verification.start message.
|
||||
// It initializes the state for this SAS verification process and stores it.
|
||||
func (mach *OlmMachine) handleVerificationStart(userID id.UserID, content *event.VerificationStartEventContent, transactionID string, timeout time.Duration, inRoomID id.RoomID) {
|
||||
mach.Log.Debug("Received verification start from %v", content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice)
|
||||
mach.Log.Debug().Msgf("Received verification start from %v", content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
|
||||
if err != nil {
|
||||
mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID)
|
||||
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
|
||||
return
|
||||
}
|
||||
warnAndCancel := func(logReason, cancelReason string) {
|
||||
mach.Log.Warn("Canceling verification transaction %v as it %s", transactionID, logReason)
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v as it %s", transactionID, logReason)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, cancelReason, event.VerificationCancelUnknownMethod)
|
||||
} else {
|
||||
@@ -179,7 +176,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
if inRoomID != "" && transactionID != "" {
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
|
||||
mach.Log.Error().Msgf("Failed to get transaction state for in-room verification %s start: %v", transactionID, err)
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Internal state error in gomuks :(", "net.maunium.internal_error")
|
||||
return
|
||||
}
|
||||
@@ -187,7 +184,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
sasMethods := commonSASMethods(verState.hooks, content.ShortAuthenticationString)
|
||||
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error accepting in-room SAS verification: %v", err)
|
||||
mach.Log.Error().Msgf("Error accepting in-room SAS verification: %v", err)
|
||||
}
|
||||
verState.chosenSASMethod = sasMethods[0]
|
||||
verState.verificationStarted = true
|
||||
@@ -197,7 +194,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
if resp == AcceptRequest {
|
||||
sasMethods := commonSASMethods(hooks, content.ShortAuthenticationString)
|
||||
if len(sasMethods) == 0 {
|
||||
mach.Log.Error("No common SAS methods: %v", content.ShortAuthenticationString)
|
||||
mach.Log.Error().Msgf("No common SAS methods: %v", content.ShortAuthenticationString)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "No common SAS methods", event.VerificationCancelUnknownMethod)
|
||||
} else {
|
||||
@@ -222,7 +219,7 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
_, loaded := mach.keyVerificationTransactionState.LoadOrStore(userID.String()+":"+transactionID, verState)
|
||||
if loaded {
|
||||
// transaction already exists
|
||||
mach.Log.Error("Transaction %v already exists, canceling", transactionID)
|
||||
mach.Log.Error().Msgf("Transaction %v already exists, canceling", transactionID)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Transaction already exists", event.VerificationCancelUnexpectedMessage)
|
||||
} else {
|
||||
@@ -240,10 +237,10 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
err = mach.SendInRoomSASVerificationAccept(inRoomID, userID, content, transactionID, verState.sas.GetPubkey(), sasMethods)
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error("Error accepting SAS verification: %v", err)
|
||||
mach.Log.Error().Msgf("Error accepting SAS verification: %v", err)
|
||||
}
|
||||
} else if resp == RejectRequest {
|
||||
mach.Log.Debug("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
mach.Log.Debug().Msgf("Not accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
var err error
|
||||
if inRoomID == "" {
|
||||
err = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
@@ -251,10 +248,10 @@ func (mach *OlmMachine) actuallyStartVerification(userID id.UserID, content *eve
|
||||
err = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error("Error canceling SAS verification: %v", err)
|
||||
mach.Log.Error().Msgf("Error canceling SAS verification: %v", err)
|
||||
}
|
||||
} else {
|
||||
mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -276,12 +273,12 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
|
||||
// if deadline exceeded cancel due to timeout
|
||||
mach.keyVerificationTransactionState.Delete(mapKey)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Timed out", event.VerificationCancelByTimeout)
|
||||
mach.Log.Warn("Verification transaction %v is canceled due to timing out", transactionID)
|
||||
mach.Log.Warn().Msgf("Verification transaction %v is canceled due to timing out", transactionID)
|
||||
verState.lock.Unlock()
|
||||
return
|
||||
}
|
||||
// otherwise the cancel func was called, so the timeout is reset
|
||||
mach.Log.Debug("Extending timeout for transaction %v", transactionID)
|
||||
mach.Log.Debug().Msgf("Extending timeout for transaction %v", transactionID)
|
||||
timeoutCtx, timeoutCancel = context.WithTimeout(context.Background(), timeout)
|
||||
verState.extendTimeout = timeoutCancel
|
||||
verState.lock.Unlock()
|
||||
@@ -292,10 +289,10 @@ func (mach *OlmMachine) timeoutAfter(verState *verificationState, transactionID
|
||||
// handleVerificationAccept handles an incoming m.key.verification.accept message.
|
||||
// It continues the SAS verification process by sending the SAS key message to the other device.
|
||||
func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *event.VerificationAcceptEventContent, transactionID string) {
|
||||
mach.Log.Debug("Received verification accept for transaction %v", transactionID)
|
||||
mach.Log.Debug().Msgf("Received verification accept for transaction %v", transactionID)
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error getting transaction state: %v", err)
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
}
|
||||
verState.lock.Lock()
|
||||
@@ -304,7 +301,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
|
||||
if !verState.initiatedByUs || verState.verificationStarted {
|
||||
// unexpected accept at this point
|
||||
mach.Log.Warn("Unexpected verification accept message for transaction %v", transactionID)
|
||||
mach.Log.Warn().Msgf("Unexpected verification accept message for transaction %v", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected accept message", event.VerificationCancelUnexpectedMessage)
|
||||
return
|
||||
@@ -316,7 +313,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
content.MessageAuthenticationCode != event.HKDFHMACSHA256 ||
|
||||
len(sasMethods) == 0 {
|
||||
|
||||
mach.Log.Warn("Canceling verification transaction %v due to unknown parameter", transactionID)
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to unknown parameter", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Verification uses unknown method", event.VerificationCancelUnknownMethod)
|
||||
return
|
||||
@@ -333,7 +330,7 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error("Error sending SAS key to other device: %v", err)
|
||||
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -341,10 +338,10 @@ func (mach *OlmMachine) handleVerificationAccept(userID id.UserID, content *even
|
||||
// handleVerificationKey handles an incoming m.key.verification.key message.
|
||||
// It stores the other device's public key in order to acquire the SAS shared secret.
|
||||
func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.VerificationKeyEventContent, transactionID string) {
|
||||
mach.Log.Debug("Got verification key for transaction %v: %v", transactionID, content.Key)
|
||||
mach.Log.Debug().Msgf("Got verification key for transaction %v: %v", transactionID, content.Key)
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error getting transaction state: %v", err)
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
}
|
||||
verState.lock.Lock()
|
||||
@@ -355,14 +352,14 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
|
||||
if !verState.verificationStarted || verState.keyReceived {
|
||||
// unexpected key at this point
|
||||
mach.Log.Warn("Unexpected verification key message for transaction %v", transactionID)
|
||||
mach.Log.Warn().Msgf("Unexpected verification key message for transaction %v", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected key message", event.VerificationCancelUnexpectedMessage)
|
||||
return
|
||||
}
|
||||
|
||||
if err := verState.sas.SetTheirKey([]byte(content.Key)); err != nil {
|
||||
mach.Log.Error("Error setting other device's key: %v", err)
|
||||
mach.Log.Error().Msgf("Error setting other device's key: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
@@ -371,9 +368,9 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
if verState.initiatedByUs {
|
||||
// verify commitment string from accept message now
|
||||
expectedCommitment := olm.NewUtility().Sha256(content.Key + verState.startEventCanonical)
|
||||
mach.Log.Debug("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment)
|
||||
mach.Log.Debug().Msgf("Received commitment: %v Expected: %v", verState.commitment, expectedCommitment)
|
||||
if expectedCommitment != verState.commitment {
|
||||
mach.Log.Warn("Canceling verification transaction %v due to commitment mismatch", transactionID)
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to commitment mismatch", transactionID)
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Commitment mismatch", event.VerificationCancelCommitmentMismatch)
|
||||
return
|
||||
@@ -388,7 +385,7 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
err = mach.SendInRoomSASVerificationKey(verState.inRoomID, userID, transactionID, string(key))
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error("Error sending SAS key to other device: %v", err)
|
||||
mach.Log.Error().Msgf("Error sending SAS key to other device: %v", err)
|
||||
return
|
||||
}
|
||||
}
|
||||
@@ -416,10 +413,10 @@ func (mach *OlmMachine) handleVerificationKey(userID id.UserID, content *event.V
|
||||
sasMethod := verState.chosenSASMethod
|
||||
sas, err := sasMethod.GetVerificationSAS(initUserID, initDeviceID, initKey, acceptUserID, acceptDeviceID, acceptKey, transactionID, verState.sas)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error generating SAS (method %v): %v", sasMethod.Type(), err)
|
||||
mach.Log.Error().Msgf("Error generating SAS (method %v): %v", sasMethod.Type(), err)
|
||||
return
|
||||
}
|
||||
mach.Log.Debug("Generated SAS (%v): %v", sasMethod.Type(), sas)
|
||||
mach.Log.Debug().Msgf("Generated SAS (%v): %v", sasMethod.Type(), sas)
|
||||
go func() {
|
||||
result := verState.hooks.VerifySASMatch(device, sas)
|
||||
mach.sasCompared(result, transactionID, verState)
|
||||
@@ -441,7 +438,7 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
|
||||
err = mach.SendInRoomSASVerificationMAC(verState.inRoomID, verState.otherDevice.UserID, verState.otherDevice.DeviceID, transactionID, verState.sas)
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error("Error sending verification MAC to other device: %v", err)
|
||||
mach.Log.Error().Msgf("Error sending verification MAC to other device: %v", err)
|
||||
}
|
||||
} else {
|
||||
verState.sasMatched <- false
|
||||
@@ -451,10 +448,10 @@ func (mach *OlmMachine) sasCompared(didMatch bool, transactionID string, verStat
|
||||
// handleVerificationMAC handles an incoming m.key.verification.mac message.
|
||||
// It verifies the other device's MAC and if the MAC is valid it marks the device as trusted.
|
||||
func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.VerificationMacEventContent, transactionID string) {
|
||||
mach.Log.Debug("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
|
||||
mach.Log.Debug().Msgf("Got MAC for verification %v: %v, MAC for keys: %v", transactionID, content.Mac, content.Keys)
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error getting transaction state: %v", err)
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
}
|
||||
verState.lock.Lock()
|
||||
@@ -468,7 +465,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
|
||||
|
||||
if !verState.verificationStarted || !verState.keyReceived {
|
||||
// unexpected MAC at this point
|
||||
mach.Log.Warn("Unexpected MAC message for transaction %v", transactionID)
|
||||
mach.Log.Warn().Msgf("Unexpected MAC message for transaction %v", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Unexpected MAC message", event.VerificationCancelUnexpectedMessage)
|
||||
return
|
||||
}
|
||||
@@ -480,7 +477,7 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
|
||||
defer verState.lock.Unlock()
|
||||
|
||||
if !matched {
|
||||
mach.Log.Warn("SAS do not match! Canceling transaction %v", transactionID)
|
||||
mach.Log.Warn().Msgf("SAS do not match! Canceling transaction %v", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "SAS do not match", event.VerificationCancelSASMismatch)
|
||||
return
|
||||
}
|
||||
@@ -490,20 +487,20 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
|
||||
expectedPKMAC, expectedKeysMAC, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID,
|
||||
mach.Client.UserID, mach.Client.DeviceID, transactionID, device.SigningKey, keyID, content.Mac)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error generating MAC to match with received MAC: %v", err)
|
||||
mach.Log.Error().Msgf("Error generating MAC to match with received MAC: %v", err)
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Debug("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
|
||||
mach.Log.Debug().Msgf("Expected %s keys MAC, got %s", expectedKeysMAC, content.Keys)
|
||||
if content.Keys != expectedKeysMAC {
|
||||
mach.Log.Warn("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched keys MAC", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched keys MACs", event.VerificationCancelKeyMismatch)
|
||||
return
|
||||
}
|
||||
|
||||
mach.Log.Debug("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
|
||||
mach.Log.Debug().Msgf("Expected %s PK MAC, got %s", expectedPKMAC, content.Mac[keyID])
|
||||
if content.Mac[keyID] != expectedPKMAC {
|
||||
mach.Log.Warn("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v due to mismatched PK MAC", transactionID)
|
||||
_ = mach.callbackAndCancelSASVerification(verState, transactionID, "Mismatched PK MACs", event.VerificationCancelKeyMismatch)
|
||||
return
|
||||
}
|
||||
@@ -512,35 +509,35 @@ func (mach *OlmMachine) handleVerificationMAC(userID id.UserID, content *event.V
|
||||
device.Trust = id.TrustStateVerified
|
||||
err = mach.CryptoStore.PutDevice(device.UserID, device)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to put device after verifying: %v", err)
|
||||
mach.Log.Warn().Msgf("Failed to put device after verifying: %v", err)
|
||||
}
|
||||
|
||||
if mach.CrossSigningKeys != nil {
|
||||
if device.UserID == mach.Client.UserID {
|
||||
err := mach.SignOwnDevice(device)
|
||||
if err != nil {
|
||||
mach.Log.Error("Failed to cross-sign own device %s: %v", device.DeviceID, err)
|
||||
mach.Log.Error().Msgf("Failed to cross-sign own device %s: %v", device.DeviceID, err)
|
||||
} else {
|
||||
mach.Log.Debug("Cross-signed own device %v after SAS verification", device.DeviceID)
|
||||
mach.Log.Debug().Msgf("Cross-signed own device %v after SAS verification", device.DeviceID)
|
||||
}
|
||||
} else {
|
||||
masterKey, err := mach.fetchMasterKey(device, content, verState, transactionID)
|
||||
if err != nil {
|
||||
mach.Log.Warn("Failed to fetch %s's master key: %v", device.UserID, err)
|
||||
mach.Log.Warn().Msgf("Failed to fetch %s's master key: %v", device.UserID, err)
|
||||
} else {
|
||||
if err := mach.SignUser(device.UserID, masterKey); err != nil {
|
||||
mach.Log.Error("Failed to cross-sign master key of %s: %v", device.UserID, err)
|
||||
mach.Log.Error().Msgf("Failed to cross-sign master key of %s: %v", device.UserID, err)
|
||||
} else {
|
||||
mach.Log.Debug("Cross-signed master key of %v after SAS verification", device.UserID)
|
||||
mach.Log.Debug().Msgf("Cross-signed master key of %v after SAS verification", device.UserID)
|
||||
}
|
||||
}
|
||||
}
|
||||
} else {
|
||||
// TODO ask user to unlock cross-signing keys?
|
||||
mach.Log.Debug("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID)
|
||||
mach.Log.Debug().Msgf("Cross-signing keys not cached, not signing %s/%s", device.UserID, device.DeviceID)
|
||||
}
|
||||
|
||||
mach.Log.Debug("Device %v of user %v verified successfully!", device.DeviceID, device.UserID)
|
||||
mach.Log.Debug().Msgf("Device %v of user %v verified successfully!", device.DeviceID, device.UserID)
|
||||
|
||||
verState.hooks.OnSuccess()
|
||||
}()
|
||||
@@ -557,20 +554,20 @@ func (mach *OlmMachine) handleVerificationCancel(userID id.UserID, content *even
|
||||
}
|
||||
|
||||
mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
mach.Log.Warn("SAS verification %v was canceled by %v with reason: %v (%v)",
|
||||
mach.Log.Warn().Msgf("SAS verification %v was canceled by %v with reason: %v (%v)",
|
||||
transactionID, userID, content.Reason, content.Code)
|
||||
}
|
||||
|
||||
// handleVerificationRequest handles an incoming m.key.verification.request message.
|
||||
func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *event.VerificationRequestEventContent, transactionID string, inRoomID id.RoomID) {
|
||||
mach.Log.Debug("Received verification request from %v", content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(userID, content.FromDevice)
|
||||
mach.Log.Debug().Msgf("Received verification request from %v", content.FromDevice)
|
||||
otherDevice, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
|
||||
if err != nil {
|
||||
mach.Log.Error("Could not find device %v of user %v", content.FromDevice, userID)
|
||||
mach.Log.Error().Msgf("Could not find device %v of user %v", content.FromDevice, userID)
|
||||
return
|
||||
}
|
||||
if !content.SupportsVerificationMethod(event.VerificationMethodSAS) {
|
||||
mach.Log.Warn("Canceling verification transaction %v as SAS is not supported", transactionID)
|
||||
mach.Log.Warn().Msgf("Canceling verification transaction %v as SAS is not supported", transactionID)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Only SAS method is supported", event.VerificationCancelUnknownMethod)
|
||||
} else {
|
||||
@@ -580,12 +577,12 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
|
||||
}
|
||||
resp, hooks := mach.AcceptVerificationFrom(transactionID, otherDevice, inRoomID)
|
||||
if resp == AcceptRequest {
|
||||
mach.Log.Debug("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
mach.Log.Debug().Msgf("Accepting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
if inRoomID == "" {
|
||||
_, err = mach.NewSASVerificationWith(otherDevice, hooks, transactionID, mach.DefaultSASTimeout)
|
||||
} else {
|
||||
if err := mach.SendInRoomSASVerificationReady(inRoomID, transactionID); err != nil {
|
||||
mach.Log.Error("Error sending in-room SAS verification ready: %v", err)
|
||||
mach.Log.Error().Msgf("Error sending in-room SAS verification ready: %v", err)
|
||||
}
|
||||
if mach.Client.UserID < otherDevice.UserID {
|
||||
// up to us to send the start message
|
||||
@@ -593,17 +590,17 @@ func (mach *OlmMachine) handleVerificationRequest(userID id.UserID, content *eve
|
||||
}
|
||||
}
|
||||
if err != nil {
|
||||
mach.Log.Error("Error accepting SAS verification request: %v", err)
|
||||
mach.Log.Error().Msgf("Error accepting SAS verification request: %v", err)
|
||||
}
|
||||
} else if resp == RejectRequest {
|
||||
mach.Log.Debug("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
mach.Log.Debug().Msgf("Rejecting SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
if inRoomID == "" {
|
||||
_ = mach.SendSASVerificationCancel(otherDevice.UserID, otherDevice.DeviceID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
} else {
|
||||
_ = mach.SendInRoomSASVerificationCancel(inRoomID, otherDevice.UserID, transactionID, "Not accepted by user", event.VerificationCancelByUser)
|
||||
}
|
||||
} else {
|
||||
mach.Log.Debug("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
mach.Log.Debug().Msgf("Ignoring SAS verification %v from %v of user %v", transactionID, otherDevice.DeviceID, otherDevice.UserID)
|
||||
}
|
||||
}
|
||||
|
||||
@@ -620,7 +617,7 @@ func (mach *OlmMachine) NewSASVerificationWith(device *id.Device, hooks Verifica
|
||||
if transactionID == "" {
|
||||
transactionID = strconv.Itoa(rand.Int())
|
||||
}
|
||||
mach.Log.Debug("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID)
|
||||
mach.Log.Debug().Msgf("Starting new verification transaction %v with device %v of user %v", transactionID, device.DeviceID, device.UserID)
|
||||
|
||||
verState := &verificationState{
|
||||
sas: olm.NewSAS(),
|
||||
@@ -669,7 +666,7 @@ func (mach *OlmMachine) CancelSASVerification(userID id.UserID, transactionID, r
|
||||
verState := verStateInterface.(*verificationState)
|
||||
verState.lock.Lock()
|
||||
defer verState.lock.Unlock()
|
||||
mach.Log.Trace("User canceled verification transaction %v with reason: %v", transactionID, reason)
|
||||
mach.Log.Trace().Msgf("User canceled verification transaction %v with reason: %v", transactionID, reason)
|
||||
mach.keyVerificationTransactionState.Delete(mapKey)
|
||||
return mach.callbackAndCancelSASVerification(verState, transactionID, reason, event.VerificationCancelByUser)
|
||||
}
|
||||
@@ -766,9 +763,9 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
|
||||
masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID,
|
||||
userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error generating master key MAC: %v", err)
|
||||
mach.Log.Error().Msgf("Error generating master key MAC: %v", err)
|
||||
} else {
|
||||
mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
|
||||
mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
|
||||
macMap[masterKeyID] = masterKeyMAC
|
||||
}
|
||||
}
|
||||
@@ -777,8 +774,8 @@ func (mach *OlmMachine) SendSASVerificationMAC(userID id.UserID, deviceID id.Dev
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac)
|
||||
mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac)
|
||||
mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac)
|
||||
mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac)
|
||||
macMap[keyID] = pubKeyMac
|
||||
|
||||
content := &event.VerificationMacEventContent{
|
||||
|
||||
31
vendor/maunium.net/go/mautrix/crypto/verification_in_room.go
generated
vendored
31
vendor/maunium.net/go/mautrix/crypto/verification_in_room.go
generated
vendored
@@ -7,6 +7,7 @@
|
||||
package crypto
|
||||
|
||||
import (
|
||||
"context"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"time"
|
||||
@@ -80,7 +81,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationCancel(roomID id.RoomID, userID
|
||||
To: userID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationCancel, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationCancel, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -97,7 +98,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationRequest(roomID id.RoomID, toUse
|
||||
To: toUserID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(roomID, event.EventMessage, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.EventMessage, content)
|
||||
if err != nil {
|
||||
return "", err
|
||||
}
|
||||
@@ -116,7 +117,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationReady(roomID id.RoomID, transac
|
||||
RelatesTo: &event.RelatesTo{Type: event.RelReference, EventID: id.EventID(transactionID)},
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationReady, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationReady, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -141,7 +142,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationStart(roomID id.RoomID, toUserI
|
||||
To: toUserID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationStart, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationStart, content)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -182,7 +183,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationAccept(roomID id.RoomID, fromUs
|
||||
To: fromUser,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationAccept, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationAccept, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -198,7 +199,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationKey(roomID id.RoomID, userID id
|
||||
To: userID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationKey, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationKey, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -222,9 +223,9 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
|
||||
masterKeyMAC, _, err := mach.getPKAndKeysMAC(sas, mach.Client.UserID, mach.Client.DeviceID,
|
||||
userID, deviceID, transactionID, masterKey, masterKeyID, keyIDsMap)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error generating master key MAC: %v", err)
|
||||
mach.Log.Error().Msgf("Error generating master key MAC: %v", err)
|
||||
} else {
|
||||
mach.Log.Debug("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
|
||||
mach.Log.Debug().Msgf("Generated master key `%v` MAC: %v", masterKey, masterKeyMAC)
|
||||
macMap[masterKeyID] = masterKeyMAC
|
||||
}
|
||||
}
|
||||
@@ -233,8 +234,8 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
mach.Log.Debug("MAC of key %s is: %s", signingKey, pubKeyMac)
|
||||
mach.Log.Debug("MAC of key ID(s) %s is: %s", keyID, keysMac)
|
||||
mach.Log.Debug().Msgf("MAC of key %s is: %s", signingKey, pubKeyMac)
|
||||
mach.Log.Debug().Msgf("MAC of key ID(s) %s is: %s", keyID, keysMac)
|
||||
macMap[keyID] = pubKeyMac
|
||||
|
||||
content := &event.VerificationMacEventContent{
|
||||
@@ -244,7 +245,7 @@ func (mach *OlmMachine) SendInRoomSASVerificationMAC(roomID id.RoomID, userID id
|
||||
To: userID,
|
||||
}
|
||||
|
||||
encrypted, err := mach.EncryptMegolmEvent(roomID, event.InRoomVerificationMAC, content)
|
||||
encrypted, err := mach.EncryptMegolmEvent(context.TODO(), roomID, event.InRoomVerificationMAC, content)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
@@ -259,7 +260,7 @@ func (mach *OlmMachine) NewInRoomSASVerificationWith(inRoomID id.RoomID, userID
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, device *id.Device, hooks VerificationHooks, transactionID string, timeout time.Duration) (string, error) {
|
||||
mach.Log.Debug("Starting new in-room verification transaction user %v", device.UserID)
|
||||
mach.Log.Debug().Msgf("Starting new in-room verification transaction user %v", device.UserID)
|
||||
|
||||
request := transactionID == ""
|
||||
if request {
|
||||
@@ -310,15 +311,15 @@ func (mach *OlmMachine) newInRoomSASVerificationWithInner(inRoomID id.RoomID, de
|
||||
}
|
||||
|
||||
func (mach *OlmMachine) handleInRoomVerificationReady(userID id.UserID, roomID id.RoomID, content *event.VerificationReadyEventContent, transactionID string) {
|
||||
device, err := mach.GetOrFetchDevice(userID, content.FromDevice)
|
||||
device, err := mach.GetOrFetchDevice(context.TODO(), userID, content.FromDevice)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
|
||||
mach.Log.Error().Msgf("Error fetching device %v of user %v: %v", content.FromDevice, userID, err)
|
||||
return
|
||||
}
|
||||
|
||||
verState, err := mach.getTransactionState(transactionID, userID)
|
||||
if err != nil {
|
||||
mach.Log.Error("Error getting transaction state: %v", err)
|
||||
mach.Log.Error().Msgf("Error getting transaction state: %v", err)
|
||||
return
|
||||
}
|
||||
//mach.keyVerificationTransactionState.Delete(userID.String() + ":" + transactionID)
|
||||
|
||||
3
vendor/maunium.net/go/mautrix/crypto/verification_sas_methods.go
generated
vendored
3
vendor/maunium.net/go/mautrix/crypto/verification_sas_methods.go
generated
vendored
@@ -4,9 +4,6 @@
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
//go:build !nosas
|
||||
// +build !nosas
|
||||
|
||||
package crypto
|
||||
|
||||
import (
|
||||
|
||||
18
vendor/maunium.net/go/mautrix/error.go
generated
vendored
18
vendor/maunium.net/go/mautrix/error.go
generated
vendored
@@ -11,6 +11,8 @@ import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"net/http"
|
||||
|
||||
"golang.org/x/exp/maps"
|
||||
)
|
||||
|
||||
// Common error codes from https://matrix.org/docs/spec/client_server/latest#api-standards
|
||||
@@ -60,6 +62,11 @@ var (
|
||||
MIncompatibleRoomVersion = RespError{ErrCode: "M_INCOMPATIBLE_ROOM_VERSION"}
|
||||
// The client specified a parameter that has the wrong value.
|
||||
MInvalidParam = RespError{ErrCode: "M_INVALID_PARAM"}
|
||||
|
||||
MURLNotSet = RespError{ErrCode: "M_URL_NOT_SET"}
|
||||
MBadStatus = RespError{ErrCode: "M_BAD_STATUS"}
|
||||
MConnectionTimeout = RespError{ErrCode: "M_CONNECTION_TIMEOUT"}
|
||||
MConnectionFailed = RespError{ErrCode: "M_CONNECTION_FAILED"}
|
||||
)
|
||||
|
||||
// HTTPError An HTTP Error response, which may wrap an underlying native Go Error.
|
||||
@@ -124,12 +131,13 @@ func (e *RespError) UnmarshalJSON(data []byte) error {
|
||||
}
|
||||
|
||||
func (e *RespError) MarshalJSON() ([]byte, error) {
|
||||
if e.ExtraData == nil {
|
||||
e.ExtraData = make(map[string]interface{})
|
||||
data := maps.Clone(e.ExtraData)
|
||||
if data == nil {
|
||||
data = make(map[string]any)
|
||||
}
|
||||
e.ExtraData["errcode"] = e.ErrCode
|
||||
e.ExtraData["error"] = e.Err
|
||||
return json.Marshal(&e.ExtraData)
|
||||
data["errcode"] = e.ErrCode
|
||||
data["error"] = e.Err
|
||||
return json.Marshal(data)
|
||||
}
|
||||
|
||||
// Error returns the errcode and error message.
|
||||
|
||||
6
vendor/maunium.net/go/mautrix/event/beeper.go
generated
vendored
6
vendor/maunium.net/go/mautrix/event/beeper.go
generated
vendored
@@ -49,3 +49,9 @@ type BeeperRetryMetadata struct {
|
||||
RetryCount int `json:"retry_count"`
|
||||
// last_retry is also present, but not used by bridges
|
||||
}
|
||||
|
||||
type BeeperRoomKeyAckEventContent struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
SessionID id.SessionID `json:"session_id"`
|
||||
FirstMessageIndex int `json:"first_message_index"`
|
||||
}
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/event/content.go
generated
vendored
2
vendor/maunium.net/go/mautrix/event/content.go
generated
vendored
@@ -80,6 +80,8 @@ var TypeMap = map[Type]reflect.Type{
|
||||
|
||||
ToDeviceOrgMatrixRoomKeyWithheld: reflect.TypeOf(RoomKeyWithheldEventContent{}),
|
||||
|
||||
ToDeviceBeeperRoomKeyAck: reflect.TypeOf(BeeperRoomKeyAckEventContent{}),
|
||||
|
||||
CallInvite: reflect.TypeOf(CallInviteEventContent{}),
|
||||
CallCandidates: reflect.TypeOf(CallCandidatesEventContent{}),
|
||||
CallAnswer: reflect.TypeOf(CallAnswerEventContent{}),
|
||||
|
||||
31
vendor/maunium.net/go/mautrix/event/encryption.go
generated
vendored
31
vendor/maunium.net/go/mautrix/event/encryption.go
generated
vendored
@@ -8,6 +8,7 @@ package event
|
||||
|
||||
import (
|
||||
"encoding/json"
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
@@ -41,6 +42,7 @@ type EncryptedEventContent struct {
|
||||
OlmCiphertext OlmCiphertexts `json:"-"`
|
||||
|
||||
RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
|
||||
Mentions *Mentions `json:"m.mentions,omitempty"`
|
||||
}
|
||||
|
||||
type OlmCiphertexts map[id.Curve25519]struct {
|
||||
@@ -92,6 +94,10 @@ type RoomKeyEventContent struct {
|
||||
RoomID id.RoomID `json:"room_id"`
|
||||
SessionID id.SessionID `json:"session_id"`
|
||||
SessionKey string `json:"session_key"`
|
||||
|
||||
MaxAge int64 `json:"com.beeper.max_age_ms"`
|
||||
MaxMessages int `json:"com.beeper.max_messages"`
|
||||
IsScheduled bool `json:"com.beeper.is_scheduled"`
|
||||
}
|
||||
|
||||
// ForwardedRoomKeyEventContent represents the content of a m.forwarded_room_key to_device event.
|
||||
@@ -101,6 +107,10 @@ type ForwardedRoomKeyEventContent struct {
|
||||
SenderKey id.SenderKey `json:"sender_key"`
|
||||
SenderClaimedKey id.Ed25519 `json:"sender_claimed_ed25519_key"`
|
||||
ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"`
|
||||
|
||||
MaxAge int64 `json:"com.beeper.max_age_ms"`
|
||||
MaxMessages int `json:"com.beeper.max_messages"`
|
||||
IsScheduled bool `json:"com.beeper.is_scheduled"`
|
||||
}
|
||||
|
||||
type KeyRequestAction string
|
||||
@@ -134,6 +144,8 @@ const (
|
||||
RoomKeyWithheldUnauthorized RoomKeyWithheldCode = "m.unauthorised"
|
||||
RoomKeyWithheldUnavailable RoomKeyWithheldCode = "m.unavailable"
|
||||
RoomKeyWithheldNoOlmSession RoomKeyWithheldCode = "m.no_olm"
|
||||
|
||||
RoomKeyWithheldBeeperRedacted RoomKeyWithheldCode = "com.beeper.redacted"
|
||||
)
|
||||
|
||||
type RoomKeyWithheldEventContent struct {
|
||||
@@ -145,4 +157,23 @@ type RoomKeyWithheldEventContent struct {
|
||||
Reason string `json:"reason,omitempty"`
|
||||
}
|
||||
|
||||
const groupSessionWithheldMsg = "group session has been withheld: %s"
|
||||
|
||||
func (withheld *RoomKeyWithheldEventContent) Error() string {
|
||||
switch withheld.Code {
|
||||
case RoomKeyWithheldBlacklisted, RoomKeyWithheldUnverified, RoomKeyWithheldUnauthorized, RoomKeyWithheldUnavailable, RoomKeyWithheldNoOlmSession:
|
||||
return fmt.Sprintf(groupSessionWithheldMsg, withheld.Code)
|
||||
default:
|
||||
return fmt.Sprintf(groupSessionWithheldMsg+" (%s)", withheld.Code, withheld.Reason)
|
||||
}
|
||||
}
|
||||
|
||||
func (withheld *RoomKeyWithheldEventContent) Is(other error) bool {
|
||||
otherWithheld, ok := other.(*RoomKeyWithheldEventContent)
|
||||
if !ok {
|
||||
return false
|
||||
}
|
||||
return withheld.Code == "" || otherWithheld.Code == "" || withheld.Code == otherWithheld.Code
|
||||
}
|
||||
|
||||
type DummyEventContent struct{}
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/event/events.go
generated
vendored
2
vendor/maunium.net/go/mautrix/event/events.go
generated
vendored
@@ -111,6 +111,8 @@ type MautrixInfo struct {
|
||||
TrustSource *id.Device
|
||||
|
||||
ReceivedAt time.Time
|
||||
EditedAt time.Time
|
||||
LastEditID id.EventID
|
||||
DecryptionDuration time.Duration
|
||||
|
||||
CheckpointSent bool
|
||||
|
||||
16
vendor/maunium.net/go/mautrix/event/message.go
generated
vendored
16
vendor/maunium.net/go/mautrix/event/message.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2020 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -97,6 +97,9 @@ type MessageEventContent struct {
|
||||
|
||||
FileName string `json:"filename,omitempty"`
|
||||
|
||||
Mentions *Mentions `json:"m.mentions,omitempty"`
|
||||
UnstableMentions *Mentions `json:"org.matrix.msc3952.mentions,omitempty"`
|
||||
|
||||
// Edits and relations
|
||||
NewContent *MessageEventContent `json:"m.new_content,omitempty"`
|
||||
RelatesTo *RelatesTo `json:"m.relates_to,omitempty"`
|
||||
@@ -135,6 +138,12 @@ func (content *MessageEventContent) SetEdit(original id.EventID) {
|
||||
if content.Format == FormatHTML && len(content.FormattedBody) > 0 {
|
||||
content.FormattedBody = "* " + content.FormattedBody
|
||||
}
|
||||
// If the message is long, remove most of the useless edit fallback to avoid event size issues.
|
||||
if len(content.Body) > 10000 {
|
||||
content.FormattedBody = ""
|
||||
content.Format = ""
|
||||
content.Body = content.Body[:50] + "[edit fallback cut…]"
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
@@ -163,6 +172,11 @@ func (content *MessageEventContent) GetInfo() *FileInfo {
|
||||
return content.Info
|
||||
}
|
||||
|
||||
type Mentions struct {
|
||||
UserIDs []id.UserID `json:"user_ids,omitempty"`
|
||||
Room bool `json:"room,omitempty"`
|
||||
}
|
||||
|
||||
type EncryptedFileInfo struct {
|
||||
attachment.EncryptedFile
|
||||
URL id.ContentURIString `json:"url"`
|
||||
|
||||
64
vendor/maunium.net/go/mautrix/event/powerlevels.go
generated
vendored
64
vendor/maunium.net/go/mautrix/event/powerlevels.go
generated
vendored
@@ -13,7 +13,7 @@ import (
|
||||
)
|
||||
|
||||
// PowerLevelsEventContent represents the content of a m.room.power_levels state event content.
|
||||
// https://spec.matrix.org/v1.2/client-server-api/#mroompower_levels
|
||||
// https://spec.matrix.org/v1.5/client-server-api/#mroompower_levels
|
||||
type PowerLevelsEventContent struct {
|
||||
usersLock sync.RWMutex
|
||||
Users map[id.UserID]int `json:"users,omitempty"`
|
||||
@@ -23,6 +23,8 @@ type PowerLevelsEventContent struct {
|
||||
Events map[string]int `json:"events,omitempty"`
|
||||
EventsDefault int `json:"events_default,omitempty"`
|
||||
|
||||
Notifications *NotificationPowerLevels `json:"notifications,omitempty"`
|
||||
|
||||
StateDefaultPtr *int `json:"state_default,omitempty"`
|
||||
|
||||
InvitePtr *int `json:"invite,omitempty"`
|
||||
@@ -32,6 +34,66 @@ type PowerLevelsEventContent struct {
|
||||
HistoricalPtr *int `json:"historical,omitempty"`
|
||||
}
|
||||
|
||||
func copyPtr(ptr *int) *int {
|
||||
if ptr == nil {
|
||||
return nil
|
||||
}
|
||||
val := *ptr
|
||||
return &val
|
||||
}
|
||||
|
||||
func copyMap[Key comparable](m map[Key]int) map[Key]int {
|
||||
if m == nil {
|
||||
return nil
|
||||
}
|
||||
copied := make(map[Key]int, len(m))
|
||||
for k, v := range m {
|
||||
copied[k] = v
|
||||
}
|
||||
return copied
|
||||
}
|
||||
|
||||
func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
|
||||
if pl == nil {
|
||||
return nil
|
||||
}
|
||||
return &PowerLevelsEventContent{
|
||||
Users: copyMap(pl.Users),
|
||||
UsersDefault: pl.UsersDefault,
|
||||
Events: copyMap(pl.Events),
|
||||
EventsDefault: pl.EventsDefault,
|
||||
StateDefaultPtr: copyPtr(pl.StateDefaultPtr),
|
||||
|
||||
Notifications: pl.Notifications.Clone(),
|
||||
|
||||
InvitePtr: copyPtr(pl.InvitePtr),
|
||||
KickPtr: copyPtr(pl.KickPtr),
|
||||
BanPtr: copyPtr(pl.BanPtr),
|
||||
RedactPtr: copyPtr(pl.RedactPtr),
|
||||
HistoricalPtr: copyPtr(pl.HistoricalPtr),
|
||||
}
|
||||
}
|
||||
|
||||
type NotificationPowerLevels struct {
|
||||
RoomPtr *int `json:"room,omitempty"`
|
||||
}
|
||||
|
||||
func (npl *NotificationPowerLevels) Clone() *NotificationPowerLevels {
|
||||
if npl == nil {
|
||||
return nil
|
||||
}
|
||||
return &NotificationPowerLevels{
|
||||
RoomPtr: copyPtr(npl.RoomPtr),
|
||||
}
|
||||
}
|
||||
|
||||
func (npl *NotificationPowerLevels) Room() int {
|
||||
if npl != nil && npl.RoomPtr != nil {
|
||||
return *npl.RoomPtr
|
||||
}
|
||||
return 50
|
||||
}
|
||||
|
||||
func (pl *PowerLevelsEventContent) Invite() int {
|
||||
if pl.InvitePtr != nil {
|
||||
return *pl.InvitePtr
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/event/relations.go
generated
vendored
2
vendor/maunium.net/go/mautrix/event/relations.go
generated
vendored
@@ -32,6 +32,8 @@ type RelatesTo struct {
|
||||
|
||||
type InReplyTo struct {
|
||||
EventID id.EventID `json:"event_id,omitempty"`
|
||||
|
||||
UnstableRoomID id.RoomID `json:"room_id,omitempty"`
|
||||
}
|
||||
|
||||
func (rel *RelatesTo) Copy() *RelatesTo {
|
||||
|
||||
5
vendor/maunium.net/go/mautrix/event/type.go
generated
vendored
5
vendor/maunium.net/go/mautrix/event/type.go
generated
vendored
@@ -124,7 +124,8 @@ func (et *Type) GuessClass() TypeClass {
|
||||
CallInvite.Type, CallCandidates.Type, CallAnswer.Type, CallReject.Type, CallSelectAnswer.Type,
|
||||
CallNegotiate.Type, CallHangup.Type, BeeperMessageStatus.Type:
|
||||
return MessageEventType
|
||||
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type:
|
||||
case ToDeviceRoomKey.Type, ToDeviceRoomKeyRequest.Type, ToDeviceForwardedRoomKey.Type, ToDeviceRoomKeyWithheld.Type,
|
||||
ToDeviceBeeperRoomKeyAck.Type:
|
||||
return ToDeviceEventType
|
||||
default:
|
||||
return UnknownEventType
|
||||
@@ -253,4 +254,6 @@ var (
|
||||
ToDeviceVerificationCancel = Type{"m.key.verification.cancel", ToDeviceEventType}
|
||||
|
||||
ToDeviceOrgMatrixRoomKeyWithheld = Type{"org.matrix.room_key.withheld", ToDeviceEventType}
|
||||
|
||||
ToDeviceBeeperRoomKeyAck = Type{"com.beeper.room_key.ack", ToDeviceEventType}
|
||||
)
|
||||
|
||||
139
vendor/maunium.net/go/mautrix/format/htmlparser.go
generated
vendored
139
vendor/maunium.net/go/mautrix/format/htmlparser.go
generated
vendored
@@ -17,7 +17,45 @@ import (
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
type Context map[string]interface{}
|
||||
type TagStack []string
|
||||
|
||||
func (ts TagStack) Index(tag string) int {
|
||||
for i := len(ts) - 1; i >= 0; i-- {
|
||||
if ts[i] == tag {
|
||||
return i
|
||||
}
|
||||
}
|
||||
return -1
|
||||
}
|
||||
|
||||
func (ts TagStack) Has(tag string) bool {
|
||||
return ts.Index(tag) >= 0
|
||||
}
|
||||
|
||||
type Context struct {
|
||||
ReturnData map[string]any
|
||||
TagStack TagStack
|
||||
|
||||
PreserveWhitespace bool
|
||||
}
|
||||
|
||||
func NewContext() Context {
|
||||
return Context{
|
||||
ReturnData: map[string]any{},
|
||||
TagStack: make(TagStack, 0, 4),
|
||||
}
|
||||
}
|
||||
|
||||
func (ctx Context) WithTag(tag string) Context {
|
||||
ctx.TagStack = append(ctx.TagStack, tag)
|
||||
return ctx
|
||||
}
|
||||
|
||||
func (ctx Context) WithWhitespace() Context {
|
||||
ctx.PreserveWhitespace = true
|
||||
return ctx
|
||||
}
|
||||
|
||||
type TextConverter func(string, Context) string
|
||||
type SpoilerConverter func(text, reason string, ctx Context) string
|
||||
type LinkConverter func(text, href string, ctx Context) string
|
||||
@@ -93,9 +131,9 @@ func Digits(num int) int {
|
||||
return int(math.Floor(math.Log10(float64(num))) + 1)
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) listToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
func (parser *HTMLParser) listToString(node *html.Node, ctx Context) string {
|
||||
ordered := node.Data == "ol"
|
||||
taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, stripLinebreak, ctx)
|
||||
taggedChildren := parser.nodeToTaggedStrings(node.FirstChild, ctx)
|
||||
counter := 1
|
||||
indentLength := 0
|
||||
if ordered {
|
||||
@@ -137,8 +175,27 @@ func (parser *HTMLParser) listToString(node *html.Node, stripLinebreak bool, ctx
|
||||
return strings.Join(children, "\n")
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) basicFormatToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
|
||||
func LongestSequence(in string, of rune) int {
|
||||
currentSeq := 0
|
||||
maxSeq := 0
|
||||
for _, chr := range in {
|
||||
if chr == of {
|
||||
currentSeq++
|
||||
} else {
|
||||
if currentSeq > maxSeq {
|
||||
maxSeq = currentSeq
|
||||
}
|
||||
currentSeq = 0
|
||||
}
|
||||
}
|
||||
if currentSeq > maxSeq {
|
||||
maxSeq = currentSeq
|
||||
}
|
||||
return maxSeq
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) basicFormatToString(node *html.Node, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
|
||||
switch node.Data {
|
||||
case "b", "strong":
|
||||
if parser.BoldConverter != nil {
|
||||
@@ -163,13 +220,14 @@ func (parser *HTMLParser) basicFormatToString(node *html.Node, stripLinebreak bo
|
||||
if parser.MonospaceConverter != nil {
|
||||
return parser.MonospaceConverter(str, ctx)
|
||||
}
|
||||
return fmt.Sprintf("`%s`", str)
|
||||
surround := strings.Repeat("`", LongestSequence(str, '`')+1)
|
||||
return fmt.Sprintf("%s%s%s", surround, str, surround)
|
||||
}
|
||||
return str
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) spanToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
|
||||
func (parser *HTMLParser) spanToString(node *html.Node, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
|
||||
if node.Data == "span" {
|
||||
reason, isSpoiler := parser.maybeGetAttribute(node, "data-mx-spoiler")
|
||||
if isSpoiler {
|
||||
@@ -195,15 +253,15 @@ func (parser *HTMLParser) spanToString(node *html.Node, stripLinebreak bool, ctx
|
||||
return str
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) headerToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
children := parser.nodeToStrings(node.FirstChild, stripLinebreak, ctx)
|
||||
func (parser *HTMLParser) headerToString(node *html.Node, ctx Context) string {
|
||||
children := parser.nodeToStrings(node.FirstChild, ctx)
|
||||
length := int(node.Data[1] - '0')
|
||||
prefix := strings.Repeat("#", length) + " "
|
||||
return prefix + strings.Join(children, "")
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) blockquoteToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
|
||||
func (parser *HTMLParser) blockquoteToString(node *html.Node, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
|
||||
childrenArr := strings.Split(strings.TrimSpace(str), "\n")
|
||||
// TODO make blockquote prefix configurable
|
||||
for index, child := range childrenArr {
|
||||
@@ -212,8 +270,8 @@ func (parser *HTMLParser) blockquoteToString(node *html.Node, stripLinebreak boo
|
||||
return strings.Join(childrenArr, "\n")
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) linkToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
|
||||
func (parser *HTMLParser) linkToString(node *html.Node, ctx Context) string {
|
||||
str := parser.nodeToTagAwareString(node.FirstChild, ctx)
|
||||
href := parser.getAttribute(node, "href")
|
||||
if len(href) == 0 {
|
||||
return str
|
||||
@@ -232,24 +290,25 @@ func (parser *HTMLParser) linkToString(node *html.Node, stripLinebreak bool, ctx
|
||||
return fmt.Sprintf("%s (%s)", str, href)
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) tagToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
func (parser *HTMLParser) tagToString(node *html.Node, ctx Context) string {
|
||||
ctx = ctx.WithTag(node.Data)
|
||||
switch node.Data {
|
||||
case "blockquote":
|
||||
return parser.blockquoteToString(node, stripLinebreak, ctx)
|
||||
return parser.blockquoteToString(node, ctx)
|
||||
case "ol", "ul":
|
||||
return parser.listToString(node, stripLinebreak, ctx)
|
||||
return parser.listToString(node, ctx)
|
||||
case "h1", "h2", "h3", "h4", "h5", "h6":
|
||||
return parser.headerToString(node, stripLinebreak, ctx)
|
||||
return parser.headerToString(node, ctx)
|
||||
case "br":
|
||||
return parser.Newline
|
||||
case "b", "strong", "i", "em", "s", "strike", "del", "u", "ins", "tt", "code":
|
||||
return parser.basicFormatToString(node, stripLinebreak, ctx)
|
||||
return parser.basicFormatToString(node, ctx)
|
||||
case "span", "font":
|
||||
return parser.spanToString(node, stripLinebreak, ctx)
|
||||
return parser.spanToString(node, ctx)
|
||||
case "a":
|
||||
return parser.linkToString(node, stripLinebreak, ctx)
|
||||
return parser.linkToString(node, ctx)
|
||||
case "p":
|
||||
return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
|
||||
return parser.nodeToTagAwareString(node.FirstChild, ctx)
|
||||
case "hr":
|
||||
return parser.HorizontalLine
|
||||
case "pre":
|
||||
@@ -259,9 +318,9 @@ func (parser *HTMLParser) tagToString(node *html.Node, stripLinebreak bool, ctx
|
||||
if strings.HasPrefix(class, "language-") {
|
||||
language = class[len("language-"):]
|
||||
}
|
||||
preStr = parser.nodeToString(node.FirstChild.FirstChild, false, ctx)
|
||||
preStr = parser.nodeToString(node.FirstChild.FirstChild, ctx.WithWhitespace())
|
||||
} else {
|
||||
preStr = parser.nodeToString(node.FirstChild, false, ctx)
|
||||
preStr = parser.nodeToString(node.FirstChild, ctx.WithWhitespace())
|
||||
}
|
||||
if parser.MonospaceBlockConverter != nil {
|
||||
return parser.MonospaceBlockConverter(preStr, language, ctx)
|
||||
@@ -271,14 +330,14 @@ func (parser *HTMLParser) tagToString(node *html.Node, stripLinebreak bool, ctx
|
||||
}
|
||||
return fmt.Sprintf("```%s\n%s```", language, preStr)
|
||||
default:
|
||||
return parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx)
|
||||
return parser.nodeToTagAwareString(node.FirstChild, ctx)
|
||||
}
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) singleNodeToString(node *html.Node, stripLinebreak bool, ctx Context) TaggedString {
|
||||
func (parser *HTMLParser) singleNodeToString(node *html.Node, ctx Context) TaggedString {
|
||||
switch node.Type {
|
||||
case html.TextNode:
|
||||
if stripLinebreak {
|
||||
if !ctx.PreserveWhitespace {
|
||||
node.Data = strings.Replace(node.Data, "\n", "", -1)
|
||||
}
|
||||
if parser.TextConverter != nil {
|
||||
@@ -286,17 +345,17 @@ func (parser *HTMLParser) singleNodeToString(node *html.Node, stripLinebreak boo
|
||||
}
|
||||
return TaggedString{node.Data, "text"}
|
||||
case html.ElementNode:
|
||||
return TaggedString{parser.tagToString(node, stripLinebreak, ctx), node.Data}
|
||||
return TaggedString{parser.tagToString(node, ctx), node.Data}
|
||||
case html.DocumentNode:
|
||||
return TaggedString{parser.nodeToTagAwareString(node.FirstChild, stripLinebreak, ctx), "html"}
|
||||
return TaggedString{parser.nodeToTagAwareString(node.FirstChild, ctx), "html"}
|
||||
default:
|
||||
return TaggedString{"", "unknown"}
|
||||
}
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) nodeToTaggedStrings(node *html.Node, stripLinebreak bool, ctx Context) (strs []TaggedString) {
|
||||
func (parser *HTMLParser) nodeToTaggedStrings(node *html.Node, ctx Context) (strs []TaggedString) {
|
||||
for ; node != nil; node = node.NextSibling {
|
||||
strs = append(strs, parser.singleNodeToString(node, stripLinebreak, ctx))
|
||||
strs = append(strs, parser.singleNodeToString(node, ctx))
|
||||
}
|
||||
return
|
||||
}
|
||||
@@ -312,8 +371,8 @@ func (parser *HTMLParser) isBlockTag(tag string) bool {
|
||||
return false
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
strs := parser.nodeToTaggedStrings(node, stripLinebreak, ctx)
|
||||
func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, ctx Context) string {
|
||||
strs := parser.nodeToTaggedStrings(node, ctx)
|
||||
var output strings.Builder
|
||||
for _, str := range strs {
|
||||
tstr := str.string
|
||||
@@ -325,15 +384,15 @@ func (parser *HTMLParser) nodeToTagAwareString(node *html.Node, stripLinebreak b
|
||||
return strings.TrimSpace(output.String())
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) nodeToStrings(node *html.Node, stripLinebreak bool, ctx Context) (strs []string) {
|
||||
func (parser *HTMLParser) nodeToStrings(node *html.Node, ctx Context) (strs []string) {
|
||||
for ; node != nil; node = node.NextSibling {
|
||||
strs = append(strs, parser.singleNodeToString(node, stripLinebreak, ctx).string)
|
||||
strs = append(strs, parser.singleNodeToString(node, ctx).string)
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (parser *HTMLParser) nodeToString(node *html.Node, stripLinebreak bool, ctx Context) string {
|
||||
return strings.Join(parser.nodeToStrings(node, stripLinebreak, ctx), "")
|
||||
func (parser *HTMLParser) nodeToString(node *html.Node, ctx Context) string {
|
||||
return strings.Join(parser.nodeToStrings(node, ctx), "")
|
||||
}
|
||||
|
||||
// Parse converts Matrix HTML into text using the settings in this parser.
|
||||
@@ -342,7 +401,7 @@ func (parser *HTMLParser) Parse(htmlData string, ctx Context) string {
|
||||
htmlData = strings.Replace(htmlData, "\t", strings.Repeat(" ", parser.TabsToSpaces), -1)
|
||||
}
|
||||
node, _ := html.Parse(strings.NewReader(htmlData))
|
||||
return parser.nodeToTagAwareString(node, true, ctx)
|
||||
return parser.nodeToTagAwareString(node, ctx)
|
||||
}
|
||||
|
||||
// HTMLToText converts Matrix HTML into text with the default settings.
|
||||
@@ -352,7 +411,7 @@ func HTMLToText(html string) string {
|
||||
Newline: "\n",
|
||||
HorizontalLine: "\n---\n",
|
||||
PillConverter: DefaultPillConverter,
|
||||
}).Parse(html, make(Context))
|
||||
}).Parse(html, NewContext())
|
||||
}
|
||||
|
||||
// HTMLToMarkdown converts Matrix HTML into markdown with the default settings.
|
||||
@@ -370,5 +429,5 @@ func HTMLToMarkdown(html string) string {
|
||||
}
|
||||
return fmt.Sprintf("[%s](%s)", text, href)
|
||||
},
|
||||
}).Parse(html, make(Context))
|
||||
}).Parse(html, NewContext())
|
||||
}
|
||||
|
||||
48
vendor/maunium.net/go/mautrix/format/mdext/filteredparser.go
generated
vendored
Normal file
48
vendor/maunium.net/go/mautrix/format/mdext/filteredparser.go
generated
vendored
Normal file
@@ -0,0 +1,48 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mdext
|
||||
|
||||
import (
|
||||
"reflect"
|
||||
|
||||
"github.com/yuin/goldmark/parser"
|
||||
"github.com/yuin/goldmark/util"
|
||||
)
|
||||
|
||||
func filterParsers(list []util.PrioritizedValue, forbidden map[reflect.Type]struct{}) []util.PrioritizedValue {
|
||||
n := 0
|
||||
for _, item := range list {
|
||||
if _, isForbidden := forbidden[reflect.TypeOf(item.Value)]; isForbidden {
|
||||
continue
|
||||
}
|
||||
list[n] = item
|
||||
n++
|
||||
}
|
||||
return list[:n]
|
||||
}
|
||||
|
||||
// ParserWithoutFeatures returns a Goldmark parser with the provided default features removed.
|
||||
//
|
||||
// e.g. to disable lists, use
|
||||
//
|
||||
// markdown := goldmark.New(goldmark.WithParser(
|
||||
// mdext.ParserWithoutFeatures(goldmark.NewListParser(), goldmark.NewListItemParser())
|
||||
// ))
|
||||
func ParserWithoutFeatures(features ...any) parser.Parser {
|
||||
forbiddenTypes := make(map[reflect.Type]struct{}, len(features))
|
||||
for _, feature := range features {
|
||||
forbiddenTypes[reflect.TypeOf(feature)] = struct{}{}
|
||||
}
|
||||
filteredBlockParsers := filterParsers(parser.DefaultBlockParsers(), forbiddenTypes)
|
||||
filteredInlineParsers := filterParsers(parser.DefaultInlineParsers(), forbiddenTypes)
|
||||
filteredParagraphTransformers := filterParsers(parser.DefaultParagraphTransformers(), forbiddenTypes)
|
||||
return parser.NewParser(
|
||||
parser.WithBlockParsers(filteredBlockParsers...),
|
||||
parser.WithInlineParsers(filteredInlineParsers...),
|
||||
parser.WithParagraphTransformers(filteredParagraphTransformers...),
|
||||
)
|
||||
}
|
||||
8
vendor/maunium.net/go/mautrix/id/matrixuri.go
generated
vendored
8
vendor/maunium.net/go/mautrix/id/matrixuri.go
generated
vendored
@@ -67,10 +67,10 @@ func (uri *MatrixURI) getQuery() url.Values {
|
||||
func (uri *MatrixURI) String() string {
|
||||
parts := []string{
|
||||
SigilToPathSegment[uri.Sigil1],
|
||||
uri.MXID1,
|
||||
url.PathEscape(uri.MXID1),
|
||||
}
|
||||
if uri.Sigil2 != 0 {
|
||||
parts = append(parts, SigilToPathSegment[uri.Sigil2], uri.MXID2)
|
||||
parts = append(parts, SigilToPathSegment[uri.Sigil2], url.PathEscape(uri.MXID2))
|
||||
}
|
||||
return (&url.URL{
|
||||
Scheme: "matrix",
|
||||
@@ -81,9 +81,9 @@ func (uri *MatrixURI) String() string {
|
||||
|
||||
// MatrixToURL converts to parsed matrix: URI into a matrix.to URL
|
||||
func (uri *MatrixURI) MatrixToURL() string {
|
||||
fragment := fmt.Sprintf("#/%s", url.QueryEscape(uri.PrimaryIdentifier()))
|
||||
fragment := fmt.Sprintf("#/%s", url.PathEscape(uri.PrimaryIdentifier()))
|
||||
if uri.Sigil2 != 0 {
|
||||
fragment = fmt.Sprintf("%s/%s", fragment, url.QueryEscape(uri.SecondaryIdentifier()))
|
||||
fragment = fmt.Sprintf("%s/%s", fragment, url.PathEscape(uri.SecondaryIdentifier()))
|
||||
}
|
||||
query := uri.getQuery().Encode()
|
||||
if len(query) > 0 {
|
||||
|
||||
9
vendor/maunium.net/go/mautrix/pushrules/action.go
generated
vendored
9
vendor/maunium.net/go/mautrix/pushrules/action.go
generated
vendored
@@ -33,14 +33,15 @@ type PushActionArray []*PushAction
|
||||
|
||||
// PushActionArrayShould contains the important information parsed from a PushActionArray.
|
||||
type PushActionArrayShould struct {
|
||||
// Whether or not the array contained a Notify, DontNotify or Coalesce action type.
|
||||
// Whether the array contained a Notify, DontNotify or Coalesce action type.
|
||||
// Deprecated: an empty array should be treated as no notification, so there's no reason to check this field.
|
||||
NotifySpecified bool
|
||||
// Whether or not the event in question should trigger a notification.
|
||||
// Whether the event in question should trigger a notification.
|
||||
Notify bool
|
||||
// Whether or not the event in question should be highlighted.
|
||||
// Whether the event in question should be highlighted.
|
||||
Highlight bool
|
||||
|
||||
// Whether or not the event in question should trigger a sound alert.
|
||||
// Whether the event in question should trigger a sound alert.
|
||||
PlaySound bool
|
||||
// The name of the sound to play if PlaySound is true.
|
||||
SoundName string
|
||||
|
||||
26
vendor/maunium.net/go/mautrix/pushrules/rule.go
generated
vendored
26
vendor/maunium.net/go/mautrix/pushrules/rule.go
generated
vendored
@@ -20,6 +20,7 @@ func init() {
|
||||
}
|
||||
|
||||
type PushRuleCollection interface {
|
||||
GetMatchingRule(room Room, evt *event.Event) *PushRule
|
||||
GetActions(room Room, evt *event.Event) PushActionArray
|
||||
}
|
||||
|
||||
@@ -32,16 +33,20 @@ func (rules PushRuleArray) SetType(typ PushRuleType) PushRuleArray {
|
||||
return rules
|
||||
}
|
||||
|
||||
func (rules PushRuleArray) GetActions(room Room, evt *event.Event) PushActionArray {
|
||||
func (rules PushRuleArray) GetMatchingRule(room Room, evt *event.Event) *PushRule {
|
||||
for _, rule := range rules {
|
||||
if !rule.Match(room, evt) {
|
||||
continue
|
||||
}
|
||||
return rule.Actions
|
||||
return rule
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rules PushRuleArray) GetActions(room Room, evt *event.Event) PushActionArray {
|
||||
return rules.GetMatchingRule(room, evt).GetActions()
|
||||
}
|
||||
|
||||
type PushRuleMap struct {
|
||||
Map map[string]*PushRule
|
||||
Type PushRuleType
|
||||
@@ -59,7 +64,7 @@ func (rules PushRuleArray) SetTypeAndMap(typ PushRuleType) PushRuleMap {
|
||||
return data
|
||||
}
|
||||
|
||||
func (ruleMap PushRuleMap) GetActions(room Room, evt *event.Event) PushActionArray {
|
||||
func (ruleMap PushRuleMap) GetMatchingRule(room Room, evt *event.Event) *PushRule {
|
||||
var rule *PushRule
|
||||
var found bool
|
||||
switch ruleMap.Type {
|
||||
@@ -69,11 +74,15 @@ func (ruleMap PushRuleMap) GetActions(room Room, evt *event.Event) PushActionArr
|
||||
rule, found = ruleMap.Map[string(evt.Sender)]
|
||||
}
|
||||
if found && rule.Match(room, evt) {
|
||||
return rule.Actions
|
||||
return rule
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ruleMap PushRuleMap) GetActions(room Room, evt *event.Event) PushActionArray {
|
||||
return ruleMap.GetMatchingRule(room, evt).GetActions()
|
||||
}
|
||||
|
||||
func (ruleMap PushRuleMap) Unmap() PushRuleArray {
|
||||
array := make(PushRuleArray, len(ruleMap.Map))
|
||||
index := 0
|
||||
@@ -114,8 +123,15 @@ type PushRule struct {
|
||||
Pattern string `json:"pattern,omitempty"`
|
||||
}
|
||||
|
||||
func (rule *PushRule) GetActions() PushActionArray {
|
||||
if rule == nil {
|
||||
return nil
|
||||
}
|
||||
return rule.Actions
|
||||
}
|
||||
|
||||
func (rule *PushRule) Match(room Room, evt *event.Event) bool {
|
||||
if !rule.Enabled {
|
||||
if rule == nil || !rule.Enabled {
|
||||
return false
|
||||
}
|
||||
switch rule.Type {
|
||||
|
||||
23
vendor/maunium.net/go/mautrix/pushrules/ruleset.go
generated
vendored
23
vendor/maunium.net/go/mautrix/pushrules/ruleset.go
generated
vendored
@@ -67,10 +67,7 @@ func (rs *PushRuleset) MarshalJSON() ([]byte, error) {
|
||||
// collections in a Ruleset match the event given to GetActions()
|
||||
var DefaultPushActions = PushActionArray{&PushAction{Action: ActionDontNotify}}
|
||||
|
||||
// GetActions matches the given event against all of the push rule
|
||||
// collections in this push ruleset in the order of priority as
|
||||
// specified in spec section 11.12.1.4.
|
||||
func (rs *PushRuleset) GetActions(room Room, evt *event.Event) (match PushActionArray) {
|
||||
func (rs *PushRuleset) GetMatchingRule(room Room, evt *event.Event) (rule *PushRule) {
|
||||
// Add push rule collections to array in priority order
|
||||
arrays := []PushRuleCollection{rs.Override, rs.Content, rs.Room, rs.Sender, rs.Underride}
|
||||
// Loop until one of the push rule collections matches the room/event combo.
|
||||
@@ -78,11 +75,23 @@ func (rs *PushRuleset) GetActions(room Room, evt *event.Event) (match PushAction
|
||||
if pra == nil {
|
||||
continue
|
||||
}
|
||||
if match = pra.GetActions(room, evt); match != nil {
|
||||
if rule = pra.GetMatchingRule(room, evt); rule != nil {
|
||||
// Match found, return it.
|
||||
return
|
||||
}
|
||||
}
|
||||
// No match found, return default actions.
|
||||
return DefaultPushActions
|
||||
// No match found
|
||||
return nil
|
||||
}
|
||||
|
||||
// GetActions matches the given event against all of the push rule
|
||||
// collections in this push ruleset in the order of priority as
|
||||
// specified in spec section 11.12.1.4.
|
||||
func (rs *PushRuleset) GetActions(room Room, evt *event.Event) (match PushActionArray) {
|
||||
actions := rs.GetMatchingRule(room, evt).GetActions()
|
||||
if actions == nil {
|
||||
// No match found, return default actions.
|
||||
return DefaultPushActions
|
||||
}
|
||||
return actions
|
||||
}
|
||||
|
||||
24
vendor/maunium.net/go/mautrix/requests.go
generated
vendored
24
vendor/maunium.net/go/mautrix/requests.go
generated
vendored
@@ -392,6 +392,10 @@ func (req *ReqHierarchy) Query() map[string]string {
|
||||
return query
|
||||
}
|
||||
|
||||
type ReqAppservicePing struct {
|
||||
TxnID string `json:"transaction_id,omitempty"`
|
||||
}
|
||||
|
||||
type ReqBeeperMergeRoom struct {
|
||||
NewRoom ReqCreateRoom `json:"create"`
|
||||
Key string `json:"key"`
|
||||
@@ -411,3 +415,23 @@ type ReqBeeperSplitRoom struct {
|
||||
Key string `json:"key"`
|
||||
Parts []BeeperSplitRoomPart `json:"parts"`
|
||||
}
|
||||
|
||||
type ReqRoomKeysVersionCreate struct {
|
||||
Algorithm string `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
}
|
||||
|
||||
type ReqRoomKeysUpdate struct {
|
||||
Rooms map[id.RoomID]ReqRoomKeysRoomUpdate `json:"rooms"`
|
||||
}
|
||||
|
||||
type ReqRoomKeysRoomUpdate struct {
|
||||
Sessions map[id.SessionID]ReqRoomKeysSessionUpdate `json:"sessions"`
|
||||
}
|
||||
|
||||
type ReqRoomKeysSessionUpdate struct {
|
||||
FirstMessageIndex int `json:"first_message_index"`
|
||||
ForwardedCount int `json:"forwarded_count"`
|
||||
IsVerified bool `json:"is_verified"`
|
||||
SessionData json.RawMessage `json:"session_data"`
|
||||
}
|
||||
|
||||
36
vendor/maunium.net/go/mautrix/responses.go
generated
vendored
36
vendor/maunium.net/go/mautrix/responses.go
generated
vendored
@@ -552,6 +552,10 @@ type StrippedStateWithTime struct {
|
||||
Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
|
||||
}
|
||||
|
||||
type RespAppservicePing struct {
|
||||
DurationMS int64 `json:"duration_ms"`
|
||||
}
|
||||
|
||||
type RespBeeperMergeRoom RespCreateRoom
|
||||
|
||||
type RespBeeperSplitRoom struct {
|
||||
@@ -562,3 +566,35 @@ type RespTimestampToEvent struct {
|
||||
EventID id.EventID `json:"event_id"`
|
||||
Timestamp jsontime.UnixMilli `json:"origin_server_ts"`
|
||||
}
|
||||
|
||||
type RespRoomKeysVersionCreate struct {
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type RespRoomKeysVersion struct {
|
||||
Algorithm string `json:"algorithm"`
|
||||
AuthData json.RawMessage `json:"auth_data"`
|
||||
Count int `json:"count"`
|
||||
ETag string `json:"etag"`
|
||||
Version string `json:"version"`
|
||||
}
|
||||
|
||||
type RespRoomKeys struct {
|
||||
Rooms map[id.RoomID]RespRoomKeysRoom `json:"rooms"`
|
||||
}
|
||||
|
||||
type RespRoomKeysRoom struct {
|
||||
Sessions map[id.SessionID]RespRoomKeysSession `json:"sessions"`
|
||||
}
|
||||
|
||||
type RespRoomKeysSession struct {
|
||||
FirstMessageIndex int `json:"first_message_index"`
|
||||
ForwardedCount int `json:"forwarded_count"`
|
||||
IsVerified bool `json:"is_verified"`
|
||||
SessionData json.RawMessage `json:"session_data"`
|
||||
}
|
||||
|
||||
type RespRoomKeysUpdate struct {
|
||||
Count int `json:"count"`
|
||||
ETag string `json:"etag"`
|
||||
}
|
||||
|
||||
368
vendor/maunium.net/go/mautrix/sqlstatestore/statestore.go
generated
vendored
Normal file
368
vendor/maunium.net/go/mautrix/sqlstatestore/statestore.go
generated
vendored
Normal file
@@ -0,0 +1,368 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package sqlstatestore
|
||||
|
||||
import (
|
||||
"database/sql"
|
||||
"embed"
|
||||
"encoding/json"
|
||||
"errors"
|
||||
"fmt"
|
||||
"strconv"
|
||||
"strings"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
//go:embed *.sql
|
||||
var rawUpgrades embed.FS
|
||||
|
||||
var UpgradeTable dbutil.UpgradeTable
|
||||
|
||||
func init() {
|
||||
UpgradeTable.RegisterFS(rawUpgrades)
|
||||
}
|
||||
|
||||
const VersionTableName = "mx_version"
|
||||
|
||||
type SQLStateStore struct {
|
||||
*dbutil.Database
|
||||
IsBridge bool
|
||||
}
|
||||
|
||||
func NewSQLStateStore(db *dbutil.Database, log dbutil.DatabaseLogger, isBridge bool) *SQLStateStore {
|
||||
return &SQLStateStore{
|
||||
Database: db.Child(VersionTableName, UpgradeTable, log),
|
||||
IsBridge: isBridge,
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsRegistered(userID id.UserID) bool {
|
||||
var isRegistered bool
|
||||
err := store.
|
||||
QueryRow("SELECT EXISTS(SELECT 1 FROM mx_registrations WHERE user_id=$1)", userID).
|
||||
Scan(&isRegistered)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan registration existence for %s: %v", userID, err)
|
||||
}
|
||||
return isRegistered
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) MarkRegistered(userID id.UserID) {
|
||||
_, err := store.Exec("INSERT INTO mx_registrations (user_id) VALUES ($1) ON CONFLICT (user_id) DO NOTHING", userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to mark %s as registered: %v", userID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomMembers(roomID id.RoomID, memberships ...event.Membership) map[id.UserID]*event.MemberEventContent {
|
||||
members := make(map[id.UserID]*event.MemberEventContent)
|
||||
args := make([]any, len(memberships)+1)
|
||||
args[0] = roomID
|
||||
query := "SELECT user_id, membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1"
|
||||
if len(memberships) > 0 {
|
||||
placeholders := make([]string, len(memberships))
|
||||
for i, membership := range memberships {
|
||||
args[i+1] = string(membership)
|
||||
placeholders[i] = fmt.Sprintf("$%d", i+2)
|
||||
}
|
||||
query = fmt.Sprintf("%s AND membership IN (%s)", query, strings.Join(placeholders, ","))
|
||||
}
|
||||
rows, err := store.Query(query, args...)
|
||||
if err != nil {
|
||||
return members
|
||||
}
|
||||
var userID id.UserID
|
||||
var member event.MemberEventContent
|
||||
for rows.Next() {
|
||||
err = rows.Scan(&userID, &member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan member in %s: %v", roomID, err)
|
||||
} else {
|
||||
members[userID] = &member
|
||||
}
|
||||
}
|
||||
return members
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) (members []id.UserID, err error) {
|
||||
memberMap := store.GetRoomMembers(roomID, event.MembershipJoin, event.MembershipInvite)
|
||||
members = make([]id.UserID, len(memberMap))
|
||||
i := 0
|
||||
for userID := range memberMap {
|
||||
members[i] = userID
|
||||
i++
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
membership := event.MembershipLeave
|
||||
err := store.
|
||||
QueryRow("SELECT membership FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&membership)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan membership of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return membership
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member.Membership = event.MembershipLeave
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool) {
|
||||
var member event.MemberEventContent
|
||||
err := store.
|
||||
QueryRow("SELECT membership, displayname, avatar_url FROM mx_user_profile WHERE room_id=$1 AND user_id=$2", roomID, userID).
|
||||
Scan(&member.Membership, &member.Displayname, &member.AvatarURL)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan member info of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return &member, err == nil
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) FindSharedRooms(userID id.UserID) (rooms []id.RoomID) {
|
||||
query := `
|
||||
SELECT room_id FROM mx_user_profile
|
||||
LEFT JOIN portal ON portal.mxid=mx_user_profile.room_id
|
||||
WHERE mx_user_profile.user_id=$1 AND portal.encrypted=true
|
||||
`
|
||||
if !store.IsBridge {
|
||||
query = `
|
||||
SELECT mx_user_profile.room_id FROM mx_user_profile
|
||||
LEFT JOIN mx_room_state ON mx_room_state.room_id=mx_user_profile.room_id
|
||||
WHERE mx_user_profile.user_id=$1 AND mx_room_state.encryption IS NOT NULL
|
||||
`
|
||||
}
|
||||
rows, err := store.Query(query, userID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to query shared rooms with %s: %v", userID, err)
|
||||
return
|
||||
}
|
||||
for rows.Next() {
|
||||
var roomID id.RoomID
|
||||
err = rows.Scan(&roomID)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to scan room ID: %v", err)
|
||||
} else {
|
||||
rooms = append(rooms, roomID)
|
||||
}
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
_, err := store.Exec(`
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, '', '')
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership
|
||||
`, roomID, userID, membership)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, membership, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
_, err := store.Exec(`
|
||||
INSERT INTO mx_user_profile (room_id, user_id, membership, displayname, avatar_url) VALUES ($1, $2, $3, $4, $5)
|
||||
ON CONFLICT (room_id, user_id) DO UPDATE SET membership=excluded.membership, displayname=excluded.displayname, avatar_url=excluded.avatar_url
|
||||
`, roomID, userID, member.Membership, member.Displayname, member.AvatarURL)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to set membership of %s in %s to %s: %v", userID, roomID, member, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
|
||||
query := "DELETE FROM mx_user_profile WHERE room_id=$1"
|
||||
params := make([]any, len(memberships)+1)
|
||||
params[0] = roomID
|
||||
if len(memberships) > 0 {
|
||||
placeholders := make([]string, len(memberships))
|
||||
for i, membership := range memberships {
|
||||
placeholders[i] = "$" + strconv.Itoa(i+2)
|
||||
params[i+1] = string(membership)
|
||||
}
|
||||
query += fmt.Sprintf(" AND membership IN (%s)", strings.Join(placeholders, ","))
|
||||
}
|
||||
_, err := store.Exec(query, params...)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to clear cached members of %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
|
||||
contentBytes, err := json.Marshal(content)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to marshal encryption config of %s: %v", roomID, err)
|
||||
return
|
||||
}
|
||||
_, err = store.Exec(`
|
||||
INSERT INTO mx_room_state (room_id, encryption) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET encryption=excluded.encryption
|
||||
`, roomID, contentBytes)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to store encryption config of %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
|
||||
var data []byte
|
||||
err := store.
|
||||
QueryRow("SELECT encryption FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan encryption config of %s: %v", roomID, err)
|
||||
}
|
||||
return nil
|
||||
} else if data == nil {
|
||||
return nil
|
||||
}
|
||||
content := &event.EncryptionEventContent{}
|
||||
err = json.Unmarshal(data, content)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to parse encryption config of %s: %v", roomID, err)
|
||||
return nil
|
||||
}
|
||||
return content
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) IsEncrypted(roomID id.RoomID) bool {
|
||||
cfg := store.GetEncryptionEvent(roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
levelsBytes, err := json.Marshal(levels)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to marshal power levels of %s: %v", roomID, err)
|
||||
return
|
||||
}
|
||||
_, err = store.Exec(`
|
||||
INSERT INTO mx_room_state (room_id, power_levels) VALUES ($1, $2)
|
||||
ON CONFLICT (room_id) DO UPDATE SET power_levels=excluded.power_levels
|
||||
`, roomID, levelsBytes)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to store power levels of %s: %v", roomID, err)
|
||||
}
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
var data []byte
|
||||
err := store.
|
||||
QueryRow("SELECT power_levels FROM mx_room_state WHERE room_id=$1", roomID).
|
||||
Scan(&data)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power levels of %s: %v", roomID, err)
|
||||
}
|
||||
return
|
||||
} else if data == nil {
|
||||
return
|
||||
}
|
||||
levels = &event.PowerLevelsEventContent{}
|
||||
err = json.Unmarshal(data, levels)
|
||||
if err != nil {
|
||||
store.Log.Warn("Failed to parse power levels of %s: %v", roomID, err)
|
||||
return nil
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
SELECT COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, userID).
|
||||
Scan(&powerLevel)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level of %s in %s: %v", userID, roomID, err)
|
||||
}
|
||||
return powerLevel
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
if eventType.IsState() {
|
||||
defaultType = "state_default"
|
||||
defaultValue = 50
|
||||
}
|
||||
var powerLevel int
|
||||
err := store.
|
||||
QueryRow(`
|
||||
SELECT COALESCE((power_levels->'events'->$2)::int, (power_levels->'$3')::int, $4)
|
||||
FROM mx_room_state WHERE room_id=$1
|
||||
`, roomID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&powerLevel)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue
|
||||
}
|
||||
return powerLevel
|
||||
}
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
}
|
||||
|
||||
func (store *SQLStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
if store.Dialect == dbutil.Postgres {
|
||||
defaultType := "events_default"
|
||||
defaultValue := 0
|
||||
if eventType.IsState() {
|
||||
defaultType = "state_default"
|
||||
defaultValue = 50
|
||||
}
|
||||
var hasPower bool
|
||||
err := store.
|
||||
QueryRow(`SELECT
|
||||
COALESCE((power_levels->'users'->$2)::int, (power_levels->'users_default')::int, 0)
|
||||
>=
|
||||
COALESCE((power_levels->'events'->$3)::int, (power_levels->'$4')::int, $5)
|
||||
FROM mx_room_state WHERE room_id=$1`, roomID, userID, eventType.Type, defaultType, defaultValue).
|
||||
Scan(&hasPower)
|
||||
if err != nil {
|
||||
if !errors.Is(err, sql.ErrNoRows) {
|
||||
store.Log.Warn("Failed to scan power level for %s in %s: %v", eventType, roomID, err)
|
||||
}
|
||||
return defaultValue == 0
|
||||
}
|
||||
return hasPower
|
||||
}
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
23
vendor/maunium.net/go/mautrix/sqlstatestore/v00-latest-revision.sql
generated
vendored
Normal file
23
vendor/maunium.net/go/mautrix/sqlstatestore/v00-latest-revision.sql
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
-- v0 -> v5: Latest revision
|
||||
|
||||
CREATE TABLE mx_registrations (
|
||||
user_id TEXT PRIMARY KEY
|
||||
);
|
||||
|
||||
-- only: postgres
|
||||
CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock');
|
||||
|
||||
CREATE TABLE mx_user_profile (
|
||||
room_id TEXT,
|
||||
user_id TEXT,
|
||||
membership membership NOT NULL,
|
||||
displayname TEXT NOT NULL DEFAULT '',
|
||||
avatar_url TEXT NOT NULL DEFAULT '',
|
||||
PRIMARY KEY (room_id, user_id)
|
||||
);
|
||||
|
||||
CREATE TABLE mx_room_state (
|
||||
room_id TEXT PRIMARY KEY,
|
||||
power_levels jsonb,
|
||||
encryption jsonb
|
||||
);
|
||||
6
vendor/maunium.net/go/mautrix/sqlstatestore/v02-membership-enum.sql
generated
vendored
Normal file
6
vendor/maunium.net/go/mautrix/sqlstatestore/v02-membership-enum.sql
generated
vendored
Normal file
@@ -0,0 +1,6 @@
|
||||
-- v2: Use enum for membership field on Postgres
|
||||
-- only: postgres
|
||||
|
||||
CREATE TYPE membership AS ENUM ('join', 'leave', 'invite', 'ban', 'knock');
|
||||
UPDATE mx_user_profile SET membership='leave' WHERE LOWER(membership) NOT IN ('join', 'leave', 'invite', 'ban', 'knock');
|
||||
ALTER TABLE mx_user_profile ALTER COLUMN membership TYPE membership USING LOWER(membership)::membership;
|
||||
10
vendor/maunium.net/go/mautrix/sqlstatestore/v03-no-null.sql
generated
vendored
Normal file
10
vendor/maunium.net/go/mautrix/sqlstatestore/v03-no-null.sql
generated
vendored
Normal file
@@ -0,0 +1,10 @@
|
||||
-- v3: Disable nulls in mx_user_profile
|
||||
|
||||
UPDATE mx_user_profile SET displayname='' WHERE displayname IS NULL;
|
||||
UPDATE mx_user_profile SET avatar_url='' WHERE avatar_url IS NULL;
|
||||
|
||||
-- only: postgres for next 4 lines
|
||||
ALTER TABLE mx_user_profile ALTER COLUMN displayname SET DEFAULT '';
|
||||
ALTER TABLE mx_user_profile ALTER COLUMN displayname SET NOT NULL;
|
||||
ALTER TABLE mx_user_profile ALTER COLUMN avatar_url SET DEFAULT '';
|
||||
ALTER TABLE mx_user_profile ALTER COLUMN avatar_url SET NOT NULL;
|
||||
2
vendor/maunium.net/go/mautrix/sqlstatestore/v04-encryption-info.sql
generated
vendored
Normal file
2
vendor/maunium.net/go/mautrix/sqlstatestore/v04-encryption-info.sql
generated
vendored
Normal file
@@ -0,0 +1,2 @@
|
||||
-- v4: Store room encryption configuration
|
||||
ALTER TABLE mx_room_state ADD COLUMN encryption jsonb;
|
||||
29
vendor/maunium.net/go/mautrix/sqlstatestore/v05-mark-encryption-state-resync.go
generated
vendored
Normal file
29
vendor/maunium.net/go/mautrix/sqlstatestore/v05-mark-encryption-state-resync.go
generated
vendored
Normal file
@@ -0,0 +1,29 @@
|
||||
package sqlstatestore
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
|
||||
"maunium.net/go/mautrix/util/dbutil"
|
||||
)
|
||||
|
||||
func init() {
|
||||
UpgradeTable.Register(-1, 5, 0, "Mark rooms that need crypto state event resynced", true, func(tx dbutil.Execable, db *dbutil.Database) error {
|
||||
portalExists, err := db.TableExists(tx, "portal")
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to check if portal table exists")
|
||||
}
|
||||
if portalExists {
|
||||
_, err = tx.Exec(`
|
||||
INSERT INTO mx_room_state (room_id, encryption)
|
||||
SELECT portal.mxid, '{"resync":true}' FROM portal WHERE portal.encrypted=true AND portal.mxid IS NOT NULL
|
||||
ON CONFLICT (room_id) DO UPDATE
|
||||
SET encryption=excluded.encryption
|
||||
WHERE mx_room_state.encryption IS NULL
|
||||
`)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
})
|
||||
}
|
||||
251
vendor/maunium.net/go/mautrix/statestore.go
generated
vendored
Normal file
251
vendor/maunium.net/go/mautrix/statestore.go
generated
vendored
Normal file
@@ -0,0 +1,251 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package mautrix
|
||||
|
||||
import (
|
||||
"sync"
|
||||
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// StateStore is an interface for storing basic room state information.
|
||||
type StateStore interface {
|
||||
IsInRoom(roomID id.RoomID, userID id.UserID) bool
|
||||
IsInvited(roomID id.RoomID, userID id.UserID) bool
|
||||
IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool
|
||||
GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent
|
||||
TryGetMember(roomID id.RoomID, userID id.UserID) (*event.MemberEventContent, bool)
|
||||
SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership)
|
||||
SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent)
|
||||
ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership)
|
||||
|
||||
SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent)
|
||||
GetPowerLevels(roomID id.RoomID) *event.PowerLevelsEventContent
|
||||
|
||||
SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent)
|
||||
IsEncrypted(roomID id.RoomID) bool
|
||||
|
||||
GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error)
|
||||
}
|
||||
|
||||
func UpdateStateStore(store StateStore, evt *event.Event) {
|
||||
if store == nil || evt == nil || evt.StateKey == nil {
|
||||
return
|
||||
}
|
||||
// We only care about events without a state key (power levels, encryption) or member events with state key
|
||||
if evt.Type != event.StateMember && evt.GetStateKey() != "" {
|
||||
return
|
||||
}
|
||||
switch content := evt.Content.Parsed.(type) {
|
||||
case *event.MemberEventContent:
|
||||
store.SetMember(evt.RoomID, id.UserID(evt.GetStateKey()), content)
|
||||
case *event.PowerLevelsEventContent:
|
||||
store.SetPowerLevels(evt.RoomID, content)
|
||||
case *event.EncryptionEventContent:
|
||||
store.SetEncryptionEvent(evt.RoomID, content)
|
||||
}
|
||||
}
|
||||
|
||||
// StateStoreSyncHandler can be added as an event handler in the syncer to update the state store automatically.
|
||||
//
|
||||
// client.Syncer.(mautrix.ExtensibleSyncer).OnEvent(client.StateStoreSyncHandler)
|
||||
//
|
||||
// DefaultSyncer.ParseEventContent must also be true for this to work (which it is by default).
|
||||
func (cli *Client) StateStoreSyncHandler(_ EventSource, evt *event.Event) {
|
||||
UpdateStateStore(cli.StateStore, evt)
|
||||
}
|
||||
|
||||
type MemoryStateStore struct {
|
||||
Registrations map[id.UserID]bool `json:"registrations"`
|
||||
Members map[id.RoomID]map[id.UserID]*event.MemberEventContent `json:"memberships"`
|
||||
PowerLevels map[id.RoomID]*event.PowerLevelsEventContent `json:"power_levels"`
|
||||
Encryption map[id.RoomID]*event.EncryptionEventContent `json:"encryption"`
|
||||
|
||||
registrationsLock sync.RWMutex
|
||||
membersLock sync.RWMutex
|
||||
powerLevelsLock sync.RWMutex
|
||||
encryptionLock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewMemoryStateStore() StateStore {
|
||||
return &MemoryStateStore{
|
||||
Registrations: make(map[id.UserID]bool),
|
||||
Members: make(map[id.RoomID]map[id.UserID]*event.MemberEventContent),
|
||||
PowerLevels: make(map[id.RoomID]*event.PowerLevelsEventContent),
|
||||
Encryption: make(map[id.RoomID]*event.EncryptionEventContent),
|
||||
}
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsRegistered(userID id.UserID) bool {
|
||||
store.registrationsLock.RLock()
|
||||
defer store.registrationsLock.RUnlock()
|
||||
registered, ok := store.Registrations[userID]
|
||||
return ok && registered
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) MarkRegistered(userID id.UserID) {
|
||||
store.registrationsLock.Lock()
|
||||
defer store.registrationsLock.Unlock()
|
||||
store.Registrations[userID] = true
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetRoomMembers(roomID id.RoomID) map[id.UserID]*event.MemberEventContent {
|
||||
store.membersLock.RLock()
|
||||
members, ok := store.Members[roomID]
|
||||
store.membersLock.RUnlock()
|
||||
if !ok {
|
||||
members = make(map[id.UserID]*event.MemberEventContent)
|
||||
store.membersLock.Lock()
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
return members
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetRoomJoinedOrInvitedMembers(roomID id.RoomID) ([]id.UserID, error) {
|
||||
members := store.GetRoomMembers(roomID)
|
||||
ids := make([]id.UserID, 0, len(members))
|
||||
for id := range members {
|
||||
ids = append(ids, id)
|
||||
}
|
||||
return ids, nil
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetMembership(roomID id.RoomID, userID id.UserID) event.Membership {
|
||||
return store.GetMember(roomID, userID).Membership
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetMember(roomID id.RoomID, userID id.UserID) *event.MemberEventContent {
|
||||
member, ok := store.TryGetMember(roomID, userID)
|
||||
if !ok {
|
||||
member = &event.MemberEventContent{Membership: event.MembershipLeave}
|
||||
}
|
||||
return member
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) TryGetMember(roomID id.RoomID, userID id.UserID) (member *event.MemberEventContent, ok bool) {
|
||||
store.membersLock.RLock()
|
||||
defer store.membersLock.RUnlock()
|
||||
members, membersOk := store.Members[roomID]
|
||||
if !membersOk {
|
||||
return
|
||||
}
|
||||
member, ok = members[userID]
|
||||
return
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsInRoom(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join")
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsInvited(roomID id.RoomID, userID id.UserID) bool {
|
||||
return store.IsMembership(roomID, userID, "join", "invite")
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsMembership(roomID id.RoomID, userID id.UserID, allowedMemberships ...event.Membership) bool {
|
||||
membership := store.GetMembership(roomID, userID)
|
||||
for _, allowedMembership := range allowedMemberships {
|
||||
if allowedMembership == membership {
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetMembership(roomID id.RoomID, userID id.UserID, membership event.Membership) {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
members = map[id.UserID]*event.MemberEventContent{
|
||||
userID: {Membership: membership},
|
||||
}
|
||||
} else {
|
||||
member, ok := members[userID]
|
||||
if !ok {
|
||||
members[userID] = &event.MemberEventContent{Membership: membership}
|
||||
} else {
|
||||
member.Membership = membership
|
||||
members[userID] = member
|
||||
}
|
||||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetMember(roomID id.RoomID, userID id.UserID, member *event.MemberEventContent) {
|
||||
store.membersLock.Lock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
members = map[id.UserID]*event.MemberEventContent{
|
||||
userID: member,
|
||||
}
|
||||
} else {
|
||||
members[userID] = member
|
||||
}
|
||||
store.Members[roomID] = members
|
||||
store.membersLock.Unlock()
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) ClearCachedMembers(roomID id.RoomID, memberships ...event.Membership) {
|
||||
store.membersLock.Lock()
|
||||
defer store.membersLock.Unlock()
|
||||
members, ok := store.Members[roomID]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
for userID, member := range members {
|
||||
for _, membership := range memberships {
|
||||
if membership == member.Membership {
|
||||
delete(members, userID)
|
||||
break
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetPowerLevels(roomID id.RoomID, levels *event.PowerLevelsEventContent) {
|
||||
store.powerLevelsLock.Lock()
|
||||
store.PowerLevels[roomID] = levels
|
||||
store.powerLevelsLock.Unlock()
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevels(roomID id.RoomID) (levels *event.PowerLevelsEventContent) {
|
||||
store.powerLevelsLock.RLock()
|
||||
levels = store.PowerLevels[roomID]
|
||||
store.powerLevelsLock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevel(roomID id.RoomID, userID id.UserID) int {
|
||||
return store.GetPowerLevels(roomID).GetUserLevel(userID)
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetPowerLevelRequirement(roomID id.RoomID, eventType event.Type) int {
|
||||
return store.GetPowerLevels(roomID).GetEventLevel(eventType)
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) HasPowerLevel(roomID id.RoomID, userID id.UserID, eventType event.Type) bool {
|
||||
return store.GetPowerLevel(roomID, userID) >= store.GetPowerLevelRequirement(roomID, eventType)
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) SetEncryptionEvent(roomID id.RoomID, content *event.EncryptionEventContent) {
|
||||
store.encryptionLock.Lock()
|
||||
store.Encryption[roomID] = content
|
||||
store.encryptionLock.Unlock()
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
|
||||
store.encryptionLock.RLock()
|
||||
defer store.encryptionLock.RUnlock()
|
||||
return store.Encryption[roomID]
|
||||
}
|
||||
|
||||
func (store *MemoryStateStore) IsEncrypted(roomID id.RoomID) bool {
|
||||
cfg := store.GetEncryptionEvent(roomID)
|
||||
return cfg != nil && cfg.Algorithm == id.AlgorithmMegolmV1
|
||||
}
|
||||
161
vendor/maunium.net/go/mautrix/store.go
generated
vendored
161
vendor/maunium.net/go/mautrix/store.go
generated
vendored
@@ -1,161 +0,0 @@
|
||||
package mautrix
|
||||
|
||||
import (
|
||||
"maunium.net/go/mautrix/event"
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// Storer is an interface which must be satisfied to store client data.
|
||||
//
|
||||
// You can either write a struct which persists this data to disk, or you can use the
|
||||
// provided "InMemoryStore" which just keeps data around in-memory which is lost on
|
||||
// restarts.
|
||||
type Storer interface {
|
||||
SaveFilterID(userID id.UserID, filterID string)
|
||||
LoadFilterID(userID id.UserID) string
|
||||
SaveNextBatch(userID id.UserID, nextBatchToken string)
|
||||
LoadNextBatch(userID id.UserID) string
|
||||
SaveRoom(room *Room)
|
||||
LoadRoom(roomID id.RoomID) *Room
|
||||
}
|
||||
|
||||
// InMemoryStore implements the Storer interface.
|
||||
//
|
||||
// Everything is persisted in-memory as maps. It is not safe to load/save filter IDs
|
||||
// or next batch tokens on any goroutine other than the syncing goroutine: the one
|
||||
// which called Client.Sync().
|
||||
type InMemoryStore struct {
|
||||
Filters map[id.UserID]string
|
||||
NextBatch map[id.UserID]string
|
||||
Rooms map[id.RoomID]*Room
|
||||
}
|
||||
|
||||
// SaveFilterID to memory.
|
||||
func (s *InMemoryStore) SaveFilterID(userID id.UserID, filterID string) {
|
||||
s.Filters[userID] = filterID
|
||||
}
|
||||
|
||||
// LoadFilterID from memory.
|
||||
func (s *InMemoryStore) LoadFilterID(userID id.UserID) string {
|
||||
return s.Filters[userID]
|
||||
}
|
||||
|
||||
// SaveNextBatch to memory.
|
||||
func (s *InMemoryStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
|
||||
s.NextBatch[userID] = nextBatchToken
|
||||
}
|
||||
|
||||
// LoadNextBatch from memory.
|
||||
func (s *InMemoryStore) LoadNextBatch(userID id.UserID) string {
|
||||
return s.NextBatch[userID]
|
||||
}
|
||||
|
||||
// SaveRoom to memory.
|
||||
func (s *InMemoryStore) SaveRoom(room *Room) {
|
||||
s.Rooms[room.ID] = room
|
||||
}
|
||||
|
||||
// LoadRoom from memory.
|
||||
func (s *InMemoryStore) LoadRoom(roomID id.RoomID) *Room {
|
||||
return s.Rooms[roomID]
|
||||
}
|
||||
|
||||
// UpdateState stores a state event. This can be passed to DefaultSyncer.OnEvent to keep all room state cached.
|
||||
func (s *InMemoryStore) UpdateState(_ EventSource, evt *event.Event) {
|
||||
if !evt.Type.IsState() {
|
||||
return
|
||||
}
|
||||
room := s.LoadRoom(evt.RoomID)
|
||||
if room == nil {
|
||||
room = NewRoom(evt.RoomID)
|
||||
s.SaveRoom(room)
|
||||
}
|
||||
room.UpdateState(evt)
|
||||
}
|
||||
|
||||
// NewInMemoryStore constructs a new InMemoryStore.
|
||||
func NewInMemoryStore() *InMemoryStore {
|
||||
return &InMemoryStore{
|
||||
Filters: make(map[id.UserID]string),
|
||||
NextBatch: make(map[id.UserID]string),
|
||||
Rooms: make(map[id.RoomID]*Room),
|
||||
}
|
||||
}
|
||||
|
||||
// AccountDataStore uses account data to store the next batch token, and
|
||||
// reuses the InMemoryStore for all other operations.
|
||||
type AccountDataStore struct {
|
||||
*InMemoryStore
|
||||
eventType string
|
||||
client *Client
|
||||
}
|
||||
|
||||
type accountData struct {
|
||||
NextBatch string `json:"next_batch"`
|
||||
}
|
||||
|
||||
// SaveNextBatch to account data.
|
||||
func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with bots")
|
||||
}
|
||||
|
||||
data := accountData{
|
||||
NextBatch: nextBatchToken,
|
||||
}
|
||||
|
||||
err := s.client.SetAccountData(s.eventType, data)
|
||||
if err != nil {
|
||||
if s.client.Logger != nil {
|
||||
s.client.Logger.Debugfln("failed to save next batch token to account data: %s", err.Error())
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// LoadNextBatch from account data.
|
||||
func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with bots")
|
||||
}
|
||||
|
||||
data := &accountData{}
|
||||
|
||||
err := s.client.GetAccountData(s.eventType, data)
|
||||
if err != nil {
|
||||
if s.client.Logger != nil {
|
||||
s.client.Logger.Debugfln("failed to load next batch token to account data: %s", err.Error())
|
||||
}
|
||||
return ""
|
||||
}
|
||||
|
||||
return data.NextBatch
|
||||
}
|
||||
|
||||
// NewAccountDataStore returns a new AccountDataStore, which stores
|
||||
// the next_batch token as a custom event in account data in the
|
||||
// homeserver.
|
||||
//
|
||||
// AccountDataStore is only appropriate for bots, not appservices.
|
||||
//
|
||||
// eventType should be a reversed DNS name like tld.domain.sub.internal and
|
||||
// must be unique for a client. The data stored in it is considered internal
|
||||
// and must not be modified through outside means. You should also add a filter
|
||||
// for account data changes of this event type, to avoid ending up in a sync
|
||||
// loop:
|
||||
//
|
||||
// mautrix.Filter{
|
||||
// AccountData: mautrix.FilterPart{
|
||||
// Limit: 20,
|
||||
// NotTypes: []event.Type{
|
||||
// event.NewEventType(eventType),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// mautrix.Client.CreateFilter(...)
|
||||
func NewAccountDataStore(eventType string, client *Client) *AccountDataStore {
|
||||
return &AccountDataStore{
|
||||
InMemoryStore: NewInMemoryStore(),
|
||||
eventType: eventType,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
118
vendor/maunium.net/go/mautrix/sync.go
generated
vendored
118
vendor/maunium.net/go/mautrix/sync.go
generated
vendored
@@ -7,6 +7,7 @@
|
||||
package mautrix
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"fmt"
|
||||
"runtime/debug"
|
||||
"time"
|
||||
@@ -28,44 +29,52 @@ const (
|
||||
EventSourceState
|
||||
EventSourceEphemeral
|
||||
EventSourceToDevice
|
||||
EventSourceDecrypted
|
||||
)
|
||||
|
||||
const primaryTypes = EventSourcePresence | EventSourceAccountData | EventSourceToDevice | EventSourceTimeline | EventSourceState
|
||||
const roomSections = EventSourceJoin | EventSourceInvite | EventSourceLeave
|
||||
const roomableTypes = EventSourceAccountData | EventSourceTimeline | EventSourceState
|
||||
const encryptableTypes = roomableTypes | EventSourceToDevice
|
||||
|
||||
func (es EventSource) String() string {
|
||||
switch {
|
||||
case es == EventSourcePresence:
|
||||
return "presence"
|
||||
case es == EventSourceAccountData:
|
||||
return "user account data"
|
||||
case es == EventSourceToDevice:
|
||||
return "to-device"
|
||||
case es&EventSourceJoin != 0:
|
||||
es -= EventSourceJoin
|
||||
switch es {
|
||||
case EventSourceState:
|
||||
return "joined state"
|
||||
case EventSourceTimeline:
|
||||
return "joined timeline"
|
||||
case EventSourceEphemeral:
|
||||
return "room ephemeral (joined)"
|
||||
case EventSourceAccountData:
|
||||
return "room account data (joined)"
|
||||
}
|
||||
case es&EventSourceInvite != 0:
|
||||
es -= EventSourceInvite
|
||||
switch es {
|
||||
case EventSourceState:
|
||||
return "invited state"
|
||||
}
|
||||
case es&EventSourceLeave != 0:
|
||||
es -= EventSourceLeave
|
||||
switch es {
|
||||
case EventSourceState:
|
||||
return "left state"
|
||||
case EventSourceTimeline:
|
||||
return "left timeline"
|
||||
}
|
||||
var typeName string
|
||||
switch es & primaryTypes {
|
||||
case EventSourcePresence:
|
||||
typeName = "presence"
|
||||
case EventSourceAccountData:
|
||||
typeName = "account data"
|
||||
case EventSourceToDevice:
|
||||
typeName = "to-device"
|
||||
case EventSourceTimeline:
|
||||
typeName = "timeline"
|
||||
case EventSourceState:
|
||||
typeName = "state"
|
||||
default:
|
||||
return fmt.Sprintf("unknown (%d)", es)
|
||||
}
|
||||
return fmt.Sprintf("unknown (%d)", es)
|
||||
if es&roomableTypes != 0 {
|
||||
switch es & roomSections {
|
||||
case EventSourceJoin:
|
||||
typeName = "joined room " + typeName
|
||||
case EventSourceInvite:
|
||||
typeName = "invited room " + typeName
|
||||
case EventSourceLeave:
|
||||
typeName = "left room " + typeName
|
||||
default:
|
||||
return fmt.Sprintf("unknown (%d)", es)
|
||||
}
|
||||
es &^= roomableTypes
|
||||
}
|
||||
if es&encryptableTypes != 0 && es&EventSourceDecrypted != 0 {
|
||||
typeName += " (decrypted)"
|
||||
es &^= EventSourceDecrypted
|
||||
}
|
||||
es &^= primaryTypes
|
||||
if es != 0 {
|
||||
return fmt.Sprintf("unknown (%d)", es)
|
||||
}
|
||||
return typeName
|
||||
}
|
||||
|
||||
// EventHandler handles a single event from a sync response.
|
||||
@@ -76,9 +85,8 @@ type SyncHandler func(resp *RespSync, since string) bool
|
||||
|
||||
// Syncer is an interface that must be satisfied in order to do /sync requests on a client.
|
||||
type Syncer interface {
|
||||
// Process the /sync response. The since parameter is the since= value that was used to produce the response.
|
||||
// This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped
|
||||
// permanently.
|
||||
// ProcessResponse processes the /sync response. The since parameter is the since= value that was used to produce the response.
|
||||
// This is useful for detecting the very first sync (since=""). If an error is return, Syncing will be stopped permanently.
|
||||
ProcessResponse(resp *RespSync, since string) error
|
||||
// OnFailedSync returns either the time to wait before retrying or an error to stop syncing permanently.
|
||||
OnFailedSync(res *RespSync, err error) (time.Duration, error)
|
||||
@@ -92,6 +100,10 @@ type ExtensibleSyncer interface {
|
||||
OnEventType(eventType event.Type, callback EventHandler)
|
||||
}
|
||||
|
||||
type DispatchableSyncer interface {
|
||||
Dispatch(source EventSource, evt *event.Event)
|
||||
}
|
||||
|
||||
// DefaultSyncer is the default syncing implementation. You can either write your own syncer, or selectively
|
||||
// replace parts of this default syncer (e.g. the ProcessResponse method). The default syncer uses the observer
|
||||
// pattern to notify callers about incoming events. See DefaultSyncer.OnEventType for more information.
|
||||
@@ -107,6 +119,8 @@ type DefaultSyncer struct {
|
||||
// ParseErrorHandler is called when event.Content.ParseRaw returns an error.
|
||||
// If it returns false, the event will not be forwarded to listeners.
|
||||
ParseErrorHandler func(evt *event.Event, err error) bool
|
||||
// FilterJSON is used when the client starts syncing and doesn't get an existing filter ID from SyncStore's LoadFilterID.
|
||||
FilterJSON *Filter
|
||||
}
|
||||
|
||||
var _ Syncer = (*DefaultSyncer)(nil)
|
||||
@@ -191,10 +205,10 @@ func (s *DefaultSyncer) processSyncEvent(roomID id.RoomID, evt *event.Event, sou
|
||||
}
|
||||
}
|
||||
|
||||
s.notifyListeners(source, evt)
|
||||
s.Dispatch(source, evt)
|
||||
}
|
||||
|
||||
func (s *DefaultSyncer) notifyListeners(source EventSource, evt *event.Event) {
|
||||
func (s *DefaultSyncer) Dispatch(source EventSource, evt *event.Event) {
|
||||
for _, fn := range s.globalListeners {
|
||||
fn(source, evt)
|
||||
}
|
||||
@@ -226,22 +240,34 @@ func (s *DefaultSyncer) OnEvent(callback EventHandler) {
|
||||
|
||||
// OnFailedSync always returns a 10 second wait period between failed /syncs, never a fatal error.
|
||||
func (s *DefaultSyncer) OnFailedSync(res *RespSync, err error) (time.Duration, error) {
|
||||
if errors.Is(err, MUnknownToken) {
|
||||
return 0, err
|
||||
}
|
||||
return 10 * time.Second, nil
|
||||
}
|
||||
|
||||
var defaultFilter = Filter{
|
||||
Room: RoomFilter{
|
||||
Timeline: FilterPart{
|
||||
Limit: 50,
|
||||
},
|
||||
},
|
||||
}
|
||||
|
||||
// GetFilterJSON returns a filter with a timeline limit of 50.
|
||||
func (s *DefaultSyncer) GetFilterJSON(userID id.UserID) *Filter {
|
||||
return &Filter{
|
||||
Room: RoomFilter{
|
||||
Timeline: FilterPart{
|
||||
Limit: 50,
|
||||
},
|
||||
},
|
||||
if s.FilterJSON == nil {
|
||||
defaultFilterCopy := defaultFilter
|
||||
s.FilterJSON = &defaultFilterCopy
|
||||
}
|
||||
return s.FilterJSON
|
||||
}
|
||||
|
||||
// OldEventIgnorer is an utility struct for bots to ignore events from before the bot joined the room.
|
||||
// Create a struct and call Register with your DefaultSyncer to register the sync handler.
|
||||
//
|
||||
// Create a struct and call Register with your DefaultSyncer to register the sync handler, e.g.:
|
||||
//
|
||||
// (&OldEventIgnorer{UserID: cli.UserID}).Register(cli.Syncer.(mautrix.ExtensibleSyncer))
|
||||
type OldEventIgnorer struct {
|
||||
UserID id.UserID
|
||||
}
|
||||
|
||||
166
vendor/maunium.net/go/mautrix/syncstore.go
generated
vendored
Normal file
166
vendor/maunium.net/go/mautrix/syncstore.go
generated
vendored
Normal file
@@ -0,0 +1,166 @@
|
||||
package mautrix
|
||||
|
||||
import (
|
||||
"errors"
|
||||
|
||||
"maunium.net/go/mautrix/id"
|
||||
)
|
||||
|
||||
// SyncStore is an interface which must be satisfied to store client data.
|
||||
//
|
||||
// You can either write a struct which persists this data to disk, or you can use the
|
||||
// provided "MemorySyncStore" which just keeps data around in-memory which is lost on
|
||||
// restarts.
|
||||
type SyncStore interface {
|
||||
SaveFilterID(userID id.UserID, filterID string)
|
||||
LoadFilterID(userID id.UserID) string
|
||||
SaveNextBatch(userID id.UserID, nextBatchToken string)
|
||||
LoadNextBatch(userID id.UserID) string
|
||||
}
|
||||
|
||||
// Deprecated: renamed to SyncStore
|
||||
type Storer = SyncStore
|
||||
|
||||
// MemorySyncStore implements the Storer interface.
|
||||
//
|
||||
// Everything is persisted in-memory as maps. It is not safe to load/save filter IDs
|
||||
// or next batch tokens on any goroutine other than the syncing goroutine: the one
|
||||
// which called Client.Sync().
|
||||
type MemorySyncStore struct {
|
||||
Filters map[id.UserID]string
|
||||
NextBatch map[id.UserID]string
|
||||
}
|
||||
|
||||
// SaveFilterID to memory.
|
||||
func (s *MemorySyncStore) SaveFilterID(userID id.UserID, filterID string) {
|
||||
s.Filters[userID] = filterID
|
||||
}
|
||||
|
||||
// LoadFilterID from memory.
|
||||
func (s *MemorySyncStore) LoadFilterID(userID id.UserID) string {
|
||||
return s.Filters[userID]
|
||||
}
|
||||
|
||||
// SaveNextBatch to memory.
|
||||
func (s *MemorySyncStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
|
||||
s.NextBatch[userID] = nextBatchToken
|
||||
}
|
||||
|
||||
// LoadNextBatch from memory.
|
||||
func (s *MemorySyncStore) LoadNextBatch(userID id.UserID) string {
|
||||
return s.NextBatch[userID]
|
||||
}
|
||||
|
||||
// NewMemorySyncStore constructs a new MemorySyncStore.
|
||||
func NewMemorySyncStore() *MemorySyncStore {
|
||||
return &MemorySyncStore{
|
||||
Filters: make(map[id.UserID]string),
|
||||
NextBatch: make(map[id.UserID]string),
|
||||
}
|
||||
}
|
||||
|
||||
// AccountDataStore uses account data to store the next batch token, and stores the filter ID in memory
|
||||
// (as filters can be safely recreated every startup).
|
||||
type AccountDataStore struct {
|
||||
FilterID string
|
||||
EventType string
|
||||
client *Client
|
||||
nextBatch string
|
||||
}
|
||||
|
||||
type accountData struct {
|
||||
NextBatch string `json:"next_batch"`
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) SaveFilterID(userID id.UserID, filterID string) {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
}
|
||||
s.FilterID = filterID
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) LoadFilterID(userID id.UserID) string {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
}
|
||||
return s.FilterID
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) SaveNextBatch(userID id.UserID, nextBatchToken string) {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
} else if nextBatchToken == s.nextBatch {
|
||||
return
|
||||
}
|
||||
|
||||
data := accountData{
|
||||
NextBatch: nextBatchToken,
|
||||
}
|
||||
|
||||
err := s.client.SetAccountData(s.EventType, data)
|
||||
if err != nil {
|
||||
s.client.Log.Warn().Err(err).Msg("Failed to save next batch token to account data")
|
||||
} else {
|
||||
s.client.Log.Debug().
|
||||
Str("old_token", s.nextBatch).
|
||||
Str("new_token", nextBatchToken).
|
||||
Msg("Saved next batch token")
|
||||
s.nextBatch = nextBatchToken
|
||||
}
|
||||
}
|
||||
|
||||
func (s *AccountDataStore) LoadNextBatch(userID id.UserID) string {
|
||||
if userID.String() != s.client.UserID.String() {
|
||||
panic("AccountDataStore must only be used with a single account")
|
||||
}
|
||||
|
||||
data := &accountData{}
|
||||
|
||||
err := s.client.GetAccountData(s.EventType, data)
|
||||
if err != nil {
|
||||
if errors.Is(err, MNotFound) {
|
||||
s.client.Log.Debug().Msg("No next batch token found in account data")
|
||||
} else {
|
||||
s.client.Log.Warn().Err(err).Msg("Failed to load next batch token from account data")
|
||||
}
|
||||
return ""
|
||||
}
|
||||
s.nextBatch = data.NextBatch
|
||||
s.client.Log.Debug().Str("next_batch", data.NextBatch).Msg("Loaded next batch token from account data")
|
||||
|
||||
return s.nextBatch
|
||||
}
|
||||
|
||||
// NewAccountDataStore returns a new AccountDataStore, which stores
|
||||
// the next_batch token as a custom event in account data in the
|
||||
// homeserver.
|
||||
//
|
||||
// AccountDataStore is only appropriate for bots, not appservices.
|
||||
//
|
||||
// The event type should be a reversed DNS name like tld.domain.sub.internal and
|
||||
// must be unique for a client. The data stored in it is considered internal
|
||||
// and must not be modified through outside means. You should also add a filter
|
||||
// for account data changes of this event type, to avoid ending up in a sync
|
||||
// loop:
|
||||
//
|
||||
// filter := mautrix.Filter{
|
||||
// AccountData: mautrix.FilterPart{
|
||||
// Limit: 20,
|
||||
// NotTypes: []event.Type{
|
||||
// event.NewEventType(eventType),
|
||||
// },
|
||||
// },
|
||||
// }
|
||||
// // If you use a custom Syncer, set the filter there, not like this
|
||||
// client.Syncer.(*mautrix.DefaultSyncer).FilterJSON = &filter
|
||||
// client.Store = mautrix.NewAccountDataStore("com.example.mybot.store", client)
|
||||
// go func() {
|
||||
// err := client.Sync()
|
||||
// // don't forget to check err
|
||||
// }()
|
||||
func NewAccountDataStore(eventType string, client *Client) *AccountDataStore {
|
||||
return &AccountDataStore{
|
||||
EventType: eventType,
|
||||
client: client,
|
||||
}
|
||||
}
|
||||
8
vendor/maunium.net/go/mautrix/url.go
generated
vendored
8
vendor/maunium.net/go/mautrix/url.go
generated
vendored
@@ -13,7 +13,7 @@ import (
|
||||
"strings"
|
||||
)
|
||||
|
||||
func parseAndNormalizeBaseURL(homeserverURL string) (*url.URL, error) {
|
||||
func ParseAndNormalizeBaseURL(homeserverURL string) (*url.URL, error) {
|
||||
hsURL, err := url.Parse(homeserverURL)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
@@ -43,7 +43,7 @@ func BuildURL(baseURL *url.URL, path ...interface{}) *url.URL {
|
||||
parts[i+1] = casted
|
||||
case int:
|
||||
parts[i+1] = strconv.Itoa(casted)
|
||||
case Stringifiable:
|
||||
case fmt.Stringer:
|
||||
parts[i+1] = casted.String()
|
||||
default:
|
||||
parts[i+1] = fmt.Sprint(casted)
|
||||
@@ -93,8 +93,8 @@ func (mup MediaURLPath) FullPath() []interface{} {
|
||||
func (cli *Client) BuildURLWithQuery(urlPath PrefixableURLPath, urlQuery map[string]string) string {
|
||||
hsURL := *BuildURL(cli.HomeserverURL, urlPath.FullPath()...)
|
||||
query := hsURL.Query()
|
||||
if cli.AppServiceUserID != "" {
|
||||
query.Set("user_id", string(cli.AppServiceUserID))
|
||||
if cli.SetAppServiceUserID {
|
||||
query.Set("user_id", string(cli.UserID))
|
||||
}
|
||||
if urlQuery != nil {
|
||||
for k, v := range urlQuery {
|
||||
|
||||
28
vendor/maunium.net/go/mautrix/util/callermarshal.go
generated
vendored
Normal file
28
vendor/maunium.net/go/mautrix/util/callermarshal.go
generated
vendored
Normal file
@@ -0,0 +1,28 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
// CallerWithFunctionName is an implementation for zerolog.CallerMarshalFunc that includes the caller function name
|
||||
// in addition to the file and line number.
|
||||
//
|
||||
// Use as
|
||||
//
|
||||
// zerolog.CallerMarshalFunc = util.CallerWithFunctionName
|
||||
func CallerWithFunctionName(pc uintptr, file string, line int) string {
|
||||
files := strings.Split(file, "/")
|
||||
file = files[len(files)-1]
|
||||
name := runtime.FuncForPC(pc).Name()
|
||||
fns := strings.Split(name, ".")
|
||||
name = fns[len(fns)-1]
|
||||
return fmt.Sprintf("%s:%d:%s()", file, line, name)
|
||||
}
|
||||
288
vendor/maunium.net/go/mautrix/util/configupgrade/helper.go
generated
vendored
288
vendor/maunium.net/go/mautrix/util/configupgrade/helper.go
generated
vendored
@@ -1,288 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package configupgrade
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"strings"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type YAMLMap map[string]YAMLNode
|
||||
type YAMLList []YAMLNode
|
||||
|
||||
type YAMLNode struct {
|
||||
*yaml.Node
|
||||
Map YAMLMap
|
||||
List YAMLList
|
||||
Key *yaml.Node
|
||||
}
|
||||
|
||||
type YAMLType uint32
|
||||
|
||||
const (
|
||||
Null YAMLType = 1 << iota
|
||||
Bool
|
||||
Str
|
||||
Int
|
||||
Float
|
||||
Timestamp
|
||||
List
|
||||
Map
|
||||
Binary
|
||||
)
|
||||
|
||||
func (t YAMLType) String() string {
|
||||
switch t {
|
||||
case Null:
|
||||
return NullTag
|
||||
case Bool:
|
||||
return BoolTag
|
||||
case Str:
|
||||
return StrTag
|
||||
case Int:
|
||||
return IntTag
|
||||
case Float:
|
||||
return FloatTag
|
||||
case Timestamp:
|
||||
return TimestampTag
|
||||
case List:
|
||||
return SeqTag
|
||||
case Map:
|
||||
return MapTag
|
||||
case Binary:
|
||||
return BinaryTag
|
||||
default:
|
||||
panic(fmt.Errorf("can't convert type %d to string", t))
|
||||
}
|
||||
}
|
||||
|
||||
func tagToType(tag string) YAMLType {
|
||||
switch tag {
|
||||
case NullTag:
|
||||
return Null
|
||||
case BoolTag:
|
||||
return Bool
|
||||
case StrTag:
|
||||
return Str
|
||||
case IntTag:
|
||||
return Int
|
||||
case FloatTag:
|
||||
return Float
|
||||
case TimestampTag:
|
||||
return Timestamp
|
||||
case SeqTag:
|
||||
return List
|
||||
case MapTag:
|
||||
return Map
|
||||
case BinaryTag:
|
||||
return Binary
|
||||
default:
|
||||
return 0
|
||||
}
|
||||
}
|
||||
|
||||
const (
|
||||
NullTag = "!!null"
|
||||
BoolTag = "!!bool"
|
||||
StrTag = "!!str"
|
||||
IntTag = "!!int"
|
||||
FloatTag = "!!float"
|
||||
TimestampTag = "!!timestamp"
|
||||
SeqTag = "!!seq"
|
||||
MapTag = "!!map"
|
||||
BinaryTag = "!!binary"
|
||||
)
|
||||
|
||||
func fromNode(node, key *yaml.Node) YAMLNode {
|
||||
switch node.Kind {
|
||||
case yaml.DocumentNode:
|
||||
return fromNode(node.Content[0], nil)
|
||||
case yaml.AliasNode:
|
||||
return fromNode(node.Alias, nil)
|
||||
case yaml.MappingNode:
|
||||
return YAMLNode{
|
||||
Node: node,
|
||||
Map: parseYAMLMap(node),
|
||||
Key: key,
|
||||
}
|
||||
case yaml.SequenceNode:
|
||||
return YAMLNode{
|
||||
Node: node,
|
||||
List: parseYAMLList(node),
|
||||
}
|
||||
default:
|
||||
return YAMLNode{Node: node, Key: key}
|
||||
}
|
||||
}
|
||||
|
||||
func (yn *YAMLNode) toNode() *yaml.Node {
|
||||
yn.UpdateContent()
|
||||
return yn.Node
|
||||
}
|
||||
|
||||
func (yn *YAMLNode) UpdateContent() {
|
||||
switch {
|
||||
case yn.Map != nil && yn.Node.Kind == yaml.MappingNode:
|
||||
yn.Content = yn.Map.toNodes()
|
||||
case yn.List != nil && yn.Node.Kind == yaml.SequenceNode:
|
||||
yn.Content = yn.List.toNodes()
|
||||
}
|
||||
}
|
||||
|
||||
func parseYAMLList(node *yaml.Node) YAMLList {
|
||||
data := make(YAMLList, len(node.Content))
|
||||
for i, item := range node.Content {
|
||||
data[i] = fromNode(item, nil)
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (yl YAMLList) toNodes() []*yaml.Node {
|
||||
nodes := make([]*yaml.Node, len(yl))
|
||||
for i, item := range yl {
|
||||
nodes[i] = item.toNode()
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
func parseYAMLMap(node *yaml.Node) YAMLMap {
|
||||
if len(node.Content)%2 != 0 {
|
||||
panic(fmt.Errorf("uneven number of items in YAML map (%d)", len(node.Content)))
|
||||
}
|
||||
data := make(YAMLMap, len(node.Content)/2)
|
||||
for i := 0; i < len(node.Content); i += 2 {
|
||||
key := node.Content[i]
|
||||
value := node.Content[i+1]
|
||||
if key.Kind == yaml.ScalarNode {
|
||||
data[key.Value] = fromNode(value, key)
|
||||
}
|
||||
}
|
||||
return data
|
||||
}
|
||||
|
||||
func (ym YAMLMap) toNodes() []*yaml.Node {
|
||||
nodes := make([]*yaml.Node, len(ym)*2)
|
||||
i := 0
|
||||
for key, value := range ym {
|
||||
nodes[i] = makeStringNode(key)
|
||||
nodes[i+1] = value.toNode()
|
||||
i += 2
|
||||
}
|
||||
return nodes
|
||||
}
|
||||
|
||||
func makeStringNode(val string) *yaml.Node {
|
||||
var node yaml.Node
|
||||
node.SetString(val)
|
||||
return &node
|
||||
}
|
||||
|
||||
func StringNode(val string) YAMLNode {
|
||||
return YAMLNode{Node: makeStringNode(val)}
|
||||
}
|
||||
|
||||
type Helper struct {
|
||||
Base YAMLNode
|
||||
Config YAMLNode
|
||||
}
|
||||
|
||||
func NewHelper(base, cfg *yaml.Node) *Helper {
|
||||
return &Helper{
|
||||
Base: fromNode(base, nil),
|
||||
Config: fromNode(cfg, nil),
|
||||
}
|
||||
}
|
||||
|
||||
func (helper *Helper) AddSpaceBeforeComment(path ...string) {
|
||||
node := helper.GetBaseNode(path...)
|
||||
if node == nil || node.Key == nil {
|
||||
panic(fmt.Errorf("didn't find key at %+v", path))
|
||||
}
|
||||
node.Key.HeadComment = "\n" + node.Key.HeadComment
|
||||
}
|
||||
|
||||
func (helper *Helper) Copy(allowedTypes YAMLType, path ...string) {
|
||||
base, cfg := helper.Base, helper.Config
|
||||
var ok bool
|
||||
for _, item := range path {
|
||||
cfg, ok = cfg.Map[item]
|
||||
if !ok {
|
||||
return
|
||||
}
|
||||
base, ok = base.Map[item]
|
||||
if !ok {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "Ignoring config field %s which is missing in base config\n", strings.Join(path, "->"))
|
||||
return
|
||||
}
|
||||
}
|
||||
if allowedTypes&tagToType(cfg.Tag) == 0 {
|
||||
_, _ = fmt.Fprintf(os.Stderr, "Ignoring incorrect config field type %s at %s\n", cfg.Tag, strings.Join(path, "->"))
|
||||
return
|
||||
}
|
||||
base.Tag = cfg.Tag
|
||||
base.Style = cfg.Style
|
||||
switch base.Kind {
|
||||
case yaml.ScalarNode:
|
||||
base.Value = cfg.Value
|
||||
case yaml.SequenceNode, yaml.MappingNode:
|
||||
base.Content = cfg.Content
|
||||
}
|
||||
}
|
||||
|
||||
func getNode(cfg YAMLNode, path []string) *YAMLNode {
|
||||
var ok bool
|
||||
for _, item := range path {
|
||||
cfg, ok = cfg.Map[item]
|
||||
if !ok {
|
||||
return nil
|
||||
}
|
||||
}
|
||||
return &cfg
|
||||
}
|
||||
|
||||
func (helper *Helper) GetNode(path ...string) *YAMLNode {
|
||||
return getNode(helper.Config, path)
|
||||
}
|
||||
|
||||
func (helper *Helper) GetBaseNode(path ...string) *YAMLNode {
|
||||
return getNode(helper.Base, path)
|
||||
}
|
||||
|
||||
func (helper *Helper) Get(tag YAMLType, path ...string) (string, bool) {
|
||||
node := helper.GetNode(path...)
|
||||
if node == nil || node.Kind != yaml.ScalarNode || tag&tagToType(node.Tag) == 0 {
|
||||
return "", false
|
||||
}
|
||||
return node.Value, true
|
||||
}
|
||||
|
||||
func (helper *Helper) GetBase(path ...string) string {
|
||||
return helper.GetBaseNode(path...).Value
|
||||
}
|
||||
|
||||
func (helper *Helper) Set(tag YAMLType, value string, path ...string) {
|
||||
base := helper.Base
|
||||
for _, item := range path {
|
||||
base = base.Map[item]
|
||||
}
|
||||
base.Tag = tag.String()
|
||||
base.Value = value
|
||||
}
|
||||
|
||||
func (helper *Helper) SetMap(value YAMLMap, path ...string) {
|
||||
base := helper.Base
|
||||
for _, item := range path {
|
||||
base = base.Map[item]
|
||||
}
|
||||
if base.Tag != MapTag || base.Kind != yaml.MappingNode {
|
||||
panic(fmt.Errorf("invalid target for SetMap(%+v): tag:%s, kind:%d", path, base.Tag, base.Kind))
|
||||
}
|
||||
base.Content = value.toNodes()
|
||||
}
|
||||
108
vendor/maunium.net/go/mautrix/util/configupgrade/upgrade.go
generated
vendored
108
vendor/maunium.net/go/mautrix/util/configupgrade/upgrade.go
generated
vendored
@@ -1,108 +0,0 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package configupgrade
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"os"
|
||||
"path"
|
||||
|
||||
"gopkg.in/yaml.v3"
|
||||
)
|
||||
|
||||
type Upgrader interface {
|
||||
DoUpgrade(helper *Helper)
|
||||
}
|
||||
|
||||
type SpacedUpgrader interface {
|
||||
Upgrader
|
||||
SpacedBlocks() [][]string
|
||||
}
|
||||
|
||||
type BaseUpgrader interface {
|
||||
Upgrader
|
||||
GetBase() string
|
||||
}
|
||||
|
||||
type StructUpgrader struct {
|
||||
SimpleUpgrader
|
||||
Blocks [][]string
|
||||
Base string
|
||||
}
|
||||
|
||||
func (su *StructUpgrader) SpacedBlocks() [][]string {
|
||||
return su.Blocks
|
||||
}
|
||||
|
||||
func (su *StructUpgrader) GetBase() string {
|
||||
return su.Base
|
||||
}
|
||||
|
||||
type SimpleUpgrader func(helper *Helper)
|
||||
|
||||
func (su SimpleUpgrader) DoUpgrade(helper *Helper) {
|
||||
su(helper)
|
||||
}
|
||||
|
||||
func (helper *Helper) apply(upgrader Upgrader) {
|
||||
upgrader.DoUpgrade(helper)
|
||||
helper.addSpaces(upgrader)
|
||||
}
|
||||
|
||||
func (helper *Helper) addSpaces(upgrader Upgrader) {
|
||||
spaced, ok := upgrader.(SpacedUpgrader)
|
||||
if ok {
|
||||
for _, spacePath := range spaced.SpacedBlocks() {
|
||||
helper.AddSpaceBeforeComment(spacePath...)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
func Do(configPath string, save bool, upgrader BaseUpgrader, additional ...Upgrader) ([]byte, bool, error) {
|
||||
sourceData, err := os.ReadFile(configPath)
|
||||
if err != nil {
|
||||
return nil, false, fmt.Errorf("failed to read config: %w", err)
|
||||
}
|
||||
var base, cfg yaml.Node
|
||||
err = yaml.Unmarshal([]byte(upgrader.GetBase()), &base)
|
||||
if err != nil {
|
||||
return sourceData, false, fmt.Errorf("failed to unmarshal example config: %w", err)
|
||||
}
|
||||
err = yaml.Unmarshal(sourceData, &cfg)
|
||||
if err != nil {
|
||||
return sourceData, false, fmt.Errorf("failed to unmarshal config: %w", err)
|
||||
}
|
||||
|
||||
helper := NewHelper(&base, &cfg)
|
||||
helper.apply(upgrader)
|
||||
for _, add := range additional {
|
||||
helper.apply(add)
|
||||
}
|
||||
|
||||
output, err := yaml.Marshal(&base)
|
||||
if err != nil {
|
||||
return sourceData, false, fmt.Errorf("failed to marshal updated config: %w", err)
|
||||
}
|
||||
if save {
|
||||
var tempFile *os.File
|
||||
tempFile, err = os.CreateTemp(path.Dir(configPath), "mautrix-config-*.yaml")
|
||||
if err != nil {
|
||||
return output, true, fmt.Errorf("failed to create temp file for writing config: %w", err)
|
||||
}
|
||||
_, err = tempFile.Write(output)
|
||||
if err != nil {
|
||||
_ = os.Remove(tempFile.Name())
|
||||
return output, true, fmt.Errorf("failed to write updated config to temp file: %w", err)
|
||||
}
|
||||
err = os.Rename(tempFile.Name(), configPath)
|
||||
if err != nil {
|
||||
_ = os.Remove(tempFile.Name())
|
||||
return output, true, fmt.Errorf("failed to override current config with temp file: %w", err)
|
||||
}
|
||||
}
|
||||
return output, true, nil
|
||||
}
|
||||
28
vendor/maunium.net/go/mautrix/util/dbutil/connlog.go
generated
vendored
28
vendor/maunium.net/go/mautrix/util/dbutil/connlog.go
generated
vendored
@@ -23,7 +23,7 @@ func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args .
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start))
|
||||
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
|
||||
return res, err
|
||||
}
|
||||
|
||||
@@ -31,7 +31,7 @@ func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start))
|
||||
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
|
||||
return &LoggingRows{
|
||||
ctx: ctx,
|
||||
db: le.db,
|
||||
@@ -46,7 +46,7 @@ func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, ar
|
||||
start := time.Now()
|
||||
query = le.db.mutateQuery(query)
|
||||
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
|
||||
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start))
|
||||
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start), nil)
|
||||
return row
|
||||
}
|
||||
|
||||
@@ -73,7 +73,7 @@ type loggingDB struct {
|
||||
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
|
||||
start := time.Now()
|
||||
tx, err := ld.db.RawDB.BeginTx(ctx, opts)
|
||||
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start))
|
||||
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start), err)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
@@ -81,6 +81,7 @@ func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Logging
|
||||
LoggingExecable: LoggingExecable{UnderlyingExecable: tx, db: ld.db},
|
||||
UnderlyingTx: tx,
|
||||
ctx: ctx,
|
||||
StartTime: start,
|
||||
}, nil
|
||||
}
|
||||
|
||||
@@ -92,22 +93,35 @@ type LoggingTxn struct {
|
||||
LoggingExecable
|
||||
UnderlyingTx *sql.Tx
|
||||
ctx context.Context
|
||||
|
||||
StartTime time.Time
|
||||
EndTime time.Time
|
||||
noTotalLog bool
|
||||
}
|
||||
|
||||
func (lt *LoggingTxn) Commit() error {
|
||||
start := time.Now()
|
||||
err := lt.UnderlyingTx.Commit()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start))
|
||||
lt.endLog()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start), err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (lt *LoggingTxn) Rollback() error {
|
||||
start := time.Now()
|
||||
err := lt.UnderlyingTx.Rollback()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start))
|
||||
lt.endLog()
|
||||
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start), err)
|
||||
return err
|
||||
}
|
||||
|
||||
func (lt *LoggingTxn) endLog() {
|
||||
lt.EndTime = time.Now()
|
||||
if !lt.noTotalLog {
|
||||
lt.db.Log.QueryTiming(lt.ctx, "<Transaction>", "", nil, -1, lt.EndTime.Sub(lt.StartTime), nil)
|
||||
}
|
||||
}
|
||||
|
||||
type LoggingRows struct {
|
||||
ctx context.Context
|
||||
db *Database
|
||||
@@ -120,7 +134,7 @@ type LoggingRows struct {
|
||||
|
||||
func (lrs *LoggingRows) stopTiming() {
|
||||
if !lrs.start.IsZero() {
|
||||
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start))
|
||||
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start), lrs.rs.Err())
|
||||
lrs.start = time.Time{}
|
||||
}
|
||||
}
|
||||
|
||||
19
vendor/maunium.net/go/mautrix/util/dbutil/database.go
generated
vendored
19
vendor/maunium.net/go/mautrix/util/dbutil/database.go
generated
vendored
@@ -13,8 +13,6 @@ import (
|
||||
"regexp"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"maunium.net/go/mautrix/bridge/bridgeconfig"
|
||||
)
|
||||
|
||||
type Dialect int
|
||||
@@ -38,7 +36,7 @@ func (dialect Dialect) String() string {
|
||||
|
||||
func ParseDialect(engine string) (Dialect, error) {
|
||||
switch strings.ToLower(engine) {
|
||||
case "postgres", "postgresql":
|
||||
case "postgres", "postgresql", "pgx":
|
||||
return Postgres, nil
|
||||
case "sqlite3", "sqlite", "litestream", "sqlite3-fk-wal":
|
||||
return SQLite, nil
|
||||
@@ -176,7 +174,18 @@ func NewWithDialect(uri, rawDialect string) (*Database, error) {
|
||||
return NewWithDB(db, rawDialect)
|
||||
}
|
||||
|
||||
func (db *Database) Configure(cfg bridgeconfig.DatabaseConfig) error {
|
||||
type Config struct {
|
||||
Type string `yaml:"type"`
|
||||
URI string `yaml:"uri"`
|
||||
|
||||
MaxOpenConns int `yaml:"max_open_conns"`
|
||||
MaxIdleConns int `yaml:"max_idle_conns"`
|
||||
|
||||
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
|
||||
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
|
||||
}
|
||||
|
||||
func (db *Database) Configure(cfg Config) error {
|
||||
db.RawDB.SetMaxOpenConns(cfg.MaxOpenConns)
|
||||
db.RawDB.SetMaxIdleConns(cfg.MaxIdleConns)
|
||||
if len(cfg.ConnMaxIdleTime) > 0 {
|
||||
@@ -196,7 +205,7 @@ func (db *Database) Configure(cfg bridgeconfig.DatabaseConfig) error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func NewFromConfig(owner string, cfg bridgeconfig.DatabaseConfig, logger DatabaseLogger) (*Database, error) {
|
||||
func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database, error) {
|
||||
dialect, err := ParseDialect(cfg.Type)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
|
||||
67
vendor/maunium.net/go/mautrix/util/dbutil/log.go
generated
vendored
67
vendor/maunium.net/go/mautrix/util/dbutil/log.go
generated
vendored
@@ -11,9 +11,9 @@ import (
|
||||
)
|
||||
|
||||
type DatabaseLogger interface {
|
||||
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration)
|
||||
WarnUnsupportedVersion(current, latest int)
|
||||
PrepareUpgrade(current, latest int)
|
||||
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error)
|
||||
WarnUnsupportedVersion(current, compat, latest int)
|
||||
PrepareUpgrade(current, compat, latest int)
|
||||
DoUpgrade(from, to int, message string, txn bool)
|
||||
// Deprecated: legacy warning method, return errors instead
|
||||
Warn(msg string, args ...interface{})
|
||||
@@ -23,35 +23,36 @@ type noopLogger struct{}
|
||||
|
||||
var NoopLogger DatabaseLogger = &noopLogger{}
|
||||
|
||||
func (n noopLogger) WarnUnsupportedVersion(_, _ int) {}
|
||||
func (n noopLogger) PrepareUpgrade(_, _ int) {}
|
||||
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
|
||||
func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
|
||||
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
|
||||
func (n noopLogger) Warn(msg string, args ...interface{}) {}
|
||||
|
||||
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration) {
|
||||
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) {
|
||||
}
|
||||
|
||||
type mauLogger struct {
|
||||
l maulogger.Logger
|
||||
}
|
||||
|
||||
// Deprecated: Use zerolog instead
|
||||
func MauLogger(log maulogger.Logger) DatabaseLogger {
|
||||
return &mauLogger{l: log}
|
||||
}
|
||||
|
||||
func (m mauLogger) WarnUnsupportedVersion(current, latest int) {
|
||||
m.l.Warnfln("Unsupported database schema version: currently on v%d, latest known: v%d - continuing anyway", current, latest)
|
||||
func (m mauLogger) WarnUnsupportedVersion(current, compat, latest int) {
|
||||
m.l.Warnfln("Unsupported database schema version: currently on v%d (compatible down to v%d), latest known: v%d - continuing anyway", current, compat, latest)
|
||||
}
|
||||
|
||||
func (m mauLogger) PrepareUpgrade(current, latest int) {
|
||||
m.l.Infofln("Database currently on v%d, latest: v%d", current, latest)
|
||||
func (m mauLogger) PrepareUpgrade(current, compat, latest int) {
|
||||
m.l.Infofln("Database currently on v%d (compat: v%d), latest known: v%d", current, compat, latest)
|
||||
}
|
||||
|
||||
func (m mauLogger) DoUpgrade(from, to int, message string, _ bool) {
|
||||
m.l.Infofln("Upgrading database from v%d to v%d: %s", from, to, message)
|
||||
}
|
||||
|
||||
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration) {
|
||||
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration, _ error) {
|
||||
if duration > 1*time.Second {
|
||||
m.l.Warnfln("%s(%s) took %.3f seconds", method, query, duration.Seconds())
|
||||
}
|
||||
@@ -63,28 +64,40 @@ func (m mauLogger) Warn(msg string, args ...interface{}) {
|
||||
|
||||
type zeroLogger struct {
|
||||
l *zerolog.Logger
|
||||
ZeroLogSettings
|
||||
}
|
||||
|
||||
func ZeroLogger(log zerolog.Logger) DatabaseLogger {
|
||||
return ZeroLoggerPtr(&log)
|
||||
type ZeroLogSettings struct {
|
||||
CallerSkipFrame int
|
||||
Caller bool
|
||||
}
|
||||
|
||||
func ZeroLoggerPtr(log *zerolog.Logger) DatabaseLogger {
|
||||
return &zeroLogger{l: log}
|
||||
func ZeroLogger(log zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
|
||||
return ZeroLoggerPtr(&log, cfg...)
|
||||
}
|
||||
|
||||
func (z zeroLogger) WarnUnsupportedVersion(current, latest int) {
|
||||
func ZeroLoggerPtr(log *zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
|
||||
wrapped := &zeroLogger{l: log}
|
||||
if len(cfg) > 0 {
|
||||
wrapped.ZeroLogSettings = cfg[0]
|
||||
}
|
||||
return wrapped
|
||||
}
|
||||
|
||||
func (z zeroLogger) WarnUnsupportedVersion(current, compat, latest int) {
|
||||
z.l.Warn().
|
||||
Int("current_db_version", current).
|
||||
Int("current_version", current).
|
||||
Int("oldest_compatible_version", compat).
|
||||
Int("latest_known_version", latest).
|
||||
Msg("Unsupported database schema version, continuing anyway")
|
||||
}
|
||||
|
||||
func (z zeroLogger) PrepareUpgrade(current, latest int) {
|
||||
func (z zeroLogger) PrepareUpgrade(current, compat, latest int) {
|
||||
evt := z.l.Info().
|
||||
Int("current_db_version", current).
|
||||
Int("current_version", current).
|
||||
Int("oldest_compatible_version", compat).
|
||||
Int("latest_known_version", latest)
|
||||
if current == latest {
|
||||
if current >= latest {
|
||||
evt.Msg("Database is up to date")
|
||||
} else {
|
||||
evt.Msg("Preparing to update database schema")
|
||||
@@ -102,9 +115,9 @@ func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
|
||||
|
||||
var whitespaceRegex = regexp.MustCompile(`\s+`)
|
||||
|
||||
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration) {
|
||||
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) {
|
||||
log := zerolog.Ctx(ctx)
|
||||
if log.GetLevel() == zerolog.Disabled {
|
||||
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
|
||||
log = z.l
|
||||
}
|
||||
if log.GetLevel() != zerolog.TraceLevel && duration < 1*time.Second {
|
||||
@@ -116,17 +129,21 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
|
||||
}
|
||||
query = strings.TrimSpace(whitespaceRegex.ReplaceAllLiteralString(query, " "))
|
||||
log.Trace().
|
||||
Err(err).
|
||||
Int64("duration_µs", duration.Microseconds()).
|
||||
Str("method", method).
|
||||
Str("query", query).
|
||||
Interface("query_args", args).
|
||||
Msg("Query")
|
||||
if duration >= 1*time.Second {
|
||||
log.Warn().
|
||||
evt := log.Warn().
|
||||
Float64("duration_seconds", duration.Seconds()).
|
||||
Str("method", method).
|
||||
Str("query", query).
|
||||
Msg("Query took long")
|
||||
Str("query", query)
|
||||
if z.Caller {
|
||||
evt = evt.Caller(z.CallerSkipFrame)
|
||||
}
|
||||
evt.Msg("Query took long")
|
||||
}
|
||||
}
|
||||
|
||||
|
||||
3
vendor/maunium.net/go/mautrix/util/dbutil/samples/05-compat.sql
generated
vendored
Normal file
3
vendor/maunium.net/go/mautrix/util/dbutil/samples/05-compat.sql
generated
vendored
Normal file
@@ -0,0 +1,3 @@
|
||||
-- v5 (compatible with v3+): Sample backwards-compatible upgrade
|
||||
|
||||
INSERT INTO foo VALUES ('meow 2', '{}');
|
||||
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/05-postgres.sql
generated
vendored
Normal file
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/05-postgres.sql
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
INSERT INTO foo VALUES ('meow 2', '{}');
|
||||
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/05-sqlite3.sql
generated
vendored
Normal file
1
vendor/maunium.net/go/mautrix/util/dbutil/samples/output/05-sqlite3.sql
generated
vendored
Normal file
@@ -0,0 +1 @@
|
||||
INSERT INTO foo VALUES ('meow 2', '{}');
|
||||
82
vendor/maunium.net/go/mautrix/util/dbutil/transaction.go
generated
vendored
Normal file
82
vendor/maunium.net/go/mautrix/util/dbutil/transaction.go
generated
vendored
Normal file
@@ -0,0 +1,82 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package dbutil
|
||||
|
||||
import (
|
||||
"context"
|
||||
"database/sql"
|
||||
"errors"
|
||||
"fmt"
|
||||
"time"
|
||||
|
||||
"github.com/rs/zerolog"
|
||||
|
||||
"maunium.net/go/mautrix/util"
|
||||
)
|
||||
|
||||
var (
|
||||
ErrTxn = errors.New("transaction")
|
||||
ErrTxnBegin = fmt.Errorf("%w: begin", ErrTxn)
|
||||
ErrTxnCommit = fmt.Errorf("%w: commit", ErrTxn)
|
||||
)
|
||||
|
||||
type contextKey int
|
||||
|
||||
const (
|
||||
ContextKeyDatabaseTransaction contextKey = iota
|
||||
ContextKeyDoTxnCallerSkip
|
||||
)
|
||||
|
||||
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
|
||||
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
|
||||
zerolog.Ctx(ctx).Debug().Msg("Already in a transaction, not creating a new one")
|
||||
return fn(ctx)
|
||||
}
|
||||
log := zerolog.Ctx(ctx).With().Str("db_txn_id", util.RandomString(12)).Logger()
|
||||
start := time.Now()
|
||||
defer func() {
|
||||
dur := time.Since(start)
|
||||
if dur > time.Second {
|
||||
val := ctx.Value(ContextKeyDoTxnCallerSkip)
|
||||
callerSkip := 2
|
||||
if val != nil {
|
||||
callerSkip += val.(int)
|
||||
}
|
||||
log.Warn().
|
||||
Float64("duration_seconds", dur.Seconds()).
|
||||
Caller(callerSkip).
|
||||
Msg("Transaction took a long time")
|
||||
}
|
||||
}()
|
||||
tx, err := db.BeginTx(ctx, opts)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msg("Failed to begin transaction")
|
||||
return util.NewDualError(ErrTxnBegin, err)
|
||||
}
|
||||
log.Trace().Msg("Transaction started")
|
||||
tx.noTotalLog = true
|
||||
ctx = log.WithContext(ctx)
|
||||
ctx = context.WithValue(ctx, ContextKeyDatabaseTransaction, tx)
|
||||
err = fn(ctx)
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msg("Database transaction failed, rolling back")
|
||||
rollbackErr := tx.Rollback()
|
||||
if rollbackErr != nil {
|
||||
log.Warn().Err(rollbackErr).Msg("Rollback after transaction error failed")
|
||||
} else {
|
||||
log.Trace().Msg("Rollback successful")
|
||||
}
|
||||
return err
|
||||
}
|
||||
err = tx.Commit()
|
||||
if err != nil {
|
||||
log.Trace().Err(err).Msg("Commit failed")
|
||||
return util.NewDualError(ErrTxnCommit, err)
|
||||
}
|
||||
log.Trace().Msg("Commit successful")
|
||||
return nil
|
||||
}
|
||||
117
vendor/maunium.net/go/mautrix/util/dbutil/upgrades.go
generated
vendored
117
vendor/maunium.net/go/mautrix/util/dbutil/upgrades.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -18,42 +18,97 @@ type upgrade struct {
|
||||
message string
|
||||
fn upgradeFunc
|
||||
|
||||
upgradesTo int
|
||||
transaction bool
|
||||
upgradesTo int
|
||||
compatVersion int
|
||||
transaction bool
|
||||
}
|
||||
|
||||
var ErrUnsupportedDatabaseVersion = fmt.Errorf("unsupported database schema version")
|
||||
var ErrForeignTables = fmt.Errorf("the database contains foreign tables")
|
||||
var ErrNotOwned = fmt.Errorf("the database is owned by")
|
||||
var ErrUnsupportedDatabaseVersion = errors.New("unsupported database schema version")
|
||||
var ErrForeignTables = errors.New("the database contains foreign tables")
|
||||
var ErrNotOwned = errors.New("the database is owned by")
|
||||
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
|
||||
|
||||
func (db *Database) getVersion() (int, error) {
|
||||
_, err := db.Exec(fmt.Sprintf("CREATE TABLE IF NOT EXISTS %s (version INTEGER)", db.VersionTable))
|
||||
if err != nil {
|
||||
return -1, err
|
||||
func (db *Database) upgradeVersionTable() error {
|
||||
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
|
||||
return fmt.Errorf("failed to check if version table is up to date: %w", err)
|
||||
} else if !compatColumnExists {
|
||||
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
|
||||
return fmt.Errorf("failed to check if version table exists: %w", err)
|
||||
} else if !tableExists {
|
||||
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to create version table: %w", err)
|
||||
}
|
||||
} else {
|
||||
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
|
||||
if err != nil {
|
||||
return fmt.Errorf("failed to add compat column to version table: %w", err)
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
version := 0
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT version FROM %s LIMIT 1", db.VersionTable)).Scan(&version)
|
||||
if err != nil && !errors.Is(err, sql.ErrNoRows) {
|
||||
return -1, err
|
||||
}
|
||||
return version, nil
|
||||
return nil
|
||||
}
|
||||
|
||||
const tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
|
||||
const tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=$1)"
|
||||
func (db *Database) getVersion() (version, compat int, err error) {
|
||||
if err = db.upgradeVersionTable(); err != nil {
|
||||
return
|
||||
}
|
||||
|
||||
func (db *Database) tableExists(table string) (exists bool, err error) {
|
||||
if db.Dialect == SQLite {
|
||||
var compatNull sql.NullInt32
|
||||
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
|
||||
if errors.Is(err, sql.ErrNoRows) {
|
||||
err = nil
|
||||
}
|
||||
if compatNull.Valid && compatNull.Int32 != 0 {
|
||||
compat = int(compatNull.Int32)
|
||||
} else {
|
||||
compat = version
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
|
||||
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
|
||||
)
|
||||
|
||||
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
|
||||
if tx == nil {
|
||||
tx = db
|
||||
}
|
||||
switch db.Dialect {
|
||||
case SQLite:
|
||||
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
|
||||
} else if db.Dialect == Postgres {
|
||||
case Postgres:
|
||||
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
|
||||
default:
|
||||
err = ErrUnsupportedDialect
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
const (
|
||||
columnExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name=$2)"
|
||||
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
|
||||
)
|
||||
|
||||
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
|
||||
if tx == nil {
|
||||
tx = db
|
||||
}
|
||||
switch db.Dialect {
|
||||
case SQLite:
|
||||
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
|
||||
case Postgres:
|
||||
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
|
||||
default:
|
||||
err = ErrUnsupportedDialect
|
||||
}
|
||||
return
|
||||
}
|
||||
|
||||
func (db *Database) tableExistsNoError(table string) bool {
|
||||
exists, err := db.tableExists(table)
|
||||
exists, err := db.TableExists(nil, table)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to check if table exists: %w", err))
|
||||
}
|
||||
@@ -94,12 +149,12 @@ func (db *Database) checkDatabaseOwner() error {
|
||||
return nil
|
||||
}
|
||||
|
||||
func (db *Database) setVersion(tx Execable, version int) error {
|
||||
func (db *Database) setVersion(tx Execable, version, compat int) error {
|
||||
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version) VALUES ($1)", db.VersionTable), version)
|
||||
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
|
||||
return err
|
||||
}
|
||||
|
||||
@@ -109,20 +164,20 @@ func (db *Database) Upgrade() error {
|
||||
return err
|
||||
}
|
||||
|
||||
version, err := db.getVersion()
|
||||
version, compat, err := db.getVersion()
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
if version > len(db.UpgradeTable) {
|
||||
if compat > len(db.UpgradeTable) {
|
||||
if db.IgnoreUnsupportedDatabase {
|
||||
db.Log.WarnUnsupportedVersion(version, len(db.UpgradeTable))
|
||||
db.Log.WarnUnsupportedVersion(version, compat, len(db.UpgradeTable))
|
||||
return nil
|
||||
}
|
||||
return fmt.Errorf("%w: currently on v%d, latest known: v%d", ErrUnsupportedDatabaseVersion, version, len(db.UpgradeTable))
|
||||
return fmt.Errorf("%w: currently on v%d (compatible down to v%d), latest known: v%d", ErrUnsupportedDatabaseVersion, version, compat, len(db.UpgradeTable))
|
||||
}
|
||||
|
||||
db.Log.PrepareUpgrade(version, len(db.UpgradeTable))
|
||||
db.Log.PrepareUpgrade(version, compat, len(db.UpgradeTable))
|
||||
logVersion := version
|
||||
for version < len(db.UpgradeTable) {
|
||||
upgradeItem := db.UpgradeTable[version]
|
||||
@@ -148,7 +203,7 @@ func (db *Database) Upgrade() error {
|
||||
}
|
||||
version = upgradeItem.upgradesTo
|
||||
logVersion = version
|
||||
err = db.setVersion(upgradeConn, version)
|
||||
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
40
vendor/maunium.net/go/mautrix/util/dbutil/upgradetable.go
generated
vendored
40
vendor/maunium.net/go/mautrix/util/dbutil/upgradetable.go
generated
vendored
@@ -1,4 +1,4 @@
|
||||
// Copyright (c) 2022 Tulir Asokan
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
@@ -29,14 +29,17 @@ func (ut *UpgradeTable) extend(toSize int) {
|
||||
}
|
||||
}
|
||||
|
||||
func (ut *UpgradeTable) Register(from, to int, message string, txn bool, fn upgradeFunc) {
|
||||
func (ut *UpgradeTable) Register(from, to, compat int, message string, txn bool, fn upgradeFunc) {
|
||||
if from < 0 {
|
||||
from += to
|
||||
}
|
||||
if from < 0 {
|
||||
panic("invalid from value in UpgradeTable.Register() call")
|
||||
}
|
||||
upg := upgrade{message: message, fn: fn, upgradesTo: to, transaction: txn}
|
||||
if compat <= 0 {
|
||||
compat = to
|
||||
}
|
||||
upg := upgrade{message: message, fn: fn, upgradesTo: to, compatVersion: compat, transaction: txn}
|
||||
if len(*ut) == from {
|
||||
*ut = append(*ut, upg)
|
||||
return
|
||||
@@ -55,7 +58,11 @@ func (ut *UpgradeTable) Register(from, to int, message string, txn bool, fn upgr
|
||||
// or
|
||||
//
|
||||
// -- v1: Message
|
||||
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
|
||||
//
|
||||
// Both syntaxes may also have a compatibility notice before the colon:
|
||||
//
|
||||
// -- v5 (compatible with v3+): Upgrade with backwards compatibility
|
||||
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+)(?: \(compatible with v(\d+)\+\))?: (.+)$`)
|
||||
|
||||
// To disable wrapping the upgrade in a single transaction, put `--transaction: off` on the second line.
|
||||
//
|
||||
@@ -64,7 +71,7 @@ var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+): (.+)$`)
|
||||
// // do dangerous stuff
|
||||
var transactionDisableRegex = regexp.MustCompile(`^-- transaction: (\w*)`)
|
||||
|
||||
func parseFileHeader(file []byte) (from, to int, message string, txn bool, lines [][]byte, err error) {
|
||||
func parseFileHeader(file []byte) (from, to, compat int, message string, txn bool, lines [][]byte, err error) {
|
||||
lines = bytes.Split(file, []byte("\n"))
|
||||
if len(lines) < 2 {
|
||||
err = errors.New("upgrade file too short")
|
||||
@@ -75,19 +82,22 @@ func parseFileHeader(file []byte) (from, to int, message string, txn bool, lines
|
||||
lines = lines[1:]
|
||||
if match == nil {
|
||||
err = errors.New("header not found")
|
||||
} else if len(match) != 4 {
|
||||
} else if len(match) != 5 {
|
||||
err = errors.New("unexpected number of items in regex match")
|
||||
} else if maybeFrom, err = strconv.Atoi(string(match[1])); len(match[1]) > 0 && err != nil {
|
||||
err = fmt.Errorf("invalid source version: %w", err)
|
||||
} else if to, err = strconv.Atoi(string(match[2])); err != nil {
|
||||
err = fmt.Errorf("invalid target version: %w", err)
|
||||
} else if compat, err = strconv.Atoi(string(match[3])); len(match[3]) > 0 && err != nil {
|
||||
err = fmt.Errorf("invalid compatible version: %w", err)
|
||||
} else {
|
||||
err = nil
|
||||
if len(match[1]) > 0 {
|
||||
from = maybeFrom
|
||||
} else {
|
||||
from = -1
|
||||
}
|
||||
message = string(match[3])
|
||||
message = string(match[4])
|
||||
txn = true
|
||||
match = transactionDisableRegex.FindSubmatch(lines[0])
|
||||
if match != nil {
|
||||
@@ -205,7 +215,7 @@ func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
|
||||
}
|
||||
}
|
||||
|
||||
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to int, message string, txn bool, fn upgradeFunc) {
|
||||
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to, compat int, message string, txn bool, fn upgradeFunc) {
|
||||
postgresName := fmt.Sprintf("%s.postgres.sql", name)
|
||||
sqliteName := fmt.Sprintf("%s.sqlite.sql", name)
|
||||
skipNames[postgresName] = struct{}{}
|
||||
@@ -218,15 +228,15 @@ func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{})
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
from, to, message, txn, _, err = parseFileHeader(postgresData)
|
||||
from, to, compat, message, txn, _, err = parseFileHeader(postgresData)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to parse header in %s: %w", postgresName, err))
|
||||
}
|
||||
sqliteFrom, sqliteTo, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
|
||||
sqliteFrom, sqliteTo, sqliteCompat, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
|
||||
if err != nil {
|
||||
panic(fmt.Errorf("failed to parse header in %s: %w", sqliteName, err))
|
||||
}
|
||||
if from != sqliteFrom || to != sqliteTo {
|
||||
if from != sqliteFrom || to != sqliteTo || compat != sqliteCompat {
|
||||
panic(fmt.Errorf("mismatching versions in postgres and sqlite versions of %s: %d/%d -> %d/%d", name, from, sqliteFrom, to, sqliteTo))
|
||||
} else if message != sqliteMessage {
|
||||
panic(fmt.Errorf("mismatching message in postgres and sqlite versions of %s: %q != %q", name, message, sqliteMessage))
|
||||
@@ -260,14 +270,14 @@ func (ut *UpgradeTable) RegisterFSPath(fs fullFS, dir string) {
|
||||
} else if _, skip := skipNames[file.Name()]; skip {
|
||||
// also do nothing
|
||||
} else if splitName := splitFileNameRegex.FindStringSubmatch(file.Name()); splitName != nil {
|
||||
from, to, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
|
||||
ut.Register(from, to, message, txn, fn)
|
||||
from, to, compat, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
|
||||
ut.Register(from, to, compat, message, txn, fn)
|
||||
} else if data, err := fs.ReadFile(filepath.Join(dir, file.Name())); err != nil {
|
||||
panic(err)
|
||||
} else if from, to, message, txn, lines, err := parseFileHeader(data); err != nil {
|
||||
} else if from, to, compat, message, txn, lines, err := parseFileHeader(data); err != nil {
|
||||
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
|
||||
} else {
|
||||
ut.Register(from, to, message, txn, sqlUpgradeFunc(file.Name(), lines))
|
||||
ut.Register(from, to, compat, message, txn, sqlUpgradeFunc(file.Name(), lines))
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
23
vendor/maunium.net/go/mautrix/util/returnonce.go
generated
vendored
Normal file
23
vendor/maunium.net/go/mautrix/util/returnonce.go
generated
vendored
Normal file
@@ -0,0 +1,23 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package util
|
||||
|
||||
import "sync"
|
||||
|
||||
// ReturnableOnce is a wrapper for sync.Once that can return a value
|
||||
type ReturnableOnce[Value any] struct {
|
||||
once sync.Once
|
||||
output Value
|
||||
err error
|
||||
}
|
||||
|
||||
func (ronce *ReturnableOnce[Value]) Do(fn func() (Value, error)) (Value, error) {
|
||||
ronce.once.Do(func() {
|
||||
ronce.output, ronce.err = fn()
|
||||
})
|
||||
return ronce.output, ronce.err
|
||||
}
|
||||
139
vendor/maunium.net/go/mautrix/util/ringbuffer.go
generated
vendored
Normal file
139
vendor/maunium.net/go/mautrix/util/ringbuffer.go
generated
vendored
Normal file
@@ -0,0 +1,139 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package util
|
||||
|
||||
import (
|
||||
"errors"
|
||||
"sync"
|
||||
)
|
||||
|
||||
type pair[Key comparable, Value any] struct {
|
||||
Set bool
|
||||
Key Key
|
||||
Value Value
|
||||
}
|
||||
|
||||
type RingBuffer[Key comparable, Value any] struct {
|
||||
ptr int
|
||||
data []pair[Key, Value]
|
||||
lock sync.RWMutex
|
||||
size int
|
||||
}
|
||||
|
||||
func NewRingBuffer[Key comparable, Value any](size int) *RingBuffer[Key, Value] {
|
||||
return &RingBuffer[Key, Value]{
|
||||
data: make([]pair[Key, Value], size),
|
||||
}
|
||||
}
|
||||
|
||||
var (
|
||||
// StopIteration can be returned by the RingBuffer.Iter or MapRingBuffer callbacks to stop iteration immediately.
|
||||
StopIteration = errors.New("stop iteration")
|
||||
|
||||
// SkipItem can be returned by the MapRingBuffer callback to skip adding a specific item.
|
||||
SkipItem = errors.New("skip item")
|
||||
)
|
||||
|
||||
func (rb *RingBuffer[Key, Value]) unlockedIter(callback func(key Key, val Value) error) error {
|
||||
end := rb.ptr
|
||||
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
|
||||
entry := rb.data[i]
|
||||
if !entry.Set {
|
||||
break
|
||||
}
|
||||
err := callback(entry.Key, entry.Value)
|
||||
if err != nil {
|
||||
if errors.Is(err, StopIteration) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[Key, Value]) Iter(callback func(key Key, val Value) error) error {
|
||||
rb.lock.RLock()
|
||||
defer rb.lock.RUnlock()
|
||||
return rb.unlockedIter(callback)
|
||||
}
|
||||
|
||||
func MapRingBuffer[Key comparable, Value, Output any](rb *RingBuffer[Key, Value], callback func(key Key, val Value) (Output, error)) ([]Output, error) {
|
||||
rb.lock.RLock()
|
||||
defer rb.lock.RUnlock()
|
||||
output := make([]Output, 0, rb.size)
|
||||
err := rb.unlockedIter(func(key Key, val Value) error {
|
||||
item, err := callback(key, val)
|
||||
if err != nil {
|
||||
if errors.Is(err, SkipItem) {
|
||||
return nil
|
||||
}
|
||||
return err
|
||||
}
|
||||
output = append(output, item)
|
||||
return nil
|
||||
})
|
||||
return output, err
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[Key, Value]) Size() int {
|
||||
rb.lock.RLock()
|
||||
defer rb.lock.RUnlock()
|
||||
return rb.size
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[Key, Value]) Contains(val Key) bool {
|
||||
_, ok := rb.Get(val)
|
||||
return ok
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[Key, Value]) Get(key Key) (val Value, found bool) {
|
||||
rb.lock.RLock()
|
||||
end := rb.ptr
|
||||
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
|
||||
if rb.data[i].Key == key {
|
||||
val = rb.data[i].Value
|
||||
found = true
|
||||
break
|
||||
}
|
||||
}
|
||||
rb.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[Key, Value]) Replace(key Key, val Value) bool {
|
||||
rb.lock.Lock()
|
||||
defer rb.lock.Unlock()
|
||||
end := rb.ptr
|
||||
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
|
||||
if rb.data[i].Key == key {
|
||||
rb.data[i].Value = val
|
||||
return true
|
||||
}
|
||||
}
|
||||
return false
|
||||
}
|
||||
|
||||
func (rb *RingBuffer[Key, Value]) Push(key Key, val Value) {
|
||||
rb.lock.Lock()
|
||||
rb.data[rb.ptr] = pair[Key, Value]{Key: key, Value: val, Set: true}
|
||||
rb.ptr = (rb.ptr + 1) % len(rb.data)
|
||||
if rb.size < len(rb.data) {
|
||||
rb.size++
|
||||
}
|
||||
rb.lock.Unlock()
|
||||
}
|
||||
|
||||
func clamp(index, len int) int {
|
||||
if index < 0 {
|
||||
return len + index
|
||||
} else if index >= len {
|
||||
return len - index
|
||||
} else {
|
||||
return index
|
||||
}
|
||||
}
|
||||
94
vendor/maunium.net/go/mautrix/util/syncmap.go
generated
vendored
Normal file
94
vendor/maunium.net/go/mautrix/util/syncmap.go
generated
vendored
Normal file
@@ -0,0 +1,94 @@
|
||||
// Copyright (c) 2023 Tulir Asokan
|
||||
//
|
||||
// This Source Code Form is subject to the terms of the Mozilla Public
|
||||
// License, v. 2.0. If a copy of the MPL was not distributed with this
|
||||
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
|
||||
|
||||
package util
|
||||
|
||||
import "sync"
|
||||
|
||||
// SyncMap is a simple map with a built-in mutex.
|
||||
type SyncMap[Key comparable, Value any] struct {
|
||||
data map[Key]Value
|
||||
lock sync.RWMutex
|
||||
}
|
||||
|
||||
func NewSyncMap[Key comparable, Value any]() *SyncMap[Key, Value] {
|
||||
return &SyncMap[Key, Value]{
|
||||
data: make(map[Key]Value),
|
||||
}
|
||||
}
|
||||
|
||||
// Set stores a value in the map.
|
||||
func (sm *SyncMap[Key, Value]) Set(key Key, value Value) {
|
||||
sm.Swap(key, value)
|
||||
}
|
||||
|
||||
// Swap sets a value in the map and returns the old value.
|
||||
//
|
||||
// The boolean return parameter is true if the value already existed, false if not.
|
||||
func (sm *SyncMap[Key, Value]) Swap(key Key, value Value) (oldValue Value, wasReplaced bool) {
|
||||
sm.lock.Lock()
|
||||
oldValue, wasReplaced = sm.data[key]
|
||||
sm.data[key] = value
|
||||
sm.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Delete removes a key from the map.
|
||||
func (sm *SyncMap[Key, Value]) Delete(key Key) {
|
||||
sm.Pop(key)
|
||||
}
|
||||
|
||||
// Pop removes a key from the map and returns the old value.
|
||||
//
|
||||
// The boolean return parameter is the same as with normal Go map access (true if the key exists, false if not).
|
||||
func (sm *SyncMap[Key, Value]) Pop(key Key) (value Value, ok bool) {
|
||||
sm.lock.Lock()
|
||||
value, ok = sm.data[key]
|
||||
delete(sm.data, key)
|
||||
sm.lock.Unlock()
|
||||
return
|
||||
}
|
||||
|
||||
// Get gets a value in the map.
|
||||
//
|
||||
// The boolean return parameter is the same as with normal Go map access (true if the key exists, false if not).
|
||||
func (sm *SyncMap[Key, Value]) Get(key Key) (value Value, ok bool) {
|
||||
sm.lock.RLock()
|
||||
value, ok = sm.data[key]
|
||||
sm.lock.RUnlock()
|
||||
return
|
||||
}
|
||||
|
||||
// GetOrSet gets a value in the map if the key already exists, otherwise inserts the given value and returns it.
|
||||
//
|
||||
// The boolean return parameter is true if the key already exists, and false if the given value was inserted.
|
||||
func (sm *SyncMap[Key, Value]) GetOrSet(key Key, value Value) (actual Value, wasGet bool) {
|
||||
sm.lock.Lock()
|
||||
defer sm.lock.Unlock()
|
||||
actual, wasGet = sm.data[key]
|
||||
if wasGet {
|
||||
return
|
||||
}
|
||||
sm.data[key] = value
|
||||
actual = value
|
||||
return
|
||||
}
|
||||
|
||||
// Clone returns a copy of the map.
|
||||
func (sm *SyncMap[Key, Value]) Clone() *SyncMap[Key, Value] {
|
||||
return &SyncMap[Key, Value]{data: sm.CopyData()}
|
||||
}
|
||||
|
||||
// CopyData returns a copy of the data in the map as a normal (non-atomic) map.
|
||||
func (sm *SyncMap[Key, Value]) CopyData() map[Key]Value {
|
||||
sm.lock.RLock()
|
||||
copied := make(map[Key]Value, len(sm.data))
|
||||
for key, value := range sm.data {
|
||||
copied[key] = value
|
||||
}
|
||||
sm.lock.RUnlock()
|
||||
return copied
|
||||
}
|
||||
30
vendor/maunium.net/go/mautrix/version.go
generated
vendored
30
vendor/maunium.net/go/mautrix/version.go
generated
vendored
@@ -1,5 +1,31 @@
|
||||
package mautrix
|
||||
|
||||
const Version = "v0.13.0"
|
||||
import (
|
||||
"fmt"
|
||||
"regexp"
|
||||
"runtime"
|
||||
"strings"
|
||||
)
|
||||
|
||||
var DefaultUserAgent = "mautrix-go/" + Version
|
||||
const Version = "v0.15.2"
|
||||
|
||||
var GoModVersion = ""
|
||||
var Commit = ""
|
||||
var VersionWithCommit = Version
|
||||
|
||||
var DefaultUserAgent = "mautrix-go/" + Version + " go/" + strings.TrimPrefix(runtime.Version(), "go")
|
||||
|
||||
var goModVersionRegex = regexp.MustCompile(`v.+\d{14}-([0-9a-f]{12})`)
|
||||
|
||||
func init() {
|
||||
if GoModVersion != "" {
|
||||
match := goModVersionRegex.FindStringSubmatch(GoModVersion)
|
||||
if match != nil {
|
||||
Commit = match[1]
|
||||
}
|
||||
}
|
||||
if Commit != "" {
|
||||
VersionWithCommit = fmt.Sprintf("%s+dev.%s", Version, Commit[:8])
|
||||
DefaultUserAgent = strings.Replace(DefaultUserAgent, "mautrix-go/"+Version, "mautrix-go/"+VersionWithCommit, 1)
|
||||
}
|
||||
}
|
||||
|
||||
2
vendor/maunium.net/go/mautrix/versions.go
generated
vendored
2
vendor/maunium.net/go/mautrix/versions.go
generated
vendored
@@ -71,6 +71,8 @@ var (
|
||||
SpecV13 = MustParseSpecVersion("v1.3")
|
||||
SpecV14 = MustParseSpecVersion("v1.4")
|
||||
SpecV15 = MustParseSpecVersion("v1.5")
|
||||
SpecV16 = MustParseSpecVersion("v1.6")
|
||||
SpecV17 = MustParseSpecVersion("v1.7")
|
||||
)
|
||||
|
||||
func (svf SpecVersionFormat) String() string {
|
||||
|
||||
Reference in New Issue
Block a user