BREAKING: update mautrix to 0.15.x

This commit is contained in:
Aine
2023-06-01 14:32:20 +00:00
parent a6b20a75ab
commit 2bdb8ca635
222 changed files with 7851 additions and 23986 deletions

View File

@@ -1,26 +0,0 @@
# update go dependencies
update:
go get .
go get -u maunium.net/go/mautrix
go mod tidy
mock:
-@rm -rf mocks
@mockery --all
# run linter
lint:
golangci-lint run ./...
# run linter and fix issues if possible
lintfix:
golangci-lint run --fix ./...
vuln:
govulncheck ./...
# run unit tests
test:
@go test ${BUILDFLAGS} -coverprofile=cover.out ./...
@go tool cover -func=cover.out
-@rm -f cover.out

View File

@@ -10,7 +10,6 @@ import (
func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
cached, ok := l.acc.Get(name)
if ok {
l.logAccountData(l.log.Debug, "GetAccountData(%q) cached:", cached, name)
if cached == nil {
return map[string]string{}, nil
}
@@ -20,7 +19,6 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
var data map[string]string
err := l.GetClient().GetAccountData(name, &data)
if err != nil {
l.logAccountData(l.log.Debug, "GetAccountData(%q) error: %v", nil, name, err)
data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") {
l.acc.Add(name, data)
@@ -29,7 +27,6 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
return data, err
}
data = l.decryptAccountData(data)
l.logAccountData(l.log.Debug, "GetAccountData(%q):", data, name)
l.acc.Add(name, data)
return data, err
@@ -39,7 +36,6 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
func (l *Linkpearl) SetAccountData(name string, data map[string]string) error {
l.acc.Add(name, data)
l.logAccountData(l.log.Debug, "SetAccountData(%q):", data, name)
data = l.encryptAccountData(data)
return l.GetClient().SetAccountData(name, data)
}
@@ -49,7 +45,6 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
key := roomID.String() + name
cached, ok := l.acc.Get(key)
if ok {
l.logAccountData(l.log.Debug, "GetRoomAccountData(%q, %q) cached:", cached, roomID, name)
if cached == nil {
return map[string]string{}, nil
}
@@ -59,7 +54,6 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
var data map[string]string
err := l.GetClient().GetRoomAccountData(roomID, name, &data)
if err != nil {
l.logAccountData(l.log.Debug, "GetRoomAccountData(%q, %q) error: %v", nil, roomID, name, err)
data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") {
l.acc.Add(key, data)
@@ -68,7 +62,6 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
return data, err
}
data = l.decryptAccountData(data)
l.logAccountData(l.log.Debug, "GetRoomAccountData(%q, %q):", data, roomID, name)
l.acc.Add(key, data)
return data, err
@@ -79,7 +72,6 @@ func (l *Linkpearl) SetRoomAccountData(roomID id.RoomID, name string, data map[s
key := roomID.String() + name
l.acc.Add(key, data)
l.logAccountData(l.log.Debug, "SetRoomAccountData(%q, %q):", data, roomID, name)
data = l.encryptAccountData(data)
return l.GetClient().SetRoomAccountData(roomID, name, data)
}
@@ -93,11 +85,11 @@ func (l *Linkpearl) encryptAccountData(data map[string]string) map[string]string
for k, v := range data {
ek, err := l.acr.Encrypt(k)
if err != nil {
l.log.Error("cannot encrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot encrypt account data")
}
ev, err := l.acr.Encrypt(v)
if err != nil {
l.log.Error("cannot encrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot encrypt account data")
}
encrypted[ek] = ev // worst case: plaintext value
}
@@ -114,35 +106,14 @@ func (l *Linkpearl) decryptAccountData(data map[string]string) map[string]string
for ek, ev := range data {
k, err := l.acr.Decrypt(ek)
if err != nil {
l.log.Error("cannot decrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot decrypt account data")
}
v, err := l.acr.Decrypt(ev)
if err != nil {
l.log.Error("cannot decrypt account data (key=%q): %v", k, err)
l.log.Error().Err(err).Str("key", k).Msg("cannot decrypt account data")
}
decrypted[k] = v // worst case: encrypted value, usual case: migration from plaintext to encrypted account data
}
return decrypted
}
func (l *Linkpearl) logAccountData(method func(string, ...any), message string, data map[string]string, args ...any) {
if len(data) == 0 {
method(message, args...)
return
}
safeData := make(map[string]string, len(data))
for k, v := range data {
sv, ok := l.aclr[k]
if ok {
safeData[k] = sv
continue
}
safeData[k] = v
}
args = append(args, safeData)
method(message+" %+v", args...)
}

View File

@@ -1,74 +0,0 @@
package linkpearl
import (
"errors"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
func (l *Linkpearl) login(username string, password string) error {
if err := l.restoreSession(); err == nil {
l.log.Debug("session restored successfully")
return nil
}
l.log.Debug("auth using login and password...")
_, err := l.api.Login(&mautrix.ReqLogin{
Type: "m.login.password",
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: username,
},
Password: password,
StoreCredentials: true,
})
if err != nil {
l.log.Error("cannot authorize using login and password: %v", err)
return err
}
l.store.SaveSession(l.api.UserID, l.api.DeviceID, l.api.AccessToken)
return nil
}
// restoreSession tries to load previous active session token from db (if any)
func (l *Linkpearl) restoreSession() error {
l.log.Debug("restoring previous session...")
userID, deviceID, token := l.store.LoadSession()
if userID == "" || deviceID == "" || token == "" {
return errors.New("cannot restore session from db")
}
if !l.validateSession(userID, deviceID, token) {
return errors.New("restored session is invalid")
}
l.api.AccessToken = token
l.api.UserID = userID
l.api.DeviceID = deviceID
return nil
}
func (l *Linkpearl) validateSession(userID id.UserID, deviceID id.DeviceID, token string) bool {
valid := true
// preserve current values
currentToken := l.api.AccessToken
currentUserID := l.api.UserID
currentDeviceID := l.api.DeviceID
// set new values
l.api.AccessToken = token
l.api.UserID = userID
l.api.DeviceID = deviceID
if _, err := l.api.GetOwnPresence(); err != nil {
l.log.Debug("previous session token was not found or invalid: %v", err)
valid = false
}
// restore original values
l.api.AccessToken = currentToken
l.api.UserID = currentUserID
l.api.DeviceID = currentDeviceID
return valid
}

View File

@@ -1,11 +1,10 @@
// Package config was added to store cross-package structs and interfaces.
package config
package linkpearl
import (
"database/sql"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
)
@@ -32,25 +31,11 @@ type Config struct {
// AccountDataSecret (Password) for encryption
AccountDataSecret string
// AccountDataLogReplace contains map of field name => value
// that will be used to replace mentioned account data fields with provided values
// when printing in logs (DEBUG, TRACE)
AccountDataLogReplace map[string]string
// MaxRetries for operations like auto join
MaxRetries int
// NoEncryption disabled encryption support
NoEncryption bool
// LPLogger used for linkpearl's glue code
LPLogger Logger
// APILogger used for matrix CS API calls
APILogger Logger
// StoreLogger used for persistent store
StoreLogger Logger
// CryptoLogger used for OLM machine
CryptoLogger Logger
// Logger
Logger zerolog.Logger
// DB object
DB *sql.DB
@@ -58,10 +43,16 @@ type Config struct {
Dialect string
}
// Logger implementation of crypto.Logger and mautrix.Logger
type Logger interface {
crypto.Logger
mautrix.WarnLogger
Info(message string, args ...interface{})
// LoginAs for cryptohelper
func (cfg *Config) LoginAs() *mautrix.ReqLogin {
return &mautrix.ReqLogin{
Type: mautrix.AuthTypePassword,
Identifier: mautrix.UserIdentifier{
Type: mautrix.IdentifierTypeUser,
User: cfg.Login,
},
Password: cfg.Password,
StoreCredentials: true,
StoreHomeserverURL: true,
}
}

26
vendor/gitlab.com/etke.cc/linkpearl/justfile generated vendored Normal file
View File

@@ -0,0 +1,26 @@
# show help by default
default:
@just --list --justfile {{ justfile() }}
# update go deps
update:
go get .
go get -u maunium.net/go/mautrix
go mod tidy
# run linter
lint:
golangci-lint run ./...
# automatically fix liter issues
lintfix:
golangci-lint run --fix ./...
vuln:
govulncheck ./...
# run unit tests
test:
@go test ${BUILDFLAGS} -coverprofile=cover.out ./...
@go tool cover -func=cover.out
-@rm -f cover.out

View File

@@ -5,12 +5,12 @@ import (
"database/sql"
lru "github.com/hashicorp/golang-lru/v2"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/crypto/cryptohelper"
"maunium.net/go/mautrix/event"
"gitlab.com/etke.cc/linkpearl/config"
"gitlab.com/etke.cc/linkpearl/store"
"maunium.net/go/mautrix/util/dbutil"
)
const (
@@ -22,14 +22,12 @@ const (
// Linkpearl object
type Linkpearl struct {
db *sql.DB
acc *lru.Cache[string, map[string]string]
acr *Crypter
aclr map[string]string
log config.Logger
api *mautrix.Client
olm *crypto.OlmMachine
store *store.Store
db *sql.DB
ch *cryptohelper.CryptoHelper
acc *lru.Cache[string, map[string]string]
acr *Crypter
log zerolog.Logger
api *mautrix.Client
joinPermit func(*event.Event) bool
autoleave bool
@@ -41,10 +39,7 @@ type ReqPresence struct {
StatusMsg string `json:"status_msg,omitempty"`
}
func setDefaults(cfg *config.Config) {
if cfg.AccountDataLogReplace == nil {
cfg.AccountDataLogReplace = make(map[string]string)
}
func setDefaults(cfg *Config) {
if cfg.MaxRetries == 0 {
cfg.MaxRetries = DefaultMaxRetries
}
@@ -66,13 +61,13 @@ func initCrypter(secret string) (*Crypter, error) {
}
// New linkpearl
func New(cfg *config.Config) (*Linkpearl, error) {
func New(cfg *Config) (*Linkpearl, error) {
setDefaults(cfg)
api, err := mautrix.NewClient(cfg.Homeserver, "", "")
if err != nil {
return nil, err
}
api.Logger = cfg.APILogger
api.Log = cfg.Logger
acc, _ := lru.New[string, map[string]string](cfg.AccountDataCache) //nolint:errcheck // addressed in setDefaults()
acr, err := initCrypter(cfg.AccountDataSecret)
@@ -84,35 +79,27 @@ func New(cfg *config.Config) (*Linkpearl, error) {
db: cfg.DB,
acc: acc,
acr: acr,
aclr: cfg.AccountDataLogReplace,
api: api,
log: cfg.LPLogger,
log: cfg.Logger,
joinPermit: cfg.JoinPermit,
autoleave: cfg.AutoLeave,
maxretries: cfg.MaxRetries,
}
storer := store.New(cfg.DB, cfg.Dialect, cfg.StoreLogger)
if err = storer.CreateTables(); err != nil {
db, err := dbutil.NewWithDB(cfg.DB, cfg.Dialect)
if err != nil {
return nil, err
}
lp.store = storer
lp.api.Store = storer
if err = lp.login(cfg.Login, cfg.Password); err != nil {
db.Log = dbutil.ZeroLogger(cfg.Logger)
lp.ch, err = cryptohelper.NewCryptoHelper(lp.api, []byte(cfg.Login), db)
if err != nil {
return nil, err
}
if !cfg.NoEncryption {
if err = lp.store.WithCrypto(lp.api.UserID, lp.api.DeviceID, cfg.StoreLogger); err != nil {
return nil, err
}
lp.olm = crypto.NewOlmMachine(lp.api, cfg.CryptoLogger, lp.store, lp.store)
if err = lp.olm.Load(); err != nil {
return nil, err
}
lp.ch.LoginAs = cfg.LoginAs()
if err = lp.ch.Init(); err != nil {
return nil, err
}
lp.api.Crypto = lp.ch
return lp, nil
}
@@ -126,14 +113,9 @@ func (l *Linkpearl) GetDB() *sql.DB {
return l.db
}
// GetStore returns underlying persistent store object, compatible with crypto.Store, crypto.StateStore and mautrix.Storer
func (l *Linkpearl) GetStore() *store.Store {
return l.store
}
// GetMachine returns underlying OLM machine
func (l *Linkpearl) GetMachine() *crypto.OlmMachine {
return l.olm
return l.ch.Machine()
}
// GetAccountDataCrypter returns crypter used for account data (if any)
@@ -165,21 +147,23 @@ func (l *Linkpearl) Start(optionalStatusMsg ...string) error {
err := l.SetPresence(event.PresenceOnline, statusMsg)
if err != nil {
l.log.Error("cannot set presence: %v", err)
l.log.Error().Err(err).Msg("cannot set presence")
}
defer l.Stop()
l.log.Info("client has been started")
l.log.Info().Msg("client has been started")
return l.api.Sync()
}
// Stop the client
func (l *Linkpearl) Stop() {
l.log.Debug("stopping the client")
err := l.api.SetPresence(event.PresenceOffline)
if err != nil {
l.log.Error("cannot set presence: %v", err)
l.log.Debug().Msg("stopping the client")
if err := l.api.SetPresence(event.PresenceOffline); err != nil {
l.log.Error().Err(err).Msg("cannot set presence")
}
l.api.StopSync()
l.log.Info("client has been stopped")
if err := l.ch.Close(); err != nil {
l.log.Error().Err(err).Msg("cannot close crypto helper")
}
l.log.Info().Msg("client has been stopped")
}

View File

@@ -4,27 +4,21 @@ import (
"fmt"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
// Send a message to the roomID and automatically try to encrypt it, if the destination room is encrypted
//
//nolint:unparam // it's public interface
func (l *Linkpearl) Send(roomID id.RoomID, content interface{}) (id.EventID, error) {
if !l.store.IsEncrypted(roomID) {
l.log.Debug("room %q is not encrypted", roomID)
return l.SendPlaintext(roomID, content)
}
l.log.Debug("room %q is encrypted", roomID)
encrypted, err := l.EncryptEvent(roomID, content)
l.log.Debug().Str("roomID", roomID.String()).Any("content", content).Msg("sending event")
resp, err := l.api.SendMessageEvent(roomID, event.EventMessage, content)
if err != nil {
l.log.Error("cannot encrypt message: %v, sending plaintext...", roomID, err)
return l.SendPlaintext(roomID, content)
return "", err
}
return l.SendEncrypted(roomID, encrypted)
return resp.EventID, nil
}
// SendNotice to a room with optional thread relation
@@ -39,7 +33,7 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, threadID id.EventID, message st
_, err := l.Send(roomID, &content)
if err != nil {
l.log.Error("cannot send a notice into room %q: %v", roomID, err)
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot send a notice int the room")
}
}
@@ -47,7 +41,7 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, threadID id.EventID, message st
func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgtype event.MessageType, relation *event.RelatesTo) error {
resp, err := l.GetClient().UploadMedia(*req)
if err != nil {
l.log.Error("cannot upload file %q: %v", req.FileName, err)
l.log.Error().Err(err).Str("file", req.FileName).Msg("cannot upload file")
return err
}
_, err = l.Send(roomID, &event.Content{
@@ -59,43 +53,8 @@ func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgt
},
})
if err != nil {
l.log.Error("cannot send uploaded file: %q: %v", req.FileName, err)
l.log.Error().Err(err).Str("file", req.FileName).Msg("cannot send uploaded file")
}
return err
}
// SendPlaintext sends plaintext event only
func (l *Linkpearl) SendPlaintext(roomID id.RoomID, content interface{}) (id.EventID, error) {
l.log.Debug("sending plaintext event to %q: %+v", roomID, content)
resp, err := l.api.SendMessageEvent(roomID, event.EventMessage, content)
if err != nil {
return "", err
}
return resp.EventID, nil
}
// SendEncrypted sends encrypted event only
func (l *Linkpearl) SendEncrypted(roomID id.RoomID, content interface{}) (id.EventID, error) {
l.log.Debug("sending encrypted event to %q: %+v", roomID, content)
resp, err := l.api.SendMessageEvent(roomID, event.EventEncrypted, content)
if err != nil {
return "", err
}
return resp.EventID, nil
}
// EncryptEvent before sending
func (l *Linkpearl) EncryptEvent(roomID id.RoomID, content interface{}) (*event.EncryptedEventContent, error) {
l.log.Debug("encrypting event %+v", content)
encrypted, err := l.olm.EncryptMegolmEvent(roomID, event.EventMessage, content)
if crypto.IsShareError(err) {
err = l.olm.ShareGroupSession(roomID, l.store.GetRoomMembers(roomID))
if err != nil {
return nil, err
}
encrypted, err = l.olm.EncryptMegolmEvent(roomID, event.EventMessage, content)
}
return encrypted, err
}

View File

@@ -1,214 +0,0 @@
package store
import (
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// NOTE: functions in that file are for crypto.Store implementation
// ref: https://pkg.go.dev/maunium.net/go/mautrix/crypto#Store
// Flush does nothing for this implementation as data is already persisted in the database.
// nolint // interface cannot be changed
func (s *Store) Flush() error {
s.log.Debug("flushing crypto store")
return nil
}
// PutNextBatch stores the next sync batch token for the current account.
func (s *Store) PutNextBatch(nextBatch string) error {
s.log.Debug("storing next batch token")
return s.s.PutNextBatch(nextBatch)
}
// GetNextBatch retrieves the next sync batch token for the current account.
func (s *Store) GetNextBatch() (string, error) {
s.log.Debug("loading next batch token")
return s.s.GetNextBatch()
}
// PutAccount stores an OlmAccount in the database.
func (s *Store) PutAccount(account *crypto.OlmAccount) error {
s.log.Debug("storing olm account")
return s.s.PutAccount(account)
}
// GetAccount retrieves an OlmAccount from the database.
func (s *Store) GetAccount() (*crypto.OlmAccount, error) {
s.log.Debug("loading olm account")
return s.s.GetAccount()
}
// HasSession returns whether there is an Olm session for the given sender key.
func (s *Store) HasSession(key id.SenderKey) bool {
s.log.Debug("check if olm session exists for the key %q", key)
return s.s.HasSession(key)
}
// GetSessions returns all the known Olm sessions for a sender key.
func (s *Store) GetSessions(key id.SenderKey) (crypto.OlmSessionList, error) {
s.log.Debug("loading olm session for the key %q", key)
return s.s.GetSessions(key)
}
// GetLatestSession retrieves the Olm session for a given sender key from the database that has the largest ID.
func (s *Store) GetLatestSession(key id.SenderKey) (*crypto.OlmSession, error) {
s.log.Debug("loading latest session for the key %q", key)
return s.s.GetLatestSession(key)
}
// AddSession persists an Olm session for a sender in the database.
func (s *Store) AddSession(key id.SenderKey, session *crypto.OlmSession) error {
s.log.Debug("adding new olm session for the key %q", key)
return s.s.AddSession(key, session)
}
// UpdateSession replaces the Olm session for a sender in the database.
func (s *Store) UpdateSession(key id.SenderKey, session *crypto.OlmSession) error {
s.log.Debug("update olm session for the key %q", key)
return s.s.UpdateSession(key, session)
}
// PutGroupSession stores an inbound Megolm group session for a room, sender and session.
func (s *Store) PutGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID, session *crypto.InboundGroupSession) error {
s.log.Debug("storing inbound group session for the room %q", roomID)
return s.s.PutGroupSession(roomID, senderKey, sessionID, session)
}
// GetGroupSession retrieves an inbound Megolm group session for a room, sender and session.
func (s *Store) GetGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*crypto.InboundGroupSession, error) {
s.log.Debug("loading inbound group session for the room %q", roomID)
return s.s.GetGroupSession(roomID, senderKey, sessionID)
}
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
// nolint // method is part of interface and cannot be changed
func (s *Store) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
s.log.Debug("storing withheld group session")
return s.s.PutWithheldGroupSession(content)
}
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
func (s *Store) GetWithheldGroupSession(roomID id.RoomID, senderKey id.SenderKey, sessionID id.SessionID) (*event.RoomKeyWithheldEventContent, error) {
s.log.Debug("loading withheld group session")
return s.s.GetWithheldGroupSession(roomID, senderKey, sessionID)
}
// GetGroupSessionsForRoom gets all the inbound Megolm sessions for a specific room. This is used for creating key
// export files. Unlike GetGroupSession, this should not return any errors about withheld keys.
func (s *Store) GetGroupSessionsForRoom(roomID id.RoomID) ([]*crypto.InboundGroupSession, error) {
s.log.Debug("loading group session for the room %q", roomID)
return s.s.GetGroupSessionsForRoom(roomID)
}
// GetAllGroupSessions gets all the inbound Megolm sessions in the store. This is used for creating key export
// files. Unlike GetGroupSession, this should not return any errors about withheld keys.
func (s *Store) GetAllGroupSessions() ([]*crypto.InboundGroupSession, error) {
s.log.Debug("loading all group sessions")
return s.s.GetAllGroupSessions()
}
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
func (s *Store) AddOutboundGroupSession(session *crypto.OutboundGroupSession) (err error) {
s.log.Debug("storing outbound group session")
return s.s.AddOutboundGroupSession(session)
}
// UpdateOutboundGroupSession replaces an outbound Megolm session with for same room and session ID.
func (s *Store) UpdateOutboundGroupSession(session *crypto.OutboundGroupSession) error {
s.log.Debug("updating outbound group session")
return s.s.UpdateOutboundGroupSession(session)
}
// GetOutboundGroupSession retrieves the outbound Megolm session for the given room ID.
func (s *Store) GetOutboundGroupSession(roomID id.RoomID) (*crypto.OutboundGroupSession, error) {
s.log.Debug("loading outbound group session")
return s.s.GetOutboundGroupSession(roomID)
}
// RemoveOutboundGroupSession removes the outbound Megolm session for the given room ID.
func (s *Store) RemoveOutboundGroupSession(roomID id.RoomID) error {
s.log.Debug("removing outbound group session")
return s.s.RemoveOutboundGroupSession(roomID)
}
// ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index.
// If the event information was not yet stored, it's stored now.
func (s *Store) ValidateMessageIndex(senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
s.log.Debug("validating message index")
return s.s.ValidateMessageIndex(senderKey, sessionID, eventID, index, timestamp)
}
// GetDevices returns a map of device IDs to device identities, including the identity and signing keys, for a given user ID.
func (s *Store) GetDevices(userID id.UserID) (map[id.DeviceID]*id.Device, error) {
s.log.Debug("loading devices of the %q", userID)
return s.s.GetDevices(userID)
}
// GetDevice returns the device dentity for a given user and device ID.
func (s *Store) GetDevice(userID id.UserID, deviceID id.DeviceID) (*id.Device, error) {
s.log.Debug("loading device %q for the %q", deviceID, userID)
return s.s.GetDevice(userID, deviceID)
}
// FindDeviceByKey finds a specific device by its sender key.
func (s *Store) FindDeviceByKey(userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
s.log.Debug("loading device of the %q by the key %q", userID, identityKey)
return s.s.FindDeviceByKey(userID, identityKey)
}
// PutDevice stores a single device for a user, replacing it if it exists already.
func (s *Store) PutDevice(userID id.UserID, device *id.Device) error {
s.log.Debug("storing device of the %q", userID)
return s.s.PutDevice(userID, device)
}
// PutDevices stores the device identity information for the given user ID.
func (s *Store) PutDevices(userID id.UserID, devices map[id.DeviceID]*id.Device) error {
s.log.Debug("storing devices of the %q", userID)
return s.s.PutDevices(userID, devices)
}
// FilterTrackedUsers finds all of the user IDs out of the given ones for which the database contains identity information.
func (s *Store) FilterTrackedUsers(users []id.UserID) ([]id.UserID, error) {
s.log.Debug("filtering tracked users")
return s.s.FilterTrackedUsers(users)
}
// PutCrossSigningKey stores a cross-signing key of some user along with its usage.
func (s *Store) PutCrossSigningKey(userID id.UserID, usage id.CrossSigningUsage, key id.Ed25519) error {
s.log.Debug("storing crosssigning key of the %q", userID)
return s.s.PutCrossSigningKey(userID, usage, key)
}
// GetCrossSigningKeys retrieves a user's stored cross-signing keys.
func (s *Store) GetCrossSigningKeys(userID id.UserID) (map[id.CrossSigningUsage]id.CrossSigningKey, error) {
s.log.Debug("loading crosssigning keys of the %q", userID)
return s.s.GetCrossSigningKeys(userID)
}
// PutSignature stores a signature of a cross-signing or device key along with the signer's user ID and key.
func (s *Store) PutSignature(signedUserID id.UserID, signedKey id.Ed25519, signerUserID id.UserID, signerKey id.Ed25519, signature string) error {
s.log.Debug("storing signature")
return s.s.PutSignature(signedUserID, signedKey, signerUserID, signerKey, signature)
}
// GetSignaturesForKeyBy retrieves the stored signatures for a given cross-signing or device key, by the given signer.
func (s *Store) GetSignaturesForKeyBy(userID id.UserID, key id.Ed25519, signerID id.UserID) (map[id.Ed25519]string, error) {
s.log.Debug("loading signatures")
return s.s.GetSignaturesForKeyBy(userID, key, signerID)
}
// IsKeySignedBy returns whether a cross-signing or device key is signed by the given signer.
func (s *Store) IsKeySignedBy(userID id.UserID, key id.Ed25519, signerID id.UserID, signerKey id.Ed25519) (bool, error) {
s.log.Debug("checking if key is signed by")
return s.s.IsKeySignedBy(userID, key, signerID, signerKey)
}
// DropSignaturesByKey deletes the signatures made by the given user and key from the store. It returns the number of signatures deleted.
func (s *Store) DropSignaturesByKey(userID id.UserID, key id.Ed25519) (int64, error) {
s.log.Debug("removing signatures by the %q/%q", userID, key)
return s.s.DropSignaturesByKey(userID, key)
}

View File

@@ -1,209 +0,0 @@
package store
import (
"encoding/json"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
var acceptedMembershipTypes = []event.Membership{
event.MembershipJoin,
event.MembershipInvite,
event.MembershipBan,
event.MembershipLeave,
}
// IsEncrypted returns whether a room is encrypted.
func (s *Store) IsEncrypted(roomID id.RoomID) bool {
if !s.encryption {
return false
}
s.log.Debug("checking if room %q is encrypted", roomID)
return s.GetEncryptionEvent(roomID) != nil
}
// SetEncryptionEvent creates or updates room's encryption event info
func (s *Store) SetEncryptionEvent(evt *event.Event) {
if !s.encryption {
return
}
if evt == nil {
return
}
var encryptionEventJSON []byte
encryptionEventJSON, err := json.Marshal(evt)
if err != nil {
s.log.Debug("cannot marshal encryption event: %v", err)
return
}
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO rooms VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO rooms VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
update := "UPDATE rooms SET encryption_event = $1 WHERE room_id = $2"
_, err = tx.Exec(update, encryptionEventJSON, evt.RoomID)
if err != nil {
s.log.Error("cannot update encryption event: %v", err)
// nolint // we already have err to return
tx.Rollback()
return
}
_, err = tx.Exec(insert, evt.RoomID, encryptionEventJSON)
if err != nil {
s.log.Error("cannot insert encryption event: %v", err)
// nolint // interface doesn't allow to return error
tx.Rollback()
return
}
err = tx.Commit()
if err != nil {
s.log.Error("cannot commit transaction: %v", err)
}
}
// SetMembership saves room members
func (s *Store) SetMembership(evt *event.Event) {
s.log.Debug("saving membership event for %q", evt.RoomID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO room_members VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO room_members VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
del := "DELETE FROM room_members WHERE room_id = $1 AND user_id = $2"
membership := evt.Content.AsMember().Membership
if s.shouldIgnoreMembership(membership) {
return
}
if membership.IsInviteOrJoin() {
_, err := tx.Exec(insert, evt.RoomID, evt.GetStateKey())
if err != nil {
s.log.Error("cannot insert membership event: %v", err)
// nolint // interface doesn't allow to return error
tx.Rollback()
return
}
} else {
_, err := tx.Exec(del, evt.RoomID, evt.GetStateKey())
if err != nil {
s.log.Error("cannot delete membership event: %v", err)
// nolint // interface doesn't allow to return error
tx.Rollback()
return
}
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot commit transaction: %v", commitErr)
// nolint // interface doesn't allow to return error
tx.Rollback()
}
}
// GetRoomMembers ...
func (s *Store) GetRoomMembers(roomID id.RoomID) []id.UserID {
s.log.Debug("loading room members of %q", roomID)
query := "SELECT user_id FROM room_members WHERE room_id = $1"
rows, err := s.db.Query(query, roomID)
users := make([]id.UserID, 0)
if err != nil {
s.log.Error("cannot load room members: %v", err)
return users
}
defer rows.Close()
var userID id.UserID
for rows.Next() {
if err := rows.Scan(&userID); err == nil {
users = append(users, userID)
}
}
return users
}
// SaveSession to DB
func (s *Store) SaveSession(userID id.UserID, deviceID id.DeviceID, accessToken string) {
s.log.Debug("saving session credentials of %q/%q", userID, deviceID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO session VALUES (?, ?, ?)"
case "postgres":
insert = "INSERT INTO session VALUES ($1, $2, $3) ON CONFLICT DO NOTHING"
}
update := "UPDATE session SET access_token = $1, device_id = $2 WHERE user_id = $3"
if _, err = tx.Exec(update, accessToken, deviceID, userID); err != nil {
s.log.Error("cannot update session credentials: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
if _, err = tx.Exec(insert, userID, deviceID, accessToken); err != nil {
s.log.Error("cannot insert session credentials: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
err = tx.Commit()
if err != nil {
s.log.Error("cannot commit transaction: %v", err)
}
}
// LoadSession from DB (user ID, device ID, access token)
func (s *Store) LoadSession() (id.UserID, id.DeviceID, string) {
s.log.Debug("loading session credentials...")
row := s.db.QueryRow("SELECT * FROM session LIMIT 1")
var userID id.UserID
var deviceID id.DeviceID
var accessToken string
if err := row.Scan(&userID, &deviceID, &accessToken); err != nil {
s.log.Error("cannot load session credentials: %v", err)
return "", "", ""
}
return userID, deviceID, accessToken
}
func (s *Store) shouldIgnoreMembership(membership event.Membership) bool {
for _, mtype := range acceptedMembershipTypes {
if membership == mtype {
return false
}
}
return true
}

View File

@@ -1,66 +0,0 @@
package store
var migrations = []string{
`
CREATE TABLE IF NOT EXISTS user_filter_ids (
user_id VARCHAR(255) PRIMARY KEY,
filter_id VARCHAR(255)
)
`,
`
CREATE TABLE IF NOT EXISTS user_batch_tokens (
user_id VARCHAR(255) PRIMARY KEY,
next_batch_token VARCHAR(255)
)
`,
`
CREATE TABLE IF NOT EXISTS rooms (
room_id VARCHAR(255) PRIMARY KEY,
encryption_event VARCHAR(65535) NULL
)
`,
`
CREATE TABLE IF NOT EXISTS room_members (
room_id VARCHAR(255),
user_id VARCHAR(255),
PRIMARY KEY (room_id, user_id)
)
`,
`
CREATE TABLE IF NOT EXISTS session (
user_id VARCHAR(255),
device_id VARCHAR(255),
access_token VARCHAR(255)
)
`,
}
// CreateTables applies all the pending database migrations.
func (s *Store) CreateTables() error {
s.log.Debug("migrating database...")
tx, beginErr := s.db.Begin()
if beginErr != nil {
s.log.Error("cannot begin transaction: %v", beginErr)
return beginErr
}
for _, query := range migrations {
_, execErr := tx.Exec(query)
if execErr != nil {
s.log.Error("cannot apply migration: %v", execErr)
// nolint // we already have the execErr to return
tx.Rollback()
return execErr
}
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot commit transaction: %v", commitErr)
// nolint // we already have the commitErr to return
tx.Rollback()
return commitErr
}
return nil
}

View File

@@ -1,63 +0,0 @@
package store
import (
"database/sql"
"encoding/json"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// NOTE: functions in that file are for crypto.StateStore implementation
// ref: https://pkg.go.dev/maunium.net/go/mautrix/crypto#StateStore
// GetEncryptionEvent returns the encryption event's content for an encrypted room.
func (s *Store) GetEncryptionEvent(roomID id.RoomID) *event.EncryptionEventContent {
if !s.encryption {
return nil
}
s.log.Debug("finding encryption event of %q", roomID)
query := "SELECT encryption_event FROM rooms WHERE room_id = $1"
row := s.db.QueryRow(query, roomID)
var encryptionEventJSON []byte
err := row.Scan(&encryptionEventJSON)
if err != nil && err != sql.ErrNoRows {
s.log.Error("cannot find encryption event: %v", err)
return nil
}
var encryptionEvent event.EncryptionEventContent
if err := json.Unmarshal(encryptionEventJSON, &encryptionEvent); err != nil {
s.log.Debug("cannot unmarshal encryption event: %q", err)
return nil
}
return &encryptionEvent
}
// FindSharedRooms returns the encrypted rooms that another user is also in for a user ID.
func (s *Store) FindSharedRooms(userID id.UserID) []id.RoomID {
if !s.encryption {
return nil
}
s.log.Debug("loading shared rooms for %q", userID)
query := "SELECT room_id FROM room_members WHERE user_id = $1"
rows, queryErr := s.db.Query(query, userID)
rooms := make([]id.RoomID, 0)
if queryErr != nil {
s.log.Error("cannot load room members: %q", queryErr)
return rooms
}
defer rows.Close()
var roomID id.RoomID
for rows.Next() {
scanErr := rows.Scan(&roomID)
if scanErr != nil {
continue
}
rooms = append(rooms, roomID)
}
return rooms
}

View File

@@ -1,55 +0,0 @@
// Package store implements crypto.Store, crypto.StateStore, mautrix.Storer and some additional "glue methods"
package store
import (
"database/sql"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
"gitlab.com/etke.cc/linkpearl/config"
)
// Store for the matrix
type Store struct {
db *sql.DB
dialect string
log config.Logger
encryption bool
s *crypto.SQLCryptoStore
}
// New store
func New(db *sql.DB, dialect string, log config.Logger) *Store {
return &Store{
db: db,
log: log,
dialect: dialect,
}
}
// WithCrypto adds crypto store support
func (s *Store) WithCrypto(userID id.UserID, deviceID id.DeviceID, logger config.Logger) error {
s.log.Debug("crypto store enabled")
s.encryption = true
db, err := dbutil.NewWithDB(s.db, s.dialect)
if err != nil {
logger.Error("cannot init database: %v", err)
return err
}
s.s = crypto.NewSQLCryptoStore(
db,
dbutil.NoopLogger,
userID.String(),
deviceID,
[]byte(userID),
)
return s.s.DB.Upgrade()
}
// GetDialect returns database dialect
func (s *Store) GetDialect() string {
return s.dialect
}

View File

@@ -1,126 +0,0 @@
package store
import (
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/id"
)
// NOTE: functions in that file are for mautrix.Storer implementation
// ref: https://pkg.go.dev/maunium.net/go/mautrix#Storer
// SaveFilterID to DB
func (s *Store) SaveFilterID(userID id.UserID, filterID string) {
s.log.Debug("saving filter ID %q for %q", filterID, userID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO user_filter_ids VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO user_filter_ids VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
update := "UPDATE user_filter_ids SET filter_id = $1 WHERE user_id = $2"
_, updateErr := tx.Exec(update, filterID, userID)
if updateErr != nil {
s.log.Error("cannot update filter ID: %v", updateErr)
// nolint // no need to check error here
tx.Rollback()
return
}
_, insertErr := tx.Exec(insert, userID, filterID)
if insertErr != nil {
s.log.Error("cannot create filter ID: %v", insertErr)
// nolint // no need to check error here
tx.Rollback()
return
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot upsert filter ID: %v", commitErr)
// nolint // no need to check error here
tx.Rollback()
}
}
// LoadFilterID from DB
func (s *Store) LoadFilterID(userID id.UserID) string {
s.log.Debug("loading filter ID for %q", userID)
query := "SELECT filter_id FROM user_filter_ids WHERE user_id = $1"
row := s.db.QueryRow(query, userID)
var filterID string
if err := row.Scan(&filterID); err != nil {
s.log.Error("cannot load filter ID: %q", err)
return ""
}
return filterID
}
// SaveNextBatch to DB
func (s *Store) SaveNextBatch(userID id.UserID, nextBatchToken string) {
s.log.Debug("saving next batch token for %q", userID)
tx, err := s.db.Begin()
if err != nil {
s.log.Error("cannot begin transaction: %v", err)
return
}
var insert string
switch s.dialect {
case "sqlite3":
insert = "INSERT OR IGNORE INTO user_batch_tokens VALUES (?, ?)"
case "postgres":
insert = "INSERT INTO user_batch_tokens VALUES ($1, $2) ON CONFLICT DO NOTHING"
}
update := "UPDATE user_batch_tokens SET next_batch_token = $1 WHERE user_id = $2"
if _, err := tx.Exec(update, nextBatchToken, userID); err != nil {
s.log.Error("cannot update next batch token: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
if _, err := tx.Exec(insert, userID, nextBatchToken); err != nil {
s.log.Error("cannot insert next batch token: %v", err)
// nolint // no need to check error here
tx.Rollback()
return
}
commitErr := tx.Commit()
if commitErr != nil {
s.log.Error("cannot commit transaction: %v", commitErr)
}
}
// LoadNextBatch from DB
func (s *Store) LoadNextBatch(userID id.UserID) string {
s.log.Debug("loading next batch token for %q", userID)
query := "SELECT next_batch_token FROM user_batch_tokens WHERE user_id = $1"
row := s.db.QueryRow(query, userID)
var batchToken string
if err := row.Scan(&batchToken); err != nil {
s.log.Error("cannot load next batch token: %v", err)
return ""
}
return batchToken
}
// SaveRoom to DB, not implemented
func (s *Store) SaveRoom(room *mautrix.Room) {
s.log.Debug("saving room %q (stub, not implemented)", room.ID)
}
// LoadRoom from DB, not implemented
func (s *Store) LoadRoom(roomID id.RoomID) *mautrix.Room {
s.log.Debug("loading room %q (stub, not implemented)", roomID)
return mautrix.NewRoom(roomID)
}

View File

@@ -11,31 +11,27 @@ import (
// OnEventType allows callers to be notified when there are new events for the given event type.
// There are no duplicate checks.
func (l *Linkpearl) OnEventType(eventType event.Type, callback mautrix.EventHandler) {
l.api.Syncer.(*mautrix.DefaultSyncer).OnEventType(eventType, callback)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(eventType, callback)
}
// OnSync shortcut to mautrix.DefaultSyncer.OnSync
func (l *Linkpearl) OnSync(callback mautrix.SyncHandler) {
l.api.Syncer.(*mautrix.DefaultSyncer).OnSync(callback)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnSync(callback)
}
// OnEvent shortcut to mautrix.DefaultSyncer.OnEvent
func (l *Linkpearl) OnEvent(callback mautrix.EventHandler) {
l.api.Syncer.(*mautrix.DefaultSyncer).OnEvent(callback)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEvent(callback)
}
func (l *Linkpearl) initSync() {
if l.olm != nil {
l.api.Syncer.(*mautrix.DefaultSyncer).OnSync(l.olm.ProcessSyncResponse)
l.api.Syncer.(*mautrix.DefaultSyncer).OnEventType(
event.StateEncryption,
func(source mautrix.EventSource, evt *event.Event) {
go l.onEncryption(source, evt)
},
)
}
l.api.Syncer.(*mautrix.DefaultSyncer).OnEventType(
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateEncryption,
func(source mautrix.EventSource, evt *event.Event) {
go l.onEncryption(source, evt)
},
)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateMember,
func(source mautrix.EventSource, evt *event.Event) {
go l.onMembership(source, evt)
@@ -43,11 +39,9 @@ func (l *Linkpearl) initSync() {
)
}
func (l *Linkpearl) onMembership(_ mautrix.EventSource, evt *event.Event) {
if l.olm != nil {
l.olm.HandleMemberEvent(evt)
}
l.store.SetMembership(evt)
func (l *Linkpearl) onMembership(src mautrix.EventSource, evt *event.Event) {
l.ch.Machine().HandleMemberEvent(src, evt)
l.api.StateStore.SetMembership(evt.RoomID, id.UserID(evt.GetStateKey()), evt.Content.AsMember().Membership)
// potentially autoaccept invites
l.onInvite(evt)
@@ -78,9 +72,9 @@ func (l *Linkpearl) tryJoin(roomID id.RoomID, retry int) {
_, err := l.api.JoinRoomByID(roomID)
if err != nil {
l.log.Error("cannot join the room %q: %v", roomID, err)
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot join room")
time.Sleep(5 * time.Second)
l.log.Debug("trying to join again (%d/%d)", retry+1, l.maxretries)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to join again")
l.tryJoin(roomID, retry+1)
}
}
@@ -92,9 +86,9 @@ func (l *Linkpearl) tryLeave(roomID id.RoomID, retry int) {
_, err := l.api.LeaveRoom(roomID)
if err != nil {
l.log.Error("cannot leave room: %v", err)
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot leave room")
time.Sleep(5 * time.Second)
l.log.Debug("trying to leave again (%d/%d)", retry+1, l.maxretries)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to leave again")
l.tryLeave(roomID, retry+1)
}
}
@@ -104,7 +98,12 @@ func (l *Linkpearl) onEmpty(evt *event.Event) {
return
}
members := l.store.GetRoomMembers(evt.RoomID)
members, err := l.api.StateStore.GetRoomJoinedOrInvitedMembers(evt.RoomID)
if err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members")
return
}
if len(members) >= 1 && members[0] != l.api.UserID {
return
}
@@ -113,5 +112,5 @@ func (l *Linkpearl) onEmpty(evt *event.Event) {
}
func (l *Linkpearl) onEncryption(_ mautrix.EventSource, evt *event.Event) {
l.store.SetEncryptionEvent(evt)
l.api.StateStore.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption())
}