BREAKING: update mautrix to 0.15.x

This commit is contained in:
Aine
2023-06-01 14:32:20 +00:00
parent a6b20a75ab
commit 2bdb8ca635
222 changed files with 7851 additions and 23986 deletions

View File

@@ -1,11 +0,0 @@
lint:
image: registry.gitlab.com/etke.cc/base
script:
- golangci-lint run ./...
unit:
image: registry.gitlab.com/etke.cc/base
script:
- go test -coverprofile=cover.out ./...
- go tool cover -func=cover.out
- rm -f cover.out

View File

@@ -1,165 +0,0 @@
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.

View File

@@ -1,6 +0,0 @@
# logger
Simple go logger, based on [log](https://pkg.go.dev/log) with following features:
* implements mautrix-go [Logger](https://pkg.go.dev/maunium.net/go/mautrix#Logger), [WarnLogger](https://pkg.go.dev/maunium.net/go/mautrix#WarnLogger), [crypto Logger](https://pkg.go.dev/maunium.net/go/mautrix/crypto#Logger)
* integrated with [sentry](https://sentry.io) - automatically add breadcrumbs for any log entry

View File

@@ -1,188 +0,0 @@
package logger
import (
"fmt"
"log"
"os"
"strings"
"github.com/getsentry/sentry-go"
)
// Logger struct
type Logger struct {
log *log.Logger
hub *sentry.Hub
level int
}
const (
// TRACE level
TRACE int = iota
// DEBUG level
DEBUG
// INFO level
INFO
// WARNING level
WARNING
// ERROR level
ERROR
// FATAL level
FATAL
)
var (
txtLevelMap = map[string]int{
"TRACE": TRACE,
"DEBUG": DEBUG,
"INFO": INFO,
"WARNING": WARNING,
"ERROR": ERROR,
"FATAL": FATAL,
}
levelMap = map[int]string{
TRACE: "TRACE",
DEBUG: "DEBUG",
INFO: "INFO",
WARNING: "WARNING",
ERROR: "ERROR",
FATAL: "FATAL",
}
sentryLevelMap = map[int]sentry.Level{
TRACE: sentry.LevelDebug,
DEBUG: sentry.LevelDebug,
INFO: sentry.LevelInfo,
WARNING: sentry.LevelWarning,
ERROR: sentry.LevelError,
FATAL: sentry.LevelFatal,
}
)
// New creates new Logger object
func New(prefix string, level string, sentryHub ...*sentry.Hub) *Logger {
levelID, ok := txtLevelMap[strings.ToUpper(level)]
if !ok {
levelID = INFO
}
var hub *sentry.Hub
if len(sentryHub) > 0 {
hub = sentryHub[0]
}
return &Logger{log: log.New(os.Stdout, prefix, 0), level: levelID, hub: hub}
}
// GetHub returns sentry hub (either attached to the logger when called New() or current sentry hub)
func (l *Logger) GetHub() *sentry.Hub {
if l.hub == nil {
return sentry.CurrentHub()
}
return l.hub
}
// GetLog returns underlying Logger object, useful in cases where log.Logger required
func (l *Logger) GetLog() *log.Logger {
return l.log
}
// GetLevel (current)
func (l *Logger) GetLevel() string {
return levelMap[l.level]
}
// Fatal log and exit
func (l *Logger) Fatal(message string, args ...interface{}) {
l.log.Panicln("FATAL", fmt.Sprintf(message, args...))
}
// Error log
func (l *Logger) Error(message string, args ...interface{}) {
// do not recover
if strings.HasPrefix(message, "recovery()") {
return
}
message = fmt.Sprintf(message, args...)
l.GetHub().AddBreadcrumb(&sentry.Breadcrumb{
Category: l.log.Prefix(),
Message: message,
Level: sentryLevelMap[ERROR],
}, nil)
if l.level > ERROR {
return
}
l.log.Println("ERROR", message)
}
// Warn log
func (l *Logger) Warn(message string, args ...interface{}) {
message = fmt.Sprintf(message, args...)
l.GetHub().AddBreadcrumb(&sentry.Breadcrumb{
Category: l.log.Prefix(),
Message: message,
Level: sentryLevelMap[WARNING],
}, nil)
if l.level > WARNING {
return
}
l.log.Println("WARNING", message)
}
// Warnfln for mautrix.Logger
func (l *Logger) Warnfln(message string, args ...interface{}) {
l.Warn(message, args...)
}
// Info log
func (l *Logger) Info(message string, args ...interface{}) {
message = fmt.Sprintf(message, args...)
l.GetHub().AddBreadcrumb(&sentry.Breadcrumb{
Category: l.log.Prefix(),
Message: message,
Level: sentryLevelMap[INFO],
}, nil)
if l.level > INFO {
return
}
l.log.Println("INFO", message)
}
// Debug log
func (l *Logger) Debug(message string, args ...interface{}) {
message = fmt.Sprintf(message, args...)
l.GetHub().AddBreadcrumb(&sentry.Breadcrumb{
Category: l.log.Prefix(),
Message: message,
Level: sentryLevelMap[DEBUG],
}, nil)
if l.level > DEBUG {
return
}
l.log.Println("DEBUG", message)
}
// Debugfln for mautrix.Logger
func (l *Logger) Debugfln(message string, args ...interface{}) {
l.Debug(message, args...)
}
// Trace log
func (l *Logger) Trace(message string, args ...interface{}) {
message = fmt.Sprintf(message, args...)
l.GetHub().AddBreadcrumb(&sentry.Breadcrumb{
Category: l.log.Prefix(),
Message: message,
Level: sentryLevelMap[TRACE],
}, nil)
if l.level > TRACE {
return
}
l.log.Println("TRACE", message)
}

View File

@@ -1,26 +0,0 @@
# update go dependencies
update:
go get .
go get -u maunium.net/go/mautrix
go mod tidy
mock:
-@rm -rf mocks
@mockery --all
# run linter
lint:
golangci-lint run ./...
# run linter and fix issues if possible
lintfix:
golangci-lint run --fix ./...
vuln:
govulncheck ./...
# run unit tests
test:
@go test ${BUILDFLAGS} -coverprofile=cover.out ./...
@go tool cover -func=cover.out
-@rm -f cover.out

View File

@@ -10,7 +10,6 @@ import (
func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
cached, ok := l.acc.Get(name)
if ok {
l.logAccountData(l.log.Debug, "GetAccountData(%q) cached:", cached, name)
if cached == nil {
return map[string]string{}, nil
}
@@ -20,7 +19,6 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
var data map[string]string
err := l.GetClient().GetAccountData(name, &data)
if err != nil {
l.logAccountData(l.log.Debug, "GetAccountData(%q) error: %v", nil, name, err)
data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") {
l.acc.Add(name, data)
@@ -29,7 +27,6 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
return data, err
}
data = l.decryptAccountData(data)
l.logAccountData(l.log.Debug, "GetAccountData(%q):", data, name)
l.acc.Add(name, data)
return data, err
@@ -39,7 +36,6 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
func (l *Linkpearl) SetAccountData(name string, data map[string]string) error {
l.acc.Add(name, data)
l.logAccountData(l.log.Debug, "SetAccountData(%q):", data, name)
data = l.encryptAccountData(data)
return l.GetClient().SetAccountData(name, data)
}
@@ -49,7 +45,6 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
key := roomID.String() + name
cached, ok := l.acc.Get(key)
if ok {
l.logAccountData(l.log.Debug, "GetRoomAccountData(%q, %q) cached:", cached, roomID, name)
if cached == nil {
return map[string]string{}, nil
}
@@ -59,7 +54,6 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
var data map[string]string
err := l.GetClient().GetRoomAccountData(roomID, name, &data)
if err != nil {
l.logAccountData(l.log.Debug, "GetRoomAccountData(%q, %q) error: %v", nil, roomID, name, err)
data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") {
l.acc.Add(key, data)
@@ -68,7 +62,6 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
return data, err
}
data = l.decryptAccountData(data)
l.logAccountData(l.log.Debug, "GetRoomAccountData(%q, %q):", data, roomID, name)
l.acc.Add(key, data)
return data, err
@@ -79,7 +72,6 @@ func (l *Linkpearl) SetRoomAccountData(roomID id.RoomID, name string, data map[s
key := roomID.String() + name
l.acc.Add(key, data)
l.logAccountData(l.log.Debug, "SetRoomAccountData(%q, %q):", data, roomID, name)
data = l.encryptAccountData(data)
return l.GetClient().SetRoomAccountData(roomID, name, data)
}
@@ -93,11 +85,11 @@ func (l *Linkpearl) encryptAccountData(data map[string]string) map[string]string
for k, v := range data {
ek, err := l.acr.Encrypt(k)
if err != nil {
l.log.Error("cannot encrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot encrypt account data")
}
ev, err := l.acr.Encrypt(v)
if err != nil {
l.log.Error("cannot encrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot encrypt account data")
}
encrypted[ek] = ev // worst case: plaintext value
}
@@ -114,35 +106,14 @@ func (l *Linkpearl) decryptAccountData(data map[string]string) map[string]string
for ek, ev := range data {
k, err := l.acr.Decrypt(ek)
if err != nil {
l.log.Error("cannot decrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot decrypt account data")
}
v, err := l.acr.Decrypt(ev)
if err != nil {
l.log.Error("cannot decrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot decrypt account data")
}
decrypted[k] = v // worst case: encrypted value, usual case: migration from plaintext to encrypted account data
}
return decrypted
}
func (l *Linkpearl) logAccountData(method func(string, ...any), message string, data map[string]string, args ...any) {
if len(data) == 0 {
method(message, args...)
return
}
safeData := make(map[string]string, len(data))
for k, v := range data {
sv, ok := l.aclr[k]
if ok {
safeData[k] = sv
continue
}
safeData[k] = v
}
args = append(args, safeData)
method(message+" %+v", args...)
}

View File

@@ -1,74 +0,0 @@
package linkpearl
import (
"errors"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
func (l *Linkpearl) login(username string, password string) error {
if err := l.restoreSession(); err == nil {
l.log.Debug("session restored successfully")
return nil
}
l.log.Debug("auth using login and password...")
_, err := l.api.Login(&mautrix.ReqLogin{
Type: "m.login.password",
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: username,
},
Password: password,
StoreCredentials: true,
})
if err != nil {
l.log.Error("cannot authorize using login and password: %v", err)
return err
}
l.store.SaveSession(l.api.UserID, l.api.DeviceID, l.api.AccessToken)
return nil
}
// restoreSession tries to load previous active session token from db (if any)
func (l *Linkpearl) restoreSession() error {
l.log.Debug("restoring previous session...")
userID, deviceID, token := l.store.LoadSession()
if userID == "" || deviceID == "" || token == "" {
return errors.New("cannot restore session from db")
}
if !l.validateSession(userID, deviceID, token) {
return errors.New("restored session is invalid")
}
l.api.AccessToken = token
l.api.UserID = userID
l.api.DeviceID = deviceID
return nil
}
func (l *Linkpearl) validateSession(userID id.UserID, deviceID id.DeviceID, token string) bool {
valid := true
// preserve current values
currentToken := l.api.AccessToken
currentUserID := l.api.UserID
currentDeviceID := l.api.DeviceID
// set new values
l.api.AccessToken = token
l.api.UserID = userID
l.api.DeviceID = deviceID
if _, err := l.api.GetOwnPresence(); err != nil {
l.log.Debug("previous session token was not found or invalid: %v", err)
valid = false
}
// restore original values
l.api.AccessToken = currentToken
l.api.UserID = currentUserID
l.api.DeviceID = currentDeviceID
return valid
}

View File

@@ -1,11 +1,10 @@
// Package config was added to store cross-package structs and interfaces.
package config
package linkpearl
import (
"database/sql"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
)
@@ -32,25 +31,11 @@ type Config struct {
// AccountDataSecret (Password) for encryption
AccountDataSecret string
// AccountDataLogReplace contains map of field name => value
// that will be used to replace mentioned account data fields with provided values
// when printing in logs (DEBUG, TRACE)
AccountDataLogReplace map[string]string
// MaxRetries for operations like auto join
MaxRetries int
// NoEncryption disabled encryption support
NoEncryption bool
// LPLogger used for linkpearl's glue code
LPLogger Logger
// APILogger used for matrix CS API calls
APILogger Logger
// StoreLogger used for persistent store
StoreLogger Logger
// CryptoLogger used for OLM machine
CryptoLogger Logger
// Logger
Logger zerolog.Logger
// DB object
DB *sql.DB
@@ -58,10 +43,16 @@ type Config struct {
Dialect string
}
// Logger implementation of crypto.Logger and mautrix.Logger
type Logger interface {
crypto.Logger
mautrix.WarnLogger
Info(message string, args ...interface{})
// LoginAs for cryptohelper
func (cfg *Config) LoginAs() *mautrix.ReqLogin {
return &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: cfg.Login,
},
Password: cfg.Password,
StoreCredentials: true,
StoreHomeserverURL: true,
}
}

26
vendor/gitlab.com/etke.cc/linkpearl/justfile generated vendored Normal file
View File

@@ -0,0 +1,26 @@
# show help by default
default:
@just --list --justfile {{ justfile() }}
# update go deps
update:
go get .
go get -u maunium.net/go/mautrix
go mod tidy
# run linter
lint:
golangci-lint run ./...
# automatically fix liter issues
lintfix:
golangci-lint run --fix ./...
vuln:
govulncheck ./...
# run unit tests
test:
@go test ${BUILDFLAGS} -coverprofile=cover.out ./...
@go tool cover -func=cover.out
-@rm -f cover.out

View File

@@ -5,12 +5,12 @@ import (
"database/sql"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/cryptohelper"
"maunium.net/go/mautrix/event"
"gitlab.com/etke.cc/linkpearl/config"
"gitlab.com/etke.cc/linkpearl/store"
"maunium.net/go/mautrix/util/dbutil"
)
const (
@@ -22,14 +22,12 @@ const (
// Linkpearl object
type Linkpearl struct {
db *sql.DB
acc *lru.Cache[string, map[string]string]
acr *Crypter
aclr map[string]string
log config.Logger
api *mautrix.Client
olm *crypto.OlmMachine
store *store.Store
db *sql.DB
ch *cryptohelper.CryptoHelper
acc *lru.Cache[string, map[string]string]
acr *Crypter
log zerolog.Logger
api *mautrix.Client
joinPermit func(*event.Event) bool
autoleave bool
@@ -41,10 +39,7 @@ type ReqPresence struct {
StatusMsg string `json:"status_msg,omitempty"`
}
func setDefaults(cfg *config.Config) {
if cfg.AccountDataLogReplace == nil {
cfg.AccountDataLogReplace = make(map[string]string)
}
func setDefaults(cfg *Config) {
if cfg.MaxRetries == 0 {
cfg.MaxRetries = DefaultMaxRetries
}
@@ -66,13 +61,13 @@ func initCrypter(secret string) (*Crypter, error) {
}
// New linkpearl
func New(cfg *config.Config) (*Linkpearl, error) {
func New(cfg *Config) (*Linkpearl, error) {
setDefaults(cfg)
api, err := mautrix.NewClient(cfg.Homeserver, "", "")
if err != nil {
return nil, err
}
api.Logger = cfg.APILogger
api.Log = cfg.Logger
acc, _ := lru.New[string, map[string]string](cfg.AccountDataCache) //nolint:errcheck // addressed in setDefaults()
acr, err := initCrypter(cfg.AccountDataSecret)
@@ -84,35 +79,27 @@ func New(cfg *config.Config) (*Linkpearl, error) {
db: cfg.DB,
acc: acc,
acr: acr,
aclr: cfg.AccountDataLogReplace,
api: api,
log: cfg.LPLogger,
log: cfg.Logger,
joinPermit: cfg.JoinPermit,
autoleave: cfg.AutoLeave,
maxretries: cfg.MaxRetries,
}
storer := store.New(cfg.DB, cfg.Dialect, cfg.StoreLogger)
if err = storer.CreateTables(); err != nil {
db, err := dbutil.NewWithDB(cfg.DB, cfg.Dialect)
if err != nil {
return nil, err
}
lp.store = storer
lp.api.Store = storer
if err = lp.login(cfg.Login, cfg.Password); err != nil {
db.Log = dbutil.ZeroLogger(cfg.Logger)
lp.ch, err = cryptohelper.NewCryptoHelper(lp.api, []byte(cfg.Login), db)
if err != nil {
return nil, err
}
if !cfg.NoEncryption {
if err = lp.store.WithCrypto(lp.api.UserID, lp.api.DeviceID, cfg.StoreLogger); err != nil {
return nil, err
}
lp.olm = crypto.NewOlmMachine(lp.api, cfg.CryptoLogger, lp.store, lp.store)
if err = lp.olm.Load(); err != nil {
return nil, err
}
lp.ch.LoginAs = cfg.LoginAs()
if err = lp.ch.Init(); err != nil {
return nil, err
}
lp.api.Crypto = lp.ch
return lp, nil
}
@@ -126,14 +113,9 @@ func (l *Linkpearl) GetDB() *sql.DB {
return l.db
}
// GetStore returns underlying persistent store object, compatible with crypto.Store, crypto.StateStore and mautrix.Storer
func (l *Linkpearl) GetStore() *store.Store {
return l.store
}
// GetMachine returns underlying OLM machine
func (l *Linkpearl) GetMachine() *crypto.OlmMachine {
return l.olm
return l.ch.Machine()
}
// GetAccountDataCrypter returns crypter used for account data (if any)
@@ -165,21 +147,23 @@ func (l *Linkpearl) Start(optionalStatusMsg ...string) error {
err := l.SetPresence(event.PresenceOnline, statusMsg)
if err != nil {
l.log.Error("cannot set presence: %v", err)
l.log.Error().Err(err).Msg("cannot set presence")
}
defer l.Stop()
l.log.Info("client has been started")
l.log.Info().Msg("client has been started")
return l.api.Sync()
}
// Stop the client
func (l *Linkpearl) Stop() {
l.log.Debug("stopping the client")
err := l.api.SetPresence(event.PresenceOffline)
if err != nil {
l.log.Error("cannot set presence: %v", err)
l.log.Debug().Msg("stopping the client")
if err := l.api.SetPresence(event.PresenceOffline); err != nil {
l.log.Error().Err(err).Msg("cannot set presence")
}
l.api.StopSync()
l.log.Info("client has been stopped")
if err := l.ch.Close(); err != nil {
l.log.Error().Err(err).Msg("cannot close crypto helper")
}
l.log.Info().Msg("client has been stopped")
}

View File

@@ -4,27 +4,21 @@ import (
"fmt"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
// Send a message to the roomID and automatically try to encrypt it, if the destination room is encrypted
//
//nolint:unparam // it's public interface
func (l *Linkpearl) Send(roomID id.RoomID, content interface{}) (id.EventID, error) {
if !l.store.IsEncrypted(roomID) {
l.log.Debug("room %q is not encrypted", roomID)
return l.SendPlaintext(roomID, content)
}
l.log.Debug("room %q is encrypted", roomID)
encrypted, err := l.EncryptEvent(roomID, content)
l.log.Debug().Str("roomID", roomID.String()).Any("content", content).Msg("sending event")
resp, err := l.api.SendMessageEvent(roomID, event.EventMessage, content)
if err != nil {
l.log.Error("cannot encrypt message: %v, sending plaintext...", roomID, err)
return l.SendPlaintext(roomID, content)
return "", err
}
return l.SendEncrypted(roomID, encrypted)
return resp.EventID, nil
}
// SendNotice to a room with optional thread relation
@@ -39,7 +33,7 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, threadID id.EventID, message st
_, err := l.Send(roomID, &content)
if err != nil {
l.log.Error("cannot send a notice into room %q: %v", roomID, err)
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot send a notice int the room")
}
}
@@ -47,7 +41,7 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, threadID id.EventID, message st
func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgtype event.MessageType, relation *event.RelatesTo) error {
resp, err := l.GetClient().UploadMedia(*req)
if err != nil {
l.log.Error("cannot upload file %q: %v", req.FileName, err)
l.log.Error().Err(err).Str("file", req.FileName).Msg("cannot upload file")
return err
}
_, err = l.Send(roomID, &event.Content{
@@ -59,43 +53,8 @@ func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgt
},
})
if err != nil {
l.log.Error("cannot send uploaded file: %q: %v", req.FileName, err)
l.log.Error().Err(err).Str("file", req.FileName).Msg("cannot send uploaded file")
}
return err
}
// SendPlaintext sends plaintext event only
func (l *Linkpearl) SendPlaintext(roomID id.RoomID, content interface{}) (id.EventID, error) {
l.log.Debug("sending plaintext event to %q: %+v", roomID, content)
resp, err := l.api.SendMessageEvent(roomID, event.EventMessage, content)
if err != nil {
return "", err
}
return resp.EventID, nil
}
// SendEncrypted sends encrypted event only
func (l *Linkpearl) SendEncrypted(roomID id.RoomID, content interface{}) (id.EventID, error) {
l.log.Debug("sending encrypted event to %q: %+v", roomID, content)
resp, err := l.api.SendMessageEvent(roomID, event.EventEncrypted, content)
if err != nil {
return "", err
}
return resp.EventID, nil
}
// EncryptEvent before sending
func (l *Linkpearl) EncryptEvent(roomID id.RoomID, content interface{}) (*event.EncryptedEventContent, error) {
l.log.Debug("encrypting event %+v", content)
encrypted, err := l.olm.EncryptMegolmEvent(roomID, event.EventMessage, content)
if crypto.IsShareError(err) {
err = l.olm.ShareGroupSession(roomID, l.store.GetRoomMembers(roomID))
if err != nil {
return nil, err
}
encrypted, err = l.olm.EncryptMegolmEvent(roomID, event.EventMessage, content)
}
return encrypted, err
}

View File

@@ -1,214 +0,0 @@
package store
import (
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// NOTE: functions in that file are for crypto.Store implementation
// ref: https://pkg.go.dev/maunium.net/go/mautrix/crypto#Store
// Flush does nothing for this implementation as data is already persisted in the database.
// nolint // interface cannot be changed
func (s *Store) Flush() error {
s.log.Debug("flushing crypto store")
return nil
}
// PutNextBatch stores the next sync batch token for the current account.
func (s *Store) PutNextBatch(nextBatch string) error {
s.log.Debug("storing next batch token")
return s.s.PutNextBatch(nextBatch)
}
// GetNextBatch retrieves the next sync batch token for the current account.
func (s *Store) GetNextBatch() (string, error) {
s.log.Debug("loading next batch token")
return s.s.GetNextBatch()
}
// PutAccount stores an OlmAccount in the database.
func (s *Store) PutAccount(account *crypto.OlmAccount) error {
s.log.Debug("storing olm account")
return s.s.PutAccount(account)
}
// GetAccount retrieves an OlmAccount from the database.
func (s *Store) GetAccount() (*crypto.OlmAccount, error) {
s.log.Debug("loading olm account")
return s.s.GetAccount()
}
// HasSession returns whether there is an Olm session for the given sender key.
func (s *Store) HasSession(key id.SenderKey) bool {
s.log.Debug("check if olm session exists for the key %q", key)
return s.s.HasSession(key)
}
// GetSessions returns all the known Olm sessions for a sender key.
func (s *Store) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
s.log.Debug("loading olm session for the key %q", key)
return s.s.GetSessions(key)
}
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
func (s *Store) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
s.log.Debug("loading latest session for the key %q", key)
return s.s.GetLatestSession(key)
}
// AddSession persists an Olm session for a sender in the database.
func (s *Store) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
s.log.Debug("adding new olm session for the key %q", key)
return s.s.AddSession(key, session)
}
// UpdateSession replaces the Olm session for a sender in the database.
func (s *Store) UpdateSession(key id.SenderKey, session *crypto.OlmSession) error {
s.log.Debug("update olm session for the key %q", key)
return s.s.UpdateSession(key, session)
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (s *Store) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
s.log.Debug("storing inbound group session for the room %q", roomID)
return s.s.PutGroupSession(roomID, senderKey, sessionID, session)
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (s *Store) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
s.log.Debug("loading inbound group session for the room %q", roomID)
return s.s.GetGroupSession(roomID, senderKey, sessionID)
}
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
// nolint // method is part of interface and cannot be changed
func (s *Store) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
s.log.Debug("storing withheld group session")
return s.s.PutWithheldGroupSession(content)
}
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
func (s *Store) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
s.log.Debug("loading withheld group session")
return s.s.GetWithheldGroupSession(roomID, senderKey, sessionID)
}
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
func (s *Store) GetGroupSessionsForRoom(roomID id.RoomID) ([]*crypto.InboundGroupSession, error) {
s.log.Debug("loading group session for the room %q", roomID)
return s.s.GetGroupSessionsForRoom(roomID)
}
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
func (s *Store) GetAllGroupSessions() ([]*crypto.InboundGroupSession, error) {
s.log.Debug("loading all group sessions")
return s.s.GetAllGroupSessions()
}
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
func (s *Store) AddOutboundGroupSession(session *crypto.OutboundGroupSession) (err error) {
s.log.Debug("storing outbound group session")
return s.s.AddOutboundGroupSession(session)
}
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
func (s *Store) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
s.log.Debug("updating outbound group session")
return s.s.UpdateOutboundGroupSession(session)
}
// GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID.
func (s *Store) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
s.log.Debug("loading outbound group session")
return s.s.GetOutboundGroupSession(roomID)
}
// RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID.
func (s *Store) RemoveOutboundGroupSession(roomID id.RoomID) error {
s.log.Debug("removing outbound group session")
return s.s.RemoveOutboundGroupSession(roomID)
}
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index.
// If the event information was not yet stored, it's stored now.
func (s *Store) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
s.log.Debug("validating message index")
return s.s.ValidateMessageIndex(senderKey, sessionID, eventID, index, timestamp)
}
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
func (s *Store) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
s.log.Debug("loading devices of the %q", userID)
return s.s.GetDevices(userID)
}
// GetDevice returns the device dentity for a given user and device ID.
func (s *Store) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
s.log.Debug("loading device %q for the %q", deviceID, userID)
return s.s.GetDevice(userID, deviceID)
}
// FindDeviceByKey finds a specific device by its sender key.
func (s *Store) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
s.log.Debug("loading device of the %q by the key %q", userID, identityKey)
return s.s.FindDeviceByKey(userID, identityKey)
}
// PutDevice stores a single device for a user, replacing it if it exists already.
func (s *Store) PutDevice(userID id.UserID, device *id.Device) error {
s.log.Debug("storing device of the %q", userID)
return s.s.PutDevice(userID, device)
}
// PutDevices stores the device identity information for the given user ID.
func (s *Store) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
s.log.Debug("storing devices of the %q", userID)
return s.s.PutDevices(userID, devices)
}
// FilterTrackedUsers finds all of the user IDs out of the given ones for which the database contains identity information.
func (s *Store) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
s.log.Debug("filtering tracked users")
return s.s.FilterTrackedUsers(users)
}
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
func (s *Store) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
s.log.Debug("storing crosssigning key of the %q", userID)
return s.s.PutCrossSigningKey(userID, usage, key)
}
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
func (s *Store) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
s.log.Debug("loading crosssigning keys of the %q", userID)
return s.s.GetCrossSigningKeys(userID)
}
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
func (s *Store) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
s.log.Debug("storing signature")
return s.s.PutSignature(signedUserID, signedKey, signerUserID, signerKey, signature)
}
// GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer.
func (s *Store) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
s.log.Debug("loading signatures")
return s.s.GetSignaturesForKeyBy(userID, key, signerID)
}
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
func (s *Store) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
s.log.Debug("checking if key is signed by")
return s.s.IsKeySignedBy(userID, key, signerID, signerKey)
}
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
func (s *Store) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
s.log.Debug("removing signatures by the %q/%q", userID, key)
return s.s.DropSignaturesByKey(userID, key)
}

View File

@@ -1,209 +0,0 @@
package store
import (
"encoding/json"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var acceptedMembershipTypes = []event.Membership{
event.MembershipJoin,
event.MembershipInvite,
event.MembershipBan,
event.MembershipLeave,
}
// IsEncrypted returns whether a room is encrypted.
func (s *Store) IsEncrypted(roomID id.RoomID) bool {
if !s.encryption {
return false
}
s.log.Debug("checking if room %q is encrypted", roomID)
return s.GetEncryptionEvent(roomID) != nil
}
// SetEncryptionEvent creates or updates room's encryption event info
func (s *Store) SetEncryptionEvent(evt *event.Event) {
if !s.encryption {
return
}
if evt == nil {
return
}
var encryptionEventJSON []byte
encryptionEventJSON, err := json.Marshal(evt)
if err != nil {
s.log.Debug("cannot marshal encryption event: %v", err)
return
}
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO rooms VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO rooms VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
update := "UPDATE rooms SET encryption_event = $1 WHERE room_id = $2"
_, err = tx.Exec(update, encryptionEventJSON, evt.RoomID)
if err != nil {
s.log.Error("cannot update encryption event: %v", err)
// nolint // we already have err to return
tx.Rollback()
return
}
_, err = tx.Exec(insert, evt.RoomID, encryptionEventJSON)
if err != nil {
s.log.Error("cannot insert encryption event: %v", err)
// nolint // interface doesn't allow to return error
tx.Rollback()
return
}
err = tx.Commit()
if err != nil {
s.log.Error("cannot commit transaction: %v", err)
}
}
// SetMembership saves room members
func (s *Store) SetMembership(evt *event.Event) {
s.log.Debug("saving membership event for %q", evt.RoomID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO room_members VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO room_members VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
del := "DELETE FROM room_members WHERE room_id = $1 AND user_id = $2"
membership := evt.Content.AsMember().Membership
if s.shouldIgnoreMembership(membership) {
return
}
if membership.IsInviteOrJoin() {
_, err := tx.Exec(insert, evt.RoomID, evt.GetStateKey())
if err != nil {
s.log.Error("cannot insert membership event: %v", err)
// nolint // interface doesn't allow to return error
tx.Rollback()
return
}
} else {
_, err := tx.Exec(del, evt.RoomID, evt.GetStateKey())
if err != nil {
s.log.Error("cannot delete membership event: %v", err)
// nolint // interface doesn't allow to return error
tx.Rollback()
return
}
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot commit transaction: %v", commitErr)
// nolint // interface doesn't allow to return error
tx.Rollback()
}
}
// GetRoomMembers ...
func (s *Store) GetRoomMembers(roomID id.RoomID) []id.UserID {
s.log.Debug("loading room members of %q", roomID)
query := "SELECT user_id FROM room_members WHERE room_id = $1"
rows, err := s.db.Query(query, roomID)
users := make([]id.UserID, 0)
if err != nil {
s.log.Error("cannot load room members: %v", err)
return users
}
defer rows.Close()
var userID id.UserID
for rows.Next() {
if err := rows.Scan(&userID); err == nil {
users = append(users, userID)
}
}
return users
}
// SaveSession to DB
func (s *Store) SaveSession(userID id.UserID, deviceID id.DeviceID, accessToken string) {
s.log.Debug("saving session credentials of %q/%q", userID, deviceID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO session VALUES (?, ?, ?)"
case "postgres":
insert = "INSERT INTO session VALUES ($1, $2, $3) ON CONFLICT DO NOTHING"
}
update := "UPDATE session SET access_token = $1, device_id = $2 WHERE user_id = $3"
if _, err = tx.Exec(update, accessToken, deviceID, userID); err != nil {
s.log.Error("cannot update session credentials: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
if _, err = tx.Exec(insert, userID, deviceID, accessToken); err != nil {
s.log.Error("cannot insert session credentials: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
err = tx.Commit()
if err != nil {
s.log.Error("cannot commit transaction: %v", err)
}
}
// LoadSession from DB (user ID, device ID, access token)
func (s *Store) LoadSession() (id.UserID, id.DeviceID, string) {
s.log.Debug("loading session credentials...")
row := s.db.QueryRow("SELECT * FROM session LIMIT 1")
var userID id.UserID
var deviceID id.DeviceID
var accessToken string
if err := row.Scan(&userID, &deviceID, &accessToken); err != nil {
s.log.Error("cannot load session credentials: %v", err)
return "", "", ""
}
return userID, deviceID, accessToken
}
func (s *Store) shouldIgnoreMembership(membership event.Membership) bool {
for _, mtype := range acceptedMembershipTypes {
if membership == mtype {
return false
}
}
return true
}

View File

@@ -1,66 +0,0 @@
package store
var migrations = []string{
`
CREATE TABLE IF NOT EXISTS user_filter_ids (
user_id VARCHAR(255) PRIMARY KEY,
filter_id VARCHAR(255)
)
`,
`
CREATE TABLE IF NOT EXISTS user_batch_tokens (
user_id VARCHAR(255) PRIMARY KEY,
next_batch_token VARCHAR(255)
)
`,
`
CREATE TABLE IF NOT EXISTS rooms (
room_id VARCHAR(255) PRIMARY KEY,
encryption_event VARCHAR(65535) NULL
)
`,
`
CREATE TABLE IF NOT EXISTS room_members (
room_id VARCHAR(255),
user_id VARCHAR(255),
PRIMARY KEY (room_id, user_id)
)
`,
`
CREATE TABLE IF NOT EXISTS session (
user_id VARCHAR(255),
device_id VARCHAR(255),
access_token VARCHAR(255)
)
`,
}
// CreateTables applies all the pending database migrations.
func (s *Store) CreateTables() error {
s.log.Debug("migrating database...")
tx, beginErr := s.db.Begin()
if beginErr != nil {
s.log.Error("cannot begin transaction: %v", beginErr)
return beginErr
}
for _, query := range migrations {
_, execErr := tx.Exec(query)
if execErr != nil {
s.log.Error("cannot apply migration: %v", execErr)
// nolint // we already have the execErr to return
tx.Rollback()
return execErr
}
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot commit transaction: %v", commitErr)
// nolint // we already have the commitErr to return
tx.Rollback()
return commitErr
}
return nil
}

View File

@@ -1,63 +0,0 @@
package store
import (
"database/sql"
"encoding/json"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// NOTE: functions in that file are for crypto.StateStore implementation
// ref: https://pkg.go.dev/maunium.net/go/mautrix/crypto#StateStore
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
func (s *Store) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
if !s.encryption {
return nil
}
s.log.Debug("finding encryption event of %q", roomID)
query := "SELECT encryption_event FROM rooms WHERE room_id = $1"
row := s.db.QueryRow(query, roomID)
var encryptionEventJSON []byte
err := row.Scan(&encryptionEventJSON)
if err != nil && err != sql.ErrNoRows {
s.log.Error("cannot find encryption event: %v", err)
return nil
}
var encryptionEvent event.EncryptionEventContent
if err := json.Unmarshal(encryptionEventJSON, &encryptionEvent); err != nil {
s.log.Debug("cannot unmarshal encryption event: %q", err)
return nil
}
return &encryptionEvent
}
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
func (s *Store) FindSharedRooms(userID id.UserID) []id.RoomID {
if !s.encryption {
return nil
}
s.log.Debug("loading shared rooms for %q", userID)
query := "SELECT room_id FROM room_members WHERE user_id = $1"
rows, queryErr := s.db.Query(query, userID)
rooms := make([]id.RoomID, 0)
if queryErr != nil {
s.log.Error("cannot load room members: %q", queryErr)
return rooms
}
defer rows.Close()
var roomID id.RoomID
for rows.Next() {
scanErr := rows.Scan(&roomID)
if scanErr != nil {
continue
}
rooms = append(rooms, roomID)
}
return rooms
}

View File

@@ -1,55 +0,0 @@
// Package store implements crypto.Store, crypto.StateStore, mautrix.Storer and some additional "glue methods"
package store
import (
"database/sql"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"gitlab.com/etke.cc/linkpearl/config"
)
// Store for the matrix
type Store struct {
db *sql.DB
dialect string
log config.Logger
encryption bool
s *crypto.SQLCryptoStore
}
// New store
func New(db *sql.DB, dialect string, log config.Logger) *Store {
return &Store{
db: db,
log: log,
dialect: dialect,
}
}
// WithCrypto adds crypto store support
func (s *Store) WithCrypto(userID id.UserID, deviceID id.DeviceID, logger config.Logger) error {
s.log.Debug("crypto store enabled")
s.encryption = true
db, err := dbutil.NewWithDB(s.db, s.dialect)
if err != nil {
logger.Error("cannot init database: %v", err)
return err
}
s.s = crypto.NewSQLCryptoStore(
db,
dbutil.NoopLogger,
userID.String(),
deviceID,
[]byte(userID),
)
return s.s.DB.Upgrade()
}
// GetDialect returns database dialect
func (s *Store) GetDialect() string {
return s.dialect
}

View File

@@ -1,126 +0,0 @@
package store
import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
// NOTE: functions in that file are for mautrix.Storer implementation
// ref: https://pkg.go.dev/maunium.net/go/mautrix#Storer
// SaveFilterID to DB
func (s *Store) SaveFilterID(userID id.UserID, filterID string) {
s.log.Debug("saving filter ID %q for %q", filterID, userID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO user_filter_ids VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO user_filter_ids VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
update := "UPDATE user_filter_ids SET filter_id = $1 WHERE user_id = $2"
_, updateErr := tx.Exec(update, filterID, userID)
if updateErr != nil {
s.log.Error("cannot update filter ID: %v", updateErr)
// nolint // no need to check error here
tx.Rollback()
return
}
_, insertErr := tx.Exec(insert, userID, filterID)
if insertErr != nil {
s.log.Error("cannot create filter ID: %v", insertErr)
// nolint // no need to check error here
tx.Rollback()
return
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot upsert filter ID: %v", commitErr)
// nolint // no need to check error here
tx.Rollback()
}
}
// LoadFilterID from DB
func (s *Store) LoadFilterID(userID id.UserID) string {
s.log.Debug("loading filter ID for %q", userID)
query := "SELECT filter_id FROM user_filter_ids WHERE user_id = $1"
row := s.db.QueryRow(query, userID)
var filterID string
if err := row.Scan(&filterID); err != nil {
s.log.Error("cannot load filter ID: %q", err)
return ""
}
return filterID
}
// SaveNextBatch to DB
func (s *Store) SaveNextBatch(userID id.UserID, nextBatchToken string) {
s.log.Debug("saving next batch token for %q", userID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO user_batch_tokens VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO user_batch_tokens VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
update := "UPDATE user_batch_tokens SET next_batch_token = $1 WHERE user_id = $2"
if _, err := tx.Exec(update, nextBatchToken, userID); err != nil {
s.log.Error("cannot update next batch token: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
if _, err := tx.Exec(insert, userID, nextBatchToken); err != nil {
s.log.Error("cannot insert next batch token: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot commit transaction: %v", commitErr)
}
}
// LoadNextBatch from DB
func (s *Store) LoadNextBatch(userID id.UserID) string {
s.log.Debug("loading next batch token for %q", userID)
query := "SELECT next_batch_token FROM user_batch_tokens WHERE user_id = $1"
row := s.db.QueryRow(query, userID)
var batchToken string
if err := row.Scan(&batchToken); err != nil {
s.log.Error("cannot load next batch token: %v", err)
return ""
}
return batchToken
}
// SaveRoom to DB, not implemented
func (s *Store) SaveRoom(room *mautrix.Room) {
s.log.Debug("saving room %q (stub, not implemented)", room.ID)
}
// LoadRoom from DB, not implemented
func (s *Store) LoadRoom(roomID id.RoomID) *mautrix.Room {
s.log.Debug("loading room %q (stub, not implemented)", roomID)
return mautrix.NewRoom(roomID)
}

View File

@@ -11,31 +11,27 @@ import (
// OnEventType allows callers to be notified when there are new events for the given event type.
// There are no duplicate checks.
func (l *Linkpearl) OnEventType(eventType event.Type, callback mautrix.EventHandler) {
l.api.Syncer.(*mautrix.DefaultSyncer).OnEventType(eventType, callback)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(eventType, callback)
}
// OnSync shortcut to mautrix.DefaultSyncer.OnSync
func (l *Linkpearl) OnSync(callback mautrix.SyncHandler) {
l.api.Syncer.(*mautrix.DefaultSyncer).OnSync(callback)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnSync(callback)
}
// OnEvent shortcut to mautrix.DefaultSyncer.OnEvent
func (l *Linkpearl) OnEvent(callback mautrix.EventHandler) {
l.api.Syncer.(*mautrix.DefaultSyncer).OnEvent(callback)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEvent(callback)
}
func (l *Linkpearl) initSync() {
if l.olm != nil {
l.api.Syncer.(*mautrix.DefaultSyncer).OnSync(l.olm.ProcessSyncResponse)
l.api.Syncer.(*mautrix.DefaultSyncer).OnEventType(
event.StateEncryption,
func(source mautrix.EventSource, evt *event.Event) {
go l.onEncryption(source, evt)
},
)
}
l.api.Syncer.(*mautrix.DefaultSyncer).OnEventType(
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateEncryption,
func(source mautrix.EventSource, evt *event.Event) {
go l.onEncryption(source, evt)
},
)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateMember,
func(source mautrix.EventSource, evt *event.Event) {
go l.onMembership(source, evt)
@@ -43,11 +39,9 @@ func (l *Linkpearl) initSync() {
)
}
func (l *Linkpearl) onMembership(_ mautrix.EventSource, evt *event.Event) {
if l.olm != nil {
l.olm.HandleMemberEvent(evt)
}
l.store.SetMembership(evt)
func (l *Linkpearl) onMembership(src mautrix.EventSource, evt *event.Event) {
l.ch.Machine().HandleMemberEvent(src, evt)
l.api.StateStore.SetMembership(evt.RoomID, id.UserID(evt.GetStateKey()), evt.Content.AsMember().Membership)
// potentially autoaccept invites
l.onInvite(evt)
@@ -78,9 +72,9 @@ func (l *Linkpearl) tryJoin(roomID id.RoomID, retry int) {
_, err := l.api.JoinRoomByID(roomID)
if err != nil {
l.log.Error("cannot join the room %q: %v", roomID, err)
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot join room")
time.Sleep(5 * time.Second)
l.log.Debug("trying to join again (%d/%d)", retry+1, l.maxretries)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to join again")
l.tryJoin(roomID, retry+1)
}
}
@@ -92,9 +86,9 @@ func (l *Linkpearl) tryLeave(roomID id.RoomID, retry int) {
_, err := l.api.LeaveRoom(roomID)
if err != nil {
l.log.Error("cannot leave room: %v", err)
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot leave room")
time.Sleep(5 * time.Second)
l.log.Debug("trying to leave again (%d/%d)", retry+1, l.maxretries)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to leave again")
l.tryLeave(roomID, retry+1)
}
}
@@ -104,7 +98,12 @@ func (l *Linkpearl) onEmpty(evt *event.Event) {
return
}
members := l.store.GetRoomMembers(evt.RoomID)
members, err := l.api.StateStore.GetRoomJoinedOrInvitedMembers(evt.RoomID)
if err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members")
return
}
if len(members) >= 1 && members[0] != l.api.UserID {
return
}
@@ -113,5 +112,5 @@ func (l *Linkpearl) onEmpty(evt *event.Event) {
}
func (l *Linkpearl) onEncryption(_ mautrix.EventSource, evt *event.Event) {
l.store.SetEncryptionEvent(evt)
l.api.StateStore.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption())
}