updated deps; updated healthchecks.io integration

This commit is contained in:
Aine
2024-04-07 14:42:12 +03:00
parent 271a4a0e31
commit 15d61f174e
122 changed files with 3432 additions and 4613 deletions

View File

@@ -65,6 +65,7 @@ env vars
* **POSTMOOGLE_STATUSMSG** - presence status message * **POSTMOOGLE_STATUSMSG** - presence status message
* **POSTMOOGLE_MONITORING_SENTRY_DSN** - sentry DSN * **POSTMOOGLE_MONITORING_SENTRY_DSN** - sentry DSN
* **POSTMOOGLE_MONITORING_SENTRY_RATE** - sentry sample rate, from 0 to 100 (default: 20) * **POSTMOOGLE_MONITORING_SENTRY_RATE** - sentry sample rate, from 0 to 100 (default: 20)
* **POSTMOOGLE_MONITORING_HEALTHCHECKS_URL** - healthchecks.io url, default: `https://hc-ping.com`
* **POSTMOOGLE_MONITORING_HEALTHCHECKS_UUID** - healthchecks.io UUID * **POSTMOOGLE_MONITORING_HEALTHCHECKS_UUID** - healthchecks.io UUID
* **POSTMOOGLE_MONITORING_HEALTHCHECKS_DURATION** - heathchecks.io duration between pings in secods (default: 5) * **POSTMOOGLE_MONITORING_HEALTHCHECKS_DURATION** - heathchecks.io duration between pings in secods (default: 5)
* **POSTMOOGLE_LOGLEVEL** - log level * **POSTMOOGLE_LOGLEVEL** - log level

View File

@@ -2,6 +2,7 @@ package main
import ( import (
"database/sql" "database/sql"
"fmt"
"io" "io"
"os" "os"
"os/signal" "os/signal"
@@ -15,7 +16,7 @@ import (
_ "github.com/mattn/go-sqlite3" _ "github.com/mattn/go-sqlite3"
"github.com/mileusna/crontab" "github.com/mileusna/crontab"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"gitlab.com/etke.cc/go/healthchecks" "gitlab.com/etke.cc/go/healthchecks/v2"
"gitlab.com/etke.cc/go/psd" "gitlab.com/etke.cc/go/psd"
"gitlab.com/etke.cc/linkpearl" "gitlab.com/etke.cc/linkpearl"
@@ -85,14 +86,18 @@ func initLog(cfg *config.Config) {
} }
func initHealthchecks(cfg *config.Config) { func initHealthchecks(cfg *config.Config) {
if cfg.Monitoring.HealchecksUUID == "" { if cfg.Monitoring.HealthchecksUUID == "" {
return return
} }
hc = healthchecks.New(cfg.Monitoring.HealchecksUUID, func(operation string, err error) { hc = healthchecks.New(
healthchecks.WithBaseURL(cfg.Monitoring.HealthchecksURL),
healthchecks.WithCheckUUID(cfg.Monitoring.HealthchecksUUID),
healthchecks.WithErrLog(func(operation string, err error) {
log.Error().Err(err).Str("operation", operation).Msg("healthchecks operation failed") log.Error().Err(err).Str("operation", operation).Msg("healthchecks operation failed")
}) }),
)
hc.Start(strings.NewReader("starting postmoogle")) hc.Start(strings.NewReader("starting postmoogle"))
go hc.Auto(cfg.Monitoring.HealthechsDuration) go hc.Auto(cfg.Monitoring.HealthchecksDuration)
} }
func initMatrix(cfg *config.Config) { func initMatrix(cfg *config.Config) {
@@ -200,6 +205,9 @@ func recovery() {
defer shutdown() defer shutdown()
err := recover() err := recover()
if err != nil { if err != nil {
if hc != nil {
hc.ExitStatus(1, strings.NewReader(fmt.Sprintf("panic: %v", err)))
}
sentry.CurrentHub().Recover(err) sentry.CurrentHub().Recover(err)
} }
} }

View File

@@ -40,8 +40,9 @@ func New() *Config {
Monitoring: Monitoring{ Monitoring: Monitoring{
SentryDSN: env.String("monitoring.sentry.dsn", env.String("sentry.dsn", "")), SentryDSN: env.String("monitoring.sentry.dsn", env.String("sentry.dsn", "")),
SentrySampleRate: env.Int("monitoring.sentry.rate", env.Int("sentry.rate", 0)), SentrySampleRate: env.Int("monitoring.sentry.rate", env.Int("sentry.rate", 0)),
HealchecksUUID: env.String("monitoring.healthchecks.uuid", ""), HealthchecksURL: env.String("monitoring.healthchecks.url", defaultConfig.Monitoring.HealthchecksURL),
HealthechsDuration: time.Duration(env.Int("monitoring.healthchecks.duration", int(defaultConfig.Monitoring.HealthechsDuration))) * time.Second, HealthchecksUUID: env.String("monitoring.healthchecks.uuid"),
HealthchecksDuration: time.Duration(env.Int("monitoring.healthchecks.duration", int(defaultConfig.Monitoring.HealthchecksDuration))) * time.Second,
}, },
LogLevel: env.String("loglevel", defaultConfig.LogLevel), LogLevel: env.String("loglevel", defaultConfig.LogLevel),
DB: DB{ DB: DB{

View File

@@ -16,7 +16,8 @@ var defaultConfig = &Config{
}, },
Monitoring: Monitoring{ Monitoring: Monitoring{
SentrySampleRate: 20, SentrySampleRate: 20,
HealthechsDuration: 5, HealthchecksURL: "https://hc-ping.com",
HealthchecksDuration: 60,
}, },
TLS: TLS{ TLS: TLS{
Port: "587", Port: "587",

View File

@@ -70,8 +70,9 @@ type TLS struct {
type Monitoring struct { type Monitoring struct {
SentryDSN string SentryDSN string
SentrySampleRate int SentrySampleRate int
HealchecksUUID string HealthchecksURL string
HealthechsDuration time.Duration HealthchecksUUID string
HealthchecksDuration time.Duration
} }
// Mailboxes config // Mailboxes config

View File

@@ -1,3 +1,3 @@
#!/bin/bash #!/bin/bash
ssmtp -v test+sub@localhost < $1 ssmtp -v aine@gelato.casa < $1

21
go.mod
View File

@@ -10,7 +10,7 @@ require (
github.com/archdx/zerolog-sentry v1.8.2 github.com/archdx/zerolog-sentry v1.8.2
github.com/emersion/go-msgauth v0.6.8 github.com/emersion/go-msgauth v0.6.8
github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43
github.com/emersion/go-smtp v0.20.2 github.com/emersion/go-smtp v0.21.0
github.com/fsnotify/fsnotify v1.7.0 github.com/fsnotify/fsnotify v1.7.0
github.com/gabriel-vasile/mimetype v1.4.3 github.com/gabriel-vasile/mimetype v1.4.3
github.com/getsentry/sentry-go v0.27.0 github.com/getsentry/sentry-go v0.27.0
@@ -24,14 +24,14 @@ require (
github.com/rs/zerolog v1.32.0 github.com/rs/zerolog v1.32.0
gitlab.com/etke.cc/go/env v1.1.0 gitlab.com/etke.cc/go/env v1.1.0
gitlab.com/etke.cc/go/fswatcher v1.0.0 gitlab.com/etke.cc/go/fswatcher v1.0.0
gitlab.com/etke.cc/go/healthchecks v1.0.1 gitlab.com/etke.cc/go/healthchecks/v2 v2.0.0
gitlab.com/etke.cc/go/mxidwc v1.0.0 gitlab.com/etke.cc/go/mxidwc v1.0.0
gitlab.com/etke.cc/go/psd v1.1.1 gitlab.com/etke.cc/go/psd v1.1.1
gitlab.com/etke.cc/go/secgen v1.2.0 gitlab.com/etke.cc/go/secgen v1.2.0
gitlab.com/etke.cc/go/validator v1.0.7 gitlab.com/etke.cc/go/validator v1.0.7
gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a gitlab.com/etke.cc/linkpearl v0.0.0-20240316115913-106577b88942
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0
maunium.net/go/mautrix v0.17.0 maunium.net/go/mautrix v0.18.0
) )
require ( require (
@@ -54,12 +54,11 @@ require (
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect github.com/tidwall/sjson v1.2.5 // indirect
github.com/yuin/goldmark v1.7.0 // indirect github.com/yuin/goldmark v1.7.1 // indirect
gitlab.com/etke.cc/go/trysmtp v1.1.3 // indirect gitlab.com/etke.cc/go/trysmtp v1.1.3 // indirect
go.mau.fi/util v0.4.0 // indirect go.mau.fi/util v0.4.1 // indirect
golang.org/x/crypto v0.19.0 // indirect golang.org/x/crypto v0.22.0 // indirect
golang.org/x/net v0.21.0 // indirect golang.org/x/net v0.24.0 // indirect
golang.org/x/sys v0.17.0 // indirect golang.org/x/sys v0.19.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
maunium.net/go/maulogger/v2 v2.4.1 // indirect
) )

50
go.sum
View File

@@ -16,8 +16,8 @@ github.com/emersion/go-msgauth v0.6.8/go.mod h1:YDwuyTCUHu9xxmAeVj0eW4INnwB6NNZo
github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= github.com/emersion/go-sasl v0.0.0-20200509203442-7bfe0ed36a21/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 h1:hH4PQfOndHDlpzYfLAAfl63E8Le6F2+EL/cdhlkyRJY= github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43 h1:hH4PQfOndHDlpzYfLAAfl63E8Le6F2+EL/cdhlkyRJY=
github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ= github.com/emersion/go-sasl v0.0.0-20231106173351-e73c9f7bad43/go.mod h1:iL2twTeMvZnrg54ZoPDNfJaJaqy0xIQFuBdrLsmspwQ=
github.com/emersion/go-smtp v0.20.2 h1:peX42Qnh5Q0q3vrAnRy43R/JwTnnv75AebxbkTL7Ia4= github.com/emersion/go-smtp v0.21.0 h1:ZDZmX9aFUuPlD1lpoT0nC/nozZuIkSCyQIyxdijjCy0=
github.com/emersion/go-smtp v0.20.2/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ= github.com/emersion/go-smtp v0.21.0/go.mod h1:qm27SGYgoIPRot6ubfQ/GpiPy/g3PaZAVRxiO/sDUgQ=
github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA= github.com/fsnotify/fsnotify v1.7.0 h1:8JEhPFa5W2WU7YfeZzPNqzMP6Lwt7L2715Ggo0nosvA=
github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM= github.com/fsnotify/fsnotify v1.7.0/go.mod h1:40Bi/Hjc2AVfZrqy+aj+yEI+/bRxZnMJyTJwOpGvigM=
github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0= github.com/gabriel-vasile/mimetype v1.4.3 h1:in2uUcidCuFcDKtdcBxlR0rJ1+fsokWf+uqxgUFjbI0=
@@ -80,8 +80,8 @@ github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo=
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM=
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk= github.com/stretchr/testify v1.9.0 h1:HtqpIVDClZ4nwg75+f6Lvsy/wHu+3BoSGCbBAcpTsTg=
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo= github.com/stretchr/testify v1.9.0/go.mod h1:r2ic/lqez/lEtzL7wO/rwa5dbSLXVDPFyf8C91i36aY=
github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.14.2/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U= github.com/tidwall/gjson v1.17.1 h1:wlYEnwqAHgzmhNUFfw7Xalt2JzQvsMx2Se4PcoFCT/U=
github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk= github.com/tidwall/gjson v1.17.1/go.mod h1:/wbyibRr2FHMks5tjHJ5F8dMZh3AcwJEMf5vlfC0lxk=
@@ -92,14 +92,14 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA= github.com/yuin/goldmark v1.7.1 h1:3bajkSilaCbjdKVsKdZjZCLBNPL9pYzrCakKaf4U49U=
github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E= github.com/yuin/goldmark v1.7.1/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
gitlab.com/etke.cc/go/env v1.1.0 h1:nbMhZkMu6C8lysRlb5siIiylWuyVkGAgEvwWEqz/82o= gitlab.com/etke.cc/go/env v1.1.0 h1:nbMhZkMu6C8lysRlb5siIiylWuyVkGAgEvwWEqz/82o=
gitlab.com/etke.cc/go/env v1.1.0/go.mod h1:e1l4RM5MA1sc0R1w/RBDAESWRwgo5cOG9gx8BKUn2C4= gitlab.com/etke.cc/go/env v1.1.0/go.mod h1:e1l4RM5MA1sc0R1w/RBDAESWRwgo5cOG9gx8BKUn2C4=
gitlab.com/etke.cc/go/fswatcher v1.0.0 h1:uyiVn+1NVCjOLZrXSZouIDBDZBMwVipS4oYuvAFpPzo= gitlab.com/etke.cc/go/fswatcher v1.0.0 h1:uyiVn+1NVCjOLZrXSZouIDBDZBMwVipS4oYuvAFpPzo=
gitlab.com/etke.cc/go/fswatcher v1.0.0/go.mod h1:MqTOxyhXfvaVZQUL9/Ksbl2ow1PTBVu3eqIldvMq0RE= gitlab.com/etke.cc/go/fswatcher v1.0.0/go.mod h1:MqTOxyhXfvaVZQUL9/Ksbl2ow1PTBVu3eqIldvMq0RE=
gitlab.com/etke.cc/go/healthchecks v1.0.1 h1:IxPB+r4KtEM6wf4K7MeQoH1XnuBITMGUqFaaRIgxeUY= gitlab.com/etke.cc/go/healthchecks/v2 v2.0.0 h1:/VX2V/I0kH0Yah546EHcOZkuxbEj+8FBmsnb5uOXGUw=
gitlab.com/etke.cc/go/healthchecks v1.0.1/go.mod h1:EzQjwSawh8tQEX43Ls0dI9mND6iWd5NHtmapdO24fMI= gitlab.com/etke.cc/go/healthchecks/v2 v2.0.0/go.mod h1:DdNc1ESc1cAgOdsIwxxV+RUWgn6ewCpfFKzLuF0kSfc=
gitlab.com/etke.cc/go/mxidwc v1.0.0 h1:6EAlJXvs3nU4RaMegYq6iFlyVvLw7JZYnZmNCGMYQP0= gitlab.com/etke.cc/go/mxidwc v1.0.0 h1:6EAlJXvs3nU4RaMegYq6iFlyVvLw7JZYnZmNCGMYQP0=
gitlab.com/etke.cc/go/mxidwc v1.0.0/go.mod h1:E/0kh45SAN9+ntTG0cwkAEKdaPxzvxVmnjwivm9nmz8= gitlab.com/etke.cc/go/mxidwc v1.0.0/go.mod h1:E/0kh45SAN9+ntTG0cwkAEKdaPxzvxVmnjwivm9nmz8=
gitlab.com/etke.cc/go/psd v1.1.1 h1:UIL0X+thvYaeBTX8/G6lilqAToGCypihujGu5gtK5zQ= gitlab.com/etke.cc/go/psd v1.1.1 h1:UIL0X+thvYaeBTX8/G6lilqAToGCypihujGu5gtK5zQ=
@@ -110,30 +110,28 @@ gitlab.com/etke.cc/go/trysmtp v1.1.3 h1:e2EHond77onMaecqCg6mWumffTSEf+ycgj88nbee
gitlab.com/etke.cc/go/trysmtp v1.1.3/go.mod h1:lOO7tTdAE0a3ETV3wN3GJ7I1Tqewu7YTpPWaOmTteV0= gitlab.com/etke.cc/go/trysmtp v1.1.3/go.mod h1:lOO7tTdAE0a3ETV3wN3GJ7I1Tqewu7YTpPWaOmTteV0=
gitlab.com/etke.cc/go/validator v1.0.7 h1:4BGDTa9x68vJhbyn7m8W2yX+2Nb5im9+JLRrgoLUlF4= gitlab.com/etke.cc/go/validator v1.0.7 h1:4BGDTa9x68vJhbyn7m8W2yX+2Nb5im9+JLRrgoLUlF4=
gitlab.com/etke.cc/go/validator v1.0.7/go.mod h1:Id0SxRj0J3IPhiKlj0w1plxVLZfHlkwipn7HfRZsDts= gitlab.com/etke.cc/go/validator v1.0.7/go.mod h1:Id0SxRj0J3IPhiKlj0w1plxVLZfHlkwipn7HfRZsDts=
gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a h1:30WtX+uepGqyFnU7jIockJWxQUeYdljhhk63DCOXLZs= gitlab.com/etke.cc/linkpearl v0.0.0-20240316115913-106577b88942 h1:hhDXBsDcYgAit9gwfvawnPXMIwHNKL9DL1kCCyyzB8A=
gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a/go.mod h1:3lqQGDDtk52Jm8PD3mZ3qhmIp4JXuq95waWH5vmEacc= gitlab.com/etke.cc/linkpearl v0.0.0-20240316115913-106577b88942/go.mod h1:0AIH2o0fi4WoZhMw+tW63rrcI5aERH9c34RVHQXn1Q0=
go.mau.fi/util v0.4.0 h1:S2X3qU4pUcb/vxBRfAuZjbrR9xVMAXSjQojNBLPBbhs= go.mau.fi/util v0.4.1 h1:3EC9KxIXo5+h869zDGf5OOZklRd/FjeVnimTwtm3owg=
go.mau.fi/util v0.4.0/go.mod h1:leeiHtgVBuN+W9aDii3deAXnfC563iN3WK6BF8/AjNw= go.mau.fi/util v0.4.1/go.mod h1:GjkTEBsehYZbSh2LlE6cWEn+6ZIZTGrTMM/5DMNlmFY=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo= golang.org/x/crypto v0.22.0 h1:g1v0xeRhjcugydODzvb3mEM9SQ0HGp9s/nh3COQ/C30=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU= golang.org/x/crypto v0.22.0/go.mod h1:vr6Su+7cTlO45qkww3VDJlzDn0ctJvRgYbC2NvXHt+M=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225 h1:LfspQV/FYTatPTr/3HzIcmiUFH7PGP+OQ6mgDYo3yuQ= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0 h1:985EYyeCOxTpcgOTJpflJUwOeEz0CQOdPt73OzpE9F8=
golang.org/x/exp v0.0.0-20240222234643-814bf88cf225/go.mod h1:CxmFvTBINI24O/j8iY7H1xHzx2i4OsyguNBmN/uPtqc= golang.org/x/exp v0.0.0-20240404231335-c0f41cb1a7a0/go.mod h1:/lliqkxwWAhPjf5oSOIJup2XcqJaw8RGS6k3TGEc7GI=
golang.org/x/net v0.0.0-20180911220305-26e67e76b6c3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180911220305-26e67e76b6c3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4= golang.org/x/net v0.24.0 h1:1PcaxkF854Fu3+lvBIx5SYn9wRlBzzcnHZSiaFFAb0w=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44= golang.org/x/net v0.24.0/go.mod h1:2Q7sJY5mzlzWjKtYUEXSlBWCdyaioyXzRB2RtU8KVE8=
golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y= golang.org/x/sys v0.19.0 h1:q5f1RH2jigJ1MoAWp2KTp3gm5zAGFUTarQZ5U386+4o=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.19.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U= golang.org/x/term v0.19.0 h1:+ThwsDv+tYfnJFhF4L8jITxu1tdTWRTZpdsWgEgjL6Q=
golang.org/x/term v0.17.0/go.mod h1:lLRBjIVuehSbZlaOtGMbcMncT+aqLLLmKrsjNrUguwk= golang.org/x/term v0.19.0/go.mod h1:2CuTdWZ7KHSQwUzKva0cbMg6q2DMI3Mmxp+gKJbskEk=
golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ= golang.org/x/text v0.14.0 h1:ScX5w1eTa3QqT8oi6+ziP7dTV1S2+ALU0bI+0zXKWiQ=
golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU= golang.org/x/text v0.14.0/go.mod h1:18ZOQIKpY8NJVqYksKHtTdi31H5itFRjB5/qKTNYzSU=
gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0= gopkg.in/check.v1 v0.0.0-20161208181325-20d25e280405/go.mod h1:Co6ibVJAznAaIkqp8huTwlJQCZ016jof/cbN4VW5Yz0=
gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA= gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/mautrix v0.18.0 h1:sNsApeSWB8x0hLjGcdmi5JqO6Tvp2PVkiSStz+Yas6k=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= maunium.net/go/mautrix v0.18.0/go.mod h1:STwJZ+6CAeiEQs7fYCkd5aC12XR5DXANE6Swy/PBKGo=
maunium.net/go/mautrix v0.17.0 h1:scc1qlUbzPn+wc+3eAPquyD+3gZwwy/hBANBm+iGKK8=
maunium.net/go/mautrix v0.17.0/go.mod h1:j+puTEQCEydlVxhJ/dQP5chfa26TdvBO7X6F3Ataav8=

View File

@@ -2,6 +2,8 @@ package smtp
import ( import (
"io" "io"
"github.com/emersion/go-sasl"
) )
var ( var (
@@ -20,6 +22,11 @@ var (
EnhancedCode: EnhancedCode{5, 7, 0}, EnhancedCode: EnhancedCode{5, 7, 0},
Message: "Authentication not supported", Message: "Authentication not supported",
} }
ErrAuthUnknownMechanism = &SMTPError{
Code: 504,
EnhancedCode: EnhancedCode{5, 7, 4},
Message: "Unsupported authentication mechanism",
}
) )
// A SMTP server backend. // A SMTP server backend.
@@ -27,6 +34,17 @@ type Backend interface {
NewSession(c *Conn) (Session, error) NewSession(c *Conn) (Session, error)
} }
// BackendFunc is an adapter to allow the use of an ordinary function as a
// Backend.
type BackendFunc func(c *Conn) (Session, error)
var _ Backend = (BackendFunc)(nil)
// NewSession calls f(c).
func (f BackendFunc) NewSession(c *Conn) (Session, error) {
return f(c)
}
// Session is used by servers to respond to an SMTP client. // Session is used by servers to respond to an SMTP client.
// //
// The methods are called when the remote client issues the matching command. // The methods are called when the remote client issues the matching command.
@@ -37,9 +55,6 @@ type Session interface {
// Free all resources associated with session. // Free all resources associated with session.
Logout() error Logout() error
// Authenticate the user using SASL PLAIN.
AuthPlain(username, password string) error
// Set return path for currently processed message. // Set return path for currently processed message.
Mail(from string, opts *MailOptions) error Mail(from string, opts *MailOptions) error
// Add recipient for currently processed message. // Add recipient for currently processed message.
@@ -76,3 +91,12 @@ type LMTPSession interface {
type StatusCollector interface { type StatusCollector interface {
SetStatus(rcptTo string, err error) SetStatus(rcptTo string, err error)
} }
// AuthSession is an add-on interface for Session. It provides support for the
// AUTH extension.
type AuthSession interface {
Session
AuthMechanisms() []string
Auth(mech string) (sasl.Server, error)
}

View File

@@ -27,10 +27,7 @@ type Client struct {
text *textproto.Conn text *textproto.Conn
serverName string serverName string
lmtp bool lmtp bool
// map of supported extensions ext map[string]string // supported extensions
ext map[string]string
// supported auth mechanisms
auth []string
localName string // the name to use in HELO/EHLO/LHLO localName string // the name to use in HELO/EHLO/LHLO
didHello bool // whether we've said HELO/EHLO/LHLO didHello bool // whether we've said HELO/EHLO/LHLO
helloError error // the error from the hello helloError error // the error from the hello
@@ -54,7 +51,8 @@ var defaultDialer = net.Dialer{Timeout: defaultTimeout}
// Dial returns a new Client connected to an SMTP server at addr. The addr must // Dial returns a new Client connected to an SMTP server at addr. The addr must
// include a port, as in "mail.example.com:smtp". // include a port, as in "mail.example.com:smtp".
// //
// This function returns a plaintext connection. To enable TLS, use StartTLS. // This function returns a plaintext connection. To enable TLS, use
// DialStartTLS.
func Dial(addr string) (*Client, error) { func Dial(addr string) (*Client, error) {
conn, err := defaultDialer.Dial("tcp", addr) conn, err := defaultDialer.Dial("tcp", addr)
if err != nil { if err != nil {
@@ -83,6 +81,22 @@ func DialTLS(addr string, tlsConfig *tls.Config) (*Client, error) {
return client, nil return client, nil
} }
// DialStartTLS retruns a new Client connected to an SMTP server via STARTTLS
// at addr. The addr must include a port, as in "mail.example.com:smtp".
//
// A nil tlsConfig is equivalent to a zero tls.Config.
func DialStartTLS(addr string, tlsConfig *tls.Config) (*Client, error) {
c, err := Dial(addr)
if err != nil {
return nil, err
}
if err := initStartTLS(c, tlsConfig); err != nil {
c.Close()
return nil, err
}
return c, nil
}
// NewClient returns a new Client using an existing connection and host as a // NewClient returns a new Client using an existing connection and host as a
// server name to be used when authenticating. // server name to be used when authenticating.
func NewClient(conn net.Conn) *Client { func NewClient(conn net.Conn) *Client {
@@ -102,6 +116,29 @@ func NewClient(conn net.Conn) *Client {
return c return c
} }
// NewClientStartTLS creates a new Client and performs a STARTTLS command.
func NewClientStartTLS(conn net.Conn, tlsConfig *tls.Config) (*Client, error) {
c := NewClient(conn)
if err := initStartTLS(c, tlsConfig); err != nil {
c.Close()
return nil, err
}
return c, nil
}
func initStartTLS(c *Client, tlsConfig *tls.Config) error {
if err := c.hello(); err != nil {
return err
}
if ok, _ := c.Extension("STARTTLS"); !ok {
return errors.New("smtp: server doesn't support STARTTLS")
}
if err := c.startTLS(tlsConfig); err != nil {
return err
}
return nil
}
// NewClientLMTP returns a new LMTP Client (as defined in RFC 2033) using an // NewClientLMTP returns a new LMTP Client (as defined in RFC 2033) using an
// existing connection and host as a server name to be used when authenticating. // existing connection and host as a server name to be used when authenticating.
func NewClientLMTP(conn net.Conn) *Client { func NewClientLMTP(conn net.Conn) *Client {
@@ -247,20 +284,17 @@ func (c *Client) ehlo() error {
} }
} }
} }
if mechs, ok := ext["AUTH"]; ok {
c.auth = strings.Split(mechs, " ")
}
c.ext = ext c.ext = ext
return err return err
} }
// StartTLS sends the STARTTLS command and encrypts all further communication. // startTLS sends the STARTTLS command and encrypts all further communication.
// Only servers that advertise the STARTTLS extension support this function. // Only servers that advertise the STARTTLS extension support this function.
// //
// A nil config is equivalent to a zero tls.Config. // A nil config is equivalent to a zero tls.Config.
// //
// If server returns an error, it will be of type *SMTPError. // If server returns an error, it will be of type *SMTPError.
func (c *Client) StartTLS(config *tls.Config) error { func (c *Client) startTLS(config *tls.Config) error {
if err := c.hello(); err != nil { if err := c.hello(); err != nil {
return err return err
} }
@@ -284,7 +318,7 @@ func (c *Client) StartTLS(config *tls.Config) error {
} }
// TLSConnectionState returns the client's TLS connection state. // TLSConnectionState returns the client's TLS connection state.
// The return values are their zero values if StartTLS did // The return values are their zero values if STARTTLS did
// not succeed. // not succeed.
func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) { func (c *Client) TLSConnectionState() (state tls.ConnectionState, ok bool) {
tc, ok := c.conn.(*tls.Conn) tc, ok := c.conn.(*tls.Conn)
@@ -572,7 +606,7 @@ func (c *Client) LMTPData(statusCb func(rcpt string, status *SMTPError)) (io.Wri
// address from, to addresses to, with message r. // address from, to addresses to, with message r.
// //
// This function does not start TLS, nor does it perform authentication. Use // This function does not start TLS, nor does it perform authentication. Use
// StartTLS and Auth before-hand if desirable. // DialStartTLS and Auth before-hand if desirable.
// //
// The addresses in the to parameter are the SMTP RCPT addresses. // The addresses in the to parameter are the SMTP RCPT addresses.
// //
@@ -606,6 +640,46 @@ func (c *Client) SendMail(from string, to []string, r io.Reader) error {
var testHookStartTLS func(*tls.Config) // nil, except for tests var testHookStartTLS func(*tls.Config) // nil, except for tests
func sendMail(addr string, implicitTLS bool, a sasl.Client, from string, to []string, r io.Reader) error {
if err := validateLine(from); err != nil {
return err
}
for _, recp := range to {
if err := validateLine(recp); err != nil {
return err
}
}
var (
c *Client
err error
)
if implicitTLS {
c, err = DialTLS(addr, nil)
} else {
c, err = DialStartTLS(addr, nil)
}
if err != nil {
return err
}
defer c.Close()
if a != nil {
if ok, _ := c.Extension("AUTH"); !ok {
return errors.New("smtp: server doesn't support AUTH")
}
if err = c.Auth(a); err != nil {
return err
}
}
if err := c.SendMail(from, to, r); err != nil {
return err
}
return c.Quit()
}
// SendMail connects to the server at addr, switches to TLS, authenticates with // SendMail connects to the server at addr, switches to TLS, authenticates with
// the optional SASL client, and then sends an email from address from, to // the optional SASL client, and then sends an email from address from, to
// addresses to, with message r. The addr must include a port, as in // addresses to, with message r. The addr must include a port, as in
@@ -628,76 +702,12 @@ var testHookStartTLS func(*tls.Config) // nil, except for tests
// attachments (see the mime/multipart package or the go-message package), or // attachments (see the mime/multipart package or the go-message package), or
// other mail functionality. // other mail functionality.
func SendMail(addr string, a sasl.Client, from string, to []string, r io.Reader) error { func SendMail(addr string, a sasl.Client, from string, to []string, r io.Reader) error {
if err := validateLine(from); err != nil { return sendMail(addr, false, a, from, to, r)
return err
}
for _, recp := range to {
if err := validateLine(recp); err != nil {
return err
}
}
c, err := Dial(addr)
if err != nil {
return err
}
defer c.Close()
if err = c.hello(); err != nil {
return err
}
if ok, _ := c.Extension("STARTTLS"); !ok {
return errors.New("smtp: server doesn't support STARTTLS")
}
if err = c.StartTLS(nil); err != nil {
return err
}
if a != nil {
if ok, _ := c.Extension("AUTH"); !ok {
return errors.New("smtp: server doesn't support AUTH")
}
if err = c.Auth(a); err != nil {
return err
}
}
if err := c.SendMail(from, to, r); err != nil {
return err
}
return c.Quit()
} }
// SendMailTLS works like SendMail, but with implicit TLS. // SendMailTLS works like SendMail, but with implicit TLS.
func SendMailTLS(addr string, a sasl.Client, from string, to []string, r io.Reader) error { func SendMailTLS(addr string, a sasl.Client, from string, to []string, r io.Reader) error {
if err := validateLine(from); err != nil { return sendMail(addr, true, a, from, to, r)
return err
}
for _, recp := range to {
if err := validateLine(recp); err != nil {
return err
}
}
c, err := DialTLS(addr, nil)
if err != nil {
return err
}
defer c.Close()
if err = c.hello(); err != nil {
return err
}
if a != nil {
if ok, _ := c.Extension("AUTH"); !ok {
return errors.New("smtp: server doesn't support AUTH")
}
if err = c.Auth(a); err != nil {
return err
}
}
if err := c.SendMail(from, to, r); err != nil {
return err
}
return c.Quit()
} }
// Extension reports whether an extension is support by the server. // Extension reports whether an extension is support by the server.
@@ -708,14 +718,47 @@ func (c *Client) Extension(ext string) (bool, string) {
if err := c.hello(); err != nil { if err := c.hello(); err != nil {
return false, "" return false, ""
} }
if c.ext == nil {
return false, ""
}
ext = strings.ToUpper(ext) ext = strings.ToUpper(ext)
param, ok := c.ext[ext] param, ok := c.ext[ext]
return ok, param return ok, param
} }
// SupportsAuth checks whether an authentication mechanism is supported.
func (c *Client) SupportsAuth(mech string) bool {
if err := c.hello(); err != nil {
return false
}
mechs, ok := c.ext["AUTH"]
if !ok {
return false
}
for _, m := range strings.Split(mechs, " ") {
if strings.EqualFold(m, mech) {
return true
}
}
return false
}
// MaxMessageSize returns the maximum message size accepted by the server.
// 0 means unlimited.
//
// If the server doesn't convey this information, ok = false is returned.
func (c *Client) MaxMessageSize() (size int, ok bool) {
if err := c.hello(); err != nil {
return 0, false
}
v := c.ext["SIZE"]
if v == "" {
return 0, false
}
size, err := strconv.Atoi(v)
if err != nil || size < 0 {
return 0, false
}
return size, true
}
// Reset sends the RSET command to the server, aborting the current mail // Reset sends the RSET command to the server, aborting the current mail
// transaction. // transaction.
func (c *Client) Reset() error { func (c *Client) Reset() error {

View File

@@ -15,6 +15,8 @@ import (
"strings" "strings"
"sync" "sync"
"time" "time"
"github.com/emersion/go-sasl"
) )
// Number of errors we'll tolerate per connection before closing. Defaults to 3. // Number of errors we'll tolerate per connection before closing. Defaults to 3.
@@ -139,11 +141,7 @@ func (c *Conn) handle(cmd string, arg string) {
c.writeResponse(221, EnhancedCode{2, 0, 0}, "Bye") c.writeResponse(221, EnhancedCode{2, 0, 0}, "Bye")
c.Close() c.Close()
case "AUTH": case "AUTH":
if c.server.AuthDisabled {
c.protocolError(500, EnhancedCode{5, 5, 2}, "Syntax error, AUTH command unrecognized")
} else {
c.handleAuth(arg) c.handleAuth(arg)
}
case "STARTTLS": case "STARTTLS":
c.handleStartTLS() c.handleStartTLS()
default: default:
@@ -205,7 +203,7 @@ func (c *Conn) Conn() net.Conn {
func (c *Conn) authAllowed() bool { func (c *Conn) authAllowed() bool {
_, isTLS := c.TLSConnectionState() _, isTLS := c.TLSConnectionState()
return !c.server.AuthDisabled && (isTLS || c.server.AllowInsecureAuth) return isTLS || c.server.AllowInsecureAuth
} }
// protocolError writes errors responses and closes the connection once too many // protocolError writes errors responses and closes the connection once too many
@@ -250,19 +248,27 @@ func (c *Conn) handleGreet(enhanced bool, arg string) {
return return
} }
caps := []string{} caps := []string{
caps = append(caps, c.server.caps...) "PIPELINING",
"8BITMIME",
"ENHANCEDSTATUSCODES",
"CHUNKING",
}
if _, isTLS := c.TLSConnectionState(); c.server.TLSConfig != nil && !isTLS { if _, isTLS := c.TLSConnectionState(); c.server.TLSConfig != nil && !isTLS {
caps = append(caps, "STARTTLS") caps = append(caps, "STARTTLS")
} }
if c.authAllowed() { if c.authAllowed() {
mechs := c.authMechanisms()
authCap := "AUTH" authCap := "AUTH"
for name := range c.server.auths { for _, name := range mechs {
authCap += " " + name authCap += " " + name
} }
if len(mechs) > 0 {
caps = append(caps, authCap) caps = append(caps, authCap)
} }
}
if c.server.EnableSMTPUTF8 { if c.server.EnableSMTPUTF8 {
caps = append(caps, "SMTPUTF8") caps = append(caps, "SMTPUTF8")
} }
@@ -280,6 +286,9 @@ func (c *Conn) handleGreet(enhanced bool, arg string) {
} else { } else {
caps = append(caps, "SIZE") caps = append(caps, "SIZE")
} }
if c.server.MaxRecipients > 0 {
caps = append(caps, fmt.Sprintf("LIMITS RCPTMAX=%v", c.server.MaxRecipients))
}
args := []string{"Hello " + domain} args := []string{"Hello " + domain}
args = append(args, caps...) args = append(args, caps...)
@@ -348,16 +357,18 @@ func (c *Conn) handleMail(arg string) {
} }
opts.RequireTLS = true opts.RequireTLS = true
case "BODY": case "BODY":
switch value { value = strings.ToUpper(value)
case "BINARYMIME": switch BodyType(value) {
case BodyBinaryMIME:
if !c.server.EnableBINARYMIME { if !c.server.EnableBINARYMIME {
c.writeResponse(504, EnhancedCode{5, 5, 4}, "BINARYMIME is not implemented") c.writeResponse(504, EnhancedCode{5, 5, 4}, "BINARYMIME is not implemented")
return return
} }
c.binarymime = true c.binarymime = true
case "7BIT", "8BITMIME": case Body7Bit, Body8BitMIME:
// This space is intentionally left blank
default: default:
c.writeResponse(500, EnhancedCode{5, 5, 4}, "Unknown BODY value") c.writeResponse(501, EnhancedCode{5, 5, 4}, "Unknown BODY value")
return return
} }
opts.Body = BodyType(value) opts.Body = BodyType(value)
@@ -765,7 +776,7 @@ func (c *Conn) handleAuth(arg string) {
return return
} }
if _, isTLS := c.TLSConnectionState(); !isTLS && !c.server.AllowInsecureAuth { if !c.authAllowed() {
c.writeResponse(523, EnhancedCode{5, 7, 10}, "TLS is required") c.writeResponse(523, EnhancedCode{5, 7, 10}, "TLS is required")
return return
} }
@@ -778,18 +789,21 @@ func (c *Conn) handleAuth(arg string) {
var err error var err error
ir, err = base64.StdEncoding.DecodeString(parts[1]) ir, err = base64.StdEncoding.DecodeString(parts[1])
if err != nil { if err != nil {
c.writeResponse(454, EnhancedCode{4, 7, 0}, "Invalid base64 data")
return return
} }
} }
newSasl, ok := c.server.auths[mechanism] sasl, err := c.auth(mechanism)
if !ok { if err != nil {
c.writeResponse(504, EnhancedCode{5, 7, 4}, "Unsupported authentication mechanism") if smtpErr, ok := err.(*SMTPError); ok {
c.writeResponse(smtpErr.Code, smtpErr.EnhancedCode, smtpErr.Message)
} else {
c.writeResponse(454, EnhancedCode{4, 7, 0}, err.Error())
}
return return
} }
sasl := newSasl(c)
response := ir response := ir
for { for {
challenge, done, err := sasl.Next(response) challenge, done, err := sasl.Next(response)
@@ -834,6 +848,20 @@ func (c *Conn) handleAuth(arg string) {
c.didAuth = true c.didAuth = true
} }
func (c *Conn) authMechanisms() []string {
if authSession, ok := c.Session().(AuthSession); ok {
return authSession.AuthMechanisms()
}
return nil
}
func (c *Conn) auth(mech string) (sasl.Server, error) {
if authSession, ok := c.Session().(AuthSession); ok {
return authSession.Auth(mech)
}
return nil, ErrAuthUnknownMechanism
}
func (c *Conn) handleStartTLS() { func (c *Conn) handleStartTLS() {
if _, isTLS := c.TLSConnectionState(); isTLS { if _, isTLS := c.TLSConnectionState(); isTLS {
c.writeResponse(502, EnhancedCode{5, 5, 1}, "Already running in TLS") c.writeResponse(502, EnhancedCode{5, 5, 1}, "Already running in TLS")

View File

@@ -10,17 +10,12 @@ import (
"os" "os"
"sync" "sync"
"time" "time"
"github.com/emersion/go-sasl"
) )
var ( var (
ErrServerClosed = errors.New("smtp: server already closed") ErrServerClosed = errors.New("smtp: server already closed")
) )
// A function that creates SASL servers.
type SaslServerFactory func(conn *Conn) sasl.Server
// Logger interface is used by Server to report unexpected internal errors. // Logger interface is used by Server to report unexpected internal errors.
type Logger interface { type Logger interface {
Printf(format string, v ...interface{}) Printf(format string, v ...interface{})
@@ -64,17 +59,10 @@ type Server struct {
// Should be used only if backend supports it. // Should be used only if backend supports it.
EnableDSN bool EnableDSN bool
// If set, the AUTH command will not be advertised and authentication
// attempts will be rejected. This setting overrides AllowInsecureAuth.
AuthDisabled bool
// The server backend. // The server backend.
Backend Backend Backend Backend
wg sync.WaitGroup wg sync.WaitGroup
caps []string
auths map[string]SaslServerFactory
done chan struct{} done chan struct{}
locker sync.Mutex locker sync.Mutex
@@ -91,23 +79,6 @@ func NewServer(be Backend) *Server {
Backend: be, Backend: be,
done: make(chan struct{}, 1), done: make(chan struct{}, 1),
ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags), ErrorLog: log.New(os.Stderr, "smtp/server ", log.LstdFlags),
caps: []string{"PIPELINING", "8BITMIME", "ENHANCEDSTATUSCODES", "CHUNKING"},
auths: map[string]SaslServerFactory{
sasl.Plain: func(conn *Conn) sasl.Server {
return sasl.NewPlainServer(func(identity, username, password string) error {
if identity != "" && identity != username {
return errors.New("identities not supported")
}
sess := conn.Session()
if sess == nil {
panic("No session when AUTH is called")
}
return sess.AuthPlain(username, password)
})
},
},
conns: make(map[*Conn]struct{}), conns: make(map[*Conn]struct{}),
} }
} }
@@ -329,11 +300,3 @@ func (s *Server) Shutdown(ctx context.Context) error {
return err return err
} }
} }
// EnableAuth enables an authentication mechanism on this server.
//
// This function should not be called directly, it must only be used by
// libraries implementing extensions of the SMTP protocol.
func (s *Server) EnableAuth(name string, f SaslServerFactory) {
s.auths[name] = f
}

View File

@@ -10,6 +10,8 @@ goldmark
goldmark is compliant with CommonMark 0.31.2. goldmark is compliant with CommonMark 0.31.2.
- [goldmark playground](https://yuin.github.io/goldmark/playground/) : Try goldmark online. This playground is built with WASM(5-10MB).
Motivation Motivation
---------------------- ----------------------
I needed a Markdown parser for Go that satisfies the following requirements: I needed a Markdown parser for Go that satisfies the following requirements:
@@ -282,7 +284,7 @@ markdown := goldmark.New(
"https:", "https:",
}), }),
extension.WithLinkifyURLRegexp( extension.WithLinkifyURLRegexp(
xurls.Strict, xurls.Strict(),
), ),
), ),
), ),
@@ -493,6 +495,7 @@ Extensions
- [goldmark-img64](https://github.com/tenkoh/goldmark-img64): Adds support for embedding images into the document as DataURL (base64 encoded). - [goldmark-img64](https://github.com/tenkoh/goldmark-img64): Adds support for embedding images into the document as DataURL (base64 encoded).
- [goldmark-enclave](https://github.com/quail-ink/goldmark-enclave): Adds support for embedding youtube/bilibili video, X's [oembed tweet](https://publish.twitter.com/), [tradingview](https://www.tradingview.com/widget/)'s chart, [quail](https://quail.ink)'s widget into the document. - [goldmark-enclave](https://github.com/quail-ink/goldmark-enclave): Adds support for embedding youtube/bilibili video, X's [oembed tweet](https://publish.twitter.com/), [tradingview](https://www.tradingview.com/widget/)'s chart, [quail](https://quail.ink)'s widget into the document.
- [goldmark-wiki-table](https://github.com/movsb/goldmark-wiki-table): Adds support for embedding Wiki Tables. - [goldmark-wiki-table](https://github.com/movsb/goldmark-wiki-table): Adds support for embedding Wiki Tables.
- [goldmark-tgmd](https://github.com/Mad-Pixels/goldmark-tgmd): A Telegram markdown renderer that can be passed to `goldmark.WithRenderer()`.
### Loading extensions at runtime ### Loading extensions at runtime
[goldmark-dynamic](https://github.com/yuin/goldmark-dynamic) allows you to write a goldmark extension in Lua and load it at runtime without re-compilation. [goldmark-dynamic](https://github.com/yuin/goldmark-dynamic) allows you to write a goldmark extension in Lua and load it at runtime without re-compilation.

View File

@@ -786,7 +786,14 @@ func RenderAttributes(w util.BufWriter, node ast.Node, filter util.BytesFilter)
_, _ = w.Write(attr.Name) _, _ = w.Write(attr.Name)
_, _ = w.WriteString(`="`) _, _ = w.WriteString(`="`)
// TODO: convert numeric values to strings // TODO: convert numeric values to strings
_, _ = w.Write(util.EscapeHTML(attr.Value.([]byte))) var value []byte
switch typed := attr.Value.(type) {
case []byte:
value = typed
case string:
value = util.StringToReadOnlyBytes(typed)
}
_, _ = w.Write(util.EscapeHTML(value))
_ = w.WriteByte('"') _ = w.WriteByte('"')
} }
} }

View File

@@ -1,11 +1,13 @@
# healthchecks # healthchecks
A [healthchecks.io](https://github.com/healthchecks/healthchecks) wrapper A [healthchecks.io](https://github.com/healthchecks/healthchecks) client
check the godoc for information check the godoc for information
```go ```go
hc := healthchecks.New("your-uuid") hc := healthchecks.New(
healthchecks.WithCheckUUID("your-uuid"),
)
go hc.Auto() go hc.Auto()
hc.Log(strings.NewReader("optional body you can attach to any action")) hc.Log(strings.NewReader("optional body you can attach to any action"))

View File

@@ -5,26 +5,58 @@ import (
"io" "io"
"net/http" "net/http"
"strconv" "strconv"
"time"
"github.com/google/uuid"
) )
// Client for healthchecks // Client for healthchecks
type Client struct { type Client struct {
HTTP *http.Client http *http.Client
log func(string, error) log func(string, error)
baseURL string baseURL string
uuid string uuid string
rid string rid string
create bool
done chan bool done chan bool
} }
// init client
func (c *Client) init(options ...Option) {
for _, option := range options {
option(c)
}
if c.log == nil {
c.log = DefaultErrLog
}
if c.baseURL == "" {
c.baseURL = DefaultAPI
}
if c.http == nil {
c.http = &http.Client{Timeout: 10 * time.Second}
}
if c.done == nil {
c.done = make(chan bool, 1)
}
if c.uuid == "" {
randomUUID, _ := uuid.NewRandom()
c.uuid = randomUUID.String()
c.create = true
c.log("uuid", fmt.Errorf("check UUID is not provided, using random %q with auto provision", c.uuid))
}
}
func (c *Client) call(operation, endpoint string, body ...io.Reader) { func (c *Client) call(operation, endpoint string, body ...io.Reader) {
var err error var err error
var resp *http.Response var resp *http.Response
targetURL := fmt.Sprintf("%s/%s%s?rid=%s", c.baseURL, c.uuid, endpoint, c.rid) targetURL := fmt.Sprintf("%s/%s%s?rid=%s", c.baseURL, c.uuid, endpoint, c.rid)
if c.create {
targetURL += "&create=1"
}
if len(body) > 0 { if len(body) > 0 {
resp, err = c.HTTP.Post(targetURL, "text/plain; charset=utf-8", body[0]) resp, err = c.http.Post(targetURL, "text/plain; charset=utf-8", body[0])
} else { } else {
resp, err = c.HTTP.Head(targetURL) resp, err = c.http.Head(targetURL)
} }
if err != nil { if err != nil {
c.log(operation, err) c.log(operation, err)
@@ -32,7 +64,7 @@ func (c *Client) call(operation, endpoint string, body ...io.Reader) {
} }
defer resp.Body.Close() defer resp.Body.Close()
if resp.StatusCode != http.StatusOK { if resp.StatusCode != http.StatusOK && resp.StatusCode != http.StatusCreated {
respb, rerr := io.ReadAll(resp.Body) respb, rerr := io.ReadAll(resp.Body)
if rerr != nil { if rerr != nil {
c.log(operation+":response", rerr) c.log(operation+":response", rerr)

View File

@@ -2,8 +2,6 @@ package healthchecks
import ( import (
"fmt" "fmt"
"net/http"
"time"
"github.com/google/uuid" "github.com/google/uuid"
) )
@@ -20,21 +18,12 @@ var DefaultErrLog = func(operation string, err error) {
} }
// New healthchecks client // New healthchecks client
func New(hcUUID string, errlog ...ErrLog) *Client { func New(options ...Option) *Client {
rid, _ := uuid.NewRandom() rid, _ := uuid.NewRandom()
c := &Client{ c := &Client{
baseURL: DefaultAPI,
uuid: hcUUID,
rid: rid.String(), rid: rid.String(),
done: make(chan bool, 1),
}
c.HTTP = &http.Client{
Timeout: 10 * time.Second,
}
c.log = DefaultErrLog
if len(errlog) > 0 {
c.log = errlog[0]
} }
c.init(options...)
return c return c
} }

22
vendor/gitlab.com/etke.cc/go/healthchecks/v2/justfile generated vendored Normal file
View File

@@ -0,0 +1,22 @@
# show help by default
default:
@just --list --justfile {{ justfile() }}
# update go deps
update *flags:
go get {{flags}} .
go mod tidy
# run linter
lint:
golangci-lint run ./...
# automatically fix liter issues
lintfix:
golangci-lint run --fix ./...
# run unit tests
test:
@go test -cover -coverprofile=cover.out -coverpkg=./... -covermode=set ./...
@go tool cover -func=cover.out
-@rm -f cover.out

View File

@@ -0,0 +1,47 @@
package healthchecks
import "net/http"
type Option func(*Client)
// WithHTTPClient sets the http client
func WithHTTPClient(httpClient *http.Client) Option {
return func(c *Client) {
c.http = httpClient
}
}
// WithBaseURL sets the base url
func WithBaseURL(baseURL string) Option {
return func(c *Client) {
c.baseURL = baseURL
}
}
// WithErrLog sets the error log
func WithErrLog(errLog ErrLog) Option {
return func(c *Client) {
c.log = errLog
}
}
// WithCheckUUID sets the check UUID
func WithCheckUUID(uuid string) Option {
return func(c *Client) {
c.uuid = uuid
}
}
// WithAutoProvision enables auto provision
func WithAutoProvision() Option {
return func(c *Client) {
c.create = true
}
}
// WithDone sets the done channel
func WithDone(done chan bool) Option {
return func(c *Client) {
c.done = done
}
}

View File

@@ -31,3 +31,16 @@ func (j JSON) Value() (driver.Value, error) {
v, err := json.Marshal(j.Data) v, err := json.Marshal(j.Data)
return string(v), err return string(v), err
} }
// JSONPtr is a convenience function for wrapping a pointer to a value in the JSON utility, but removing typed nils
// (i.e. preventing nils from turning into the string "null" in the database).
func JSONPtr[T any](val *T) JSON {
return JSON{Data: UntypedNil(val)}
}
func UntypedNil[T any](val *T) any {
if val == nil {
return nil
}
return val
}

View File

@@ -10,6 +10,7 @@ import (
"context" "context"
"database/sql" "database/sql"
"errors" "errors"
"time"
"golang.org/x/exp/constraints" "golang.org/x/exp/constraints"
) )
@@ -60,6 +61,34 @@ func NumPtr[T constraints.Integer | constraints.Float](val T) *T {
return &val return &val
} }
// UnixPtr returns a pointer to the given time as unix seconds, or nil if the time is zero.
func UnixPtr(val time.Time) *int64 {
return ConvertedPtr(val, time.Time.Unix)
}
// UnixMilliPtr returns a pointer to the given time as unix milliseconds, or nil if the time is zero.
func UnixMilliPtr(val time.Time) *int64 {
return ConvertedPtr(val, time.Time.UnixMilli)
}
type Zeroable interface {
IsZero() bool
}
// ConvertedPtr returns a pointer to the converted version of the given value, or nil if the input is zero.
//
// This is primarily meant for time.Time, but it can be used with any type that has implements `IsZero() bool`.
//
// yourTime := time.Now()
// unixMSPtr := dbutil.TimePtr(yourTime, time.Time.UnixMilli)
func ConvertedPtr[Input Zeroable, Output any](val Input, converter func(Input) Output) *Output {
if val.IsZero() {
return nil
}
converted := converter(val)
return &converted
}
func (qh *QueryHelper[T]) GetDB() *Database { func (qh *QueryHelper[T]) GetDB() *Database {
return qh.db return qh.db
} }

View File

@@ -11,6 +11,7 @@ import (
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
"runtime"
"time" "time"
"github.com/rs/zerolog" "github.com/rs/zerolog"
@@ -45,27 +46,55 @@ func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sq
} }
func (db *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) { func (db *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
if ctx == nil {
panic("BeginTx() called with nil ctx")
}
return db.LoggingDB.BeginTx(ctx, opts) return db.LoggingDB.BeginTx(ctx, opts)
} }
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error { func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
if ctx == nil {
panic("DoTxn() called with nil ctx")
}
if ctx.Value(ContextKeyDatabaseTransaction) != nil { if ctx.Value(ContextKeyDatabaseTransaction) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one") zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
return fn(ctx) return fn(ctx)
} }
log := zerolog.Ctx(ctx).With().Str("db_txn_id", random.String(12)).Logger() log := zerolog.Ctx(ctx).With().Str("db_txn_id", random.String(12)).Logger()
start := time.Now() slowLog := log
defer func() {
dur := time.Since(start) callerSkip := 1
if dur > time.Second { if val := ctx.Value(ContextKeyDoTxnCallerSkip); val != nil {
val := ctx.Value(ContextKeyDoTxnCallerSkip)
callerSkip := 2
if val != nil {
callerSkip += val.(int) callerSkip += val.(int)
} }
log.Warn(). if pc, file, line, ok := runtime.Caller(callerSkip); ok {
slowLog = log.With().Str(zerolog.CallerFieldName, zerolog.CallerMarshalFunc(pc, file, line)).Logger()
}
start := time.Now()
deadlockCh := make(chan struct{})
go func() {
ticker := time.NewTicker(5 * time.Second)
defer ticker.Stop()
for {
select {
case <-ticker.C:
slowLog.Warn().
Dur("duration_seconds", time.Since(start)).
Msg("Transaction still running")
case <-deadlockCh:
return
}
}
}()
defer func() {
close(deadlockCh)
dur := time.Since(start)
if dur > time.Second {
slowLog.Warn().
Float64("duration_seconds", dur.Seconds()). Float64("duration_seconds", dur.Seconds()).
Caller(callerSkip).
Msg("Transaction took long") Msg("Transaction took long")
} }
}() }()
@@ -100,7 +129,7 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
func (db *Database) Conn(ctx context.Context) Execable { func (db *Database) Conn(ctx context.Context) Execable {
if ctx == nil { if ctx == nil {
return &db.LoggingDB panic("Conn() called with nil ctx")
} }
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction) txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
if ok { if ok {

View File

@@ -0,0 +1,30 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package jsonbytes
import (
"encoding/base64"
"encoding/json"
)
// UnpaddedBytes is a byte slice that is encoded and decoded using
// [base64.RawStdEncoding] instead of the default padded base64.
type UnpaddedBytes []byte
func (b UnpaddedBytes) MarshalJSON() ([]byte, error) {
return json.Marshal(base64.RawStdEncoding.EncodeToString(b))
}
func (b *UnpaddedBytes) UnmarshalJSON(data []byte) error {
var b64str string
err := json.Unmarshal(data, &b64str)
if err != nil {
return err
}
*b, err = base64.RawStdEncoding.DecodeString(b64str)
return err
}

View File

@@ -33,6 +33,9 @@
#define CONSTBASE R16 #define CONSTBASE R16
#define BLOCKS R17 #define BLOCKS R17
// for VPERMXOR
#define MASK R18
DATA consts<>+0x00(SB)/8, $0x3320646e61707865 DATA consts<>+0x00(SB)/8, $0x3320646e61707865
DATA consts<>+0x08(SB)/8, $0x6b20657479622d32 DATA consts<>+0x08(SB)/8, $0x6b20657479622d32
DATA consts<>+0x10(SB)/8, $0x0000000000000001 DATA consts<>+0x10(SB)/8, $0x0000000000000001
@@ -53,7 +56,11 @@ DATA consts<>+0x80(SB)/8, $0x6b2065746b206574
DATA consts<>+0x88(SB)/8, $0x6b2065746b206574 DATA consts<>+0x88(SB)/8, $0x6b2065746b206574
DATA consts<>+0x90(SB)/8, $0x0000000100000000 DATA consts<>+0x90(SB)/8, $0x0000000100000000
DATA consts<>+0x98(SB)/8, $0x0000000300000002 DATA consts<>+0x98(SB)/8, $0x0000000300000002
GLOBL consts<>(SB), RODATA, $0xa0 DATA consts<>+0xa0(SB)/8, $0x5566774411223300
DATA consts<>+0xa8(SB)/8, $0xddeeffcc99aabb88
DATA consts<>+0xb0(SB)/8, $0x6677445522330011
DATA consts<>+0xb8(SB)/8, $0xeeffccddaabb8899
GLOBL consts<>(SB), RODATA, $0xc0
//func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32) //func chaCha20_ctr32_vsx(out, inp *byte, len int, key *[8]uint32, counter *uint32)
TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40 TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
@@ -70,6 +77,9 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD $48, R10 MOVD $48, R10
MOVD $64, R11 MOVD $64, R11
SRD $6, LEN, BLOCKS SRD $6, LEN, BLOCKS
// for VPERMXOR
MOVD $consts<>+0xa0(SB), MASK
MOVD $16, R20
// V16 // V16
LXVW4X (CONSTBASE)(R0), VS48 LXVW4X (CONSTBASE)(R0), VS48
ADD $80,CONSTBASE ADD $80,CONSTBASE
@@ -87,6 +97,10 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
// V28 // V28
LXVW4X (CONSTBASE)(R11), VS60 LXVW4X (CONSTBASE)(R11), VS60
// Load mask constants for VPERMXOR
LXVW4X (MASK)(R0), V20
LXVW4X (MASK)(R20), V21
// splat slot from V19 -> V26 // splat slot from V19 -> V26
VSPLTW $0, V19, V26 VSPLTW $0, V19, V26
@@ -97,7 +111,7 @@ TEXT ·chaCha20_ctr32_vsx(SB),NOSPLIT,$64-40
MOVD $10, R14 MOVD $10, R14
MOVD R14, CTR MOVD R14, CTR
PCALIGN $16
loop_outer_vsx: loop_outer_vsx:
// V0, V1, V2, V3 // V0, V1, V2, V3
LXVW4X (R0)(CONSTBASE), VS32 LXVW4X (R0)(CONSTBASE), VS32
@@ -128,22 +142,17 @@ loop_outer_vsx:
VSPLTISW $12, V28 VSPLTISW $12, V28
VSPLTISW $8, V29 VSPLTISW $8, V29
VSPLTISW $7, V30 VSPLTISW $7, V30
PCALIGN $16
loop_vsx: loop_vsx:
VADDUWM V0, V4, V0 VADDUWM V0, V4, V0
VADDUWM V1, V5, V1 VADDUWM V1, V5, V1
VADDUWM V2, V6, V2 VADDUWM V2, V6, V2
VADDUWM V3, V7, V3 VADDUWM V3, V7, V3
VXOR V12, V0, V12 VPERMXOR V12, V0, V21, V12
VXOR V13, V1, V13 VPERMXOR V13, V1, V21, V13
VXOR V14, V2, V14 VPERMXOR V14, V2, V21, V14
VXOR V15, V3, V15 VPERMXOR V15, V3, V21, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VRLW V15, V27, V15
VADDUWM V8, V12, V8 VADDUWM V8, V12, V8
VADDUWM V9, V13, V9 VADDUWM V9, V13, V9
@@ -165,15 +174,10 @@ loop_vsx:
VADDUWM V2, V6, V2 VADDUWM V2, V6, V2
VADDUWM V3, V7, V3 VADDUWM V3, V7, V3
VXOR V12, V0, V12 VPERMXOR V12, V0, V20, V12
VXOR V13, V1, V13 VPERMXOR V13, V1, V20, V13
VXOR V14, V2, V14 VPERMXOR V14, V2, V20, V14
VXOR V15, V3, V15 VPERMXOR V15, V3, V20, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VRLW V15, V29, V15
VADDUWM V8, V12, V8 VADDUWM V8, V12, V8
VADDUWM V9, V13, V9 VADDUWM V9, V13, V9
@@ -195,15 +199,10 @@ loop_vsx:
VADDUWM V2, V7, V2 VADDUWM V2, V7, V2
VADDUWM V3, V4, V3 VADDUWM V3, V4, V3
VXOR V15, V0, V15 VPERMXOR V15, V0, V21, V15
VXOR V12, V1, V12 VPERMXOR V12, V1, V21, V12
VXOR V13, V2, V13 VPERMXOR V13, V2, V21, V13
VXOR V14, V3, V14 VPERMXOR V14, V3, V21, V14
VRLW V15, V27, V15
VRLW V12, V27, V12
VRLW V13, V27, V13
VRLW V14, V27, V14
VADDUWM V10, V15, V10 VADDUWM V10, V15, V10
VADDUWM V11, V12, V11 VADDUWM V11, V12, V11
@@ -225,15 +224,10 @@ loop_vsx:
VADDUWM V2, V7, V2 VADDUWM V2, V7, V2
VADDUWM V3, V4, V3 VADDUWM V3, V4, V3
VXOR V15, V0, V15 VPERMXOR V15, V0, V20, V15
VXOR V12, V1, V12 VPERMXOR V12, V1, V20, V12
VXOR V13, V2, V13 VPERMXOR V13, V2, V20, V13
VXOR V14, V3, V14 VPERMXOR V14, V3, V20, V14
VRLW V15, V29, V15
VRLW V12, V29, V12
VRLW V13, V29, V13
VRLW V14, V29, V14
VADDUWM V10, V15, V10 VADDUWM V10, V15, V10
VADDUWM V11, V12, V11 VADDUWM V11, V12, V11
@@ -249,48 +243,48 @@ loop_vsx:
VRLW V6, V30, V6 VRLW V6, V30, V6
VRLW V7, V30, V7 VRLW V7, V30, V7
VRLW V4, V30, V4 VRLW V4, V30, V4
BC 16, LT, loop_vsx BDNZ loop_vsx
VADDUWM V12, V26, V12 VADDUWM V12, V26, V12
WORD $0x13600F8C // VMRGEW V0, V1, V27 VMRGEW V0, V1, V27
WORD $0x13821F8C // VMRGEW V2, V3, V28 VMRGEW V2, V3, V28
WORD $0x10000E8C // VMRGOW V0, V1, V0 VMRGOW V0, V1, V0
WORD $0x10421E8C // VMRGOW V2, V3, V2 VMRGOW V2, V3, V2
WORD $0x13A42F8C // VMRGEW V4, V5, V29 VMRGEW V4, V5, V29
WORD $0x13C63F8C // VMRGEW V6, V7, V30 VMRGEW V6, V7, V30
XXPERMDI VS32, VS34, $0, VS33 XXPERMDI VS32, VS34, $0, VS33
XXPERMDI VS32, VS34, $3, VS35 XXPERMDI VS32, VS34, $3, VS35
XXPERMDI VS59, VS60, $0, VS32 XXPERMDI VS59, VS60, $0, VS32
XXPERMDI VS59, VS60, $3, VS34 XXPERMDI VS59, VS60, $3, VS34
WORD $0x10842E8C // VMRGOW V4, V5, V4 VMRGOW V4, V5, V4
WORD $0x10C63E8C // VMRGOW V6, V7, V6 VMRGOW V6, V7, V6
WORD $0x13684F8C // VMRGEW V8, V9, V27 VMRGEW V8, V9, V27
WORD $0x138A5F8C // VMRGEW V10, V11, V28 VMRGEW V10, V11, V28
XXPERMDI VS36, VS38, $0, VS37 XXPERMDI VS36, VS38, $0, VS37
XXPERMDI VS36, VS38, $3, VS39 XXPERMDI VS36, VS38, $3, VS39
XXPERMDI VS61, VS62, $0, VS36 XXPERMDI VS61, VS62, $0, VS36
XXPERMDI VS61, VS62, $3, VS38 XXPERMDI VS61, VS62, $3, VS38
WORD $0x11084E8C // VMRGOW V8, V9, V8 VMRGOW V8, V9, V8
WORD $0x114A5E8C // VMRGOW V10, V11, V10 VMRGOW V10, V11, V10
WORD $0x13AC6F8C // VMRGEW V12, V13, V29 VMRGEW V12, V13, V29
WORD $0x13CE7F8C // VMRGEW V14, V15, V30 VMRGEW V14, V15, V30
XXPERMDI VS40, VS42, $0, VS41 XXPERMDI VS40, VS42, $0, VS41
XXPERMDI VS40, VS42, $3, VS43 XXPERMDI VS40, VS42, $3, VS43
XXPERMDI VS59, VS60, $0, VS40 XXPERMDI VS59, VS60, $0, VS40
XXPERMDI VS59, VS60, $3, VS42 XXPERMDI VS59, VS60, $3, VS42
WORD $0x118C6E8C // VMRGOW V12, V13, V12 VMRGOW V12, V13, V12
WORD $0x11CE7E8C // VMRGOW V14, V15, V14 VMRGOW V14, V15, V14
VSPLTISW $4, V27 VSPLTISW $4, V27
VADDUWM V26, V27, V26 VADDUWM V26, V27, V26
@@ -431,7 +425,7 @@ tail_vsx:
ADD $-1, R11, R12 ADD $-1, R11, R12
ADD $-1, INP ADD $-1, INP
ADD $-1, OUT ADD $-1, OUT
PCALIGN $16
looptail_vsx: looptail_vsx:
// Copying the result to OUT // Copying the result to OUT
// in bytes. // in bytes.
@@ -439,7 +433,7 @@ looptail_vsx:
MOVBZU 1(INP), TMP MOVBZU 1(INP), TMP
XOR KEY, TMP, KEY XOR KEY, TMP, KEY
MOVBU KEY, 1(OUT) MOVBU KEY, 1(OUT)
BC 16, LT, looptail_vsx BDNZ looptail_vsx
// Clear the stack values // Clear the stack values
STXVW4X VS48, (R11)(R0) STXVW4X VS48, (R11)(R0)

View File

@@ -19,15 +19,14 @@
#define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \ #define POLY1305_MUL(h0, h1, h2, r0, r1, t0, t1, t2, t3, t4, t5) \
MULLD r0, h0, t0; \ MULLD r0, h0, t0; \
MULLD r0, h1, t4; \
MULHDU r0, h0, t1; \ MULHDU r0, h0, t1; \
MULLD r0, h1, t4; \
MULHDU r0, h1, t5; \ MULHDU r0, h1, t5; \
ADDC t4, t1, t1; \ ADDC t4, t1, t1; \
MULLD r0, h2, t2; \ MULLD r0, h2, t2; \
ADDZE t5; \
MULHDU r1, h0, t4; \ MULHDU r1, h0, t4; \
MULLD r1, h0, h0; \ MULLD r1, h0, h0; \
ADD t5, t2, t2; \ ADDE t5, t2, t2; \
ADDC h0, t1, t1; \ ADDC h0, t1, t1; \
MULLD h2, r1, t3; \ MULLD h2, r1, t3; \
ADDZE t4, h0; \ ADDZE t4, h0; \
@@ -37,13 +36,11 @@
ADDE t5, t3, t3; \ ADDE t5, t3, t3; \
ADDC h0, t2, t2; \ ADDC h0, t2, t2; \
MOVD $-4, t4; \ MOVD $-4, t4; \
MOVD t0, h0; \
MOVD t1, h1; \
ADDZE t3; \ ADDZE t3; \
ANDCC $3, t2, h2; \ RLDICL $0, t2, $62, h2; \
AND t2, t4, t0; \ AND t2, t4, h0; \
ADDC t0, h0, h0; \ ADDC t0, h0, h0; \
ADDE t3, h1, h1; \ ADDE t3, t1, h1; \
SLD $62, t3, t4; \ SLD $62, t3, t4; \
SRD $2, t2; \ SRD $2, t2; \
ADDZE h2; \ ADDZE h2; \
@@ -75,6 +72,7 @@ TEXT ·update(SB), $0-32
loop: loop:
POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22) POLY1305_ADD(R4, R8, R9, R10, R20, R21, R22)
PCALIGN $16
multiply: multiply:
POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21) POLY1305_MUL(R8, R9, R10, R11, R12, R16, R17, R18, R14, R20, R21)
ADD $-16, R5 ADD $-16, R5

View File

@@ -426,6 +426,35 @@ func (l ServerAuthError) Error() string {
return "[" + strings.Join(errs, ", ") + "]" return "[" + strings.Join(errs, ", ") + "]"
} }
// ServerAuthCallbacks defines server-side authentication callbacks.
type ServerAuthCallbacks struct {
// PasswordCallback behaves like [ServerConfig.PasswordCallback].
PasswordCallback func(conn ConnMetadata, password []byte) (*Permissions, error)
// PublicKeyCallback behaves like [ServerConfig.PublicKeyCallback].
PublicKeyCallback func(conn ConnMetadata, key PublicKey) (*Permissions, error)
// KeyboardInteractiveCallback behaves like [ServerConfig.KeyboardInteractiveCallback].
KeyboardInteractiveCallback func(conn ConnMetadata, client KeyboardInteractiveChallenge) (*Permissions, error)
// GSSAPIWithMICConfig behaves like [ServerConfig.GSSAPIWithMICConfig].
GSSAPIWithMICConfig *GSSAPIWithMICConfig
}
// PartialSuccessError can be returned by any of the [ServerConfig]
// authentication callbacks to indicate to the client that authentication has
// partially succeeded, but further steps are required.
type PartialSuccessError struct {
// Next defines the authentication callbacks to apply to further steps. The
// available methods communicated to the client are based on the non-nil
// ServerAuthCallbacks fields.
Next ServerAuthCallbacks
}
func (p *PartialSuccessError) Error() string {
return "ssh: authenticated with partial success"
}
// ErrNoAuth is the error value returned if no // ErrNoAuth is the error value returned if no
// authentication method has been passed yet. This happens as a normal // authentication method has been passed yet. This happens as a normal
// part of the authentication loop, since the client first tries // part of the authentication loop, since the client first tries
@@ -439,8 +468,18 @@ func (s *connection) serverAuthenticate(config *ServerConfig) (*Permissions, err
var perms *Permissions var perms *Permissions
authFailures := 0 authFailures := 0
noneAuthCount := 0
var authErrs []error var authErrs []error
var displayedBanner bool var displayedBanner bool
partialSuccessReturned := false
// Set the initial authentication callbacks from the config. They can be
// changed if a PartialSuccessError is returned.
authConfig := ServerAuthCallbacks{
PasswordCallback: config.PasswordCallback,
PublicKeyCallback: config.PublicKeyCallback,
KeyboardInteractiveCallback: config.KeyboardInteractiveCallback,
GSSAPIWithMICConfig: config.GSSAPIWithMICConfig,
}
userAuthLoop: userAuthLoop:
for { for {
@@ -471,6 +510,11 @@ userAuthLoop:
return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service) return nil, errors.New("ssh: client attempted to negotiate for unknown service: " + userAuthReq.Service)
} }
if s.user != userAuthReq.User && partialSuccessReturned {
return nil, fmt.Errorf("ssh: client changed the user after a partial success authentication, previous user %q, current user %q",
s.user, userAuthReq.User)
}
s.user = userAuthReq.User s.user = userAuthReq.User
if !displayedBanner && config.BannerCallback != nil { if !displayedBanner && config.BannerCallback != nil {
@@ -491,20 +535,18 @@ userAuthLoop:
switch userAuthReq.Method { switch userAuthReq.Method {
case "none": case "none":
if config.NoClientAuth { noneAuthCount++
// We don't allow none authentication after a partial success
// response.
if config.NoClientAuth && !partialSuccessReturned {
if config.NoClientAuthCallback != nil { if config.NoClientAuthCallback != nil {
perms, authErr = config.NoClientAuthCallback(s) perms, authErr = config.NoClientAuthCallback(s)
} else { } else {
authErr = nil authErr = nil
} }
} }
// allow initial attempt of 'none' without penalty
if authFailures == 0 {
authFailures--
}
case "password": case "password":
if config.PasswordCallback == nil { if authConfig.PasswordCallback == nil {
authErr = errors.New("ssh: password auth not configured") authErr = errors.New("ssh: password auth not configured")
break break
} }
@@ -518,17 +560,17 @@ userAuthLoop:
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
} }
perms, authErr = config.PasswordCallback(s, password) perms, authErr = authConfig.PasswordCallback(s, password)
case "keyboard-interactive": case "keyboard-interactive":
if config.KeyboardInteractiveCallback == nil { if authConfig.KeyboardInteractiveCallback == nil {
authErr = errors.New("ssh: keyboard-interactive auth not configured") authErr = errors.New("ssh: keyboard-interactive auth not configured")
break break
} }
prompter := &sshClientKeyboardInteractive{s} prompter := &sshClientKeyboardInteractive{s}
perms, authErr = config.KeyboardInteractiveCallback(s, prompter.Challenge) perms, authErr = authConfig.KeyboardInteractiveCallback(s, prompter.Challenge)
case "publickey": case "publickey":
if config.PublicKeyCallback == nil { if authConfig.PublicKeyCallback == nil {
authErr = errors.New("ssh: publickey auth not configured") authErr = errors.New("ssh: publickey auth not configured")
break break
} }
@@ -562,11 +604,18 @@ userAuthLoop:
if !ok { if !ok {
candidate.user = s.user candidate.user = s.user
candidate.pubKeyData = pubKeyData candidate.pubKeyData = pubKeyData
candidate.perms, candidate.result = config.PublicKeyCallback(s, pubKey) candidate.perms, candidate.result = authConfig.PublicKeyCallback(s, pubKey)
if candidate.result == nil && candidate.perms != nil && candidate.perms.CriticalOptions != nil && candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" { _, isPartialSuccessError := candidate.result.(*PartialSuccessError)
candidate.result = checkSourceAddress(
if (candidate.result == nil || isPartialSuccessError) &&
candidate.perms != nil &&
candidate.perms.CriticalOptions != nil &&
candidate.perms.CriticalOptions[sourceAddressCriticalOption] != "" {
if err := checkSourceAddress(
s.RemoteAddr(), s.RemoteAddr(),
candidate.perms.CriticalOptions[sourceAddressCriticalOption]) candidate.perms.CriticalOptions[sourceAddressCriticalOption]); err != nil {
candidate.result = err
}
} }
cache.add(candidate) cache.add(candidate)
} }
@@ -578,8 +627,8 @@ userAuthLoop:
if len(payload) > 0 { if len(payload) > 0 {
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
} }
_, isPartialSuccessError := candidate.result.(*PartialSuccessError)
if candidate.result == nil { if candidate.result == nil || isPartialSuccessError {
okMsg := userAuthPubKeyOkMsg{ okMsg := userAuthPubKeyOkMsg{
Algo: algo, Algo: algo,
PubKey: pubKeyData, PubKey: pubKeyData,
@@ -629,11 +678,11 @@ userAuthLoop:
perms = candidate.perms perms = candidate.perms
} }
case "gssapi-with-mic": case "gssapi-with-mic":
if config.GSSAPIWithMICConfig == nil { if authConfig.GSSAPIWithMICConfig == nil {
authErr = errors.New("ssh: gssapi-with-mic auth not configured") authErr = errors.New("ssh: gssapi-with-mic auth not configured")
break break
} }
gssapiConfig := config.GSSAPIWithMICConfig gssapiConfig := authConfig.GSSAPIWithMICConfig
userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload) userAuthRequestGSSAPI, err := parseGSSAPIPayload(userAuthReq.Payload)
if err != nil { if err != nil {
return nil, parseError(msgUserAuthRequest) return nil, parseError(msgUserAuthRequest)
@@ -689,7 +738,28 @@ userAuthLoop:
break userAuthLoop break userAuthLoop
} }
var failureMsg userAuthFailureMsg
if partialSuccess, ok := authErr.(*PartialSuccessError); ok {
// After a partial success error we don't allow changing the user
// name and execute the NoClientAuthCallback.
partialSuccessReturned = true
// In case a partial success is returned, the server may send
// a new set of authentication methods.
authConfig = partialSuccess.Next
// Reset pubkey cache, as the new PublicKeyCallback might
// accept a different set of public keys.
cache = pubKeyCache{}
// Send back a partial success message to the user.
failureMsg.PartialSuccess = true
} else {
// Allow initial attempt of 'none' without penalty.
if authFailures > 0 || userAuthReq.Method != "none" || noneAuthCount != 1 {
authFailures++ authFailures++
}
if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries { if config.MaxAuthTries > 0 && authFailures >= config.MaxAuthTries {
// If we have hit the max attempts, don't bother sending the // If we have hit the max attempts, don't bother sending the
// final SSH_MSG_USERAUTH_FAILURE message, since there are // final SSH_MSG_USERAUTH_FAILURE message, since there are
@@ -709,29 +779,29 @@ userAuthLoop:
// disconnect, should we only send that message.) // disconnect, should we only send that message.)
// //
// Either way, OpenSSH disconnects immediately after the last // Either way, OpenSSH disconnects immediately after the last
// failed authnetication attempt, and given they are typically // failed authentication attempt, and given they are typically
// considered the golden implementation it seems reasonable // considered the golden implementation it seems reasonable
// to match that behavior. // to match that behavior.
continue continue
} }
}
var failureMsg userAuthFailureMsg if authConfig.PasswordCallback != nil {
if config.PasswordCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "password") failureMsg.Methods = append(failureMsg.Methods, "password")
} }
if config.PublicKeyCallback != nil { if authConfig.PublicKeyCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "publickey") failureMsg.Methods = append(failureMsg.Methods, "publickey")
} }
if config.KeyboardInteractiveCallback != nil { if authConfig.KeyboardInteractiveCallback != nil {
failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive") failureMsg.Methods = append(failureMsg.Methods, "keyboard-interactive")
} }
if config.GSSAPIWithMICConfig != nil && config.GSSAPIWithMICConfig.Server != nil && if authConfig.GSSAPIWithMICConfig != nil && authConfig.GSSAPIWithMICConfig.Server != nil &&
config.GSSAPIWithMICConfig.AllowLogin != nil { authConfig.GSSAPIWithMICConfig.AllowLogin != nil {
failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic") failureMsg.Methods = append(failureMsg.Methods, "gssapi-with-mic")
} }
if len(failureMsg.Methods) == 0 { if len(failureMsg.Methods) == 0 {
return nil, errors.New("ssh: no authentication methods configured but NoClientAuth is also false") return nil, errors.New("ssh: no authentication methods available")
} }
if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil { if err := s.transport.writePacket(Marshal(&failureMsg)); err != nil {

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build (aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos) && go1.9 //go:build aix || darwin || dragonfly || freebsd || linux || netbsd || openbsd || solaris || zos
package unix package unix

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build aix || darwin || dragonfly || freebsd || openbsd || solaris //go:build aix || darwin || dragonfly || freebsd || openbsd || solaris || zos
package unix package unix

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build darwin && go1.12 //go:build darwin
package unix package unix

View File

@@ -13,6 +13,7 @@
package unix package unix
import ( import (
"errors"
"sync" "sync"
"unsafe" "unsafe"
) )
@@ -169,25 +170,26 @@ func Getfsstat(buf []Statfs_t, flags int) (n int, err error) {
func Uname(uname *Utsname) error { func Uname(uname *Utsname) error {
mib := []_C_int{CTL_KERN, KERN_OSTYPE} mib := []_C_int{CTL_KERN, KERN_OSTYPE}
n := unsafe.Sizeof(uname.Sysname) n := unsafe.Sizeof(uname.Sysname)
if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil { // Suppress ENOMEM errors to be compatible with the C library __xuname() implementation.
if err := sysctl(mib, &uname.Sysname[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err return err
} }
mib = []_C_int{CTL_KERN, KERN_HOSTNAME} mib = []_C_int{CTL_KERN, KERN_HOSTNAME}
n = unsafe.Sizeof(uname.Nodename) n = unsafe.Sizeof(uname.Nodename)
if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil { if err := sysctl(mib, &uname.Nodename[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err return err
} }
mib = []_C_int{CTL_KERN, KERN_OSRELEASE} mib = []_C_int{CTL_KERN, KERN_OSRELEASE}
n = unsafe.Sizeof(uname.Release) n = unsafe.Sizeof(uname.Release)
if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil { if err := sysctl(mib, &uname.Release[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err return err
} }
mib = []_C_int{CTL_KERN, KERN_VERSION} mib = []_C_int{CTL_KERN, KERN_VERSION}
n = unsafe.Sizeof(uname.Version) n = unsafe.Sizeof(uname.Version)
if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil { if err := sysctl(mib, &uname.Version[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err return err
} }
@@ -205,7 +207,7 @@ func Uname(uname *Utsname) error {
mib = []_C_int{CTL_HW, HW_MACHINE} mib = []_C_int{CTL_HW, HW_MACHINE}
n = unsafe.Sizeof(uname.Machine) n = unsafe.Sizeof(uname.Machine)
if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil { if err := sysctl(mib, &uname.Machine[0], &n, nil, 0); err != nil && !errors.Is(err, ENOMEM) {
return err return err
} }

View File

@@ -1849,6 +1849,105 @@ func Dup2(oldfd, newfd int) error {
//sys Fsmount(fd int, flags int, mountAttrs int) (fsfd int, err error) //sys Fsmount(fd int, flags int, mountAttrs int) (fsfd int, err error)
//sys Fsopen(fsName string, flags int) (fd int, err error) //sys Fsopen(fsName string, flags int) (fd int, err error)
//sys Fspick(dirfd int, pathName string, flags int) (fd int, err error) //sys Fspick(dirfd int, pathName string, flags int) (fd int, err error)
//sys fsconfig(fd int, cmd uint, key *byte, value *byte, aux int) (err error)
func fsconfigCommon(fd int, cmd uint, key string, value *byte, aux int) (err error) {
var keyp *byte
if keyp, err = BytePtrFromString(key); err != nil {
return
}
return fsconfig(fd, cmd, keyp, value, aux)
}
// FsconfigSetFlag is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_FLAG.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
func FsconfigSetFlag(fd int, key string) (err error) {
return fsconfigCommon(fd, FSCONFIG_SET_FLAG, key, nil, 0)
}
// FsconfigSetString is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_STRING.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// value is the parameter value to set.
func FsconfigSetString(fd int, key string, value string) (err error) {
var valuep *byte
if valuep, err = BytePtrFromString(value); err != nil {
return
}
return fsconfigCommon(fd, FSCONFIG_SET_STRING, key, valuep, 0)
}
// FsconfigSetBinary is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_BINARY.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// value is the parameter value to set.
func FsconfigSetBinary(fd int, key string, value []byte) (err error) {
if len(value) == 0 {
return EINVAL
}
return fsconfigCommon(fd, FSCONFIG_SET_BINARY, key, &value[0], len(value))
}
// FsconfigSetPath is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_PATH.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// path is a non-empty path for specified key.
// atfd is a file descriptor at which to start lookup from or AT_FDCWD.
func FsconfigSetPath(fd int, key string, path string, atfd int) (err error) {
var valuep *byte
if valuep, err = BytePtrFromString(path); err != nil {
return
}
return fsconfigCommon(fd, FSCONFIG_SET_PATH, key, valuep, atfd)
}
// FsconfigSetPathEmpty is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_PATH_EMPTY. The same as
// FconfigSetPath but with AT_PATH_EMPTY implied.
func FsconfigSetPathEmpty(fd int, key string, path string, atfd int) (err error) {
var valuep *byte
if valuep, err = BytePtrFromString(path); err != nil {
return
}
return fsconfigCommon(fd, FSCONFIG_SET_PATH_EMPTY, key, valuep, atfd)
}
// FsconfigSetFd is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_SET_FD.
//
// fd is the filesystem context to act upon.
// key the parameter key to set.
// value is a file descriptor to be assigned to specified key.
func FsconfigSetFd(fd int, key string, value int) (err error) {
return fsconfigCommon(fd, FSCONFIG_SET_FD, key, nil, value)
}
// FsconfigCreate is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_CMD_CREATE.
//
// fd is the filesystem context to act upon.
func FsconfigCreate(fd int) (err error) {
return fsconfig(fd, FSCONFIG_CMD_CREATE, nil, nil, 0)
}
// FsconfigReconfigure is equivalent to fsconfig(2) called
// with cmd == FSCONFIG_CMD_RECONFIGURE.
//
// fd is the filesystem context to act upon.
func FsconfigReconfigure(fd int) (err error) {
return fsconfig(fd, FSCONFIG_CMD_RECONFIGURE, nil, nil, 0)
}
//sys Getdents(fd int, buf []byte) (n int, err error) = SYS_GETDENTS64 //sys Getdents(fd int, buf []byte) (n int, err error) = SYS_GETDENTS64
//sysnb Getpgid(pid int) (pgid int, err error) //sysnb Getpgid(pid int) (pgid int, err error)

View File

@@ -1520,6 +1520,14 @@ func (m *mmapper) Munmap(data []byte) (err error) {
return nil return nil
} }
func Mmap(fd int, offset int64, length int, prot int, flags int) (data []byte, err error) {
return mapper.Mmap(fd, offset, length, prot, flags)
}
func Munmap(b []byte) (err error) {
return mapper.Munmap(b)
}
func Read(fd int, p []byte) (n int, err error) { func Read(fd int, p []byte) (n int, err error) {
n, err = read(fd, p) n, err = read(fd, p)
if raceenabled { if raceenabled {

View File

@@ -906,6 +906,16 @@ func Fspick(dirfd int, pathName string, flags int) (fd int, err error) {
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT // THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func fsconfig(fd int, cmd uint, key *byte, value *byte, aux int) (err error) {
_, _, e1 := Syscall6(SYS_FSCONFIG, uintptr(fd), uintptr(cmd), uintptr(unsafe.Pointer(key)), uintptr(unsafe.Pointer(value)), uintptr(aux), 0)
if e1 != 0 {
err = errnoErr(e1)
}
return
}
// THIS FILE IS GENERATED BY THE COMMAND AT THE TOP; DO NOT EDIT
func Getdents(fd int, buf []byte) (n int, err error) { func Getdents(fd int, buf []byte) (n int, err error) {
var _p0 unsafe.Pointer var _p0 unsafe.Pointer
if len(buf) > 0 { if len(buf) > 0 {

View File

@@ -836,6 +836,15 @@ const (
FSPICK_EMPTY_PATH = 0x8 FSPICK_EMPTY_PATH = 0x8
FSMOUNT_CLOEXEC = 0x1 FSMOUNT_CLOEXEC = 0x1
FSCONFIG_SET_FLAG = 0x0
FSCONFIG_SET_STRING = 0x1
FSCONFIG_SET_BINARY = 0x2
FSCONFIG_SET_PATH = 0x3
FSCONFIG_SET_PATH_EMPTY = 0x4
FSCONFIG_SET_FD = 0x5
FSCONFIG_CMD_CREATE = 0x6
FSCONFIG_CMD_RECONFIGURE = 0x7
) )
type OpenHow struct { type OpenHow struct {
@@ -1550,6 +1559,7 @@ const (
IFLA_DEVLINK_PORT = 0x3e IFLA_DEVLINK_PORT = 0x3e
IFLA_GSO_IPV4_MAX_SIZE = 0x3f IFLA_GSO_IPV4_MAX_SIZE = 0x3f
IFLA_GRO_IPV4_MAX_SIZE = 0x40 IFLA_GRO_IPV4_MAX_SIZE = 0x40
IFLA_DPLL_PIN = 0x41
IFLA_PROTO_DOWN_REASON_UNSPEC = 0x0 IFLA_PROTO_DOWN_REASON_UNSPEC = 0x0
IFLA_PROTO_DOWN_REASON_MASK = 0x1 IFLA_PROTO_DOWN_REASON_MASK = 0x1
IFLA_PROTO_DOWN_REASON_VALUE = 0x2 IFLA_PROTO_DOWN_REASON_VALUE = 0x2
@@ -1565,6 +1575,7 @@ const (
IFLA_INET6_ICMP6STATS = 0x6 IFLA_INET6_ICMP6STATS = 0x6
IFLA_INET6_TOKEN = 0x7 IFLA_INET6_TOKEN = 0x7
IFLA_INET6_ADDR_GEN_MODE = 0x8 IFLA_INET6_ADDR_GEN_MODE = 0x8
IFLA_INET6_RA_MTU = 0x9
IFLA_BR_UNSPEC = 0x0 IFLA_BR_UNSPEC = 0x0
IFLA_BR_FORWARD_DELAY = 0x1 IFLA_BR_FORWARD_DELAY = 0x1
IFLA_BR_HELLO_TIME = 0x2 IFLA_BR_HELLO_TIME = 0x2
@@ -1612,6 +1623,9 @@ const (
IFLA_BR_MCAST_MLD_VERSION = 0x2c IFLA_BR_MCAST_MLD_VERSION = 0x2c
IFLA_BR_VLAN_STATS_PER_PORT = 0x2d IFLA_BR_VLAN_STATS_PER_PORT = 0x2d
IFLA_BR_MULTI_BOOLOPT = 0x2e IFLA_BR_MULTI_BOOLOPT = 0x2e
IFLA_BR_MCAST_QUERIER_STATE = 0x2f
IFLA_BR_FDB_N_LEARNED = 0x30
IFLA_BR_FDB_MAX_LEARNED = 0x31
IFLA_BRPORT_UNSPEC = 0x0 IFLA_BRPORT_UNSPEC = 0x0
IFLA_BRPORT_STATE = 0x1 IFLA_BRPORT_STATE = 0x1
IFLA_BRPORT_PRIORITY = 0x2 IFLA_BRPORT_PRIORITY = 0x2
@@ -1649,6 +1663,14 @@ const (
IFLA_BRPORT_BACKUP_PORT = 0x22 IFLA_BRPORT_BACKUP_PORT = 0x22
IFLA_BRPORT_MRP_RING_OPEN = 0x23 IFLA_BRPORT_MRP_RING_OPEN = 0x23
IFLA_BRPORT_MRP_IN_OPEN = 0x24 IFLA_BRPORT_MRP_IN_OPEN = 0x24
IFLA_BRPORT_MCAST_EHT_HOSTS_LIMIT = 0x25
IFLA_BRPORT_MCAST_EHT_HOSTS_CNT = 0x26
IFLA_BRPORT_LOCKED = 0x27
IFLA_BRPORT_MAB = 0x28
IFLA_BRPORT_MCAST_N_GROUPS = 0x29
IFLA_BRPORT_MCAST_MAX_GROUPS = 0x2a
IFLA_BRPORT_NEIGH_VLAN_SUPPRESS = 0x2b
IFLA_BRPORT_BACKUP_NHID = 0x2c
IFLA_INFO_UNSPEC = 0x0 IFLA_INFO_UNSPEC = 0x0
IFLA_INFO_KIND = 0x1 IFLA_INFO_KIND = 0x1
IFLA_INFO_DATA = 0x2 IFLA_INFO_DATA = 0x2
@@ -1670,6 +1692,9 @@ const (
IFLA_MACVLAN_MACADDR = 0x4 IFLA_MACVLAN_MACADDR = 0x4
IFLA_MACVLAN_MACADDR_DATA = 0x5 IFLA_MACVLAN_MACADDR_DATA = 0x5
IFLA_MACVLAN_MACADDR_COUNT = 0x6 IFLA_MACVLAN_MACADDR_COUNT = 0x6
IFLA_MACVLAN_BC_QUEUE_LEN = 0x7
IFLA_MACVLAN_BC_QUEUE_LEN_USED = 0x8
IFLA_MACVLAN_BC_CUTOFF = 0x9
IFLA_VRF_UNSPEC = 0x0 IFLA_VRF_UNSPEC = 0x0
IFLA_VRF_TABLE = 0x1 IFLA_VRF_TABLE = 0x1
IFLA_VRF_PORT_UNSPEC = 0x0 IFLA_VRF_PORT_UNSPEC = 0x0
@@ -1693,9 +1718,22 @@ const (
IFLA_XFRM_UNSPEC = 0x0 IFLA_XFRM_UNSPEC = 0x0
IFLA_XFRM_LINK = 0x1 IFLA_XFRM_LINK = 0x1
IFLA_XFRM_IF_ID = 0x2 IFLA_XFRM_IF_ID = 0x2
IFLA_XFRM_COLLECT_METADATA = 0x3
IFLA_IPVLAN_UNSPEC = 0x0 IFLA_IPVLAN_UNSPEC = 0x0
IFLA_IPVLAN_MODE = 0x1 IFLA_IPVLAN_MODE = 0x1
IFLA_IPVLAN_FLAGS = 0x2 IFLA_IPVLAN_FLAGS = 0x2
NETKIT_NEXT = -0x1
NETKIT_PASS = 0x0
NETKIT_DROP = 0x2
NETKIT_REDIRECT = 0x7
NETKIT_L2 = 0x0
NETKIT_L3 = 0x1
IFLA_NETKIT_UNSPEC = 0x0
IFLA_NETKIT_PEER_INFO = 0x1
IFLA_NETKIT_PRIMARY = 0x2
IFLA_NETKIT_POLICY = 0x3
IFLA_NETKIT_PEER_POLICY = 0x4
IFLA_NETKIT_MODE = 0x5
IFLA_VXLAN_UNSPEC = 0x0 IFLA_VXLAN_UNSPEC = 0x0
IFLA_VXLAN_ID = 0x1 IFLA_VXLAN_ID = 0x1
IFLA_VXLAN_GROUP = 0x2 IFLA_VXLAN_GROUP = 0x2
@@ -1726,6 +1764,8 @@ const (
IFLA_VXLAN_GPE = 0x1b IFLA_VXLAN_GPE = 0x1b
IFLA_VXLAN_TTL_INHERIT = 0x1c IFLA_VXLAN_TTL_INHERIT = 0x1c
IFLA_VXLAN_DF = 0x1d IFLA_VXLAN_DF = 0x1d
IFLA_VXLAN_VNIFILTER = 0x1e
IFLA_VXLAN_LOCALBYPASS = 0x1f
IFLA_GENEVE_UNSPEC = 0x0 IFLA_GENEVE_UNSPEC = 0x0
IFLA_GENEVE_ID = 0x1 IFLA_GENEVE_ID = 0x1
IFLA_GENEVE_REMOTE = 0x2 IFLA_GENEVE_REMOTE = 0x2
@@ -1740,6 +1780,7 @@ const (
IFLA_GENEVE_LABEL = 0xb IFLA_GENEVE_LABEL = 0xb
IFLA_GENEVE_TTL_INHERIT = 0xc IFLA_GENEVE_TTL_INHERIT = 0xc
IFLA_GENEVE_DF = 0xd IFLA_GENEVE_DF = 0xd
IFLA_GENEVE_INNER_PROTO_INHERIT = 0xe
IFLA_BAREUDP_UNSPEC = 0x0 IFLA_BAREUDP_UNSPEC = 0x0
IFLA_BAREUDP_PORT = 0x1 IFLA_BAREUDP_PORT = 0x1
IFLA_BAREUDP_ETHERTYPE = 0x2 IFLA_BAREUDP_ETHERTYPE = 0x2
@@ -1752,6 +1793,8 @@ const (
IFLA_GTP_FD1 = 0x2 IFLA_GTP_FD1 = 0x2
IFLA_GTP_PDP_HASHSIZE = 0x3 IFLA_GTP_PDP_HASHSIZE = 0x3
IFLA_GTP_ROLE = 0x4 IFLA_GTP_ROLE = 0x4
IFLA_GTP_CREATE_SOCKETS = 0x5
IFLA_GTP_RESTART_COUNT = 0x6
IFLA_BOND_UNSPEC = 0x0 IFLA_BOND_UNSPEC = 0x0
IFLA_BOND_MODE = 0x1 IFLA_BOND_MODE = 0x1
IFLA_BOND_ACTIVE_SLAVE = 0x2 IFLA_BOND_ACTIVE_SLAVE = 0x2
@@ -1781,6 +1824,9 @@ const (
IFLA_BOND_AD_ACTOR_SYSTEM = 0x1a IFLA_BOND_AD_ACTOR_SYSTEM = 0x1a
IFLA_BOND_TLB_DYNAMIC_LB = 0x1b IFLA_BOND_TLB_DYNAMIC_LB = 0x1b
IFLA_BOND_PEER_NOTIF_DELAY = 0x1c IFLA_BOND_PEER_NOTIF_DELAY = 0x1c
IFLA_BOND_AD_LACP_ACTIVE = 0x1d
IFLA_BOND_MISSED_MAX = 0x1e
IFLA_BOND_NS_IP6_TARGET = 0x1f
IFLA_BOND_AD_INFO_UNSPEC = 0x0 IFLA_BOND_AD_INFO_UNSPEC = 0x0
IFLA_BOND_AD_INFO_AGGREGATOR = 0x1 IFLA_BOND_AD_INFO_AGGREGATOR = 0x1
IFLA_BOND_AD_INFO_NUM_PORTS = 0x2 IFLA_BOND_AD_INFO_NUM_PORTS = 0x2
@@ -1796,6 +1842,7 @@ const (
IFLA_BOND_SLAVE_AD_AGGREGATOR_ID = 0x6 IFLA_BOND_SLAVE_AD_AGGREGATOR_ID = 0x6
IFLA_BOND_SLAVE_AD_ACTOR_OPER_PORT_STATE = 0x7 IFLA_BOND_SLAVE_AD_ACTOR_OPER_PORT_STATE = 0x7
IFLA_BOND_SLAVE_AD_PARTNER_OPER_PORT_STATE = 0x8 IFLA_BOND_SLAVE_AD_PARTNER_OPER_PORT_STATE = 0x8
IFLA_BOND_SLAVE_PRIO = 0x9
IFLA_VF_INFO_UNSPEC = 0x0 IFLA_VF_INFO_UNSPEC = 0x0
IFLA_VF_INFO = 0x1 IFLA_VF_INFO = 0x1
IFLA_VF_UNSPEC = 0x0 IFLA_VF_UNSPEC = 0x0
@@ -1854,8 +1901,16 @@ const (
IFLA_STATS_LINK_XSTATS_SLAVE = 0x3 IFLA_STATS_LINK_XSTATS_SLAVE = 0x3
IFLA_STATS_LINK_OFFLOAD_XSTATS = 0x4 IFLA_STATS_LINK_OFFLOAD_XSTATS = 0x4
IFLA_STATS_AF_SPEC = 0x5 IFLA_STATS_AF_SPEC = 0x5
IFLA_STATS_GETSET_UNSPEC = 0x0
IFLA_STATS_GET_FILTERS = 0x1
IFLA_STATS_SET_OFFLOAD_XSTATS_L3_STATS = 0x2
IFLA_OFFLOAD_XSTATS_UNSPEC = 0x0 IFLA_OFFLOAD_XSTATS_UNSPEC = 0x0
IFLA_OFFLOAD_XSTATS_CPU_HIT = 0x1 IFLA_OFFLOAD_XSTATS_CPU_HIT = 0x1
IFLA_OFFLOAD_XSTATS_HW_S_INFO = 0x2
IFLA_OFFLOAD_XSTATS_L3_STATS = 0x3
IFLA_OFFLOAD_XSTATS_HW_S_INFO_UNSPEC = 0x0
IFLA_OFFLOAD_XSTATS_HW_S_INFO_REQUEST = 0x1
IFLA_OFFLOAD_XSTATS_HW_S_INFO_USED = 0x2
IFLA_XDP_UNSPEC = 0x0 IFLA_XDP_UNSPEC = 0x0
IFLA_XDP_FD = 0x1 IFLA_XDP_FD = 0x1
IFLA_XDP_ATTACHED = 0x2 IFLA_XDP_ATTACHED = 0x2
@@ -1885,6 +1940,11 @@ const (
IFLA_RMNET_UNSPEC = 0x0 IFLA_RMNET_UNSPEC = 0x0
IFLA_RMNET_MUX_ID = 0x1 IFLA_RMNET_MUX_ID = 0x1
IFLA_RMNET_FLAGS = 0x2 IFLA_RMNET_FLAGS = 0x2
IFLA_MCTP_UNSPEC = 0x0
IFLA_MCTP_NET = 0x1
IFLA_DSA_UNSPEC = 0x0
IFLA_DSA_CONDUIT = 0x1
IFLA_DSA_MASTER = 0x1
) )
const ( const (

View File

@@ -165,6 +165,7 @@ func NewCallbackCDecl(fn interface{}) uintptr {
//sys CreateFile(name *uint16, access uint32, mode uint32, sa *SecurityAttributes, createmode uint32, attrs uint32, templatefile Handle) (handle Handle, err error) [failretval==InvalidHandle] = CreateFileW //sys CreateFile(name *uint16, access uint32, mode uint32, sa *SecurityAttributes, createmode uint32, attrs uint32, templatefile Handle) (handle Handle, err error) [failretval==InvalidHandle] = CreateFileW
//sys CreateNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *SecurityAttributes) (handle Handle, err error) [failretval==InvalidHandle] = CreateNamedPipeW //sys CreateNamedPipe(name *uint16, flags uint32, pipeMode uint32, maxInstances uint32, outSize uint32, inSize uint32, defaultTimeout uint32, sa *SecurityAttributes) (handle Handle, err error) [failretval==InvalidHandle] = CreateNamedPipeW
//sys ConnectNamedPipe(pipe Handle, overlapped *Overlapped) (err error) //sys ConnectNamedPipe(pipe Handle, overlapped *Overlapped) (err error)
//sys DisconnectNamedPipe(pipe Handle) (err error)
//sys GetNamedPipeInfo(pipe Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error) //sys GetNamedPipeInfo(pipe Handle, flags *uint32, outSize *uint32, inSize *uint32, maxInstances *uint32) (err error)
//sys GetNamedPipeHandleState(pipe Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW //sys GetNamedPipeHandleState(pipe Handle, state *uint32, curInstances *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32, userName *uint16, maxUserNameSize uint32) (err error) = GetNamedPipeHandleStateW
//sys SetNamedPipeHandleState(pipe Handle, state *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32) (err error) = SetNamedPipeHandleState //sys SetNamedPipeHandleState(pipe Handle, state *uint32, maxCollectionCount *uint32, collectDataTimeout *uint32) (err error) = SetNamedPipeHandleState
@@ -348,8 +349,19 @@ func NewCallbackCDecl(fn interface{}) uintptr {
//sys SetProcessPriorityBoost(process Handle, disable bool) (err error) = kernel32.SetProcessPriorityBoost //sys SetProcessPriorityBoost(process Handle, disable bool) (err error) = kernel32.SetProcessPriorityBoost
//sys GetProcessWorkingSetSizeEx(hProcess Handle, lpMinimumWorkingSetSize *uintptr, lpMaximumWorkingSetSize *uintptr, flags *uint32) //sys GetProcessWorkingSetSizeEx(hProcess Handle, lpMinimumWorkingSetSize *uintptr, lpMaximumWorkingSetSize *uintptr, flags *uint32)
//sys SetProcessWorkingSetSizeEx(hProcess Handle, dwMinimumWorkingSetSize uintptr, dwMaximumWorkingSetSize uintptr, flags uint32) (err error) //sys SetProcessWorkingSetSizeEx(hProcess Handle, dwMinimumWorkingSetSize uintptr, dwMaximumWorkingSetSize uintptr, flags uint32) (err error)
//sys ClearCommBreak(handle Handle) (err error)
//sys ClearCommError(handle Handle, lpErrors *uint32, lpStat *ComStat) (err error)
//sys EscapeCommFunction(handle Handle, dwFunc uint32) (err error)
//sys GetCommState(handle Handle, lpDCB *DCB) (err error)
//sys GetCommModemStatus(handle Handle, lpModemStat *uint32) (err error)
//sys GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) //sys GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error)
//sys PurgeComm(handle Handle, dwFlags uint32) (err error)
//sys SetCommBreak(handle Handle) (err error)
//sys SetCommMask(handle Handle, dwEvtMask uint32) (err error)
//sys SetCommState(handle Handle, lpDCB *DCB) (err error)
//sys SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) //sys SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error)
//sys SetupComm(handle Handle, dwInQueue uint32, dwOutQueue uint32) (err error)
//sys WaitCommEvent(handle Handle, lpEvtMask *uint32, lpOverlapped *Overlapped) (err error)
//sys GetActiveProcessorCount(groupNumber uint16) (ret uint32) //sys GetActiveProcessorCount(groupNumber uint16) (ret uint32)
//sys GetMaximumProcessorCount(groupNumber uint16) (ret uint32) //sys GetMaximumProcessorCount(groupNumber uint16) (ret uint32)
//sys EnumWindows(enumFunc uintptr, param unsafe.Pointer) (err error) = user32.EnumWindows //sys EnumWindows(enumFunc uintptr, param unsafe.Pointer) (err error) = user32.EnumWindows
@@ -1834,3 +1846,73 @@ func ResizePseudoConsole(pconsole Handle, size Coord) error {
// accept arguments that can be casted to uintptr, and Coord can't. // accept arguments that can be casted to uintptr, and Coord can't.
return resizePseudoConsole(pconsole, *((*uint32)(unsafe.Pointer(&size)))) return resizePseudoConsole(pconsole, *((*uint32)(unsafe.Pointer(&size))))
} }
// DCB constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/ns-winbase-dcb.
const (
CBR_110 = 110
CBR_300 = 300
CBR_600 = 600
CBR_1200 = 1200
CBR_2400 = 2400
CBR_4800 = 4800
CBR_9600 = 9600
CBR_14400 = 14400
CBR_19200 = 19200
CBR_38400 = 38400
CBR_57600 = 57600
CBR_115200 = 115200
CBR_128000 = 128000
CBR_256000 = 256000
DTR_CONTROL_DISABLE = 0x00000000
DTR_CONTROL_ENABLE = 0x00000010
DTR_CONTROL_HANDSHAKE = 0x00000020
RTS_CONTROL_DISABLE = 0x00000000
RTS_CONTROL_ENABLE = 0x00001000
RTS_CONTROL_HANDSHAKE = 0x00002000
RTS_CONTROL_TOGGLE = 0x00003000
NOPARITY = 0
ODDPARITY = 1
EVENPARITY = 2
MARKPARITY = 3
SPACEPARITY = 4
ONESTOPBIT = 0
ONE5STOPBITS = 1
TWOSTOPBITS = 2
)
// EscapeCommFunction constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-escapecommfunction.
const (
SETXOFF = 1
SETXON = 2
SETRTS = 3
CLRRTS = 4
SETDTR = 5
CLRDTR = 6
SETBREAK = 8
CLRBREAK = 9
)
// PurgeComm constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-purgecomm.
const (
PURGE_TXABORT = 0x0001
PURGE_RXABORT = 0x0002
PURGE_TXCLEAR = 0x0004
PURGE_RXCLEAR = 0x0008
)
// SetCommMask constants. See https://learn.microsoft.com/en-us/windows/win32/api/winbase/nf-winbase-setcommmask.
const (
EV_RXCHAR = 0x0001
EV_RXFLAG = 0x0002
EV_TXEMPTY = 0x0004
EV_CTS = 0x0008
EV_DSR = 0x0010
EV_RLSD = 0x0020
EV_BREAK = 0x0040
EV_ERR = 0x0080
EV_RING = 0x0100
)

View File

@@ -3380,3 +3380,27 @@ type BLOB struct {
Size uint32 Size uint32
BlobData *byte BlobData *byte
} }
type ComStat struct {
Flags uint32
CBInQue uint32
CBOutQue uint32
}
type DCB struct {
DCBlength uint32
BaudRate uint32
Flags uint32
wReserved uint16
XonLim uint16
XoffLim uint16
ByteSize uint8
Parity uint8
StopBits uint8
XonChar byte
XoffChar byte
ErrorChar byte
EofChar byte
EvtChar byte
wReserved1 uint16
}

View File

@@ -188,6 +188,8 @@ var (
procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject") procAssignProcessToJobObject = modkernel32.NewProc("AssignProcessToJobObject")
procCancelIo = modkernel32.NewProc("CancelIo") procCancelIo = modkernel32.NewProc("CancelIo")
procCancelIoEx = modkernel32.NewProc("CancelIoEx") procCancelIoEx = modkernel32.NewProc("CancelIoEx")
procClearCommBreak = modkernel32.NewProc("ClearCommBreak")
procClearCommError = modkernel32.NewProc("ClearCommError")
procCloseHandle = modkernel32.NewProc("CloseHandle") procCloseHandle = modkernel32.NewProc("CloseHandle")
procClosePseudoConsole = modkernel32.NewProc("ClosePseudoConsole") procClosePseudoConsole = modkernel32.NewProc("ClosePseudoConsole")
procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe") procConnectNamedPipe = modkernel32.NewProc("ConnectNamedPipe")
@@ -212,7 +214,9 @@ var (
procDeleteProcThreadAttributeList = modkernel32.NewProc("DeleteProcThreadAttributeList") procDeleteProcThreadAttributeList = modkernel32.NewProc("DeleteProcThreadAttributeList")
procDeleteVolumeMountPointW = modkernel32.NewProc("DeleteVolumeMountPointW") procDeleteVolumeMountPointW = modkernel32.NewProc("DeleteVolumeMountPointW")
procDeviceIoControl = modkernel32.NewProc("DeviceIoControl") procDeviceIoControl = modkernel32.NewProc("DeviceIoControl")
procDisconnectNamedPipe = modkernel32.NewProc("DisconnectNamedPipe")
procDuplicateHandle = modkernel32.NewProc("DuplicateHandle") procDuplicateHandle = modkernel32.NewProc("DuplicateHandle")
procEscapeCommFunction = modkernel32.NewProc("EscapeCommFunction")
procExitProcess = modkernel32.NewProc("ExitProcess") procExitProcess = modkernel32.NewProc("ExitProcess")
procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW") procExpandEnvironmentStringsW = modkernel32.NewProc("ExpandEnvironmentStringsW")
procFindClose = modkernel32.NewProc("FindClose") procFindClose = modkernel32.NewProc("FindClose")
@@ -236,6 +240,8 @@ var (
procGenerateConsoleCtrlEvent = modkernel32.NewProc("GenerateConsoleCtrlEvent") procGenerateConsoleCtrlEvent = modkernel32.NewProc("GenerateConsoleCtrlEvent")
procGetACP = modkernel32.NewProc("GetACP") procGetACP = modkernel32.NewProc("GetACP")
procGetActiveProcessorCount = modkernel32.NewProc("GetActiveProcessorCount") procGetActiveProcessorCount = modkernel32.NewProc("GetActiveProcessorCount")
procGetCommModemStatus = modkernel32.NewProc("GetCommModemStatus")
procGetCommState = modkernel32.NewProc("GetCommState")
procGetCommTimeouts = modkernel32.NewProc("GetCommTimeouts") procGetCommTimeouts = modkernel32.NewProc("GetCommTimeouts")
procGetCommandLineW = modkernel32.NewProc("GetCommandLineW") procGetCommandLineW = modkernel32.NewProc("GetCommandLineW")
procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW") procGetComputerNameExW = modkernel32.NewProc("GetComputerNameExW")
@@ -322,6 +328,7 @@ var (
procProcess32NextW = modkernel32.NewProc("Process32NextW") procProcess32NextW = modkernel32.NewProc("Process32NextW")
procProcessIdToSessionId = modkernel32.NewProc("ProcessIdToSessionId") procProcessIdToSessionId = modkernel32.NewProc("ProcessIdToSessionId")
procPulseEvent = modkernel32.NewProc("PulseEvent") procPulseEvent = modkernel32.NewProc("PulseEvent")
procPurgeComm = modkernel32.NewProc("PurgeComm")
procQueryDosDeviceW = modkernel32.NewProc("QueryDosDeviceW") procQueryDosDeviceW = modkernel32.NewProc("QueryDosDeviceW")
procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW") procQueryFullProcessImageNameW = modkernel32.NewProc("QueryFullProcessImageNameW")
procQueryInformationJobObject = modkernel32.NewProc("QueryInformationJobObject") procQueryInformationJobObject = modkernel32.NewProc("QueryInformationJobObject")
@@ -335,6 +342,9 @@ var (
procResetEvent = modkernel32.NewProc("ResetEvent") procResetEvent = modkernel32.NewProc("ResetEvent")
procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole") procResizePseudoConsole = modkernel32.NewProc("ResizePseudoConsole")
procResumeThread = modkernel32.NewProc("ResumeThread") procResumeThread = modkernel32.NewProc("ResumeThread")
procSetCommBreak = modkernel32.NewProc("SetCommBreak")
procSetCommMask = modkernel32.NewProc("SetCommMask")
procSetCommState = modkernel32.NewProc("SetCommState")
procSetCommTimeouts = modkernel32.NewProc("SetCommTimeouts") procSetCommTimeouts = modkernel32.NewProc("SetCommTimeouts")
procSetConsoleCursorPosition = modkernel32.NewProc("SetConsoleCursorPosition") procSetConsoleCursorPosition = modkernel32.NewProc("SetConsoleCursorPosition")
procSetConsoleMode = modkernel32.NewProc("SetConsoleMode") procSetConsoleMode = modkernel32.NewProc("SetConsoleMode")
@@ -342,7 +352,6 @@ var (
procSetDefaultDllDirectories = modkernel32.NewProc("SetDefaultDllDirectories") procSetDefaultDllDirectories = modkernel32.NewProc("SetDefaultDllDirectories")
procSetDllDirectoryW = modkernel32.NewProc("SetDllDirectoryW") procSetDllDirectoryW = modkernel32.NewProc("SetDllDirectoryW")
procSetEndOfFile = modkernel32.NewProc("SetEndOfFile") procSetEndOfFile = modkernel32.NewProc("SetEndOfFile")
procSetFileValidData = modkernel32.NewProc("SetFileValidData")
procSetEnvironmentVariableW = modkernel32.NewProc("SetEnvironmentVariableW") procSetEnvironmentVariableW = modkernel32.NewProc("SetEnvironmentVariableW")
procSetErrorMode = modkernel32.NewProc("SetErrorMode") procSetErrorMode = modkernel32.NewProc("SetErrorMode")
procSetEvent = modkernel32.NewProc("SetEvent") procSetEvent = modkernel32.NewProc("SetEvent")
@@ -351,6 +360,7 @@ var (
procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle") procSetFileInformationByHandle = modkernel32.NewProc("SetFileInformationByHandle")
procSetFilePointer = modkernel32.NewProc("SetFilePointer") procSetFilePointer = modkernel32.NewProc("SetFilePointer")
procSetFileTime = modkernel32.NewProc("SetFileTime") procSetFileTime = modkernel32.NewProc("SetFileTime")
procSetFileValidData = modkernel32.NewProc("SetFileValidData")
procSetHandleInformation = modkernel32.NewProc("SetHandleInformation") procSetHandleInformation = modkernel32.NewProc("SetHandleInformation")
procSetInformationJobObject = modkernel32.NewProc("SetInformationJobObject") procSetInformationJobObject = modkernel32.NewProc("SetInformationJobObject")
procSetNamedPipeHandleState = modkernel32.NewProc("SetNamedPipeHandleState") procSetNamedPipeHandleState = modkernel32.NewProc("SetNamedPipeHandleState")
@@ -361,6 +371,7 @@ var (
procSetStdHandle = modkernel32.NewProc("SetStdHandle") procSetStdHandle = modkernel32.NewProc("SetStdHandle")
procSetVolumeLabelW = modkernel32.NewProc("SetVolumeLabelW") procSetVolumeLabelW = modkernel32.NewProc("SetVolumeLabelW")
procSetVolumeMountPointW = modkernel32.NewProc("SetVolumeMountPointW") procSetVolumeMountPointW = modkernel32.NewProc("SetVolumeMountPointW")
procSetupComm = modkernel32.NewProc("SetupComm")
procSizeofResource = modkernel32.NewProc("SizeofResource") procSizeofResource = modkernel32.NewProc("SizeofResource")
procSleepEx = modkernel32.NewProc("SleepEx") procSleepEx = modkernel32.NewProc("SleepEx")
procTerminateJobObject = modkernel32.NewProc("TerminateJobObject") procTerminateJobObject = modkernel32.NewProc("TerminateJobObject")
@@ -379,6 +390,7 @@ var (
procVirtualQueryEx = modkernel32.NewProc("VirtualQueryEx") procVirtualQueryEx = modkernel32.NewProc("VirtualQueryEx")
procVirtualUnlock = modkernel32.NewProc("VirtualUnlock") procVirtualUnlock = modkernel32.NewProc("VirtualUnlock")
procWTSGetActiveConsoleSessionId = modkernel32.NewProc("WTSGetActiveConsoleSessionId") procWTSGetActiveConsoleSessionId = modkernel32.NewProc("WTSGetActiveConsoleSessionId")
procWaitCommEvent = modkernel32.NewProc("WaitCommEvent")
procWaitForMultipleObjects = modkernel32.NewProc("WaitForMultipleObjects") procWaitForMultipleObjects = modkernel32.NewProc("WaitForMultipleObjects")
procWaitForSingleObject = modkernel32.NewProc("WaitForSingleObject") procWaitForSingleObject = modkernel32.NewProc("WaitForSingleObject")
procWriteConsoleW = modkernel32.NewProc("WriteConsoleW") procWriteConsoleW = modkernel32.NewProc("WriteConsoleW")
@@ -1641,6 +1653,22 @@ func CancelIoEx(s Handle, o *Overlapped) (err error) {
return return
} }
func ClearCommBreak(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procClearCommBreak.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ClearCommError(handle Handle, lpErrors *uint32, lpStat *ComStat) (err error) {
r1, _, e1 := syscall.Syscall(procClearCommError.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(lpErrors)), uintptr(unsafe.Pointer(lpStat)))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func CloseHandle(handle Handle) (err error) { func CloseHandle(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procCloseHandle.Addr(), 1, uintptr(handle), 0, 0) r1, _, e1 := syscall.Syscall(procCloseHandle.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 { if r1 == 0 {
@@ -1845,6 +1873,14 @@ func DeviceIoControl(handle Handle, ioControlCode uint32, inBuffer *byte, inBuff
return return
} }
func DisconnectNamedPipe(pipe Handle) (err error) {
r1, _, e1 := syscall.Syscall(procDisconnectNamedPipe.Addr(), 1, uintptr(pipe), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetProcessHandle Handle, lpTargetHandle *Handle, dwDesiredAccess uint32, bInheritHandle bool, dwOptions uint32) (err error) { func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetProcessHandle Handle, lpTargetHandle *Handle, dwDesiredAccess uint32, bInheritHandle bool, dwOptions uint32) (err error) {
var _p0 uint32 var _p0 uint32
if bInheritHandle { if bInheritHandle {
@@ -1857,6 +1893,14 @@ func DuplicateHandle(hSourceProcessHandle Handle, hSourceHandle Handle, hTargetP
return return
} }
func EscapeCommFunction(handle Handle, dwFunc uint32) (err error) {
r1, _, e1 := syscall.Syscall(procEscapeCommFunction.Addr(), 2, uintptr(handle), uintptr(dwFunc), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func ExitProcess(exitcode uint32) { func ExitProcess(exitcode uint32) {
syscall.Syscall(procExitProcess.Addr(), 1, uintptr(exitcode), 0, 0) syscall.Syscall(procExitProcess.Addr(), 1, uintptr(exitcode), 0, 0)
return return
@@ -2058,6 +2102,22 @@ func GetActiveProcessorCount(groupNumber uint16) (ret uint32) {
return return
} }
func GetCommModemStatus(handle Handle, lpModemStat *uint32) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommModemStatus.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpModemStat)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func GetCommState(handle Handle, lpDCB *DCB) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommState.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpDCB)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) { func GetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) {
r1, _, e1 := syscall.Syscall(procGetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0) r1, _, e1 := syscall.Syscall(procGetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0)
if r1 == 0 { if r1 == 0 {
@@ -2810,6 +2870,14 @@ func PulseEvent(event Handle) (err error) {
return return
} }
func PurgeComm(handle Handle, dwFlags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procPurgeComm.Addr(), 2, uintptr(handle), uintptr(dwFlags), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func QueryDosDevice(deviceName *uint16, targetPath *uint16, max uint32) (n uint32, err error) { func QueryDosDevice(deviceName *uint16, targetPath *uint16, max uint32) (n uint32, err error) {
r0, _, e1 := syscall.Syscall(procQueryDosDeviceW.Addr(), 3, uintptr(unsafe.Pointer(deviceName)), uintptr(unsafe.Pointer(targetPath)), uintptr(max)) r0, _, e1 := syscall.Syscall(procQueryDosDeviceW.Addr(), 3, uintptr(unsafe.Pointer(deviceName)), uintptr(unsafe.Pointer(targetPath)), uintptr(max))
n = uint32(r0) n = uint32(r0)
@@ -2924,6 +2992,30 @@ func ResumeThread(thread Handle) (ret uint32, err error) {
return return
} }
func SetCommBreak(handle Handle) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommBreak.Addr(), 1, uintptr(handle), 0, 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommMask(handle Handle, dwEvtMask uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommMask.Addr(), 2, uintptr(handle), uintptr(dwEvtMask), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommState(handle Handle, lpDCB *DCB) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommState.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(lpDCB)), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) { func SetCommTimeouts(handle Handle, timeouts *CommTimeouts) (err error) {
r1, _, e1 := syscall.Syscall(procSetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0) r1, _, e1 := syscall.Syscall(procSetCommTimeouts.Addr(), 2, uintptr(handle), uintptr(unsafe.Pointer(timeouts)), 0)
if r1 == 0 { if r1 == 0 {
@@ -2989,14 +3081,6 @@ func SetEndOfFile(handle Handle) (err error) {
return return
} }
func SetFileValidData(handle Handle, validDataLength int64) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileValidData.Addr(), 2, uintptr(handle), uintptr(validDataLength), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetEnvironmentVariable(name *uint16, value *uint16) (err error) { func SetEnvironmentVariable(name *uint16, value *uint16) (err error) {
r1, _, e1 := syscall.Syscall(procSetEnvironmentVariableW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(value)), 0) r1, _, e1 := syscall.Syscall(procSetEnvironmentVariableW.Addr(), 2, uintptr(unsafe.Pointer(name)), uintptr(unsafe.Pointer(value)), 0)
if r1 == 0 { if r1 == 0 {
@@ -3060,6 +3144,14 @@ func SetFileTime(handle Handle, ctime *Filetime, atime *Filetime, wtime *Filetim
return return
} }
func SetFileValidData(handle Handle, validDataLength int64) (err error) {
r1, _, e1 := syscall.Syscall(procSetFileValidData.Addr(), 2, uintptr(handle), uintptr(validDataLength), 0)
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SetHandleInformation(handle Handle, mask uint32, flags uint32) (err error) { func SetHandleInformation(handle Handle, mask uint32, flags uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetHandleInformation.Addr(), 3, uintptr(handle), uintptr(mask), uintptr(flags)) r1, _, e1 := syscall.Syscall(procSetHandleInformation.Addr(), 3, uintptr(handle), uintptr(mask), uintptr(flags))
if r1 == 0 { if r1 == 0 {
@@ -3145,6 +3237,14 @@ func SetVolumeMountPoint(volumeMountPoint *uint16, volumeName *uint16) (err erro
return return
} }
func SetupComm(handle Handle, dwInQueue uint32, dwOutQueue uint32) (err error) {
r1, _, e1 := syscall.Syscall(procSetupComm.Addr(), 3, uintptr(handle), uintptr(dwInQueue), uintptr(dwOutQueue))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func SizeofResource(module Handle, resInfo Handle) (size uint32, err error) { func SizeofResource(module Handle, resInfo Handle) (size uint32, err error) {
r0, _, e1 := syscall.Syscall(procSizeofResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0) r0, _, e1 := syscall.Syscall(procSizeofResource.Addr(), 2, uintptr(module), uintptr(resInfo), 0)
size = uint32(r0) size = uint32(r0)
@@ -3291,6 +3391,14 @@ func WTSGetActiveConsoleSessionId() (sessionID uint32) {
return return
} }
func WaitCommEvent(handle Handle, lpEvtMask *uint32, lpOverlapped *Overlapped) (err error) {
r1, _, e1 := syscall.Syscall(procWaitCommEvent.Addr(), 3, uintptr(handle), uintptr(unsafe.Pointer(lpEvtMask)), uintptr(unsafe.Pointer(lpOverlapped)))
if r1 == 0 {
err = errnoErr(e1)
}
return
}
func waitForMultipleObjects(count uint32, handles uintptr, waitAll bool, waitMilliseconds uint32) (event uint32, err error) { func waitForMultipleObjects(count uint32, handles uintptr, waitAll bool, waitMilliseconds uint32) (event uint32, err error) {
var _p0 uint32 var _p0 uint32
if waitAll { if waitAll {

View File

@@ -1,374 +0,0 @@
Mozilla Public License Version 2.0
==================================
1. Definitions
--------------
1.1. "Contributor"
means each individual or legal entity that creates, contributes to
the creation of, or owns Covered Software.
1.2. "Contributor Version"
means the combination of the Contributions of others (if any) used
by a Contributor and that particular Contributor's Contribution.
1.3. "Contribution"
means Covered Software of a particular Contributor.
1.4. "Covered Software"
means Source Code Form to which the initial Contributor has attached
the notice in Exhibit A, the Executable Form of such Source Code
Form, and Modifications of such Source Code Form, in each case
including portions thereof.
1.5. "Incompatible With Secondary Licenses"
means
(a) that the initial Contributor has attached the notice described
in Exhibit B to the Covered Software; or
(b) that the Covered Software was made available under the terms of
version 1.1 or earlier of the License, but not also under the
terms of a Secondary License.
1.6. "Executable Form"
means any form of the work other than Source Code Form.
1.7. "Larger Work"
means a work that combines Covered Software with other material, in
a separate file or files, that is not Covered Software.
1.8. "License"
means this document.
1.9. "Licensable"
means having the right to grant, to the maximum extent possible,
whether at the time of the initial grant or subsequently, any and
all of the rights conveyed by this License.
1.10. "Modifications"
means any of the following:
(a) any file in Source Code Form that results from an addition to,
deletion from, or modification of the contents of Covered
Software; or
(b) any new file in Source Code Form that contains any Covered
Software.
1.11. "Patent Claims" of a Contributor
means any patent claim(s), including without limitation, method,
process, and apparatus claims, in any patent Licensable by such
Contributor that would be infringed, but for the grant of the
License, by the making, using, selling, offering for sale, having
made, import, or transfer of either its Contributions or its
Contributor Version.
1.12. "Secondary License"
means either the GNU General Public License, Version 2.0, the GNU
Lesser General Public License, Version 2.1, the GNU Affero General
Public License, Version 3.0, or any later versions of those
licenses.
1.13. "Source Code Form"
means the form of the work preferred for making modifications.
1.14. "You" (or "Your")
means an individual or a legal entity exercising rights under this
License. For legal entities, "You" includes any entity that
controls, is controlled by, or is under common control with You. For
purposes of this definition, "control" means (a) the power, direct
or indirect, to cause the direction or management of such entity,
whether by contract or otherwise, or (b) ownership of more than
fifty percent (50%) of the outstanding shares or beneficial
ownership of such entity.
2. License Grants and Conditions
--------------------------------
2.1. Grants
Each Contributor hereby grants You a world-wide, royalty-free,
non-exclusive license:
(a) under intellectual property rights (other than patent or trademark)
Licensable by such Contributor to use, reproduce, make available,
modify, display, perform, distribute, and otherwise exploit its
Contributions, either on an unmodified basis, with Modifications, or
as part of a Larger Work; and
(b) under Patent Claims of such Contributor to make, use, sell, offer
for sale, have made, import, and otherwise transfer either its
Contributions or its Contributor Version.
2.2. Effective Date
The licenses granted in Section 2.1 with respect to any Contribution
become effective for each Contribution on the date the Contributor first
distributes such Contribution.
2.3. Limitations on Grant Scope
The licenses granted in this Section 2 are the only rights granted under
this License. No additional rights or licenses will be implied from the
distribution or licensing of Covered Software under this License.
Notwithstanding Section 2.1(b) above, no patent license is granted by a
Contributor:
(a) for any code that a Contributor has removed from Covered Software;
or
(b) for infringements caused by: (i) Your and any other third party's
modifications of Covered Software, or (ii) the combination of its
Contributions with other software (except as part of its Contributor
Version); or
(c) under Patent Claims infringed by Covered Software in the absence of
its Contributions.
This License does not grant any rights in the trademarks, service marks,
or logos of any Contributor (except as may be necessary to comply with
the notice requirements in Section 3.4).
2.4. Subsequent Licenses
No Contributor makes additional grants as a result of Your choice to
distribute the Covered Software under a subsequent version of this
License (see Section 10.2) or under the terms of a Secondary License (if
permitted under the terms of Section 3.3).
2.5. Representation
Each Contributor represents that the Contributor believes its
Contributions are its original creation(s) or it has sufficient rights
to grant the rights to its Contributions conveyed by this License.
2.6. Fair Use
This License is not intended to limit any rights You have under
applicable copyright doctrines of fair use, fair dealing, or other
equivalents.
2.7. Conditions
Sections 3.1, 3.2, 3.3, and 3.4 are conditions of the licenses granted
in Section 2.1.
3. Responsibilities
-------------------
3.1. Distribution of Source Form
All distribution of Covered Software in Source Code Form, including any
Modifications that You create or to which You contribute, must be under
the terms of this License. You must inform recipients that the Source
Code Form of the Covered Software is governed by the terms of this
License, and how they can obtain a copy of this License. You may not
attempt to alter or restrict the recipients' rights in the Source Code
Form.
3.2. Distribution of Executable Form
If You distribute Covered Software in Executable Form then:
(a) such Covered Software must also be made available in Source Code
Form, as described in Section 3.1, and You must inform recipients of
the Executable Form how they can obtain a copy of such Source Code
Form by reasonable means in a timely manner, at a charge no more
than the cost of distribution to the recipient; and
(b) You may distribute such Executable Form under the terms of this
License, or sublicense it under different terms, provided that the
license for the Executable Form does not attempt to limit or alter
the recipients' rights in the Source Code Form under this License.
3.3. Distribution of a Larger Work
You may create and distribute a Larger Work under terms of Your choice,
provided that You also comply with the requirements of this License for
the Covered Software. If the Larger Work is a combination of Covered
Software with a work governed by one or more Secondary Licenses, and the
Covered Software is not Incompatible With Secondary Licenses, this
License permits You to additionally distribute such Covered Software
under the terms of such Secondary License(s), so that the recipient of
the Larger Work may, at their option, further distribute the Covered
Software under the terms of either this License or such Secondary
License(s).
3.4. Notices
You may not remove or alter the substance of any license notices
(including copyright notices, patent notices, disclaimers of warranty,
or limitations of liability) contained within the Source Code Form of
the Covered Software, except that You may alter any license notices to
the extent required to remedy known factual inaccuracies.
3.5. Application of Additional Terms
You may choose to offer, and to charge a fee for, warranty, support,
indemnity or liability obligations to one or more recipients of Covered
Software. However, You may do so only on Your own behalf, and not on
behalf of any Contributor. You must make it absolutely clear that any
such warranty, support, indemnity, or liability obligation is offered by
You alone, and You hereby agree to indemnify every Contributor for any
liability incurred by such Contributor as a result of warranty, support,
indemnity or liability terms You offer. You may include additional
disclaimers of warranty and limitations of liability specific to any
jurisdiction.
4. Inability to Comply Due to Statute or Regulation
---------------------------------------------------
If it is impossible for You to comply with any of the terms of this
License with respect to some or all of the Covered Software due to
statute, judicial order, or regulation then You must: (a) comply with
the terms of this License to the maximum extent possible; and (b)
describe the limitations and the code they affect. Such description must
be placed in a text file included with all distributions of the Covered
Software under this License. Except to the extent prohibited by statute
or regulation, such description must be sufficiently detailed for a
recipient of ordinary skill to be able to understand it.
5. Termination
--------------
5.1. The rights granted under this License will terminate automatically
if You fail to comply with any of its terms. However, if You become
compliant, then the rights granted under this License from a particular
Contributor are reinstated (a) provisionally, unless and until such
Contributor explicitly and finally terminates Your grants, and (b) on an
ongoing basis, if such Contributor fails to notify You of the
non-compliance by some reasonable means prior to 60 days after You have
come back into compliance. Moreover, Your grants from a particular
Contributor are reinstated on an ongoing basis if such Contributor
notifies You of the non-compliance by some reasonable means, this is the
first time You have received notice of non-compliance with this License
from such Contributor, and You become compliant prior to 30 days after
Your receipt of the notice.
5.2. If You initiate litigation against any entity by asserting a patent
infringement claim (excluding declaratory judgment actions,
counter-claims, and cross-claims) alleging that a Contributor Version
directly or indirectly infringes any patent, then the rights granted to
You by any and all Contributors for the Covered Software under Section
2.1 of this License shall terminate.
5.3. In the event of termination under Sections 5.1 or 5.2 above, all
end user license agreements (excluding distributors and resellers) which
have been validly granted by You or Your distributors under this License
prior to termination shall survive termination.
************************************************************************
* *
* 6. Disclaimer of Warranty *
* ------------------------- *
* *
* Covered Software is provided under this License on an "as is" *
* basis, without warranty of any kind, either expressed, implied, or *
* statutory, including, without limitation, warranties that the *
* Covered Software is free of defects, merchantable, fit for a *
* particular purpose or non-infringing. The entire risk as to the *
* quality and performance of the Covered Software is with You. *
* Should any Covered Software prove defective in any respect, You *
* (not any Contributor) assume the cost of any necessary servicing, *
* repair, or correction. This disclaimer of warranty constitutes an *
* essential part of this License. No use of any Covered Software is *
* authorized under this License except under this disclaimer. *
* *
************************************************************************
************************************************************************
* *
* 7. Limitation of Liability *
* -------------------------- *
* *
* Under no circumstances and under no legal theory, whether tort *
* (including negligence), contract, or otherwise, shall any *
* Contributor, or anyone who distributes Covered Software as *
* permitted above, be liable to You for any direct, indirect, *
* special, incidental, or consequential damages of any character *
* including, without limitation, damages for lost profits, loss of *
* goodwill, work stoppage, computer failure or malfunction, or any *
* and all other commercial damages or losses, even if such party *
* shall have been informed of the possibility of such damages. This *
* limitation of liability shall not apply to liability for death or *
* personal injury resulting from such party's negligence to the *
* extent applicable law prohibits such limitation. Some *
* jurisdictions do not allow the exclusion or limitation of *
* incidental or consequential damages, so this exclusion and *
* limitation may not apply to You. *
* *
************************************************************************
8. Litigation
-------------
Any litigation relating to this License may be brought only in the
courts of a jurisdiction where the defendant maintains its principal
place of business and such litigation shall be governed by laws of that
jurisdiction, without reference to its conflict-of-law provisions.
Nothing in this Section shall prevent a party's ability to bring
cross-claims or counter-claims.
9. Miscellaneous
----------------
This License represents the complete agreement concerning the subject
matter hereof. If any provision of this License is held to be
unenforceable, such provision shall be reformed only to the extent
necessary to make it enforceable. Any law or regulation which provides
that the language of a contract shall be construed against the drafter
shall not be used to construe this License against a Contributor.
10. Versions of the License
---------------------------
10.1. New Versions
Mozilla Foundation is the license steward. Except as provided in Section
10.3, no one other than the license steward has the right to modify or
publish new versions of this License. Each version will be given a
distinguishing version number.
10.2. Effect of New Versions
You may distribute the Covered Software under the terms of the version
of the License under which You originally received the Covered Software,
or under the terms of any subsequent version published by the license
steward.
10.3. Modified Versions
If you create software not governed by this License, and you want to
create a new license for such software, you may create and use a
modified version of this License if you rename the license and remove
any references to the name of the license steward (except to note that
such modified license differs from this License).
10.4. Distributing Source Code Form that is Incompatible With Secondary
Licenses
If You choose to distribute Source Code Form that is Incompatible With
Secondary Licenses under the terms of this version of the License, the
notice described in Exhibit B of this License must be attached.
Exhibit A - Source Code Form License Notice
-------------------------------------------
This Source Code Form is subject to the terms of the Mozilla Public
License, v. 2.0. If a copy of the MPL was not distributed with this
file, You can obtain one at http://mozilla.org/MPL/2.0/.
If it is not possible or desirable to put the notice in a particular
file, then You may include the notice in a location (such as a LICENSE
file in a relevant directory) where a recipient would be likely to look
for such a notice.
You may add additional accurate notices of copyright ownership.
Exhibit B - "Incompatible With Secondary Licenses" Notice
---------------------------------------------------------
This Source Code Form is "Incompatible With Secondary Licenses", as
defined by the Mozilla Public License, v. 2.0.

View File

@@ -1,6 +0,0 @@
# maulogger
A logger in Go. Deprecated in favor of [zerolog](https://github.com/rs/zerolog).
Utilities for migrating gracefully can be found in the maulogadapt package,
it includes both wrapping a zerolog in the maulogger interface, and wrapping a
maulogger as a zerolog output writer.

View File

@@ -1,284 +0,0 @@
// mauLogger - A logger for Go programs
// Copyright (c) 2016-2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogger
import (
"os"
)
// DefaultLogger ...
var DefaultLogger = Create().(*BasicLogger)
// SetWriter formats the given parts with fmt.Sprint and logs the result with the SetWriter level
func SetWriter(w *os.File) {
DefaultLogger.SetWriter(w)
}
// OpenFile formats the given parts with fmt.Sprint and logs the result with the OpenFile level
func OpenFile() error {
return DefaultLogger.OpenFile()
}
// Close formats the given parts with fmt.Sprint and logs the result with the Close level
func Close() error {
return DefaultLogger.Close()
}
// Sub creates a Sublogger
func Sub(module string) Logger {
return DefaultLogger.Sub(module)
}
// Raw formats the given parts with fmt.Sprint and logs the result with the Raw level
func Rawm(level Level, metadata map[string]interface{}, module, message string) {
DefaultLogger.Raw(level, metadata, module, message)
}
func Raw(level Level, module, message string) {
DefaultLogger.Raw(level, map[string]interface{}{}, module, message)
}
// Log formats the given parts with fmt.Sprint and logs the result with the given level
func Log(level Level, parts ...interface{}) {
DefaultLogger.DefaultSub.Log(level, parts...)
}
// Logln formats the given parts with fmt.Sprintln and logs the result with the given level
func Logln(level Level, parts ...interface{}) {
DefaultLogger.DefaultSub.Logln(level, parts...)
}
// Logf formats the given message and args with fmt.Sprintf and logs the result with the given level
func Logf(level Level, message string, args ...interface{}) {
DefaultLogger.DefaultSub.Logf(level, message, args...)
}
// Logfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the given level
func Logfln(level Level, message string, args ...interface{}) {
DefaultLogger.DefaultSub.Logfln(level, message, args...)
}
// Debug formats the given parts with fmt.Sprint and logs the result with the Debug level
func Debug(parts ...interface{}) {
DefaultLogger.DefaultSub.Debug(parts...)
}
// Debugln formats the given parts with fmt.Sprintln and logs the result with the Debug level
func Debugln(parts ...interface{}) {
DefaultLogger.DefaultSub.Debugln(parts...)
}
// Debugf formats the given message and args with fmt.Sprintf and logs the result with the Debug level
func Debugf(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Debugf(message, args...)
}
// Debugfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Debug level
func Debugfln(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Debugfln(message, args...)
}
// Info formats the given parts with fmt.Sprint and logs the result with the Info level
func Info(parts ...interface{}) {
DefaultLogger.DefaultSub.Info(parts...)
}
// Infoln formats the given parts with fmt.Sprintln and logs the result with the Info level
func Infoln(parts ...interface{}) {
DefaultLogger.DefaultSub.Infoln(parts...)
}
// Infof formats the given message and args with fmt.Sprintf and logs the result with the Info level
func Infof(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Infof(message, args...)
}
// Infofln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Info level
func Infofln(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Infofln(message, args...)
}
// Warn formats the given parts with fmt.Sprint and logs the result with the Warn level
func Warn(parts ...interface{}) {
DefaultLogger.DefaultSub.Warn(parts...)
}
// Warnln formats the given parts with fmt.Sprintln and logs the result with the Warn level
func Warnln(parts ...interface{}) {
DefaultLogger.DefaultSub.Warnln(parts...)
}
// Warnf formats the given message and args with fmt.Sprintf and logs the result with the Warn level
func Warnf(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Warnf(message, args...)
}
// Warnfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Warn level
func Warnfln(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Warnfln(message, args...)
}
// Error formats the given parts with fmt.Sprint and logs the result with the Error level
func Error(parts ...interface{}) {
DefaultLogger.DefaultSub.Error(parts...)
}
// Errorln formats the given parts with fmt.Sprintln and logs the result with the Error level
func Errorln(parts ...interface{}) {
DefaultLogger.DefaultSub.Errorln(parts...)
}
// Errorf formats the given message and args with fmt.Sprintf and logs the result with the Error level
func Errorf(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Errorf(message, args...)
}
// Errorfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Error level
func Errorfln(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Errorfln(message, args...)
}
// Fatal formats the given parts with fmt.Sprint and logs the result with the Fatal level
func Fatal(parts ...interface{}) {
DefaultLogger.DefaultSub.Fatal(parts...)
}
// Fatalln formats the given parts with fmt.Sprintln and logs the result with the Fatal level
func Fatalln(parts ...interface{}) {
DefaultLogger.DefaultSub.Fatalln(parts...)
}
// Fatalf formats the given message and args with fmt.Sprintf and logs the result with the Fatal level
func Fatalf(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Fatalf(message, args...)
}
// Fatalfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Fatal level
func Fatalfln(message string, args ...interface{}) {
DefaultLogger.DefaultSub.Fatalfln(message, args...)
}
// Log formats the given parts with fmt.Sprint and logs the result with the given level
func (log *BasicLogger) Log(level Level, parts ...interface{}) {
log.DefaultSub.Log(level, parts...)
}
// Logln formats the given parts with fmt.Sprintln and logs the result with the given level
func (log *BasicLogger) Logln(level Level, parts ...interface{}) {
log.DefaultSub.Logln(level, parts...)
}
// Logf formats the given message and args with fmt.Sprintf and logs the result with the given level
func (log *BasicLogger) Logf(level Level, message string, args ...interface{}) {
log.DefaultSub.Logf(level, message, args...)
}
// Logfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the given level
func (log *BasicLogger) Logfln(level Level, message string, args ...interface{}) {
log.DefaultSub.Logfln(level, message, args...)
}
// Debug formats the given parts with fmt.Sprint and logs the result with the Debug level
func (log *BasicLogger) Debug(parts ...interface{}) {
log.DefaultSub.Debug(parts...)
}
// Debugln formats the given parts with fmt.Sprintln and logs the result with the Debug level
func (log *BasicLogger) Debugln(parts ...interface{}) {
log.DefaultSub.Debugln(parts...)
}
// Debugf formats the given message and args with fmt.Sprintf and logs the result with the Debug level
func (log *BasicLogger) Debugf(message string, args ...interface{}) {
log.DefaultSub.Debugf(message, args...)
}
// Debugfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Debug level
func (log *BasicLogger) Debugfln(message string, args ...interface{}) {
log.DefaultSub.Debugfln(message, args...)
}
// Info formats the given parts with fmt.Sprint and logs the result with the Info level
func (log *BasicLogger) Info(parts ...interface{}) {
log.DefaultSub.Info(parts...)
}
// Infoln formats the given parts with fmt.Sprintln and logs the result with the Info level
func (log *BasicLogger) Infoln(parts ...interface{}) {
log.DefaultSub.Infoln(parts...)
}
// Infofln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Info level
func (log *BasicLogger) Infofln(message string, args ...interface{}) {
log.DefaultSub.Infofln(message, args...)
}
// Infof formats the given message and args with fmt.Sprintf and logs the result with the Info level
func (log *BasicLogger) Infof(message string, args ...interface{}) {
log.DefaultSub.Infof(message, args...)
}
// Warn formats the given parts with fmt.Sprint and logs the result with the Warn level
func (log *BasicLogger) Warn(parts ...interface{}) {
log.DefaultSub.Warn(parts...)
}
// Warnln formats the given parts with fmt.Sprintln and logs the result with the Warn level
func (log *BasicLogger) Warnln(parts ...interface{}) {
log.DefaultSub.Warnln(parts...)
}
// Warnfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Warn level
func (log *BasicLogger) Warnfln(message string, args ...interface{}) {
log.DefaultSub.Warnfln(message, args...)
}
// Warnf formats the given message and args with fmt.Sprintf and logs the result with the Warn level
func (log *BasicLogger) Warnf(message string, args ...interface{}) {
log.DefaultSub.Warnf(message, args...)
}
// Error formats the given parts with fmt.Sprint and logs the result with the Error level
func (log *BasicLogger) Error(parts ...interface{}) {
log.DefaultSub.Error(parts...)
}
// Errorln formats the given parts with fmt.Sprintln and logs the result with the Error level
func (log *BasicLogger) Errorln(parts ...interface{}) {
log.DefaultSub.Errorln(parts...)
}
// Errorf formats the given message and args with fmt.Sprintf and logs the result with the Error level
func (log *BasicLogger) Errorf(message string, args ...interface{}) {
log.DefaultSub.Errorf(message, args...)
}
// Errorfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Error level
func (log *BasicLogger) Errorfln(message string, args ...interface{}) {
log.DefaultSub.Errorfln(message, args...)
}
// Fatal formats the given parts with fmt.Sprint and logs the result with the Fatal level
func (log *BasicLogger) Fatal(parts ...interface{}) {
log.DefaultSub.Fatal(parts...)
}
// Fatalln formats the given parts with fmt.Sprintln and logs the result with the Fatal level
func (log *BasicLogger) Fatalln(parts ...interface{}) {
log.DefaultSub.Fatalln(parts...)
}
// Fatalf formats the given message and args with fmt.Sprintf and logs the result with the Fatal level
func (log *BasicLogger) Fatalf(message string, args ...interface{}) {
log.DefaultSub.Fatalf(message, args...)
}
// Fatalfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Fatal level
func (log *BasicLogger) Fatalfln(message string, args ...interface{}) {
log.DefaultSub.Fatalfln(message, args...)
}

View File

@@ -1,47 +0,0 @@
// mauLogger - A logger for Go programs
// Copyright (c) 2016-2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogger
import (
"fmt"
)
// Level is the severity level of a log entry.
type Level struct {
Name string
Severity, Color int
}
var (
// LevelDebug is the level for debug messages.
LevelDebug = Level{Name: "DEBUG", Color: -1, Severity: 0}
// LevelInfo is the level for basic log messages.
LevelInfo = Level{Name: "INFO", Color: 36, Severity: 10}
// LevelWarn is the level saying that something went wrong, but the program will continue operating mostly normally.
LevelWarn = Level{Name: "WARN", Color: 33, Severity: 50}
// LevelError is the level saying that something went wrong and the program may not operate as expected, but will still continue.
LevelError = Level{Name: "ERROR", Color: 31, Severity: 100}
// LevelFatal is the level saying that something went wrong and the program will not operate normally.
LevelFatal = Level{Name: "FATAL", Color: 35, Severity: 9001}
)
// GetColor gets the ANSI escape color code for the log level.
func (lvl Level) GetColor() string {
if lvl.Color < 0 {
return "\x1b[0m"
}
return fmt.Sprintf("\x1b[%dm", lvl.Color)
}
// GetReset gets the ANSI escape reset code.
func (lvl Level) GetReset() string {
if lvl.Color < 0 {
return ""
}
return "\x1b[0m"
}

View File

@@ -1,224 +0,0 @@
// mauLogger - A logger for Go programs
// Copyright (c) 2016-2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogger
import (
"encoding/json"
"fmt"
"io"
"os"
"strings"
"sync"
"time"
)
// LoggerFileFormat ...
type LoggerFileFormat func(now string, i int) string
type BasicLogger struct {
PrintLevel int
FlushLineThreshold int
FileTimeFormat string
FileFormat LoggerFileFormat
TimeFormat string
FileMode os.FileMode
DefaultSub Logger
JSONFile bool
JSONStdout bool
stdoutEncoder *json.Encoder
fileEncoder *json.Encoder
writer *os.File
writerLock sync.Mutex
StdoutLock sync.Mutex
StderrLock sync.Mutex
lines int
metadata map[string]interface{}
}
// Logger contains advanced logging functions
type Logger interface {
Sub(module string) Logger
Subm(module string, metadata map[string]interface{}) Logger
WithDefaultLevel(level Level) Logger
GetParent() Logger
Writer(level Level) io.WriteCloser
Log(level Level, parts ...interface{})
Logln(level Level, parts ...interface{})
Logf(level Level, message string, args ...interface{})
Logfln(level Level, message string, args ...interface{})
Debug(parts ...interface{})
Debugln(parts ...interface{})
Debugf(message string, args ...interface{})
Debugfln(message string, args ...interface{})
Info(parts ...interface{})
Infoln(parts ...interface{})
Infof(message string, args ...interface{})
Infofln(message string, args ...interface{})
Warn(parts ...interface{})
Warnln(parts ...interface{})
Warnf(message string, args ...interface{})
Warnfln(message string, args ...interface{})
Error(parts ...interface{})
Errorln(parts ...interface{})
Errorf(message string, args ...interface{})
Errorfln(message string, args ...interface{})
Fatal(parts ...interface{})
Fatalln(parts ...interface{})
Fatalf(message string, args ...interface{})
Fatalfln(message string, args ...interface{})
}
// Create a Logger
func Createm(metadata map[string]interface{}) Logger {
var log = &BasicLogger{
PrintLevel: 10,
FileTimeFormat: "2006-01-02",
FileFormat: func(now string, i int) string { return fmt.Sprintf("%[1]s-%02[2]d.log", now, i) },
TimeFormat: "15:04:05 02.01.2006",
FileMode: 0600,
FlushLineThreshold: 5,
lines: 0,
metadata: metadata,
}
log.DefaultSub = log.Sub("")
return log
}
func Create() Logger {
return Createm(map[string]interface{}{})
}
func (log *BasicLogger) EnableJSONStdout() {
log.JSONStdout = true
log.stdoutEncoder = json.NewEncoder(os.Stdout)
}
func (log *BasicLogger) GetParent() Logger {
return nil
}
// SetWriter formats the given parts with fmt.Sprint and logs the result with the SetWriter level
func (log *BasicLogger) SetWriter(w *os.File) {
log.writer = w
if log.JSONFile {
log.fileEncoder = json.NewEncoder(w)
}
}
// OpenFile formats the given parts with fmt.Sprint and logs the result with the OpenFile level
func (log *BasicLogger) OpenFile() error {
now := time.Now().Format(log.FileTimeFormat)
i := 1
for ; ; i++ {
if _, err := os.Stat(log.FileFormat(now, i)); os.IsNotExist(err) {
break
}
}
writer, err := os.OpenFile(log.FileFormat(now, i), os.O_WRONLY|os.O_CREATE|os.O_APPEND, log.FileMode)
if err != nil {
return err
} else if writer == nil {
return os.ErrInvalid
}
log.SetWriter(writer)
return nil
}
// Close formats the given parts with fmt.Sprint and logs the result with the Close level
func (log *BasicLogger) Close() error {
if log.writer != nil {
return log.writer.Close()
}
return nil
}
type logLine struct {
log *BasicLogger
Command string `json:"command"`
Time time.Time `json:"time"`
Level string `json:"level"`
Module string `json:"module"`
Message string `json:"message"`
Metadata map[string]interface{} `json:"metadata"`
}
func (ll logLine) String() string {
if len(ll.Module) == 0 {
return fmt.Sprintf("[%s] [%s] %s", ll.Time.Format(ll.log.TimeFormat), ll.Level, ll.Message)
} else {
return fmt.Sprintf("[%s] [%s/%s] %s", ll.Time.Format(ll.log.TimeFormat), ll.Module, ll.Level, ll.Message)
}
}
func reduceItem(m1, m2 map[string]interface{}) map[string]interface{} {
m3 := map[string]interface{}{}
_merge := func(m map[string]interface{}) {
for ia, va := range m {
m3[ia] = va
}
}
_merge(m1)
_merge(m2)
return m3
}
// Raw formats the given parts with fmt.Sprint and logs the result with the Raw level
func (log *BasicLogger) Raw(level Level, extraMetadata map[string]interface{}, module, origMessage string) {
message := logLine{log, "log", time.Now(), level.Name, module, strings.TrimSpace(origMessage), reduceItem(log.metadata, extraMetadata)}
if log.writer != nil {
log.writerLock.Lock()
var err error
if log.JSONFile {
err = log.fileEncoder.Encode(&message)
} else {
_, err = log.writer.WriteString(message.String())
_, _ = log.writer.WriteString("\n")
}
log.writerLock.Unlock()
if err != nil {
log.StderrLock.Lock()
_, _ = os.Stderr.WriteString("Failed to write to log file:")
_, _ = os.Stderr.WriteString(err.Error())
log.StderrLock.Unlock()
}
}
if level.Severity >= log.PrintLevel {
if log.JSONStdout {
log.StdoutLock.Lock()
_ = log.stdoutEncoder.Encode(&message)
log.StdoutLock.Unlock()
} else if level.Severity >= LevelError.Severity {
log.StderrLock.Lock()
_, _ = os.Stderr.WriteString(level.GetColor())
_, _ = os.Stderr.WriteString(message.String())
_, _ = os.Stderr.WriteString(level.GetReset())
_, _ = os.Stderr.WriteString("\n")
log.StderrLock.Unlock()
} else {
log.StdoutLock.Lock()
_, _ = os.Stdout.WriteString(level.GetColor())
_, _ = os.Stdout.WriteString(message.String())
_, _ = os.Stdout.WriteString(level.GetReset())
_, _ = os.Stdout.WriteString("\n")
log.StdoutLock.Unlock()
}
}
}

View File

@@ -1,185 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogadapt
import (
"fmt"
"io"
"strings"
"github.com/rs/zerolog"
"maunium.net/go/maulogger/v2"
)
type MauZeroLog struct {
*zerolog.Logger
orig *zerolog.Logger
mod string
}
func ZeroAsMau(log *zerolog.Logger) maulogger.Logger {
return MauZeroLog{log, log, ""}
}
var _ maulogger.Logger = (*MauZeroLog)(nil)
func (m MauZeroLog) Sub(module string) maulogger.Logger {
return m.Subm(module, map[string]interface{}{})
}
func (m MauZeroLog) Subm(module string, metadata map[string]interface{}) maulogger.Logger {
if m.mod != "" {
module = fmt.Sprintf("%s/%s", m.mod, module)
}
var orig zerolog.Logger
if m.orig != nil {
orig = *m.orig
} else {
orig = *m.Logger
}
if len(metadata) > 0 {
with := m.orig.With()
for key, value := range metadata {
with = with.Interface(key, value)
}
orig = with.Logger()
}
log := orig.With().Str("module", module).Logger()
return MauZeroLog{&log, &orig, module}
}
func (m MauZeroLog) WithDefaultLevel(_ maulogger.Level) maulogger.Logger {
return m
}
func (m MauZeroLog) GetParent() maulogger.Logger {
return nil
}
type nopWriteCloser struct {
io.Writer
}
func (nopWriteCloser) Close() error { return nil }
func (m MauZeroLog) Writer(level maulogger.Level) io.WriteCloser {
return nopWriteCloser{m.Logger.With().Str(zerolog.LevelFieldName, zerolog.LevelFieldMarshalFunc(mauToZeroLevel(level))).Logger()}
}
func mauToZeroLevel(level maulogger.Level) zerolog.Level {
switch level {
case maulogger.LevelDebug:
return zerolog.DebugLevel
case maulogger.LevelInfo:
return zerolog.InfoLevel
case maulogger.LevelWarn:
return zerolog.WarnLevel
case maulogger.LevelError:
return zerolog.ErrorLevel
case maulogger.LevelFatal:
return zerolog.FatalLevel
default:
return zerolog.TraceLevel
}
}
func (m MauZeroLog) Log(level maulogger.Level, parts ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Logln(level maulogger.Level, parts ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Logf(level maulogger.Level, message string, args ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Logfln(level maulogger.Level, message string, args ...interface{}) {
m.Logger.WithLevel(mauToZeroLevel(level)).Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Debug(parts ...interface{}) {
m.Logger.Debug().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Debugln(parts ...interface{}) {
m.Logger.Debug().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Debugf(message string, args ...interface{}) {
m.Logger.Debug().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Debugfln(message string, args ...interface{}) {
m.Logger.Debug().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Info(parts ...interface{}) {
m.Logger.Info().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Infoln(parts ...interface{}) {
m.Logger.Info().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Infof(message string, args ...interface{}) {
m.Logger.Info().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Infofln(message string, args ...interface{}) {
m.Logger.Info().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Warn(parts ...interface{}) {
m.Logger.Warn().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Warnln(parts ...interface{}) {
m.Logger.Warn().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Warnf(message string, args ...interface{}) {
m.Logger.Warn().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Warnfln(message string, args ...interface{}) {
m.Logger.Warn().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Error(parts ...interface{}) {
m.Logger.Error().Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Errorln(parts ...interface{}) {
m.Logger.Error().Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Errorf(message string, args ...interface{}) {
m.Logger.Error().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Errorfln(message string, args ...interface{}) {
m.Logger.Error().Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Fatal(parts ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprint(parts...))
}
func (m MauZeroLog) Fatalln(parts ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(strings.TrimSuffix(fmt.Sprintln(parts...), "\n"))
}
func (m MauZeroLog) Fatalf(message string, args ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprintf(message, args...))
}
func (m MauZeroLog) Fatalfln(message string, args ...interface{}) {
m.Logger.WithLevel(zerolog.FatalLevel).Msg(fmt.Sprintf(message, args...))
}

View File

@@ -1,73 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogadapt
import (
"bytes"
"github.com/rs/zerolog"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"maunium.net/go/maulogger/v2"
)
// ZeroMauLog is a simple wrapper for a maulogger that can be set as the output writer for zerolog.
type ZeroMauLog struct {
maulogger.Logger
}
func MauAsZero(log maulogger.Logger) *zerolog.Logger {
zero := zerolog.New(&ZeroMauLog{log})
return &zero
}
var _ zerolog.LevelWriter = (*ZeroMauLog)(nil)
func (z *ZeroMauLog) Write(p []byte) (n int, err error) {
return 0, nil
}
func (z *ZeroMauLog) WriteLevel(level zerolog.Level, p []byte) (n int, err error) {
var mauLevel maulogger.Level
switch level {
case zerolog.DebugLevel:
mauLevel = maulogger.LevelDebug
case zerolog.InfoLevel, zerolog.NoLevel:
mauLevel = maulogger.LevelInfo
case zerolog.WarnLevel:
mauLevel = maulogger.LevelWarn
case zerolog.ErrorLevel:
mauLevel = maulogger.LevelError
case zerolog.FatalLevel, zerolog.PanicLevel:
mauLevel = maulogger.LevelFatal
case zerolog.Disabled, zerolog.TraceLevel:
fallthrough
default:
return 0, nil
}
p = bytes.TrimSuffix(p, []byte{'\n'})
msg := gjson.GetBytes(p, zerolog.MessageFieldName).Str
p, err = sjson.DeleteBytes(p, zerolog.MessageFieldName)
if err != nil {
return
}
p, err = sjson.DeleteBytes(p, zerolog.LevelFieldName)
if err != nil {
return
}
p, err = sjson.DeleteBytes(p, zerolog.TimestampFieldName)
if err != nil {
return
}
if len(p) > 2 {
msg += " " + string(p)
}
z.Log(mauLevel, msg)
return len(p), nil
}

View File

@@ -1,216 +0,0 @@
// mauLogger - A logger for Go programs
// Copyright (c) 2016-2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogger
import (
"fmt"
)
type Sublogger struct {
topLevel *BasicLogger
parent Logger
Module string
DefaultLevel Level
metadata map[string]interface{}
}
// Subm creates a Sublogger
func (log *BasicLogger) Subm(module string, metadata map[string]interface{}) Logger {
return &Sublogger{
topLevel: log,
parent: log,
Module: module,
DefaultLevel: LevelInfo,
metadata: metadata,
}
}
func (log *BasicLogger) Sub(module string) Logger {
return log.Subm(module, map[string]interface{}{})
}
// WithDefaultLevel creates a Sublogger with the same Module but different DefaultLevel
func (log *BasicLogger) WithDefaultLevel(lvl Level) Logger {
return log.DefaultSub.WithDefaultLevel(lvl)
}
func (log *Sublogger) GetParent() Logger {
return log.parent
}
// Sub creates a Sublogger
func (log *Sublogger) Subm(module string, metadata map[string]interface{}) Logger {
if len(module) > 0 {
module = fmt.Sprintf("%s/%s", log.Module, module)
} else {
module = log.Module
}
return &Sublogger{
topLevel: log.topLevel,
parent: log,
Module: module,
DefaultLevel: log.DefaultLevel,
metadata: metadata,
}
}
func (log *Sublogger) Sub(module string) Logger {
return log.Subm(module, map[string]interface{}{})
}
// WithDefaultLevel creates a Sublogger with the same Module but different DefaultLevel
func (log *Sublogger) WithDefaultLevel(lvl Level) Logger {
return &Sublogger{
topLevel: log.topLevel,
parent: log.parent,
Module: log.Module,
DefaultLevel: lvl,
}
}
// SetModule changes the module name of this Sublogger
func (log *Sublogger) SetModule(mod string) {
log.Module = mod
}
// SetDefaultLevel changes the default logging level of this Sublogger
func (log *Sublogger) SetDefaultLevel(lvl Level) {
log.DefaultLevel = lvl
}
// SetParent changes the parent of this Sublogger
func (log *Sublogger) SetParent(parent *BasicLogger) {
log.topLevel = parent
}
//Write ...
func (log *Sublogger) Write(p []byte) (n int, err error) {
log.topLevel.Raw(log.DefaultLevel, log.metadata, log.Module, string(p))
return len(p), nil
}
// Log formats the given parts with fmt.Sprint and logs the result with the given level
func (log *Sublogger) Log(level Level, parts ...interface{}) {
log.topLevel.Raw(level, log.metadata, log.Module, fmt.Sprint(parts...))
}
// Logln formats the given parts with fmt.Sprintln and logs the result with the given level
func (log *Sublogger) Logln(level Level, parts ...interface{}) {
log.topLevel.Raw(level, log.metadata, log.Module, fmt.Sprintln(parts...))
}
// Logf formats the given message and args with fmt.Sprintf and logs the result with the given level
func (log *Sublogger) Logf(level Level, message string, args ...interface{}) {
log.topLevel.Raw(level, log.metadata, log.Module, fmt.Sprintf(message, args...))
}
// Logfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the given level
func (log *Sublogger) Logfln(level Level, message string, args ...interface{}) {
log.topLevel.Raw(level, log.metadata, log.Module, fmt.Sprintf(message+"\n", args...))
}
// Debug formats the given parts with fmt.Sprint and logs the result with the Debug level
func (log *Sublogger) Debug(parts ...interface{}) {
log.topLevel.Raw(LevelDebug, log.metadata, log.Module, fmt.Sprint(parts...))
}
// Debugln formats the given parts with fmt.Sprintln and logs the result with the Debug level
func (log *Sublogger) Debugln(parts ...interface{}) {
log.topLevel.Raw(LevelDebug, log.metadata, log.Module, fmt.Sprintln(parts...))
}
// Debugf formats the given message and args with fmt.Sprintf and logs the result with the Debug level
func (log *Sublogger) Debugf(message string, args ...interface{}) {
log.topLevel.Raw(LevelDebug, log.metadata, log.Module, fmt.Sprintf(message, args...))
}
// Debugfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Debug level
func (log *Sublogger) Debugfln(message string, args ...interface{}) {
log.topLevel.Raw(LevelDebug, log.metadata, log.Module, fmt.Sprintf(message+"\n", args...))
}
// Info formats the given parts with fmt.Sprint and logs the result with the Info level
func (log *Sublogger) Info(parts ...interface{}) {
log.topLevel.Raw(LevelInfo, log.metadata, log.Module, fmt.Sprint(parts...))
}
// Infoln formats the given parts with fmt.Sprintln and logs the result with the Info level
func (log *Sublogger) Infoln(parts ...interface{}) {
log.topLevel.Raw(LevelInfo, log.metadata, log.Module, fmt.Sprintln(parts...))
}
// Infof formats the given message and args with fmt.Sprintf and logs the result with the Info level
func (log *Sublogger) Infof(message string, args ...interface{}) {
log.topLevel.Raw(LevelInfo, log.metadata, log.Module, fmt.Sprintf(message, args...))
}
// Infofln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Info level
func (log *Sublogger) Infofln(message string, args ...interface{}) {
log.topLevel.Raw(LevelInfo, log.metadata, log.Module, fmt.Sprintf(message+"\n", args...))
}
// Warn formats the given parts with fmt.Sprint and logs the result with the Warn level
func (log *Sublogger) Warn(parts ...interface{}) {
log.topLevel.Raw(LevelWarn, log.metadata, log.Module, fmt.Sprint(parts...))
}
// Warnln formats the given parts with fmt.Sprintln and logs the result with the Warn level
func (log *Sublogger) Warnln(parts ...interface{}) {
log.topLevel.Raw(LevelWarn, log.metadata, log.Module, fmt.Sprintln(parts...))
}
// Warnf formats the given message and args with fmt.Sprintf and logs the result with the Warn level
func (log *Sublogger) Warnf(message string, args ...interface{}) {
log.topLevel.Raw(LevelWarn, log.metadata, log.Module, fmt.Sprintf(message, args...))
}
// Warnfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Warn level
func (log *Sublogger) Warnfln(message string, args ...interface{}) {
log.topLevel.Raw(LevelWarn, log.metadata, log.Module, fmt.Sprintf(message+"\n", args...))
}
// Error formats the given parts with fmt.Sprint and logs the result with the Error level
func (log *Sublogger) Error(parts ...interface{}) {
log.topLevel.Raw(LevelError, log.metadata, log.Module, fmt.Sprint(parts...))
}
// Errorln formats the given parts with fmt.Sprintln and logs the result with the Error level
func (log *Sublogger) Errorln(parts ...interface{}) {
log.topLevel.Raw(LevelError, log.metadata, log.Module, fmt.Sprintln(parts...))
}
// Errorf formats the given message and args with fmt.Sprintf and logs the result with the Error level
func (log *Sublogger) Errorf(message string, args ...interface{}) {
log.topLevel.Raw(LevelError, log.metadata, log.Module, fmt.Sprintf(message, args...))
}
// Errorfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Error level
func (log *Sublogger) Errorfln(message string, args ...interface{}) {
log.topLevel.Raw(LevelError, log.metadata, log.Module, fmt.Sprintf(message+"\n", args...))
}
// Fatal formats the given parts with fmt.Sprint and logs the result with the Fatal level
func (log *Sublogger) Fatal(parts ...interface{}) {
log.topLevel.Raw(LevelFatal, log.metadata, log.Module, fmt.Sprint(parts...))
}
// Fatalln formats the given parts with fmt.Sprintln and logs the result with the Fatal level
func (log *Sublogger) Fatalln(parts ...interface{}) {
log.topLevel.Raw(LevelFatal, log.metadata, log.Module, fmt.Sprintln(parts...))
}
// Fatalf formats the given message and args with fmt.Sprintf and logs the result with the Fatal level
func (log *Sublogger) Fatalf(message string, args ...interface{}) {
log.topLevel.Raw(LevelFatal, log.metadata, log.Module, fmt.Sprintf(message, args...))
}
// Fatalfln formats the given message and args with fmt.Sprintf, appends a newline and logs the result with the Fatal level
func (log *Sublogger) Fatalfln(message string, args ...interface{}) {
log.topLevel.Raw(LevelFatal, log.metadata, log.Module, fmt.Sprintf(message+"\n", args...))
}

View File

@@ -1,78 +0,0 @@
// mauLogger - A logger for Go programs
// Copyright (c) 2021 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package maulogger
import (
"bytes"
"io"
"sync"
)
// LogWriter is a buffered io.Writer that writes lines to a Logger.
type LogWriter struct {
log Logger
lock sync.Mutex
level Level
buf bytes.Buffer
}
func (log *BasicLogger) Writer(level Level) io.WriteCloser {
return &LogWriter{
log: log,
level: level,
}
}
func (log *Sublogger) Writer(level Level) io.WriteCloser {
return &LogWriter{
log: log,
level: level,
}
}
func (lw *LogWriter) writeLine(data []byte) {
if lw.buf.Len() == 0 {
if len(data) == 0 {
return
}
lw.log.Logln(lw.level, string(data))
} else {
lw.buf.Write(data)
lw.log.Logln(lw.level, lw.buf.String())
lw.buf.Reset()
}
}
// Write will write lines from the given data to the buffer. If the data doesn't end with a line break,
// everything after the last line break will be buffered until the next Write or Close call.
func (lw *LogWriter) Write(data []byte) (int, error) {
lw.lock.Lock()
newline := bytes.IndexByte(data, '\n')
if newline == len(data)-1 {
lw.writeLine(data[:len(data)-1])
} else if newline < 0 {
lw.buf.Write(data)
} else {
lines := bytes.Split(data, []byte("\n"))
for _, line := range lines[:len(lines)-1] {
lw.writeLine(line)
}
lw.buf.Write(lines[len(lines)-1])
}
lw.lock.Unlock()
return len(data), nil
}
// Close will flush remaining data in the buffer into the logger.
func (lw *LogWriter) Close() error {
lw.lock.Lock()
lw.log.Logln(lw.level, lw.buf.String())
lw.buf.Reset()
lw.lock.Unlock()
return nil
}

View File

@@ -17,3 +17,8 @@ repos:
- "maunium.net/go/mautrix" - "maunium.net/go/mautrix"
- "-w" - "-w"
- id: go-vet-repo-mod - id: go-vet-repo-mod
- repo: https://github.com/beeper/pre-commit-go
rev: v0.3.1
hooks:
- id: prevent-literal-http-methods

View File

@@ -1,9 +1,59 @@
## v0.18.0 (2024-03-16)
* **Breaking change *(client, bridge, appservice)*** Dropped support for
maulogger. Only zerolog loggers are now provided by default.
* *(bridge)* Fixed upload size limit not having a default if the server
returned no value.
* *(synapseadmin)* Added wrappers for some room and user admin APIs.
(thanks to [@grvn-ht] in [#181]).
* *(crypto/verificationhelper)* Fixed bugs.
* *(crypto)* Fixed key backup uploading doing too much base64.
* *(crypto)* Changed `EncryptMegolmEvent` to return an error if persisting the
megolm session fails. This ensures that database errors won't cause messages
to be sent with duplicate indexes.
* *(crypto)* Changed `GetOrRequestSecret` to use a callback instead of returning
the value directly. This allows validating the value in order to ignore
invalid secrets.
* *(id)* Added `ParseCommonIdentifier` function to parse any Matrix identifier
in the [Common Identifier Format].
* *(federation)* Added simple key server that passes the federation tester.
[@grvn-ht]: https://github.com/grvn-ht
[#181]: https://github.com/mautrix/go/pull/181
[Common Identifier Format]: https://spec.matrix.org/v1.9/appendices/#common-identifier-format
### beta.1 (2024-02-16)
* Bumped minimum Go version to 1.21.
* *(bridge)* Bumped minimum Matrix spec version to v1.4.
* **Breaking change *(crypto)*** Deleted old half-broken interactive
verification code and replaced it with a new `verificationhelper`.
* The new verification helper is still experimental.
* Both QR and emoji verification are supported (in theory).
* *(crypto)* Added support for server-side key backup.
* *(crypto)* Added support for receiving and sending secrets like cross-signing
private keys via secret sharing.
* *(crypto)* Added support for tracking which devices megolm sessions were
initially shared to, and allowing re-sharing the keys to those sessions.
* *(client)* Changed cross-signing key upload method to accept a callback for
user-interactive auth instead of only hardcoding password support.
* *(appservice)* Dropped support for legacy non-prefixed appservice paths
(e.g. `/transactions` instead of `/_matrix/app/v1/transactions`).
* *(appservice)* Dropped support for legacy `access_token` authorization in
appservice endpoints.
* *(bridge)* Fixed `RawArgs` field in command events of command state callbacks.
* *(appservice)* Added `CreateFull` helper function for creating an `AppService`
instance with all the mandatory fields set.
## v0.17.0 (2024-01-16) ## v0.17.0 (2024-01-16)
* **Breaking change *(bridge)*** Added raw event to portal membership handling * **Breaking change *(bridge)*** Added raw event to portal membership handling
functions. functions.
* **Breaking change *(everything)*** Added context parameters to all functions * **Breaking change *(everything)*** Added context parameters to all functions
(started by [@recht] in [#144]). (started by [@recht] in [#144]).
* **Breaking change *(client)*** Moved event source from sync event handler
function parameters to the `Mautrix.EventSource` field inside the event
struct.
* **Breaking change *(client)*** Moved `EventSource` to `event.Source`. * **Breaking change *(client)*** Moved `EventSource` to `event.Source`.
* *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version * *(client)* Removed deprecated `OldEventIgnorer`. The non-deprecated version
(`Client.DontProcessOldEvents`) is still available. (`Client.DontProcessOldEvents`) is still available.

View File

@@ -19,8 +19,8 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"go.mau.fi/util/retryafter" "go.mau.fi/util/retryafter"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/pushrules" "maunium.net/go/mautrix/pushrules"
@@ -34,15 +34,18 @@ type CryptoHelper interface {
Init(context.Context) error Init(context.Context) error
} }
// Deprecated: switch to zerolog type VerificationHelper interface {
type Logger interface { Init(context.Context) error
Debugfln(message string, args ...interface{}) StartVerification(ctx context.Context, to id.UserID) (id.VerificationTransactionID, error)
} StartInRoomVerification(ctx context.Context, roomID id.RoomID, to id.UserID) (id.VerificationTransactionID, error)
AcceptVerification(ctx context.Context, txnID id.VerificationTransactionID) error
CancelVerification(ctx context.Context, txnID id.VerificationTransactionID, code event.VerificationCancelCode, reason string) error
// Deprecated: switch to zerolog HandleScannedQRData(ctx context.Context, data []byte) error
type WarnLogger interface { ConfirmQRCodeScanned(ctx context.Context, txnID id.VerificationTransactionID) error
Logger
Warnfln(message string, args ...interface{}) StartSAS(ctx context.Context, txnID id.VerificationTransactionID) error
ConfirmSAS(ctx context.Context, txnID id.VerificationTransactionID) error
} }
// Client represents a Matrix client. // Client represents a Matrix client.
@@ -57,10 +60,9 @@ type Client struct {
Store SyncStore // The thing which can store tokens/ids Store SyncStore // The thing which can store tokens/ids
StateStore StateStore StateStore StateStore
Crypto CryptoHelper Crypto CryptoHelper
Verification VerificationHelper
Log zerolog.Logger Log zerolog.Logger
// Deprecated: switch to the zerolog instance in Log
Logger Logger
RequestHook func(req *http.Request) RequestHook func(req *http.Request)
ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration) ResponseHook func(req *http.Request, resp *http.Response, duration time.Duration)
@@ -107,7 +109,7 @@ func DiscoverClientAPI(ctx context.Context, serverName string) (*ClientWellKnown
Path: "/.well-known/matrix/client", Path: "/.well-known/matrix/client",
} }
req, err := http.NewRequestWithContext(ctx, "GET", wellKnownURL.String(), nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, wellKnownURL.String(), nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -576,14 +578,14 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) { func (cli *Client) Whoami(ctx context.Context) (resp *RespWhoami, err error) {
urlPath := cli.BuildClientURL("v3", "account", "whoami") urlPath := cli.BuildClientURL("v3", "account", "whoami")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
// CreateFilter makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter // CreateFilter makes an HTTP request according to https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3useruseridfilter
func (cli *Client) CreateFilter(ctx context.Context, filter *Filter) (resp *RespCreateFilter, err error) { func (cli *Client) CreateFilter(ctx context.Context, filter *Filter) (resp *RespCreateFilter, err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "filter") urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "filter")
_, err = cli.MakeRequest(ctx, "POST", urlPath, filter, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, filter, &resp)
return return
} }
@@ -764,7 +766,7 @@ func (cli *Client) RegisterDummy(ctx context.Context, req *ReqRegister) (*RespRe
// GetLoginFlows fetches the login flows that the homeserver supports using https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login // GetLoginFlows fetches the login flows that the homeserver supports using https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3login
func (cli *Client) GetLoginFlows(ctx context.Context) (resp *RespLoginFlows, err error) { func (cli *Client) GetLoginFlows(ctx context.Context) (resp *RespLoginFlows, err error) {
urlPath := cli.BuildClientURL("v3", "login") urlPath := cli.BuildClientURL("v3", "login")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
@@ -808,7 +810,7 @@ func (cli *Client) Login(ctx context.Context, req *ReqLogin) (resp *RespLogin, e
// This does not clear the credentials from the client instance. See ClearCredentials() instead. // This does not clear the credentials from the client instance. See ClearCredentials() instead.
func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) { func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) {
urlPath := cli.BuildClientURL("v3", "logout") urlPath := cli.BuildClientURL("v3", "logout")
_, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, &resp)
return return
} }
@@ -816,21 +818,21 @@ func (cli *Client) Logout(ctx context.Context) (resp *RespLogout, err error) {
// This does not clear the credentials from the client instance. See ClearCredentials() instead. // This does not clear the credentials from the client instance. See ClearCredentials() instead.
func (cli *Client) LogoutAll(ctx context.Context) (resp *RespLogout, err error) { func (cli *Client) LogoutAll(ctx context.Context) (resp *RespLogout, err error) {
urlPath := cli.BuildClientURL("v3", "logout", "all") urlPath := cli.BuildClientURL("v3", "logout", "all")
_, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, &resp)
return return
} }
// Versions returns the list of supported Matrix versions on this homeserver. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientversions // Versions returns the list of supported Matrix versions on this homeserver. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientversions
func (cli *Client) Versions(ctx context.Context) (resp *RespVersions, err error) { func (cli *Client) Versions(ctx context.Context) (resp *RespVersions, err error) {
urlPath := cli.BuildClientURL("versions") urlPath := cli.BuildClientURL("versions")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
// Capabilities returns capabilities on this homeserver. See https://spec.matrix.org/v1.3/client-server-api/#capabilities-negotiation // Capabilities returns capabilities on this homeserver. See https://spec.matrix.org/v1.3/client-server-api/#capabilities-negotiation
func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, err error) { func (cli *Client) Capabilities(ctx context.Context) (resp *RespCapabilities, err error) {
urlPath := cli.BuildClientURL("v3", "capabilities") urlPath := cli.BuildClientURL("v3", "capabilities")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
@@ -847,7 +849,7 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin
} else { } else {
urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias) urlPath = cli.BuildClientURL("v3", "join", roomIDorAlias)
} }
_, err = cli.MakeRequest(ctx, "POST", urlPath, content, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
if err != nil { if err != nil {
@@ -862,7 +864,7 @@ func (cli *Client) JoinRoom(ctx context.Context, roomIDorAlias, serverName strin
// Unlike JoinRoom, this method can only be used to join rooms that the server already knows about. // Unlike JoinRoom, this method can only be used to join rooms that the server already knows about.
// It's mostly intended for bridges and other things where it's already certain that the server is in the room. // It's mostly intended for bridges and other things where it's already certain that the server is in the room.
func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) { func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *RespJoinRoom, err error) {
_, err = cli.MakeRequest(ctx, "POST", cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildClientURL("v3", "rooms", roomID, "join"), nil, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) err = cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
if err != nil { if err != nil {
@@ -874,14 +876,14 @@ func (cli *Client) JoinRoomByID(ctx context.Context, roomID id.RoomID) (resp *Re
func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUserProfile, err error) { func (cli *Client) GetProfile(ctx context.Context, mxid id.UserID) (resp *RespUserProfile, err error) {
urlPath := cli.BuildClientURL("v3", "profile", mxid) urlPath := cli.BuildClientURL("v3", "profile", mxid)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
// GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname // GetDisplayName returns the display name of the user with the specified MXID. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3profileuseriddisplayname
func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) { func (cli *Client) GetDisplayName(ctx context.Context, mxid id.UserID) (resp *RespUserDisplayName, err error) {
urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname") urlPath := cli.BuildClientURL("v3", "profile", mxid, "displayname")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
@@ -896,7 +898,7 @@ func (cli *Client) SetDisplayName(ctx context.Context, displayName string) (err
s := struct { s := struct {
DisplayName string `json:"displayname"` DisplayName string `json:"displayname"`
}{displayName} }{displayName}
_, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil)
return return
} }
@@ -907,7 +909,7 @@ func (cli *Client) GetAvatarURL(ctx context.Context, mxid id.UserID) (url id.Con
AvatarURL id.ContentURI `json:"avatar_url"` AvatarURL id.ContentURI `json:"avatar_url"`
}{} }{}
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &s) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &s)
if err != nil { if err != nil {
return return
} }
@@ -926,7 +928,7 @@ func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err err
s := struct { s := struct {
AvatarURL string `json:"avatar_url"` AvatarURL string `json:"avatar_url"`
}{url.String()} }{url.String()}
_, err = cli.MakeRequest(ctx, "PUT", urlPath, &s, nil) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &s, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -937,21 +939,21 @@ func (cli *Client) SetAvatarURL(ctx context.Context, url id.ContentURI) (err err
// BeeperUpdateProfile sets custom fields in the user's profile. // BeeperUpdateProfile sets custom fields in the user's profile.
func (cli *Client) BeeperUpdateProfile(ctx context.Context, data map[string]any) (err error) { func (cli *Client) BeeperUpdateProfile(ctx context.Context, data map[string]any) (err error) {
urlPath := cli.BuildClientURL("v3", "profile", cli.UserID) urlPath := cli.BuildClientURL("v3", "profile", cli.UserID)
_, err = cli.MakeRequest(ctx, "PATCH", urlPath, &data, nil) _, err = cli.MakeRequest(ctx, http.MethodPatch, urlPath, &data, nil)
return return
} }
// GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype // GetAccountData gets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3useruseridaccount_datatype
func (cli *Client) GetAccountData(ctx context.Context, name string, output interface{}) (err error) { func (cli *Client) GetAccountData(ctx context.Context, name string, output interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, output)
return return
} }
// SetAccountData sets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype // SetAccountData sets the user's account data of this type. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype
func (cli *Client) SetAccountData(ctx context.Context, name string, data interface{}) (err error) { func (cli *Client) SetAccountData(ctx context.Context, name string, data interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name) urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "account_data", name)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -962,14 +964,14 @@ func (cli *Client) SetAccountData(ctx context.Context, name string, data interfa
// GetRoomAccountData gets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype // GetRoomAccountData gets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridaccount_datatype
func (cli *Client) GetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, output interface{}) (err error) { func (cli *Client) GetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, output interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, output) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, output)
return return
} }
// SetRoomAccountData sets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridroomsroomidaccount_datatype // SetRoomAccountData sets the user's account data of this type in a specific room. See https://spec.matrix.org/v1.2/client-server-api/#put_matrixclientv3useruseridroomsroomidaccount_datatype
func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, data interface{}) (err error) { func (cli *Client) SetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, data interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name) urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "account_data", name)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil)
if err != nil { if err != nil {
return err return err
} }
@@ -1027,7 +1029,7 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID} urlData := ClientURLPath{"v3", "rooms", roomID, "send", eventType.String(), txnID}
urlPath := cli.BuildURLWithQuery(urlData, queryParams) urlPath := cli.BuildURLWithQuery(urlData, queryParams)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
return return
} }
@@ -1035,7 +1037,7 @@ func (cli *Client) SendMessageEvent(ctx context.Context, roomID id.RoomID, event
// contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal. // contentJSON should be a pointer to something that can be encoded as JSON using json.Marshal.
func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) { func (cli *Client) SendStateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, contentJSON interface{}) (resp *RespSendEvent, err error) {
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) urlPath := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
} }
@@ -1048,7 +1050,7 @@ func (cli *Client) SendMassagedStateEvent(ctx context.Context, roomID id.RoomID,
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{ urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "state", eventType.String(), stateKey}, map[string]string{
"ts": strconv.FormatInt(ts, 10), "ts": strconv.FormatInt(ts, 10),
}) })
_, err = cli.MakeRequest(ctx, "PUT", urlPath, contentJSON, &resp) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, contentJSON, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON) cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, contentJSON)
} }
@@ -1102,7 +1104,7 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id
txnID = cli.TxnID() txnID = cli.TxnID()
} }
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "redact", eventID, txnID) urlPath := cli.BuildClientURL("v3", "rooms", roomID, "redact", eventID, txnID)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, req.Extra, &resp) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req.Extra, &resp)
return return
} }
@@ -1114,7 +1116,7 @@ func (cli *Client) RedactEvent(ctx context.Context, roomID id.RoomID, eventID id
// fmt.Println("Room:", resp.RoomID) // fmt.Println("Room:", resp.RoomID)
func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *RespCreateRoom, err error) { func (cli *Client) CreateRoom(ctx context.Context, req *ReqCreateRoom) (resp *RespCreateRoom, err error) {
urlPath := cli.BuildClientURL("v3", "createRoom") urlPath := cli.BuildClientURL("v3", "createRoom")
_, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
storeErr := cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin) storeErr := cli.StateStore.SetMembership(ctx, resp.RoomID, cli.UserID, event.MembershipJoin)
if storeErr != nil { if storeErr != nil {
@@ -1153,7 +1155,7 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq
panic("invalid number of arguments to LeaveRoom") panic("invalid number of arguments to LeaveRoom")
} }
u := cli.BuildClientURL("v3", "rooms", roomID, "leave") u := cli.BuildClientURL("v3", "rooms", roomID, "leave")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
err = cli.StateStore.SetMembership(ctx, roomID, cli.UserID, event.MembershipLeave) err = cli.StateStore.SetMembership(ctx, roomID, cli.UserID, event.MembershipLeave)
if err != nil { if err != nil {
@@ -1166,14 +1168,14 @@ func (cli *Client) LeaveRoom(ctx context.Context, roomID id.RoomID, optionalReq
// ForgetRoom forgets a room entirely. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidforget // ForgetRoom forgets a room entirely. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidforget
func (cli *Client) ForgetRoom(ctx context.Context, roomID id.RoomID) (resp *RespForgetRoom, err error) { func (cli *Client) ForgetRoom(ctx context.Context, roomID id.RoomID) (resp *RespForgetRoom, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "forget") u := cli.BuildClientURL("v3", "rooms", roomID, "forget")
_, err = cli.MakeRequest(ctx, "POST", u, struct{}{}, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, u, struct{}{}, &resp)
return return
} }
// InviteUser invites a user to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite // InviteUser invites a user to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite
func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) { func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInviteUser) (resp *RespInviteUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "invite") u := cli.BuildClientURL("v3", "rooms", roomID, "invite")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipInvite) err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipInvite)
if err != nil { if err != nil {
@@ -1186,14 +1188,14 @@ func (cli *Client) InviteUser(ctx context.Context, roomID id.RoomID, req *ReqInv
// InviteUserByThirdParty invites a third-party identifier to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1 // InviteUserByThirdParty invites a third-party identifier to a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidinvite-1
func (cli *Client) InviteUserByThirdParty(ctx context.Context, roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) { func (cli *Client) InviteUserByThirdParty(ctx context.Context, roomID id.RoomID, req *ReqInvite3PID) (resp *RespInviteUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "invite") u := cli.BuildClientURL("v3", "rooms", roomID, "invite")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp)
return return
} }
// KickUser kicks a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick // KickUser kicks a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidkick
func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) { func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickUser) (resp *RespKickUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "kick") u := cli.BuildClientURL("v3", "rooms", roomID, "kick")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave)
if err != nil { if err != nil {
@@ -1206,7 +1208,7 @@ func (cli *Client) KickUser(ctx context.Context, roomID id.RoomID, req *ReqKickU
// BanUser bans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban // BanUser bans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidban
func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) { func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUser) (resp *RespBanUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "ban") u := cli.BuildClientURL("v3", "rooms", roomID, "ban")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipBan) err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipBan)
if err != nil { if err != nil {
@@ -1219,7 +1221,7 @@ func (cli *Client) BanUser(ctx context.Context, roomID id.RoomID, req *ReqBanUse
// UnbanUser unbans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban // UnbanUser unbans a user from a room. See https://spec.matrix.org/v1.2/client-server-api/#post_matrixclientv3roomsroomidunban
func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) { func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnbanUser) (resp *RespUnbanUser, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "unban") u := cli.BuildClientURL("v3", "rooms", roomID, "unban")
_, err = cli.MakeRequest(ctx, "POST", u, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, u, req, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave) err = cli.StateStore.SetMembership(ctx, roomID, req.UserID, event.MembershipLeave)
if err != nil { if err != nil {
@@ -1233,7 +1235,7 @@ func (cli *Client) UnbanUser(ctx context.Context, roomID id.RoomID, req *ReqUnba
func (cli *Client) UserTyping(ctx context.Context, roomID id.RoomID, typing bool, timeout time.Duration) (resp *RespTyping, err error) { func (cli *Client) UserTyping(ctx context.Context, roomID id.RoomID, typing bool, timeout time.Duration) (resp *RespTyping, err error) {
req := ReqTyping{Typing: typing, Timeout: timeout.Milliseconds()} req := ReqTyping{Typing: typing, Timeout: timeout.Milliseconds()}
u := cli.BuildClientURL("v3", "rooms", roomID, "typing", cli.UserID) u := cli.BuildClientURL("v3", "rooms", roomID, "typing", cli.UserID)
_, err = cli.MakeRequest(ctx, "PUT", u, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, &resp)
return return
} }
@@ -1241,7 +1243,7 @@ func (cli *Client) UserTyping(ctx context.Context, roomID id.RoomID, typing bool
func (cli *Client) GetPresence(ctx context.Context, userID id.UserID) (resp *RespPresence, err error) { func (cli *Client) GetPresence(ctx context.Context, userID id.UserID) (resp *RespPresence, err error) {
resp = new(RespPresence) resp = new(RespPresence)
u := cli.BuildClientURL("v3", "presence", userID, "status") u := cli.BuildClientURL("v3", "presence", userID, "status")
_, err = cli.MakeRequest(ctx, "GET", u, nil, resp) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, resp)
return return
} }
@@ -1253,7 +1255,7 @@ func (cli *Client) GetOwnPresence(ctx context.Context) (resp *RespPresence, err
func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) { func (cli *Client) SetPresence(ctx context.Context, status event.Presence) (err error) {
req := ReqPresence{Presence: status} req := ReqPresence{Presence: status}
u := cli.BuildClientURL("v3", "presence", cli.UserID, "status") u := cli.BuildClientURL("v3", "presence", cli.UserID, "status")
_, err = cli.MakeRequest(ctx, "PUT", u, req, nil) _, err = cli.MakeRequest(ctx, http.MethodPut, u, req, nil)
return return
} }
@@ -1295,7 +1297,7 @@ func (cli *Client) updateStoreWithOutgoingEvent(ctx context.Context, roomID id.R
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3roomsroomidstateeventtypestatekey
func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) { func (cli *Client) StateEvent(ctx context.Context, roomID id.RoomID, eventType event.Type, stateKey string, outContent interface{}) (err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey) u := cli.BuildClientURL("v3", "rooms", roomID, "state", eventType.String(), stateKey)
_, err = cli.MakeRequest(ctx, "GET", u, nil, outContent) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, outContent)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, outContent) cli.updateStoreWithOutgoingEvent(ctx, roomID, eventType, stateKey, outContent)
} }
@@ -1367,13 +1369,13 @@ func (cli *Client) State(ctx context.Context, roomID id.RoomID) (stateMap RoomSt
// GetMediaConfig fetches the configuration of the content repository, such as upload limitations. // GetMediaConfig fetches the configuration of the content repository, such as upload limitations.
func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) { func (cli *Client) GetMediaConfig(ctx context.Context) (resp *RespMediaConfig, err error) {
u := cli.BuildURL(MediaURLPath{"v3", "config"}) u := cli.BuildURL(MediaURLPath{"v3", "config"})
_, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
return return
} }
// UploadLink uploads an HTTP URL and then returns an MXC URI. // UploadLink uploads an HTTP URL and then returns an MXC URI.
func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) { func (cli *Client) UploadLink(ctx context.Context, link string) (*RespMediaUpload, error) {
req, err := http.NewRequestWithContext(ctx, "GET", link, nil) req, err := http.NewRequestWithContext(ctx, http.MethodGet, link, nil)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -1672,12 +1674,14 @@ func (cli *Client) GetURLPreview(ctx context.Context, url string) (*RespPreviewU
// This API is primarily designed for application services which may want to efficiently look up joined members in a room. // This API is primarily designed for application services which may want to efficiently look up joined members in a room.
func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *RespJoinedMembers, err error) { func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *RespJoinedMembers, err error) {
u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members") u := cli.BuildClientURL("v3", "rooms", roomID, "joined_members")
_, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin) clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, event.MembershipJoin)
if clearErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr). cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID). Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching joined members") Msg("Failed to clear cached member list after fetching joined members")
}
for userID, member := range resp.Joined { for userID, member := range resp.Joined {
updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{ updateErr := cli.StateStore.SetMember(ctx, roomID, userID, &event.MemberEventContent{
Membership: event.MembershipJoin, Membership: event.MembershipJoin,
@@ -1685,7 +1689,7 @@ func (cli *Client) JoinedMembers(ctx context.Context, roomID id.RoomID) (resp *R
Displayname: member.DisplayName, Displayname: member.DisplayName,
}) })
if updateErr != nil { if updateErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr). cli.cliOrContextLog(ctx).Warn().Err(updateErr).
Stringer("room_id", roomID). Stringer("room_id", roomID).
Stringer("user_id", userID). Stringer("user_id", userID).
Msg("Failed to update membership in state store after fetching joined members") Msg("Failed to update membership in state store after fetching joined members")
@@ -1711,7 +1715,7 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb
query["not_membership"] = string(extra.NotMembership) query["not_membership"] = string(extra.NotMembership)
} }
u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query) u := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "members"}, query)
_, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
if err == nil && cli.StateStore != nil { if err == nil && cli.StateStore != nil {
var clearMemberships []event.Membership var clearMemberships []event.Membership
if extra.Membership != "" { if extra.Membership != "" {
@@ -1719,10 +1723,12 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb
} }
if extra.NotMembership == "" { if extra.NotMembership == "" {
clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...) clearErr := cli.StateStore.ClearCachedMembers(ctx, roomID, clearMemberships...)
if clearErr != nil {
cli.cliOrContextLog(ctx).Warn().Err(clearErr). cli.cliOrContextLog(ctx).Warn().Err(clearErr).
Stringer("room_id", roomID). Stringer("room_id", roomID).
Msg("Failed to clear cached member list after fetching joined members") Msg("Failed to clear cached member list after fetching joined members")
} }
}
for _, evt := range resp.Chunk { for _, evt := range resp.Chunk {
UpdateStateStore(ctx, cli.StateStore, evt) UpdateStateStore(ctx, cli.StateStore, evt)
} }
@@ -1736,7 +1742,7 @@ func (cli *Client) Members(ctx context.Context, roomID id.RoomID, req ...ReqMemb
// This API is primarily designed for application services which may want to efficiently look up joined rooms. // This API is primarily designed for application services which may want to efficiently look up joined rooms.
func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err error) { func (cli *Client) JoinedRooms(ctx context.Context) (resp *RespJoinedRooms, err error) {
u := cli.BuildClientURL("v3", "joined_rooms") u := cli.BuildClientURL("v3", "joined_rooms")
_, err = cli.MakeRequest(ctx, "GET", u, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, u, nil, &resp)
return return
} }
@@ -1775,7 +1781,7 @@ func (cli *Client) Messages(ctx context.Context, roomID id.RoomID, from, to stri
} }
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "messages"}, query) urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "messages"}, query)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
@@ -1810,13 +1816,13 @@ func (cli *Client) Context(ctx context.Context, roomID id.RoomID, eventID id.Eve
} }
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "context", eventID}, query) urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "rooms", roomID, "context", eventID}, query)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) { func (cli *Client) GetEvent(ctx context.Context, roomID id.RoomID, eventID id.EventID) (resp *event.Event, err error) {
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "event", eventID) urlPath := cli.BuildClientURL("v3", "rooms", roomID, "event", eventID)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
@@ -1837,13 +1843,13 @@ func (cli *Client) MarkReadWithContent(ctx context.Context, roomID id.RoomID, ev
// To mark a message in a specific thread as read, use pass a ReqSendReceipt as the content. // To mark a message in a specific thread as read, use pass a ReqSendReceipt as the content.
func (cli *Client) SendReceipt(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType, content interface{}) (err error) { func (cli *Client) SendReceipt(ctx context.Context, roomID id.RoomID, eventID id.EventID, receiptType event.ReceiptType, content interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "receipt", receiptType, eventID) urlPath := cli.BuildClientURL("v3", "rooms", roomID, "receipt", receiptType, eventID)
_, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, nil)
return return
} }
func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content interface{}) (err error) { func (cli *Client) SetReadMarkers(ctx context.Context, roomID id.RoomID, content interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "read_markers") urlPath := cli.BuildClientURL("v3", "rooms", roomID, "read_markers")
_, err = cli.MakeRequest(ctx, "POST", urlPath, content, nil) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, content, nil)
return return
} }
@@ -1857,7 +1863,7 @@ func (cli *Client) AddTag(ctx context.Context, roomID id.RoomID, tag string, ord
func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag string, data interface{}) (err error) { func (cli *Client) AddTagWithCustomData(ctx context.Context, roomID id.RoomID, tag string, data interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, data, nil) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, data, nil)
return return
} }
@@ -1868,13 +1874,13 @@ func (cli *Client) GetTags(ctx context.Context, roomID id.RoomID) (tags event.Ta
func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp interface{}) (err error) { func (cli *Client) GetTagsWithCustomData(ctx context.Context, roomID id.RoomID, resp interface{}) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags") urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag string) (err error) { func (cli *Client) RemoveTag(ctx context.Context, roomID id.RoomID, tag string) (err error) {
urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag) urlPath := cli.BuildClientURL("v3", "user", cli.UserID, "rooms", roomID, "tags", tag)
_, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
return return
} }
@@ -1889,49 +1895,49 @@ func (cli *Client) SetTags(ctx context.Context, roomID id.RoomID, tags event.Tag
// See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3voipturnserver // See https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3voipturnserver
func (cli *Client) TurnServer(ctx context.Context) (resp *RespTurnServer, err error) { func (cli *Client) TurnServer(ctx context.Context) (resp *RespTurnServer, err error) {
urlPath := cli.BuildClientURL("v3", "voip", "turnServer") urlPath := cli.BuildClientURL("v3", "voip", "turnServer")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
func (cli *Client) CreateAlias(ctx context.Context, alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) { func (cli *Client) CreateAlias(ctx context.Context, alias id.RoomAlias, roomID id.RoomID) (resp *RespAliasCreate, err error) {
urlPath := cli.BuildClientURL("v3", "directory", "room", alias) urlPath := cli.BuildClientURL("v3", "directory", "room", alias)
_, err = cli.MakeRequest(ctx, "PUT", urlPath, &ReqAliasCreate{RoomID: roomID}, &resp) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, &ReqAliasCreate{RoomID: roomID}, &resp)
return return
} }
func (cli *Client) ResolveAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasResolve, err error) { func (cli *Client) ResolveAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasResolve, err error) {
urlPath := cli.BuildClientURL("v3", "directory", "room", alias) urlPath := cli.BuildClientURL("v3", "directory", "room", alias)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
func (cli *Client) DeleteAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasDelete, err error) { func (cli *Client) DeleteAlias(ctx context.Context, alias id.RoomAlias) (resp *RespAliasDelete, err error) {
urlPath := cli.BuildClientURL("v3", "directory", "room", alias) urlPath := cli.BuildClientURL("v3", "directory", "room", alias)
_, err = cli.MakeRequest(ctx, "DELETE", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
return return
} }
func (cli *Client) GetAliases(ctx context.Context, roomID id.RoomID) (resp *RespAliasList, err error) { func (cli *Client) GetAliases(ctx context.Context, roomID id.RoomID) (resp *RespAliasList, err error) {
urlPath := cli.BuildClientURL("v3", "rooms", roomID, "aliases") urlPath := cli.BuildClientURL("v3", "rooms", roomID, "aliases")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
func (cli *Client) UploadKeys(ctx context.Context, req *ReqUploadKeys) (resp *RespUploadKeys, err error) { func (cli *Client) UploadKeys(ctx context.Context, req *ReqUploadKeys) (resp *RespUploadKeys, err error) {
urlPath := cli.BuildClientURL("v3", "keys", "upload") urlPath := cli.BuildClientURL("v3", "keys", "upload")
_, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp)
return return
} }
func (cli *Client) QueryKeys(ctx context.Context, req *ReqQueryKeys) (resp *RespQueryKeys, err error) { func (cli *Client) QueryKeys(ctx context.Context, req *ReqQueryKeys) (resp *RespQueryKeys, err error) {
urlPath := cli.BuildClientURL("v3", "keys", "query") urlPath := cli.BuildClientURL("v3", "keys", "query")
_, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp)
return return
} }
func (cli *Client) ClaimKeys(ctx context.Context, req *ReqClaimKeys) (resp *RespClaimKeys, err error) { func (cli *Client) ClaimKeys(ctx context.Context, req *ReqClaimKeys) (resp *RespClaimKeys, err error) {
urlPath := cli.BuildClientURL("v3", "keys", "claim") urlPath := cli.BuildClientURL("v3", "keys", "claim")
_, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp)
return return
} }
@@ -1940,43 +1946,195 @@ func (cli *Client) GetKeyChanges(ctx context.Context, from, to string) (resp *Re
"from": from, "from": from,
"to": to, "to": to,
}) })
_, err = cli.MakeRequest(ctx, "POST", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, nil, &resp)
return return
} }
// GetKeyBackup retrieves the keys from the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeys
func (cli *Client) GetKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeys[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
// PutKeysInBackup stores several keys in the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeys
func (cli *Client) PutKeysInBackup(ctx context.Context, version id.KeyBackupVersion, req *ReqKeyBackup) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
return
}
// DeleteKeyBackup deletes all keys from the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeys
func (cli *Client) DeleteKeyBackup(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys"}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
return
}
// GetKeyBackupForRoom retrieves the keys from the backup for the given room.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomid
func (cli *Client) GetKeyBackupForRoom(
ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID,
) (resp *RespRoomKeyBackup[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
// PutKeysInBackupForRoom stores several keys in the backup for the given room.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomid
func (cli *Client) PutKeysInBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, req *ReqRoomKeyBackup) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
return
}
// DeleteKeysFromBackupForRoom deletes all the keys in the backup for the given
// room.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomid
func (cli *Client) DeleteKeysFromBackupForRoom(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String()}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
return
}
// GetKeyBackupForRoomAndSession retrieves a key from the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keyskeysroomidsessionid
func (cli *Client) GetKeyBackupForRoomAndSession(
ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID,
) (resp *RespKeyBackupData[backup.EncryptedSessionData[backup.MegolmSessionData]], err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
// PutKeysInBackupForRoomAndSession stores a key in the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keyskeysroomidsessionid
func (cli *Client) PutKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, req *ReqKeyBackupData) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
return
}
// DeleteKeysInBackupForRoomAndSession deletes a key from the backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#delete_matrixclientv3room_keyskeysroomidsessionid
func (cli *Client) DeleteKeysInBackupForRoomAndSession(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID) (resp *RespRoomKeysUpdate, err error) {
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "room_keys", "keys", roomID.String(), sessionID.String()}, map[string]string{
"version": string(version),
})
_, err = cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, &resp)
return
}
// GetKeyBackupLatestVersion returns information about the latest backup version.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversion
func (cli *Client) GetKeyBackupLatestVersion(ctx context.Context) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) {
urlPath := cli.BuildClientURL("v3", "room_keys", "version")
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
// CreateKeyBackupVersion creates a new key backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#post_matrixclientv3room_keysversion
func (cli *Client) CreateKeyBackupVersion(ctx context.Context, req *ReqRoomKeysVersionCreate[backup.MegolmAuthData]) (resp *RespRoomKeysVersionCreate, err error) {
urlPath := cli.BuildClientURL("v3", "room_keys", "version")
_, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp)
return
}
// GetKeyBackupVersion returns information about an existing key backup.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#get_matrixclientv3room_keysversionversion
func (cli *Client) GetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) (resp *RespRoomKeysVersion[backup.MegolmAuthData], err error) {
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
_, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return
}
// UpdateKeyBackupVersion updates information about an existing key backup. Only
// the auth_data can be modified.
//
// See: https://spec.matrix.org/v1.9/client-server-api/#put_matrixclientv3room_keysversionversion
func (cli *Client) UpdateKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion, req *ReqRoomKeysVersionUpdate[backup.MegolmAuthData]) error {
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
_, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, nil, nil)
return err
}
// DeleteKeyBackupVersion deletes an existing key backup. Both the information
// about the backup, as well as all key data related to the backup will be
// deleted.
//
// See: https://spec.matrix.org/v1.1/client-server-api/#delete_matrixclientv3room_keysversionversion
func (cli *Client) DeleteKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error {
urlPath := cli.BuildClientURL("v3", "room_keys", "version", version)
_, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
return err
}
func (cli *Client) SendToDevice(ctx context.Context, eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) { func (cli *Client) SendToDevice(ctx context.Context, eventType event.Type, req *ReqSendToDevice) (resp *RespSendToDevice, err error) {
urlPath := cli.BuildClientURL("v3", "sendToDevice", eventType.String(), cli.TxnID()) urlPath := cli.BuildClientURL("v3", "sendToDevice", eventType.String(), cli.TxnID())
_, err = cli.MakeRequest(ctx, "PUT", urlPath, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPut, urlPath, req, &resp)
return return
} }
func (cli *Client) GetDevicesInfo(ctx context.Context) (resp *RespDevicesInfo, err error) { func (cli *Client) GetDevicesInfo(ctx context.Context) (resp *RespDevicesInfo, err error) {
urlPath := cli.BuildClientURL("v3", "devices") urlPath := cli.BuildClientURL("v3", "devices")
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
func (cli *Client) GetDeviceInfo(ctx context.Context, deviceID id.DeviceID) (resp *RespDeviceInfo, err error) { func (cli *Client) GetDeviceInfo(ctx context.Context, deviceID id.DeviceID) (resp *RespDeviceInfo, err error) {
urlPath := cli.BuildClientURL("v3", "devices", deviceID) urlPath := cli.BuildClientURL("v3", "devices", deviceID)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
return return
} }
func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req *ReqDeviceInfo) error { func (cli *Client) SetDeviceInfo(ctx context.Context, deviceID id.DeviceID, req *ReqDeviceInfo) error {
urlPath := cli.BuildClientURL("v3", "devices", deviceID) urlPath := cli.BuildClientURL("v3", "devices", deviceID)
_, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, req, nil)
return err return err
} }
func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error { func (cli *Client) DeleteDevice(ctx context.Context, deviceID id.DeviceID, req *ReqDeleteDevice) error {
urlPath := cli.BuildClientURL("v3", "devices", deviceID) urlPath := cli.BuildClientURL("v3", "devices", deviceID)
_, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err return err
} }
func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error { func (cli *Client) DeleteDevices(ctx context.Context, req *ReqDeleteDevices) error {
urlPath := cli.BuildClientURL("v3", "delete_devices") urlPath := cli.BuildClientURL("v3", "delete_devices")
_, err := cli.MakeRequest(ctx, "DELETE", urlPath, req, nil) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, req, nil)
return err return err
} }
@@ -1992,7 +2150,7 @@ func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCross
RequestJSON: keys, RequestJSON: keys,
SensitiveContent: keys.Auth != nil, SensitiveContent: keys.Auth != nil,
}) })
if respErr, ok := err.(HTTPError); ok && respErr.IsStatus(http.StatusUnauthorized) { if respErr, ok := err.(HTTPError); ok && respErr.IsStatus(http.StatusUnauthorized) && uiaCallback != nil {
// try again with UI auth // try again with UI auth
var uiAuthResp RespUserInteractive var uiAuthResp RespUserInteractive
if err := json.Unmarshal(content, &uiAuthResp); err != nil { if err := json.Unmarshal(content, &uiAuthResp); err != nil {
@@ -2001,7 +2159,7 @@ func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCross
auth := uiaCallback(&uiAuthResp) auth := uiaCallback(&uiAuthResp)
if auth != nil { if auth != nil {
keys.Auth = auth keys.Auth = auth
return cli.UploadCrossSigningKeys(ctx, keys, uiaCallback) return cli.UploadCrossSigningKeys(ctx, keys, nil)
} }
} }
return err return err
@@ -2009,7 +2167,7 @@ func (cli *Client) UploadCrossSigningKeys(ctx context.Context, keys *UploadCross
func (cli *Client) UploadSignatures(ctx context.Context, req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) { func (cli *Client) UploadSignatures(ctx context.Context, req *ReqUploadSignatures) (resp *RespUploadSignatures, err error) {
urlPath := cli.BuildClientURL("v3", "keys", "signatures", "upload") urlPath := cli.BuildClientURL("v3", "keys", "signatures", "upload")
_, err = cli.MakeRequest(ctx, "POST", urlPath, req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, urlPath, req, &resp)
return return
} }
@@ -2023,13 +2181,13 @@ func (cli *Client) GetScopedPushRules(ctx context.Context, scope string) (resp *
u, _ := url.Parse(cli.BuildClientURL("v3", "pushrules", scope)) u, _ := url.Parse(cli.BuildClientURL("v3", "pushrules", scope))
// client.BuildURL returns the URL without a trailing slash, but the pushrules endpoint requires the slash. // client.BuildURL returns the URL without a trailing slash, but the pushrules endpoint requires the slash.
u.Path += "/" u.Path += "/"
_, err = cli.MakeRequest(ctx, "GET", u.String(), nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, u.String(), nil, &resp)
return return
} }
func (cli *Client) GetPushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) { func (cli *Client) GetPushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) (resp *pushrules.PushRule, err error) {
urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID)
_, err = cli.MakeRequest(ctx, "GET", urlPath, nil, &resp) _, err = cli.MakeRequest(ctx, http.MethodGet, urlPath, nil, &resp)
if resp != nil { if resp != nil {
resp.Type = kind resp.Type = kind
} }
@@ -2038,7 +2196,7 @@ func (cli *Client) GetPushRule(ctx context.Context, scope string, kind pushrules
func (cli *Client) DeletePushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) error { func (cli *Client) DeletePushRule(ctx context.Context, scope string, kind pushrules.PushRuleType, ruleID string) error {
urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID) urlPath := cli.BuildClientURL("v3", "pushrules", scope, kind, ruleID)
_, err := cli.MakeRequest(ctx, "DELETE", urlPath, nil, nil) _, err := cli.MakeRequest(ctx, http.MethodDelete, urlPath, nil, nil)
return err return err
} }
@@ -2051,7 +2209,7 @@ func (cli *Client) PutPushRule(ctx context.Context, scope string, kind pushrules
query["before"] = req.Before query["before"] = req.Before
} }
urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "pushrules", scope, kind, ruleID}, query) urlPath := cli.BuildURLWithQuery(ClientURLPath{"v3", "pushrules", scope, kind, ruleID}, query)
_, err := cli.MakeRequest(ctx, "PUT", urlPath, req, nil) _, err := cli.MakeRequest(ctx, http.MethodPut, urlPath, req, nil)
return err return err
} }
@@ -2072,7 +2230,7 @@ func (cli *Client) BatchSend(ctx context.Context, roomID id.RoomID, req *ReqBatc
if len(req.BatchID) > 0 { if len(req.BatchID) > 0 {
query["batch_id"] = req.BatchID.String() query["batch_id"] = req.BatchID.String()
} }
_, err = cli.MakeRequest(ctx, "POST", cli.BuildURLWithQuery(path, query), req, &resp) _, err = cli.MakeRequest(ctx, http.MethodPost, cli.BuildURLWithQuery(path, query), req, &resp)
return return
} }
@@ -2123,7 +2281,7 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli
if err != nil { if err != nil {
return nil, err return nil, err
} }
cli := &Client{ return &Client{
AccessToken: accessToken, AccessToken: accessToken,
UserAgent: DefaultUserAgent, UserAgent: DefaultUserAgent,
HomeserverURL: hsURL, HomeserverURL: hsURL,
@@ -2135,7 +2293,5 @@ func NewClient(homeserverURL string, userID id.UserID, accessToken string) (*Cli
// The client will work with this storer: it just won't remember across restarts. // The client will work with this storer: it just won't remember across restarts.
// In practice, a database backend should be used. // In practice, a database backend should be used.
Store: NewMemorySyncStore(), Store: NewMemorySyncStore(),
} }, nil
cli.Logger = maulogadapt.ZeroAsMau(&cli.Log)
return cli, nil
} }

View File

@@ -9,6 +9,7 @@ package crypto
import ( import (
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@@ -17,6 +18,7 @@ type OlmAccount struct {
signingKey id.SigningKey signingKey id.SigningKey
identityKey id.IdentityKey identityKey id.IdentityKey
Shared bool Shared bool
KeyBackupVersion id.KeyBackupVersion
} }
func NewOlmAccount() *OlmAccount { func NewOlmAccount() *OlmAccount {
@@ -62,11 +64,7 @@ func (account *OlmAccount) getInitialKeys(userID id.UserID, deviceID id.DeviceID
panic(err) panic(err)
} }
deviceKeys.Signatures = mautrix.Signatures{ deviceKeys.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature)
userID: {
id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature,
},
}
return deviceKeys return deviceKeys
} }
@@ -79,11 +77,7 @@ func (account *OlmAccount) getOneTimeKeys(userID id.UserID, deviceID id.DeviceID
for keyID, key := range account.Internal.OneTimeKeys() { for keyID, key := range account.Internal.OneTimeKeys() {
key := mautrix.OneTimeKey{Key: key} key := mautrix.OneTimeKey{Key: key}
signature, _ := account.Internal.SignJSON(key) signature, _ := account.Internal.SignJSON(key)
key.Signatures = mautrix.Signatures{ key.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature)
userID: {
id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature,
},
}
key.IsSigned = true key.IsSigned = true
oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key oneTimeKeys[id.NewKeyID(id.KeyAlgorithmSignedCurve25519, keyID)] = key
} }

60
vendor/maunium.net/go/mautrix/crypto/aescbc/aes_cbc.go generated vendored Normal file
View File

@@ -0,0 +1,60 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package aescbc
import (
"crypto/aes"
"crypto/cipher"
"maunium.net/go/mautrix/crypto/pkcs7"
)
// Encrypt encrypts the plaintext with the key and IV. The IV length must be
// equal to the AES block size.
//
// This function might mutate the plaintext.
func Encrypt(key, iv, plaintext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, ErrNoKeyProvided
}
if len(iv) != aes.BlockSize {
return nil, ErrIVNotBlockSize
}
plaintext = pkcs7.Pad(plaintext, aes.BlockSize)
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cipher.NewCBCEncrypter(block, iv).CryptBlocks(plaintext, plaintext)
return plaintext, nil
}
// Decrypt decrypts the ciphertext with the key and IV. The IV length must be
// equal to the block size.
//
// This function mutates the ciphertext.
func Decrypt(key, iv, ciphertext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, ErrNoKeyProvided
}
if len(iv) != aes.BlockSize {
return nil, ErrIVNotBlockSize
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(ciphertext) < aes.BlockSize {
return nil, ErrNotMultipleBlockSize
}
cipher.NewCBCDecrypter(block, iv).CryptBlocks(ciphertext, ciphertext)
return pkcs7.Unpad(ciphertext), nil
}

15
vendor/maunium.net/go/mautrix/crypto/aescbc/errors.go generated vendored Normal file
View File

@@ -0,0 +1,15 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package aescbc
import "errors"
var (
ErrNoKeyProvided = errors.New("no key")
ErrIVNotBlockSize = errors.New("IV length does not match AES block size")
ErrNotMultipleBlockSize = errors.New("ciphertext length is not a multiple of the AES block size")
)

View File

@@ -0,0 +1,137 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package backup
import (
"bytes"
"crypto/ecdh"
"crypto/hmac"
"crypto/rand"
"crypto/sha256"
"encoding/json"
"errors"
"go.mau.fi/util/jsonbytes"
"golang.org/x/crypto/hkdf"
"maunium.net/go/mautrix/crypto/aescbc"
)
var ErrInvalidMAC = errors.New("invalid MAC")
// EncryptedSessionData is the encrypted session_data field of a key backup as
// defined in [Section 11.12.3.2.2 of the Spec].
//
// The type parameter T represents the format of the session data contained in
// the encrypted payload.
//
// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
type EncryptedSessionData[T any] struct {
Ciphertext jsonbytes.UnpaddedBytes `json:"ciphertext"`
Ephemeral EphemeralKey `json:"ephemeral"`
MAC jsonbytes.UnpaddedBytes `json:"mac"`
}
func calculateEncryptionParameters(sharedSecret []byte) (key, macKey, iv []byte, err error) {
hkdfReader := hkdf.New(sha256.New, sharedSecret, nil, nil)
encryptionParams := make([]byte, 80)
_, err = hkdfReader.Read(encryptionParams)
if err != nil {
return nil, nil, nil, err
}
return encryptionParams[:32], encryptionParams[32:64], encryptionParams[64:], nil
}
// calculateCompatMAC calculates the MAC for compatibility with Olm and
// Vodozemac which do not actually write the ciphertext when computing the MAC.
//
// Deprecated: Use [calculateMAC] instead.
func calculateCompatMAC(macKey []byte) []byte {
hash := hmac.New(sha256.New, macKey)
return hash.Sum(nil)[:8]
}
// calculateMAC calculates the MAC as described in step 5 of according to
// [Section 11.12.3.2.2] of the Spec.
//
// [Section 11.12.3.2.2]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
func calculateMAC(macKey, ciphertext []byte) []byte {
hash := hmac.New(sha256.New, macKey)
_, err := hash.Write(ciphertext)
if err != nil {
panic(err)
}
return hash.Sum(nil)[:8]
}
// EncryptSessionData encrypts the given session data with the given recovery
// key as defined in [Section 11.12.3.2.2 of the Spec].
//
// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
func EncryptSessionData[T any](backupKey *MegolmBackupKey, sessionData T) (*EncryptedSessionData[T], error) {
sessionJSON, err := json.Marshal(sessionData)
if err != nil {
return nil, err
}
ephemeralKey, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
sharedSecret, err := ephemeralKey.ECDH(backupKey.PublicKey())
if err != nil {
return nil, err
}
key, macKey, iv, err := calculateEncryptionParameters(sharedSecret)
if err != nil {
return nil, err
}
ciphertext, err := aescbc.Encrypt(key, iv, sessionJSON)
if err != nil {
return nil, err
}
return &EncryptedSessionData[T]{
Ciphertext: ciphertext,
Ephemeral: EphemeralKey{ephemeralKey.PublicKey()},
MAC: calculateCompatMAC(macKey),
}, nil
}
// Decrypt decrypts the [EncryptedSessionData] into a *T using the recovery key
// by reversing the process described in [Section 11.12.3.2.2 of the Spec].
//
// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
func (esd *EncryptedSessionData[T]) Decrypt(backupKey *MegolmBackupKey) (*T, error) {
sharedSecret, err := backupKey.ECDH(esd.Ephemeral.PublicKey)
if err != nil {
return nil, err
}
key, macKey, iv, err := calculateEncryptionParameters(sharedSecret)
if err != nil {
return nil, err
}
// Verify the MAC before decrypting.
if !bytes.Equal(calculateCompatMAC(macKey), esd.MAC) {
return nil, ErrInvalidMAC
}
plaintext, err := aescbc.Decrypt(key, iv, esd.Ciphertext)
if err != nil {
return nil, err
}
var sessionData T
err = json.Unmarshal(plaintext, &sessionData)
return &sessionData, err
}

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package backup
import (
"crypto/ecdh"
"encoding/base64"
"encoding/json"
)
// EphemeralKey is a wrapper around an ECDH X25519 public key that implements
// JSON marshalling and unmarshalling.
type EphemeralKey struct {
*ecdh.PublicKey
}
func (k *EphemeralKey) MarshalJSON() ([]byte, error) {
if k == nil || k.PublicKey == nil {
return json.Marshal(nil)
}
return json.Marshal(base64.RawStdEncoding.EncodeToString(k.Bytes()))
}
func (k *EphemeralKey) UnmarshalJSON(data []byte) error {
var keyStr string
err := json.Unmarshal(data, &keyStr)
if err != nil {
return err
}
keyBytes, err := base64.RawStdEncoding.DecodeString(keyStr)
if err != nil {
return err
}
k.PublicKey, err = ecdh.X25519().NewPublicKey(keyBytes)
return err
}

View File

@@ -0,0 +1,39 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package backup
import (
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id"
)
// MegolmAuthData is the auth_data when the key backup is created with
// the [id.KeyBackupAlgorithmMegolmBackupV1] algorithm as defined in
// [Section 11.12.3.2.2 of the Spec].
//
// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
type MegolmAuthData struct {
PublicKey id.Ed25519 `json:"public_key"`
Signatures signatures.Signatures `json:"signatures"`
}
type SenderClaimedKeys struct {
Ed25519 id.Ed25519 `json:"ed25519"`
}
// MegolmSessionData is the decrypted session_data when the key backup is created
// with the [id.KeyBackupAlgorithmMegolmBackupV1] algorithm as defined in
// [Section 11.12.3.2.2 of the Spec].
//
// [Section 11.12.3.2.2 of the Spec]: https://spec.matrix.org/v1.9/client-server-api/#backup-algorithm-mmegolm_backupv1curve25519-aes-sha2
type MegolmSessionData struct {
Algorithm id.Algorithm `json:"algorithm"`
ForwardingKeyChain []string `json:"forwarding_curve25519_key_chain"`
SenderClaimedKeys SenderClaimedKeys `json:"sender_claimed_keys"`
SenderKey id.SenderKey `json:"sender_key"`
SessionKey string `json:"session_key"`
}

View File

@@ -0,0 +1,34 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package backup
import (
"crypto/ecdh"
"crypto/rand"
)
// MegolmBackupKey is a wrapper around an ECDH X25519 private key that is used
// to decrypt a megolm key backup.
type MegolmBackupKey struct {
*ecdh.PrivateKey
}
func NewMegolmBackupKey() (*MegolmBackupKey, error) {
key, err := ecdh.X25519().GenerateKey(rand.Reader)
if err != nil {
return nil, err
}
return &MegolmBackupKey{key}, nil
}
func MegolmBackupKeyFromBytes(bytes []byte) (*MegolmBackupKey, error) {
key, err := ecdh.X25519().NewPrivateKey(bytes)
if err != nil {
return nil, err
}
return &MegolmBackupKey{key}, nil
}

View File

@@ -13,21 +13,22 @@ import (
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
// CrossSigningKeysCache holds the three cross-signing keys for the current user. // CrossSigningKeysCache holds the three cross-signing keys for the current user.
type CrossSigningKeysCache struct { type CrossSigningKeysCache struct {
MasterKey *olm.PkSigning MasterKey olm.PKSigning
SelfSigningKey *olm.PkSigning SelfSigningKey olm.PKSigning
UserSigningKey *olm.PkSigning UserSigningKey olm.PKSigning
} }
func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache { func (cskc *CrossSigningKeysCache) PublicKeys() *CrossSigningPublicKeysCache {
return &CrossSigningPublicKeysCache{ return &CrossSigningPublicKeysCache{
MasterKey: cskc.MasterKey.PublicKey, MasterKey: cskc.MasterKey.PublicKey(),
SelfSigningKey: cskc.SelfSigningKey.PublicKey, SelfSigningKey: cskc.SelfSigningKey.PublicKey(),
UserSigningKey: cskc.UserSigningKey.PublicKey, UserSigningKey: cskc.UserSigningKey.PublicKey(),
} }
} }
@@ -39,28 +40,28 @@ type CrossSigningSeeds struct {
func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds { func (mach *OlmMachine) ExportCrossSigningKeys() CrossSigningSeeds {
return CrossSigningSeeds{ return CrossSigningSeeds{
MasterKey: mach.CrossSigningKeys.MasterKey.Seed, MasterKey: mach.CrossSigningKeys.MasterKey.Seed(),
SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed, SelfSigningKey: mach.CrossSigningKeys.SelfSigningKey.Seed(),
UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed, UserSigningKey: mach.CrossSigningKeys.UserSigningKey.Seed(),
} }
} }
func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err error) { func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err error) {
var keysCache CrossSigningKeysCache var keysCache CrossSigningKeysCache
if keysCache.MasterKey, err = olm.NewPkSigningFromSeed(keys.MasterKey); err != nil { if keysCache.MasterKey, err = olm.NewPKSigningFromSeed(keys.MasterKey); err != nil {
return return
} }
if keysCache.SelfSigningKey, err = olm.NewPkSigningFromSeed(keys.SelfSigningKey); err != nil { if keysCache.SelfSigningKey, err = olm.NewPKSigningFromSeed(keys.SelfSigningKey); err != nil {
return return
} }
if keysCache.UserSigningKey, err = olm.NewPkSigningFromSeed(keys.UserSigningKey); err != nil { if keysCache.UserSigningKey, err = olm.NewPKSigningFromSeed(keys.UserSigningKey); err != nil {
return return
} }
mach.Log.Debug(). mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()). Str("master", keysCache.MasterKey.PublicKey().String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()). Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()). Str("user_signing", keysCache.UserSigningKey.PublicKey().String()).
Msg("Imported own cross-signing keys") Msg("Imported own cross-signing keys")
mach.CrossSigningKeys = &keysCache mach.CrossSigningKeys = &keysCache
@@ -72,19 +73,19 @@ func (mach *OlmMachine) ImportCrossSigningKeys(keys CrossSigningSeeds) (err erro
func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, error) { func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, error) {
var keysCache CrossSigningKeysCache var keysCache CrossSigningKeysCache
var err error var err error
if keysCache.MasterKey, err = olm.NewPkSigning(); err != nil { if keysCache.MasterKey, err = olm.NewPKSigning(); err != nil {
return nil, fmt.Errorf("failed to generate master key: %w", err) return nil, fmt.Errorf("failed to generate master key: %w", err)
} }
if keysCache.SelfSigningKey, err = olm.NewPkSigning(); err != nil { if keysCache.SelfSigningKey, err = olm.NewPKSigning(); err != nil {
return nil, fmt.Errorf("failed to generate self-signing key: %w", err) return nil, fmt.Errorf("failed to generate self-signing key: %w", err)
} }
if keysCache.UserSigningKey, err = olm.NewPkSigning(); err != nil { if keysCache.UserSigningKey, err = olm.NewPKSigning(); err != nil {
return nil, fmt.Errorf("failed to generate user-signing key: %w", err) return nil, fmt.Errorf("failed to generate user-signing key: %w", err)
} }
mach.Log.Debug(). mach.Log.Debug().
Str("master", keysCache.MasterKey.PublicKey.String()). Str("master", keysCache.MasterKey.PublicKey().String()).
Str("self_signing", keysCache.SelfSigningKey.PublicKey.String()). Str("self_signing", keysCache.SelfSigningKey.PublicKey().String()).
Str("user_signing", keysCache.UserSigningKey.PublicKey.String()). Str("user_signing", keysCache.UserSigningKey.PublicKey().String()).
Msg("Generated cross-signing keys") Msg("Generated cross-signing keys")
return &keysCache, nil return &keysCache, nil
} }
@@ -92,48 +93,45 @@ func (mach *OlmMachine) GenerateCrossSigningKeys() (*CrossSigningKeysCache, erro
// PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server. // PublishCrossSigningKeys signs and uploads the public keys of the given cross-signing keys to the server.
func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error { func (mach *OlmMachine) PublishCrossSigningKeys(ctx context.Context, keys *CrossSigningKeysCache, uiaCallback mautrix.UIACallback) error {
userID := mach.Client.UserID userID := mach.Client.UserID
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey.String()) masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String())
masterKey := mautrix.CrossSigningKeys{ masterKey := mautrix.CrossSigningKeys{
UserID: userID, UserID: userID,
Usage: []id.CrossSigningUsage{id.XSUsageMaster}, Usage: []id.CrossSigningUsage{id.XSUsageMaster},
Keys: map[id.KeyID]id.Ed25519{ Keys: map[id.KeyID]id.Ed25519{
masterKeyID: keys.MasterKey.PublicKey, masterKeyID: keys.MasterKey.PublicKey(),
}, },
} }
masterSig, err := mach.account.Internal.SignJSON(masterKey)
if err != nil {
return fmt.Errorf("failed to sign master key: %w", err)
}
masterKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, mach.Client.DeviceID.String(), masterSig)
selfKey := mautrix.CrossSigningKeys{ selfKey := mautrix.CrossSigningKeys{
UserID: userID, UserID: userID,
Usage: []id.CrossSigningUsage{id.XSUsageSelfSigning}, Usage: []id.CrossSigningUsage{id.XSUsageSelfSigning},
Keys: map[id.KeyID]id.Ed25519{ Keys: map[id.KeyID]id.Ed25519{
id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey.String()): keys.SelfSigningKey.PublicKey, id.NewKeyID(id.KeyAlgorithmEd25519, keys.SelfSigningKey.PublicKey().String()): keys.SelfSigningKey.PublicKey(),
}, },
} }
selfSig, err := keys.MasterKey.SignJSON(selfKey) selfSig, err := keys.MasterKey.SignJSON(selfKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to sign self-signing key: %w", err) return fmt.Errorf("failed to sign self-signing key: %w", err)
} }
selfKey.Signatures = map[id.UserID]map[id.KeyID]string{ selfKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), selfSig)
userID: {
masterKeyID: selfSig,
},
}
userKey := mautrix.CrossSigningKeys{ userKey := mautrix.CrossSigningKeys{
UserID: userID, UserID: userID,
Usage: []id.CrossSigningUsage{id.XSUsageUserSigning}, Usage: []id.CrossSigningUsage{id.XSUsageUserSigning},
Keys: map[id.KeyID]id.Ed25519{ Keys: map[id.KeyID]id.Ed25519{
id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey.String()): keys.UserSigningKey.PublicKey, id.NewKeyID(id.KeyAlgorithmEd25519, keys.UserSigningKey.PublicKey().String()): keys.UserSigningKey.PublicKey(),
}, },
} }
userSig, err := keys.MasterKey.SignJSON(userKey) userSig, err := keys.MasterKey.SignJSON(userKey)
if err != nil { if err != nil {
return fmt.Errorf("failed to sign user-signing key: %w", err) return fmt.Errorf("failed to sign user-signing key: %w", err)
} }
userKey.Signatures = map[id.UserID]map[id.KeyID]string{ userKey.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, keys.MasterKey.PublicKey().String(), userSig)
userID: {
masterKeyID: userSig,
},
}
err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{ err = mach.Client.UploadCrossSigningKeys(ctx, &mautrix.UploadCrossSigningKeysReq{
Master: masterKey, Master: masterKey,

View File

@@ -14,7 +14,7 @@ import (
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@@ -34,31 +34,6 @@ var (
ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC") ErrMismatchingMasterKeyMAC = errors.New("mismatching cross-signing master key MAC")
) )
func (mach *OlmMachine) fetchMasterKey(ctx context.Context, device *id.Device, content *event.VerificationMacEventContent, verState *verificationState, transactionID string) (id.Ed25519, error) {
crossSignKeys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, device.UserID)
if err != nil {
return "", fmt.Errorf("failed to fetch cross-signing keys: %w", err)
}
masterKey, ok := crossSignKeys[id.XSUsageMaster]
if !ok {
return "", ErrCrossSigningMasterKeyNotFound
}
masterKeyID := id.NewKeyID(id.KeyAlgorithmEd25519, masterKey.Key.String())
masterKeyMAC, ok := content.Mac[masterKeyID]
if !ok {
return masterKey.Key, ErrMasterKeyMACNotFound
}
expectedMasterKeyMAC, _, err := mach.getPKAndKeysMAC(verState.sas, device.UserID, device.DeviceID,
mach.Client.UserID, mach.Client.DeviceID, transactionID, masterKey.Key, masterKeyID, content.Mac)
if err != nil {
return masterKey.Key, fmt.Errorf("failed to calculate expected MAC for master key: %w", err)
}
if masterKeyMAC != expectedMasterKeyMAC {
err = fmt.Errorf("%w: expected %s, got %s", ErrMismatchingMasterKeyMAC, expectedMasterKeyMAC, masterKeyMAC)
}
return masterKey.Key, err
}
// SignUser creates a cross-signing signature for a user, stores it and uploads it to the server. // SignUser creates a cross-signing signature for a user, stores it and uploads it to the server.
func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKey id.Ed25519) error { func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKey id.Ed25519) error {
if userID == mach.Client.UserID { if userID == mach.Client.UserID {
@@ -85,7 +60,7 @@ func (mach *OlmMachine) SignUser(ctx context.Context, userID id.UserID, masterKe
Str("signature", signature). Str("signature", signature).
Msg("Signed master key of user with our user-signing key") Msg("Signed master key of user with our user-signing key")
if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey, signature); err != nil { if err := mach.CryptoStore.PutSignature(ctx, userID, masterKey, mach.Client.UserID, mach.CrossSigningKeys.UserSigningKey.PublicKey(), signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err) return fmt.Errorf("error storing signature in crypto store: %w", err)
} }
@@ -102,7 +77,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
userID := mach.Client.UserID userID := mach.Client.UserID
deviceID := mach.Client.DeviceID deviceID := mach.Client.DeviceID
masterKey := mach.CrossSigningKeys.MasterKey.PublicKey masterKey := mach.CrossSigningKeys.MasterKey.PublicKey()
masterKeyObj := mautrix.ReqKeysSignatures{ masterKeyObj := mautrix.ReqKeysSignatures{
UserID: userID, UserID: userID,
@@ -115,11 +90,7 @@ func (mach *OlmMachine) SignOwnMasterKey(ctx context.Context) error {
if err != nil { if err != nil {
return fmt.Errorf("failed to sign JSON: %w", err) return fmt.Errorf("failed to sign JSON: %w", err)
} }
masterKeyObj.Signatures = mautrix.Signatures{ masterKeyObj.Signatures = signatures.NewSingleSignature(userID, id.KeyAlgorithmEd25519, deviceID.String(), signature)
userID: map[id.KeyID]string{
id.NewKeyID(id.KeyAlgorithmEd25519, deviceID.String()): signature,
},
}
mach.Log.Debug(). mach.Log.Debug().
Str("device_id", deviceID.String()). Str("device_id", deviceID.String()).
Str("signature", signature). Str("signature", signature).
@@ -178,7 +149,7 @@ func (mach *OlmMachine) SignOwnDevice(ctx context.Context, device *id.Device) er
Str("signature", signature). Str("signature", signature).
Msg("Signed own device key with self-signing key") Msg("Signed own device key with self-signing key")
if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey, signature); err != nil { if err := mach.CryptoStore.PutSignature(ctx, device.UserID, device.SigningKey, mach.Client.UserID, mach.CrossSigningKeys.SelfSigningKey.PublicKey(), signature); err != nil {
return fmt.Errorf("error storing signature in crypto store: %w", err) return fmt.Errorf("error storing signature in crypto store: %w", err)
} }
@@ -209,16 +180,12 @@ func (mach *OlmMachine) getFullDeviceKeys(ctx context.Context, device *id.Device
} }
// signAndUpload signs the given key signatures object and uploads it to the server. // signAndUpload signs the given key signatures object and uploads it to the server.
func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key *olm.PkSigning) (string, error) { func (mach *OlmMachine) signAndUpload(ctx context.Context, req mautrix.ReqKeysSignatures, userID id.UserID, signedThing string, key olm.PKSigning) (string, error) {
signature, err := key.SignJSON(req) signature, err := key.SignJSON(req)
if err != nil { if err != nil {
return "", fmt.Errorf("failed to sign JSON: %w", err) return "", fmt.Errorf("failed to sign JSON: %w", err)
} }
req.Signatures = mautrix.Signatures{ req.Signatures = signatures.NewSingleSignature(mach.Client.UserID, id.KeyAlgorithmEd25519, key.PublicKey().String(), signature)
mach.Client.UserID: map[id.KeyID]string{
id.NewKeyID(id.KeyAlgorithmEd25519, key.PublicKey.String()): signature,
},
}
resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{ resp, err := mach.Client.UploadSignatures(ctx, &mautrix.ReqUploadSignatures{
userID: map[string]mautrix.ReqKeysSignatures{ userID: map[string]mautrix.ReqKeysSignatures{

View File

@@ -14,6 +14,7 @@ import (
"maunium.net/go/mautrix/crypto/ssss" "maunium.net/go/mautrix/crypto/ssss"
"maunium.net/go/mautrix/crypto/utils" "maunium.net/go/mautrix/crypto/utils"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
) )
// FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine. // FetchCrossSigningKeysFromSSSS fetches all the cross-signing keys from SSSS, decrypts them using the given key and stores them in the olm machine.
@@ -57,33 +58,8 @@ func (mach *OlmMachine) retrieveDecryptXSigningKey(ctx context.Context, keyName
return decryptedKey, nil return decryptedKey, nil
} }
// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys. func (mach *OlmMachine) GenerateAndUploadCrossSigningKeysWithPassword(ctx context.Context, userPassword, passphrase string) (string, *CrossSigningKeysCache, error) {
// return mach.GenerateAndUploadCrossSigningKeys(ctx, func(uiResp *mautrix.RespUserInteractive) interface{} {
// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key
// is used. The base58-formatted recovery key is the first return parameter.
//
// The account password of the user is required for uploading keys to the server.
func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, userPassword, passphrase string) (string, error) {
key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase)
if err != nil {
return "", fmt.Errorf("failed to generate and upload SSSS key: %w", err)
}
// generate the three cross-signing keys
keysCache, err := mach.GenerateCrossSigningKeys()
if err != nil {
return "", err
}
recoveryKey := key.RecoveryKey()
// Store the private keys in SSSS
if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil {
return recoveryKey, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err)
}
// Publish cross-signing keys
err = mach.PublishCrossSigningKeys(ctx, keysCache, func(uiResp *mautrix.RespUserInteractive) interface{} {
return &mautrix.ReqUIAuthLogin{ return &mautrix.ReqUIAuthLogin{
BaseAuthData: mautrix.BaseAuthData{ BaseAuthData: mautrix.BaseAuthData{
Type: mautrix.AuthTypePassword, Type: mautrix.AuthTypePassword,
@@ -92,29 +68,68 @@ func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, u
User: mach.Client.UserID.String(), User: mach.Client.UserID.String(),
Password: userPassword, Password: userPassword,
} }
}) }, passphrase)
}
// GenerateAndUploadCrossSigningKeys generates a new key with all corresponding cross-signing keys.
//
// A passphrase can be provided to generate the SSSS key. If the passphrase is empty, a random key
// is used. The base58-formatted recovery key is the first return parameter.
//
// The account password of the user is required for uploading keys to the server.
func (mach *OlmMachine) GenerateAndUploadCrossSigningKeys(ctx context.Context, uiaCallback mautrix.UIACallback, passphrase string) (string, *CrossSigningKeysCache, error) {
key, err := mach.SSSS.GenerateAndUploadKey(ctx, passphrase)
if err != nil { if err != nil {
return recoveryKey, fmt.Errorf("failed to publish cross-signing keys: %w", err) return "", nil, fmt.Errorf("failed to generate and upload SSSS key: %w", err)
}
// generate the three cross-signing keys
keysCache, err := mach.GenerateCrossSigningKeys()
if err != nil {
return "", nil, err
}
// Store the private keys in SSSS
if err := mach.UploadCrossSigningKeysToSSSS(ctx, key, keysCache); err != nil {
return "", nil, fmt.Errorf("failed to upload cross-signing keys to SSSS: %w", err)
}
// Publish cross-signing keys
err = mach.PublishCrossSigningKeys(ctx, keysCache, uiaCallback)
if err != nil {
return "", nil, fmt.Errorf("failed to publish cross-signing keys: %w", err)
} }
err = mach.SSSS.SetDefaultKeyID(ctx, key.ID) err = mach.SSSS.SetDefaultKeyID(ctx, key.ID)
if err != nil { if err != nil {
return recoveryKey, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err) return "", nil, fmt.Errorf("failed to mark %s as the default key: %w", key.ID, err)
} }
return recoveryKey, nil return key.RecoveryKey(), keysCache, nil
} }
// UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key. // UploadCrossSigningKeysToSSSS stores the given cross-signing keys on the server encrypted with the given key.
func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error { func (mach *OlmMachine) UploadCrossSigningKeysToSSSS(ctx context.Context, key *ssss.Key, keys *CrossSigningKeysCache) error {
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed, key); err != nil { if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningMaster, keys.MasterKey.Seed(), key); err != nil {
return err return err
} }
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed, key); err != nil { if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningSelf, keys.SelfSigningKey.Seed(), key); err != nil {
return err return err
} }
if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed, key); err != nil { if err := mach.SSSS.SetEncryptedAccountData(ctx, event.AccountDataCrossSigningUser, keys.UserSigningKey.Seed(), key); err != nil {
return err return err
} }
// Also store these locally
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageMaster, keys.MasterKey.PublicKey()); err != nil {
return err
}
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageSelfSigning, keys.SelfSigningKey.PublicKey()); err != nil {
return err
}
if err := mach.CryptoStore.PutCrossSigningKey(ctx, mach.Client.UserID, id.XSUsageUserSigning, keys.UserSigningKey.PublicKey()); err != nil {
return err
}
return nil return nil
} }

View File

@@ -11,7 +11,7 @@ import (
"context" "context"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@@ -80,7 +80,7 @@ func (mach *OlmMachine) storeCrossSigningKeys(ctx context.Context, crossSigningK
} }
log.Debug().Msg("Verifying cross-signing key signature") log.Debug().Msg("Verifying cross-signing key signature")
if verified, err := olm.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil { if verified, err := signatures.VerifySignatureJSON(userKeys, signUserID, signKeyName, signingKey); err != nil {
log.Warn().Err(err).Msg("Error verifying cross-signing key signature") log.Warn().Err(err).Msg("Error verifying cross-signing key signature")
} else { } else {
if verified { if verified {

View File

@@ -71,8 +71,12 @@ func (mach *OlmMachine) DecryptMegolmEvent(ctx context.Context, evt *event.Event
ownSigningKey, ownIdentityKey := mach.account.Keys() ownSigningKey, ownIdentityKey := mach.account.Keys()
if sess.SigningKey == ownSigningKey && sess.SenderKey == ownIdentityKey && len(sess.ForwardingChains) == 0 { if sess.SigningKey == ownSigningKey && sess.SenderKey == ownIdentityKey && len(sess.ForwardingChains) == 0 {
trustLevel = id.TrustStateVerified trustLevel = id.TrustStateVerified
} else {
if mach.DisableDecryptKeyFetching {
device, err = mach.CryptoStore.FindDeviceByKey(ctx, evt.Sender, sess.SenderKey)
} else { } else {
device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey) device, err = mach.GetOrFetchDeviceByKey(ctx, evt.Sender, sess.SenderKey)
}
if err != nil { if err != nil {
// We don't want to throw these errors as the message can still be decrypted. // We don't want to throw these errors as the message can still be decrypted.
log.Debug().Err(err).Msg("Failed to get device to verify session") log.Debug().Err(err).Msg("Failed to get device to verify session")

View File

@@ -57,7 +57,7 @@ func (mach *OlmMachine) decryptOlmEvent(ctx context.Context, evt *event.Event) (
if !ok { if !ok {
return nil, NotEncryptedForMe return nil, NotEncryptedForMe
} }
decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt.Sender, content.SenderKey, ownContent.Type, ownContent.Body) decrypted, err := mach.decryptAndParseOlmCiphertext(ctx, evt, content.SenderKey, ownContent.Type, ownContent.Body)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -69,13 +69,13 @@ type OlmEventKeys struct {
Ed25519 id.Ed25519 `json:"ed25519"` Ed25519 id.Ed25519 `json:"ed25519"`
} }
func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, sender id.UserID, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) { func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, evt *event.Event, senderKey id.SenderKey, olmType id.OlmMsgType, ciphertext string) (*DecryptedOlmEvent, error) {
if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg { if olmType != id.OlmMsgTypePreKey && olmType != id.OlmMsgTypeMsg {
return nil, UnsupportedOlmMessageType return nil, UnsupportedOlmMessageType
} }
endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second) endTimeTrace := mach.timeTrace(ctx, "decrypting olm ciphertext", 5*time.Second)
plaintext, err := mach.tryDecryptOlmCiphertext(ctx, sender, senderKey, olmType, ciphertext) plaintext, err := mach.tryDecryptOlmCiphertext(ctx, evt.Sender, senderKey, olmType, ciphertext)
endTimeTrace() endTimeTrace()
if err != nil { if err != nil {
return nil, err return nil, err
@@ -88,7 +88,8 @@ func (mach *OlmMachine) decryptAndParseOlmCiphertext(ctx context.Context, sender
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to parse olm payload: %w", err) return nil, fmt.Errorf("failed to parse olm payload: %w", err)
} }
if sender != olmEvt.Sender { olmEvt.Type.Class = evt.Type.Class
if evt.Sender != olmEvt.Sender {
return nil, SenderMismatch return nil, SenderMismatch
} else if mach.Client.UserID != olmEvt.Recipient { } else if mach.Client.UserID != olmEvt.Recipient {
return nil, RecipientMismatch return nil, RecipientMismatch

View File

@@ -14,7 +14,7 @@ import (
"github.com/rs/zerolog" "github.com/rs/zerolog"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@@ -52,7 +52,7 @@ func (mach *OlmMachine) storeDeviceSelfSignatures(ctx context.Context, userID id
} else if _, ok := selfSigs[id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]; !ok { } else if _, ok := selfSigs[id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]; !ok {
continue continue
} }
if verified, err := olm.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified { if verified, err := signatures.VerifySignatureJSON(deviceKeys, signerUserID, pubKey.String(), pubKey); verified {
if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok { if signKey, ok := deviceKeys.Keys[id.DeviceKeyID(signerKey)]; ok {
signature := deviceKeys.Signatures[signerUserID][id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())] signature := deviceKeys.Signatures[signerUserID][id.NewKeyID(id.KeyAlgorithmEd25519, pubKey.String())]
log.Trace().Err(err). log.Trace().Err(err).
@@ -245,7 +245,7 @@ func (mach *OlmMachine) validateDevice(userID id.UserID, deviceID id.DeviceID, d
return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey) return existing, fmt.Errorf("%w (expected %s, got %s)", MismatchingSigningKey, existing.SigningKey, signingKey)
} }
ok, err := olm.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey) ok, err := signatures.VerifySignatureJSON(deviceKeys, userID, deviceID.String(), signingKey)
if err != nil { if err != nil {
return existing, fmt.Errorf("failed to verify signature: %w", err) return existing, fmt.Errorf("failed to verify signature: %w", err)
} else if !ok { } else if !ok {

View File

@@ -118,7 +118,7 @@ func (mach *OlmMachine) EncryptMegolmEvent(ctx context.Context, roomID id.RoomID
log.Debug().Msg("Encrypted event successfully") log.Debug().Msg("Encrypted event successfully")
err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session) err = mach.CryptoStore.UpdateOutboundGroupSession(ctx, session)
if err != nil { if err != nil {
log.Warn().Err(err).Msg("Failed to update megolm session in crypto store after encrypting") return nil, fmt.Errorf("failed to update outbound group session after encrypting: %w", err)
} }
encrypted := &event.EncryptedEventContent{ encrypted := &event.EncryptedEventContent{
Algorithm: id.AlgorithmMegolmV1, Algorithm: id.AlgorithmMegolmV1,
@@ -330,6 +330,17 @@ func (mach *OlmMachine) encryptAndSendGroupSession(ctx context.Context, session
Str("target_user_id", userID.String()). Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()). Str("target_device_id", deviceID.String()).
Msg("Encrypted group session for device") Msg("Encrypted group session for device")
if !mach.DisableSharedGroupSessionTracking {
err := mach.CryptoStore.MarkOutboundGroupSessionShared(ctx, userID, device.identity.IdentityKey, session.id)
if err != nil {
log.Warn().
Err(err).
Str("target_user_id", userID.String()).
Str("target_device_id", deviceID.String()).
Stringer("target_session_id", session.id).
Msg("Failed to mark outbound group session shared")
}
}
} }
} }

View File

@@ -12,7 +12,7 @@ import (
"fmt" "fmt"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
@@ -109,7 +109,7 @@ func (mach *OlmMachine) createOutboundSessions(ctx context.Context, input map[id
continue continue
} }
identity := input[userID][deviceID] identity := input[userID][deviceID]
if ok, err := olm.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil { if ok, err := signatures.VerifySignatureJSON(oneTimeKey.RawData, userID, deviceID.String(), identity.SigningKey); err != nil {
log.Error().Err(err).Msg("Failed to verify signature of one-time key") log.Error().Err(err).Msg("Failed to verify signature of one-time key")
} else if !ok { } else if !ok {
log.Warn().Msg("One-time key has invalid signature from device") log.Warn().Msg("One-time key has invalid signature from device")

View File

@@ -110,12 +110,13 @@ func (a Account) IdentityKeys() (id.Ed25519, id.Curve25519) {
return ed25519, curve25519 return ed25519, curve25519
} }
// Sign returns the signature of a message using the Ed25519 key for this Account. // Sign returns the base64-encoded signature of a message using the Ed25519 key
// for this Account.
func (a Account) Sign(message []byte) ([]byte, error) { func (a Account) Sign(message []byte) ([]byte, error) {
if len(message) == 0 { if len(message) == 0 {
return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput) return nil, fmt.Errorf("sign: %w", goolm.ErrEmptyInput)
} }
return goolm.Base64Encode(a.IdKeys.Ed25519.Sign(message)), nil return []byte(base64.RawStdEncoding.EncodeToString(a.IdKeys.Ed25519.Sign(message))), nil
} }
// OneTimeKeys returns the public parts of the unpublished one time keys of the Account. // OneTimeKeys returns the public parts of the unpublished one time keys of the Account.

View File

@@ -2,8 +2,10 @@ package cipher
import ( import (
"bytes" "bytes"
"crypto/aes"
"io" "io"
"maunium.net/go/mautrix/crypto/aescbc"
"maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/crypto"
) )
@@ -36,7 +38,7 @@ func deriveAESKeys(kdfInfo []byte, key []byte) (*derivedAESKeys, error) {
// AESSha512BlockSize resturns the blocksize of the cipher AESSHA256. // AESSha512BlockSize resturns the blocksize of the cipher AESSHA256.
func AESSha512BlockSize() int { func AESSha512BlockSize() int {
return crypto.AESCBCBlocksize() return aes.BlockSize
} }
// AESSHA256 is a valid cipher using AES with CBC and HKDFSha256. // AESSHA256 is a valid cipher using AES with CBC and HKDFSha256.
@@ -57,7 +59,7 @@ func (c AESSHA256) Encrypt(key, plaintext []byte) (ciphertext []byte, err error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
ciphertext, err = crypto.AESCBCEncrypt(keys.key, keys.iv, plaintext) ciphertext, err = aescbc.Encrypt(keys.key, keys.iv, plaintext)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -70,7 +72,7 @@ func (c AESSHA256) Decrypt(key, ciphertext []byte) (plaintext []byte, err error)
if err != nil { if err != nil {
return nil, err return nil, err
} }
plaintext, err = crypto.AESCBCDecrypt(keys.key, keys.iv, ciphertext) plaintext, err = aescbc.Decrypt(keys.key, keys.iv, ciphertext)
if err != nil { if err != nil {
return nil, err return nil, err
} }

View File

@@ -1,4 +1,5 @@
// cipher provides the methods and structs to do encryptions for olm/megolm. // Package cipher provides the methods and structs to do encryptions for
// olm/megolm.
package cipher package cipher
// Cipher defines a valid cipher. // Cipher defines a valid cipher.

View File

@@ -1,75 +0,0 @@
package crypto
import (
"bytes"
"crypto/aes"
"crypto/cipher"
"fmt"
"maunium.net/go/mautrix/crypto/goolm"
)
// AESCBCBlocksize returns the blocksize of the encryption method
func AESCBCBlocksize() int {
return aes.BlockSize
}
// AESCBCEncrypt encrypts the plaintext with the key and iv. len(iv) must be equal to the blocksize!
func AESCBCEncrypt(key, iv, plaintext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
}
if len(iv) != AESCBCBlocksize() {
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
}
var cipherText []byte
plaintext = pkcs5Padding(plaintext, AESCBCBlocksize())
if len(plaintext)%AESCBCBlocksize() != 0 {
return nil, fmt.Errorf("message: %w", goolm.ErrNotMultipleBlocksize)
}
block, err := aes.NewCipher(key)
if err != nil {
return nil, err
}
cipherText = make([]byte, len(plaintext))
cbc := cipher.NewCBCEncrypter(block, iv)
cbc.CryptBlocks(cipherText, plaintext)
return cipherText, nil
}
// AESCBCDecrypt decrypts the ciphertext with the key and iv. len(iv) must be equal to the blocksize!
func AESCBCDecrypt(key, iv, ciphertext []byte) ([]byte, error) {
if len(key) == 0 {
return nil, fmt.Errorf("AESCBCEncrypt: %w", goolm.ErrNoKeyProvided)
}
if len(iv) != AESCBCBlocksize() {
return nil, fmt.Errorf("iv: %w", goolm.ErrNotBlocksize)
}
var block cipher.Block
var err error
block, err = aes.NewCipher(key)
if err != nil {
return nil, err
}
if len(ciphertext) < AESCBCBlocksize() {
return nil, fmt.Errorf("ciphertext: %w", goolm.ErrNotMultipleBlocksize)
}
cbc := cipher.NewCBCDecrypter(block, iv)
cbc.CryptBlocks(ciphertext, ciphertext)
return pkcs5Unpadding(ciphertext), nil
}
// pkcs5Padding paddes the plaintext to be used in the AESCBC encryption.
func pkcs5Padding(plaintext []byte, blockSize int) []byte {
padding := (blockSize - len(plaintext)%blockSize)
padtext := bytes.Repeat([]byte{byte(padding)}, padding)
return append(plaintext, padtext...)
}
// pkcs5Unpadding undoes the padding to the plaintext after AESCBC decryption.
func pkcs5Unpadding(plaintext []byte) []byte {
length := len(plaintext)
unpadding := int(plaintext[length-1])
return plaintext[:(length - unpadding)]
}

View File

@@ -0,0 +1,2 @@
// Package crpyto provides the nessesary encryption methods for olm/megolm
package crypto

View File

@@ -1,2 +0,0 @@
// crpyto provides the nessesary encryption methods for olm/megolm
package crypto

View File

@@ -21,8 +21,6 @@ var (
ErrChainTooHigh = errors.New("chain index too high") ErrChainTooHigh = errors.New("chain index too high")
ErrBadInput = errors.New("bad input") ErrBadInput = errors.New("bad input")
ErrBadVersion = errors.New("wrong version") ErrBadVersion = errors.New("wrong version")
ErrNotBlocksize = errors.New("length != blocksize")
ErrNotMultipleBlocksize = errors.New("length not a multiple of the blocksize")
ErrWrongPickleVersion = errors.New("wrong pickle version") ErrWrongPickleVersion = errors.New("wrong pickle version")
ErrValueTooShort = errors.New("value too short") ErrValueTooShort = errors.New("value too short")
ErrInputToSmall = errors.New("input too small (truncated?)") ErrInputToSmall = errors.New("input too small (truncated?)")

View File

@@ -45,8 +45,8 @@ func NewDecryptionFromPrivate(privateKey crypto.Curve25519PrivateKey) (*Decrypti
return s, nil return s, nil
} }
// PubKey returns the public key base 64 encoded. // PublicKey returns the public key base 64 encoded.
func (s Decryption) PubKey() id.Curve25519 { func (s Decryption) PublicKey() id.Curve25519 {
return s.KeyPair.B64Encoded() return s.KeyPair.B64Encoded()
} }

View File

@@ -2,7 +2,11 @@ package pk
import ( import (
"crypto/rand" "crypto/rand"
"encoding/json"
"github.com/tidwall/sjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm" "maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto" "maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@@ -10,15 +14,15 @@ import (
// Signing is used for signing a pk // Signing is used for signing a pk
type Signing struct { type Signing struct {
KeyPair crypto.Ed25519KeyPair `json:"key_pair"` keyPair crypto.Ed25519KeyPair
Seed []byte `json:"seed"` seed []byte
} }
// NewSigningFromSeed constructs a new Signing based on a seed. // NewSigningFromSeed constructs a new Signing based on a seed.
func NewSigningFromSeed(seed []byte) (*Signing, error) { func NewSigningFromSeed(seed []byte) (*Signing, error) {
s := &Signing{} s := &Signing{}
s.Seed = seed s.seed = seed
s.KeyPair = crypto.Ed25519GenerateFromSeed(seed) s.keyPair = crypto.Ed25519GenerateFromSeed(seed)
return s, nil return s, nil
} }
@@ -32,13 +36,34 @@ func NewSigning() (*Signing, error) {
return NewSigningFromSeed(seed) return NewSigningFromSeed(seed)
} }
// Sign returns the signature of the message base64 encoded. // Seed returns the seed of the key pair.
func (s Signing) Sign(message []byte) []byte { func (s Signing) Seed() []byte {
signature := s.KeyPair.Sign(message) return s.seed
return goolm.Base64Encode(signature)
} }
// PublicKey returns the public key of the key pair base 64 encoded. // PublicKey returns the public key of the key pair base 64 encoded.
func (s Signing) PublicKey() id.Ed25519 { func (s Signing) PublicKey() id.Ed25519 {
return s.KeyPair.B64Encoded() return s.keyPair.B64Encoded()
}
// Sign returns the signature of the message base64 encoded.
func (s Signing) Sign(message []byte) ([]byte, error) {
signature := s.keyPair.Sign(message)
return goolm.Base64Encode(signature), nil
}
// SignJSON creates a signature for the given object after encoding it to
// canonical JSON.
func (s Signing) SignJSON(obj any) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
signature, err := s.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
if err != nil {
return "", err
}
return string(signature), nil
} }

View File

@@ -1,76 +0,0 @@
// sas provides the means to do SAS between keys
package sas
import (
"io"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
)
// SAS contains the key pair and secret for SAS.
type SAS struct {
KeyPair crypto.Curve25519KeyPair
Secret []byte
}
// New creates a new SAS with a new key pair.
func New() (*SAS, error) {
kp, err := crypto.Curve25519GenerateKey(nil)
if err != nil {
return nil, err
}
s := &SAS{
KeyPair: kp,
}
return s, nil
}
// GetPubkey returns the public key of the key pair base64 encoded
func (s SAS) GetPubkey() []byte {
return goolm.Base64Encode(s.KeyPair.PublicKey)
}
// SetTheirKey sets the key of the other party and computes the shared secret.
func (s *SAS) SetTheirKey(key []byte) error {
keyDecoded, err := goolm.Base64Decode(key)
if err != nil {
return err
}
sharedSecret, err := s.KeyPair.SharedSecret(keyDecoded)
if err != nil {
return err
}
s.Secret = sharedSecret
return nil
}
// GenerateBytes creates length bytes from the shared secret and info.
func (s SAS) GenerateBytes(info []byte, length uint) ([]byte, error) {
byteReader := crypto.HKDFSHA256(s.Secret, nil, info)
output := make([]byte, length)
if _, err := io.ReadFull(byteReader, output); err != nil {
return nil, err
}
return output, nil
}
// calculateMAC returns a base64 encoded MAC of input.
func (s *SAS) calculateMAC(input, info []byte, length uint) ([]byte, error) {
key, err := s.GenerateBytes(info, length)
if err != nil {
return nil, err
}
mac := crypto.HMACSHA256(key, input)
return goolm.Base64Encode(mac), nil
}
// CalculateMACFixes returns a base64 encoded, 32 byte long MAC of input.
func (s SAS) CalculateMAC(input, info []byte) ([]byte, error) {
return s.calculateMAC(input, info, 32)
}
// CalculateMACLongKDF returns a base64 encoded, 256 byte long MAC of input.
func (s SAS) CalculateMACLongKDF(input, info []byte) ([]byte, error) {
return s.calculateMAC(input, info, 256)
}

View File

@@ -0,0 +1,3 @@
// Package session provides the different types of sessions for en/decrypting
// of messages
package session

View File

@@ -1,2 +0,0 @@
// session provides the different types of sessions for en/decrypting of messages
package session

View File

@@ -1,23 +0,0 @@
package utilities
import (
"encoding/base64"
"maunium.net/go/mautrix/crypto/goolm"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id"
)
// VerifySignature verifies an ed25519 signature.
func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) {
keyDecoded, err := base64.RawStdEncoding.DecodeString(string(key))
if err != nil {
return false, err
}
signatureDecoded, err := goolm.Base64Decode(signature)
if err != nil {
return false, err
}
publicKey := crypto.Ed25519PublicKey(keyDecoded)
return publicKey.Verify(message, signatureDecoded), nil
}

184
vendor/maunium.net/go/mautrix/crypto/keybackup.go generated vendored Normal file
View File

@@ -0,0 +1,184 @@
package crypto
import (
"context"
"fmt"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/backup"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/signatures"
"maunium.net/go/mautrix/id"
)
func (mach *OlmMachine) DownloadAndStoreLatestKeyBackup(ctx context.Context, megolmBackupKey *backup.MegolmBackupKey) (id.KeyBackupVersion, error) {
log := mach.machOrContextLog(ctx).With().
Str("action", "download and store latest key backup").
Logger()
ctx = log.WithContext(ctx)
versionInfo, err := mach.GetAndVerifyLatestKeyBackupVersion(ctx)
if err != nil {
return "", err
} else if versionInfo == nil {
return "", nil
}
err = mach.GetAndStoreKeyBackup(ctx, versionInfo.Version, megolmBackupKey)
return versionInfo.Version, err
}
func (mach *OlmMachine) GetAndVerifyLatestKeyBackupVersion(ctx context.Context) (*mautrix.RespRoomKeysVersion[backup.MegolmAuthData], error) {
versionInfo, err := mach.Client.GetKeyBackupLatestVersion(ctx)
if err != nil {
return nil, err
}
if versionInfo.Algorithm != id.KeyBackupAlgorithmMegolmBackupV1 {
return nil, fmt.Errorf("unsupported key backup algorithm: %s", versionInfo.Algorithm)
}
log := mach.machOrContextLog(ctx).With().
Int("count", versionInfo.Count).
Str("etag", versionInfo.ETag).
Stringer("key_backup_version", versionInfo.Version).
Logger()
userSignatures, ok := versionInfo.AuthData.Signatures[mach.Client.UserID]
if !ok {
return nil, fmt.Errorf("no signature from user %s found in key backup", mach.Client.UserID)
}
crossSigningPubkeys := mach.GetOwnCrossSigningPublicKeys(ctx)
signatureVerified := false
for keyID := range userSignatures {
keyAlg, keyName := keyID.Parse()
if keyAlg != id.KeyAlgorithmEd25519 {
continue
}
log := log.With().Str("key_name", keyName).Logger()
var key id.Ed25519
if keyName == crossSigningPubkeys.MasterKey.String() {
key = crossSigningPubkeys.MasterKey
} else if device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, id.DeviceID(keyName)); err != nil {
log.Warn().Err(err).Msg("Failed to fetch device")
continue
} else if !mach.IsDeviceTrusted(device) {
log.Warn().Err(err).Msg("Device is not trusted")
continue
} else {
key = device.SigningKey
}
ok, err = signatures.VerifySignatureJSON(versionInfo.AuthData, mach.Client.UserID, keyName, key)
if err != nil || !ok {
log.Warn().Err(err).Stringer("key_id", keyID).Msg("Signature verification failed")
continue
} else {
// One of the signatures is valid, break from the loop.
signatureVerified = true
break
}
}
if !signatureVerified {
return nil, fmt.Errorf("no valid signature from user %s found in key backup", mach.Client.UserID)
}
return versionInfo, nil
}
func (mach *OlmMachine) GetAndStoreKeyBackup(ctx context.Context, version id.KeyBackupVersion, megolmBackupKey *backup.MegolmBackupKey) error {
keys, err := mach.Client.GetKeyBackup(ctx, version)
if err != nil {
return err
}
log := zerolog.Ctx(ctx)
var count, failedCount int
for roomID, backup := range keys.Rooms {
for sessionID, keyBackupData := range backup.Sessions {
sessionData, err := keyBackupData.SessionData.Decrypt(megolmBackupKey)
if err != nil {
log.Warn().Err(err).Msg("Failed to decrypt session data")
failedCount++
continue
}
err = mach.ImportRoomKeyFromBackup(ctx, version, roomID, sessionID, sessionData)
if err != nil {
log.Warn().Err(err).Msg("Failed to import room key from backup")
failedCount++
continue
}
count++
}
}
log.Info().
Int("count", count).
Int("failed_count", failedCount).
Msg("successfully imported sessions from backup")
return nil
}
func (mach *OlmMachine) ImportRoomKeyFromBackup(ctx context.Context, version id.KeyBackupVersion, roomID id.RoomID, sessionID id.SessionID, keyBackupData *backup.MegolmSessionData) error {
log := zerolog.Ctx(ctx).With().
Str("room_id", roomID.String()).
Str("session_id", sessionID.String()).
Logger()
if keyBackupData.Algorithm != id.AlgorithmMegolmV1 {
return fmt.Errorf("ignoring room key in backup with weird algorithm %s", keyBackupData.Algorithm)
}
igsInternal, err := olm.InboundGroupSessionImport([]byte(keyBackupData.SessionKey))
if err != nil {
return fmt.Errorf("failed to import inbound group session: %w", err)
} else if igsInternal.ID() != sessionID {
log.Warn().
Stringer("actual_session_id", igsInternal.ID()).
Msg("Mismatched session ID while creating inbound group session from key backup")
return fmt.Errorf("mismatched session ID while creating inbound group session from key backup")
}
var maxAge time.Duration
var maxMessages int
if config, err := mach.StateStore.GetEncryptionEvent(ctx, roomID); err != nil {
log.Error().Err(err).Msg("Failed to get encryption event for room")
} else if config != nil {
maxAge = time.Duration(config.RotationPeriodMillis) * time.Millisecond
maxMessages = config.RotationPeriodMessages
}
if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 {
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
}
igs := &InboundGroupSession{
Internal: *igsInternal,
SigningKey: keyBackupData.SenderClaimedKeys.Ed25519,
SenderKey: keyBackupData.SenderKey,
RoomID: roomID,
ForwardingChains: append(keyBackupData.ForwardingKeyChain, keyBackupData.SenderKey.String()),
id: sessionID,
ReceivedAt: time.Now().UTC(),
MaxAge: maxAge.Milliseconds(),
MaxMessages: maxMessages,
KeyBackupVersion: version,
}
err = mach.CryptoStore.PutGroupSession(ctx, roomID, keyBackupData.SenderKey, sessionID, igs)
if err != nil {
return fmt.Errorf("failed to store new inbound group session: %w", err)
}
mach.markSessionReceived(ctx, sessionID)
return nil
}

View File

@@ -122,7 +122,7 @@ func (mach *OlmMachine) importExportedRoomKey(ctx context.Context, session Expor
if err != nil { if err != nil {
return false, fmt.Errorf("failed to store imported session: %w", err) return false, fmt.Errorf("failed to store imported session: %w", err)
} }
mach.markSessionReceived(igs.ID()) mach.markSessionReceived(ctx, igs.ID())
return true, nil return true, nil
} }

View File

@@ -168,6 +168,9 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
if content.MaxMessages != 0 { if content.MaxMessages != 0 {
maxMessages = content.MaxMessages maxMessages = content.MaxMessages
} }
if firstKnownIndex := igsInternal.FirstKnownIndex(); firstKnownIndex > 0 {
log.Warn().Uint32("first_known_index", firstKnownIndex).Msg("Importing partial session")
}
igs := &InboundGroupSession{ igs := &InboundGroupSession{
Internal: *igsInternal, Internal: *igsInternal,
SigningKey: evt.Keys.Ed25519, SigningKey: evt.Keys.Ed25519,
@@ -186,7 +189,7 @@ func (mach *OlmMachine) importForwardedRoomKey(ctx context.Context, evt *Decrypt
log.Error().Err(err).Msg("Failed to store new inbound group session") log.Error().Err(err).Msg("Failed to store new inbound group session")
return false return false
} }
mach.markSessionReceived(content.SessionID) mach.markSessionReceived(ctx, content.SessionID)
log.Debug().Msg("Received forwarded inbound group session") log.Debug().Msg("Received forwarded inbound group session")
return true return true
} }
@@ -222,11 +225,34 @@ func (mach *OlmMachine) rejectKeyRequest(ctx context.Context, rejection KeyShare
} }
} }
func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, _ event.RequestedKeyInfo) *KeyShareRejection { // sendToOneDevice sends a to-device event to a single device.
func (mach *OlmMachine) sendToOneDevice(ctx context.Context, userID id.UserID, deviceID id.DeviceID, eventType event.Type, content interface{}) error {
_, err := mach.Client.SendToDevice(ctx, eventType, &mautrix.ReqSendToDevice{
Messages: map[id.UserID]map[id.DeviceID]*event.Content{
userID: {
deviceID: {
Parsed: content,
},
},
},
})
return err
}
func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Device, evt event.RequestedKeyInfo) *KeyShareRejection {
log := mach.machOrContextLog(ctx) log := mach.machOrContextLog(ctx)
if mach.Client.UserID != device.UserID { if mach.Client.UserID != device.UserID {
log.Debug().Msg("Rejecting key request from a different user") isShared, err := mach.CryptoStore.IsOutboundGroupSessionShared(ctx, device.UserID, device.IdentityKey, evt.SessionID)
if err != nil {
log.Err(err).Msg("Rejecting key request due to internal error when checking session sharing")
return &KeyShareRejectNoResponse
} else if !isShared {
log.Debug().Msg("Rejecting key request for unshared session")
return &KeyShareRejectOtherUser return &KeyShareRejectOtherUser
}
log.Debug().Msg("Accepting key request for shared session")
return nil
} else if mach.Client.DeviceID == device.DeviceID { } else if mach.Client.DeviceID == device.DeviceID {
log.Debug().Msg("Ignoring key request from ourselves") log.Debug().Msg("Ignoring key request from ourselves")
return &KeyShareRejectNoResponse return &KeyShareRejectNoResponse
@@ -248,7 +274,7 @@ func (mach *OlmMachine) defaultAllowKeyShare(ctx context.Context, device *id.Dev
} }
} }
func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.UserID, content *event.RoomKeyRequestEventContent) { func (mach *OlmMachine) HandleRoomKeyRequest(ctx context.Context, sender id.UserID, content *event.RoomKeyRequestEventContent) {
log := zerolog.Ctx(ctx).With(). log := zerolog.Ctx(ctx).With().
Str("request_id", content.RequestID). Str("request_id", content.RequestID).
Str("device_id", content.RequestingDeviceID.String()). Str("device_id", content.RequestingDeviceID.String()).
@@ -327,7 +353,7 @@ func (mach *OlmMachine) handleRoomKeyRequest(ctx context.Context, sender id.User
} }
} }
func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) { func (mach *OlmMachine) HandleBeeperRoomKeyAck(ctx context.Context, sender id.UserID, content *event.BeeperRoomKeyAckEventContent) {
log := mach.machOrContextLog(ctx).With(). log := mach.machOrContextLog(ctx).With().
Str("room_id", content.RoomID.String()). Str("room_id", content.RoomID.String()).
Str("session_id", content.SessionID.String()). Str("session_id", content.SessionID.String()).

View File

@@ -33,18 +33,17 @@ type OlmMachine struct {
PlaintextMentions bool PlaintextMentions bool
// Never ask the server for keys automatically as a side effect. // Never ask the server for keys automatically as a side effect during Megolm decryption.
DisableKeyFetching bool DisableDecryptKeyFetching bool
// Don't mark outbound Olm sessions as shared for devices they were initially sent to.
DisableSharedGroupSessionTracking bool
SendKeysMinTrust id.TrustState SendKeysMinTrust id.TrustState
ShareKeysMinTrust id.TrustState ShareKeysMinTrust id.TrustState
AllowKeyShare func(context.Context, *id.Device, event.RequestedKeyInfo) *KeyShareRejection AllowKeyShare func(context.Context, *id.Device, event.RequestedKeyInfo) *KeyShareRejection
DefaultSASTimeout time.Duration
// AcceptVerificationFrom determines whether the machine will accept verification requests from this device.
AcceptVerificationFrom func(string, *id.Device, id.RoomID) (VerificationRequestResponse, VerificationHooks)
account *OlmAccount account *OlmAccount
roomKeyRequestFilled *sync.Map roomKeyRequestFilled *sync.Map
@@ -53,6 +52,9 @@ type OlmMachine struct {
keyWaiters map[id.SessionID]chan struct{} keyWaiters map[id.SessionID]chan struct{}
keyWaitersLock sync.Mutex keyWaitersLock sync.Mutex
// Optional callback which is called when we save a session to store
SessionReceived func(context.Context, id.SessionID)
devicesToUnwedge map[id.IdentityKey]bool devicesToUnwedge map[id.IdentityKey]bool
devicesToUnwedgeLock sync.Mutex devicesToUnwedgeLock sync.Mutex
recentlyUnwedged map[id.IdentityKey]time.Time recentlyUnwedged map[id.IdentityKey]time.Time
@@ -78,6 +80,9 @@ type OlmMachine struct {
DeleteKeysOnDeviceDelete bool DeleteKeysOnDeviceDelete bool
DisableDeviceChangeKeyRotation bool DisableDeviceChangeKeyRotation bool
secretLock sync.Mutex
secretListeners map[string]chan<- string
} }
// StateStore is used by OlmMachine to get room state information that's needed for encryption. // StateStore is used by OlmMachine to get room state information that's needed for encryption.
@@ -106,12 +111,6 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor
SendKeysMinTrust: id.TrustStateUnset, SendKeysMinTrust: id.TrustStateUnset,
ShareKeysMinTrust: id.TrustStateCrossSignedTOFU, ShareKeysMinTrust: id.TrustStateCrossSignedTOFU,
DefaultSASTimeout: 10 * time.Minute,
AcceptVerificationFrom: func(string, *id.Device, id.RoomID) (VerificationRequestResponse, VerificationHooks) {
// Reject requests by default. Users need to override this to return appropriate verification hooks.
return RejectRequest, nil
},
roomKeyRequestFilled: &sync.Map{}, roomKeyRequestFilled: &sync.Map{},
keyVerificationTransactionState: &sync.Map{}, keyVerificationTransactionState: &sync.Map{},
@@ -119,6 +118,7 @@ func NewOlmMachine(client *mautrix.Client, log *zerolog.Logger, cryptoStore Stor
devicesToUnwedge: make(map[id.IdentityKey]bool), devicesToUnwedge: make(map[id.IdentityKey]bool),
recentlyUnwedged: make(map[id.IdentityKey]time.Time), recentlyUnwedged: make(map[id.IdentityKey]time.Time),
secretListeners: make(map[string]chan<- string),
} }
mach.AllowKeyShare = mach.defaultAllowKeyShare mach.AllowKeyShare = mach.defaultAllowKeyShare
return mach return mach
@@ -145,11 +145,21 @@ func (mach *OlmMachine) Load(ctx context.Context) (err error) {
return nil return nil
} }
func (mach *OlmMachine) saveAccount(ctx context.Context) { func (mach *OlmMachine) saveAccount(ctx context.Context) error {
err := mach.CryptoStore.PutAccount(ctx, mach.account) err := mach.CryptoStore.PutAccount(ctx, mach.account)
if err != nil { if err != nil {
mach.Log.Error().Err(err).Msg("Failed to save account") mach.Log.Error().Err(err).Msg("Failed to save account")
} }
return err
}
func (mach *OlmMachine) KeyBackupVersion() id.KeyBackupVersion {
return mach.account.KeyBackupVersion
}
func (mach *OlmMachine) SetKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) error {
mach.account.KeyBackupVersion = version
return mach.saveAccount(ctx)
} }
// FlushStore calls the Flush method of the CryptoStore. // FlushStore calls the Flush method of the CryptoStore.
@@ -227,11 +237,7 @@ func (mach *OlmMachine) HandleDeviceLists(ctx context.Context, dl *mautrix.Devic
Str("trace_id", traceID). Str("trace_id", traceID).
Interface("changes", dl.Changed). Interface("changes", dl.Changed).
Msg("Device list changes in /sync") Msg("Device list changes in /sync")
if mach.DisableKeyFetching {
mach.CryptoStore.MarkTrackedUsersOutdated(ctx, dl.Changed)
} else {
mach.FetchKeys(ctx, dl.Changed, false) mach.FetchKeys(ctx, dl.Changed, false)
}
mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes") mach.Log.Debug().Str("trace_id", traceID).Msg("Finished handling device list changes")
} }
} }
@@ -328,6 +334,47 @@ func (mach *OlmMachine) HandleMemberEvent(ctx context.Context, evt *event.Event)
} }
} }
func (mach *OlmMachine) HandleEncryptedEvent(ctx context.Context, evt *event.Event) {
if _, ok := evt.Content.Parsed.(*event.EncryptedEventContent); !ok {
mach.machOrContextLog(ctx).Warn().Msg("Passed invalid event to encrypted handler")
return
}
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
if err != nil {
mach.machOrContextLog(ctx).Error().Err(err).Msg("Failed to decrypt to-device event")
return
}
log := mach.machOrContextLog(ctx).With().
Str("decrypted_type", decryptedEvt.Type.Type).
Str("sender_device", decryptedEvt.SenderDevice.String()).
Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()).
Logger()
log.Trace().Msg("Successfully decrypted to-device event")
switch decryptedContent := decryptedEvt.Content.Parsed.(type) {
case *event.RoomKeyEventContent:
mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent)
log.Trace().Msg("Handled room key event")
case *event.ForwardedRoomKeyEventContent:
if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) {
if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok {
// close channel to notify listener that the key was received
close(ch.(chan struct{}))
}
}
log.Trace().Msg("Handled forwarded room key event")
case *event.DummyEventContent:
log.Debug().Msg("Received encrypted dummy event")
case *event.SecretSendEventContent:
mach.receiveSecret(ctx, decryptedEvt, decryptedContent)
log.Trace().Msg("Handled secret send event")
default:
log.Debug().Msg("Unhandled encrypted to-device event")
}
}
// HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you // HandleToDeviceEvent handles a single to-device event. This is automatically called by ProcessSyncResponse, so you
// don't need to add any custom handlers if you use that method. // don't need to add any custom handlers if you use that method.
func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) { func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Event) {
@@ -352,60 +399,19 @@ func (mach *OlmMachine) HandleToDeviceEvent(ctx context.Context, evt *event.Even
} }
switch content := evt.Content.Parsed.(type) { switch content := evt.Content.Parsed.(type) {
case *event.EncryptedEventContent: case *event.EncryptedEventContent:
log = log.With(). mach.HandleEncryptedEvent(ctx, evt)
Str("sender_key", content.SenderKey.String()).
Logger()
log.Debug().Msg("Handling encrypted to-device event")
ctx = log.WithContext(ctx)
decryptedEvt, err := mach.decryptOlmEvent(ctx, evt)
if err != nil {
log.Error().Err(err).Msg("Failed to decrypt to-device event")
return
}
log = log.With().
Str("decrypted_type", decryptedEvt.Type.Type).
Str("sender_device", decryptedEvt.SenderDevice.String()).
Str("sender_signing_key", decryptedEvt.Keys.Ed25519.String()).
Logger()
log.Trace().Msg("Successfully decrypted to-device event")
switch decryptedContent := decryptedEvt.Content.Parsed.(type) {
case *event.RoomKeyEventContent:
mach.receiveRoomKey(ctx, decryptedEvt, decryptedContent)
log.Trace().Msg("Handled room key event")
case *event.ForwardedRoomKeyEventContent:
if mach.importForwardedRoomKey(ctx, decryptedEvt, decryptedContent) {
if ch, ok := mach.roomKeyRequestFilled.Load(decryptedContent.SessionID); ok {
// close channel to notify listener that the key was received
close(ch.(chan struct{}))
}
}
log.Trace().Msg("Handled forwarded room key event")
case *event.DummyEventContent:
log.Debug().Msg("Received encrypted dummy event")
default:
log.Debug().Msg("Unhandled encrypted to-device event")
}
return return
case *event.RoomKeyRequestEventContent: case *event.RoomKeyRequestEventContent:
go mach.handleRoomKeyRequest(ctx, evt.Sender, content) go mach.HandleRoomKeyRequest(ctx, evt.Sender, content)
case *event.BeeperRoomKeyAckEventContent: case *event.BeeperRoomKeyAckEventContent:
mach.handleBeeperRoomKeyAck(ctx, evt.Sender, content) mach.HandleBeeperRoomKeyAck(ctx, evt.Sender, content)
// verification cases
case *event.VerificationStartEventContent:
mach.handleVerificationStart(ctx, evt.Sender, content, content.TransactionID, 10*time.Minute, "")
case *event.VerificationAcceptEventContent:
mach.handleVerificationAccept(ctx, evt.Sender, content, content.TransactionID)
case *event.VerificationKeyEventContent:
mach.handleVerificationKey(ctx, evt.Sender, content, content.TransactionID)
case *event.VerificationMacEventContent:
mach.handleVerificationMAC(ctx, evt.Sender, content, content.TransactionID)
case *event.VerificationCancelEventContent:
mach.handleVerificationCancel(evt.Sender, content, content.TransactionID)
case *event.VerificationRequestEventContent:
mach.handleVerificationRequest(ctx, evt.Sender, content, content.TransactionID, "")
case *event.RoomKeyWithheldEventContent: case *event.RoomKeyWithheldEventContent:
mach.handleRoomKeyWithheld(ctx, content) mach.HandleRoomKeyWithheld(ctx, content)
case *event.SecretRequestEventContent:
if content.Action == event.SecretRequestRequest {
mach.HandleSecretRequest(ctx, evt.Sender, content)
log.Trace().Msg("Handled secret request event")
}
default: default:
deviceID, _ := evt.Content.Raw["device_id"].(string) deviceID, _ := evt.Content.Raw["device_id"].(string)
log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event") log.Debug().Str("maybe_device_id", deviceID).Msg("Unhandled to-device event")
@@ -420,7 +426,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID) device, err := mach.CryptoStore.GetDevice(ctx, userID, deviceID)
if err != nil { if err != nil {
return nil, fmt.Errorf("failed to get sender device from store: %w", err) return nil, fmt.Errorf("failed to get sender device from store: %w", err)
} else if device != nil || mach.DisableKeyFetching { } else if device != nil {
return device, nil return device, nil
} }
if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil { if usersToDevices, err := mach.FetchKeys(ctx, []id.UserID{userID}, true); err != nil {
@@ -439,7 +445,7 @@ func (mach *OlmMachine) GetOrFetchDevice(ctx context.Context, userID id.UserID,
// the given identity key. // the given identity key.
func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) { func (mach *OlmMachine) GetOrFetchDeviceByKey(ctx context.Context, userID id.UserID, identityKey id.IdentityKey) (*id.Device, error) {
deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey) deviceIdentity, err := mach.CryptoStore.FindDeviceByKey(ctx, userID, identityKey)
if err != nil || deviceIdentity != nil || mach.DisableKeyFetching { if err != nil || deviceIdentity != nil {
return deviceIdentity, err return deviceIdentity, err
} }
mach.machOrContextLog(ctx).Debug(). mach.machOrContextLog(ctx).Debug().
@@ -517,7 +523,7 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session") log.Error().Err(err).Str("session_id", sessionID.String()).Msg("Failed to store new inbound group session")
return return
} }
mach.markSessionReceived(sessionID) mach.markSessionReceived(ctx, sessionID)
log.Debug(). log.Debug().
Str("session_id", sessionID.String()). Str("session_id", sessionID.String()).
Str("sender_key", senderKey.String()). Str("sender_key", senderKey.String()).
@@ -527,7 +533,11 @@ func (mach *OlmMachine) createGroupSession(ctx context.Context, senderKey id.Sen
Msg("Received inbound group session") Msg("Received inbound group session")
} }
func (mach *OlmMachine) markSessionReceived(id id.SessionID) { func (mach *OlmMachine) markSessionReceived(ctx context.Context, id id.SessionID) {
if mach.SessionReceived != nil {
mach.SessionReceived(ctx, id)
}
mach.keyWaitersLock.Lock() mach.keyWaitersLock.Lock()
ch, ok := mach.keyWaiters[id] ch, ok := mach.keyWaiters[id]
if ok { if ok {
@@ -619,7 +629,7 @@ func (mach *OlmMachine) receiveRoomKey(ctx context.Context, evt *DecryptedOlmEve
mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled) mach.createGroupSession(ctx, evt.SenderKey, evt.Keys.Ed25519, content.RoomID, content.SessionID, content.SessionKey, maxAge, maxMessages, content.IsScheduled)
} }
func (mach *OlmMachine) handleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) { func (mach *OlmMachine) HandleRoomKeyWithheld(ctx context.Context, content *event.RoomKeyWithheldEventContent) {
if content.Algorithm != id.AlgorithmMegolmV1 { if content.Algorithm != id.AlgorithmMegolmV1 {
zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event") zerolog.Ctx(ctx).Debug().Interface("content", content).Msg("Non-megolm room key withheld event")
return return
@@ -682,8 +692,7 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
} }
mach.lastOTKUpload = time.Now() mach.lastOTKUpload = time.Now()
mach.account.Shared = true mach.account.Shared = true
mach.saveAccount(ctx) return mach.saveAccount(ctx)
return nil
} }
func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) { func (mach *OlmMachine) ExpiredKeyDeleteLoop(ctx context.Context) {

View File

@@ -1,71 +1,29 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
// When the goolm build flag is enabled, this file will make [PKSigning]
// constructors use the goolm constuctors.
//go:build goolm //go:build goolm
package olm package olm
import ( import "maunium.net/go/mautrix/crypto/goolm/pk"
"encoding/json"
"github.com/tidwall/sjson" // NewPKSigningFromSeed creates a new PKSigning object using the given seed.
func NewPKSigningFromSeed(seed []byte) (PKSigning, error) {
"maunium.net/go/mautrix/crypto/canonicaljson" return pk.NewSigningFromSeed(seed)
"maunium.net/go/mautrix/crypto/goolm/pk"
"maunium.net/go/mautrix/id"
)
// PkSigning stores a key pair for signing messages.
type PkSigning struct {
pk.Signing
PublicKey id.Ed25519
Seed []byte
} }
// Clear clears the underlying memory of a PkSigning object. // NewPKSigning creates a new [PKSigning] object, containing a key pair for
func (p *PkSigning) Clear() { // signing messages.
p.Signing = pk.Signing{} func NewPKSigning() (PKSigning, error) {
return pk.NewSigning()
} }
// NewPkSigningFromSeed creates a new PkSigning object using the given seed. func NewPKDecryption(privateKey []byte) (PKDecryption, error) {
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) { return pk.NewDecryption()
p := &PkSigning{}
signing, err := pk.NewSigningFromSeed(seed)
if err != nil {
return nil, err
}
p.Signing = *signing
p.Seed = seed
p.PublicKey = p.Signing.PublicKey()
return p, nil
}
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages.
func NewPkSigning() (*PkSigning, error) {
p := &PkSigning{}
signing, err := pk.NewSigning()
if err != nil {
return nil, err
}
p.Signing = *signing
p.Seed = signing.Seed
p.PublicKey = p.Signing.PublicKey()
return p, err
}
// Sign creates a signature for the given message using this key.
func (p *PkSigning) Sign(message []byte) ([]byte, error) {
return p.Signing.Sign(message), nil
}
// SignJSON creates a signature for the given object after encoding it to canonical JSON.
func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj)
if err != nil {
return "", err
}
objJSON, _ = sjson.DeleteBytes(objJSON, "unsigned")
objJSON, _ = sjson.DeleteBytes(objJSON, "signatures")
signature, err := p.Sign(canonicaljson.CanonicalJSONAssumeValid(objJSON))
if err != nil {
return "", err
}
return string(signature), nil
} }

View File

@@ -0,0 +1,41 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/pk"
"maunium.net/go/mautrix/id"
)
// PKSigning is an interface for signing messages.
type PKSigning interface {
// Seed returns the seed of the key.
Seed() []byte
// PublicKey returns the public key.
PublicKey() id.Ed25519
// Sign creates a signature for the given message using this key.
Sign(message []byte) ([]byte, error)
// SignJSON creates a signature for the given object after encoding it to
// canonical JSON.
SignJSON(obj any) (string, error)
}
var _ PKSigning = (*pk.Signing)(nil)
// PKDecryption is an interface for decrypting messages.
type PKDecryption interface {
// PublicKey returns the public key.
PublicKey() id.Curve25519
// Decrypt verifies and decrypts the given message.
Decrypt(ciphertext, mac []byte, key id.Curve25519) ([]byte, error)
}
var _ PKDecryption = (*pk.Decryption)(nil)

View File

@@ -1,3 +1,9 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
//go:build !goolm //go:build !goolm
package olm package olm
@@ -18,14 +24,17 @@ import (
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
// PkSigning stores a key pair for signing messages. // LibOlmPKSigning stores a key pair for signing messages.
type PkSigning struct { type LibOlmPKSigning struct {
int *C.OlmPkSigning int *C.OlmPkSigning
mem []byte mem []byte
PublicKey id.Ed25519 publicKey id.Ed25519
Seed []byte seed []byte
} }
// Ensure that LibOlmPKSigning implements PKSigning.
var _ PKSigning = (*LibOlmPKSigning)(nil)
func pkSigningSize() uint { func pkSigningSize() uint {
return uint(C.olm_pk_signing_size()) return uint(C.olm_pk_signing_size())
} }
@@ -42,48 +51,57 @@ func pkSigningSignatureLength() uint {
return uint(C.olm_pk_signature_length()) return uint(C.olm_pk_signature_length())
} }
func NewBlankPkSigning() *PkSigning { func newBlankPKSigning() *LibOlmPKSigning {
memory := make([]byte, pkSigningSize()) memory := make([]byte, pkSigningSize())
return &PkSigning{ return &LibOlmPKSigning{
int: C.olm_pk_signing(unsafe.Pointer(&memory[0])), int: C.olm_pk_signing(unsafe.Pointer(&memory[0])),
mem: memory, mem: memory,
} }
} }
// Clear clears the underlying memory of a PkSigning object. // NewPKSigningFromSeed creates a new [PKSigning] object using the given seed.
func (p *PkSigning) Clear() { func NewPKSigningFromSeed(seed []byte) (PKSigning, error) {
C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int)) p := newBlankPKSigning()
} p.clear()
// NewPkSigningFromSeed creates a new PkSigning object using the given seed.
func NewPkSigningFromSeed(seed []byte) (*PkSigning, error) {
p := NewBlankPkSigning()
p.Clear()
pubKey := make([]byte, pkSigningPublicKeyLength()) pubKey := make([]byte, pkSigningPublicKeyLength())
if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int), if C.olm_pk_signing_key_from_seed((*C.OlmPkSigning)(p.int),
unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)), unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() { unsafe.Pointer(&seed[0]), C.size_t(len(seed))) == errorVal() {
return nil, p.lastError() return nil, p.lastError()
} }
p.PublicKey = id.Ed25519(pubKey) p.publicKey = id.Ed25519(pubKey)
p.Seed = seed p.seed = seed
return p, nil return p, nil
} }
// NewPkSigning creates a new PkSigning object, containing a key pair for signing messages. // NewPKSigning creates a new LibOlmPKSigning object, containing a key pair for
func NewPkSigning() (*PkSigning, error) { // signing messages.
func NewPKSigning() (PKSigning, error) {
// Generate the seed // Generate the seed
seed := make([]byte, pkSigningSeedLength()) seed := make([]byte, pkSigningSeedLength())
_, err := rand.Read(seed) _, err := rand.Read(seed)
if err != nil { if err != nil {
panic(NotEnoughGoRandom) panic(NotEnoughGoRandom)
} }
pk, err := NewPkSigningFromSeed(seed) pk, err := NewPKSigningFromSeed(seed)
return pk, err return pk, err
} }
func (p *LibOlmPKSigning) PublicKey() id.Ed25519 {
return p.publicKey
}
func (p *LibOlmPKSigning) Seed() []byte {
return p.seed
}
// clear clears the underlying memory of a LibOlmPKSigning object.
func (p *LibOlmPKSigning) clear() {
C.olm_clear_pk_signing((*C.OlmPkSigning)(p.int))
}
// Sign creates a signature for the given message using this key. // Sign creates a signature for the given message using this key.
func (p *PkSigning) Sign(message []byte) ([]byte, error) { func (p *LibOlmPKSigning) Sign(message []byte) ([]byte, error) {
signature := make([]byte, pkSigningSignatureLength()) signature := make([]byte, pkSigningSignatureLength())
if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)), if C.olm_pk_sign((*C.OlmPkSigning)(p.int), (*C.uint8_t)(unsafe.Pointer(&message[0])), C.size_t(len(message)),
(*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() { (*C.uint8_t)(unsafe.Pointer(&signature[0])), C.size_t(len(signature))) == errorVal() {
@@ -93,7 +111,7 @@ func (p *PkSigning) Sign(message []byte) ([]byte, error) {
} }
// SignJSON creates a signature for the given object after encoding it to canonical JSON. // SignJSON creates a signature for the given object after encoding it to canonical JSON.
func (p *PkSigning) SignJSON(obj interface{}) (string, error) { func (p *LibOlmPKSigning) SignJSON(obj interface{}) (string, error) {
objJSON, err := json.Marshal(obj) objJSON, err := json.Marshal(obj)
if err != nil { if err != nil {
return "", err return "", err
@@ -107,12 +125,13 @@ func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
return string(signature), nil return string(signature), nil
} }
// lastError returns the last error that happened in relation to this PkSigning object. // lastError returns the last error that happened in relation to this
func (p *PkSigning) lastError() error { // LibOlmPKSigning object.
func (p *LibOlmPKSigning) lastError() error {
return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int)))) return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int))))
} }
type PkDecryption struct { type LibOlmPKDecryption struct {
int *C.OlmPkDecryption int *C.OlmPkDecryption
mem []byte mem []byte
PublicKey []byte PublicKey []byte
@@ -126,13 +145,13 @@ func pkDecryptionPublicKeySize() uint {
return uint(C.olm_pk_key_length()) return uint(C.olm_pk_key_length())
} }
func NewPkDecryption(privateKey []byte) (*PkDecryption, error) { func NewPkDecryption(privateKey []byte) (*LibOlmPKDecryption, error) {
memory := make([]byte, pkDecryptionSize()) memory := make([]byte, pkDecryptionSize())
p := &PkDecryption{ p := &LibOlmPKDecryption{
int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])), int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])),
mem: memory, mem: memory,
} }
p.Clear() p.clear()
pubKey := make([]byte, pkDecryptionPublicKeySize()) pubKey := make([]byte, pkDecryptionPublicKeySize())
if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int), if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int),
@@ -145,7 +164,7 @@ func NewPkDecryption(privateKey []byte) (*PkDecryption, error) {
return p, nil return p, nil
} }
func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) { func (p *LibOlmPKDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext)))) maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext))))
plaintext := make([]byte, maxPlaintextLength) plaintext := make([]byte, maxPlaintextLength)
@@ -162,11 +181,12 @@ func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byt
} }
// Clear clears the underlying memory of a PkDecryption object. // Clear clears the underlying memory of a PkDecryption object.
func (p *PkDecryption) Clear() { func (p *LibOlmPKDecryption) clear() {
C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int)) C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int))
} }
// lastError returns the last error that happened in relation to this PkDecryption object. // lastError returns the last error that happened in relation to this
func (p *PkDecryption) lastError() error { // LibOlmPKDecryption object.
func (p *LibOlmPKDecryption) lastError() error {
return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int)))) return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int))))
} }

View File

@@ -1,146 +0,0 @@
//go:build !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++
// #include <olm/olm.h>
import "C"
import (
"encoding/json"
"fmt"
"unsafe"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exgjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/id"
)
// Utility stores the necessary state to perform hash and signature
// verification operations.
type Utility struct {
int *C.OlmUtility
mem []byte
}
// utilitySize returns the size of a utility object in bytes.
func utilitySize() uint {
return uint(C.olm_utility_size())
}
// sha256Len returns the length of the buffer needed to hold the SHA-256 hash.
func (u *Utility) sha256Len() uint {
return uint(C.olm_sha256_length((*C.OlmUtility)(u.int)))
}
// lastError returns an error describing the most recent error to happen to a
// utility.
func (u *Utility) lastError() error {
return convertError(C.GoString(C.olm_utility_last_error((*C.OlmUtility)(u.int))))
}
// Clear clears the memory used to back this utility.
func (u *Utility) Clear() error {
r := C.olm_clear_utility((*C.OlmUtility)(u.int))
if r == errorVal() {
return u.lastError()
}
return nil
}
// NewUtility creates a new utility.
func NewUtility() *Utility {
memory := make([]byte, utilitySize())
return &Utility{
int: C.olm_utility(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
// Sha256 calculates the SHA-256 hash of the input and encodes it as base64.
func (u *Utility) Sha256(input string) string {
if len(input) == 0 {
panic(EmptyInput)
}
output := make([]byte, u.sha256Len())
r := C.olm_sha256(
(*C.OlmUtility)(u.int),
unsafe.Pointer(&([]byte(input)[0])),
C.size_t(len(input)),
unsafe.Pointer(&(output[0])),
C.size_t(len(output)))
if r == errorVal() {
panic(u.lastError())
}
return string(output)
}
// VerifySignature verifies an ed25519 signature. Returns true if the verification
// suceeds or false otherwise. Returns error on failure. If the key was too
// small then the error will be "INVALID_BASE64".
func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) {
if len(message) == 0 || len(key) == 0 || len(signature) == 0 {
return false, EmptyInput
}
r := C.olm_ed25519_verify(
(*C.OlmUtility)(u.int),
unsafe.Pointer(&([]byte(key)[0])),
C.size_t(len(key)),
unsafe.Pointer(&([]byte(message)[0])),
C.size_t(len(message)),
unsafe.Pointer(&([]byte(signature)[0])),
C.size_t(len(signature)))
if r == errorVal() {
err = u.lastError()
if err == BadMessageMAC {
err = nil
}
} else {
ok = true
}
return ok, err
}
// VerifySignatureJSON verifies the signature in the JSON object _obj following
// the Matrix specification:
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
// If the _obj is a struct, the `json` tags will be honored.
func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
var err error
objJSON, ok := obj.(json.RawMessage)
if !ok {
objJSON, err = json.Marshal(obj)
if err != nil {
return false, err
}
}
sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
if !sig.Exists() || sig.Type != gjson.String {
return false, SignatureNotFound
}
objJSON, err = sjson.DeleteBytes(objJSON, "unsigned")
if err != nil {
return false, err
}
objJSON, err = sjson.DeleteBytes(objJSON, "signatures")
if err != nil {
return false, err
}
objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON))
return u.VerifySignature(objJSONString, key, sig.Str)
}
// VerifySignatureJSON verifies the signature in the JSON object _obj following
// the Matrix specification:
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
// This function is a wrapper over Utility.VerifySignatureJSON that creates and
// destroys the Utility object transparently.
// If the _obj is a struct, the `json` tags will be honored.
func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
u := NewUtility()
defer u.Clear()
return u.VerifySignatureJSON(obj, userID, keyName, key)
}

View File

@@ -1,92 +0,0 @@
//go:build goolm
package olm
import (
"crypto/sha256"
"encoding/base64"
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exgjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/utilities"
"maunium.net/go/mautrix/id"
)
// Utility stores the necessary state to perform hash and signature
// verification operations.
type Utility struct{}
// Clear clears the memory used to back this utility.
func (u *Utility) Clear() error {
return nil
}
// NewUtility creates a new utility.
func NewUtility() *Utility {
return &Utility{}
}
// Sha256 calculates the SHA-256 hash of the input and encodes it as base64.
func (u *Utility) Sha256(input string) string {
if len(input) == 0 {
panic(EmptyInput)
}
hash := sha256.Sum256([]byte(input))
return base64.RawStdEncoding.EncodeToString(hash[:])
}
// VerifySignature verifies an ed25519 signature. Returns true if the verification
// suceeds or false otherwise. Returns error on failure. If the key was too
// small then the error will be "INVALID_BASE64".
func (u *Utility) VerifySignature(message string, key id.Ed25519, signature string) (ok bool, err error) {
if len(message) == 0 || len(key) == 0 || len(signature) == 0 {
return false, EmptyInput
}
return utilities.VerifySignature([]byte(message), key, []byte(signature))
}
// VerifySignatureJSON verifies the signature in the JSON object _obj following
// the Matrix specification:
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
// If the _obj is a struct, the `json` tags will be honored.
func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
var err error
objJSON, ok := obj.(json.RawMessage)
if !ok {
objJSON, err = json.Marshal(obj)
if err != nil {
return false, err
}
}
sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
if !sig.Exists() || sig.Type != gjson.String {
return false, SignatureNotFound
}
objJSON, err = sjson.DeleteBytes(objJSON, "unsigned")
if err != nil {
return false, err
}
objJSON, err = sjson.DeleteBytes(objJSON, "signatures")
if err != nil {
return false, err
}
objJSONString := string(canonicaljson.CanonicalJSONAssumeValid(objJSON))
return u.VerifySignature(objJSONString, key, sig.Str)
}
// VerifySignatureJSON verifies the signature in the JSON object _obj following
// the Matrix specification:
// https://matrix.org/speculator/spec/drafts%2Fe2e/appendices.html#signing-json
// This function is a wrapper over Utility.VerifySignatureJSON that creates and
// destroys the Utility object transparently.
// If the _obj is a struct, the `json` tags will be honored.
func VerifySignatureJSON(obj interface{}, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
u := NewUtility()
defer u.Clear()
return u.VerifySignatureJSON(obj, userID, keyName, key)
}

View File

@@ -1,142 +0,0 @@
//go:build !nosas && !goolm
package olm
// #cgo LDFLAGS: -lolm -lstdc++
// #include <olm/olm.h>
// #include <olm/sas.h>
import "C"
import (
"crypto/rand"
"unsafe"
)
// SAS stores an Olm Short Authentication String (SAS) object.
type SAS struct {
int *C.OlmSAS
mem []byte
}
// NewBlankSAS initializes an empty SAS object.
func NewBlankSAS() *SAS {
memory := make([]byte, sasSize())
return &SAS{
int: C.olm_sas(unsafe.Pointer(&memory[0])),
mem: memory,
}
}
// sasSize is the size of a SAS object in bytes.
func sasSize() uint {
return uint(C.olm_sas_size())
}
// sasRandomLength is the number of random bytes needed to create an SAS object.
func (sas *SAS) sasRandomLength() uint {
return uint(C.olm_create_sas_random_length(sas.int))
}
// NewSAS creates a new SAS object.
func NewSAS() *SAS {
sas := NewBlankSAS()
random := make([]byte, sas.sasRandomLength()+1)
_, err := rand.Read(random)
if err != nil {
panic(NotEnoughGoRandom)
}
r := C.olm_create_sas(
(*C.OlmSAS)(sas.int),
unsafe.Pointer(&random[0]),
C.size_t(len(random)))
if r == errorVal() {
panic(sas.lastError())
} else {
return sas
}
}
// clear clears the memory used to back an SAS object.
func (sas *SAS) clear() uint {
return uint(C.olm_clear_sas(sas.int))
}
// lastError returns the most recent error to happen to an SAS object.
func (sas *SAS) lastError() error {
return convertError(C.GoString(C.olm_sas_last_error(sas.int)))
}
// pubkeyLength is the size of a public key in bytes.
func (sas *SAS) pubkeyLength() uint {
return uint(C.olm_sas_pubkey_length((*C.OlmSAS)(sas.int)))
}
// GetPubkey gets the public key for the SAS object.
func (sas *SAS) GetPubkey() []byte {
pubkey := make([]byte, sas.pubkeyLength())
r := C.olm_sas_get_pubkey(
(*C.OlmSAS)(sas.int),
unsafe.Pointer(&pubkey[0]),
C.size_t(len(pubkey)))
if r == errorVal() {
panic(sas.lastError())
}
return pubkey
}
// SetTheirKey sets the public key of the other user.
func (sas *SAS) SetTheirKey(theirKey []byte) error {
theirKeyCopy := make([]byte, len(theirKey))
copy(theirKeyCopy, theirKey)
r := C.olm_sas_set_their_key(
(*C.OlmSAS)(sas.int),
unsafe.Pointer(&theirKeyCopy[0]),
C.size_t(len(theirKeyCopy)))
if r == errorVal() {
return sas.lastError()
}
return nil
}
// GenerateBytes generates bytes to use for the short authentication string.
func (sas *SAS) GenerateBytes(info []byte, count uint) ([]byte, error) {
infoCopy := make([]byte, len(info))
copy(infoCopy, info)
output := make([]byte, count)
r := C.olm_sas_generate_bytes(
(*C.OlmSAS)(sas.int),
unsafe.Pointer(&infoCopy[0]),
C.size_t(len(infoCopy)),
unsafe.Pointer(&output[0]),
C.size_t(len(output)))
if r == errorVal() {
return nil, sas.lastError()
}
return output, nil
}
// macLength is the size of a message authentication code generated by olm_sas_calculate_mac.
func (sas *SAS) macLength() uint {
return uint(C.olm_sas_mac_length((*C.OlmSAS)(sas.int)))
}
// CalculateMAC generates a message authentication code (MAC) based on the shared secret.
func (sas *SAS) CalculateMAC(input []byte, info []byte) ([]byte, error) {
inputCopy := make([]byte, len(input))
copy(inputCopy, input)
infoCopy := make([]byte, len(info))
copy(infoCopy, info)
mac := make([]byte, sas.macLength())
r := C.olm_sas_calculate_mac(
(*C.OlmSAS)(sas.int),
unsafe.Pointer(&inputCopy[0]),
C.size_t(len(inputCopy)),
unsafe.Pointer(&infoCopy[0]),
C.size_t(len(infoCopy)),
unsafe.Pointer(&mac[0]),
C.size_t(len(mac)))
if r == errorVal() {
return nil, sas.lastError()
}
return mac, nil
}

View File

@@ -1,23 +0,0 @@
//go:build !nosas && goolm
package olm
import (
"maunium.net/go/mautrix/crypto/goolm/sas"
)
// SAS stores an Olm Short Authentication String (SAS) object.
type SAS struct {
sas.SAS
}
// NewSAS creates a new SAS object.
func NewSAS() *SAS {
newSAS, err := sas.New()
if err != nil {
panic(err)
}
return &SAS{
SAS: *newSAS,
}
}

30
vendor/maunium.net/go/mautrix/crypto/pkcs7/pkcs7.go generated vendored Normal file
View File

@@ -0,0 +1,30 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package pkcs7
import "bytes"
// Pad implements PKCS#7 padding as defined in [RFC2315]. It pads the plaintext
// to the given blockSize in the range [1, 255]. This is normally used in
// AES-CBC encryption.
//
// [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt
func Pad(plaintext []byte, blockSize int) []byte {
padding := blockSize - len(plaintext)%blockSize
return append(plaintext, bytes.Repeat([]byte{byte(padding)}, padding)...)
}
// Unpad implements PKCS#7 unpadding as defined in [RFC2315]. It unpads the
// plaintext by reading the padding amount from the last byte of the plaintext.
// This is normally used in AES-CBC decryption.
//
// [RFC2315]: https://www.ietf.org/rfc/rfc2315.txt
func Unpad(plaintext []byte) []byte {
length := len(plaintext)
unpadding := int(plaintext[length-1])
return plaintext[:length-unpadding]
}

View File

@@ -109,6 +109,7 @@ type InboundGroupSession struct {
MaxAge int64 MaxAge int64
MaxMessages int MaxMessages int
IsScheduled bool IsScheduled bool
KeyBackupVersion id.KeyBackupVersion
id id.SessionID id id.SessionID
} }

191
vendor/maunium.net/go/mautrix/crypto/sharing.go generated vendored Normal file
View File

@@ -0,0 +1,191 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package crypto
import (
"context"
"time"
"go.mau.fi/util/random"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
)
// Callback function to process a received secret.
//
// Returning true or an error will immediately return from the wait loop, returning false will continue waiting for new responses.
type SecretReceiverFunc func(string) (bool, error)
func (mach *OlmMachine) GetOrRequestSecret(ctx context.Context, name id.Secret, receiver SecretReceiverFunc, timeout time.Duration) (err error) {
ctx, cancel := context.WithTimeout(ctx, timeout)
defer cancel()
// always offer our stored secret first, if any
secret, err := mach.CryptoStore.GetSecret(ctx, name)
if err != nil {
return err
} else if secret != "" {
if ok, err := receiver(secret); ok || err != nil {
return err
}
}
requestID, secretChan := random.String(64), make(chan string, 5)
mach.secretLock.Lock()
mach.secretListeners[requestID] = secretChan
mach.secretLock.Unlock()
defer func() {
mach.secretLock.Lock()
delete(mach.secretListeners, requestID)
mach.secretLock.Unlock()
}()
// request secret from any device
err = mach.sendToOneDevice(ctx, mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestRequest,
RequestID: requestID,
Name: name,
RequestingDeviceID: mach.Client.DeviceID,
})
if err != nil {
return
}
// best effort cancel request from all devices when returning
defer func() {
go mach.sendToOneDevice(context.Background(), mach.Client.UserID, id.DeviceID("*"), event.ToDeviceSecretRequest, &event.SecretRequestEventContent{
Action: event.SecretRequestCancellation,
RequestID: requestID,
RequestingDeviceID: mach.Client.DeviceID,
})
}()
for {
select {
case <-ctx.Done():
return ctx.Err()
case secret = <-secretChan:
if ok, err := receiver(secret); err != nil {
return err
} else if ok {
return mach.CryptoStore.PutSecret(ctx, name, secret)
}
}
}
}
func (mach *OlmMachine) HandleSecretRequest(ctx context.Context, userID id.UserID, content *event.SecretRequestEventContent) {
log := mach.machOrContextLog(ctx).With().
Stringer("user_id", userID).
Stringer("requesting_device_id", content.RequestingDeviceID).
Stringer("action", content.Action).
Str("request_id", content.RequestID).
Stringer("secret", content.Name).
Logger()
log.Trace().Msg("Handling secret request")
if content.Action == event.SecretRequestCancellation {
log.Trace().Msg("Secret request cancellation is unimplemented, ignoring")
return
} else if content.Action != event.SecretRequestRequest {
log.Warn().Msg("Ignoring unknown secret request action")
return
}
// immediately ignore requests from other users
if userID != mach.Client.UserID || content.RequestingDeviceID == "" {
log.Debug().Msg("Secret request was not from our own device, ignoring")
return
}
if content.RequestingDeviceID == mach.Client.DeviceID {
log.Debug().Msg("Secret request was from this device, ignoring")
return
}
keys, err := mach.CryptoStore.GetCrossSigningKeys(ctx, mach.Client.UserID)
if err != nil {
log.Err(err).Msg("Failed to get cross signing keys from crypto store")
return
}
crossSigningKey, ok := keys[id.XSUsageSelfSigning]
if !ok {
log.Warn().Msg("Couldn't find self signing key to verify requesting device")
return
}
device, err := mach.GetOrFetchDevice(ctx, mach.Client.UserID, content.RequestingDeviceID)
if err != nil {
log.Err(err).Msg("Failed to get or fetch requesting device")
return
}
verified, err := mach.CryptoStore.IsKeySignedBy(ctx, mach.Client.UserID, device.SigningKey, mach.Client.UserID, crossSigningKey.Key)
if err != nil {
log.Err(err).Msg("Failed to check if requesting device is verified")
return
}
if !verified {
log.Warn().Msg("Requesting device is not verified, ignoring request")
return
}
secret, err := mach.CryptoStore.GetSecret(ctx, content.Name)
if err != nil {
log.Err(err).Msg("Failed to get secret from store")
return
} else if secret != "" {
log.Debug().Msg("Responding to secret request")
mach.SendEncryptedToDevice(ctx, device, event.ToDeviceSecretSend, event.Content{
Parsed: event.SecretSendEventContent{
RequestID: content.RequestID,
Secret: secret,
},
})
} else {
log.Debug().Msg("No stored secret found, secret request ignored")
}
}
func (mach *OlmMachine) receiveSecret(ctx context.Context, evt *DecryptedOlmEvent, content *event.SecretSendEventContent) {
log := mach.machOrContextLog(ctx).With().
Stringer("sender", evt.Sender).
Stringer("sender_device", evt.SenderDevice).
Str("request_id", content.RequestID).
Logger()
log.Trace().Msg("Handling secret send request")
// immediately ignore secrets from other users
if evt.Sender != mach.Client.UserID {
log.Warn().Msg("Secret send was not from our own device")
return
} else if content.Secret == "" {
log.Warn().Msg("We were sent an empty secret")
return
}
mach.secretLock.Lock()
secretChan := mach.secretListeners[content.RequestID]
mach.secretLock.Unlock()
if secretChan == nil {
log.Warn().Msg("We were sent a secret we didn't request")
return
}
// secret channel is buffered and we don't want to block
// at worst we drop _some_ of the responses
select {
case secretChan <- content.Secret:
default:
}
}

View File

@@ -0,0 +1,94 @@
// Copyright (c) 2024 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package signatures
import (
"encoding/base64"
"encoding/json"
"errors"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exgjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/crypto/goolm/crypto"
"maunium.net/go/mautrix/id"
)
var (
ErrEmptyInput = errors.New("empty input")
ErrSignatureNotFound = errors.New("input JSON doesn't contain signature from specified device")
)
// Signatures represents a set of signatures for some data from multiple users
// and keys.
type Signatures map[id.UserID]map[id.KeyID]string
// NewSingleSignature creates a new [Signatures] object with a single
// signature.
func NewSingleSignature(userID id.UserID, algorithm id.KeyAlgorithm, keyID string, signature string) Signatures {
return Signatures{
userID: {
id.NewKeyID(algorithm, keyID): signature,
},
}
}
// VerifySignature verifies an Ed25519 signature.
func VerifySignature(message []byte, key id.Ed25519, signature []byte) (ok bool, err error) {
if len(message) == 0 || len(key) == 0 || len(signature) == 0 {
return false, ErrEmptyInput
}
keyDecoded, err := base64.RawStdEncoding.DecodeString(key.String())
if err != nil {
return false, err
}
publicKey := crypto.Ed25519PublicKey(keyDecoded)
return publicKey.Verify(message, signature), nil
}
// VerifySignatureJSON verifies the signature in the given JSON object "obj"
// as described in [Appendix 3] of the Matrix Spec.
//
// This function is a wrapper over [Utility.VerifySignatureJSON] that creates
// and destroys the [Utility] object transparently.
//
// If the "obj" is not already a [json.RawMessage], it will re-encoded as JSON
// for the verification, so "json" tags will be honored.
//
// [Appendix 3]: https://spec.matrix.org/v1.9/appendices/#signing-json
func VerifySignatureJSON(obj any, userID id.UserID, keyName string, key id.Ed25519) (bool, error) {
var err error
objJSON, ok := obj.(json.RawMessage)
if !ok {
objJSON, err = json.Marshal(obj)
if err != nil {
return false, err
}
}
sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
if !sig.Exists() || sig.Type != gjson.String {
return false, ErrSignatureNotFound
}
objJSON, err = sjson.DeleteBytes(objJSON, "unsigned")
if err != nil {
return false, err
}
objJSON, err = sjson.DeleteBytes(objJSON, "signatures")
if err != nil {
return false, err
}
objJSONString := canonicaljson.CanonicalJSONAssumeValid(objJSON)
sigBytes, err := base64.RawStdEncoding.DecodeString(sig.Str)
if err != nil {
return false, err
}
return VerifySignature(objJSONString, key, sigBytes)
}

View File

@@ -21,6 +21,7 @@ import (
"go.mau.fi/util/dbutil" "go.mau.fi/util/dbutil"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/goolm/cipher"
"maunium.net/go/mautrix/crypto/olm" "maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/sql_store_upgrade" "maunium.net/go/mautrix/crypto/sql_store_upgrade"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
@@ -124,20 +125,21 @@ func (store *SQLCryptoStore) PutAccount(ctx context.Context, account *OlmAccount
store.Account = account store.Account = account
bytes := account.Internal.Pickle(store.PickleKey) bytes := account.Internal.Pickle(store.PickleKey)
_, err := store.DB.Exec(ctx, ` _, err := store.DB.Exec(ctx, `
INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id) VALUES ($1, $2, $3, $4, $5) INSERT INTO crypto_account (device_id, shared, sync_token, account, account_id, key_backup_version) VALUES ($1, $2, $3, $4, $5, $6)
ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token, ON CONFLICT (account_id) DO UPDATE SET shared=excluded.shared, sync_token=excluded.sync_token,
account=excluded.account, account_id=excluded.account_id account=excluded.account, account_id=excluded.account_id,
`, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID) key_backup_version=excluded.key_backup_version
`, store.DeviceID, account.Shared, store.SyncToken, bytes, store.AccountID, account.KeyBackupVersion)
return err return err
} }
// GetAccount retrieves an OlmAccount from the database. // GetAccount retrieves an OlmAccount from the database.
func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) { func (store *SQLCryptoStore) GetAccount(ctx context.Context) (*OlmAccount, error) {
if store.Account == nil { if store.Account == nil {
row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account FROM crypto_account WHERE account_id=$1", store.AccountID) row := store.DB.QueryRow(ctx, "SELECT shared, sync_token, account, key_backup_version FROM crypto_account WHERE account_id=$1", store.AccountID)
acc := &OlmAccount{Internal: *olm.NewBlankAccount()} acc := &OlmAccount{Internal: *olm.NewBlankAccount()}
var accountBytes []byte var accountBytes []byte
err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes) err := row.Scan(&acc.Shared, &store.SyncToken, &accountBytes, &acc.KeyBackupVersion)
if err == sql.ErrNoRows { if err == sql.ErrNoRows {
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {
@@ -284,17 +286,18 @@ func (store *SQLCryptoStore) PutGroupSession(ctx context.Context, roomID id.Room
_, err = store.DB.Exec(ctx, ` _, err = store.DB.Exec(ctx, `
INSERT INTO crypto_megolm_inbound_session ( INSERT INTO crypto_megolm_inbound_session (
session_id, sender_key, signing_key, room_id, session, forwarding_chains, session_id, sender_key, signing_key, room_id, session, forwarding_chains,
ratchet_safety, received_at, max_age, max_messages, is_scheduled, account_id ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version, account_id
) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12) ) VALUES ($1, $2, $3, $4, $5, $6, $7, $8, $9, $10, $11, $12, $13)
ON CONFLICT (session_id, account_id) DO UPDATE ON CONFLICT (session_id, account_id) DO UPDATE
SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key, SET withheld_code=NULL, withheld_reason=NULL, sender_key=excluded.sender_key, signing_key=excluded.signing_key,
room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains, room_id=excluded.room_id, session=excluded.session, forwarding_chains=excluded.forwarding_chains,
ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at, ratchet_safety=excluded.ratchet_safety, received_at=excluded.received_at,
max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled max_age=excluded.max_age, max_messages=excluded.max_messages, is_scheduled=excluded.is_scheduled,
key_backup_version=excluded.key_backup_version
`, `,
sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains, sessionID, senderKey, session.SigningKey, roomID, sessionBytes, forwardingChains,
ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages), ratchetSafety, datePtr(session.ReceivedAt), intishPtr(session.MaxAge), intishPtr(session.MaxMessages),
session.IsScheduled, store.AccountID, session.IsScheduled, session.KeyBackupVersion, store.AccountID,
) )
return err return err
} }
@@ -306,12 +309,13 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
var receivedAt sql.NullTime var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64 var maxAge, maxMessages sql.NullInt64
var isScheduled bool var isScheduled bool
var version id.KeyBackupVersion
err := store.DB.QueryRow(ctx, ` err := store.DB.QueryRow(ctx, `
SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled SELECT sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session FROM crypto_megolm_inbound_session
WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`, WHERE room_id=$1 AND (sender_key=$2 OR $2 = '') AND session_id=$3 AND account_id=$4`,
roomID, senderKey, sessionID, store.AccountID, roomID, senderKey, sessionID, store.AccountID,
).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) ).Scan(&senderKeyDB, &signingKey, &sessionBytes, &forwardingChains, &withheldCode, &withheldReason, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
return nil, nil return nil, nil
} else if err != nil { } else if err != nil {
@@ -341,6 +345,7 @@ func (store *SQLCryptoStore) GetGroupSession(ctx context.Context, roomID id.Room
MaxAge: maxAge.Int64, MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64), MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled, IsScheduled: isScheduled,
KeyBackupVersion: version,
}, nil }, nil
} }
@@ -468,7 +473,8 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
var receivedAt sql.NullTime var receivedAt sql.NullTime
var maxAge, maxMessages sql.NullInt64 var maxAge, maxMessages sql.NullInt64
var isScheduled bool var isScheduled bool
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled) var version id.KeyBackupVersion
err := rows.Scan(&roomID, &senderKey, &signingKey, &sessionBytes, &forwardingChains, &ratchetSafetyBytes, &receivedAt, &maxAge, &maxMessages, &isScheduled, &version)
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -484,31 +490,35 @@ func (store *SQLCryptoStore) scanInboundGroupSession(rows dbutil.Scannable) (*In
MaxAge: maxAge.Int64, MaxAge: maxAge.Int64,
MaxMessages: int(maxMessages.Int64), MaxMessages: int(maxMessages.Int64),
IsScheduled: isScheduled, IsScheduled: isScheduled,
KeyBackupVersion: version,
}, nil }, nil
} }
func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) ([]*InboundGroupSession, error) { func (store *SQLCryptoStore) GetGroupSessionsForRoom(ctx context.Context, roomID id.RoomID) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, ` rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`, FROM crypto_megolm_inbound_session WHERE room_id=$1 AND account_id=$2 AND session IS NOT NULL`,
roomID, store.AccountID, roomID, store.AccountID,
) )
if err != nil { return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
return nil, err
}
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
} }
func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) ([]*InboundGroupSession, error) { func (store *SQLCryptoStore) GetAllGroupSessions(ctx context.Context) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, ` rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, withheld_code, withheld_reason, ratchet_safety, received_at, max_age, max_messages, is_scheduled SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE account_id=$2 AND session IS NOT NULL`, FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL`,
store.AccountID, store.AccountID,
) )
if err != nil { return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
return nil, err
} }
return dbutil.NewRowIter(rows, store.scanInboundGroupSession).AsList()
func (store *SQLCryptoStore) GetGroupSessionsWithoutKeyBackupVersion(ctx context.Context, version id.KeyBackupVersion) dbutil.RowIter[*InboundGroupSession] {
rows, err := store.DB.Query(ctx, `
SELECT room_id, sender_key, signing_key, session, forwarding_chains, ratchet_safety, received_at, max_age, max_messages, is_scheduled, key_backup_version
FROM crypto_megolm_inbound_session WHERE account_id=$1 AND session IS NOT NULL AND key_backup_version != $2`,
store.AccountID, version,
)
return dbutil.NewRowIterWithError(rows, store.scanInboundGroupSession, err)
} }
// AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices. // AddOutboundGroupSession stores an outbound Megolm session, along with the information about the room and involved devices.
@@ -568,6 +578,20 @@ func (store *SQLCryptoStore) RemoveOutboundGroupSession(ctx context.Context, roo
return err return err
} }
func (store *SQLCryptoStore) MarkOutboundGroupSessionShared(ctx context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) error {
_, err := store.DB.Exec(ctx, "INSERT INTO crypto_megolm_outbound_session_shared (user_id, identity_key, session_id) VALUES ($1, $2, $3)", userID, identityKey, sessionID)
return err
}
func (store *SQLCryptoStore) IsOutboundGroupSessionShared(ctx context.Context, userID id.UserID, identityKey id.IdentityKey, sessionID id.SessionID) (shared bool, err error) {
err = store.DB.QueryRow(ctx, `SELECT TRUE FROM crypto_megolm_outbound_session_shared WHERE user_id=$1 AND identity_key=$2 AND session_id=$3`,
userID, identityKey, sessionID).Scan(&shared)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return
}
// ValidateMessageIndex returns whether the given event information match the ones stored in the database // ValidateMessageIndex returns whether the given event information match the ones stored in the database
// for the given sender key, session ID and index. If the index hasn't been stored, this will store it. // for the given sender key, session ID and index. If the index hasn't been stored, this will store it.
func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) { func (store *SQLCryptoStore) ValidateMessageIndex(ctx context.Context, senderKey id.SenderKey, sessionID id.SessionID, eventID id.EventID, index uint, timestamp int64) (bool, error) {
@@ -845,3 +869,32 @@ func (store *SQLCryptoStore) DropSignaturesByKey(ctx context.Context, userID id.
} }
return count, nil return count, nil
} }
func (store *SQLCryptoStore) PutSecret(ctx context.Context, name id.Secret, value string) error {
bytes, err := cipher.Pickle(store.PickleKey, []byte(value))
if err != nil {
return err
}
_, err = store.DB.Exec(ctx, `
INSERT INTO crypto_secrets (name, secret) VALUES ($1, $2)
ON CONFLICT (name) DO UPDATE SET secret=excluded.secret
`, name, bytes)
return err
}
func (store *SQLCryptoStore) GetSecret(ctx context.Context, name id.Secret) (value string, err error) {
var bytes []byte
err = store.DB.QueryRow(ctx, `SELECT secret FROM crypto_secrets WHERE name=$1`, name).Scan(&bytes)
if errors.Is(err, sql.ErrNoRows) {
return "", nil
} else if err != nil {
return "", err
}
bytes, err = cipher.Unpickle(store.PickleKey, bytes)
return string(bytes), err
}
func (store *SQLCryptoStore) DeleteSecret(ctx context.Context, name id.Secret) (err error) {
_, err = store.DB.Exec(ctx, "DELETE FROM crypto_secrets WHERE name=$1", name)
return
}

View File

@@ -1,10 +1,11 @@
-- v0 -> v11: Latest revision -- v0 -> v14 (compatible with v9+): Latest revision
CREATE TABLE IF NOT EXISTS crypto_account ( CREATE TABLE IF NOT EXISTS crypto_account (
account_id TEXT PRIMARY KEY, account_id TEXT PRIMARY KEY,
device_id TEXT NOT NULL, device_id TEXT NOT NULL,
shared BOOLEAN NOT NULL, shared BOOLEAN NOT NULL,
sync_token TEXT NOT NULL, sync_token TEXT NOT NULL,
account bytea NOT NULL account bytea NOT NULL,
key_backup_version TEXT NOT NULL DEFAULT ''
); );
CREATE TABLE IF NOT EXISTS crypto_message_index ( CREATE TABLE IF NOT EXISTS crypto_message_index (
@@ -58,6 +59,7 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_inbound_session (
max_age BIGINT, max_age BIGINT,
max_messages INTEGER, max_messages INTEGER,
is_scheduled BOOLEAN NOT NULL DEFAULT false, is_scheduled BOOLEAN NOT NULL DEFAULT false,
key_backup_version TEXT NOT NULL DEFAULT '',
PRIMARY KEY (account_id, session_id) PRIMARY KEY (account_id, session_id)
); );
@@ -75,6 +77,14 @@ CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session (
PRIMARY KEY (account_id, room_id) PRIMARY KEY (account_id, room_id)
); );
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session_shared (
user_id TEXT NOT NULL,
identity_key CHAR(43) NOT NULL,
session_id CHAR(43) NOT NULL,
PRIMARY KEY (user_id, identity_key, session_id)
);
CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys ( CREATE TABLE IF NOT EXISTS crypto_cross_signing_keys (
user_id TEXT, user_id TEXT,
usage TEXT, usage TEXT,
@@ -93,3 +103,8 @@ CREATE TABLE IF NOT EXISTS crypto_cross_signing_signatures (
signature CHAR(88) NOT NULL, signature CHAR(88) NOT NULL,
PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key) PRIMARY KEY (signed_user_id, signed_key, signer_user_id, signer_key)
); );
CREATE TABLE IF NOT EXISTS crypto_secrets (
name TEXT PRIMARY KEY NOT NULL,
secret bytea NOT NULL
);

View File

@@ -0,0 +1,5 @@
-- v12 (compatible with v9+): Add crypto_secrets table
CREATE TABLE IF NOT EXISTS crypto_secrets (
name TEXT PRIMARY KEY NOT NULL,
secret bytea NOT NULL
);

View File

@@ -0,0 +1,9 @@
-- v13 (compatible with v9+): Add crypto_megolm_outbound_session_shared table
CREATE TABLE IF NOT EXISTS crypto_megolm_outbound_session_shared (
user_id TEXT NOT NULL,
identity_key CHAR(43) NOT NULL,
session_id CHAR(43) NOT NULL,
PRIMARY KEY (user_id, identity_key, session_id)
);

Some files were not shown because too many files have changed in this diff Show More