automatically ignore known forwarded addresses, fixes #64

This commit is contained in:
Aine
2023-09-18 12:35:37 +03:00
parent e90925eceb
commit 60b4386dd8
187 changed files with 4070 additions and 2667 deletions

View File

@@ -1,3 +1,38 @@
## v0.16.1 (2023-09-16)
* **Breaking change *(id)*** Updated user ID localpart encoding to not encode
`+` as per [MSC4009].
* *(bridge)* Added bridge utility to handle double puppeting logins.
* The utility supports automatic logins with all three current methods
(shared secret, legacy appservice, new appservice).
* *(appservice)* Added warning logs and timeout on appservice event handling.
* Defaults to warning after 30 seconds and timeout 15 minutes after that.
* Timeouts can be adjusted or disabled by setting `ExecSync` variables in the
`EventProcessor`.
* *(crypto/olm)* Added `PkDecryption` wrapper.
[MSC4009]: https://github.com/matrix-org/matrix-spec-proposals/pull/4009
## v0.16.0 (2023-08-16)
* Bumped minimum Go version to 1.20.
* **Breaking change *(util)*** Moved package to [go.mau.fi/util](https://go.mau.fi/util/)
* *(event)* Removed MSC2716 `historical` field in the `m.room.power_levels`
event content struct.
* *(bridge)* Added `--version-json` flag to print bridge version info as JSON.
* *(appservice)* Added option to use custom transaction handler for websocket mode.
## v0.15.4 (2023-07-16)
* *(client)* Deprecated MSC2716 methods and added new Beeper-specific batch
send methods, as upstream MSC2716 support has been abandoned.
* *(client)* Added proper error handling and automatic retries to media
downloads.
* *(crypto, bridge)* Added option to remove all keys that were received before
the automatic ratcheting was implemented (in v0.15.1).
* *(dbutil)* Added `JSON` utility for writing/reading arbitrary JSON objects to
the db conveniently without manually de/serializing.
## v0.15.3 (2023-06-16)
* *(synapseadmin)* Added wrappers for some Synapse admin API endpoints.

View File

@@ -14,11 +14,11 @@ import (
"net/url"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/retryafter"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix/event"
@@ -258,45 +258,52 @@ const (
LogRequestIDContextKey
)
func (cli *Client) LogRequest(req *http.Request) {
func (cli *Client) RequestStart(req *http.Request) {
if cli.RequestHook != nil {
cli.RequestHook(req)
}
evt := zerolog.Ctx(req.Context()).Debug().
Str("method", req.Method).
Str("url", req.URL.String())
body := req.Context().Value(LogBodyContextKey)
if body != nil {
evt.Interface("body", body)
}
evt.Msg("Sending request")
}
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, handlerErr error, contentLength int, duration time.Duration) {
if cli.ResponseHook != nil {
cli.ResponseHook(req, resp, duration)
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) {
var evt *zerolog.Event
if err != nil {
evt = zerolog.Ctx(req.Context()).Err(err)
} else if handlerErr != nil {
evt = zerolog.Ctx(req.Context()).Warn().
AnErr("body_parse_err", handlerErr)
} else {
evt = zerolog.Ctx(req.Context()).Debug()
}
mime := resp.Header.Get("Content-Type")
length := resp.ContentLength
if length == -1 && contentLength > 0 {
length = int64(contentLength)
}
path := strings.TrimPrefix(req.URL.Path, cli.HomeserverURL.Path)
path = strings.TrimPrefix(path, "/_matrix/client")
evt := zerolog.Ctx(req.Context()).Debug().
evt = evt.
Str("method", req.Method).
Str("path", path).
Int("status_code", resp.StatusCode).
Int64("response_length", length).
Str("response_mime", mime).
Str("url", req.URL.String()).
Dur("duration", duration)
if handlerErr != nil {
evt.AnErr("body_parse_err", handlerErr)
if resp != nil {
if cli.ResponseHook != nil {
cli.ResponseHook(req, resp, duration)
}
mime := resp.Header.Get("Content-Type")
length := resp.ContentLength
if length == -1 && contentLength > 0 {
length = int64(contentLength)
}
evt = evt.Int("status_code", resp.StatusCode).
Int64("response_length", length).
Str("response_mime", mime)
if serverRequestID := resp.Header.Get("X-Beeper-Request-ID"); serverRequestID != "" {
evt.Str("beeper_request_id", serverRequestID)
}
}
if serverRequestID := resp.Header.Get("X-Beeper-Request-ID"); serverRequestID != "" {
evt.Str("beeper_request_id", serverRequestID)
if body := req.Context().Value(LogBodyContextKey); body != nil {
evt.Interface("req_body", body)
}
if err != nil {
evt.Msg("Request failed")
} else if handlerErr != nil {
evt.Msg("Request parsing failed")
} else {
evt.Msg("Request completed")
}
evt.Msg("Request completed")
}
func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) {
@@ -520,38 +527,8 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
}
}
// parseBackoffFromResponse extracts the backoff time specified in the Retry-After header if present. See
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After.
func parseBackoffFromResponse(req *http.Request, res *http.Response, now time.Time, fallback time.Duration) time.Duration {
retryAfterHeaderValue := res.Header.Get("Retry-After")
if retryAfterHeaderValue == "" {
return fallback
}
if t, err := time.Parse(http.TimeFormat, retryAfterHeaderValue); err == nil {
return t.Sub(now)
}
if seconds, err := strconv.Atoi(retryAfterHeaderValue); err == nil {
return time.Duration(seconds) * time.Second
}
zerolog.Ctx(req.Context()).Warn().
Str("retry_after", retryAfterHeaderValue).
Msg("Failed to parse Retry-After header value")
return fallback
}
func (cli *Client) shouldRetry(res *http.Response) bool {
return res.StatusCode == http.StatusBadGateway ||
res.StatusCode == http.StatusServiceUnavailable ||
res.StatusCode == http.StatusGatewayTimeout ||
(res.StatusCode == http.StatusTooManyRequests && !cli.IgnoreRateLimit)
}
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
cli.LogRequest(req)
cli.RequestStart(req)
startTime := time.Now()
res, err := cli.Client.Do(req)
duration := time.Now().Sub(startTime)
@@ -562,29 +539,29 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
if retries > 0 {
return cli.doRetry(req, err, retries, backoff, responseJSON, handler)
}
return nil, HTTPError{
err = HTTPError{
Request: req,
Response: res,
Message: "request error",
WrappedError: err,
}
cli.LogRequestDone(req, res, err, nil, 0, duration)
return nil, err
}
if retries > 0 && cli.shouldRetry(res) {
if res.StatusCode == http.StatusTooManyRequests {
backoff = parseBackoffFromResponse(req, res, time.Now(), backoff)
}
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler)
}
var body []byte
if res.StatusCode < 200 || res.StatusCode >= 300 {
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, nil, len(body), duration)
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
} else {
body, err = handler(req, res, responseJSON)
cli.LogRequestDone(req, res, err, len(body), duration)
cli.LogRequestDone(req, res, nil, err, len(body), duration)
}
return body, err
}
@@ -1371,26 +1348,80 @@ func (cli *Client) Download(mxcURL id.ContentURI) (io.ReadCloser, error) {
}
func (cli *Client) DownloadContext(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) {
_, resp, err := cli.downloadContext(ctx, mxcURL)
return resp.Body, err
resp, err := cli.downloadContext(ctx, mxcURL)
if err != nil {
return nil, err
}
return resp.Body, nil
}
func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Request, *http.Response, error) {
func (cli *Client) doMediaRetry(req *http.Request, cause error, retries int, backoff time.Duration) (*http.Response, error) {
log := zerolog.Ctx(req.Context())
if req.Body != nil {
if req.GetBody == nil {
log.Warn().Msg("Failed to get new body to retry request: GetBody is nil")
return nil, cause
}
var err error
req.Body, err = req.GetBody()
if err != nil {
log.Warn().Err(err).Msg("Failed to get new body to retry request")
return nil, cause
}
}
log.Warn().Err(cause).
Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Request failed, retrying")
time.Sleep(backoff)
return cli.doMediaRequest(req, retries-1, backoff*2)
}
func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.Duration) (*http.Response, error) {
cli.RequestStart(req)
startTime := time.Now()
res, err := cli.Client.Do(req)
duration := time.Now().Sub(startTime)
if err != nil {
if retries > 0 {
return cli.doMediaRetry(req, err, retries, backoff)
}
err = HTTPError{
Request: req,
Response: res,
Message: "request error",
WrappedError: err,
}
cli.LogRequestDone(req, res, err, nil, 0, duration)
return nil, err
}
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
return cli.doMediaRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff)
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
var body []byte
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, err, nil, len(body), duration)
} else {
cli.LogRequestDone(req, res, nil, nil, -1, duration)
}
return res, err
}
func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) {
ctxLog := zerolog.Ctx(ctx)
if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger {
ctx = cli.Log.WithContext(ctx)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil)
if err != nil {
return req, nil, err
return nil, err
}
req.Header.Set("User-Agent", cli.UserAgent+" (media downloader)")
cli.LogRequest(req)
if resp, err := cli.Client.Do(req); err != nil {
return req, nil, err
} else {
return req, resp, nil
}
return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second)
}
func (cli *Client) DownloadBytes(mxcURL id.ContentURI) ([]byte, error) {
@@ -1398,18 +1429,11 @@ func (cli *Client) DownloadBytes(mxcURL id.ContentURI) ([]byte, error) {
}
func (cli *Client) DownloadBytesContext(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) {
req, resp, err := cli.downloadContext(ctx, mxcURL)
resp, err := cli.downloadContext(ctx, mxcURL)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 || resp.StatusCode < 200 {
respErr := &RespError{}
if _ = json.NewDecoder(resp.Body).Decode(respErr); respErr.ErrCode == "" {
respErr = nil
}
return nil, HTTPError{Request: req, Response: resp, RespError: respErr}
}
return io.ReadAll(resp.Body)
}
@@ -1980,7 +2004,7 @@ func (cli *Client) PutPushRule(scope string, kind pushrules.PushRuleType, ruleID
// BatchSend sends a batch of historical events into a room. This is only available for appservices.
//
// See https://github.com/matrix-org/matrix-doc/pull/2716 for more info.
// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead.
func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) {
path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"}
query := map[string]string{
@@ -2011,6 +2035,12 @@ func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, e
return
}
func (cli *Client) BeeperBatchSend(roomID id.RoomID, req *ReqBeeperBatchSend) (resp *RespBeeperBatchSend, err error) {
u := cli.BuildClientURL("unstable", "com.beeper.backfill", "rooms", roomID, "batch_send")
_, err = cli.MakeRequest(http.MethodPost, u, req, &resp)
return
}
func (cli *Client) BeeperMergeRooms(req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) {
urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "merge")
_, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp)

View File

@@ -14,13 +14,13 @@ import (
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/sqlstatestore"
"maunium.net/go/mautrix/util/dbutil"
)
type CryptoHelper struct {

View File

@@ -339,6 +339,9 @@ func (mach *OlmMachine) handleBeeperRoomKeyAck(ctx context.Context, sender id.Us
log.Err(err).Msg("Failed to get group session to check if it should be redacted")
}
return
} else if sess == nil {
log.Warn().Msg("Got key backup ack for unknown session")
return
}
log = log.With().
Str("sender_key", sess.SenderKey.String()).

View File

@@ -622,8 +622,8 @@ func (mach *OlmMachine) ShareKeys(ctx context.Context, currentOTKCount int) erro
start := time.Now()
mach.otkUploadLock.Lock()
defer mach.otkUploadLock.Unlock()
if mach.lastOTKUpload.Add(1 * time.Minute).After(start) {
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests")
if mach.lastOTKUpload.Add(1*time.Minute).After(start) || currentOTKCount < 0 {
log.Debug().Msg("Checking OTK count from server due to suspiciously close share keys requests or negative OTK count")
resp, err := mach.Client.UploadKeys(&mautrix.ReqUploadKeys{})
if err != nil {
return fmt.Errorf("failed to check current OTK counts: %w", err)

View File

@@ -109,3 +109,62 @@ func (p *PkSigning) SignJSON(obj interface{}) (string, error) {
func (p *PkSigning) lastError() error {
return convertError(C.GoString(C.olm_pk_signing_last_error((*C.OlmPkSigning)(p.int))))
}
type PkDecryption struct {
int *C.OlmPkDecryption
mem []byte
PublicKey []byte
}
func pkDecryptionSize() uint {
return uint(C.olm_pk_decryption_size())
}
func pkDecryptionPublicKeySize() uint {
return uint(C.olm_pk_key_length())
}
func NewPkDecryption(privateKey []byte) (*PkDecryption, error) {
memory := make([]byte, pkDecryptionSize())
p := &PkDecryption{
int: C.olm_pk_decryption(unsafe.Pointer(&memory[0])),
mem: memory,
}
p.Clear()
pubKey := make([]byte, pkDecryptionPublicKeySize())
if C.olm_pk_key_from_private((*C.OlmPkDecryption)(p.int),
unsafe.Pointer(&pubKey[0]), C.size_t(len(pubKey)),
unsafe.Pointer(&privateKey[0]), C.size_t(len(privateKey))) == errorVal() {
return nil, p.lastError()
}
p.PublicKey = pubKey
return p, nil
}
func (p *PkDecryption) Decrypt(ephemeralKey []byte, mac []byte, ciphertext []byte) ([]byte, error) {
maxPlaintextLength := uint(C.olm_pk_max_plaintext_length((*C.OlmPkDecryption)(p.int), C.size_t(len(ciphertext))))
plaintext := make([]byte, maxPlaintextLength)
size := C.olm_pk_decrypt((*C.OlmPkDecryption)(p.int),
unsafe.Pointer(&ephemeralKey[0]), C.size_t(len(ephemeralKey)),
unsafe.Pointer(&mac[0]), C.size_t(len(mac)),
unsafe.Pointer(&ciphertext[0]), C.size_t(len(ciphertext)),
unsafe.Pointer(&plaintext[0]), C.size_t(len(plaintext)))
if size == errorVal() {
return nil, p.lastError()
}
return plaintext[:size], nil
}
// Clear clears the underlying memory of a PkDecryption object.
func (p *PkDecryption) Clear() {
C.olm_clear_pk_decryption((*C.OlmPkDecryption)(p.int))
}
// lastError returns the last error that happened in relation to this PkDecryption object.
func (p *PkDecryption) lastError() error {
return convertError(C.GoString(C.olm_pk_decryption_last_error((*C.OlmPkDecryption)(p.int))))
}

View File

@@ -11,10 +11,10 @@ import (
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/exgjson"
"maunium.net/go/mautrix/crypto/canonicaljson"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util"
)
// Utility stores the necessary state to perform hash and signature
@@ -115,7 +115,7 @@ func (u *Utility) VerifySignatureJSON(obj interface{}, userID id.UserID, keyName
return false, err
}
}
sig := gjson.GetBytes(objJSON, util.GJSONPath("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
sig := gjson.GetBytes(objJSON, exgjson.Path("signatures", string(userID), fmt.Sprintf("ed25519:%s", keyName)))
if !sig.Exists() || sig.Type != gjson.String {
return false, SignatureNotFound
}

View File

@@ -18,13 +18,13 @@ import (
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/crypto/olm"
"maunium.net/go/mautrix/crypto/sql_store_upgrade"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
var PostgresArrayWrapper func(interface{}) interface {
@@ -418,6 +418,22 @@ func (store *SQLCryptoStore) RedactExpiredGroupSessions() ([]id.SessionID, error
return sessionIDs, err
}
func (store *SQLCryptoStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
res, err := store.DB.Query(`
UPDATE crypto_megolm_inbound_session
SET withheld_code=$1, withheld_reason=$2, session=NULL, forwarding_chains=NULL
WHERE account_id=$3 AND session IS NOT NULL AND received_at IS NULL
RETURNING session_id
`, event.RoomKeyWithheldBeeperRedacted, "Session redacted: outdated", store.AccountID)
var sessionIDs []id.SessionID
for res.Next() {
var sessionID id.SessionID
_ = res.Scan(&sessionID)
sessionIDs = append(sessionIDs, sessionID)
}
return sessionIDs, err
}
func (store *SQLCryptoStore) PutWithheldGroupSession(content event.RoomKeyWithheldEventContent) error {
_, err := store.DB.Exec("INSERT INTO crypto_megolm_inbound_session (session_id, sender_key, room_id, withheld_code, withheld_reason, received_at, account_id) VALUES ($1, $2, $3, $4, $5, $6, $7)",
content.SessionID, content.SenderKey, content.RoomID, content.Code, content.Reason, time.Now().UTC(), store.AccountID)

View File

@@ -10,7 +10,7 @@ import (
"embed"
"fmt"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/util/dbutil"
)
var Table dbutil.UpgradeTable

View File

@@ -59,6 +59,8 @@ type Store interface {
RedactGroupSessions(id.RoomID, id.SenderKey, string) ([]id.SessionID, error)
// RedactExpiredGroupSessions removes the session data for all inbound Megolm sessions that have expired.
RedactExpiredGroupSessions() ([]id.SessionID, error)
// RedactOutdatedGroupSessions removes the session data for all inbound Megolm sessions that are lacking the expiration metadata.
RedactOutdatedGroupSessions() ([]id.SessionID, error)
// PutWithheldGroupSession tells the store that a specific Megolm session was withheld.
PutWithheldGroupSession(event.RoomKeyWithheldEventContent) error
// GetWithheldGroupSession gets the event content that was previously inserted with PutWithheldGroupSession.
@@ -317,6 +319,10 @@ func (gs *MemoryStore) RedactExpiredGroupSessions() ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
func (gs *MemoryStore) RedactOutdatedGroupSessions() ([]id.SessionID, error) {
return nil, fmt.Errorf("not implemented")
}
func (gs *MemoryStore) getWithheldGroupSessions(roomID id.RoomID, senderKey id.SenderKey) map[id.SessionID]*event.RoomKeyWithheldEventContent {
room, ok := gs.WithheldGroupSessions[roomID]
if !ok {

View File

@@ -16,10 +16,9 @@ import (
"math/rand"
"strings"
"go.mau.fi/util/base58"
"golang.org/x/crypto/hkdf"
"golang.org/x/crypto/pbkdf2"
"maunium.net/go/mautrix/util/base58"
)
const (

View File

@@ -42,6 +42,12 @@ type BeeperMessageStatusEventContent struct {
LastRetry id.EventID `json:"last_retry,omitempty"`
MutateEventKey string `json:"mutate_event_key,omitempty"`
// Indicates the set of users to whom the event was delivered. If nil, then
// the client should not expect delivered status at any later point. If not
// nil (even if empty), this field indicates which users the event was
// delivered to.
DeliveredToUsers *[]id.UserID `json:"delivered_to_users,omitempty"`
}
type BeeperRetryMetadata struct {

View File

@@ -33,6 +33,8 @@ const (
MsgFile MessageType = "m.file"
MsgVerificationRequest MessageType = "m.key.verification.request"
MsgBeeperGallery MessageType = "com.beeper.gallery"
)
// Format specifies the format of the formatted_body in m.room.message events.
@@ -110,7 +112,10 @@ type MessageEventContent struct {
replyFallbackRemoved bool
MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"`
MessageSendRetry *BeeperRetryMetadata `json:"com.beeper.message_send_retry,omitempty"`
BeeperGalleryImages []*MessageEventContent `json:"com.beeper.gallery.images,omitempty"`
BeeperGalleryCaption string `json:"com.beeper.gallery.caption,omitempty"`
BeeperGalleryCaptionHTML string `json:"com.beeper.gallery.caption_html,omitempty"`
}
func (content *MessageEventContent) GetRelatesTo() *RelatesTo {

View File

@@ -27,11 +27,10 @@ type PowerLevelsEventContent struct {
StateDefaultPtr *int `json:"state_default,omitempty"`
InvitePtr *int `json:"invite,omitempty"`
KickPtr *int `json:"kick,omitempty"`
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
HistoricalPtr *int `json:"historical,omitempty"`
InvitePtr *int `json:"invite,omitempty"`
KickPtr *int `json:"kick,omitempty"`
BanPtr *int `json:"ban,omitempty"`
RedactPtr *int `json:"redact,omitempty"`
}
func copyPtr(ptr *int) *int {
@@ -66,11 +65,10 @@ func (pl *PowerLevelsEventContent) Clone() *PowerLevelsEventContent {
Notifications: pl.Notifications.Clone(),
InvitePtr: copyPtr(pl.InvitePtr),
KickPtr: copyPtr(pl.KickPtr),
BanPtr: copyPtr(pl.BanPtr),
RedactPtr: copyPtr(pl.RedactPtr),
HistoricalPtr: copyPtr(pl.HistoricalPtr),
InvitePtr: copyPtr(pl.InvitePtr),
KickPtr: copyPtr(pl.KickPtr),
BanPtr: copyPtr(pl.BanPtr),
RedactPtr: copyPtr(pl.RedactPtr),
}
}
@@ -122,13 +120,6 @@ func (pl *PowerLevelsEventContent) Redact() int {
return 50
}
func (pl *PowerLevelsEventContent) Historical() int {
if pl.HistoricalPtr != nil {
return *pl.HistoricalPtr
}
return 100
}
func (pl *PowerLevelsEventContent) StateDefault() int {
if pl.StateDefaultPtr != nil {
return *pl.StateDefaultPtr

View File

@@ -11,8 +11,6 @@ import (
"regexp"
"strings"
"golang.org/x/net/html"
"maunium.net/go/mautrix/id"
)
@@ -23,7 +21,7 @@ func TrimReplyFallbackHTML(html string) string {
}
func TrimReplyFallbackText(text string) string {
if !strings.HasPrefix(text, "> ") || !strings.Contains(text, "\n") {
if (!strings.HasPrefix(text, "> <") && !strings.HasPrefix(text, "> * <")) || !strings.Contains(text, "\n") {
return text
}
@@ -59,7 +57,7 @@ func (evt *Event) GenerateReplyFallbackHTML() string {
parsedContent.RemoveReplyFallback()
body := parsedContent.FormattedBody
if len(body) == 0 {
body = strings.ReplaceAll(html.EscapeString(parsedContent.Body), "\n", "<br/>")
body = TextToHTML(parsedContent.Body)
}
senderDisplayName := evt.Sender

View File

@@ -170,6 +170,7 @@ type ModPolicyContent struct {
Recommendation string `json:"recommendation"`
}
// Deprecated: MSC2716 has been abandoned
type InsertionMarkerContent struct {
InsertionID id.EventID `json:"org.matrix.msc2716.marker.insertion"`
Timestamp int64 `json:"com.beeper.timestamp,omitempty"`

View File

@@ -187,7 +187,9 @@ var (
StateHalfShotBridge = Type{"uk.half-shot.bridge", StateEventType}
StateSpaceChild = Type{"m.space.child", StateEventType}
StateSpaceParent = Type{"m.space.parent", StateEventType}
StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType}
// Deprecated: MSC2716 has been abandoned
StateInsertionMarker = Type{"org.matrix.msc2716.marker", StateEventType}
)
// Message events

View File

@@ -72,7 +72,7 @@ func (userID UserID) URI() *MatrixURI {
}
}
var ValidLocalpartRegex = regexp.MustCompile("^[0-9a-z-.=_/]+$")
var ValidLocalpartRegex = regexp.MustCompile("^[0-9a-z-.=_/+]+$")
// ValidateUserLocalpart validates a Matrix user ID localpart using the grammar
// in https://matrix.org/docs/spec/appendices#user-identifier
@@ -132,7 +132,7 @@ func escape(buf *bytes.Buffer, b byte) {
}
func shouldEncode(b byte) bool {
return b != '-' && b != '.' && b != '_' && !(b >= '0' && b <= '9') && !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z')
return b != '-' && b != '.' && b != '_' && b != '+' && !(b >= '0' && b <= '9') && !(b >= 'a' && b <= 'z') && !(b >= 'A' && b <= 'Z')
}
func shouldEscape(b byte) bool {
@@ -140,7 +140,7 @@ func shouldEscape(b byte) bool {
}
func isValidByte(b byte) bool {
return isValidEscapedChar(b) || (b >= '0' && b <= '9') || b == '.' || b == '=' || b == '-'
return isValidEscapedChar(b) || (b >= '0' && b <= '9') || b == '.' || b == '=' || b == '-' || b == '+'
}
func isValidEscapedChar(b byte) bool {

View File

@@ -23,6 +23,8 @@ const (
AuthTypeAppservice AuthType = "m.login.application_service"
AuthTypeSynapseJWT AuthType = "org.matrix.login.jwt"
AuthTypeDevtureSharedSecret AuthType = "com.devture.shared_secret_auth"
)
type IdentifierType string
@@ -331,6 +333,7 @@ type ReqPutPushRule struct {
Pattern string `json:"pattern"`
}
// Deprecated: MSC2716 was abandoned
type ReqBatchSend struct {
PrevEventID id.EventID `json:"-"`
BatchID id.BatchID `json:"-"`
@@ -342,6 +345,16 @@ type ReqBatchSend struct {
Events []*event.Event `json:"events"`
}
type ReqBeeperBatchSend struct {
// ForwardIfNoMessages should be set to true if the batch should be forward
// backfilled if there are no messages currently in the room.
ForwardIfNoMessages bool `json:"forward_if_no_messages"`
Forward bool `json:"forward"`
SendNotification bool `json:"send_notification"`
MarkReadBy id.UserID `json:"mark_read_by,omitempty"`
Events []*event.Event `json:"events"`
}
type ReqSetReadMarkers struct {
Read id.EventID `json:"m.read,omitempty"`
ReadPrivate id.EventID `json:"m.read.private,omitempty"`

View File

@@ -3,16 +3,17 @@ package mautrix
import (
"bytes"
"encoding/json"
"fmt"
"reflect"
"strconv"
"strings"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
"go.mau.fi/util/jsontime"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util"
"maunium.net/go/mautrix/util/jsontime"
)
// RespWhoami is the JSON response for https://spec.matrix.org/v1.2/client-server-api/#get_matrixclientv3accountwhoami
@@ -270,8 +271,26 @@ type marshalableRespSync RespSync
var syncPathsToDelete = []string{"account_data", "presence", "to_device", "device_lists", "device_one_time_keys_count", "rooms"}
// marshalAndDeleteEmpty marshals a JSON object, then uses gjson to delete empty objects at the given gjson paths.
func marshalAndDeleteEmpty(marshalable interface{}, paths []string) ([]byte, error) {
data, err := json.Marshal(marshalable)
if err != nil {
return nil, err
}
for _, path := range paths {
res := gjson.GetBytes(data, path)
if res.IsObject() && len(res.Raw) == 2 {
data, err = sjson.DeleteBytes(data, path)
if err != nil {
return nil, fmt.Errorf("failed to delete empty %s: %w", path, err)
}
}
}
return data, nil
}
func (rs *RespSync) MarshalJSON() ([]byte, error) {
return util.MarshalAndDeleteEmpty((*marshalableRespSync)(rs), syncPathsToDelete)
return marshalAndDeleteEmpty((*marshalableRespSync)(rs), syncPathsToDelete)
}
type DeviceLists struct {
@@ -299,7 +318,7 @@ type marshalableSyncLeftRoom SyncLeftRoom
var syncLeftRoomPathsToDelete = []string{"summary", "state", "timeline"}
func (slr SyncLeftRoom) MarshalJSON() ([]byte, error) {
return util.MarshalAndDeleteEmpty((marshalableSyncLeftRoom)(slr), syncLeftRoomPathsToDelete)
return marshalAndDeleteEmpty((marshalableSyncLeftRoom)(slr), syncLeftRoomPathsToDelete)
}
type SyncJoinedRoom struct {
@@ -324,7 +343,7 @@ type marshalableSyncJoinedRoom SyncJoinedRoom
var syncJoinedRoomPathsToDelete = []string{"summary", "state", "timeline", "ephemeral", "account_data"}
func (sjr SyncJoinedRoom) MarshalJSON() ([]byte, error) {
return util.MarshalAndDeleteEmpty((marshalableSyncJoinedRoom)(sjr), syncJoinedRoomPathsToDelete)
return marshalAndDeleteEmpty((marshalableSyncJoinedRoom)(sjr), syncJoinedRoomPathsToDelete)
}
type SyncInvitedRoom struct {
@@ -337,7 +356,7 @@ type marshalableSyncInvitedRoom SyncInvitedRoom
var syncInvitedRoomPathsToDelete = []string{"summary"}
func (sir SyncInvitedRoom) MarshalJSON() ([]byte, error) {
return util.MarshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete)
return marshalAndDeleteEmpty((marshalableSyncInvitedRoom)(sir), syncInvitedRoomPathsToDelete)
}
type SyncKnockedRoom struct {
@@ -402,6 +421,7 @@ type RespDeviceInfo struct {
LastSeenTS int64 `json:"last_seen_ts"`
}
// Deprecated: MSC2716 was abandoned
type RespBatchSend struct {
StateEventIDs []id.EventID `json:"state_event_ids"`
EventIDs []id.EventID `json:"event_ids"`
@@ -413,6 +433,10 @@ type RespBatchSend struct {
NextBatchID id.BatchID `json:"next_batch_id"`
}
type RespBeeperBatchSend struct {
EventIDs []id.EventID `json:"event_ids"`
}
// RespCapabilities is the JSON response for https://spec.matrix.org/v1.3/client-server-api/#get_matrixclientv3capabilities
type RespCapabilities struct {
RoomVersions *CapRoomVersions `json:"m.room_versions,omitempty"`

View File

@@ -15,9 +15,10 @@ import (
"strconv"
"strings"
"go.mau.fi/util/dbutil"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/id"
"maunium.net/go/mautrix/util/dbutil"
)
//go:embed *.sql

View File

@@ -3,7 +3,7 @@ package sqlstatestore
import (
"fmt"
"maunium.net/go/mautrix/util/dbutil"
"go.mau.fi/util/dbutil"
)
func init() {

View File

@@ -1,9 +0,0 @@
base58
==========
This is a copy of <https://github.com/btcsuite/btcd/tree/master/btcutil/base58>.
## License
Package base58 is licensed under the [copyfree](http://copyfree.org) ISC
License.

View File

@@ -1,49 +0,0 @@
// Copyright (c) 2015 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
// AUTOGENERATED by genalphabet.go; do not edit.
package base58
const (
// alphabet is the modified base58 alphabet used by Bitcoin.
alphabet = "123456789ABCDEFGHJKLMNPQRSTUVWXYZabcdefghijkmnopqrstuvwxyz"
alphabetIdx0 = '1'
)
var b58 = [256]byte{
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 0, 1, 2, 3, 4, 5, 6,
7, 8, 255, 255, 255, 255, 255, 255,
255, 9, 10, 11, 12, 13, 14, 15,
16, 255, 17, 18, 19, 20, 21, 255,
22, 23, 24, 25, 26, 27, 28, 29,
30, 31, 32, 255, 255, 255, 255, 255,
255, 33, 34, 35, 36, 37, 38, 39,
40, 41, 42, 43, 255, 44, 45, 46,
47, 48, 49, 50, 51, 52, 53, 54,
55, 56, 57, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
255, 255, 255, 255, 255, 255, 255, 255,
}

View File

@@ -1,138 +0,0 @@
// Copyright (c) 2013-2015 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package base58
import (
"math/big"
)
//go:generate go run genalphabet.go
var bigRadix = [...]*big.Int{
big.NewInt(0),
big.NewInt(58),
big.NewInt(58 * 58),
big.NewInt(58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58),
bigRadix10,
}
var bigRadix10 = big.NewInt(58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58 * 58) // 58^10
// Decode decodes a modified base58 string to a byte slice.
func Decode(b string) []byte {
answer := big.NewInt(0)
scratch := new(big.Int)
// Calculating with big.Int is slow for each iteration.
// x += b58[b[i]] * j
// j *= 58
//
// Instead we can try to do as much calculations on int64.
// We can represent a 10 digit base58 number using an int64.
//
// Hence we'll try to convert 10, base58 digits at a time.
// The rough idea is to calculate `t`, such that:
//
// t := b58[b[i+9]] * 58^9 ... + b58[b[i+1]] * 58^1 + b58[b[i]] * 58^0
// x *= 58^10
// x += t
//
// Of course, in addition, we'll need to handle boundary condition when `b` is not multiple of 58^10.
// In that case we'll use the bigRadix[n] lookup for the appropriate power.
for t := b; len(t) > 0; {
n := len(t)
if n > 10 {
n = 10
}
total := uint64(0)
for _, v := range t[:n] {
tmp := b58[v]
if tmp == 255 {
return []byte("")
}
total = total*58 + uint64(tmp)
}
answer.Mul(answer, bigRadix[n])
scratch.SetUint64(total)
answer.Add(answer, scratch)
t = t[n:]
}
tmpval := answer.Bytes()
var numZeros int
for numZeros = 0; numZeros < len(b); numZeros++ {
if b[numZeros] != alphabetIdx0 {
break
}
}
flen := numZeros + len(tmpval)
val := make([]byte, flen)
copy(val[numZeros:], tmpval)
return val
}
// Encode encodes a byte slice to a modified base58 string.
func Encode(b []byte) string {
x := new(big.Int)
x.SetBytes(b)
// maximum length of output is log58(2^(8*len(b))) == len(b) * 8 / log(58)
maxlen := int(float64(len(b))*1.365658237309761) + 1
answer := make([]byte, 0, maxlen)
mod := new(big.Int)
for x.Sign() > 0 {
// Calculating with big.Int is slow for each iteration.
// x, mod = x / 58, x % 58
//
// Instead we can try to do as much calculations on int64.
// x, mod = x / 58^10, x % 58^10
//
// Which will give us mod, which is 10 digit base58 number.
// We'll loop that 10 times to convert to the answer.
x.DivMod(x, bigRadix10, mod)
if x.Sign() == 0 {
// When x = 0, we need to ensure we don't add any extra zeros.
m := mod.Int64()
for m > 0 {
answer = append(answer, alphabet[m%58])
m /= 58
}
} else {
m := mod.Int64()
for i := 0; i < 10; i++ {
answer = append(answer, alphabet[m%58])
m /= 58
}
}
}
// leading zero bytes
for _, i := range b {
if i != 0 {
break
}
answer = append(answer, alphabetIdx0)
}
// reverse
alen := len(answer)
for i := 0; i < alen/2; i++ {
answer[i], answer[alen-1-i] = answer[alen-1-i], answer[i]
}
return string(answer)
}

View File

@@ -1,52 +0,0 @@
// Copyright (c) 2013-2014 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
package base58
import (
"crypto/sha256"
"errors"
)
// ErrChecksum indicates that the checksum of a check-encoded string does not verify against
// the checksum.
var ErrChecksum = errors.New("checksum error")
// ErrInvalidFormat indicates that the check-encoded string has an invalid format.
var ErrInvalidFormat = errors.New("invalid format: version and/or checksum bytes missing")
// checksum: first four bytes of sha256^2
func checksum(input []byte) (cksum [4]byte) {
h := sha256.Sum256(input)
h2 := sha256.Sum256(h[:])
copy(cksum[:], h2[:4])
return
}
// CheckEncode prepends a version byte and appends a four byte checksum.
func CheckEncode(input []byte, version byte) string {
b := make([]byte, 0, 1+len(input)+4)
b = append(b, version)
b = append(b, input...)
cksum := checksum(b)
b = append(b, cksum[:]...)
return Encode(b)
}
// CheckDecode decodes a string that was encoded with CheckEncode and verifies the checksum.
func CheckDecode(input string) (result []byte, version byte, err error) {
decoded := Decode(input)
if len(decoded) < 5 {
return nil, 0, ErrInvalidFormat
}
version = decoded[0]
var cksum [4]byte
copy(cksum[:], decoded[len(decoded)-4:])
if checksum(decoded[:len(decoded)-4]) != cksum {
return nil, 0, ErrChecksum
}
payload := decoded[1 : len(decoded)-4]
result = append(result, payload...)
return
}

View File

@@ -1,29 +0,0 @@
// Copyright (c) 2014 The btcsuite developers
// Use of this source code is governed by an ISC
// license that can be found in the LICENSE file.
/*
Package base58 provides an API for working with modified base58 and Base58Check
encodings.
# Modified Base58 Encoding
Standard base58 encoding is similar to standard base64 encoding except, as the
name implies, it uses a 58 character alphabet which results in an alphanumeric
string and allows some characters which are problematic for humans to be
excluded. Due to this, there can be various base58 alphabets.
The modified base58 alphabet used by Bitcoin, and hence this package, omits the
0, O, I, and l characters that look the same in many fonts and are therefore
hard to humans to distinguish.
# Base58Check Encoding Scheme
The Base58Check encoding scheme is primarily used for Bitcoin addresses at the
time of this writing, however it can be used to generically encode arbitrary
byte arrays into human-readable strings along with a version byte that can be
used to differentiate the same payload. For Bitcoin addresses, the extra
version is used to differentiate the network of otherwise identical public keys
which helps prevent using an address intended for one network on another.
*/
package base58

View File

@@ -1,28 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"fmt"
"runtime"
"strings"
)
// CallerWithFunctionName is an implementation for zerolog.CallerMarshalFunc that includes the caller function name
// in addition to the file and line number.
//
// Use as
//
// zerolog.CallerMarshalFunc = util.CallerWithFunctionName
func CallerWithFunctionName(pc uintptr, file string, line int) string {
files := strings.Split(file, "/")
file = files[len(files)-1]
name := runtime.FuncForPC(pc).Name()
fns := strings.Split(name, ".")
name = fns[len(fns)-1]
return fmt.Sprintf("%s:%d:%s()", file, line, name)
}

View File

@@ -1,189 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"time"
)
// LoggingExecable is a wrapper for anything with database Exec methods (i.e. sql.Conn, sql.DB and sql.Tx)
// that can preprocess queries (e.g. replacing $ with ? on SQLite) and log query durations.
type LoggingExecable struct {
UnderlyingExecable UnderlyingExecable
db *Database
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
return res, err
}
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
return &LoggingRows{
ctx: ctx,
db: le.db,
query: query,
args: args,
rs: rows,
start: start,
}, err
}
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
le.db.Log.QueryTiming(ctx, "QueryRow", query, args, -1, time.Since(start), nil)
return row
}
func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result, error) {
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
func (le *LoggingExecable) QueryRow(query string, args ...interface{}) *sql.Row {
return le.QueryRowContext(context.Background(), query, args...)
}
// loggingDB is a wrapper for LoggingExecable that allows access to BeginTx.
//
// While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since
// the LoggingExecable could be for a transaction (where BeginTx wouldn't make sense).
type loggingDB struct {
LoggingExecable
}
func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
targetDB := ld.db.RawDB
if opts != nil && opts.ReadOnly && ld.db.ReadOnlyDB != nil {
targetDB = ld.db.ReadOnlyDB
}
start := time.Now()
tx, err := targetDB.BeginTx(ctx, opts)
ld.db.Log.QueryTiming(ctx, "Begin", "", nil, -1, time.Since(start), err)
if err != nil {
return nil, err
}
return &LoggingTxn{
LoggingExecable: LoggingExecable{UnderlyingExecable: tx, db: ld.db},
UnderlyingTx: tx,
ctx: ctx,
StartTime: start,
}, nil
}
func (ld *loggingDB) Begin() (*LoggingTxn, error) {
return ld.BeginTx(context.Background(), nil)
}
type LoggingTxn struct {
LoggingExecable
UnderlyingTx *sql.Tx
ctx context.Context
StartTime time.Time
EndTime time.Time
noTotalLog bool
}
func (lt *LoggingTxn) Commit() error {
start := time.Now()
err := lt.UnderlyingTx.Commit()
lt.EndTime = time.Now()
if !lt.noTotalLog {
lt.db.Log.QueryTiming(lt.ctx, "<Transaction>", "", nil, -1, lt.EndTime.Sub(lt.StartTime), nil)
}
lt.db.Log.QueryTiming(lt.ctx, "Commit", "", nil, -1, time.Since(start), err)
return err
}
func (lt *LoggingTxn) Rollback() error {
start := time.Now()
err := lt.UnderlyingTx.Rollback()
lt.EndTime = time.Now()
if !lt.noTotalLog {
lt.db.Log.QueryTiming(lt.ctx, "<Transaction>", "", nil, -1, lt.EndTime.Sub(lt.StartTime), nil)
}
lt.db.Log.QueryTiming(lt.ctx, "Rollback", "", nil, -1, time.Since(start), err)
return err
}
type LoggingRows struct {
ctx context.Context
db *Database
query string
args []interface{}
rs Rows
start time.Time
nrows int
}
func (lrs *LoggingRows) stopTiming() {
if !lrs.start.IsZero() {
lrs.db.Log.QueryTiming(lrs.ctx, "EndRows", lrs.query, lrs.args, lrs.nrows, time.Since(lrs.start), lrs.rs.Err())
lrs.start = time.Time{}
}
}
func (lrs *LoggingRows) Close() error {
err := lrs.rs.Close()
lrs.stopTiming()
return err
}
func (lrs *LoggingRows) ColumnTypes() ([]*sql.ColumnType, error) {
return lrs.rs.ColumnTypes()
}
func (lrs *LoggingRows) Columns() ([]string, error) {
return lrs.rs.Columns()
}
func (lrs *LoggingRows) Err() error {
return lrs.rs.Err()
}
func (lrs *LoggingRows) Next() bool {
hasNext := lrs.rs.Next()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) NextResultSet() bool {
hasNext := lrs.rs.NextResultSet()
if !hasNext {
lrs.stopTiming()
} else {
lrs.nrows++
}
return hasNext
}
func (lrs *LoggingRows) Scan(dest ...any) error {
return lrs.rs.Scan(dest...)
}

View File

@@ -1,289 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"fmt"
"net/url"
"regexp"
"strings"
"time"
)
type Dialect int
const (
DialectUnknown Dialect = iota
Postgres
SQLite
)
func (dialect Dialect) String() string {
switch dialect {
case Postgres:
return "postgres"
case SQLite:
return "sqlite3"
default:
return ""
}
}
func ParseDialect(engine string) (Dialect, error) {
engine = strings.ToLower(engine)
if strings.HasPrefix(engine, "postgres") || engine == "pgx" {
return Postgres, nil
} else if strings.HasPrefix(engine, "sqlite") || strings.HasPrefix(engine, "litestream") {
return SQLite, nil
} else {
return DialectUnknown, fmt.Errorf("unknown dialect '%s'", engine)
}
}
type Rows interface {
Close() error
ColumnTypes() ([]*sql.ColumnType, error)
Columns() ([]string, error)
Err() error
Next() bool
NextResultSet() bool
Scan(...any) error
}
type Scannable interface {
Scan(...interface{}) error
}
// Expected implementations of Scannable
var (
_ Scannable = (*sql.Row)(nil)
_ Scannable = (Rows)(nil)
)
type UnderlyingContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type ContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type UnderlyingExecable interface {
UnderlyingContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Execable interface {
ContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
}
type Transaction interface {
Execable
Commit() error
Rollback() error
}
// Expected implementations of Execable
var (
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ UnderlyingContextExecable = (*sql.Conn)(nil)
)
type Database struct {
loggingDB
RawDB *sql.DB
ReadOnlyDB *sql.DB
Owner string
VersionTable string
Log DatabaseLogger
Dialect Dialect
UpgradeTable UpgradeTable
IgnoreForeignTables bool
IgnoreUnsupportedDatabase bool
}
var positionalParamPattern = regexp.MustCompile(`\$(\d+)`)
func (db *Database) mutateQuery(query string) string {
switch db.Dialect {
case SQLite:
return positionalParamPattern.ReplaceAllString(query, "?$1")
default:
return query
}
}
func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log DatabaseLogger) *Database {
if log == nil {
log = db.Log
}
return &Database{
RawDB: db.RawDB,
loggingDB: db.loggingDB,
Owner: "",
VersionTable: versionTable,
UpgradeTable: upgradeTable,
Log: log,
Dialect: db.Dialect,
IgnoreForeignTables: true,
IgnoreUnsupportedDatabase: db.IgnoreUnsupportedDatabase,
}
}
func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
dialect, err := ParseDialect(rawDialect)
if err != nil {
return nil, err
}
wrappedDB := &Database{
RawDB: db,
Dialect: dialect,
Log: NoopLogger,
IgnoreForeignTables: true,
VersionTable: "version",
}
wrappedDB.loggingDB.UnderlyingExecable = db
wrappedDB.loggingDB.db = wrappedDB
return wrappedDB, nil
}
func NewWithDialect(uri, rawDialect string) (*Database, error) {
db, err := sql.Open(rawDialect, uri)
if err != nil {
return nil, err
}
return NewWithDB(db, rawDialect)
}
type PoolConfig struct {
Type string `yaml:"type"`
URI string `yaml:"uri"`
MaxOpenConns int `yaml:"max_open_conns"`
MaxIdleConns int `yaml:"max_idle_conns"`
ConnMaxIdleTime string `yaml:"conn_max_idle_time"`
ConnMaxLifetime string `yaml:"conn_max_lifetime"`
}
type Config struct {
PoolConfig `yaml:",inline"`
ReadOnlyPool PoolConfig `yaml:"ro_pool"`
}
func (db *Database) Close() error {
err := db.RawDB.Close()
if db.ReadOnlyDB != nil {
err2 := db.ReadOnlyDB.Close()
if err == nil {
err = fmt.Errorf("closing read-only db failed: %w", err)
} else {
err = fmt.Errorf("%w (closing read-only db also failed: %v)", err, err2)
}
}
return err
}
func (db *Database) Configure(cfg Config) error {
if err := db.configure(db.ReadOnlyDB, cfg.ReadOnlyPool); err != nil {
return err
}
return db.configure(db.RawDB, cfg.PoolConfig)
}
func (db *Database) configure(rawDB *sql.DB, cfg PoolConfig) error {
if rawDB == nil {
return nil
}
rawDB.SetMaxOpenConns(cfg.MaxOpenConns)
rawDB.SetMaxIdleConns(cfg.MaxIdleConns)
if len(cfg.ConnMaxIdleTime) > 0 {
maxIdleTimeDuration, err := time.ParseDuration(cfg.ConnMaxIdleTime)
if err != nil {
return fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
rawDB.SetConnMaxIdleTime(maxIdleTimeDuration)
}
if len(cfg.ConnMaxLifetime) > 0 {
maxLifetimeDuration, err := time.ParseDuration(cfg.ConnMaxLifetime)
if err != nil {
return fmt.Errorf("failed to parse max_conn_idle_time: %w", err)
}
rawDB.SetConnMaxLifetime(maxLifetimeDuration)
}
return nil
}
func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database, error) {
wrappedDB, err := NewWithDialect(cfg.URI, cfg.Type)
if err != nil {
return nil, err
}
wrappedDB.Owner = owner
if logger != nil {
wrappedDB.Log = logger
}
if cfg.ReadOnlyPool.MaxOpenConns > 0 {
if cfg.ReadOnlyPool.Type == "" {
cfg.ReadOnlyPool.Type = cfg.Type
}
roUri := cfg.ReadOnlyPool.URI
if roUri == "" {
uriParts := strings.Split(cfg.URI, "?")
var qs url.Values
if len(uriParts) == 2 {
var err error
qs, err = url.ParseQuery(uriParts[1])
if err != nil {
return nil, err
}
qs.Del("_txlock")
}
qs.Set("_query_only", "true")
roUri = uriParts[0] + "?" + qs.Encode()
}
wrappedDB.ReadOnlyDB, err = sql.Open(cfg.ReadOnlyPool.Type, roUri)
if err != nil {
return nil, err
}
}
err = wrappedDB.Configure(cfg)
if err != nil {
return nil, err
}
return wrappedDB, nil
}

View File

@@ -1,157 +0,0 @@
package dbutil
import (
"context"
"regexp"
"strings"
"time"
"github.com/rs/zerolog"
"maunium.net/go/maulogger/v2"
)
type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error)
WarnUnsupportedVersion(current, compat, latest int)
PrepareUpgrade(current, compat, latest int)
DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{})
}
type noopLogger struct{}
var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) {
}
type mauLogger struct {
l maulogger.Logger
}
// Deprecated: Use zerolog instead
func MauLogger(log maulogger.Logger) DatabaseLogger {
return &mauLogger{l: log}
}
func (m mauLogger) WarnUnsupportedVersion(current, compat, latest int) {
m.l.Warnfln("Unsupported database schema version: currently on v%d (compatible down to v%d), latest known: v%d - continuing anyway", current, compat, latest)
}
func (m mauLogger) PrepareUpgrade(current, compat, latest int) {
m.l.Infofln("Database currently on v%d (compat: v%d), latest known: v%d", current, compat, latest)
}
func (m mauLogger) DoUpgrade(from, to int, message string, _ bool) {
m.l.Infofln("Upgrading database from v%d to v%d: %s", from, to, message)
}
func (m mauLogger) QueryTiming(_ context.Context, method, query string, _ []interface{}, _ int, duration time.Duration, _ error) {
if duration > 1*time.Second {
m.l.Warnfln("%s(%s) took %.3f seconds", method, query, duration.Seconds())
}
}
func (m mauLogger) Warn(msg string, args ...interface{}) {
m.l.Warnfln(msg, args...)
}
type zeroLogger struct {
l *zerolog.Logger
ZeroLogSettings
}
type ZeroLogSettings struct {
CallerSkipFrame int
Caller bool
}
func ZeroLogger(log zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
return ZeroLoggerPtr(&log, cfg...)
}
func ZeroLoggerPtr(log *zerolog.Logger, cfg ...ZeroLogSettings) DatabaseLogger {
wrapped := &zeroLogger{l: log}
if len(cfg) > 0 {
wrapped.ZeroLogSettings = cfg[0]
} else {
wrapped.ZeroLogSettings = ZeroLogSettings{
CallerSkipFrame: 2, // Skip LoggingExecable.ExecContext and zeroLogger.QueryTiming
Caller: true,
}
}
return wrapped
}
func (z zeroLogger) WarnUnsupportedVersion(current, compat, latest int) {
z.l.Warn().
Int("current_version", current).
Int("oldest_compatible_version", compat).
Int("latest_known_version", latest).
Msg("Unsupported database schema version, continuing anyway")
}
func (z zeroLogger) PrepareUpgrade(current, compat, latest int) {
evt := z.l.Info().
Int("current_version", current).
Int("oldest_compatible_version", compat).
Int("latest_known_version", latest)
if current >= latest {
evt.Msg("Database is up to date")
} else {
evt.Msg("Preparing to update database schema")
}
}
func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
z.l.Info().
Int("from", from).
Int("to", to).
Bool("single_txn", txn).
Str("description", message).
Msg("Upgrading database")
}
var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
log = z.l
}
if log.GetLevel() != zerolog.TraceLevel && duration < 1*time.Second {
return
}
if nrows > -1 {
rowLog := log.With().Int("rows", nrows).Logger()
log = &rowLog
}
query = strings.TrimSpace(whitespaceRegex.ReplaceAllLiteralString(query, " "))
log.Trace().
Err(err).
Int64("duration_µs", duration.Microseconds()).
Str("method", method).
Str("query", query).
Interface("query_args", args).
Msg("Query")
if duration >= 1*time.Second {
evt := log.Warn().
Float64("duration_seconds", duration.Seconds()).
Str("method", method).
Str("query", query)
if z.Caller {
evt = evt.Caller(z.CallerSkipFrame)
}
evt.Msg("Query took long")
}
}
func (z zeroLogger) Warn(msg string, args ...interface{}) {
z.l.Warn().Msgf(msg, args...)
}

View File

@@ -1,21 +0,0 @@
-- v0 -> v3: Sample revision jump
CREATE TABLE foo (
-- only: postgres
key BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
-- only: sqlite
key INTEGER PRIMARY KEY,
data JSONB NOT NULL
);
-- only: sqlite until "end only"
CREATE TRIGGER test AFTER INSERT ON foo WHEN NEW.data->>'action' = 'delete' BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
END;
-- end only sqlite
-- only: postgres until "end only"
CREATE FUNCTION delete_data() RETURNS TRIGGER LANGUAGE plpgsql AS $$ BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
RETURN NEW;
END $$;
-- end only postgres

View File

@@ -1,4 +0,0 @@
-- v4: Sample outside transaction
-- transaction: off
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -1,3 +0,0 @@
-- v5 (compatible with v3+): Sample backwards-compatible upgrade
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -1,11 +0,0 @@
CREATE TABLE foo (
key BIGINT PRIMARY KEY GENERATED BY DEFAULT AS IDENTITY,
data JSONB NOT NULL
);
CREATE FUNCTION delete_data() RETURNS TRIGGER LANGUAGE plpgsql AS $$ BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
RETURN NEW;
END $$;
-- end only postgres

View File

@@ -1,10 +0,0 @@
CREATE TABLE foo (
key INTEGER PRIMARY KEY,
data JSONB NOT NULL
);
CREATE TRIGGER test AFTER INSERT ON foo WHEN NEW.data->>'action' = 'delete' BEGIN
DELETE FROM test WHERE key <= NEW.data->>'index';
END;
-- end only sqlite

View File

@@ -1 +0,0 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -1 +0,0 @@
INSERT INTO foo VALUES ('meow', '{}');

View File

@@ -1 +0,0 @@
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -1 +0,0 @@
INSERT INTO foo VALUES ('meow 2', '{}');

View File

@@ -1,93 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"errors"
"fmt"
"time"
"github.com/rs/zerolog"
"maunium.net/go/mautrix/util"
)
var (
ErrTxn = errors.New("transaction")
ErrTxnBegin = fmt.Errorf("%w: begin", ErrTxn)
ErrTxnCommit = fmt.Errorf("%w: commit", ErrTxn)
)
type contextKey int
const (
ContextKeyDatabaseTransaction contextKey = iota
ContextKeyDoTxnCallerSkip
)
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
zerolog.Ctx(ctx).Debug().Msg("Already in a transaction, not creating a new one")
return fn(ctx)
}
log := zerolog.Ctx(ctx).With().Str("db_txn_id", util.RandomString(12)).Logger()
start := time.Now()
defer func() {
dur := time.Since(start)
if dur > time.Second {
val := ctx.Value(ContextKeyDoTxnCallerSkip)
callerSkip := 2
if val != nil {
callerSkip += val.(int)
}
log.Warn().
Float64("duration_seconds", dur.Seconds()).
Caller(callerSkip).
Msg("Transaction took a long time")
}
}()
tx, err := db.BeginTx(ctx, opts)
if err != nil {
log.Trace().Err(err).Msg("Failed to begin transaction")
return util.NewDualError(ErrTxnBegin, err)
}
log.Trace().Msg("Transaction started")
tx.noTotalLog = true
ctx = log.WithContext(ctx)
ctx = context.WithValue(ctx, ContextKeyDatabaseTransaction, tx)
err = fn(ctx)
if err != nil {
log.Trace().Err(err).Msg("Database transaction failed, rolling back")
rollbackErr := tx.Rollback()
if rollbackErr != nil {
log.Warn().Err(rollbackErr).Msg("Rollback after transaction error failed")
} else {
log.Trace().Msg("Rollback successful")
}
return err
}
err = tx.Commit()
if err != nil {
log.Trace().Err(err).Msg("Commit failed")
return util.NewDualError(ErrTxnCommit, err)
}
log.Trace().Msg("Commit successful")
return nil
}
func (db *Database) Conn(ctx context.Context) ContextExecable {
if ctx == nil {
return db
}
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
if ok {
return txn
}
return db
}

View File

@@ -1,218 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"database/sql"
"errors"
"fmt"
)
type upgradeFunc func(Execable, *Database) error
type upgrade struct {
message string
fn upgradeFunc
upgradesTo int
compatVersion int
transaction bool
}
var ErrUnsupportedDatabaseVersion = errors.New("unsupported database schema version")
var ErrForeignTables = errors.New("the database contains foreign tables")
var ErrNotOwned = errors.New("the database is owned by")
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
func (db *Database) upgradeVersionTable() error {
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
return fmt.Errorf("failed to check if version table is up to date: %w", err)
} else if !compatColumnExists {
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
} else if !tableExists {
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to create version table: %w", err)
}
} else {
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to add compat column to version table: %w", err)
}
}
}
return nil
}
func (db *Database) getVersion() (version, compat int, err error) {
if err = db.upgradeVersionTable(); err != nil {
return
}
var compatNull sql.NullInt32
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
if compatNull.Valid && compatNull.Int32 != 0 {
compat = int(compatNull.Int32)
} else {
compat = version
}
return
}
const (
tableExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.tables WHERE table_name=$1)"
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
)
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect {
case SQLite:
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
case Postgres:
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
const (
columnExistsPostgres = "SELECT EXISTS(SELECT 1 FROM information_schema.columns WHERE table_name=$1 AND column_name=$2)"
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
)
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect {
case SQLite:
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
case Postgres:
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
func (db *Database) tableExistsNoError(table string) bool {
exists, err := db.TableExists(nil, table)
if err != nil {
panic(fmt.Errorf("failed to check if table exists: %w", err))
}
return exists
}
const createOwnerTable = `
CREATE TABLE IF NOT EXISTS database_owner (
key INTEGER PRIMARY KEY DEFAULT 0,
owner TEXT NOT NULL
)
`
func (db *Database) checkDatabaseOwner() error {
var owner string
if !db.IgnoreForeignTables {
if db.tableExistsNoError("state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if db.tableExistsNoError("roomserver_rooms") {
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
}
}
if db.Owner == "" {
return nil
}
if _, err := db.Exec(createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
} else if err != nil {
return fmt.Errorf("failed to check database owner: %w", err)
} else if owner != db.Owner {
return fmt.Errorf("%w %s", ErrNotOwned, owner)
}
return nil
}
func (db *Database) setVersion(tx Execable, version, compat int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
return err
}
func (db *Database) Upgrade() error {
err := db.checkDatabaseOwner()
if err != nil {
return err
}
version, compat, err := db.getVersion()
if err != nil {
return err
}
if compat > len(db.UpgradeTable) {
if db.IgnoreUnsupportedDatabase {
db.Log.WarnUnsupportedVersion(version, compat, len(db.UpgradeTable))
return nil
}
return fmt.Errorf("%w: currently on v%d (compatible down to v%d), latest known: v%d", ErrUnsupportedDatabaseVersion, version, compat, len(db.UpgradeTable))
}
db.Log.PrepareUpgrade(version, compat, len(db.UpgradeTable))
logVersion := version
for version < len(db.UpgradeTable) {
upgradeItem := db.UpgradeTable[version]
if upgradeItem.fn == nil {
version++
continue
}
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
var tx Transaction
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil {
return err
}
upgradeConn = tx
} else {
upgradeConn = db
}
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
if err != nil {
return err
}
if tx != nil {
err = tx.Commit()
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -1,283 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"bytes"
"errors"
"fmt"
"io/fs"
"path/filepath"
"regexp"
"strconv"
"strings"
)
type UpgradeTable []upgrade
func (ut *UpgradeTable) extend(toSize int) {
if cap(*ut) >= toSize {
*ut = (*ut)[:toSize]
} else {
resized := make([]upgrade, toSize)
copy(resized, *ut)
*ut = resized
}
}
func (ut *UpgradeTable) Register(from, to, compat int, message string, txn bool, fn upgradeFunc) {
if from < 0 {
from += to
}
if from < 0 {
panic("invalid from value in UpgradeTable.Register() call")
}
if compat <= 0 {
compat = to
}
upg := upgrade{message: message, fn: fn, upgradesTo: to, compatVersion: compat, transaction: txn}
if len(*ut) == from {
*ut = append(*ut, upg)
return
} else if len(*ut) < from {
ut.extend(from + 1)
} else if (*ut)[from].fn != nil {
panic(fmt.Errorf("tried to override upgrade at %d ('%s') with '%s'", from, (*ut)[from].message, upg.message))
}
(*ut)[from] = upg
}
// Syntax is either
//
// -- v0 -> v1: Message
//
// or
//
// -- v1: Message
//
// Both syntaxes may also have a compatibility notice before the colon:
//
// -- v5 (compatible with v3+): Upgrade with backwards compatibility
var upgradeHeaderRegex = regexp.MustCompile(`^-- (?:v(\d+) -> )?v(\d+)(?: \(compatible with v(\d+)\+\))?: (.+)$`)
// To disable wrapping the upgrade in a single transaction, put `--transaction: off` on the second line.
//
// -- v5: Upgrade without transaction
// -- transaction: off
// // do dangerous stuff
var transactionDisableRegex = regexp.MustCompile(`^-- transaction: (\w*)`)
func parseFileHeader(file []byte) (from, to, compat int, message string, txn bool, lines [][]byte, err error) {
lines = bytes.Split(file, []byte("\n"))
if len(lines) < 2 {
err = errors.New("upgrade file too short")
return
}
var maybeFrom int
match := upgradeHeaderRegex.FindSubmatch(lines[0])
lines = lines[1:]
if match == nil {
err = errors.New("header not found")
} else if len(match) != 5 {
err = errors.New("unexpected number of items in regex match")
} else if maybeFrom, err = strconv.Atoi(string(match[1])); len(match[1]) > 0 && err != nil {
err = fmt.Errorf("invalid source version: %w", err)
} else if to, err = strconv.Atoi(string(match[2])); err != nil {
err = fmt.Errorf("invalid target version: %w", err)
} else if compat, err = strconv.Atoi(string(match[3])); len(match[3]) > 0 && err != nil {
err = fmt.Errorf("invalid compatible version: %w", err)
} else {
err = nil
if len(match[1]) > 0 {
from = maybeFrom
} else {
from = -1
}
message = string(match[4])
txn = true
match = transactionDisableRegex.FindSubmatch(lines[0])
if match != nil {
lines = lines[1:]
if string(match[1]) != "off" {
err = fmt.Errorf("invalid value %q for transaction flag", match[1])
}
txn = false
}
}
return
}
// To limit the next line to one dialect:
//
// -- only: postgres
//
// To limit the next N lines:
//
// -- only: sqlite for next 123 lines
//
// If the single-line limit is on the second line of the file, the whole file is limited to that dialect.
var dialectLineFilter = regexp.MustCompile(`^\s*-- only: (postgres|sqlite)(?: for next (\d+) lines| until "(end) only")?`)
// Constants used to make parseDialectFilter clearer
const (
skipUntilEndTag = -1
skipNothing = 0
skipCurrentLine = 1
skipNextLine = 2
)
func (db *Database) parseDialectFilter(line []byte) (int, error) {
match := dialectLineFilter.FindSubmatch(line)
if match == nil {
return skipNothing, nil
}
dialect, err := ParseDialect(string(match[1]))
if err != nil {
return skipNothing, err
} else if dialect == db.Dialect {
// Skip the dialect filter line
return skipCurrentLine, nil
} else if bytes.Equal(match[3], []byte("end")) {
return skipUntilEndTag, nil
} else if len(match[2]) == 0 {
// Skip the dialect filter and the next line
return skipNextLine, nil
} else {
// Parse number of lines to skip, add 1 for current line
lineCount, err := strconv.Atoi(string(match[2]))
if err != nil {
return skipNothing, fmt.Errorf("invalid line count '%s': %w", match[2], err)
}
return skipCurrentLine + lineCount, nil
}
}
var endLineFilter = regexp.MustCompile(`^\s*-- end only (postgres|sqlite)$`)
func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
output := make([][]byte, 0, len(lines))
for i := 0; i < len(lines); i++ {
skipLines, err := db.parseDialectFilter(lines[i])
if err != nil {
return "", err
} else if skipLines > 0 {
// Current line is implicitly skipped, so reduce one here
i += skipLines - 1
} else if skipLines == skipUntilEndTag {
startedAt := i
startedAtMatch := dialectLineFilter.FindSubmatch(lines[startedAt])
for ; i < len(lines); i++ {
if match := endLineFilter.FindSubmatch(lines[i]); match != nil {
if !bytes.Equal(match[1], startedAtMatch[1]) {
return "", fmt.Errorf(`unexpected end tag %q for %q start at line %d`, string(match[0]), string(startedAtMatch[1]), startedAt)
}
break
}
}
if i == len(lines) {
return "", fmt.Errorf(`didn't get end tag matching start %q at line %d`, string(startedAtMatch[1]), startedAt)
}
} else {
output = append(output, lines[i])
}
}
return string(bytes.Join(output, []byte("\n"))), nil
}
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx Execable, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
return nil
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
} else {
_, err = tx.Exec(upgradeSQL)
return err
}
}
}
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
return func(tx Execable, database *Database) (err error) {
switch database.Dialect {
case SQLite:
_, err = tx.Exec(sqliteData)
case Postgres:
_, err = tx.Exec(postgresData)
default:
err = fmt.Errorf("unknown dialect %s", database.Dialect)
}
return
}
}
func parseSplitSQLUpgrade(name string, fs fullFS, skipNames map[string]struct{}) (from, to, compat int, message string, txn bool, fn upgradeFunc) {
postgresName := fmt.Sprintf("%s.postgres.sql", name)
sqliteName := fmt.Sprintf("%s.sqlite.sql", name)
skipNames[postgresName] = struct{}{}
skipNames[sqliteName] = struct{}{}
postgresData, err := fs.ReadFile(postgresName)
if err != nil {
panic(err)
}
sqliteData, err := fs.ReadFile(sqliteName)
if err != nil {
panic(err)
}
from, to, compat, message, txn, _, err = parseFileHeader(postgresData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", postgresName, err))
}
sqliteFrom, sqliteTo, sqliteCompat, sqliteMessage, sqliteTxn, _, err := parseFileHeader(sqliteData)
if err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", sqliteName, err))
}
if from != sqliteFrom || to != sqliteTo || compat != sqliteCompat {
panic(fmt.Errorf("mismatching versions in postgres and sqlite versions of %s: %d/%d -> %d/%d", name, from, sqliteFrom, to, sqliteTo))
} else if message != sqliteMessage {
panic(fmt.Errorf("mismatching message in postgres and sqlite versions of %s: %q != %q", name, message, sqliteMessage))
} else if txn != sqliteTxn {
panic(fmt.Errorf("mismatching transaction flag in postgres and sqlite versions of %s: %t != %t", name, txn, sqliteTxn))
}
fn = splitSQLUpgradeFunc(string(sqliteData), string(postgresData))
return
}
type fullFS interface {
fs.ReadFileFS
fs.ReadDirFS
}
var splitFileNameRegex = regexp.MustCompile(`^(.+)\.(postgres|sqlite)\.sql$`)
func (ut *UpgradeTable) RegisterFS(fs fullFS) {
ut.RegisterFSPath(fs, ".")
}
func (ut *UpgradeTable) RegisterFSPath(fs fullFS, dir string) {
files, err := fs.ReadDir(dir)
if err != nil {
panic(err)
}
skipNames := map[string]struct{}{}
for _, file := range files {
if file.IsDir() || !strings.HasSuffix(file.Name(), ".sql") {
// do nothing
} else if _, skip := skipNames[file.Name()]; skip {
// also do nothing
} else if splitName := splitFileNameRegex.FindStringSubmatch(file.Name()); splitName != nil {
from, to, compat, message, txn, fn := parseSplitSQLUpgrade(splitName[1], fs, skipNames)
ut.Register(from, to, compat, message, txn, fn)
} else if data, err := fs.ReadFile(filepath.Join(dir, file.Name())); err != nil {
panic(err)
} else if from, to, compat, message, txn, lines, err := parseFileHeader(data); err != nil {
panic(fmt.Errorf("failed to parse header in %s: %w", file.Name(), err))
} else {
ut.Register(from, to, compat, message, txn, sqlUpgradeFunc(file.Name(), lines))
}
}
}

View File

@@ -1,33 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"errors"
"fmt"
)
type DualError struct {
High error
Low error
}
func NewDualError(high, low error) DualError {
return DualError{high, low}
}
func (err DualError) Is(other error) bool {
return errors.Is(other, err.High) || errors.Is(other, err.Low)
}
func (err DualError) Unwrap() error {
return err.Low
}
func (err DualError) Error() string {
return fmt.Sprintf("%v: %v", err.High, err.Low)
}

View File

@@ -1,54 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"errors"
"fmt"
"strings"
"time"
)
var Day = 24 * time.Hour
var Week = 7 * Day
func pluralize(value int, unit string) string {
if value == 1 {
return "1 " + unit
}
return fmt.Sprintf("%d %ss", value, unit)
}
func appendDurationPart(time, unit time.Duration, name string, parts *[]string) (remainder time.Duration) {
if time < unit {
return time
}
value := int(time / unit)
remainder = time % unit
*parts = append(*parts, pluralize(value, name))
return
}
func FormatDuration(d time.Duration) string {
if d < 0 {
panic(errors.New("FormatDuration: negative duration"))
} else if d < time.Second {
return "now"
}
parts := make([]string, 0, 2)
d = appendDurationPart(d, Week, "week", &parts)
d = appendDurationPart(d, Day, "day", &parts)
d = appendDurationPart(d, time.Hour, "hour", &parts)
d = appendDurationPart(d, time.Minute, "minute", &parts)
d = appendDurationPart(d, time.Second, "second", &parts)
if len(parts) > 2 {
parts[0] = strings.Join(parts[:len(parts)-1], ", ")
parts[1] = parts[len(parts)-1]
parts = parts[:2]
}
return strings.Join(parts, " and ")
}

View File

@@ -1,31 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"strings"
)
var GJSONEscaper = strings.NewReplacer(
`\`, `\\`,
".", `\.`,
"|", `\|`,
"#", `\#`,
"@", `\@`,
"*", `\*`,
"?", `\?`)
func GJSONPath(path ...string) string {
var result strings.Builder
for i, part := range path {
_, _ = GJSONEscaper.WriteString(&result, part)
if i < len(path)-1 {
result.WriteRune('.')
}
}
return result.String()
}

View File

@@ -1,86 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package jsontime
import (
"encoding/json"
"time"
)
func UM(time time.Time) UnixMilli {
return UnixMilli{Time: time}
}
func UMInt(ts int64) UnixMilli {
return UM(time.UnixMilli(ts))
}
func UnixMilliNow() UnixMilli {
return UM(time.Now())
}
type UnixMilli struct {
time.Time
}
func (um UnixMilli) MarshalJSON() ([]byte, error) {
if um.IsZero() {
return []byte{'0'}, nil
}
return json.Marshal(um.UnixMilli())
}
func (um *UnixMilli) UnmarshalJSON(data []byte) error {
var val int64
err := json.Unmarshal(data, &val)
if err != nil {
return err
}
if val == 0 {
um.Time = time.Time{}
} else {
um.Time = time.UnixMilli(val)
}
return nil
}
func U(time time.Time) Unix {
return Unix{Time: time}
}
func UInt(ts int64) Unix {
return U(time.Unix(ts, 0))
}
func UnixNow() Unix {
return U(time.Now())
}
type Unix struct {
time.Time
}
func (u Unix) MarshalJSON() ([]byte, error) {
if u.IsZero() {
return []byte{'0'}, nil
}
return json.Marshal(u.Unix())
}
func (u *Unix) UnmarshalJSON(data []byte) error {
var val int64
err := json.Unmarshal(data, &val)
if err != nil {
return err
}
if val == 0 {
u.Time = time.Time{}
} else {
u.Time = time.Unix(val, 0)
}
return nil
}

View File

@@ -1,30 +0,0 @@
package util
import (
"encoding/json"
"fmt"
"github.com/tidwall/gjson"
"github.com/tidwall/sjson"
)
// MarshalAndDeleteEmpty marshals a JSON object, then uses gjson to delete empty objects at the given gjson paths.
//
// This can be used as a convenient way to create a marshaler that omits empty non-pointer structs.
// See mautrix.RespSync for example.
func MarshalAndDeleteEmpty(marshalable interface{}, paths []string) ([]byte, error) {
data, err := json.Marshal(marshalable)
if err != nil {
return nil, err
}
for _, path := range paths {
res := gjson.GetBytes(data, path)
if res.IsObject() && len(res.Raw) == 2 {
data, err = sjson.DeleteBytes(data, path)
if err != nil {
return nil, fmt.Errorf("failed to delete empty %s: %w", path, err)
}
}
}
return data, nil
}

View File

@@ -1,50 +0,0 @@
// Copyright (c) 2022 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"mime"
"strings"
)
// MimeExtensionSanityOverrides includes extensions for various common mimetypes.
//
// This is necessary because sometimes the OS mimetype database and Go interact in weird ways,
// which causes very obscure extensions to be first in the array for common mimetypes
// (e.g. image/jpeg -> .jpe, text/plain -> ,v).
var MimeExtensionSanityOverrides = map[string]string{
"image/png": ".png",
"image/webp": ".webp",
"image/jpeg": ".jpg",
"image/tiff": ".tiff",
"image/heif": ".heic",
"image/heic": ".heic",
"audio/mpeg": ".mp3",
"audio/ogg": ".ogg",
"audio/webm": ".webm",
"audio/x-caf": ".caf",
"video/mp4": ".mp4",
"video/mpeg": ".mpeg",
"video/webm": ".webm",
"text/plain": ".txt",
"text/html": ".html",
"application/xml": ".xml",
}
func ExtensionFromMimetype(mimetype string) string {
ext, ok := MimeExtensionSanityOverrides[strings.Split(mimetype, ";")[0]]
if !ok {
exts, _ := mime.ExtensionsByType(mimetype)
if len(exts) > 0 {
ext = exts[0]
}
}
return ext
}

View File

@@ -1,65 +0,0 @@
// Copyright (c) 2022 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"crypto/rand"
"encoding/base64"
"hash/crc32"
"strings"
"unsafe"
)
func RandomBytes(n int) []byte {
data := make([]byte, n)
_, err := rand.Read(data)
if err != nil {
panic(err)
}
return data
}
var letters = "0123456789ABCDEFGHIJKLMNOPQRSTUVWXYZabcdefghijklmnopqrstuvwxyz"
// RandomString generates a random string of the given length.
func RandomString(n int) string {
if n <= 0 {
return ""
}
base64Len := n
if n%4 != 0 {
base64Len += 4 - (n % 4)
}
decodedLength := base64.RawStdEncoding.DecodedLen(base64Len)
output := make([]byte, base64Len)
base64.RawStdEncoding.Encode(output, RandomBytes(decodedLength))
for i, char := range output {
if char == '+' || char == '/' {
_, err := rand.Read(output[i : i+1])
if err != nil {
panic(err)
}
output[i] = letters[int(output[i])%len(letters)]
}
}
return (*(*string)(unsafe.Pointer(&output)))[:n]
}
func base62Encode(val uint32, minWidth int) string {
var buf strings.Builder
for val > 0 {
buf.WriteByte(letters[val%62])
val /= 62
}
return strings.Repeat("0", minWidth-buf.Len()) + buf.String()
}
func RandomToken(namespace string, randomLength int) string {
token := namespace + "_" + RandomString(randomLength)
checksum := base62Encode(crc32.ChecksumIEEE([]byte(token)), 6)
return token + "_" + checksum
}

View File

@@ -1,23 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import "sync"
// ReturnableOnce is a wrapper for sync.Once that can return a value
type ReturnableOnce[Value any] struct {
once sync.Once
output Value
err error
}
func (ronce *ReturnableOnce[Value]) Do(fn func() (Value, error)) (Value, error) {
ronce.once.Do(func() {
ronce.output, ronce.err = fn()
})
return ronce.output, ronce.err
}

View File

@@ -1,139 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import (
"errors"
"sync"
)
type pair[Key comparable, Value any] struct {
Set bool
Key Key
Value Value
}
type RingBuffer[Key comparable, Value any] struct {
ptr int
data []pair[Key, Value]
lock sync.RWMutex
size int
}
func NewRingBuffer[Key comparable, Value any](size int) *RingBuffer[Key, Value] {
return &RingBuffer[Key, Value]{
data: make([]pair[Key, Value], size),
}
}
var (
// StopIteration can be returned by the RingBuffer.Iter or MapRingBuffer callbacks to stop iteration immediately.
StopIteration = errors.New("stop iteration")
// SkipItem can be returned by the MapRingBuffer callback to skip adding a specific item.
SkipItem = errors.New("skip item")
)
func (rb *RingBuffer[Key, Value]) unlockedIter(callback func(key Key, val Value) error) error {
end := rb.ptr
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
entry := rb.data[i]
if !entry.Set {
break
}
err := callback(entry.Key, entry.Value)
if err != nil {
if errors.Is(err, StopIteration) {
return nil
}
return err
}
}
return nil
}
func (rb *RingBuffer[Key, Value]) Iter(callback func(key Key, val Value) error) error {
rb.lock.RLock()
defer rb.lock.RUnlock()
return rb.unlockedIter(callback)
}
func MapRingBuffer[Key comparable, Value, Output any](rb *RingBuffer[Key, Value], callback func(key Key, val Value) (Output, error)) ([]Output, error) {
rb.lock.RLock()
defer rb.lock.RUnlock()
output := make([]Output, 0, rb.size)
err := rb.unlockedIter(func(key Key, val Value) error {
item, err := callback(key, val)
if err != nil {
if errors.Is(err, SkipItem) {
return nil
}
return err
}
output = append(output, item)
return nil
})
return output, err
}
func (rb *RingBuffer[Key, Value]) Size() int {
rb.lock.RLock()
defer rb.lock.RUnlock()
return rb.size
}
func (rb *RingBuffer[Key, Value]) Contains(val Key) bool {
_, ok := rb.Get(val)
return ok
}
func (rb *RingBuffer[Key, Value]) Get(key Key) (val Value, found bool) {
rb.lock.RLock()
end := rb.ptr
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
if rb.data[i].Key == key {
val = rb.data[i].Value
found = true
break
}
}
rb.lock.RUnlock()
return
}
func (rb *RingBuffer[Key, Value]) Replace(key Key, val Value) bool {
rb.lock.Lock()
defer rb.lock.Unlock()
end := rb.ptr
for i := clamp(end-1, len(rb.data)); i != end; i = clamp(i-1, len(rb.data)) {
if rb.data[i].Key == key {
rb.data[i].Value = val
return true
}
}
return false
}
func (rb *RingBuffer[Key, Value]) Push(key Key, val Value) {
rb.lock.Lock()
rb.data[rb.ptr] = pair[Key, Value]{Key: key, Value: val, Set: true}
rb.ptr = (rb.ptr + 1) % len(rb.data)
if rb.size < len(rb.data) {
rb.size++
}
rb.lock.Unlock()
}
func clamp(index, len int) int {
if index < 0 {
return len + index
} else if index >= len {
return len - index
} else {
return index
}
}

View File

@@ -1,94 +0,0 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package util
import "sync"
// SyncMap is a simple map with a built-in mutex.
type SyncMap[Key comparable, Value any] struct {
data map[Key]Value
lock sync.RWMutex
}
func NewSyncMap[Key comparable, Value any]() *SyncMap[Key, Value] {
return &SyncMap[Key, Value]{
data: make(map[Key]Value),
}
}
// Set stores a value in the map.
func (sm *SyncMap[Key, Value]) Set(key Key, value Value) {
sm.Swap(key, value)
}
// Swap sets a value in the map and returns the old value.
//
// The boolean return parameter is true if the value already existed, false if not.
func (sm *SyncMap[Key, Value]) Swap(key Key, value Value) (oldValue Value, wasReplaced bool) {
sm.lock.Lock()
oldValue, wasReplaced = sm.data[key]
sm.data[key] = value
sm.lock.Unlock()
return
}
// Delete removes a key from the map.
func (sm *SyncMap[Key, Value]) Delete(key Key) {
sm.Pop(key)
}
// Pop removes a key from the map and returns the old value.
//
// The boolean return parameter is the same as with normal Go map access (true if the key exists, false if not).
func (sm *SyncMap[Key, Value]) Pop(key Key) (value Value, ok bool) {
sm.lock.Lock()
value, ok = sm.data[key]
delete(sm.data, key)
sm.lock.Unlock()
return
}
// Get gets a value in the map.
//
// The boolean return parameter is the same as with normal Go map access (true if the key exists, false if not).
func (sm *SyncMap[Key, Value]) Get(key Key) (value Value, ok bool) {
sm.lock.RLock()
value, ok = sm.data[key]
sm.lock.RUnlock()
return
}
// GetOrSet gets a value in the map if the key already exists, otherwise inserts the given value and returns it.
//
// The boolean return parameter is true if the key already exists, and false if the given value was inserted.
func (sm *SyncMap[Key, Value]) GetOrSet(key Key, value Value) (actual Value, wasGet bool) {
sm.lock.Lock()
defer sm.lock.Unlock()
actual, wasGet = sm.data[key]
if wasGet {
return
}
sm.data[key] = value
actual = value
return
}
// Clone returns a copy of the map.
func (sm *SyncMap[Key, Value]) Clone() *SyncMap[Key, Value] {
return &SyncMap[Key, Value]{data: sm.CopyData()}
}
// CopyData returns a copy of the data in the map as a normal (non-atomic) map.
func (sm *SyncMap[Key, Value]) CopyData() map[Key]Value {
sm.lock.RLock()
copied := make(map[Key]Value, len(sm.data))
for key, value := range sm.data {
copied[key] = value
}
sm.lock.RUnlock()
return copied
}

View File

@@ -7,7 +7,7 @@ import (
"strings"
)
const Version = "v0.15.3"
const Version = "v0.16.1"
var GoModVersion = ""
var Commit = ""

View File

@@ -48,6 +48,26 @@ func (versions *RespVersions) GetLatest() (latest SpecVersion) {
return
}
type UnstableFeature struct {
UnstableFlag string
SpecVersion SpecVersion
}
var (
FeatureAppservicePing = UnstableFeature{UnstableFlag: "fi.mau.msc2659.stable", SpecVersion: SpecV17}
BeeperFeatureHungry = UnstableFeature{UnstableFlag: "com.beeper.hungry"}
BeeperFeatureBatchSending = UnstableFeature{UnstableFlag: "com.beeper.batch_sending"}
BeeperFeatureRoomYeeting = UnstableFeature{UnstableFlag: "com.beeper.room_yeeting"}
BeeperFeatureAutojoinInvites = UnstableFeature{UnstableFlag: "com.beeper.room_create_autojoin_invites"}
BeeperFeatureArbitraryProfileMeta = UnstableFeature{UnstableFlag: "com.beeper.arbitrary_profile_meta"}
)
func (versions *RespVersions) Supports(feature UnstableFeature) bool {
return versions.UnstableFeatures[feature.UnstableFlag] ||
(!feature.SpecVersion.IsEmpty() && versions.ContainsGreaterOrEqual(feature.SpecVersion))
}
type SpecVersionFormat int
const (
@@ -143,6 +163,10 @@ func (sv SpecVersion) String() string {
}
}
func (sv SpecVersion) IsEmpty() bool {
return sv.Format == SpecVersionFormatUnknown && sv.Raw == ""
}
func (sv SpecVersion) LessThan(other SpecVersion) bool {
return sv != other && !sv.GreaterThan(other)
}