refactor to mautrix 0.17.x; update deps

This commit is contained in:
Aine
2024-02-11 20:47:04 +02:00
parent 0a9701f4c9
commit dd0ad4c245
237 changed files with 9091 additions and 3317 deletions

View File

@@ -22,19 +22,19 @@ func parseMXIDpatterns(patterns []string, defaultPattern string) ([]*regexp.Rege
return mxidwc.ParsePatterns(patterns) return mxidwc.ParsePatterns(patterns)
} }
func (b *Bot) allowUsers(actorID id.UserID, targetRoomID id.RoomID) bool { func (b *Bot) allowUsers(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
// first, check if it's an allowed user // first, check if it's an allowed user
if mxidwc.Match(actorID.String(), b.allowedUsers) { if mxidwc.Match(actorID.String(), b.allowedUsers) {
return true return true
} }
// second, check if it's an admin (admin may not fit the allowed users pattern) // second, check if it's an admin (admin may not fit the allowed users pattern)
if b.allowAdmin(actorID, targetRoomID) { if b.allowAdmin(ctx, actorID, targetRoomID) {
return true return true
} }
// then, check if it's the owner (same as above) // then, check if it's the owner (same as above)
cfg, err := b.cfg.GetRoom(targetRoomID) cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err == nil && cfg.Owner() == actorID.String() { if err == nil && cfg.Owner() == actorID.String() {
return true return true
} }
@@ -42,15 +42,15 @@ func (b *Bot) allowUsers(actorID id.UserID, targetRoomID id.RoomID) bool {
return false return false
} }
func (b *Bot) allowAnyone(_ id.UserID, _ id.RoomID) bool { func (b *Bot) allowAnyone(_ context.Context, _ id.UserID, _ id.RoomID) bool {
return true return true
} }
func (b *Bot) allowOwner(actorID id.UserID, targetRoomID id.RoomID) bool { func (b *Bot) allowOwner(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(actorID, targetRoomID) { if !b.allowUsers(ctx, actorID, targetRoomID) {
return false return false
} }
cfg, err := b.cfg.GetRoom(targetRoomID) cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err != nil { if err != nil {
b.Error(context.Background(), "failed to retrieve settings: %v", err) b.Error(context.Background(), "failed to retrieve settings: %v", err)
return false return false
@@ -61,19 +61,19 @@ func (b *Bot) allowOwner(actorID id.UserID, targetRoomID id.RoomID) bool {
return true return true
} }
return owner == actorID.String() || b.allowAdmin(actorID, targetRoomID) return owner == actorID.String() || b.allowAdmin(ctx, actorID, targetRoomID)
} }
func (b *Bot) allowAdmin(actorID id.UserID, _ id.RoomID) bool { func (b *Bot) allowAdmin(_ context.Context, actorID id.UserID, _ id.RoomID) bool {
return mxidwc.Match(actorID.String(), b.allowedAdmins) return mxidwc.Match(actorID.String(), b.allowedAdmins)
} }
func (b *Bot) allowSend(actorID id.UserID, targetRoomID id.RoomID) bool { func (b *Bot) allowSend(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(actorID, targetRoomID) { if !b.allowUsers(ctx, actorID, targetRoomID) {
return false return false
} }
cfg, err := b.cfg.GetRoom(targetRoomID) cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err != nil { if err != nil {
b.Error(context.Background(), "failed to retrieve settings: %v", err) b.Error(context.Background(), "failed to retrieve settings: %v", err)
return false return false
@@ -82,14 +82,14 @@ func (b *Bot) allowSend(actorID id.UserID, targetRoomID id.RoomID) bool {
return !cfg.NoSend() return !cfg.NoSend()
} }
func (b *Bot) allowReply(actorID id.UserID, targetRoomID id.RoomID) bool { func (b *Bot) allowReply(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(actorID, targetRoomID) { if !b.allowUsers(ctx, actorID, targetRoomID) {
return false return false
} }
cfg, err := b.cfg.GetRoom(targetRoomID) cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err != nil { if err != nil {
b.Error(context.Background(), "failed to retrieve settings: %v", err) b.Error(ctx, "failed to retrieve settings: %v", err)
return false return false
} }
@@ -106,30 +106,30 @@ func (b *Bot) isReserved(mailbox string) bool {
} }
// IsGreylisted checks if host is in greylist // IsGreylisted checks if host is in greylist
func (b *Bot) IsGreylisted(addr net.Addr) bool { func (b *Bot) IsGreylisted(ctx context.Context, addr net.Addr) bool {
if b.cfg.GetBot().Greylist() == 0 { if b.cfg.GetBot(ctx).Greylist() == 0 {
return false return false
} }
greylist := b.cfg.GetGreylist() greylist := b.cfg.GetGreylist(ctx)
greylistedAt, ok := greylist.Get(addr) greylistedAt, ok := greylist.Get(addr)
if !ok { if !ok {
b.log.Debug().Str("addr", addr.String()).Msg("greylisting") b.log.Debug().Str("addr", addr.String()).Msg("greylisting")
greylist.Add(addr) greylist.Add(addr)
err := b.cfg.SetGreylist(greylist) err := b.cfg.SetGreylist(ctx, greylist)
if err != nil { if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update greylist") b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update greylist")
} }
return true return true
} }
duration := time.Duration(b.cfg.GetBot().Greylist()) * time.Minute duration := time.Duration(b.cfg.GetBot(ctx).Greylist()) * time.Minute
return greylistedAt.Add(duration).After(time.Now().UTC()) return greylistedAt.Add(duration).After(time.Now().UTC())
} }
// IsBanned checks if address is banned // IsBanned checks if address is banned
func (b *Bot) IsBanned(addr net.Addr) bool { func (b *Bot) IsBanned(ctx context.Context, addr net.Addr) bool {
return b.cfg.GetBanlist().Has(addr) return b.cfg.GetBanlist(ctx).Has(addr)
} }
// IsTrusted checks if address is a trusted (proxy) // IsTrusted checks if address is a trusted (proxy)
@@ -146,12 +146,12 @@ func (b *Bot) IsTrusted(addr net.Addr) bool {
} }
// Ban an address automatically // Ban an address automatically
func (b *Bot) BanAuto(addr net.Addr) { func (b *Bot) BanAuto(ctx context.Context, addr net.Addr) {
if !b.cfg.GetBot().BanlistEnabled() { if !b.cfg.GetBot(ctx).BanlistEnabled() {
return return
} }
if !b.cfg.GetBot().BanlistAuto() { if !b.cfg.GetBot(ctx).BanlistAuto() {
return return
} }
@@ -159,21 +159,21 @@ func (b *Bot) BanAuto(addr net.Addr) {
return return
} }
b.log.Debug().Str("addr", addr.String()).Msg("attempting to automatically ban") b.log.Debug().Str("addr", addr.String()).Msg("attempting to automatically ban")
banlist := b.cfg.GetBanlist() banlist := b.cfg.GetBanlist(ctx)
banlist.Add(addr) banlist.Add(addr)
err := b.cfg.SetBanlist(banlist) err := b.cfg.SetBanlist(ctx, banlist)
if err != nil { if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist") b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist")
} }
} }
// Ban an address for incorrect auth automatically // Ban an address for incorrect auth automatically
func (b *Bot) BanAuth(addr net.Addr) { func (b *Bot) BanAuth(ctx context.Context, addr net.Addr) {
if !b.cfg.GetBot().BanlistEnabled() { if !b.cfg.GetBot(ctx).BanlistEnabled() {
return return
} }
if !b.cfg.GetBot().BanlistAuth() { if !b.cfg.GetBot(ctx).BanlistAuth() {
return return
} }
@@ -181,33 +181,33 @@ func (b *Bot) BanAuth(addr net.Addr) {
return return
} }
b.log.Debug().Str("addr", addr.String()).Msg("attempting to automatically ban") b.log.Debug().Str("addr", addr.String()).Msg("attempting to automatically ban")
banlist := b.cfg.GetBanlist() banlist := b.cfg.GetBanlist(ctx)
banlist.Add(addr) banlist.Add(addr)
err := b.cfg.SetBanlist(banlist) err := b.cfg.SetBanlist(ctx, banlist)
if err != nil { if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist") b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist")
} }
} }
// Ban an address manually // Ban an address manually
func (b *Bot) BanManually(addr net.Addr) { func (b *Bot) BanManually(ctx context.Context, addr net.Addr) {
if !b.cfg.GetBot().BanlistEnabled() { if !b.cfg.GetBot(ctx).BanlistEnabled() {
return return
} }
if b.IsTrusted(addr) { if b.IsTrusted(addr) {
return return
} }
b.log.Debug().Str("addr", addr.String()).Msg("attempting to manually ban") b.log.Debug().Str("addr", addr.String()).Msg("attempting to manually ban")
banlist := b.cfg.GetBanlist() banlist := b.cfg.GetBanlist(ctx)
banlist.Add(addr) banlist.Add(addr)
err := b.cfg.SetBanlist(banlist) err := b.cfg.SetBanlist(ctx, banlist)
if err != nil { if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist") b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist")
} }
} }
// AllowAuth check if SMTP login (email) and password are valid // AllowAuth check if SMTP login (email) and password are valid
func (b *Bot) AllowAuth(email, password string) (id.RoomID, bool) { func (b *Bot) AllowAuth(ctx context.Context, email, password string) (id.RoomID, bool) {
var suffix bool var suffix bool
for _, domain := range b.domains { for _, domain := range b.domains {
if strings.HasSuffix(email, "@"+domain) { if strings.HasSuffix(email, "@"+domain) {
@@ -223,7 +223,7 @@ func (b *Bot) AllowAuth(email, password string) (id.RoomID, bool) {
if !ok { if !ok {
return "", false return "", false
} }
cfg, err := b.cfg.GetRoom(roomID) cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("failed to retrieve settings") b.log.Error().Err(err).Msg("failed to retrieve settings")
return "", false return "", false

View File

@@ -1,13 +1,14 @@
package bot package bot
import ( import (
"context"
"fmt" "fmt"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
type activationFlow func(id.UserID, id.RoomID, string) bool type activationFlow func(context.Context, id.UserID, id.RoomID, string) bool
func (b *Bot) getActivationFlow() activationFlow { func (b *Bot) getActivationFlow() activationFlow {
switch b.mbxc.Activation { switch b.mbxc.Activation {
@@ -21,19 +22,19 @@ func (b *Bot) getActivationFlow() activationFlow {
} }
// ActivateMailbox using the configured flow // ActivateMailbox using the configured flow
func (b *Bot) ActivateMailbox(ownerID id.UserID, roomID id.RoomID, mailbox string) bool { func (b *Bot) ActivateMailbox(ctx context.Context, ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
flow := b.getActivationFlow() flow := b.getActivationFlow()
return flow(ownerID, roomID, mailbox) return flow(ctx, ownerID, roomID, mailbox)
} }
func (b *Bot) activateNone(ownerID id.UserID, roomID id.RoomID, mailbox string) bool { func (b *Bot) activateNone(_ context.Context, ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
b.log.Debug().Str("mailbox", mailbox).Str("roomID", roomID.String()).Str("ownerID", ownerID.String()).Msg("activating mailbox through the flow 'none'") b.log.Debug().Str("mailbox", mailbox).Str("roomID", roomID.String()).Str("ownerID", ownerID.String()).Msg("activating mailbox through the flow 'none'")
b.rooms.Store(mailbox, roomID) b.rooms.Store(mailbox, roomID)
return true return true
} }
func (b *Bot) activateNotify(ownerID id.UserID, roomID id.RoomID, mailbox string) bool { func (b *Bot) activateNotify(ctx context.Context, ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
b.log.Debug().Str("mailbox", mailbox).Str("roomID", roomID.String()).Str("ownerID", ownerID.String()).Msg("activating mailbox through the flow 'notify'") b.log.Debug().Str("mailbox", mailbox).Str("roomID", roomID.String()).Str("ownerID", ownerID.String()).Msg("activating mailbox through the flow 'notify'")
b.rooms.Store(mailbox, roomID) b.rooms.Store(mailbox, roomID)
if len(b.adminRooms) == 0 { if len(b.adminRooms) == 0 {
@@ -43,7 +44,7 @@ func (b *Bot) activateNotify(ownerID id.UserID, roomID id.RoomID, mailbox string
msg := fmt.Sprintf("Mailbox %q has been registered by %q for the room %q", mailbox, ownerID, roomID) msg := fmt.Sprintf("Mailbox %q has been registered by %q for the room %q", mailbox, ownerID, roomID)
for _, adminRoom := range b.adminRooms { for _, adminRoom := range b.adminRooms {
content := format.RenderMarkdown(msg, true, true) content := format.RenderMarkdown(msg, true, true)
_, err := b.lp.Send(adminRoom, &content) _, err := b.lp.Send(ctx, adminRoom, &content)
if err != nil { if err != nil {
b.log.Info().Str("adminRoom", adminRoom.String()).Msg("cannot send mailbox activation notification to the admin room") b.log.Info().Str("adminRoom", adminRoom.String()).Msg("cannot send mailbox activation notification to the admin room")
continue continue

View File

@@ -36,6 +36,7 @@ type Bot struct {
rooms sync.Map rooms sync.Map
proxies []string proxies []string
sendmail func(string, string, string) error sendmail func(string, string, string) error
psd *utils.PSD
cfg *config.Manager cfg *config.Manager
log *zerolog.Logger log *zerolog.Logger
lp *linkpearl.Linkpearl lp *linkpearl.Linkpearl
@@ -50,6 +51,7 @@ func New(
lp *linkpearl.Linkpearl, lp *linkpearl.Linkpearl,
log *zerolog.Logger, log *zerolog.Logger,
cfg *config.Manager, cfg *config.Manager,
psd *utils.PSD,
proxies []string, proxies []string,
prefix string, prefix string,
domains []string, domains []string,
@@ -63,13 +65,14 @@ func New(
adminRooms: []id.RoomID{}, adminRooms: []id.RoomID{},
proxies: proxies, proxies: proxies,
mbxc: mbxc, mbxc: mbxc,
psd: psd,
cfg: cfg, cfg: cfg,
log: log, log: log,
lp: lp, lp: lp,
mu: utils.NewMutex(), mu: utils.NewMutex(),
q: q, q: q,
} }
users, err := b.initBotUsers() users, err := b.initBotUsers(context.Background())
if err != nil { if err != nil {
return nil, err return nil, err
} }
@@ -105,7 +108,7 @@ func (b *Bot) Error(ctx context.Context, message string, args ...any) {
} }
var noThreads bool var noThreads bool
cfg, cerr := b.cfg.GetRoom(evt.RoomID) cfg, cerr := b.cfg.GetRoom(ctx, evt.RoomID)
if cerr == nil { if cerr == nil {
noThreads = cfg.NoThreads() noThreads = cfg.NoThreads()
} }
@@ -115,16 +118,17 @@ func (b *Bot) Error(ctx context.Context, message string, args ...any) {
relatesTo = linkpearl.RelatesTo(threadID, noThreads) relatesTo = linkpearl.RelatesTo(threadID, noThreads)
} }
b.lp.SendNotice(evt.RoomID, "ERROR: "+err.Error(), relatesTo) b.lp.SendNotice(ctx, evt.RoomID, "ERROR: "+err.Error(), relatesTo)
} }
// Start performs matrix /sync // Start performs matrix /sync
func (b *Bot) Start(statusMsg string) error { func (b *Bot) Start(statusMsg string) error {
if err := b.migrateMautrix015(); err != nil { ctx := context.Background()
if err := b.migrateMautrix015(ctx); err != nil {
return err return err
} }
if err := b.syncRooms(); err != nil { if err := b.syncRooms(ctx); err != nil {
return err return err
} }
@@ -135,7 +139,8 @@ func (b *Bot) Start(statusMsg string) error {
// Stop the bot // Stop the bot
func (b *Bot) Stop() { func (b *Bot) Stop() {
err := b.lp.GetClient().SetPresence(event.PresenceOffline) ctx := context.Background()
err := b.lp.GetClient().SetPresence(ctx, event.PresenceOffline)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot set presence = offline") b.log.Error().Err(err).Msg("cannot set presence = offline")
} }

View File

@@ -46,7 +46,7 @@ type (
key string key string
description string description string
sanitizer func(string) string sanitizer func(string) string
allowed func(id.UserID, id.RoomID) bool allowed func(context.Context, id.UserID, id.RoomID) bool
} }
commandList []command commandList []command
) )
@@ -351,7 +351,7 @@ func (b *Bot) initCommands() commandList {
func (b *Bot) handle(ctx context.Context) { func (b *Bot) handle(ctx context.Context) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
err := b.lp.GetClient().MarkRead(evt.RoomID, evt.ID) err := b.lp.GetClient().MarkRead(ctx, evt.RoomID, evt.ID)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot send read receipt") b.log.Error().Err(err).Msg("cannot send read receipt")
} }
@@ -378,14 +378,14 @@ func (b *Bot) handle(ctx context.Context) {
if cmd == nil { if cmd == nil {
return return
} }
_, err = b.lp.GetClient().UserTyping(evt.RoomID, true, 30*time.Second) _, err = b.lp.GetClient().UserTyping(ctx, evt.RoomID, true, 30*time.Second)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot send typing notification") b.log.Error().Err(err).Msg("cannot send typing notification")
} }
defer b.lp.GetClient().UserTyping(evt.RoomID, false, 30*time.Second) //nolint:errcheck // we don't care defer b.lp.GetClient().UserTyping(ctx, evt.RoomID, false, 30*time.Second) //nolint:errcheck // we don't care
if !cmd.allowed(evt.Sender, evt.RoomID) { if !cmd.allowed(ctx, evt.Sender, evt.RoomID) {
b.lp.SendNotice(evt.RoomID, "not allowed to do that, kupo") b.lp.SendNotice(ctx, evt.RoomID, "not allowed to do that, kupo")
return return
} }
@@ -452,7 +452,7 @@ func (b *Bot) parseCommand(message string, toLower bool) []string {
return strings.Split(strings.TrimSpace(message), " ") return strings.Split(strings.TrimSpace(message), " ")
} }
func (b *Bot) sendIntroduction(roomID id.RoomID) { func (b *Bot) sendIntroduction(ctx context.Context, roomID id.RoomID) {
var msg strings.Builder var msg strings.Builder
msg.WriteString("Hello, kupo!\n\n") msg.WriteString("Hello, kupo!\n\n")
@@ -468,7 +468,7 @@ func (b *Bot) sendIntroduction(roomID id.RoomID) {
msg.WriteString(utils.EmailsList("SOME_INBOX", "")) msg.WriteString(utils.EmailsList("SOME_INBOX", ""))
msg.WriteString("` and have them appear in this room.") msg.WriteString("` and have them appear in this room.")
b.lp.SendNotice(roomID, msg.String()) b.lp.SendNotice(ctx, roomID, msg.String())
} }
func (b *Bot) getHelpValue(cfg config.Room, cmd command) string { func (b *Bot) getHelpValue(cfg config.Room, cmd command) string {
@@ -497,7 +497,7 @@ func (b *Bot) getHelpValue(cfg config.Room, cmd command) string {
func (b *Bot) sendHelp(ctx context.Context) { func (b *Bot) sendHelp(ctx context.Context) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg, serr := b.cfg.GetRoom(evt.RoomID) cfg, serr := b.cfg.GetRoom(ctx, evt.RoomID)
if serr != nil { if serr != nil {
b.log.Error().Err(serr).Msg("cannot retrieve settings") b.log.Error().Err(serr).Msg("cannot retrieve settings")
} }
@@ -505,7 +505,7 @@ func (b *Bot) sendHelp(ctx context.Context) {
var msg strings.Builder var msg strings.Builder
msg.WriteString("The following commands are supported and accessible to you:\n\n") msg.WriteString("The following commands are supported and accessible to you:\n\n")
for _, cmd := range b.commands { for _, cmd := range b.commands {
if !cmd.allowed(evt.Sender, evt.RoomID) { if !cmd.allowed(ctx, evt.Sender, evt.RoomID) {
continue continue
} }
if cmd.key == "" { if cmd.key == "" {
@@ -528,7 +528,7 @@ func (b *Bot) sendHelp(ctx context.Context) {
msg.WriteString("\n") msg.WriteString("\n")
} }
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
func (b *Bot) runSend(ctx context.Context) { func (b *Bot) runSend(ctx context.Context) {
@@ -538,7 +538,7 @@ func (b *Bot) runSend(ctx context.Context) {
return return
} }
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "failed to retrieve room settings: %v", err) b.Error(ctx, "failed to retrieve room settings: %v", err)
return return
@@ -555,11 +555,11 @@ func (b *Bot) runSend(ctx context.Context) {
func (b *Bot) getSendDetails(ctx context.Context) (to, subject, body string, ok bool) { func (b *Bot) getSendDetails(ctx context.Context) (to, subject, body string, ok bool) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
if !b.allowSend(evt.Sender, evt.RoomID) { if !b.allowSend(ctx, evt.Sender, evt.RoomID) {
return "", "", "", false return "", "", "", false
} }
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "failed to retrieve room settings: %v", err) b.Error(ctx, "failed to retrieve room settings: %v", err)
return "", "", "", false return "", "", "", false
@@ -568,7 +568,7 @@ func (b *Bot) getSendDetails(ctx context.Context) (to, subject, body string, ok
commandSlice := b.parseCommand(evt.Content.AsMessage().Body, false) commandSlice := b.parseCommand(evt.Content.AsMessage().Body, false)
to, subject, body, err = utils.ParseSend(commandSlice) to, subject, body, err = utils.ParseSend(commandSlice)
if errors.Is(err, utils.ErrInvalidArgs) { if errors.Is(err, utils.ErrInvalidArgs) {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf( b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf(
"Usage:\n"+ "Usage:\n"+
"```\n"+ "```\n"+
"%s send someone@example.com\n"+ "%s send someone@example.com\n"+
@@ -585,7 +585,7 @@ func (b *Bot) getSendDetails(ctx context.Context) (to, subject, body string, ok
mailbox := cfg.Mailbox() mailbox := cfg.Mailbox()
if mailbox == "" { if mailbox == "" {
b.lp.SendNotice(evt.RoomID, "mailbox is not configured, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "mailbox is not configured, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return "", "", "", false return "", "", "", false
} }
@@ -608,8 +608,8 @@ func (b *Bot) runSendCommand(ctx context.Context, cfg config.Room, tos []string,
} }
} }
b.lock(evt.RoomID, evt.ID) b.lock(ctx, evt.RoomID, evt.ID)
defer b.unlock(evt.RoomID, evt.ID) defer b.unlock(ctx, evt.RoomID, evt.ID)
domain := utils.SanitizeDomain(cfg.Domain()) domain := utils.SanitizeDomain(cfg.Domain())
from := cfg.Mailbox() + "@" + domain from := cfg.Mailbox() + "@" + domain
@@ -617,12 +617,12 @@ func (b *Bot) runSendCommand(ctx context.Context, cfg config.Room, tos []string,
for _, to := range tos { for _, to := range tos {
recipients := []string{to} recipients := []string{to}
eml := email.New(ID, "", " "+ID, subject, from, to, to, "", body, htmlBody, nil, nil) eml := email.New(ID, "", " "+ID, subject, from, to, to, "", body, htmlBody, nil, nil)
data := eml.Compose(b.cfg.GetBot().DKIMPrivateKey()) data := eml.Compose(b.cfg.GetBot(ctx).DKIMPrivateKey())
if data == "" { if data == "" {
b.lp.SendNotice(evt.RoomID, "email body is empty", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "email body is empty", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return return
} }
queued, err := b.Sendmail(evt.ID, from, to, data) queued, err := b.Sendmail(ctx, evt.ID, from, to, data)
if queued { if queued {
b.log.Warn().Err(err).Msg("email has been queued") b.log.Warn().Err(err).Msg("email has been queued")
b.saveSentMetadata(ctx, queued, evt.ID, recipients, eml, cfg) b.saveSentMetadata(ctx, queued, evt.ID, recipients, eml, cfg)
@@ -635,6 +635,6 @@ func (b *Bot) runSendCommand(ctx context.Context, cfg config.Room, tos []string,
b.saveSentMetadata(ctx, false, evt.ID, recipients, eml, cfg) b.saveSentMetadata(ctx, false, evt.ID, recipients, eml, cfg)
} }
if len(tos) > 1 { if len(tos) > 1 {
b.lp.SendNotice(evt.RoomID, "All emails were sent.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "All emails were sent.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
} }

View File

@@ -37,7 +37,7 @@ func (b *Bot) sendMailboxes(ctx context.Context) {
if !ok { if !ok {
return true return true
} }
cfg, err := b.cfg.GetRoom(roomID) cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot retrieve settings") b.log.Error().Err(err).Msg("cannot retrieve settings")
} }
@@ -49,7 +49,7 @@ func (b *Bot) sendMailboxes(ctx context.Context) {
sort.Strings(slice) sort.Strings(slice)
if len(slice) == 0 { if len(slice) == 0 {
b.lp.SendNotice(evt.RoomID, "No mailboxes are managed by the bot so far, kupo!", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "No mailboxes are managed by the bot so far, kupo!", linkpearl.RelatesTo(evt.ID))
return return
} }
@@ -64,20 +64,20 @@ func (b *Bot) sendMailboxes(ctx context.Context) {
msg.WriteString("\n") msg.WriteString("\n")
} }
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runDelete(ctx context.Context, commandSlice []string) { func (b *Bot) runDelete(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
if len(commandSlice) < 2 { if len(commandSlice) < 2 {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Usage: `%s delete MAILBOX`", b.prefix), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Usage: `%s delete MAILBOX`", b.prefix), linkpearl.RelatesTo(evt.ID))
return return
} }
mailbox := utils.Mailbox(commandSlice[1]) mailbox := utils.Mailbox(commandSlice[1])
v, ok := b.rooms.Load(mailbox) v, ok := b.rooms.Load(mailbox)
if v == nil || !ok { if v == nil || !ok {
b.lp.SendNotice(evt.RoomID, "mailbox does not exists, kupo", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "mailbox does not exists, kupo", linkpearl.RelatesTo(evt.ID))
return return
} }
roomID, ok := v.(id.RoomID) roomID, ok := v.(id.RoomID)
@@ -86,18 +86,18 @@ func (b *Bot) runDelete(ctx context.Context, commandSlice []string) {
} }
b.rooms.Delete(mailbox) b.rooms.Delete(mailbox)
err := b.cfg.SetRoom(roomID, config.Room{}) err := b.cfg.SetRoom(ctx, roomID, config.Room{})
if err != nil { if err != nil {
b.Error(ctx, "cannot update settings: %v", err) b.Error(ctx, "cannot update settings: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, "mailbox has been deleted", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "mailbox has been deleted", linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runUsers(ctx context.Context, commandSlice []string) { func (b *Bot) runUsers(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 { if len(commandSlice) < 2 {
var msg strings.Builder var msg strings.Builder
users := cfg.Users() users := cfg.Users()
@@ -112,35 +112,35 @@ func (b *Bot) runUsers(ctx context.Context, commandSlice []string) {
msg.WriteString("where each pattern is like `@someone:example.com`, ") msg.WriteString("where each pattern is like `@someone:example.com`, ")
msg.WriteString("`@bot.*:example.com`, `@*:another.com`, or `@*:*`\n") msg.WriteString("`@bot.*:example.com`, `@*:another.com`, or `@*:*`\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return return
} }
_, homeserver, err := b.lp.GetClient().UserID.Parse() _, homeserver, err := b.lp.GetClient().UserID.Parse()
if err != nil { if err != nil {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("invalid userID: %v", err), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("invalid userID: %v", err), linkpearl.RelatesTo(evt.ID))
} }
patterns := commandSlice[1:] patterns := commandSlice[1:]
allowedUsers, err := parseMXIDpatterns(patterns, "@*:"+homeserver) allowedUsers, err := parseMXIDpatterns(patterns, "@*:"+homeserver)
if err != nil { if err != nil {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("invalid patterns: %v", err), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("invalid patterns: %v", err), linkpearl.RelatesTo(evt.ID))
return return
} }
cfg.Set(config.BotUsers, strings.Join(patterns, " ")) cfg.Set(config.BotUsers, strings.Join(patterns, " "))
err = b.cfg.SetBot(cfg) err = b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot set bot config: %v", err) b.Error(ctx, "cannot set bot config: %v", err)
} }
b.allowedUsers = allowedUsers b.allowedUsers = allowedUsers
b.lp.SendNotice(evt.RoomID, "allowed users updated", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "allowed users updated", linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runDKIM(ctx context.Context, commandSlice []string) { func (b *Bot) runDKIM(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
if len(commandSlice) > 1 && commandSlice[1] == "reset" { if len(commandSlice) > 1 && commandSlice[1] == "reset" {
cfg.Set(config.BotDKIMPrivateKey, "") cfg.Set(config.BotDKIMPrivateKey, "")
cfg.Set(config.BotDKIMSignature, "") cfg.Set(config.BotDKIMSignature, "")
@@ -157,14 +157,14 @@ func (b *Bot) runDKIM(ctx context.Context, commandSlice []string) {
} }
cfg.Set(config.BotDKIMSignature, signature) cfg.Set(config.BotDKIMSignature, signature)
cfg.Set(config.BotDKIMPrivateKey, private) cfg.Set(config.BotDKIMPrivateKey, private)
err := b.cfg.SetBot(cfg) err := b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot save bot options: %v", err) b.Error(ctx, "cannot save bot options: %v", err)
return return
} }
} }
b.lp.SendNotice(evt.RoomID, fmt.Sprintf( b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf(
"DKIM signature is: `%s`.\n"+ "DKIM signature is: `%s`.\n"+
"You need to add it to DNS records of all domains added to postmoogle (if not already):\n"+ "You need to add it to DNS records of all domains added to postmoogle (if not already):\n"+
"Add new DNS record with type = `TXT`, key (subdomain/from): `postmoogle._domainkey` and value (to):\n ```\n%s\n```\n"+ "Add new DNS record with type = `TXT`, key (subdomain/from): `postmoogle._domainkey` and value (to):\n ```\n%s\n```\n"+
@@ -177,7 +177,7 @@ func (b *Bot) runDKIM(ctx context.Context, commandSlice []string) {
func (b *Bot) runCatchAll(ctx context.Context, commandSlice []string) { func (b *Bot) runCatchAll(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 { if len(commandSlice) < 2 {
var msg strings.Builder var msg strings.Builder
msg.WriteString("Currently: `") msg.WriteString("Currently: `")
@@ -195,30 +195,30 @@ func (b *Bot) runCatchAll(ctx context.Context, commandSlice []string) {
msg.WriteString(" catch-all MAILBOX`") msg.WriteString(" catch-all MAILBOX`")
msg.WriteString("where mailbox is valid and existing mailbox name\n") msg.WriteString("where mailbox is valid and existing mailbox name\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return return
} }
mailbox := utils.Mailbox(commandSlice[1]) mailbox := utils.Mailbox(commandSlice[1])
_, ok := b.GetMapping(mailbox) _, ok := b.GetMapping(ctx, mailbox)
if !ok { if !ok {
b.lp.SendNotice(evt.RoomID, "mailbox does not exist, kupo.", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "mailbox does not exist, kupo.", linkpearl.RelatesTo(evt.ID))
return return
} }
cfg.Set(config.BotCatchAll, mailbox) cfg.Set(config.BotCatchAll, mailbox)
err := b.cfg.SetBot(cfg) err := b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot save bot options: %v", err) b.Error(ctx, "cannot save bot options: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Catch-all is set to: `%s` (%s).", mailbox, utils.EmailsList(mailbox, "")), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Catch-all is set to: `%s` (%s).", mailbox, utils.EmailsList(mailbox, "")), linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runAdminRoom(ctx context.Context, commandSlice []string) { func (b *Bot) runAdminRoom(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 { if len(commandSlice) < 2 {
var msg strings.Builder var msg strings.Builder
msg.WriteString("Currently: `") msg.WriteString("Currently: `")
@@ -233,13 +233,13 @@ func (b *Bot) runAdminRoom(ctx context.Context, commandSlice []string) {
msg.WriteString(" adminroom ROOM_ID`") msg.WriteString(" adminroom ROOM_ID`")
msg.WriteString("where ROOM_ID is valid and existing matrix room id\n") msg.WriteString("where ROOM_ID is valid and existing matrix room id\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return return
} }
roomID := b.parseCommand(evt.Content.AsMessage().Body, false)[1] // get original value, without forced lower case roomID := b.parseCommand(evt.Content.AsMessage().Body, false)[1] // get original value, without forced lower case
cfg.Set(config.BotAdminRoom, roomID) cfg.Set(config.BotAdminRoom, roomID)
err := b.cfg.SetBot(cfg) err := b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot save bot options: %v", err) b.Error(ctx, "cannot save bot options: %v", err)
return return
@@ -247,12 +247,12 @@ func (b *Bot) runAdminRoom(ctx context.Context, commandSlice []string) {
b.adminRooms = append([]id.RoomID{id.RoomID(roomID)}, b.adminRooms...) // make it the first room in list on the fly b.adminRooms = append([]id.RoomID{id.RoomID(roomID)}, b.adminRooms...) // make it the first room in list on the fly
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Admin Room is set to: `%s`.", roomID), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Admin Room is set to: `%s`.", roomID), linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) printGreylist(ctx context.Context, roomID id.RoomID) { func (b *Bot) printGreylist(ctx context.Context, roomID id.RoomID) {
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
greylist := b.cfg.GetGreylist() greylist := b.cfg.GetGreylist(ctx)
var msg strings.Builder var msg strings.Builder
size := len(greylist) size := len(greylist)
duration := cfg.Greylist() duration := cfg.Greylist()
@@ -278,7 +278,7 @@ func (b *Bot) printGreylist(ctx context.Context, roomID id.RoomID) {
msg.WriteString("where `MIN` is duration in minutes for automatic greylisting\n") msg.WriteString("where `MIN` is duration in minutes for automatic greylisting\n")
} }
b.lp.SendNotice(roomID, msg.String(), linkpearl.RelatesTo(eventFromContext(ctx).ID)) b.lp.SendNotice(ctx, roomID, msg.String(), linkpearl.RelatesTo(eventFromContext(ctx).ID))
} }
func (b *Bot) runGreylist(ctx context.Context, commandSlice []string) { func (b *Bot) runGreylist(ctx context.Context, commandSlice []string) {
@@ -287,21 +287,21 @@ func (b *Bot) runGreylist(ctx context.Context, commandSlice []string) {
b.printGreylist(ctx, evt.RoomID) b.printGreylist(ctx, evt.RoomID)
return return
} }
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
value := utils.SanitizeIntString(commandSlice[1]) value := utils.SanitizeIntString(commandSlice[1])
cfg.Set(config.BotGreylist, value) cfg.Set(config.BotGreylist, value)
err := b.cfg.SetBot(cfg) err := b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot set bot config: %v", err) b.Error(ctx, "cannot set bot config: %v", err)
} }
b.lp.SendNotice(evt.RoomID, "greylist duration has been updated", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "greylist duration has been updated", linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runBanlist(ctx context.Context, commandSlice []string) { func (b *Bot) runBanlist(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 { if len(commandSlice) < 2 {
banlist := b.cfg.GetBanlist() banlist := b.cfg.GetBanlist(ctx)
var msg strings.Builder var msg strings.Builder
size := len(banlist) size := len(banlist)
if size > 0 { if size > 0 {
@@ -322,26 +322,26 @@ func (b *Bot) runBanlist(ctx context.Context, commandSlice []string) {
msg.WriteString("where each ip is IPv4 or IPv6\n\n") msg.WriteString("where each ip is IPv4 or IPv6\n\n")
msg.WriteString("You can find current banlist values below:\n") msg.WriteString("You can find current banlist values below:\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.addBanlistTimeline(ctx, false) b.addBanlistTimeline(ctx, false)
return return
} }
value := utils.SanitizeBoolString(commandSlice[1]) value := utils.SanitizeBoolString(commandSlice[1])
cfg.Set(config.BotBanlistEnabled, value) cfg.Set(config.BotBanlistEnabled, value)
err := b.cfg.SetBot(cfg) err := b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot set bot config: %v", err) b.Error(ctx, "cannot set bot config: %v", err)
} }
b.lp.SendNotice(evt.RoomID, "banlist has been updated", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "banlist has been updated", linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runBanlistTotals(ctx context.Context) { func (b *Bot) runBanlistTotals(ctx context.Context) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
banlist := b.cfg.GetBanlist() banlist := b.cfg.GetBanlist(ctx)
var msg strings.Builder var msg strings.Builder
size := len(banlist) size := len(banlist)
if size == 0 { if size == 0 {
b.lp.SendNotice(evt.RoomID, "banlist is empty, kupo.", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "banlist is empty, kupo.", linkpearl.RelatesTo(evt.ID))
return return
} }
@@ -349,13 +349,13 @@ func (b *Bot) runBanlistTotals(ctx context.Context) {
msg.WriteString(strconv.Itoa(size)) msg.WriteString(strconv.Itoa(size))
msg.WriteString(" hosts banned\n\n") msg.WriteString(" hosts banned\n\n")
msg.WriteString("You can find daily totals below:\n") msg.WriteString("You can find daily totals below:\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.addBanlistTimeline(ctx, true) b.addBanlistTimeline(ctx, true)
} }
func (b *Bot) runBanlistAuth(ctx context.Context, commandSlice []string) { //nolint:dupl // not in that case func (b *Bot) runBanlistAuth(ctx context.Context, commandSlice []string) { //nolint:dupl // not in that case
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 { if len(commandSlice) < 2 {
var msg strings.Builder var msg strings.Builder
msg.WriteString("Currently: `") msg.WriteString("Currently: `")
@@ -368,21 +368,21 @@ func (b *Bot) runBanlistAuth(ctx context.Context, commandSlice []string) { //nol
msg.WriteString(" banlist:auth true` (banlist itself must be enabled!)\n\n") msg.WriteString(" banlist:auth true` (banlist itself must be enabled!)\n\n")
} }
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return return
} }
value := utils.SanitizeBoolString(commandSlice[1]) value := utils.SanitizeBoolString(commandSlice[1])
cfg.Set(config.BotBanlistAuth, value) cfg.Set(config.BotBanlistAuth, value)
err := b.cfg.SetBot(cfg) err := b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot set bot config: %v", err) b.Error(ctx, "cannot set bot config: %v", err)
} }
b.lp.SendNotice(evt.RoomID, "auth banning has been updated", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "auth banning has been updated", linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runBanlistAuto(ctx context.Context, commandSlice []string) { //nolint:dupl // not in that case func (b *Bot) runBanlistAuto(ctx context.Context, commandSlice []string) { //nolint:dupl // not in that case
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 { if len(commandSlice) < 2 {
var msg strings.Builder var msg strings.Builder
msg.WriteString("Currently: `") msg.WriteString("Currently: `")
@@ -395,16 +395,16 @@ func (b *Bot) runBanlistAuto(ctx context.Context, commandSlice []string) { //nol
msg.WriteString(" banlist:auto true` (banlist itself must be enabled!)\n\n") msg.WriteString(" banlist:auto true` (banlist itself must be enabled!)\n\n")
} }
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return return
} }
value := utils.SanitizeBoolString(commandSlice[1]) value := utils.SanitizeBoolString(commandSlice[1])
cfg.Set(config.BotBanlistAuto, value) cfg.Set(config.BotBanlistAuto, value)
err := b.cfg.SetBot(cfg) err := b.cfg.SetBot(ctx, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot set bot config: %v", err) b.Error(ctx, "cannot set bot config: %v", err)
} }
b.lp.SendNotice(evt.RoomID, "auto banning has been updated", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "auto banning has been updated", linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) runBanlistChange(ctx context.Context, mode string, commandSlice []string) { func (b *Bot) runBanlistChange(ctx context.Context, mode string, commandSlice []string) {
@@ -413,11 +413,11 @@ func (b *Bot) runBanlistChange(ctx context.Context, mode string, commandSlice []
b.runBanlist(ctx, commandSlice) b.runBanlist(ctx, commandSlice)
return return
} }
if !b.cfg.GetBot().BanlistEnabled() { if !b.cfg.GetBot(ctx).BanlistEnabled() {
b.lp.SendNotice(evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID))
return return
} }
banlist := b.cfg.GetBanlist() banlist := b.cfg.GetBanlist(ctx)
var action func(net.Addr) var action func(net.Addr)
if mode == "remove" { if mode == "remove" {
@@ -436,18 +436,18 @@ func (b *Bot) runBanlistChange(ctx context.Context, mode string, commandSlice []
action(addr) action(addr)
} }
err := b.cfg.SetBanlist(banlist) err := b.cfg.SetBanlist(ctx, banlist)
if err != nil { if err != nil {
b.Error(ctx, "cannot set banlist: %v", err) b.Error(ctx, "cannot set banlist: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, "banlist has been updated, kupo", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "banlist has been updated, kupo", linkpearl.RelatesTo(evt.ID))
} }
func (b *Bot) addBanlistTimeline(ctx context.Context, onlyTotals bool) { func (b *Bot) addBanlistTimeline(ctx context.Context, onlyTotals bool) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
banlist := b.cfg.GetBanlist() banlist := b.cfg.GetBanlist(ctx)
timeline := map[string][]string{} timeline := map[string][]string{}
for ip, ts := range banlist { for ip, ts := range banlist {
key := "???" key := "???"
@@ -479,22 +479,22 @@ func (b *Bot) addBanlistTimeline(ctx context.Context, onlyTotals bool) {
txt.WriteString(strings.Join(data, "`, `")) txt.WriteString(strings.Join(data, "`, `"))
txt.WriteString("`\n") txt.WriteString("`\n")
} }
b.lp.SendNotice(evt.RoomID, txt.String(), linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, txt.String(), linkpearl.RelatesTo(evt.ID))
} }
} }
func (b *Bot) runBanlistReset(ctx context.Context) { func (b *Bot) runBanlistReset(ctx context.Context) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
if !b.cfg.GetBot().BanlistEnabled() { if !b.cfg.GetBot(ctx).BanlistEnabled() {
b.lp.SendNotice(evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID))
return return
} }
err := b.cfg.SetBanlist(config.List{}) err := b.cfg.SetBanlist(ctx, config.List{})
if err != nil { if err != nil {
b.Error(ctx, "cannot set banlist: %v", err) b.Error(ctx, "cannot set banlist: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, "banlist has been reset, kupo", linkpearl.RelatesTo(evt.ID)) b.lp.SendNotice(ctx, evt.RoomID, "banlist has been reset, kupo", linkpearl.RelatesTo(evt.ID))
} }

View File

@@ -16,7 +16,7 @@ import (
func (b *Bot) runStop(ctx context.Context) { func (b *Bot) runStop(ctx context.Context) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err) b.Error(ctx, "failed to retrieve settings: %v", err)
return return
@@ -24,19 +24,19 @@ func (b *Bot) runStop(ctx context.Context) {
mailbox := cfg.Get(config.RoomMailbox) mailbox := cfg.Get(config.RoomMailbox)
if mailbox == "" { if mailbox == "" {
b.lp.SendNotice(evt.RoomID, "that room is not configured yet", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "that room is not configured yet", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return return
} }
b.rooms.Delete(mailbox) b.rooms.Delete(mailbox)
err = b.cfg.SetRoom(evt.RoomID, config.Room{}) err = b.cfg.SetRoom(ctx, evt.RoomID, config.Room{})
if err != nil { if err != nil {
b.Error(ctx, "cannot update settings: %v", err) b.Error(ctx, "cannot update settings: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, "mailbox has been disabled", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "mailbox has been disabled", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
func (b *Bot) handleOption(ctx context.Context, cmd []string) { func (b *Bot) handleOption(ctx context.Context, cmd []string) {
@@ -58,7 +58,7 @@ func (b *Bot) handleOption(ctx context.Context, cmd []string) {
func (b *Bot) getOption(ctx context.Context, name string) { func (b *Bot) getOption(ctx context.Context, name string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err) b.Error(ctx, "failed to retrieve settings: %v", err)
return return
@@ -73,7 +73,7 @@ func (b *Bot) getOption(ctx context.Context, name string) {
msg := fmt.Sprintf("`%s` is not set, kupo.\n"+ msg := fmt.Sprintf("`%s` is not set, kupo.\n"+
"To set it, send a `%s %s VALUE` command.", "To set it, send a `%s %s VALUE` command.",
name, b.prefix, name) name, b.prefix, name)
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return return
} }
@@ -91,18 +91,18 @@ func (b *Bot) getOption(ctx context.Context, name string) {
"or just set a new one with `%s %s NEW_PASSWORD`.", "or just set a new one with `%s %s NEW_PASSWORD`.",
b.prefix, name) b.prefix, name)
} }
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
func (b *Bot) setMailbox(ctx context.Context, value string) { func (b *Bot) setMailbox(ctx context.Context, value string) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
existingID, ok := b.getMapping(value) existingID, ok := b.getMapping(value)
if (ok && existingID != "" && existingID != evt.RoomID) || b.isReserved(value) { if (ok && existingID != "" && existingID != evt.RoomID) || b.isReserved(value) {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Mailbox `%s` (%s) already taken, kupo", value, utils.EmailsList(value, ""))) b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Mailbox `%s` (%s) already taken, kupo", value, utils.EmailsList(value, "")))
return return
} }
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err) b.Error(ctx, "failed to retrieve settings: %v", err)
return return
@@ -113,23 +113,23 @@ func (b *Bot) setMailbox(ctx context.Context, value string) {
if old != "" { if old != "" {
b.rooms.Delete(old) b.rooms.Delete(old)
} }
active := b.ActivateMailbox(evt.Sender, evt.RoomID, value) active := b.ActivateMailbox(ctx, evt.Sender, evt.RoomID, value)
cfg.Set(config.RoomActive, strconv.FormatBool(active)) cfg.Set(config.RoomActive, strconv.FormatBool(active))
value = fmt.Sprintf("%s@%s", value, utils.SanitizeDomain(cfg.Domain())) value = fmt.Sprintf("%s@%s", value, utils.SanitizeDomain(cfg.Domain()))
err = b.cfg.SetRoom(evt.RoomID, cfg) err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot update settings: %v", err) b.Error(ctx, "cannot update settings: %v", err)
return return
} }
msg := fmt.Sprintf("mailbox of this room set to `%s`", value) msg := fmt.Sprintf("mailbox of this room set to `%s`", value)
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
func (b *Bot) setPassword(ctx context.Context) { func (b *Bot) setPassword(ctx context.Context) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err) b.Error(ctx, "failed to retrieve settings: %v", err)
return return
@@ -143,13 +143,13 @@ func (b *Bot) setPassword(ctx context.Context) {
} }
cfg.Set(config.RoomPassword, value) cfg.Set(config.RoomPassword, value)
err = b.cfg.SetRoom(evt.RoomID, cfg) err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot update settings: %v", err) b.Error(ctx, "cannot update settings: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, "SMTP password has been set", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "SMTP password has been set", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
func (b *Bot) setOption(ctx context.Context, name, value string) { func (b *Bot) setOption(ctx context.Context, name, value string) {
@@ -159,7 +159,7 @@ func (b *Bot) setOption(ctx context.Context, name, value string) {
} }
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err) b.Error(ctx, "failed to retrieve settings: %v", err)
return return
@@ -176,19 +176,19 @@ func (b *Bot) setOption(ctx context.Context, name, value string) {
old := cfg.Get(name) old := cfg.Get(name)
if old == value { if old == value {
b.lp.SendNotice(evt.RoomID, "nothing changed, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "nothing changed, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return return
} }
cfg.Set(name, value) cfg.Set(name, value)
err = b.cfg.SetRoom(evt.RoomID, cfg) err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot update settings: %v", err) b.Error(ctx, "cannot update settings: %v", err)
return return
} }
msg := fmt.Sprintf("`%s` of this room set to:\n```\n%s\n```", name, value) msg := fmt.Sprintf("`%s` of this room set to:\n```\n%s\n```", name, value)
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) { func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
@@ -197,7 +197,7 @@ func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
b.getOption(ctx, config.RoomSpamlist) b.getOption(ctx, config.RoomSpamlist)
return return
} }
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "cannot get room settings: %v", err) b.Error(ctx, "cannot get room settings: %v", err)
return return
@@ -212,7 +212,7 @@ func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
} }
cfg.Set(config.RoomSpamlist, utils.SliceString(spamlist)) cfg.Set(config.RoomSpamlist, utils.SliceString(spamlist))
err = b.cfg.SetRoom(evt.RoomID, cfg) err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot store room settings: %v", err) b.Error(ctx, "cannot store room settings: %v", err)
return return
@@ -223,7 +223,7 @@ func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
threadID = evt.ID threadID = evt.ID
} }
b.lp.SendNotice(evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(threadID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(threadID, cfg.NoThreads()))
} }
func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) { func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
@@ -232,7 +232,7 @@ func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
b.getOption(ctx, config.RoomSpamlist) b.getOption(ctx, config.RoomSpamlist)
return return
} }
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "cannot get room settings: %v", err) b.Error(ctx, "cannot get room settings: %v", err)
return return
@@ -248,7 +248,7 @@ func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
toRemove[idx] = struct{}{} toRemove[idx] = struct{}{}
} }
if len(toRemove) == 0 { if len(toRemove) == 0 {
b.lp.SendNotice(evt.RoomID, "nothing new, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "nothing new, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return return
} }
@@ -261,34 +261,34 @@ func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
} }
cfg.Set(config.RoomSpamlist, utils.SliceString(updatedSpamlist)) cfg.Set(config.RoomSpamlist, utils.SliceString(updatedSpamlist))
err = b.cfg.SetRoom(evt.RoomID, cfg) err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot store room settings: %v", err) b.Error(ctx, "cannot store room settings: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }
func (b *Bot) runSpamlistReset(ctx context.Context) { func (b *Bot) runSpamlistReset(ctx context.Context) {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "cannot get room settings: %v", err) b.Error(ctx, "cannot get room settings: %v", err)
return return
} }
spamlist := utils.StringSlice(cfg[config.RoomSpamlist]) spamlist := utils.StringSlice(cfg[config.RoomSpamlist])
if len(spamlist) == 0 { if len(spamlist) == 0 {
b.lp.SendNotice(evt.RoomID, "spamlist is empty, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "spamlist is empty, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return return
} }
cfg.Set(config.RoomSpamlist, "") cfg.Set(config.RoomSpamlist, "")
err = b.cfg.SetRoom(evt.RoomID, cfg) err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil { if err != nil {
b.Error(ctx, "cannot store room settings: %v", err) b.Error(ctx, "cannot store room settings: %v", err)
return return
} }
b.lp.SendNotice(evt.RoomID, "spamlist has been reset, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "spamlist has been reset, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
} }

View File

@@ -1,6 +1,8 @@
package config package config
import ( import (
"context"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"gitlab.com/etke.cc/linkpearl" "gitlab.com/etke.cc/linkpearl"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
@@ -27,10 +29,10 @@ func New(lp *linkpearl.Linkpearl, log *zerolog.Logger) *Manager {
} }
// GetBot config // GetBot config
func (m *Manager) GetBot() Bot { func (m *Manager) GetBot(ctx context.Context) Bot {
var err error var err error
var config Bot var config Bot
config, err = m.lp.GetAccountData(acBotKey) config, err = m.lp.GetAccountData(ctx, acBotKey)
if err != nil { if err != nil {
m.log.Error().Err(err).Msg("cannot get bot settings") m.log.Error().Err(err).Msg("cannot get bot settings")
} }
@@ -43,13 +45,13 @@ func (m *Manager) GetBot() Bot {
} }
// SetBot config // SetBot config
func (m *Manager) SetBot(cfg Bot) error { func (m *Manager) SetBot(ctx context.Context, cfg Bot) error {
return m.lp.SetAccountData(acBotKey, cfg) return m.lp.SetAccountData(ctx, acBotKey, cfg)
} }
// GetRoom config // GetRoom config
func (m *Manager) GetRoom(roomID id.RoomID) (Room, error) { func (m *Manager) GetRoom(ctx context.Context, roomID id.RoomID) (Room, error) {
config, err := m.lp.GetRoomAccountData(roomID, acRoomKey) config, err := m.lp.GetRoomAccountData(ctx, roomID, acRoomKey)
if err != nil { if err != nil {
m.log.Warn().Err(err).Str("room_id", roomID.String()).Msg("cannot get room settings") m.log.Warn().Err(err).Str("room_id", roomID.String()).Msg("cannot get room settings")
} }
@@ -61,19 +63,19 @@ func (m *Manager) GetRoom(roomID id.RoomID) (Room, error) {
} }
// SetRoom config // SetRoom config
func (m *Manager) SetRoom(roomID id.RoomID, cfg Room) error { func (m *Manager) SetRoom(ctx context.Context, roomID id.RoomID, cfg Room) error {
return m.lp.SetRoomAccountData(roomID, acRoomKey, cfg) return m.lp.SetRoomAccountData(ctx, roomID, acRoomKey, cfg)
} }
// GetBanlist config // GetBanlist config
func (m *Manager) GetBanlist() List { func (m *Manager) GetBanlist(ctx context.Context) List {
if !m.GetBot().BanlistEnabled() { if !m.GetBot(ctx).BanlistEnabled() {
return make(List, 0) return make(List, 0)
} }
m.mu.Lock("banlist") m.mu.Lock("banlist")
defer m.mu.Unlock("banlist") defer m.mu.Unlock("banlist")
config, err := m.lp.GetAccountData(acBanlistKey) config, err := m.lp.GetAccountData(ctx, acBanlistKey)
if err != nil { if err != nil {
m.log.Error().Err(err).Msg("cannot get banlist") m.log.Error().Err(err).Msg("cannot get banlist")
} }
@@ -85,8 +87,8 @@ func (m *Manager) GetBanlist() List {
} }
// SetBanlist config // SetBanlist config
func (m *Manager) SetBanlist(cfg List) error { func (m *Manager) SetBanlist(ctx context.Context, cfg List) error {
if !m.GetBot().BanlistEnabled() { if !m.GetBot(ctx).BanlistEnabled() {
return nil return nil
} }
@@ -96,12 +98,12 @@ func (m *Manager) SetBanlist(cfg List) error {
cfg = make(List, 0) cfg = make(List, 0)
} }
return m.lp.SetAccountData(acBanlistKey, cfg) return m.lp.SetAccountData(ctx, acBanlistKey, cfg)
} }
// GetGreylist config // GetGreylist config
func (m *Manager) GetGreylist() List { func (m *Manager) GetGreylist(ctx context.Context) List {
config, err := m.lp.GetAccountData(acGreylistKey) config, err := m.lp.GetAccountData(ctx, acGreylistKey)
if err != nil { if err != nil {
m.log.Error().Err(err).Msg("cannot get banlist") m.log.Error().Err(err).Msg("cannot get banlist")
} }
@@ -114,6 +116,6 @@ func (m *Manager) GetGreylist() List {
} }
// SetGreylist config // SetGreylist config
func (m *Manager) SetGreylist(cfg List) error { func (m *Manager) SetGreylist(ctx context.Context, cfg List) error {
return m.lp.SetAccountData(acGreylistKey, cfg) return m.lp.SetAccountData(ctx, acGreylistKey, cfg)
} }

View File

@@ -15,8 +15,10 @@ const (
ctxThreadID ctxkey = iota ctxThreadID ctxkey = iota
) )
func newContext(evt *event.Event) context.Context { func newContext(ctx context.Context, evt *event.Event) context.Context {
ctx := context.Background() if ctx == nil {
ctx = context.Background()
}
hub := sentry.CurrentHub().Clone() hub := sentry.CurrentHub().Clone()
ctx = sentry.SetHubOnContext(ctx, hub) ctx = sentry.SetHubOnContext(ctx, hub)
ctx = eventToContext(ctx, evt) ctx = eventToContext(ctx, evt)

View File

@@ -1,6 +1,7 @@
package bot package bot
import ( import (
"context"
"strconv" "strconv"
"time" "time"
@@ -9,21 +10,21 @@ import (
"gitlab.com/etke.cc/postmoogle/bot/config" "gitlab.com/etke.cc/postmoogle/bot/config"
) )
func (b *Bot) syncRooms() error { func (b *Bot) syncRooms(ctx context.Context) error {
adminRooms := []id.RoomID{} adminRooms := []id.RoomID{}
adminRoom := b.cfg.GetBot().AdminRoom() adminRoom := b.cfg.GetBot(ctx).AdminRoom()
if adminRoom != "" { if adminRoom != "" {
adminRooms = append(adminRooms, adminRoom) adminRooms = append(adminRooms, adminRoom)
} }
resp, err := b.lp.GetClient().JoinedRooms() resp, err := b.lp.GetClient().JoinedRooms(ctx)
if err != nil { if err != nil {
return err return err
} }
for _, roomID := range resp.JoinedRooms { for _, roomID := range resp.JoinedRooms {
b.migrateRoomSettings(roomID) b.migrateRoomSettings(ctx, roomID)
cfg, serr := b.cfg.GetRoom(roomID) cfg, serr := b.cfg.GetRoom(ctx, roomID)
if serr != nil { if serr != nil {
continue continue
} }
@@ -33,7 +34,7 @@ func (b *Bot) syncRooms() error {
b.rooms.Store(mailbox, roomID) b.rooms.Store(mailbox, roomID)
} }
if cfg.Owner() != "" && b.allowAdmin(id.UserID(cfg.Owner()), "") { if cfg.Owner() != "" && b.allowAdmin(ctx, id.UserID(cfg.Owner()), "") {
adminRooms = append(adminRooms, roomID) adminRooms = append(adminRooms, roomID)
} }
} }
@@ -42,8 +43,8 @@ func (b *Bot) syncRooms() error {
return nil return nil
} }
func (b *Bot) migrateRoomSettings(roomID id.RoomID) { func (b *Bot) migrateRoomSettings(ctx context.Context, roomID id.RoomID) {
cfg, err := b.cfg.GetRoom(roomID) cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot retrieve room settings") b.log.Error().Err(err).Msg("cannot retrieve room settings")
return return
@@ -56,7 +57,7 @@ func (b *Bot) migrateRoomSettings(roomID id.RoomID) {
return return
} }
cfg.MigrateSpamlistSettings() cfg.MigrateSpamlistSettings()
err = b.cfg.SetRoom(roomID, cfg) err = b.cfg.SetRoom(ctx, roomID, cfg)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot migrate room settings") b.log.Error().Err(err).Msg("cannot migrate room settings")
} }
@@ -68,8 +69,8 @@ func (b *Bot) migrateRoomSettings(roomID id.RoomID) {
// alongside with other database configs to simplify maintenance, // alongside with other database configs to simplify maintenance,
// but with that simplification there is no proper way to migrate // but with that simplification there is no proper way to migrate
// existing sync token and session info. No data loss, tho. // existing sync token and session info. No data loss, tho.
func (b *Bot) migrateMautrix015() error { func (b *Bot) migrateMautrix015(ctx context.Context) error {
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
ts := cfg.Mautrix015Migration() ts := cfg.Mautrix015Migration()
// already migrated // already migrated
if ts > 0 { if ts > 0 {
@@ -82,11 +83,11 @@ func (b *Bot) migrateMautrix015() error {
tss := strconv.FormatInt(ts, 10) tss := strconv.FormatInt(ts, 10)
cfg.Set(config.BotMautrix015Migration, tss) cfg.Set(config.BotMautrix015Migration, tss)
return b.cfg.SetBot(cfg) return b.cfg.SetBot(ctx, cfg)
} }
func (b *Bot) initBotUsers() ([]string, error) { func (b *Bot) initBotUsers(ctx context.Context) ([]string, error) {
cfg := b.cfg.GetBot() cfg := b.cfg.GetBot(ctx)
cfgUsers := cfg.Users() cfgUsers := cfg.Users()
if len(cfgUsers) > 0 { if len(cfgUsers) > 0 {
return cfgUsers, nil return cfgUsers, nil
@@ -97,10 +98,10 @@ func (b *Bot) initBotUsers() ([]string, error) {
return nil, err return nil, err
} }
cfg.Set(config.BotUsers, "@*:"+homeserver) cfg.Set(config.BotUsers, "@*:"+homeserver)
return cfg.Users(), b.cfg.SetBot(cfg) return cfg.Users(), b.cfg.SetBot(ctx, cfg)
} }
// SyncRooms and mailboxes // SyncRooms and mailboxes
func (b *Bot) SyncRooms() { func (b *Bot) SyncRooms() {
b.syncRooms() //nolint:errcheck // nothing can be done here b.syncRooms(context.Background()) //nolint:errcheck // nothing can be done here
} }

View File

@@ -60,14 +60,14 @@ func (b *Bot) shouldQueue(msg string) bool {
// Sendmail tries to send email immediately, but if it gets 4xx error (greylisting), // Sendmail tries to send email immediately, but if it gets 4xx error (greylisting),
// the email will be added to the queue and retried several times after that // the email will be added to the queue and retried several times after that
func (b *Bot) Sendmail(eventID id.EventID, from, to, data string) (bool, error) { func (b *Bot) Sendmail(ctx context.Context, eventID id.EventID, from, to, data string) (bool, error) {
log := b.log.With().Str("from", from).Str("to", to).Str("eventID", eventID.String()).Logger() log := b.log.With().Str("from", from).Str("to", to).Str("eventID", eventID.String()).Logger()
log.Info().Msg("attempting to deliver email") log.Info().Msg("attempting to deliver email")
err := b.sendmail(from, to, data) err := b.sendmail(from, to, data)
if err != nil { if err != nil {
if b.shouldQueue(err.Error()) { if b.shouldQueue(err.Error()) {
log.Info().Err(err).Msg("email has been added to the queue") log.Info().Err(err).Msg("email has been added to the queue")
return true, b.q.Add(eventID.String(), from, to, data) return true, b.q.Add(ctx, eventID.String(), from, to, data)
} }
log.Warn().Err(err).Msg("email delivery failed") log.Warn().Err(err).Msg("email delivery failed")
return false, err return false, err
@@ -78,8 +78,8 @@ func (b *Bot) Sendmail(eventID id.EventID, from, to, data string) (bool, error)
} }
// GetDKIMprivkey returns DKIM private key // GetDKIMprivkey returns DKIM private key
func (b *Bot) GetDKIMprivkey() string { func (b *Bot) GetDKIMprivkey(ctx context.Context) string {
return b.cfg.GetBot().DKIMPrivateKey() return b.cfg.GetBot(ctx).DKIMPrivateKey()
} }
func (b *Bot) getMapping(mailbox string) (id.RoomID, bool) { func (b *Bot) getMapping(mailbox string) (id.RoomID, bool) {
@@ -97,10 +97,10 @@ func (b *Bot) getMapping(mailbox string) (id.RoomID, bool) {
} }
// GetMapping returns mapping of mailbox = room // GetMapping returns mapping of mailbox = room
func (b *Bot) GetMapping(mailbox string) (id.RoomID, bool) { func (b *Bot) GetMapping(ctx context.Context, mailbox string) (id.RoomID, bool) {
roomID, ok := b.getMapping(mailbox) roomID, ok := b.getMapping(mailbox)
if !ok { if !ok {
catchAll := b.cfg.GetBot().CatchAll() catchAll := b.cfg.GetBot(ctx).CatchAll()
if catchAll == "" { if catchAll == "" {
return roomID, ok return roomID, ok
} }
@@ -111,8 +111,8 @@ func (b *Bot) GetMapping(mailbox string) (id.RoomID, bool) {
} }
// GetIFOptions returns incoming email filtering options (room settings) // GetIFOptions returns incoming email filtering options (room settings)
func (b *Bot) GetIFOptions(roomID id.RoomID) email.IncomingFilteringOptions { func (b *Bot) GetIFOptions(ctx context.Context, roomID id.RoomID) email.IncomingFilteringOptions {
cfg, err := b.cfg.GetRoom(roomID) cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot retrieve room settings") b.log.Error().Err(err).Msg("cannot retrieve room settings")
} }
@@ -124,11 +124,11 @@ func (b *Bot) GetIFOptions(roomID id.RoomID) email.IncomingFilteringOptions {
// //
//nolint:gocognit // TODO //nolint:gocognit // TODO
func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error { func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
roomID, ok := b.GetMapping(eml.Mailbox(true)) roomID, ok := b.GetMapping(ctx, eml.Mailbox(true))
if !ok { if !ok {
return ErrNoRoom return ErrNoRoom
} }
cfg, err := b.cfg.GetRoom(roomID) cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil { if err != nil {
b.Error(ctx, "cannot get settings: %v", err) b.Error(ctx, "cannot get settings: %v", err)
} }
@@ -139,15 +139,15 @@ func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
var threadID id.EventID var threadID id.EventID
newThread := true newThread := true
if eml.InReplyTo != "" || eml.References != "" { if eml.InReplyTo != "" || eml.References != "" {
threadID = b.getThreadID(roomID, eml.InReplyTo, eml.References) threadID = b.getThreadID(ctx, roomID, eml.InReplyTo, eml.References)
if threadID != "" { if threadID != "" {
newThread = false newThread = false
ctx = threadIDToContext(ctx, threadID) ctx = threadIDToContext(ctx, threadID)
b.setThreadID(roomID, eml.MessageID, threadID) b.setThreadID(ctx, roomID, eml.MessageID, threadID)
} }
} }
content := eml.Content(threadID, cfg.ContentOptions()) content := eml.Content(threadID, cfg.ContentOptions(), b.psd)
eventID, serr := b.lp.Send(roomID, content) eventID, serr := b.lp.Send(ctx, roomID, content)
if serr != nil { if serr != nil {
if !strings.Contains(serr.Error(), "M_UNKNOWN") { // if it's not an unknown event error if !strings.Contains(serr.Error(), "M_UNKNOWN") { // if it's not an unknown event error
return serr return serr
@@ -160,11 +160,11 @@ func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
ctx = threadIDToContext(ctx, threadID) ctx = threadIDToContext(ctx, threadID)
} }
b.setThreadID(roomID, eml.MessageID, threadID) b.setThreadID(ctx, roomID, eml.MessageID, threadID)
b.setLastEventID(roomID, threadID, eventID) b.setLastEventID(ctx, roomID, threadID, eventID)
if newThread && cfg.Threadify() { if newThread && cfg.Threadify() {
_, berr := b.lp.Send(roomID, eml.ContentBody(threadID, cfg.ContentOptions())) _, berr := b.lp.Send(ctx, roomID, eml.ContentBody(threadID, cfg.ContentOptions()))
if berr != nil { if berr != nil {
return berr return berr
} }
@@ -179,15 +179,15 @@ func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
} }
if newThread && cfg.Autoreply() != "" { if newThread && cfg.Autoreply() != "" {
b.sendAutoreply(roomID, threadID) b.sendAutoreply(ctx, roomID, threadID)
} }
return nil return nil
} }
//nolint:gocognit // TODO //nolint:gocognit // TODO
func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) { func (b *Bot) sendAutoreply(ctx context.Context, roomID id.RoomID, threadID id.EventID) {
cfg, err := b.cfg.GetRoom(roomID) cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil { if err != nil {
return return
} }
@@ -197,7 +197,7 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
return return
} }
threadEvt, err := b.lp.GetClient().GetEvent(roomID, threadID) threadEvt, err := b.lp.GetClient().GetEvent(ctx, roomID, threadID)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot get thread event for autoreply") b.log.Error().Err(err).Msg("cannot get thread event for autoreply")
return return
@@ -216,7 +216,7 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
}, },
} }
meta := b.getParentEmail(evt, cfg.Mailbox()) meta := b.getParentEmail(ctx, evt, cfg.Mailbox())
if meta.To == "" { if meta.To == "" {
return return
@@ -246,16 +246,16 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
meta.References = meta.References + " " + meta.MessageID meta.References = meta.References + " " + meta.MessageID
b.log.Info().Any("meta", meta).Msg("sending automatic reply") b.log.Info().Any("meta", meta).Msg("sending automatic reply")
eml := email.New(meta.MessageID, meta.InReplyTo, meta.References, meta.Subject, meta.From, meta.To, meta.RcptTo, meta.CC, body, htmlBody, nil, nil) eml := email.New(meta.MessageID, meta.InReplyTo, meta.References, meta.Subject, meta.From, meta.To, meta.RcptTo, meta.CC, body, htmlBody, nil, nil)
data := eml.Compose(b.cfg.GetBot().DKIMPrivateKey()) data := eml.Compose(b.cfg.GetBot(ctx).DKIMPrivateKey())
if data == "" { if data == "" {
return return
} }
var queued bool var queued bool
ctx := newContext(threadEvt) ctx = newContext(ctx, threadEvt)
recipients := meta.Recipients recipients := meta.Recipients
for _, to := range recipients { for _, to := range recipients {
queued, err = b.Sendmail(evt.ID, meta.From, to, data) queued, err = b.Sendmail(ctx, evt.ID, meta.From, to, data)
if queued { if queued {
b.log.Info().Err(err).Str("from", meta.From).Str("to", to).Msg("email has been queued") b.log.Info().Err(err).Str("from", meta.From).Str("to", to).Msg("email has been queued")
b.saveSentMetadata(ctx, queued, meta.ThreadID, recipients, eml, cfg, "Autoreply has been sent to "+to+" (queued)") b.saveSentMetadata(ctx, queued, meta.ThreadID, recipients, eml, cfg, "Autoreply has been sent to "+to+" (queued)")
@@ -273,7 +273,7 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
func (b *Bot) canReply(ctx context.Context) bool { func (b *Bot) canReply(ctx context.Context) bool {
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
return b.allowSend(evt.Sender, evt.RoomID) && b.allowReply(evt.Sender, evt.RoomID) return b.allowSend(ctx, evt.Sender, evt.RoomID) && b.allowReply(ctx, evt.Sender, evt.RoomID)
} }
// SendEmailReply sends replies from matrix thread to email thread // SendEmailReply sends replies from matrix thread to email thread
@@ -284,7 +284,7 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
if !b.canReply(ctx) { if !b.canReply(ctx) {
return return
} }
cfg, err := b.cfg.GetRoom(evt.RoomID) cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "cannot retrieve room settings: %v", err) b.Error(ctx, "cannot retrieve room settings: %v", err)
return return
@@ -295,10 +295,10 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
return return
} }
b.lock(evt.RoomID, evt.ID) b.lock(ctx, evt.RoomID, evt.ID)
defer b.unlock(evt.RoomID, evt.ID) defer b.unlock(ctx, evt.RoomID, evt.ID)
meta := b.getParentEmail(evt, mailbox) meta := b.getParentEmail(ctx, evt, mailbox)
if meta.To == "" { if meta.To == "" {
b.Error(ctx, "cannot find parent email and continue the thread. Please, start a new email thread") b.Error(ctx, "cannot find parent email and continue the thread. Please, start a new email thread")
@@ -306,7 +306,7 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
} }
if meta.ThreadID == "" { if meta.ThreadID == "" {
meta.ThreadID = b.getThreadID(evt.RoomID, meta.InReplyTo, meta.References) meta.ThreadID = b.getThreadID(ctx, evt.RoomID, meta.InReplyTo, meta.References)
ctx = threadIDToContext(ctx, meta.ThreadID) ctx = threadIDToContext(ctx, meta.ThreadID)
} }
content := evt.Content.AsMessage() content := evt.Content.AsMessage()
@@ -330,16 +330,16 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
meta.References = meta.References + " " + meta.MessageID meta.References = meta.References + " " + meta.MessageID
b.log.Info().Any("meta", meta).Msg("sending email reply") b.log.Info().Any("meta", meta).Msg("sending email reply")
eml := email.New(meta.MessageID, meta.InReplyTo, meta.References, meta.Subject, meta.From, meta.To, meta.RcptTo, meta.CC, body, htmlBody, nil, nil) eml := email.New(meta.MessageID, meta.InReplyTo, meta.References, meta.Subject, meta.From, meta.To, meta.RcptTo, meta.CC, body, htmlBody, nil, nil)
data := eml.Compose(b.cfg.GetBot().DKIMPrivateKey()) data := eml.Compose(b.cfg.GetBot(ctx).DKIMPrivateKey())
if data == "" { if data == "" {
b.lp.SendNotice(evt.RoomID, "email body is empty", linkpearl.RelatesTo(meta.ThreadID, cfg.NoThreads())) b.lp.SendNotice(ctx, evt.RoomID, "email body is empty", linkpearl.RelatesTo(meta.ThreadID, cfg.NoThreads()))
return return
} }
var queued bool var queued bool
recipients := meta.Recipients recipients := meta.Recipients
for _, to := range recipients { for _, to := range recipients {
queued, err = b.Sendmail(evt.ID, meta.From, to, data) queued, err = b.Sendmail(ctx, evt.ID, meta.From, to, data)
if queued { if queued {
b.log.Info().Err(err).Str("from", meta.From).Str("to", to).Msg("email has been queued") b.log.Info().Err(err).Str("from", meta.From).Str("to", to).Msg("email has been queued")
b.saveSentMetadata(ctx, queued, meta.ThreadID, recipients, eml, cfg) b.saveSentMetadata(ctx, queued, meta.ThreadID, recipients, eml, cfg)
@@ -444,7 +444,7 @@ func (e *parentEmail) calculateRecipients(from string, forwardedFrom []string) {
e.Recipients = rcpts e.Recipients = rcpts
} }
func (b *Bot) getParentEvent(evt *event.Event) (id.EventID, *event.Event) { func (b *Bot) getParentEvent(ctx context.Context, evt *event.Event) (id.EventID, *event.Event) {
content := evt.Content.AsMessage() content := evt.Content.AsMessage()
threadID := linkpearl.EventParent(evt.ID, content) threadID := linkpearl.EventParent(evt.ID, content)
b.log.Debug().Str("eventID", evt.ID.String()).Str("threadID", threadID.String()).Msg("looking up for the parent event within thread") b.log.Debug().Str("eventID", evt.ID.String()).Str("threadID", threadID.String()).Msg("looking up for the parent event within thread")
@@ -452,23 +452,23 @@ func (b *Bot) getParentEvent(evt *event.Event) (id.EventID, *event.Event) {
b.log.Debug().Str("eventID", evt.ID.String()).Msg("event is the thread itself") b.log.Debug().Str("eventID", evt.ID.String()).Msg("event is the thread itself")
return threadID, evt return threadID, evt
} }
lastEventID := b.getLastEventID(evt.RoomID, threadID) lastEventID := b.getLastEventID(ctx, evt.RoomID, threadID)
b.log.Debug().Str("eventID", evt.ID.String()).Str("threadID", threadID.String()).Str("lastEventID", lastEventID.String()).Msg("the last event of the thread (and parent of the event) has been found") b.log.Debug().Str("eventID", evt.ID.String()).Str("threadID", threadID.String()).Str("lastEventID", lastEventID.String()).Msg("the last event of the thread (and parent of the event) has been found")
if lastEventID == evt.ID { if lastEventID == evt.ID {
return threadID, evt return threadID, evt
} }
parentEvt, err := b.lp.GetClient().GetEvent(evt.RoomID, lastEventID) parentEvt, err := b.lp.GetClient().GetEvent(ctx, evt.RoomID, lastEventID)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot get parent event") b.log.Error().Err(err).Msg("cannot get parent event")
return threadID, nil return threadID, nil
} }
linkpearl.ParseContent(parentEvt, b.log) linkpearl.ParseContent(parentEvt, b.log)
if !b.lp.GetMachine().StateStore.IsEncrypted(evt.RoomID) { if ok, _ := b.lp.GetMachine().StateStore.IsEncrypted(ctx, evt.RoomID); !ok { //nolint:errcheck // that's fine
return threadID, parentEvt return threadID, parentEvt
} }
decrypted, err := b.lp.GetClient().Crypto.Decrypt(parentEvt) decrypted, err := b.lp.GetClient().Crypto.Decrypt(ctx, parentEvt)
if err != nil { if err != nil {
b.log.Error().Err(err).Msg("cannot decrypt parent event") b.log.Error().Err(err).Msg("cannot decrypt parent event")
return threadID, nil return threadID, nil
@@ -477,9 +477,9 @@ func (b *Bot) getParentEvent(evt *event.Event) (id.EventID, *event.Event) {
return threadID, decrypted return threadID, decrypted
} }
func (b *Bot) getParentEmail(evt *event.Event, newFromMailbox string) *parentEmail { func (b *Bot) getParentEmail(ctx context.Context, evt *event.Event, newFromMailbox string) *parentEmail {
parent := &parentEmail{} parent := &parentEmail{}
threadID, parentEvt := b.getParentEvent(evt) threadID, parentEvt := b.getParentEvent(ctx, evt)
parent.ThreadID = threadID parent.ThreadID = threadID
if parentEvt == nil { if parentEvt == nil {
return parent return parent
@@ -527,7 +527,7 @@ func (b *Bot) saveSentMetadata(ctx context.Context, queued bool, threadID id.Eve
} }
evt := eventFromContext(ctx) evt := eventFromContext(ctx)
content := eml.Content(threadID, cfg.ContentOptions()) content := eml.Content(threadID, cfg.ContentOptions(), b.psd)
notice := format.RenderMarkdown(text, true, true) notice := format.RenderMarkdown(text, true, true)
msgContent, ok := content.Parsed.(*event.MessageEventContent) msgContent, ok := content.Parsed.(*event.MessageEventContent)
if !ok { if !ok {
@@ -539,28 +539,28 @@ func (b *Bot) saveSentMetadata(ctx context.Context, queued bool, threadID id.Eve
msgContent.FormattedBody = notice.FormattedBody msgContent.FormattedBody = notice.FormattedBody
msgContent.RelatesTo = linkpearl.RelatesTo(threadID, cfg.NoThreads()) msgContent.RelatesTo = linkpearl.RelatesTo(threadID, cfg.NoThreads())
content.Parsed = msgContent content.Parsed = msgContent
msgID, err := b.lp.Send(evt.RoomID, content) msgID, err := b.lp.Send(ctx, evt.RoomID, content)
if err != nil { if err != nil {
b.Error(ctx, "cannot send notice: %v", err) b.Error(ctx, "cannot send notice: %v", err)
return return
} }
domain := utils.SanitizeDomain(cfg.Domain()) domain := utils.SanitizeDomain(cfg.Domain())
b.setThreadID(evt.RoomID, email.MessageID(evt.ID, domain), threadID) b.setThreadID(ctx, evt.RoomID, email.MessageID(evt.ID, domain), threadID)
b.setThreadID(evt.RoomID, email.MessageID(msgID, domain), threadID) b.setThreadID(ctx, evt.RoomID, email.MessageID(msgID, domain), threadID)
b.setLastEventID(evt.RoomID, threadID, msgID) b.setLastEventID(ctx, evt.RoomID, threadID, msgID)
} }
func (b *Bot) sendFiles(ctx context.Context, roomID id.RoomID, files []*utils.File, noThreads bool, parentID id.EventID) { func (b *Bot) sendFiles(ctx context.Context, roomID id.RoomID, files []*utils.File, noThreads bool, parentID id.EventID) {
for _, file := range files { for _, file := range files {
req := file.Convert() req := file.Convert()
err := b.lp.SendFile(roomID, req, file.MsgType, linkpearl.RelatesTo(parentID, noThreads)) err := b.lp.SendFile(ctx, roomID, req, file.MsgType, linkpearl.RelatesTo(parentID, noThreads))
if err != nil { if err != nil {
b.Error(ctx, "cannot upload file %s: %v", req.FileName, err) b.Error(ctx, "cannot upload file %s: %v", req.FileName, err)
} }
} }
} }
func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.EventID { func (b *Bot) getThreadID(ctx context.Context, roomID id.RoomID, messageID, references string) id.EventID {
refs := []string{messageID} refs := []string{messageID}
if references != "" { if references != "" {
refs = append(refs, strings.Split(references, " ")...) refs = append(refs, strings.Split(references, " ")...)
@@ -568,7 +568,7 @@ func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.Eve
for _, refID := range refs { for _, refID := range refs {
key := acMessagePrefix + "." + refID key := acMessagePrefix + "." + refID
data, err := b.lp.GetRoomAccountData(roomID, key) data, err := b.lp.GetRoomAccountData(ctx, roomID, key)
if err != nil { if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot retrieve thread ID") b.log.Error().Err(err).Str("key", key).Msg("cannot retrieve thread ID")
continue continue
@@ -576,7 +576,7 @@ func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.Eve
if data["eventID"] == "" { if data["eventID"] == "" {
continue continue
} }
resp, err := b.lp.GetClient().GetEvent(roomID, id.EventID(data["eventID"])) resp, err := b.lp.GetClient().GetEvent(ctx, roomID, id.EventID(data["eventID"]))
if err != nil { if err != nil {
b.log.Warn().Err(err).Str("roomID", roomID.String()).Str("eventID", data["eventID"]).Msg("cannot get event by id (may be removed)") b.log.Warn().Err(err).Str("roomID", roomID.String()).Str("eventID", data["eventID"]).Msg("cannot get event by id (may be removed)")
continue continue
@@ -587,17 +587,17 @@ func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.Eve
return "" return ""
} }
func (b *Bot) setThreadID(roomID id.RoomID, messageID string, eventID id.EventID) { func (b *Bot) setThreadID(ctx context.Context, roomID id.RoomID, messageID string, eventID id.EventID) {
key := acMessagePrefix + "." + messageID key := acMessagePrefix + "." + messageID
err := b.lp.SetRoomAccountData(roomID, key, map[string]string{"eventID": eventID.String()}) err := b.lp.SetRoomAccountData(ctx, roomID, key, map[string]string{"eventID": eventID.String()})
if err != nil { if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot save thread ID") b.log.Error().Err(err).Str("key", key).Msg("cannot save thread ID")
} }
} }
func (b *Bot) getLastEventID(roomID id.RoomID, threadID id.EventID) id.EventID { func (b *Bot) getLastEventID(ctx context.Context, roomID id.RoomID, threadID id.EventID) id.EventID {
key := acLastEventPrefix + "." + threadID.String() key := acLastEventPrefix + "." + threadID.String()
data, err := b.lp.GetRoomAccountData(roomID, key) data, err := b.lp.GetRoomAccountData(ctx, roomID, key)
if err != nil { if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot retrieve last event ID") b.log.Error().Err(err).Str("key", key).Msg("cannot retrieve last event ID")
return threadID return threadID
@@ -609,9 +609,9 @@ func (b *Bot) getLastEventID(roomID id.RoomID, threadID id.EventID) id.EventID {
return threadID return threadID
} }
func (b *Bot) setLastEventID(roomID id.RoomID, threadID, eventID id.EventID) { func (b *Bot) setLastEventID(ctx context.Context, roomID id.RoomID, threadID, eventID id.EventID) {
key := acLastEventPrefix + "." + threadID.String() key := acLastEventPrefix + "." + threadID.String()
err := b.lp.SetRoomAccountData(roomID, key, map[string]string{"eventID": eventID.String()}) err := b.lp.SetRoomAccountData(ctx, roomID, key, map[string]string{"eventID": eventID.String()})
if err != nil { if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot save thread ID") b.log.Error().Err(err).Str("key", key).Msg("cannot save thread ID")
} }

View File

@@ -1,29 +1,31 @@
package bot package bot
import ( import (
"context"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
func (b *Bot) lock(roomID id.RoomID, optionalEventID ...id.EventID) { func (b *Bot) lock(ctx context.Context, roomID id.RoomID, optionalEventID ...id.EventID) {
b.mu.Lock(roomID.String()) b.mu.Lock(roomID.String())
if len(optionalEventID) == 0 { if len(optionalEventID) == 0 {
return return
} }
evtID := optionalEventID[0] evtID := optionalEventID[0]
if _, err := b.lp.GetClient().SendReaction(roomID, evtID, "📨"); err != nil { if _, err := b.lp.GetClient().SendReaction(ctx, roomID, evtID, "📨"); err != nil {
b.log.Error().Err(err).Str("roomID", roomID.String()).Str("eventID", evtID.String()).Msg("cannot send reaction on lock") b.log.Error().Err(err).Str("roomID", roomID.String()).Str("eventID", evtID.String()).Msg("cannot send reaction on lock")
} }
} }
func (b *Bot) unlock(roomID id.RoomID, optionalEventID ...id.EventID) { func (b *Bot) unlock(ctx context.Context, roomID id.RoomID, optionalEventID ...id.EventID) {
b.mu.Unlock(roomID.String()) b.mu.Unlock(roomID.String())
if len(optionalEventID) == 0 { if len(optionalEventID) == 0 {
return return
} }
evtID := optionalEventID[0] evtID := optionalEventID[0]
if _, err := b.lp.GetClient().SendReaction(roomID, evtID, "✅"); err != nil { if _, err := b.lp.GetClient().SendReaction(ctx, roomID, evtID, "✅"); err != nil {
b.log.Error().Err(err).Str("roomID", roomID.String()).Str("eventID", evtID.String()).Msg("cannot send reaction on unlock") b.log.Error().Err(err).Str("roomID", roomID.String()).Str("eventID", evtID.String()).Msg("cannot send reaction on unlock")
} }
} }

View File

@@ -1,6 +1,8 @@
package queue package queue
import ( import (
"context"
"github.com/rs/zerolog" "github.com/rs/zerolog"
"gitlab.com/etke.cc/linkpearl" "gitlab.com/etke.cc/linkpearl"
@@ -41,7 +43,8 @@ func (q *Queue) SetSendmail(function func(string, string, string) error) {
// Process queue // Process queue
func (q *Queue) Process() { func (q *Queue) Process() {
q.log.Debug().Msg("staring queue processing...") q.log.Debug().Msg("staring queue processing...")
cfg := q.cfg.GetBot() ctx := context.Background()
cfg := q.cfg.GetBot(ctx)
batchSize := cfg.QueueBatch() batchSize := cfg.QueueBatch()
if batchSize == 0 { if batchSize == 0 {
@@ -55,7 +58,7 @@ func (q *Queue) Process() {
q.mu.Lock(acQueueKey) q.mu.Lock(acQueueKey)
defer q.mu.Unlock(acQueueKey) defer q.mu.Unlock(acQueueKey)
index, err := q.lp.GetAccountData(acQueueKey) index, err := q.lp.GetAccountData(ctx, acQueueKey)
if err != nil { if err != nil {
q.log.Error().Err(err).Msg("cannot get queue index") q.log.Error().Err(err).Msg("cannot get queue index")
} }
@@ -66,9 +69,9 @@ func (q *Queue) Process() {
q.log.Debug().Msg("finished re-deliveries from queue") q.log.Debug().Msg("finished re-deliveries from queue")
return return
} }
if dequeue := q.try(itemkey, maxRetries); dequeue { if dequeue := q.try(ctx, itemkey, maxRetries); dequeue {
q.log.Info().Str("id", id).Msg("email has been delivered") q.log.Info().Str("id", id).Msg("email has been delivered")
err = q.Remove(id) err = q.Remove(ctx, id)
if err != nil { if err != nil {
q.log.Error().Err(err).Str("id", id).Msg("cannot dequeue email") q.log.Error().Err(err).Str("id", id).Msg("cannot dequeue email")
} }

View File

@@ -1,11 +1,12 @@
package queue package queue
import ( import (
"context"
"strconv" "strconv"
) )
// Add to queue // Add to queue
func (q *Queue) Add(id, from, to, data string) error { func (q *Queue) Add(ctx context.Context, id, from, to, data string) error {
itemkey := acQueueKey + "." + id itemkey := acQueueKey + "." + id
item := map[string]string{ item := map[string]string{
"attempts": "0", "attempts": "0",
@@ -17,7 +18,7 @@ func (q *Queue) Add(id, from, to, data string) error {
q.mu.Lock(itemkey) q.mu.Lock(itemkey)
defer q.mu.Unlock(itemkey) defer q.mu.Unlock(itemkey)
err := q.lp.SetAccountData(itemkey, item) err := q.lp.SetAccountData(ctx, itemkey, item)
if err != nil { if err != nil {
q.log.Error().Err(err).Str("id", id).Msg("cannot enqueue email") q.log.Error().Err(err).Str("id", id).Msg("cannot enqueue email")
return err return err
@@ -25,13 +26,13 @@ func (q *Queue) Add(id, from, to, data string) error {
q.mu.Lock(acQueueKey) q.mu.Lock(acQueueKey)
defer q.mu.Unlock(acQueueKey) defer q.mu.Unlock(acQueueKey)
queueIndex, err := q.lp.GetAccountData(acQueueKey) queueIndex, err := q.lp.GetAccountData(ctx, acQueueKey)
if err != nil { if err != nil {
q.log.Error().Err(err).Msg("cannot get queue index") q.log.Error().Err(err).Msg("cannot get queue index")
return err return err
} }
queueIndex[id] = itemkey queueIndex[id] = itemkey
err = q.lp.SetAccountData(acQueueKey, queueIndex) err = q.lp.SetAccountData(ctx, acQueueKey, queueIndex)
if err != nil { if err != nil {
q.log.Error().Err(err).Msg("cannot save queue index") q.log.Error().Err(err).Msg("cannot save queue index")
return err return err
@@ -41,8 +42,8 @@ func (q *Queue) Add(id, from, to, data string) error {
} }
// Remove from queue // Remove from queue
func (q *Queue) Remove(id string) error { func (q *Queue) Remove(ctx context.Context, id string) error {
index, err := q.lp.GetAccountData(acQueueKey) index, err := q.lp.GetAccountData(ctx, acQueueKey)
if err != nil { if err != nil {
q.log.Error().Err(err).Msg("cannot get queue index") q.log.Error().Err(err).Msg("cannot get queue index")
return err return err
@@ -52,7 +53,7 @@ func (q *Queue) Remove(id string) error {
itemkey = acQueueKey + "." + id itemkey = acQueueKey + "." + id
} }
delete(index, id) delete(index, id)
err = q.lp.SetAccountData(acQueueKey, index) err = q.lp.SetAccountData(ctx, acQueueKey, index)
if err != nil { if err != nil {
q.log.Error().Err(err).Msg("cannot update queue index") q.log.Error().Err(err).Msg("cannot update queue index")
return err return err
@@ -60,15 +61,15 @@ func (q *Queue) Remove(id string) error {
q.mu.Lock(itemkey) q.mu.Lock(itemkey)
defer q.mu.Unlock(itemkey) defer q.mu.Unlock(itemkey)
return q.lp.SetAccountData(itemkey, map[string]string{}) return q.lp.SetAccountData(ctx, itemkey, map[string]string{})
} }
// try to send email // try to send email
func (q *Queue) try(itemkey string, maxRetries int) bool { func (q *Queue) try(ctx context.Context, itemkey string, maxRetries int) bool {
q.mu.Lock(itemkey) q.mu.Lock(itemkey)
defer q.mu.Unlock(itemkey) defer q.mu.Unlock(itemkey)
item, err := q.lp.GetAccountData(itemkey) item, err := q.lp.GetAccountData(ctx, itemkey)
if err != nil { if err != nil {
q.log.Error().Err(err).Str("id", itemkey).Msg("cannot retrieve a queue item") q.log.Error().Err(err).Str("id", itemkey).Msg("cannot retrieve a queue item")
return false return false
@@ -92,7 +93,7 @@ func (q *Queue) try(itemkey string, maxRetries int) bool {
q.log.Info().Str("id", itemkey).Str("from", item["from"]).Str("to", item["to"]).Err(err).Msg("attempted to deliver email, but it's not ready yet") q.log.Info().Str("id", itemkey).Str("from", item["from"]).Str("to", item["to"]).Err(err).Msg("attempted to deliver email, but it's not ready yet")
attempts++ attempts++
item["attempts"] = strconv.Itoa(attempts) item["attempts"] = strconv.Itoa(attempts)
err = q.lp.SetAccountData(itemkey, item) err = q.lp.SetAccountData(ctx, itemkey, item)
if err != nil { if err != nil {
q.log.Error().Err(err).Str("id", itemkey).Msg("cannot update attempt count on email") q.log.Error().Err(err).Str("id", itemkey).Msg("cannot update attempt count on email")
} }

View File

@@ -22,14 +22,14 @@ func (b *Bot) handleReaction(ctx context.Context) {
} }
srcID := content.GetRelatesTo().EventID srcID := content.GetRelatesTo().EventID
srcEvt, err := b.lp.GetClient().GetEvent(evt.RoomID, srcID) srcEvt, err := b.lp.GetClient().GetEvent(ctx, evt.RoomID, srcID)
if err != nil { if err != nil {
b.Error(ctx, "cannot find event %s: %v", srcID, err) b.Error(ctx, "cannot find event %s: %v", srcID, err)
return return
} }
linkpearl.ParseContent(srcEvt, b.log) linkpearl.ParseContent(srcEvt, b.log)
if b.lp.GetMachine().StateStore.IsEncrypted(evt.RoomID) { if ok, _ := b.lp.GetMachine().StateStore.IsEncrypted(ctx, evt.RoomID); ok { //nolint:errcheck // that's ok
decrypted, derr := b.lp.GetClient().Crypto.Decrypt(srcEvt) decrypted, derr := b.lp.GetClient().Crypto.Decrypt(ctx, srcEvt)
if derr == nil { if derr == nil {
srcEvt = decrypted srcEvt = decrypted
} }

View File

@@ -3,7 +3,6 @@ package bot
import ( import (
"context" "context"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
) )
@@ -12,27 +11,27 @@ func (b *Bot) initSync() {
b.lp.OnEventType( b.lp.OnEventType(
event.StateMember, event.StateMember,
func(_ mautrix.EventSource, evt *event.Event) { func(ctx context.Context, evt *event.Event) {
go b.onMembership(evt) go b.onMembership(ctx, evt)
}, },
) )
b.lp.OnEventType( b.lp.OnEventType(
event.EventMessage, event.EventMessage,
func(_ mautrix.EventSource, evt *event.Event) { func(ctx context.Context, evt *event.Event) {
go b.onMessage(evt) go b.onMessage(ctx, evt)
}, },
) )
b.lp.OnEventType( b.lp.OnEventType(
event.EventReaction, event.EventReaction,
func(_ mautrix.EventSource, evt *event.Event) { func(ctx context.Context, evt *event.Event) {
go b.onReaction(evt) go b.onReaction(ctx, evt)
}, },
) )
} }
// joinPermit is called by linkpearl when processing "invite" events and deciding if rooms should be auto-joined or not // joinPermit is called by linkpearl when processing "invite" events and deciding if rooms should be auto-joined or not
func (b *Bot) joinPermit(evt *event.Event) bool { func (b *Bot) joinPermit(ctx context.Context, evt *event.Event) bool {
if !b.allowUsers(evt.Sender, evt.RoomID) { if !b.allowUsers(ctx, evt.Sender, evt.RoomID) {
b.log.Debug().Str("userID", evt.Sender.String()).Msg("Rejecting room invitation from unallowed user") b.log.Debug().Str("userID", evt.Sender.String()).Msg("Rejecting room invitation from unallowed user")
return false return false
} }
@@ -40,13 +39,13 @@ func (b *Bot) joinPermit(evt *event.Event) bool {
return true return true
} }
func (b *Bot) onMembership(evt *event.Event) { func (b *Bot) onMembership(ctx context.Context, evt *event.Event) {
// mautrix 0.15.x migration // mautrix 0.15.x migration
if b.ignoreBefore >= evt.Timestamp { if b.ignoreBefore >= evt.Timestamp {
return return
} }
ctx := newContext(evt) ctx = newContext(ctx, evt)
evtType := evt.Content.AsMember().Membership evtType := evt.Content.AsMember().Membership
if evtType == event.MembershipJoin && evt.Sender == b.lp.GetClient().UserID { if evtType == event.MembershipJoin && evt.Sender == b.lp.GetClient().UserID {
@@ -61,7 +60,7 @@ func (b *Bot) onMembership(evt *event.Event) {
// Potentially handle other membership events in the future // Potentially handle other membership events in the future
} }
func (b *Bot) onMessage(evt *event.Event) { func (b *Bot) onMessage(ctx context.Context, evt *event.Event) {
// ignore own messages // ignore own messages
if evt.Sender == b.lp.GetClient().UserID { if evt.Sender == b.lp.GetClient().UserID {
return return
@@ -71,11 +70,11 @@ func (b *Bot) onMessage(evt *event.Event) {
return return
} }
ctx := newContext(evt) ctx = newContext(ctx, evt)
b.handle(ctx) b.handle(ctx)
} }
func (b *Bot) onReaction(evt *event.Event) { func (b *Bot) onReaction(ctx context.Context, evt *event.Event) {
// ignore own messages // ignore own messages
if evt.Sender == b.lp.GetClient().UserID { if evt.Sender == b.lp.GetClient().UserID {
return return
@@ -85,7 +84,7 @@ func (b *Bot) onReaction(evt *event.Event) {
return return
} }
ctx := newContext(evt) ctx = newContext(ctx, evt)
b.handleReaction(ctx) b.handleReaction(ctx)
} }
@@ -100,7 +99,7 @@ func (b *Bot) onBotJoin(ctx context.Context) {
return return
} }
b.sendIntroduction(evt.RoomID) b.sendIntroduction(ctx, evt.RoomID)
b.sendHelp(ctx) b.sendHelp(ctx)
} }
@@ -111,7 +110,7 @@ func (b *Bot) onLeave(ctx context.Context) {
b.log.Info().Str("eventID", evt.ID.String()).Msg("Suppressing already handled event") b.log.Info().Str("eventID", evt.ID.String()).Msg("Suppressing already handled event")
return return
} }
members, err := b.lp.GetClient().StateStore.GetRoomJoinedOrInvitedMembers(evt.RoomID) members, err := b.lp.GetClient().StateStore.GetRoomJoinedOrInvitedMembers(ctx, evt.RoomID)
if err != nil { if err != nil {
b.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members") b.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members")
return return
@@ -121,7 +120,7 @@ func (b *Bot) onLeave(ctx context.Context) {
if count == 1 && members[0] == b.lp.GetClient().UserID { if count == 1 && members[0] == b.lp.GetClient().UserID {
b.log.Info().Str("roomID", evt.RoomID.String()).Msg("no more users left in the room") b.log.Info().Str("roomID", evt.RoomID.String()).Msg("no more users left in the room")
b.runStop(ctx) b.runStop(ctx)
_, err := b.lp.GetClient().LeaveRoom(evt.RoomID) _, err := b.lp.GetClient().LeaveRoom(ctx, evt.RoomID)
if err != nil { if err != nil {
b.Error(ctx, "cannot leave empty room: %v", err) b.Error(ctx, "cannot leave empty room: %v", err)
} }

View File

@@ -114,9 +114,10 @@ func initMatrix(cfg *config.Config) {
log.Fatal().Err(err).Msg("cannot initialize matrix bot") log.Fatal().Err(err).Msg("cannot initialize matrix bot")
} }
psd := utils.NewPSD(cfg.PSD.URL, cfg.PSD.Login, cfg.PSD.Password, &log)
mxc = mxconfig.New(lp, &log) mxc = mxconfig.New(lp, &log)
q = queue.New(lp, mxc, &log) q = queue.New(lp, mxc, &log)
mxb, err = bot.New(q, lp, &log, mxc, cfg.Proxies, cfg.Prefix, cfg.Domains, cfg.Admins, bot.MBXConfig(cfg.Mailboxes)) mxb, err = bot.New(q, lp, &log, mxc, psd, cfg.Proxies, cfg.Prefix, cfg.Domains, cfg.Admins, bot.MBXConfig(cfg.Mailboxes))
if err != nil { if err != nil {
log.Panic().Err(err).Msg("cannot start matrix bot") log.Panic().Err(err).Msg("cannot start matrix bot")
} }

View File

@@ -48,6 +48,11 @@ func New() *Config {
DSN: env.String("db.dsn", defaultConfig.DB.DSN), DSN: env.String("db.dsn", defaultConfig.DB.DSN),
Dialect: env.String("db.dialect", defaultConfig.DB.Dialect), Dialect: env.String("db.dialect", defaultConfig.DB.Dialect),
}, },
PSD: PSD{
URL: env.String("psd.url"),
Login: env.String("psd.login"),
Password: env.String("psd.password"),
},
Relay: Relay{ Relay: Relay{
Host: env.String("relay.host", defaultConfig.Relay.Host), Host: env.String("relay.host", defaultConfig.Relay.Host),
Port: env.String("relay.port", defaultConfig.Relay.Port), Port: env.String("relay.port", defaultConfig.Relay.Port),

View File

@@ -38,6 +38,9 @@ type Config struct {
// DB config // DB config
DB DB DB DB
// PSD config
PSD PSD
// TLS config // TLS config
TLS TLS TLS TLS
@@ -78,6 +81,12 @@ type Mailboxes struct {
Activation string Activation string
} }
type PSD struct {
URL string
Login string
Password string
}
// Relay config // Relay config
type Relay struct { type Relay struct {
Host string Host string

View File

@@ -108,8 +108,9 @@ func (e *Email) Mailbox(incoming bool) string {
return utils.Mailbox(e.From) return utils.Mailbox(e.From)
} }
func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, options *ContentOptions) { func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, options *ContentOptions, psd *utils.PSD) {
if options.Sender { if options.Sender {
text.WriteString(psd.Status(e.From))
text.WriteString(e.From) text.WriteString(e.From)
} }
if options.Recipient { if options.Recipient {
@@ -125,8 +126,12 @@ func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, option
} }
} }
if options.CC && len(e.CC) > 0 { if options.CC && len(e.CC) > 0 {
ccs := make([]string, 0, len(e.CC))
for _, addr := range e.CC {
ccs = append(ccs, psd.Status(addr)+addr)
}
text.WriteString("\ncc: ") text.WriteString("\ncc: ")
text.WriteString(strings.Join(e.CC, ", ")) text.WriteString(strings.Join(ccs, ", "))
} }
if options.Sender || options.Recipient || options.CC { if options.Sender || options.Recipient || options.CC {
text.WriteString("\n\n") text.WriteString("\n\n")
@@ -146,10 +151,10 @@ func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, option
} }
// Content converts the email object to a Matrix event content // Content converts the email object to a Matrix event content
func (e *Email) Content(threadID id.EventID, options *ContentOptions) *event.Content { func (e *Email) Content(threadID id.EventID, options *ContentOptions, psd *utils.PSD) *event.Content {
var text strings.Builder var text strings.Builder
e.contentHeader(threadID, &text, options) e.contentHeader(threadID, &text, options, psd)
if threadID != "" || (threadID == "" && !options.Threadify) { if threadID != "" || (threadID == "" && !options.Threadify) {
if e.HTML != "" && options.HTML { if e.HTML != "" && options.HTML {

20
go.mod
View File

@@ -18,16 +18,16 @@ require (
github.com/mcnijman/go-emailaddress v1.1.0 github.com/mcnijman/go-emailaddress v1.1.0
github.com/mileusna/crontab v1.2.0 github.com/mileusna/crontab v1.2.0
github.com/raja/argon2pw v1.0.2-0.20210910183755-a391af63bd39 github.com/raja/argon2pw v1.0.2-0.20210910183755-a391af63bd39
github.com/rs/zerolog v1.31.0 github.com/rs/zerolog v1.32.0
gitlab.com/etke.cc/go/env v1.0.0 gitlab.com/etke.cc/go/env v1.1.0
gitlab.com/etke.cc/go/fswatcher v1.0.0 gitlab.com/etke.cc/go/fswatcher v1.0.0
gitlab.com/etke.cc/go/healthchecks v1.0.1 gitlab.com/etke.cc/go/healthchecks v1.0.1
gitlab.com/etke.cc/go/mxidwc v1.0.0 gitlab.com/etke.cc/go/mxidwc v1.0.0
gitlab.com/etke.cc/go/secgen v1.1.1 gitlab.com/etke.cc/go/secgen v1.1.1
gitlab.com/etke.cc/go/validator v1.0.6 gitlab.com/etke.cc/go/validator v1.0.6
gitlab.com/etke.cc/linkpearl v0.0.0-20231121221431-72443f33d266 gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3
maunium.net/go/mautrix v0.16.2 maunium.net/go/mautrix v0.17.0
) )
require ( require (
@@ -50,12 +50,12 @@ require (
github.com/tidwall/match v1.1.1 // indirect github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect github.com/tidwall/sjson v1.2.5 // indirect
github.com/yuin/goldmark v1.6.0 // indirect github.com/yuin/goldmark v1.7.0 // indirect
gitlab.com/etke.cc/go/trysmtp v1.1.3 // indirect gitlab.com/etke.cc/go/trysmtp v1.1.3 // indirect
go.mau.fi/util v0.2.1 // indirect go.mau.fi/util v0.3.0 // indirect
golang.org/x/crypto v0.15.0 // indirect golang.org/x/crypto v0.19.0 // indirect
golang.org/x/net v0.18.0 // indirect golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.14.0 // indirect golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect golang.org/x/text v0.14.0 // indirect
maunium.net/go/maulogger/v2 v2.4.1 // indirect maunium.net/go/maulogger/v2 v2.4.1 // indirect
) )

44
go.sum
View File

@@ -1,6 +1,6 @@
blitiri.com.ar/go/spf v1.5.1 h1:CWUEasc44OrANJD8CzceRnRn1Jv0LttY68cYym2/pbE= blitiri.com.ar/go/spf v1.5.1 h1:CWUEasc44OrANJD8CzceRnRn1Jv0LttY68cYym2/pbE=
blitiri.com.ar/go/spf v1.5.1/go.mod h1:E71N92TfL4+Yyd5lpKuE9CAF2pd4JrUq1xQfkTxoNdk= blitiri.com.ar/go/spf v1.5.1/go.mod h1:E71N92TfL4+Yyd5lpKuE9CAF2pd4JrUq1xQfkTxoNdk=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60= github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo=
github.com/archdx/zerolog-sentry v1.2.0 h1:FDFqlo5XvL/jpDAPoAWI15EjJQVFvixn70v3IH//eTM= github.com/archdx/zerolog-sentry v1.2.0 h1:FDFqlo5XvL/jpDAPoAWI15EjJQVFvixn70v3IH//eTM=
github.com/archdx/zerolog-sentry v1.2.0/go.mod h1:3H8gClGFafB90fKMsvfP017bdmkG5MD6UiA+6iPEwGw= github.com/archdx/zerolog-sentry v1.2.0/go.mod h1:3H8gClGFafB90fKMsvfP017bdmkG5MD6UiA+6iPEwGw=
github.com/buger/jsonparser v1.0.0 h1:etJTGF5ESxjI0Ic2UaLQs2LQQpa8G9ykQScukbh4L8A= github.com/buger/jsonparser v1.0.0 h1:etJTGF5ESxjI0Ic2UaLQs2LQQpa8G9ykQScukbh4L8A=
@@ -55,8 +55,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI= github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.12 h1:Y41i/hVW3Pgwr8gV+J23B9YEY0zxjptBuCWEaxmAOow= github.com/mattn/go-runewidth v0.0.12 h1:Y41i/hVW3Pgwr8gV+J23B9YEY0zxjptBuCWEaxmAOow=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk= github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI= github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg= github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mcnijman/go-emailaddress v1.1.0 h1:7/Uxgn9pXwXmvXsFSgORo6XoRTrttj7AGmmB2yFArAg= github.com/mcnijman/go-emailaddress v1.1.0 h1:7/Uxgn9pXwXmvXsFSgORo6XoRTrttj7AGmmB2yFArAg=
@@ -78,8 +76,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY= github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc= github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg= github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A= github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss= github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo=
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM= github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME= github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -96,10 +94,12 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU= github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY= github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28= github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68= github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY= github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
gitlab.com/etke.cc/go/env v1.0.0 h1:J98BwzOuELnjsVPFvz5wa79L7IoRV9CmrS41xLYXtSw= gitlab.com/etke.cc/go/env v1.0.0 h1:J98BwzOuELnjsVPFvz5wa79L7IoRV9CmrS41xLYXtSw=
gitlab.com/etke.cc/go/env v1.0.0/go.mod h1:e1l4RM5MA1sc0R1w/RBDAESWRwgo5cOG9gx8BKUn2C4= gitlab.com/etke.cc/go/env v1.0.0/go.mod h1:e1l4RM5MA1sc0R1w/RBDAESWRwgo5cOG9gx8BKUn2C4=
gitlab.com/etke.cc/go/env v1.1.0 h1:nbMhZkMu6C8lysRlb5siIiylWuyVkGAgEvwWEqz/82o=
gitlab.com/etke.cc/go/env v1.1.0/go.mod h1:e1l4RM5MA1sc0R1w/RBDAESWRwgo5cOG9gx8BKUn2C4=
gitlab.com/etke.cc/go/fswatcher v1.0.0 h1:uyiVn+1NVCjOLZrXSZouIDBDZBMwVipS4oYuvAFpPzo= gitlab.com/etke.cc/go/fswatcher v1.0.0 h1:uyiVn+1NVCjOLZrXSZouIDBDZBMwVipS4oYuvAFpPzo=
gitlab.com/etke.cc/go/fswatcher v1.0.0/go.mod h1:MqTOxyhXfvaVZQUL9/Ksbl2ow1PTBVu3eqIldvMq0RE= gitlab.com/etke.cc/go/fswatcher v1.0.0/go.mod h1:MqTOxyhXfvaVZQUL9/Ksbl2ow1PTBVu3eqIldvMq0RE=
gitlab.com/etke.cc/go/healthchecks v1.0.1 h1:IxPB+r4KtEM6wf4K7MeQoH1XnuBITMGUqFaaRIgxeUY= gitlab.com/etke.cc/go/healthchecks v1.0.1 h1:IxPB+r4KtEM6wf4K7MeQoH1XnuBITMGUqFaaRIgxeUY=
@@ -112,21 +112,21 @@ gitlab.com/etke.cc/go/trysmtp v1.1.3 h1:e2EHond77onMaecqCg6mWumffTSEf+ycgj88nbee
gitlab.com/etke.cc/go/trysmtp v1.1.3/go.mod h1:lOO7tTdAE0a3ETV3wN3GJ7I1Tqewu7YTpPWaOmTteV0= gitlab.com/etke.cc/go/trysmtp v1.1.3/go.mod h1:lOO7tTdAE0a3ETV3wN3GJ7I1Tqewu7YTpPWaOmTteV0=
gitlab.com/etke.cc/go/validator v1.0.6 h1:w0Muxf9Pqw7xvF7NaaswE6d7r9U3nB2t2l5PnFMrecQ= gitlab.com/etke.cc/go/validator v1.0.6 h1:w0Muxf9Pqw7xvF7NaaswE6d7r9U3nB2t2l5PnFMrecQ=
gitlab.com/etke.cc/go/validator v1.0.6/go.mod h1:Id0SxRj0J3IPhiKlj0w1plxVLZfHlkwipn7HfRZsDts= gitlab.com/etke.cc/go/validator v1.0.6/go.mod h1:Id0SxRj0J3IPhiKlj0w1plxVLZfHlkwipn7HfRZsDts=
gitlab.com/etke.cc/linkpearl v0.0.0-20231121221431-72443f33d266 h1:mGbLQkdE35WeyinqP38HC0dqUOJ7FItEAumVIOz7Gg8= gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a h1:30WtX+uepGqyFnU7jIockJWxQUeYdljhhk63DCOXLZs=
gitlab.com/etke.cc/linkpearl v0.0.0-20231121221431-72443f33d266/go.mod h1:wFEvngglb6ZTlE58/2a9gwYYs6V3FTYclYn5Pf5EGyQ= gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a/go.mod h1:3lqQGDDtk52Jm8PD3mZ3qhmIp4JXuq95waWH5vmEacc=
go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw= go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs=
go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c= go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4= golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA= golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g= golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ= golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE= golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
golang.org/x/net v0.0.0-20180911220305-26e67e76b6c3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4= golang.org/x/net v0.0.0-20180911220305-26e67e76b6c3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20210501142056-aec3718b3fa0/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk= golang.org/x/net v0.0.0-20210501142056-aec3718b3fa0/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y= golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c= golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg= golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ= golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs= golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -135,11 +135,11 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg= golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q= golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA= golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo= golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8= golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8= golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk= golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ= golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ= golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
@@ -152,5 +152,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM= gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8= maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho= maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg= maunium.net/go/mautrix v0.17.0 h1:scc1qlUbzPn+wc+3eAPquyD+3gZwwy/hBANBm+iGKK8=
maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4= maunium.net/go/mautrix v0.17.0/go.mod h1:j+puTEQCEydlVxhJ/dQP5chfa26TdvBO7X6F3Ataav8=

View File

@@ -1,6 +1,7 @@
package smtp package smtp
import ( import (
"context"
"crypto/tls" "crypto/tls"
"net" "net"
"sync" "sync"
@@ -15,10 +16,10 @@ type Listener struct {
tls *tls.Config tls *tls.Config
tlsMu sync.Mutex tlsMu sync.Mutex
listener net.Listener listener net.Listener
isBanned func(net.Addr) bool isBanned func(context.Context, net.Addr) bool
} }
func NewListener(port string, tlsConfig *tls.Config, isBanned func(net.Addr) bool, log *zerolog.Logger) (*Listener, error) { func NewListener(port string, tlsConfig *tls.Config, isBanned func(context.Context, net.Addr) bool, log *zerolog.Logger) (*Listener, error) {
actual, err := net.Listen("tcp", ":"+port) actual, err := net.Listen("tcp", ":"+port)
if err != nil { if err != nil {
return nil, err return nil, err
@@ -52,7 +53,7 @@ func (l *Listener) Accept() (net.Conn, error) {
continue continue
} }
} }
if l.isBanned(conn.RemoteAddr()) { if l.isBanned(context.Background(), conn.RemoteAddr()) {
conn.Close() conn.Close()
l.log.Info().Str("addr", conn.RemoteAddr().String()).Msg("rejected connection (already banned)") l.log.Info().Str("addr", conn.RemoteAddr().String()).Msg("rejected connection (already banned)")
continue continue

View File

@@ -60,16 +60,16 @@ type Manager struct {
} }
type matrixbot interface { type matrixbot interface {
AllowAuth(string, string) (id.RoomID, bool) AllowAuth(context.Context, string, string) (id.RoomID, bool)
IsGreylisted(net.Addr) bool IsGreylisted(context.Context, net.Addr) bool
IsBanned(net.Addr) bool IsBanned(context.Context, net.Addr) bool
IsTrusted(net.Addr) bool IsTrusted(net.Addr) bool
BanAuto(net.Addr) BanAuto(context.Context, net.Addr)
BanAuth(net.Addr) BanAuth(context.Context, net.Addr)
GetMapping(string) (id.RoomID, bool) GetMapping(context.Context, string) (id.RoomID, bool)
GetIFOptions(id.RoomID) email.IncomingFilteringOptions GetIFOptions(context.Context, id.RoomID) email.IncomingFilteringOptions
IncomingEmail(context.Context, *email.Email) error IncomingEmail(context.Context, *email.Email) error
GetDKIMprivkey() string GetDKIMprivkey(context.Context) string
} }
// Caller is Sendmail caller // Caller is Sendmail caller

View File

@@ -46,27 +46,28 @@ type mailServer struct {
// Login used for outgoing mail submissions only (when you use postmoogle as smtp server in your scripts) // Login used for outgoing mail submissions only (when you use postmoogle as smtp server in your scripts)
func (m *mailServer) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) { func (m *mailServer) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
m.log.Debug().Str("username", username).Any("state", state).Msg("Login") m.log.Debug().Str("username", username).Any("state", state).Msg("Login")
if m.bot.IsBanned(state.RemoteAddr) { ctx := context.Background()
if m.bot.IsBanned(ctx, state.RemoteAddr) {
return nil, ErrBanned return nil, ErrBanned
} }
if !email.AddressValid(username) { if !email.AddressValid(username) {
m.log.Debug().Str("address", username).Msg("address is invalid") m.log.Debug().Str("address", username).Msg("address is invalid")
m.bot.BanAuth(state.RemoteAddr) m.bot.BanAuth(ctx, state.RemoteAddr)
return nil, ErrBanned return nil, ErrBanned
} }
roomID, allow := m.bot.AllowAuth(username, password) roomID, allow := m.bot.AllowAuth(ctx, username, password)
if !allow { if !allow {
m.log.Debug().Str("username", username).Msg("username or password is invalid") m.log.Debug().Str("username", username).Msg("username or password is invalid")
m.bot.BanAuth(state.RemoteAddr) m.bot.BanAuth(ctx, state.RemoteAddr)
return nil, ErrBanned return nil, ErrBanned
} }
return &outgoingSession{ return &outgoingSession{
ctx: sentry.SetHubOnContext(context.Background(), sentry.CurrentHub().Clone()), ctx: sentry.SetHubOnContext(context.Background(), sentry.CurrentHub().Clone()),
sendmail: m.sender.Send, sendmail: m.sender.Send,
privkey: m.bot.GetDKIMprivkey(), privkey: m.bot.GetDKIMprivkey(ctx),
from: username, from: username,
log: m.log, log: m.log,
domains: m.domains, domains: m.domains,
@@ -79,7 +80,8 @@ func (m *mailServer) Login(state *smtp.ConnectionState, username, password strin
// AnonymousLogin used for incoming mail submissions only // AnonymousLogin used for incoming mail submissions only
func (m *mailServer) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) { func (m *mailServer) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
m.log.Debug().Any("state", state).Msg("AnonymousLogin") m.log.Debug().Any("state", state).Msg("AnonymousLogin")
if m.bot.IsBanned(state.RemoteAddr) { ctx := context.Background()
if m.bot.IsBanned(ctx, state.RemoteAddr) {
return nil, ErrBanned return nil, ErrBanned
} }

View File

@@ -33,12 +33,12 @@ var (
// incomingSession represents an SMTP-submission session receiving emails from remote servers // incomingSession represents an SMTP-submission session receiving emails from remote servers
type incomingSession struct { type incomingSession struct {
log *zerolog.Logger log *zerolog.Logger
getRoomID func(string) (id.RoomID, bool) getRoomID func(context.Context, string) (id.RoomID, bool)
getFilters func(id.RoomID) email.IncomingFilteringOptions getFilters func(context.Context, id.RoomID) email.IncomingFilteringOptions
receiveEmail func(context.Context, *email.Email) error receiveEmail func(context.Context, *email.Email) error
greylisted func(net.Addr) bool greylisted func(context.Context, net.Addr) bool
trusted func(net.Addr) bool trusted func(net.Addr) bool
ban func(net.Addr) ban func(context.Context, net.Addr)
domains []string domains []string
roomID id.RoomID roomID id.RoomID
@@ -52,7 +52,7 @@ func (s *incomingSession) Mail(from string, opts smtp.MailOptions) error {
sentry.GetHubFromContext(s.ctx).Scope().SetTag("from", from) sentry.GetHubFromContext(s.ctx).Scope().SetTag("from", from)
if !email.AddressValid(from) { if !email.AddressValid(from) {
s.log.Debug().Str("from", from).Msg("address is invalid") s.log.Debug().Str("from", from).Msg("address is invalid")
s.ban(s.addr) s.ban(s.ctx, s.addr)
return ErrBanned return ErrBanned
} }
s.from = email.Address(from) s.from = email.Address(from)
@@ -77,7 +77,7 @@ func (s *incomingSession) Rcpt(to string) error {
} }
var ok bool var ok bool
s.roomID, ok = s.getRoomID(utils.Mailbox(to)) s.roomID, ok = s.getRoomID(s.ctx, utils.Mailbox(to))
if !ok { if !ok {
s.log.Debug().Str("to", to).Msg("mapping not found") s.log.Debug().Str("to", to).Msg("mapping not found")
return ErrNoUser return ErrNoUser
@@ -126,12 +126,12 @@ func (s *incomingSession) Data(r io.Reader) error {
} }
addr := s.getAddr(envelope) addr := s.getAddr(envelope)
reader.Seek(0, io.SeekStart) //nolint:errcheck // becase we're sure that's ok reader.Seek(0, io.SeekStart) //nolint:errcheck // becase we're sure that's ok
validations := s.getFilters(s.roomID) validations := s.getFilters(s.ctx, s.roomID)
if !validateIncoming(s.from, s.tos[0], addr, s.log, validations) { if !validateIncoming(s.from, s.tos[0], addr, s.log, validations) {
s.ban(addr) s.ban(s.ctx, addr)
return ErrBanned return ErrBanned
} }
if s.greylisted(addr) { if s.greylisted(s.ctx, addr) {
return &smtp.SMTPError{ return &smtp.SMTPError{
Code: GraylistCode, Code: GraylistCode,
EnhancedCode: GraylistEnhancedCode, EnhancedCode: GraylistEnhancedCode,
@@ -172,7 +172,7 @@ type outgoingSession struct {
sendmail func(string, string, string) error sendmail func(string, string, string) error
privkey string privkey string
domains []string domains []string
getRoomID func(string) (id.RoomID, bool) getRoomID func(context.Context, string) (id.RoomID, bool)
ctx context.Context //nolint:containedctx // that's session ctx context.Context //nolint:containedctx // that's session
tos []string tos []string
@@ -198,7 +198,7 @@ func (s *outgoingSession) Mail(from string, _ smtp.MailOptions) error {
return ErrNoUser return ErrNoUser
} }
roomID, ok := s.getRoomID(utils.Mailbox(from)) roomID, ok := s.getRoomID(s.ctx, utils.Mailbox(from))
if !ok { if !ok {
s.log.Debug().Str("from", from).Msg("mapping not found") s.log.Debug().Str("from", from).Msg("mapping not found")
return ErrNoUser return ErrNoUser

112
utils/psd.go Normal file
View File

@@ -0,0 +1,112 @@
package utils
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sync"
"time"
"github.com/rs/zerolog"
)
//nolint:gocritic // sync.Mutex is intended
type PSD struct {
sync.Mutex
cachedAt time.Time
cache map[string]bool
log *zerolog.Logger
url *url.URL
login string
password string
}
type PSDTarget struct {
Targets []string `json:"targets"`
Labels map[string]string `json:"labels"`
}
func NewPSD(baseURL, login, password string, log *zerolog.Logger) *PSD {
uri, err := url.Parse(baseURL)
if err != nil || login == "" || password == "" {
return &PSD{}
}
return &PSD{url: uri, login: login, password: password, log: log}
}
func (p *PSD) Contains(email string) (bool, error) {
if p.cachedAt.IsZero() || time.Since(p.cachedAt) > 10*time.Minute {
err := p.updateCache()
if err != nil {
return false, err
}
}
p.Lock()
defer p.Unlock()
return p.cache[email], nil
}
func (p *PSD) Status(email string) string {
ok, err := p.Contains(email)
if !ok || err != nil {
return ""
}
return "👤"
}
func (p *PSD) updateCache() error {
p.Lock()
defer p.Unlock()
defer func() {
p.cachedAt = time.Now()
}()
if p.url == nil {
return nil
}
cloned := *p.url
uri := cloned.JoinPath("/emails")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), http.NoBody)
if err != nil {
p.log.Error().Err(err).Msg("failed to create request")
return err
}
req.SetBasicAuth(p.login, p.password)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("%s", resp.Status) //nolint:goerr113 // no need to wrap
p.log.Error().Err(err).Msg("failed to fetch PSD")
return err
}
datab, err := io.ReadAll(resp.Body)
if err != nil {
p.log.Error().Err(err).Msg("failed to read response")
return err
}
var psd []*PSDTarget
err = json.Unmarshal(datab, &psd)
if err != nil {
p.log.Error().Err(err).Msg("failed to unmarshal response")
return err
}
p.cache = make(map[string]bool)
for _, t := range psd {
for _, email := range t.Targets {
p.cache[email] = true
}
}
return nil
}

View File

@@ -547,7 +547,7 @@ and facilitates the unification of logging and tracing in some systems:
type TracingHook struct{} type TracingHook struct{}
func (h TracingHook) Run(e *zerolog.Event, level zerolog.Level, msg string) { func (h TracingHook) Run(e *zerolog.Event, level zerolog.Level, msg string) {
ctx := e.Ctx() ctx := e.GetCtx()
spanId := getSpanIdFromContext(ctx) // as per your tracing framework spanId := getSpanIdFromContext(ctx) // as per your tracing framework
e.Str("span-id", spanId) e.Str("span-id", spanId)
} }

View File

@@ -76,6 +76,8 @@ type ConsoleWriter struct {
FormatErrFieldValue Formatter FormatErrFieldValue Formatter
FormatExtra func(map[string]interface{}, *bytes.Buffer) error FormatExtra func(map[string]interface{}, *bytes.Buffer) error
FormatPrepare func(map[string]interface{}) error
} }
// NewConsoleWriter creates and initializes a new ConsoleWriter. // NewConsoleWriter creates and initializes a new ConsoleWriter.
@@ -124,6 +126,13 @@ func (w ConsoleWriter) Write(p []byte) (n int, err error) {
return n, fmt.Errorf("cannot decode event: %s", err) return n, fmt.Errorf("cannot decode event: %s", err)
} }
if w.FormatPrepare != nil {
err = w.FormatPrepare(evt)
if err != nil {
return n, err
}
}
for _, p := range w.PartsOrder { for _, p := range w.PartsOrder {
w.writePart(buf, evt, p) w.writePart(buf, evt, p)
} }
@@ -146,6 +155,15 @@ func (w ConsoleWriter) Write(p []byte) (n int, err error) {
return len(p), err return len(p), err
} }
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (w ConsoleWriter) Close() error {
if closer, ok := w.Out.(io.Closer); ok {
return closer.Close()
}
return nil
}
// writeFields appends formatted key-value pairs to buf. // writeFields appends formatted key-value pairs to buf.
func (w ConsoleWriter) writeFields(evt map[string]interface{}, buf *bytes.Buffer) { func (w ConsoleWriter) writeFields(evt map[string]interface{}, buf *bytes.Buffer) {
var fields = make([]string, 0, len(evt)) var fields = make([]string, 0, len(evt))
@@ -272,7 +290,7 @@ func (w ConsoleWriter) writePart(buf *bytes.Buffer, evt map[string]interface{},
} }
case MessageFieldName: case MessageFieldName:
if w.FormatMessage == nil { if w.FormatMessage == nil {
f = consoleDefaultFormatMessage f = consoleDefaultFormatMessage(w.NoColor, evt[LevelFieldName])
} else { } else {
f = w.FormatMessage f = w.FormatMessage
} }
@@ -310,10 +328,10 @@ func needsQuote(s string) bool {
return false return false
} }
// colorize returns the string s wrapped in ANSI code c, unless disabled is true. // colorize returns the string s wrapped in ANSI code c, unless disabled is true or c is 0.
func colorize(s interface{}, c int, disabled bool) string { func colorize(s interface{}, c int, disabled bool) string {
e := os.Getenv("NO_COLOR") e := os.Getenv("NO_COLOR")
if e != "" { if e != "" || c == 0 {
disabled = true disabled = true
} }
@@ -378,27 +396,16 @@ func consoleDefaultFormatLevel(noColor bool) Formatter {
return func(i interface{}) string { return func(i interface{}) string {
var l string var l string
if ll, ok := i.(string); ok { if ll, ok := i.(string); ok {
switch ll { level, _ := ParseLevel(ll)
case LevelTraceValue: fl, ok := FormattedLevels[level]
l = colorize("TRC", colorMagenta, noColor) if ok {
case LevelDebugValue: l = colorize(fl, LevelColors[level], noColor)
l = colorize("DBG", colorYellow, noColor) } else {
case LevelInfoValue: l = strings.ToUpper(ll)[0:3]
l = colorize("INF", colorGreen, noColor)
case LevelWarnValue:
l = colorize("WRN", colorRed, noColor)
case LevelErrorValue:
l = colorize(colorize("ERR", colorRed, noColor), colorBold, noColor)
case LevelFatalValue:
l = colorize(colorize("FTL", colorRed, noColor), colorBold, noColor)
case LevelPanicValue:
l = colorize(colorize("PNC", colorRed, noColor), colorBold, noColor)
default:
l = colorize(ll, colorBold, noColor)
} }
} else { } else {
if i == nil { if i == nil {
l = colorize("???", colorBold, noColor) l = "???"
} else { } else {
l = strings.ToUpper(fmt.Sprintf("%s", i))[0:3] l = strings.ToUpper(fmt.Sprintf("%s", i))[0:3]
} }
@@ -425,11 +432,18 @@ func consoleDefaultFormatCaller(noColor bool) Formatter {
} }
} }
func consoleDefaultFormatMessage(i interface{}) string { func consoleDefaultFormatMessage(noColor bool, level interface{}) Formatter {
if i == nil { return func(i interface{}) string {
if i == nil || i == "" {
return "" return ""
} }
switch level {
case LevelInfoValue, LevelWarnValue, LevelErrorValue, LevelFatalValue, LevelPanicValue:
return colorize(fmt.Sprintf("%s", i), colorBold, noColor)
default:
return fmt.Sprintf("%s", i) return fmt.Sprintf("%s", i)
}
}
} }
func consoleDefaultFormatFieldName(noColor bool) Formatter { func consoleDefaultFormatFieldName(noColor bool) Formatter {
@@ -450,6 +464,6 @@ func consoleDefaultFormatErrFieldName(noColor bool) Formatter {
func consoleDefaultFormatErrFieldValue(noColor bool) Formatter { func consoleDefaultFormatErrFieldValue(noColor bool) Formatter {
return func(i interface{}) string { return func(i interface{}) string {
return colorize(fmt.Sprintf("%s", i), colorRed, noColor) return colorize(colorize(fmt.Sprintf("%s", i), colorBold, noColor), colorRed, noColor)
} }
} }

View File

@@ -3,7 +3,7 @@ package zerolog
import ( import (
"context" "context"
"fmt" "fmt"
"io/ioutil" "io"
"math" "math"
"net" "net"
"time" "time"
@@ -23,7 +23,7 @@ func (c Context) Logger() Logger {
// Only map[string]interface{} and []interface{} are accepted. []interface{} must // Only map[string]interface{} and []interface{} are accepted. []interface{} must
// alternate string keys and arbitrary values, and extraneous ones are ignored. // alternate string keys and arbitrary values, and extraneous ones are ignored.
func (c Context) Fields(fields interface{}) Context { func (c Context) Fields(fields interface{}) Context {
c.l.context = appendFields(c.l.context, fields) c.l.context = appendFields(c.l.context, fields, c.l.stack)
return c return c
} }
@@ -57,7 +57,7 @@ func (c Context) Array(key string, arr LogArrayMarshaler) Context {
// Object marshals an object that implement the LogObjectMarshaler interface. // Object marshals an object that implement the LogObjectMarshaler interface.
func (c Context) Object(key string, obj LogObjectMarshaler) Context { func (c Context) Object(key string, obj LogObjectMarshaler) Context {
e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0) e := newEvent(LevelWriterAdapter{io.Discard}, 0)
e.Object(key, obj) e.Object(key, obj)
c.l.context = enc.AppendObjectData(c.l.context, e.buf) c.l.context = enc.AppendObjectData(c.l.context, e.buf)
putEvent(e) putEvent(e)
@@ -66,7 +66,7 @@ func (c Context) Object(key string, obj LogObjectMarshaler) Context {
// EmbedObject marshals and Embeds an object that implement the LogObjectMarshaler interface. // EmbedObject marshals and Embeds an object that implement the LogObjectMarshaler interface.
func (c Context) EmbedObject(obj LogObjectMarshaler) Context { func (c Context) EmbedObject(obj LogObjectMarshaler) Context {
e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0) e := newEvent(LevelWriterAdapter{io.Discard}, 0)
e.EmbedObject(obj) e.EmbedObject(obj)
c.l.context = enc.AppendObjectData(c.l.context, e.buf) c.l.context = enc.AppendObjectData(c.l.context, e.buf)
putEvent(e) putEvent(e)
@@ -163,6 +163,22 @@ func (c Context) Errs(key string, errs []error) Context {
// Err adds the field "error" with serialized err to the logger context. // Err adds the field "error" with serialized err to the logger context.
func (c Context) Err(err error) Context { func (c Context) Err(err error) Context {
if c.l.stack && ErrorStackMarshaler != nil {
switch m := ErrorStackMarshaler(err).(type) {
case nil:
case LogObjectMarshaler:
c = c.Object(ErrorStackFieldName, m)
case error:
if m != nil && !isNilValue(m) {
c = c.Str(ErrorStackFieldName, m.Error())
}
case string:
c = c.Str(ErrorStackFieldName, m)
default:
c = c.Interface(ErrorStackFieldName, m)
}
}
return c.AnErr(ErrorFieldName, err) return c.AnErr(ErrorFieldName, err)
} }
@@ -375,10 +391,19 @@ func (c Context) Durs(key string, d []time.Duration) Context {
// Interface adds the field key with obj marshaled using reflection. // Interface adds the field key with obj marshaled using reflection.
func (c Context) Interface(key string, i interface{}) Context { func (c Context) Interface(key string, i interface{}) Context {
if obj, ok := i.(LogObjectMarshaler); ok {
return c.Object(key, obj)
}
c.l.context = enc.AppendInterface(enc.AppendKey(c.l.context, key), i) c.l.context = enc.AppendInterface(enc.AppendKey(c.l.context, key), i)
return c return c
} }
// Type adds the field key with val's type using reflection.
func (c Context) Type(key string, val interface{}) Context {
c.l.context = enc.AppendType(enc.AppendKey(c.l.context, key), val)
return c
}
// Any is a wrapper around Context.Interface. // Any is a wrapper around Context.Interface.
func (c Context) Any(key string, i interface{}) Context { func (c Context) Any(key string, i interface{}) Context {
return c.Interface(key, i) return c.Interface(key, i)

View File

@@ -164,7 +164,7 @@ func (e *Event) Fields(fields interface{}) *Event {
if e == nil { if e == nil {
return e return e
} }
e.buf = appendFields(e.buf, fields) e.buf = appendFields(e.buf, fields, e.stack)
return e return e
} }

7
vendor/github.com/rs/zerolog/example.jsonl generated vendored Normal file
View File

@@ -0,0 +1,7 @@
{"time":"5:41PM","level":"info","message":"Starting listener","listen":":8080","pid":37556}
{"time":"5:41PM","level":"debug","message":"Access","database":"myapp","host":"localhost:4962","pid":37556}
{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":23}
{"time":"5:41PM","level":"info","message":"Access","method":"POST","path":"/posts","pid":37556,"resp_time":532}
{"time":"5:41PM","level":"warn","message":"Slow request","method":"POST","path":"/posts","pid":37556,"resp_time":532}
{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":10}
{"time":"5:41PM","level":"error","message":"Database connection lost","database":"myapp","pid":37556,"error":"connection reset by peer"}

View File

@@ -12,13 +12,13 @@ func isNilValue(i interface{}) bool {
return (*[2]uintptr)(unsafe.Pointer(&i))[1] == 0 return (*[2]uintptr)(unsafe.Pointer(&i))[1] == 0
} }
func appendFields(dst []byte, fields interface{}) []byte { func appendFields(dst []byte, fields interface{}, stack bool) []byte {
switch fields := fields.(type) { switch fields := fields.(type) {
case []interface{}: case []interface{}:
if n := len(fields); n&0x1 == 1 { // odd number if n := len(fields); n&0x1 == 1 { // odd number
fields = fields[:n-1] fields = fields[:n-1]
} }
dst = appendFieldList(dst, fields) dst = appendFieldList(dst, fields, stack)
case map[string]interface{}: case map[string]interface{}:
keys := make([]string, 0, len(fields)) keys := make([]string, 0, len(fields))
for key := range fields { for key := range fields {
@@ -28,13 +28,13 @@ func appendFields(dst []byte, fields interface{}) []byte {
kv := make([]interface{}, 2) kv := make([]interface{}, 2)
for _, key := range keys { for _, key := range keys {
kv[0], kv[1] = key, fields[key] kv[0], kv[1] = key, fields[key]
dst = appendFieldList(dst, kv) dst = appendFieldList(dst, kv, stack)
} }
} }
return dst return dst
} }
func appendFieldList(dst []byte, kvList []interface{}) []byte { func appendFieldList(dst []byte, kvList []interface{}, stack bool) []byte {
for i, n := 0, len(kvList); i < n; i += 2 { for i, n := 0, len(kvList); i < n; i += 2 {
key, val := kvList[i], kvList[i+1] key, val := kvList[i], kvList[i+1]
if key, ok := key.(string); ok { if key, ok := key.(string); ok {
@@ -74,6 +74,21 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte {
default: default:
dst = enc.AppendInterface(dst, m) dst = enc.AppendInterface(dst, m)
} }
if stack && ErrorStackMarshaler != nil {
dst = enc.AppendKey(dst, ErrorStackFieldName)
switch m := ErrorStackMarshaler(val).(type) {
case nil:
case error:
if m != nil && !isNilValue(m) {
dst = enc.AppendString(dst, m.Error())
}
case string:
dst = enc.AppendString(dst, m)
default:
dst = enc.AppendInterface(dst, m)
}
}
case []error: case []error:
dst = enc.AppendArrayStart(dst) dst = enc.AppendArrayStart(dst)
for i, err := range val { for i, err := range val {

View File

@@ -108,6 +108,34 @@ var (
// DefaultContextLogger is returned from Ctx() if there is no logger associated // DefaultContextLogger is returned from Ctx() if there is no logger associated
// with the context. // with the context.
DefaultContextLogger *Logger DefaultContextLogger *Logger
// LevelColors are used by ConsoleWriter's consoleDefaultFormatLevel to color
// log levels.
LevelColors = map[Level]int{
TraceLevel: colorBlue,
DebugLevel: 0,
InfoLevel: colorGreen,
WarnLevel: colorYellow,
ErrorLevel: colorRed,
FatalLevel: colorRed,
PanicLevel: colorRed,
}
// FormattedLevels are used by ConsoleWriter's consoleDefaultFormatLevel
// for a short level name.
FormattedLevels = map[Level]string{
TraceLevel: "TRC",
DebugLevel: "DBG",
InfoLevel: "INF",
WarnLevel: "WRN",
ErrorLevel: "ERR",
FatalLevel: "FTL",
PanicLevel: "PNC",
}
// TriggerLevelWriterBufferReuseLimit is a limit in bytes that a buffer is dropped
// from the TriggerLevelWriter buffer pool if the buffer grows above the limit.
TriggerLevelWriterBufferReuseLimit = 64 * 1024
) )
var ( var (

32
vendor/github.com/rs/zerolog/log.go generated vendored
View File

@@ -118,7 +118,6 @@ import (
"errors" "errors"
"fmt" "fmt"
"io" "io"
"io/ioutil"
"os" "os"
"strconv" "strconv"
"strings" "strings"
@@ -246,7 +245,7 @@ type Logger struct {
// you may consider using sync wrapper. // you may consider using sync wrapper.
func New(w io.Writer) Logger { func New(w io.Writer) Logger {
if w == nil { if w == nil {
w = ioutil.Discard w = io.Discard
} }
lw, ok := w.(LevelWriter) lw, ok := w.(LevelWriter)
if !ok { if !ok {
@@ -326,10 +325,13 @@ func (l Logger) Sample(s Sampler) Logger {
} }
// Hook returns a logger with the h Hook. // Hook returns a logger with the h Hook.
func (l Logger) Hook(h Hook) Logger { func (l Logger) Hook(hooks ...Hook) Logger {
newHooks := make([]Hook, len(l.hooks), len(l.hooks)+1) if len(hooks) == 0 {
return l
}
newHooks := make([]Hook, len(l.hooks), len(l.hooks)+len(hooks))
copy(newHooks, l.hooks) copy(newHooks, l.hooks)
l.hooks = append(newHooks, h) l.hooks = append(newHooks, hooks...)
return l return l
} }
@@ -385,7 +387,14 @@ func (l *Logger) Err(err error) *Event {
// //
// You must call Msg on the returned event in order to send the event. // You must call Msg on the returned event in order to send the event.
func (l *Logger) Fatal() *Event { func (l *Logger) Fatal() *Event {
return l.newEvent(FatalLevel, func(msg string) { os.Exit(1) }) return l.newEvent(FatalLevel, func(msg string) {
if closer, ok := l.w.(io.Closer); ok {
// Close the writer to flush any buffered message. Otherwise the message
// will be lost as os.Exit() terminates the program immediately.
closer.Close()
}
os.Exit(1)
})
} }
// Panic starts a new message with panic level. The panic() function // Panic starts a new message with panic level. The panic() function
@@ -450,6 +459,14 @@ func (l *Logger) Printf(format string, v ...interface{}) {
} }
} }
// Println sends a log event using debug level and no extra field.
// Arguments are handled in the manner of fmt.Println.
func (l *Logger) Println(v ...interface{}) {
if e := l.Debug(); e.Enabled() {
e.CallerSkipFrame(1).Msg(fmt.Sprintln(v...))
}
}
// Write implements the io.Writer interface. This is useful to set as a writer // Write implements the io.Writer interface. This is useful to set as a writer
// for the standard library log. // for the standard library log.
func (l Logger) Write(p []byte) (n int, err error) { func (l Logger) Write(p []byte) (n int, err error) {
@@ -488,6 +505,9 @@ func (l *Logger) newEvent(level Level, done func(string)) *Event {
// should returns true if the log event should be logged. // should returns true if the log event should be logged.
func (l *Logger) should(lvl Level) bool { func (l *Logger) should(lvl Level) bool {
if l.w == nil {
return false
}
if lvl < l.level || lvl < GlobalLevel() { if lvl < l.level || lvl < GlobalLevel() {
return false return false
} }

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

After

Width:  |  Height:  |  Size: 116 KiB

View File

@@ -78,3 +78,12 @@ func (sw syslogWriter) WriteLevel(level Level, p []byte) (n int, err error) {
n = len(p) n = len(p)
return return
} }
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (sw syslogWriter) Close() error {
if c, ok := sw.w.(io.Closer); ok {
return c.Close()
}
return nil
}

View File

@@ -27,6 +27,15 @@ func (lw LevelWriterAdapter) WriteLevel(l Level, p []byte) (n int, err error) {
return lw.Write(p) return lw.Write(p)
} }
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (lw LevelWriterAdapter) Close() error {
if closer, ok := lw.Writer.(io.Closer); ok {
return closer.Close()
}
return nil
}
type syncWriter struct { type syncWriter struct {
mu sync.Mutex mu sync.Mutex
lw LevelWriter lw LevelWriter
@@ -57,6 +66,15 @@ func (s *syncWriter) WriteLevel(l Level, p []byte) (n int, err error) {
return s.lw.WriteLevel(l, p) return s.lw.WriteLevel(l, p)
} }
func (s *syncWriter) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if closer, ok := s.lw.(io.Closer); ok {
return closer.Close()
}
return nil
}
type multiLevelWriter struct { type multiLevelWriter struct {
writers []LevelWriter writers []LevelWriter
} }
@@ -89,6 +107,20 @@ func (t multiLevelWriter) WriteLevel(l Level, p []byte) (n int, err error) {
return n, err return n, err
} }
// Calls close on all the underlying writers that are io.Closers. If any of the
// Close methods return an error, the remainder of the closers are not closed
// and the error is returned.
func (t multiLevelWriter) Close() error {
for _, w := range t.writers {
if closer, ok := w.(io.Closer); ok {
if err := closer.Close(); err != nil {
return err
}
}
}
return nil
}
// MultiLevelWriter creates a writer that duplicates its writes to all the // MultiLevelWriter creates a writer that duplicates its writes to all the
// provided writers, similar to the Unix tee(1) command. If some writers // provided writers, similar to the Unix tee(1) command. If some writers
// implement LevelWriter, their WriteLevel method will be used instead of Write. // implement LevelWriter, their WriteLevel method will be used instead of Write.
@@ -180,3 +212,135 @@ func (w *FilteredLevelWriter) WriteLevel(level Level, p []byte) (int, error) {
} }
return len(p), nil return len(p), nil
} }
var triggerWriterPool = &sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, 1024))
},
}
// TriggerLevelWriter buffers log lines at the ConditionalLevel or below
// until a trigger level (or higher) line is emitted. Log lines with level
// higher than ConditionalLevel are always written out to the destination
// writer. If trigger never happens, buffered log lines are never written out.
//
// It can be used to configure "log level per request".
type TriggerLevelWriter struct {
// Destination writer. If LevelWriter is provided (usually), its WriteLevel is used
// instead of Write.
io.Writer
// ConditionalLevel is the level (and below) at which lines are buffered until
// a trigger level (or higher) line is emitted. Usually this is set to DebugLevel.
ConditionalLevel Level
// TriggerLevel is the lowest level that triggers the sending of the conditional
// level lines. Usually this is set to ErrorLevel.
TriggerLevel Level
buf *bytes.Buffer
triggered bool
mu sync.Mutex
}
func (w *TriggerLevelWriter) WriteLevel(l Level, p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
// At first trigger level or above log line, we flush the buffer and change the
// trigger state to triggered.
if !w.triggered && l >= w.TriggerLevel {
err := w.trigger()
if err != nil {
return 0, err
}
}
// Unless triggered, we buffer everything at and below ConditionalLevel.
if !w.triggered && l <= w.ConditionalLevel {
if w.buf == nil {
w.buf = triggerWriterPool.Get().(*bytes.Buffer)
}
// We prefix each log line with a byte with the level.
// Hopefully we will never have a level value which equals a newline
// (which could interfere with reconstruction of log lines in the trigger method).
w.buf.WriteByte(byte(l))
w.buf.Write(p)
return len(p), nil
}
// Anything above ConditionalLevel is always passed through.
// Once triggered, everything is passed through.
if lw, ok := w.Writer.(LevelWriter); ok {
return lw.WriteLevel(l, p)
}
return w.Write(p)
}
// trigger expects lock to be held.
func (w *TriggerLevelWriter) trigger() error {
if w.triggered {
return nil
}
w.triggered = true
if w.buf == nil {
return nil
}
p := w.buf.Bytes()
for len(p) > 0 {
// We do not use bufio.Scanner here because we already have full buffer
// in the memory and we do not want extra copying from the buffer to
// scanner's token slice, nor we want to hit scanner's token size limit,
// and we also want to preserve newlines.
i := bytes.IndexByte(p, '\n')
line := p[0 : i+1]
p = p[i+1:]
// We prefixed each log line with a byte with the level.
level := Level(line[0])
line = line[1:]
var err error
if lw, ok := w.Writer.(LevelWriter); ok {
_, err = lw.WriteLevel(level, line)
} else {
_, err = w.Write(line)
}
if err != nil {
return err
}
}
return nil
}
// Trigger forces flushing the buffer and change the trigger state to
// triggered, if the writer has not already been triggered before.
func (w *TriggerLevelWriter) Trigger() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.trigger()
}
// Close closes the writer and returns the buffer to the pool.
func (w *TriggerLevelWriter) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.buf == nil {
return nil
}
// We return the buffer only if it has not grown above the limit.
// This prevents accumulation of large buffers in the pool just
// because occasionally a large buffer might be needed.
if w.buf.Cap() <= TriggerLevelWriterBufferReuseLimit {
w.buf.Reset()
triggerWriterPool.Put(w.buf)
}
w.buf = nil
return nil
}

View File

@@ -8,7 +8,7 @@ goldmark
> A Markdown parser written in Go. Easy to extend, standards-compliant, well-structured. > A Markdown parser written in Go. Easy to extend, standards-compliant, well-structured.
goldmark is compliant with CommonMark 0.30. goldmark is compliant with CommonMark 0.31.2.
Motivation Motivation
---------------------- ----------------------
@@ -260,7 +260,7 @@ You can override autolinking patterns via options.
| Functional option | Type | Description | | Functional option | Type | Description |
| ----------------- | ---- | ----------- | | ----------------- | ---- | ----------- |
| `extension.WithLinkifyAllowedProtocols` | `[][]byte` | List of allowed protocols such as `[][]byte{ []byte("http:") }` | | `extension.WithLinkifyAllowedProtocols` | `[][]byte \| []string` | List of allowed protocols such as `[]string{ "http:" }` |
| `extension.WithLinkifyURLRegexp` | `*regexp.Regexp` | Regexp that defines URLs, including protocols | | `extension.WithLinkifyURLRegexp` | `*regexp.Regexp` | Regexp that defines URLs, including protocols |
| `extension.WithLinkifyWWWRegexp` | `*regexp.Regexp` | Regexp that defines URL starting with `www.`. This pattern corresponds to [the extended www autolink](https://github.github.com/gfm/#extended-www-autolink) | | `extension.WithLinkifyWWWRegexp` | `*regexp.Regexp` | Regexp that defines URL starting with `www.`. This pattern corresponds to [the extended www autolink](https://github.github.com/gfm/#extended-www-autolink) |
| `extension.WithLinkifyEmailRegexp` | `*regexp.Regexp` | Regexp that defines email addresses` | | `extension.WithLinkifyEmailRegexp` | `*regexp.Regexp` | Regexp that defines email addresses` |
@@ -277,9 +277,9 @@ markdown := goldmark.New(
), ),
goldmark.WithExtensions( goldmark.WithExtensions(
extension.NewLinkify( extension.NewLinkify(
extension.WithLinkifyAllowedProtocols([][]byte{ extension.WithLinkifyAllowedProtocols([]string{
[]byte("http:"), "http:",
[]byte("https:"), "https:",
}), }),
extension.WithLinkifyURLRegexp( extension.WithLinkifyURLRegexp(
xurls.Strict, xurls.Strict,
@@ -297,13 +297,13 @@ This extension has some options:
| Functional option | Type | Description | | Functional option | Type | Description |
| ----------------- | ---- | ----------- | | ----------------- | ---- | ----------- |
| `extension.WithFootnoteIDPrefix` | `[]byte` | a prefix for the id attributes.| | `extension.WithFootnoteIDPrefix` | `[]byte \| string` | a prefix for the id attributes.|
| `extension.WithFootnoteIDPrefixFunction` | `func(gast.Node) []byte` | a function that determines the id attribute for given Node.| | `extension.WithFootnoteIDPrefixFunction` | `func(gast.Node) []byte` | a function that determines the id attribute for given Node.|
| `extension.WithFootnoteLinkTitle` | `[]byte` | an optional title attribute for footnote links.| | `extension.WithFootnoteLinkTitle` | `[]byte \| string` | an optional title attribute for footnote links.|
| `extension.WithFootnoteBacklinkTitle` | `[]byte` | an optional title attribute for footnote backlinks. | | `extension.WithFootnoteBacklinkTitle` | `[]byte \| string` | an optional title attribute for footnote backlinks. |
| `extension.WithFootnoteLinkClass` | `[]byte` | a class for footnote links. This defaults to `footnote-ref`. | | `extension.WithFootnoteLinkClass` | `[]byte \| string` | a class for footnote links. This defaults to `footnote-ref`. |
| `extension.WithFootnoteBacklinkClass` | `[]byte` | a class for footnote backlinks. This defaults to `footnote-backref`. | | `extension.WithFootnoteBacklinkClass` | `[]byte \| string` | a class for footnote backlinks. This defaults to `footnote-backref`. |
| `extension.WithFootnoteBacklinkHTML` | `[]byte` | a class for footnote backlinks. This defaults to `&#x21a9;&#xfe0e;`. | | `extension.WithFootnoteBacklinkHTML` | `[]byte \| string` | a class for footnote backlinks. This defaults to `&#x21a9;&#xfe0e;`. |
Some options can have special substitutions. Occurrences of “^^” in the string will be replaced by the corresponding footnote number in the HTML output. Occurrences of “%%” will be replaced by a number for the reference (footnotes can have multiple references). Some options can have special substitutions. Occurrences of “^^” in the string will be replaced by the corresponding footnote number in the HTML output. Occurrences of “%%” will be replaced by a number for the reference (footnotes can have multiple references).
@@ -319,7 +319,7 @@ for _, path := range files {
markdown := goldmark.New( markdown := goldmark.New(
goldmark.WithExtensions( goldmark.WithExtensions(
NewFootnote( NewFootnote(
WithFootnoteIDPrefix([]byte(path)), WithFootnoteIDPrefix(path),
), ),
), ),
) )
@@ -379,7 +379,7 @@ This extension provides additional options for CJK users.
| Functional option | Type | Description | | Functional option | Type | Description |
| ----------------- | ---- | ----------- | | ----------------- | ---- | ----------- |
| `extension.WithEastAsianLineBreaks` | `...extension.EastAsianLineBreaksStyle` | Soft line breaks are rendered as a newline. Some asian users will see it as an unnecessary space. With this option, soft line breaks between east asian wide characters will be ignored. | | `extension.WithEastAsianLineBreaks` | `...extension.EastAsianLineBreaksStyle` | Soft line breaks are rendered as a newline. Some asian users will see it as an unnecessary space. With this option, soft line breaks between east asian wide characters will be ignored. This defaults to `EastAsianLineBreaksStyleSimple`. |
| `extension.WithEscapedSpace` | `-` | Without spaces around an emphasis started with east asian punctuations, it is not interpreted as an emphasis(as defined in CommonMark spec). With this option, you can avoid this inconvenient behavior by putting 'not rendered' spaces around an emphasis like `太郎は\ **「こんにちわ」**\ といった`. | | `extension.WithEscapedSpace` | `-` | Without spaces around an emphasis started with east asian punctuations, it is not interpreted as an emphasis(as defined in CommonMark spec). With this option, you can avoid this inconvenient behavior by putting 'not rendered' spaces around an emphasis like `太郎は\ **「こんにちわ」**\ といった`. |
#### Styles of Line Breaking #### Styles of Line Breaking
@@ -467,6 +467,7 @@ As you can see, goldmark's performance is on par with cmark's.
Extensions Extensions
-------------------- --------------------
### List of extensions
- [goldmark-meta](https://github.com/yuin/goldmark-meta): A YAML metadata - [goldmark-meta](https://github.com/yuin/goldmark-meta): A YAML metadata
extension for the goldmark Markdown parser. extension for the goldmark Markdown parser.
@@ -490,6 +491,13 @@ Extensions
- [goldmark-d2](https://github.com/FurqanSoftware/goldmark-d2): Adds support for [D2](https://d2lang.com/) diagrams. - [goldmark-d2](https://github.com/FurqanSoftware/goldmark-d2): Adds support for [D2](https://d2lang.com/) diagrams.
- [goldmark-katex](https://github.com/FurqanSoftware/goldmark-katex): Adds support for [KaTeX](https://katex.org/) math and equations. - [goldmark-katex](https://github.com/FurqanSoftware/goldmark-katex): Adds support for [KaTeX](https://katex.org/) math and equations.
- [goldmark-img64](https://github.com/tenkoh/goldmark-img64): Adds support for embedding images into the document as DataURL (base64 encoded). - [goldmark-img64](https://github.com/tenkoh/goldmark-img64): Adds support for embedding images into the document as DataURL (base64 encoded).
- [goldmark-enclave](https://github.com/quail-ink/goldmark-enclave): Adds support for embedding youtube/bilibili video, X's [oembed tweet](https://publish.twitter.com/), [tradingview](https://www.tradingview.com/widget/)'s chart, [quail](https://quail.ink)'s widget into the document.
- [goldmark-wiki-table](https://github.com/movsb/goldmark-wiki-table): Adds support for embedding Wiki Tables.
### Loading extensions at runtime
[goldmark-dynamic](https://github.com/yuin/goldmark-dynamic) allows you to write a goldmark extension in Lua and load it at runtime without re-compilation.
Please refer to [goldmark-dynamic](https://github.com/yuin/goldmark-dynamic) for details.
goldmark internal(for extension developers) goldmark internal(for extension developers)

View File

@@ -382,8 +382,8 @@ func (o *withFootnoteIDPrefix) SetFootnoteOption(c *FootnoteConfig) {
} }
// WithFootnoteIDPrefix is a functional option that is a prefix for the id attributes generated by footnotes. // WithFootnoteIDPrefix is a functional option that is a prefix for the id attributes generated by footnotes.
func WithFootnoteIDPrefix(a []byte) FootnoteOption { func WithFootnoteIDPrefix[T []byte | string](a T) FootnoteOption {
return &withFootnoteIDPrefix{a} return &withFootnoteIDPrefix{[]byte(a)}
} }
const optFootnoteIDPrefixFunction renderer.OptionName = "FootnoteIDPrefixFunction" const optFootnoteIDPrefixFunction renderer.OptionName = "FootnoteIDPrefixFunction"
@@ -420,8 +420,8 @@ func (o *withFootnoteLinkTitle) SetFootnoteOption(c *FootnoteConfig) {
} }
// WithFootnoteLinkTitle is a functional option that is an optional title attribute for footnote links. // WithFootnoteLinkTitle is a functional option that is an optional title attribute for footnote links.
func WithFootnoteLinkTitle(a []byte) FootnoteOption { func WithFootnoteLinkTitle[T []byte | string](a T) FootnoteOption {
return &withFootnoteLinkTitle{a} return &withFootnoteLinkTitle{[]byte(a)}
} }
const optFootnoteBacklinkTitle renderer.OptionName = "FootnoteBacklinkTitle" const optFootnoteBacklinkTitle renderer.OptionName = "FootnoteBacklinkTitle"
@@ -439,8 +439,8 @@ func (o *withFootnoteBacklinkTitle) SetFootnoteOption(c *FootnoteConfig) {
} }
// WithFootnoteBacklinkTitle is a functional option that is an optional title attribute for footnote backlinks. // WithFootnoteBacklinkTitle is a functional option that is an optional title attribute for footnote backlinks.
func WithFootnoteBacklinkTitle(a []byte) FootnoteOption { func WithFootnoteBacklinkTitle[T []byte | string](a T) FootnoteOption {
return &withFootnoteBacklinkTitle{a} return &withFootnoteBacklinkTitle{[]byte(a)}
} }
const optFootnoteLinkClass renderer.OptionName = "FootnoteLinkClass" const optFootnoteLinkClass renderer.OptionName = "FootnoteLinkClass"
@@ -458,8 +458,8 @@ func (o *withFootnoteLinkClass) SetFootnoteOption(c *FootnoteConfig) {
} }
// WithFootnoteLinkClass is a functional option that is a class for footnote links. // WithFootnoteLinkClass is a functional option that is a class for footnote links.
func WithFootnoteLinkClass(a []byte) FootnoteOption { func WithFootnoteLinkClass[T []byte | string](a T) FootnoteOption {
return &withFootnoteLinkClass{a} return &withFootnoteLinkClass{[]byte(a)}
} }
const optFootnoteBacklinkClass renderer.OptionName = "FootnoteBacklinkClass" const optFootnoteBacklinkClass renderer.OptionName = "FootnoteBacklinkClass"
@@ -477,8 +477,8 @@ func (o *withFootnoteBacklinkClass) SetFootnoteOption(c *FootnoteConfig) {
} }
// WithFootnoteBacklinkClass is a functional option that is a class for footnote backlinks. // WithFootnoteBacklinkClass is a functional option that is a class for footnote backlinks.
func WithFootnoteBacklinkClass(a []byte) FootnoteOption { func WithFootnoteBacklinkClass[T []byte | string](a T) FootnoteOption {
return &withFootnoteBacklinkClass{a} return &withFootnoteBacklinkClass{[]byte(a)}
} }
const optFootnoteBacklinkHTML renderer.OptionName = "FootnoteBacklinkHTML" const optFootnoteBacklinkHTML renderer.OptionName = "FootnoteBacklinkHTML"
@@ -496,8 +496,8 @@ func (o *withFootnoteBacklinkHTML) SetFootnoteOption(c *FootnoteConfig) {
} }
// WithFootnoteBacklinkHTML is an HTML content for footnote backlinks. // WithFootnoteBacklinkHTML is an HTML content for footnote backlinks.
func WithFootnoteBacklinkHTML(a []byte) FootnoteOption { func WithFootnoteBacklinkHTML[T []byte | string](a T) FootnoteOption {
return &withFootnoteBacklinkHTML{a} return &withFootnoteBacklinkHTML{[]byte(a)}
} }
// FootnoteHTMLRenderer is a renderer.NodeRenderer implementation that // FootnoteHTMLRenderer is a renderer.NodeRenderer implementation that

View File

@@ -66,10 +66,12 @@ func (o *withLinkifyAllowedProtocols) SetLinkifyOption(p *LinkifyConfig) {
// WithLinkifyAllowedProtocols is a functional option that specify allowed // WithLinkifyAllowedProtocols is a functional option that specify allowed
// protocols in autolinks. Each protocol must end with ':' like // protocols in autolinks. Each protocol must end with ':' like
// 'http:' . // 'http:' .
func WithLinkifyAllowedProtocols(value [][]byte) LinkifyOption { func WithLinkifyAllowedProtocols[T []byte | string](value []T) LinkifyOption {
return &withLinkifyAllowedProtocols{ opt := &withLinkifyAllowedProtocols{}
value: value, for _, v := range value {
opt.value = append(opt.value, []byte(v))
} }
return opt
} }
type withLinkifyURLRegexp struct { type withLinkifyURLRegexp struct {

View File

@@ -115,10 +115,10 @@ func (o *withTypographicSubstitutions) SetTypographerOption(p *TypographerConfig
// WithTypographicSubstitutions is a functional otpion that specify replacement text // WithTypographicSubstitutions is a functional otpion that specify replacement text
// for punctuations. // for punctuations.
func WithTypographicSubstitutions(values map[TypographicPunctuation][]byte) TypographerOption { func WithTypographicSubstitutions[T []byte | string](values map[TypographicPunctuation]T) TypographerOption {
replacements := newDefaultSubstitutions() replacements := newDefaultSubstitutions()
for k, v := range values { for k, v := range values {
replacements[k] = v replacements[k] = []byte(v)
} }
return &withTypographicSubstitutions{replacements} return &withTypographicSubstitutions{replacements}

View File

@@ -61,8 +61,8 @@ var allowedBlockTags = map[string]bool{
"option": true, "option": true,
"p": true, "p": true,
"param": true, "param": true,
"search": true,
"section": true, "section": true,
"source": true,
"summary": true, "summary": true,
"table": true, "table": true,
"tbody": true, "tbody": true,

View File

@@ -58,47 +58,38 @@ var closeProcessingInstruction = []byte("?>")
var openCDATA = []byte("<![CDATA[") var openCDATA = []byte("<![CDATA[")
var closeCDATA = []byte("]]>") var closeCDATA = []byte("]]>")
var closeDecl = []byte(">") var closeDecl = []byte(">")
var emptyComment = []byte("<!---->") var emptyComment1 = []byte("<!-->")
var invalidComment1 = []byte("<!-->") var emptyComment2 = []byte("<!--->")
var invalidComment2 = []byte("<!--->")
var openComment = []byte("<!--") var openComment = []byte("<!--")
var closeComment = []byte("-->") var closeComment = []byte("-->")
var doubleHyphen = []byte("--")
func (s *rawHTMLParser) parseComment(block text.Reader, pc Context) ast.Node { func (s *rawHTMLParser) parseComment(block text.Reader, pc Context) ast.Node {
savedLine, savedSegment := block.Position() savedLine, savedSegment := block.Position()
node := ast.NewRawHTML() node := ast.NewRawHTML()
line, segment := block.PeekLine() line, segment := block.PeekLine()
if bytes.HasPrefix(line, emptyComment) { if bytes.HasPrefix(line, emptyComment1) {
node.Segments.Append(segment.WithStop(segment.Start + len(emptyComment))) node.Segments.Append(segment.WithStop(segment.Start + len(emptyComment1)))
block.Advance(len(emptyComment)) block.Advance(len(emptyComment1))
return node return node
} }
if bytes.HasPrefix(line, invalidComment1) || bytes.HasPrefix(line, invalidComment2) { if bytes.HasPrefix(line, emptyComment2) {
return nil node.Segments.Append(segment.WithStop(segment.Start + len(emptyComment2)))
block.Advance(len(emptyComment2))
return node
} }
offset := len(openComment) offset := len(openComment)
line = line[offset:] line = line[offset:]
for { for {
hindex := bytes.Index(line, doubleHyphen) index := bytes.Index(line, closeComment)
if hindex > -1 { if index > -1 {
hindex += offset node.Segments.Append(segment.WithStop(segment.Start + offset + index + len(closeComment)))
} block.Advance(offset + index + len(closeComment))
index := bytes.Index(line, closeComment) + offset
if index > -1 && hindex == index {
if index == 0 || len(line) < 2 || line[index-offset-1] != '-' {
node.Segments.Append(segment.WithStop(segment.Start + index + len(closeComment)))
block.Advance(index + len(closeComment))
return node return node
} }
} offset = 0
if hindex > 0 {
break
}
node.Segments.Append(segment) node.Segments.Append(segment)
block.AdvanceLine() block.AdvanceLine()
line, segment = block.PeekLine() line, segment = block.PeekLine()
offset = 0
if line == nil { if line == nil {
break break
} }

View File

@@ -808,7 +808,7 @@ func IsPunct(c byte) bool {
// IsPunctRune returns true if the given rune is a punctuation, otherwise false. // IsPunctRune returns true if the given rune is a punctuation, otherwise false.
func IsPunctRune(r rune) bool { func IsPunctRune(r rune) bool {
return int32(r) <= 256 && IsPunct(byte(r)) || unicode.IsPunct(r) return unicode.IsSymbol(r) || unicode.IsPunct(r)
} }
// IsSpace returns true if the given character is a space, otherwise false. // IsSpace returns true if the given character is a space, otherwise false.

View File

@@ -1,661 +1,166 @@
GNU AFFERO GENERAL PUBLIC LICENSE GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 19 November 2007 Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/> Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed. of this license document, but changing it is not allowed.
Preamble
This version of the GNU Lesser General Public License incorporates
The GNU Affero General Public License is a free, copyleft license for the terms and conditions of version 3 of the GNU General Public
software and other kinds of works, specifically designed to ensure License, supplemented by the additional permissions listed below.
cooperation with the community in the case of network server software.
0. Additional Definitions.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast, As used herein, "this License" refers to version 3 of the GNU Lesser
our General Public Licenses are intended to guarantee your freedom to General Public License, and the "GNU GPL" refers to version 3 of the GNU
share and change all versions of a program--to make sure it remains free General Public License.
software for all its users.
"The Library" refers to a covered work governed by this License,
When we speak of free software, we are referring to freedom, not other than an Application or a Combined Work as defined below.
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for An "Application" is any work that makes use of an interface provided
them if you wish), that you receive source code or can get it if you by the Library, but which is not otherwise based on the Library.
want it, that you can change the software or use pieces of it in new Defining a subclass of a class defined by the Library is deemed a mode
free programs, and that you know you can do these things. of using an interface provided by the Library.
Developers that use our General Public Licenses protect your rights A "Combined Work" is a work produced by combining or linking an
with two steps: (1) assert copyright on the software, and (2) offer Application with the Library. The particular version of the Library
you this License which gives you legal permission to copy, distribute with which the Combined Work was made is also called the "Linked
and/or modify the software. Version".
A secondary benefit of defending all users' freedom is that The "Minimal Corresponding Source" for a Combined Work means the
improvements made in alternate versions of the program, if they Corresponding Source for the Combined Work, excluding any source code
receive widespread use, become available for other developers to for portions of the Combined Work that, considered in isolation, are
incorporate. Many developers of free software are heartened and based on the Application, and not on the Linked Version.
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about. The "Corresponding Application Code" for a Combined Work means the
The GNU General Public License permits making a modified version and object code and/or source code for the Application, including any data
letting the public access it on a server without ever releasing its and utility programs needed for reproducing the Combined Work from the
source code to the public. Application, but excluding the System Libraries of the Combined Work.
The GNU Affero General Public License is designed specifically to 1. Exception to Section 3 of the GNU GPL.
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to You may convey a covered work under sections 3 and 4 of this License
provide the source code of the modified version running there to the without being bound by section 3 of the GNU GPL.
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source 2. Conveying Modified Versions.
code of the modified version.
If you modify a copy of the Library, and, in your modifications, a
An older license, called the Affero General Public License and facility refers to a function or data to be supplied by an Application
published by Affero, was designed to accomplish similar goals. This is that uses the facility (other than as an argument passed when the
a different license, not a version of the Affero GPL, but Affero has facility is invoked), then you may convey a copy of the modified
released a new version of the Affero GPL which permits relicensing under version:
this license.
a) under this License, provided that you make a good faith effort to
The precise terms and conditions for copying, distribution and ensure that, in the event an Application does not supply the
modification follow. function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
TERMS AND CONDITIONS
b) under the GNU GPL, with none of the additional permissions of
0. Definitions. this License applicable to that copy.
"This License" refers to version 3 of the GNU Affero General Public License. 3. Object Code Incorporating Material from Library Header Files.
"Copyright" also means copyright-like laws that apply to other kinds of The object code form of an Application may incorporate material from
works, such as semiconductor masks. a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
"The Program" refers to any copyrightable work licensed under this material is not limited to numerical parameters, data structure
License. Each licensee is addressed as "you". "Licensees" and layouts and accessors, or small macros, inline functions and templates
"recipients" may be individuals or organizations. (ten or fewer lines in length), you do both of the following:
To "modify" a work means to copy from or adapt all or part of the work a) Give prominent notice with each copy of the object code that the
in a fashion requiring copyright permission, other than the making of an Library is used in it and that the Library and its use are
exact copy. The resulting work is called a "modified version" of the covered by this License.
earlier work or a work "based on" the earlier work.
b) Accompany the object code with a copy of the GNU GPL and this license
A "covered work" means either the unmodified Program or a work based document.
on the Program.
4. Combined Works.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for You may convey a Combined Work under terms of your choice that,
infringement under applicable copyright law, except executing it on a taken together, effectively do not restrict modification of the
computer or modifying a private copy. Propagation includes copying, portions of the Library contained in the Combined Work and reverse
distribution (with or without modification), making available to the engineering for debugging such modifications, if you also do each of
public, and in some countries other activities as well. the following:
To "convey" a work means any kind of propagation that enables other a) Give prominent notice with each copy of the Combined Work that
parties to make or receive copies. Mere interaction with a user through the Library is used in it and that the Library and its use are
a computer network, with no transfer of a copy, is not conveying. covered by this License.
An interactive user interface displays "Appropriate Legal Notices" b) Accompany the Combined Work with a copy of the GNU GPL and this license
to the extent that it includes a convenient and prominently visible document.
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the c) For a Combined Work that displays copyright notices during
extent that warranties are provided), that licensees may convey the execution, include the copyright notice for the Library among
work under this License, and how to view a copy of this License. If these notices, as well as a reference directing the user to the
the interface presents a list of user commands or options, such as a copies of the GNU GPL and this license document.
menu, a prominent item in the list meets this criterion.
d) Do one of the following:
1. Source Code.
0) Convey the Minimal Corresponding Source under the terms of this
The "source code" for a work means the preferred form of the work License, and the Corresponding Application Code in a form
for making modifications to it. "Object code" means any non-source suitable for, and under terms that permit, the user to
form of a work. recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
A "Standard Interface" means an interface that either is an official manner specified by section 6 of the GNU GPL for conveying
standard defined by a recognized standards body, or, in the case of Corresponding Source.
interfaces specified for a particular programming language, one that
is widely used among developers working in that language. 1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
The "System Libraries" of an executable work include anything, other a copy of the Library already present on the user's computer
than the work as a whole, that (a) is included in the normal form of system, and (b) will operate properly with a modified version
packaging a Major Component, but which is not part of that Major of the Library that is interface-compatible with the Linked
Component, and (b) serves only to enable use of the work with that Version.
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A e) Provide Installation Information, but only if you would otherwise
"Major Component", in this context, means a major essential component be required to provide such information under section 6 of the
(kernel, window system, and so on) of the specific operating system GNU GPL, and only to the extent that such information is
(if any) on which the executable work runs, or a compiler used to necessary to install and execute a modified version of the
produce the work, or an object code interpreter used to run it. Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
The "Corresponding Source" for a work in object code form means all you use option 4d0, the Installation Information must accompany
the source code needed to generate, install, and (for an executable the Minimal Corresponding Source and Corresponding Application
work) run the object code and to modify the work, including scripts to Code. If you use option 4d1, you must provide the Installation
control those activities. However, it does not include the work's Information in the manner specified by section 6 of the GNU GPL
System Libraries, or general-purpose tools or generally available free for conveying Corresponding Source.)
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source 5. Combined Libraries.
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically You may place library facilities that are a work based on the
linked subprograms that the work is specifically designed to require, Library side by side in a single library together with other library
such as by intimate data communication or control flow between those facilities that are not Applications and are not covered by this
subprograms and other parts of the work. License, and convey such a combined library under terms of your
choice, if you do both of the following:
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding a) Accompany the combined library with a copy of the same work based
Source. on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
The Corresponding Source for a work in source code form is that
same work. b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
2. Basic Permissions. accompanying uncombined form of the same work.
All rights granted under this License are granted for the term of 6. Revised Versions of the GNU Lesser General Public License.
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited The Free Software Foundation may publish revised and/or new versions
permission to run the unmodified Program. The output from running a of the GNU Lesser General Public License from time to time. Such new
covered work is covered by this License only if the output, given its versions will be similar in spirit to the present version, but may
content, constitutes a covered work. This License acknowledges your differ in detail to address new problems or concerns.
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General Library as you received it specifies that a certain numbered version
Public License "or any later version" applies to it, you have the of the GNU Lesser General Public License "or any later version"
option of following the terms and conditions either of that numbered applies to it, you have the option of following the terms and
version or of any later version published by the Free Software conditions either of that published version or of any later version
Foundation. If the Program does not specify a version number of the published by the Free Software Foundation. If the Library as you
GNU Affero General Public License, you may choose any version ever published received it does not specify a version number of the GNU Lesser
by the Free Software Foundation. General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Program specifies that a proxy can decide which future If the Library as you received it specifies that a proxy can decide
versions of the GNU Affero General Public License can be used, that proxy's whether future versions of the GNU Lesser General Public License shall
public statement of acceptance of a version permanently authorizes you apply, that proxy's public statement of acceptance of any version is
to choose that version for the Program. permanent authorization for you to choose that version for the
Library.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
env
Copyright (C) 2022 etke.cc / Go
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.

View File

@@ -14,26 +14,36 @@ func SetPrefix(prefix string) {
} }
// String returns string vars // String returns string vars
func String(shortkey string, defaultValue string) string { func String(shortkey string, defaultValue ...string) string {
var dv string
if len(defaultValue) > 0 {
dv = defaultValue[0]
}
key := strings.ToUpper(envprefix + "_" + strings.ReplaceAll(shortkey, ".", "_")) key := strings.ToUpper(envprefix + "_" + strings.ReplaceAll(shortkey, ".", "_"))
value := strings.TrimSpace(os.Getenv(key)) value := strings.TrimSpace(os.Getenv(key))
if value == "" { if value == "" {
return defaultValue return dv
} }
return value return value
} }
// Int returns int vars // Int returns int vars
func Int(shortkey string, defaultValue int) int { func Int(shortkey string, defaultValue ...int) int {
str := String(shortkey, "") var dv int
if len(defaultValue) > 0 {
dv = defaultValue[0]
}
str := String(shortkey)
if str == "" { if str == "" {
return defaultValue return dv
} }
val, err := strconv.Atoi(str) val, err := strconv.Atoi(str)
if err != nil { if err != nil {
return defaultValue return dv
} }
return val return val
@@ -41,7 +51,7 @@ func Int(shortkey string, defaultValue int) int {
// Bool returns boolean vars (1, true, yes) // Bool returns boolean vars (1, true, yes)
func Bool(shortkey string) bool { func Bool(shortkey string) bool {
str := strings.ToLower(String(shortkey, "")) str := strings.ToLower(String(shortkey))
if str == "" { if str == "" {
return false return false
} }
@@ -50,7 +60,7 @@ func Bool(shortkey string) bool {
// Slice returns slice from space-separated strings, eg: export VAR="one two three" => []string{"one", "two", "three"} // Slice returns slice from space-separated strings, eg: export VAR="one two three" => []string{"one", "two", "three"}
func Slice(shortkey string) []string { func Slice(shortkey string) []string {
str := String(shortkey, "") str := String(shortkey)
if str == "" { if str == "" {
return nil return nil
} }

View File

@@ -1,13 +1,14 @@
package linkpearl package linkpearl
import ( import (
"context"
"strings" "strings"
"maunium.net/go/mautrix/id" "maunium.net/go/mautrix/id"
) )
// GetAccountData of the user (from cache and API, with encryption support) // GetAccountData of the user (from cache and API, with encryption support)
func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) { func (l *Linkpearl) GetAccountData(ctx context.Context, name string) (map[string]string, error) {
cached, ok := l.acc.Get(name) cached, ok := l.acc.Get(name)
if ok { if ok {
if cached == nil { if cached == nil {
@@ -17,7 +18,7 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
} }
var data map[string]string var data map[string]string
err := l.GetClient().GetAccountData(name, &data) err := l.GetClient().GetAccountData(ctx, name, &data)
if err != nil { if err != nil {
data = map[string]string{} data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") { if strings.Contains(err.Error(), "M_NOT_FOUND") {
@@ -33,15 +34,15 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
} }
// SetAccountData of the user (to cache and API, with encryption support) // SetAccountData of the user (to cache and API, with encryption support)
func (l *Linkpearl) SetAccountData(name string, data map[string]string) error { func (l *Linkpearl) SetAccountData(ctx context.Context, name string, data map[string]string) error {
l.acc.Add(name, data) l.acc.Add(name, data)
data = l.encryptAccountData(data) data = l.encryptAccountData(data)
return UnwrapError(l.GetClient().SetAccountData(name, data)) return UnwrapError(l.GetClient().SetAccountData(ctx, name, data))
} }
// GetRoomAccountData of the room (from cache and API, with encryption support) // GetRoomAccountData of the room (from cache and API, with encryption support)
func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[string]string, error) { func (l *Linkpearl) GetRoomAccountData(ctx context.Context, roomID id.RoomID, name string) (map[string]string, error) {
key := roomID.String() + name key := roomID.String() + name
cached, ok := l.acc.Get(key) cached, ok := l.acc.Get(key)
if ok { if ok {
@@ -52,7 +53,7 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
} }
var data map[string]string var data map[string]string
err := l.GetClient().GetRoomAccountData(roomID, name, &data) err := l.GetClient().GetRoomAccountData(ctx, roomID, name, &data)
if err != nil { if err != nil {
data = map[string]string{} data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") { if strings.Contains(err.Error(), "M_NOT_FOUND") {
@@ -68,12 +69,12 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
} }
// SetRoomAccountData of the room (to cache and API, with encryption support) // SetRoomAccountData of the room (to cache and API, with encryption support)
func (l *Linkpearl) SetRoomAccountData(roomID id.RoomID, name string, data map[string]string) error { func (l *Linkpearl) SetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, data map[string]string) error {
key := roomID.String() + name key := roomID.String() + name
l.acc.Add(key, data) l.acc.Add(key, data)
data = l.encryptAccountData(data) data = l.encryptAccountData(data)
return UnwrapError(l.GetClient().SetRoomAccountData(roomID, name, data)) return UnwrapError(l.GetClient().SetRoomAccountData(ctx, roomID, name, data))
} }
func (l *Linkpearl) encryptAccountData(data map[string]string) map[string]string { func (l *Linkpearl) encryptAccountData(data map[string]string) map[string]string {

View File

@@ -1,6 +1,7 @@
package linkpearl package linkpearl
import ( import (
"context"
"crypto/hmac" "crypto/hmac"
"crypto/sha512" "crypto/sha512"
"database/sql" "database/sql"
@@ -25,7 +26,7 @@ type Config struct {
// JoinPermit is a callback function that tells // JoinPermit is a callback function that tells
// if linkpearl should respond to the given "invite" event // if linkpearl should respond to the given "invite" event
// and join the room // and join the room
JoinPermit func(*event.Event) bool JoinPermit func(context.Context, *event.Event) bool
// AutoLeave if true, linkpearl will automatically leave empty rooms // AutoLeave if true, linkpearl will automatically leave empty rooms
AutoLeave bool AutoLeave bool

View File

@@ -1,6 +1,7 @@
package linkpearl package linkpearl
import ( import (
"context"
"strconv" "strconv"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
@@ -15,7 +16,7 @@ type RespThreads struct {
} }
// Threads endpoint, ref: https://spec.matrix.org/v1.8/client-server-api/#get_matrixclientv1roomsroomidthreads // Threads endpoint, ref: https://spec.matrix.org/v1.8/client-server-api/#get_matrixclientv1roomsroomidthreads
func (l *Linkpearl) Threads(roomID id.RoomID, fromToken ...string) (*RespThreads, error) { func (l *Linkpearl) Threads(ctx context.Context, roomID id.RoomID, fromToken ...string) (*RespThreads, error) {
var from string var from string
if len(fromToken) > 0 { if len(fromToken) > 0 {
from = fromToken[0] from = fromToken[0]
@@ -28,18 +29,18 @@ func (l *Linkpearl) Threads(roomID id.RoomID, fromToken ...string) (*RespThreads
var resp *RespThreads var resp *RespThreads
urlPath := l.GetClient().BuildURLWithQuery(mautrix.ClientURLPath{"v1", "rooms", roomID, "threads"}, query) urlPath := l.GetClient().BuildURLWithQuery(mautrix.ClientURLPath{"v1", "rooms", roomID, "threads"}, query)
_, err := l.GetClient().MakeRequest("GET", urlPath, nil, &resp) _, err := l.GetClient().MakeRequest(ctx, "GET", urlPath, nil, &resp)
return resp, UnwrapError(err) return resp, UnwrapError(err)
} }
// FindThreadBy tries to find thread message event by field and value // FindThreadBy tries to find thread message event by field and value
func (l *Linkpearl) FindThreadBy(roomID id.RoomID, field, value string, fromToken ...string) *event.Event { func (l *Linkpearl) FindThreadBy(ctx context.Context, roomID id.RoomID, field, value string, fromToken ...string) *event.Event {
var from string var from string
if len(fromToken) > 0 { if len(fromToken) > 0 {
from = fromToken[0] from = fromToken[0]
} }
resp, err := l.Threads(roomID, from) resp, err := l.Threads(ctx, roomID, from)
err = UnwrapError(err) err = UnwrapError(err)
if err != nil { if err != nil {
l.log.Warn().Err(err).Str("roomID", roomID.String()).Msg("cannot get room threads") l.log.Warn().Err(err).Str("roomID", roomID.String()).Msg("cannot get room threads")
@@ -47,7 +48,7 @@ func (l *Linkpearl) FindThreadBy(roomID id.RoomID, field, value string, fromToke
} }
for _, msg := range resp.Chunk { for _, msg := range resp.Chunk {
evt, contains := l.eventContains(msg, field, value) evt, contains := l.eventContains(ctx, msg, field, value)
if contains { if contains {
return evt return evt
} }
@@ -57,17 +58,17 @@ func (l *Linkpearl) FindThreadBy(roomID id.RoomID, field, value string, fromToke
return nil return nil
} }
return l.FindThreadBy(roomID, field, value, resp.NextBatch) return l.FindThreadBy(ctx, roomID, field, value, resp.NextBatch)
} }
// FindEventBy tries to find message event by field and value // FindEventBy tries to find message event by field and value
func (l *Linkpearl) FindEventBy(roomID id.RoomID, field, value string, fromToken ...string) *event.Event { func (l *Linkpearl) FindEventBy(ctx context.Context, roomID id.RoomID, field, value string, fromToken ...string) *event.Event {
var from string var from string
if len(fromToken) > 0 { if len(fromToken) > 0 {
from = fromToken[0] from = fromToken[0]
} }
resp, err := l.GetClient().Messages(roomID, from, "", mautrix.DirectionBackward, nil, l.eventsLimit) resp, err := l.GetClient().Messages(ctx, roomID, from, "", mautrix.DirectionBackward, nil, l.eventsLimit)
err = UnwrapError(err) err = UnwrapError(err)
if err != nil { if err != nil {
l.log.Warn().Err(err).Str("roomID", roomID.String()).Msg("cannot get room events") l.log.Warn().Err(err).Str("roomID", roomID.String()).Msg("cannot get room events")
@@ -75,7 +76,7 @@ func (l *Linkpearl) FindEventBy(roomID id.RoomID, field, value string, fromToken
} }
for _, msg := range resp.Chunk { for _, msg := range resp.Chunk {
evt, contains := l.eventContains(msg, field, value) evt, contains := l.eventContains(ctx, msg, field, value)
if contains { if contains {
return evt return evt
} }
@@ -85,13 +86,13 @@ func (l *Linkpearl) FindEventBy(roomID id.RoomID, field, value string, fromToken
return nil return nil
} }
return l.FindEventBy(roomID, field, value, resp.End) return l.FindEventBy(ctx, roomID, field, value, resp.End)
} }
func (l *Linkpearl) eventContains(evt *event.Event, field, value string) (*event.Event, bool) { func (l *Linkpearl) eventContains(ctx context.Context, evt *event.Event, field, value string) (*event.Event, bool) {
if evt.Type == event.EventEncrypted { if evt.Type == event.EventEncrypted {
ParseContent(evt, &l.log) ParseContent(evt, &l.log)
decrypted, err := l.GetClient().Crypto.Decrypt(evt) decrypted, err := l.GetClient().Crypto.Decrypt(ctx, evt)
if err == nil { if err == nil {
evt = decrypted evt = decrypted
} }

View File

@@ -2,6 +2,7 @@
package linkpearl package linkpearl
import ( import (
"context"
"database/sql" "database/sql"
lru "github.com/hashicorp/golang-lru/v2" lru "github.com/hashicorp/golang-lru/v2"
@@ -31,7 +32,7 @@ type Linkpearl struct {
log zerolog.Logger log zerolog.Logger
api *mautrix.Client api *mautrix.Client
joinPermit func(*event.Event) bool joinPermit func(ctx context.Context, evt *event.Event) bool
autoleave bool autoleave bool
maxretries int maxretries int
eventsLimit int eventsLimit int
@@ -54,7 +55,7 @@ func setDefaults(cfg *Config) {
} }
if cfg.JoinPermit == nil { if cfg.JoinPermit == nil {
// By default, we approve all join requests // By default, we approve all join requests
cfg.JoinPermit = func(*event.Event) bool { return true } cfg.JoinPermit = func(_ context.Context, _ *event.Event) bool { return true }
} }
} }
@@ -103,7 +104,7 @@ func New(cfg *Config) (*Linkpearl, error) {
return nil, err return nil, err
} }
lp.ch.LoginAs = cfg.LoginAs() lp.ch.LoginAs = cfg.LoginAs()
if err = lp.ch.Init(); err != nil { if err = lp.ch.Init(context.Background()); err != nil {
return nil, err return nil, err
} }
lp.api.Crypto = lp.ch lp.api.Crypto = lp.ch
@@ -131,16 +132,16 @@ func (l *Linkpearl) GetAccountDataCrypter() *Crypter {
} }
// SetPresence (own). See https://spec.matrix.org/v1.3/client-server-api/#put_matrixclientv3presenceuseridstatus // SetPresence (own). See https://spec.matrix.org/v1.3/client-server-api/#put_matrixclientv3presenceuseridstatus
func (l *Linkpearl) SetPresence(presence event.Presence, message string) error { func (l *Linkpearl) SetPresence(ctx context.Context, presence event.Presence, message string) error {
req := ReqPresence{Presence: presence, StatusMsg: message} req := ReqPresence{Presence: presence, StatusMsg: message}
u := l.GetClient().BuildClientURL("v3", "presence", l.GetClient().UserID, "status") u := l.GetClient().BuildClientURL("v3", "presence", l.GetClient().UserID, "status")
_, err := l.GetClient().MakeRequest("PUT", u, req, nil) _, err := l.GetClient().MakeRequest(ctx, "PUT", u, req, nil)
return err return err
} }
// SetJoinPermit sets the the join permit callback function // SetJoinPermit sets the the join permit callback function
func (l *Linkpearl) SetJoinPermit(value func(*event.Event) bool) { func (l *Linkpearl) SetJoinPermit(value func(context.Context, *event.Event) bool) {
l.joinPermit = value l.joinPermit = value
} }
@@ -152,7 +153,7 @@ func (l *Linkpearl) Start(optionalStatusMsg ...string) error {
statusMsg = optionalStatusMsg[0] statusMsg = optionalStatusMsg[0]
} }
err := l.SetPresence(event.PresenceOnline, statusMsg) err := l.SetPresence(context.Background(), event.PresenceOnline, statusMsg)
if err != nil { if err != nil {
l.log.Error().Err(err).Msg("cannot set presence") l.log.Error().Err(err).Msg("cannot set presence")
} }
@@ -165,7 +166,7 @@ func (l *Linkpearl) Start(optionalStatusMsg ...string) error {
// Stop the client // Stop the client
func (l *Linkpearl) Stop() { func (l *Linkpearl) Stop() {
l.log.Debug().Msg("stopping the client") l.log.Debug().Msg("stopping the client")
if err := l.api.SetPresence(event.PresenceOffline); err != nil { if err := l.api.SetPresence(context.Background(), event.PresenceOffline); err != nil {
l.log.Error().Err(err).Msg("cannot set presence") l.log.Error().Err(err).Msg("cannot set presence")
} }
l.api.StopSync() l.api.StopSync()

View File

@@ -1,6 +1,8 @@
package linkpearl package linkpearl
import ( import (
"context"
"maunium.net/go/mautrix" "maunium.net/go/mautrix"
"maunium.net/go/mautrix/event" "maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format" "maunium.net/go/mautrix/format"
@@ -10,9 +12,9 @@ import (
// Send a message to the roomID and automatically try to encrypt it, if the destination room is encrypted // Send a message to the roomID and automatically try to encrypt it, if the destination room is encrypted
// //
//nolint:unparam // it's public interface //nolint:unparam // it's public interface
func (l *Linkpearl) Send(roomID id.RoomID, content interface{}) (id.EventID, error) { func (l *Linkpearl) Send(ctx context.Context, roomID id.RoomID, content interface{}) (id.EventID, error) {
l.log.Debug().Str("roomID", roomID.String()).Any("content", content).Msg("sending event") l.log.Debug().Str("roomID", roomID.String()).Any("content", content).Msg("sending event")
resp, err := l.api.SendMessageEvent(roomID, event.EventMessage, content) resp, err := l.api.SendMessageEvent(ctx, roomID, event.EventMessage, content)
if err != nil { if err != nil {
return "", UnwrapError(err) return "", UnwrapError(err)
} }
@@ -20,7 +22,7 @@ func (l *Linkpearl) Send(roomID id.RoomID, content interface{}) (id.EventID, err
} }
// SendNotice to a room with optional relations, markdown supported // SendNotice to a room with optional relations, markdown supported
func (l *Linkpearl) SendNotice(roomID id.RoomID, message string, relates ...*event.RelatesTo) { func (l *Linkpearl) SendNotice(ctx context.Context, roomID id.RoomID, message string, relates ...*event.RelatesTo) {
var withRelatesTo bool var withRelatesTo bool
content := format.RenderMarkdown(message, true, true) content := format.RenderMarkdown(message, true, true)
content.MsgType = event.MsgNotice content.MsgType = event.MsgNotice
@@ -29,12 +31,12 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, message string, relates ...*eve
content.RelatesTo = relates[0] content.RelatesTo = relates[0]
} }
_, err := l.Send(roomID, &content) _, err := l.Send(ctx, roomID, &content)
if err != nil { if err != nil {
l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "1/2").Msg("cannot send a notice into the room") l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "1/2").Msg("cannot send a notice into the room")
if withRelatesTo { if withRelatesTo {
content.RelatesTo = nil content.RelatesTo = nil
_, err = l.Send(roomID, &content) _, err = l.Send(ctx, roomID, &content)
if err != nil { if err != nil {
l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "2/2").Msg("cannot send a notice into the room even without relations") l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "2/2").Msg("cannot send a notice into the room even without relations")
} }
@@ -43,13 +45,13 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, message string, relates ...*eve
} }
// SendFile to a matrix room // SendFile to a matrix room
func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgtype event.MessageType, relates ...*event.RelatesTo) error { func (l *Linkpearl) SendFile(ctx context.Context, roomID id.RoomID, req *mautrix.ReqUploadMedia, msgtype event.MessageType, relates ...*event.RelatesTo) error {
var relation *event.RelatesTo var relation *event.RelatesTo
if len(relates) > 0 { if len(relates) > 0 {
relation = relates[0] relation = relates[0]
} }
resp, err := l.GetClient().UploadMedia(*req) resp, err := l.GetClient().UploadMedia(ctx, *req)
if err != nil { if err != nil {
err = UnwrapError(err) err = UnwrapError(err)
l.log.Error().Err(err).Str("file", req.FileName).Msg("cannot upload file") l.log.Error().Err(err).Str("file", req.FileName).Msg("cannot upload file")
@@ -62,13 +64,13 @@ func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgt
RelatesTo: relation, RelatesTo: relation,
} }
_, err = l.Send(roomID, content) _, err = l.Send(ctx, roomID, content)
err = UnwrapError(err) err = UnwrapError(err)
if err != nil { if err != nil {
l.log.Error().Err(err).Str("roomID", roomID.String()).Str("retries", "1/2").Msg("cannot send file into the room") l.log.Error().Err(err).Str("roomID", roomID.String()).Str("retries", "1/2").Msg("cannot send file into the room")
if relation != nil { if relation != nil {
content.RelatesTo = nil content.RelatesTo = nil
_, err = l.Send(roomID, &content) _, err = l.Send(ctx, roomID, &content)
err = UnwrapError(err) err = UnwrapError(err)
if err != nil { if err != nil {
l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "2/2").Msg("cannot send file into the room even without relations") l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "2/2").Msg("cannot send file into the room even without relations")

View File

@@ -1,6 +1,7 @@
package linkpearl package linkpearl
import ( import (
"context"
"strings" "strings"
"time" "time"
@@ -28,54 +29,56 @@ func (l *Linkpearl) OnEvent(callback mautrix.EventHandler) {
func (l *Linkpearl) initSync() { func (l *Linkpearl) initSync() {
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType( l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateEncryption, event.StateEncryption,
func(source mautrix.EventSource, evt *event.Event) { func(ctx context.Context, evt *event.Event) {
go l.onEncryption(source, evt) go l.onEncryption(ctx, evt)
}, },
) )
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType( l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateMember, event.StateMember,
func(source mautrix.EventSource, evt *event.Event) { func(ctx context.Context, evt *event.Event) {
go l.onMembership(source, evt) go l.onMembership(ctx, evt)
}, },
) )
} }
func (l *Linkpearl) onMembership(src mautrix.EventSource, evt *event.Event) { func (l *Linkpearl) onMembership(ctx context.Context, evt *event.Event) {
l.ch.Machine().HandleMemberEvent(src, evt) l.ch.Machine().HandleMemberEvent(ctx, evt)
l.api.StateStore.SetMembership(evt.RoomID, id.UserID(evt.GetStateKey()), evt.Content.AsMember().Membership) if err := l.api.StateStore.SetMembership(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), evt.Content.AsMember().Membership); err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Str("userID", evt.GetStateKey()).Msg("cannot set membership")
}
// potentially autoaccept invites // potentially autoaccept invites
l.onInvite(evt) l.onInvite(ctx, evt)
// autoleave empty rooms // autoleave empty rooms
l.onEmpty(evt) l.onEmpty(ctx, evt)
} }
func (l *Linkpearl) onInvite(evt *event.Event) { func (l *Linkpearl) onInvite(ctx context.Context, evt *event.Event) {
userID := l.api.UserID.String() userID := l.api.UserID.String()
invite := evt.Content.AsMember().Membership == event.MembershipInvite invite := evt.Content.AsMember().Membership == event.MembershipInvite
if !invite || evt.GetStateKey() != userID { if !invite || evt.GetStateKey() != userID {
return return
} }
if l.joinPermit(evt) { if l.joinPermit(ctx, evt) {
l.tryJoin(evt.RoomID, 0) l.tryJoin(ctx, evt.RoomID, 0)
return return
} }
l.tryLeave(evt.RoomID, 0) l.tryLeave(ctx, evt.RoomID, 0)
} }
// TODO: https://spec.matrix.org/v1.8/client-server-api/#post_matrixclientv3joinroomidoralias // TODO: https://spec.matrix.org/v1.8/client-server-api/#post_matrixclientv3joinroomidoralias
// endpoint supports server_name param and tells "The servers to attempt to join the room through. One of the servers must be participating in the room.", // endpoint supports server_name param and tells "The servers to attempt to join the room through. One of the servers must be participating in the room.",
// meaning you can specify more than 1 server. It is not clear, what format should be used "example.com,example.org", or "example.com example.org", or whatever else. // meaning you can specify more than 1 server. It is not clear, what format should be used "example.com,example.org", or "example.com example.org", or whatever else.
// Moreover, it is not clear if the following values can be used together with that field: l.api.UserID.Homeserver() and evt.Sender.Homeserver() // Moreover, it is not clear if the following values can be used together with that field: l.api.UserID.Homeserver() and evt.Sender.Homeserver()
func (l *Linkpearl) tryJoin(roomID id.RoomID, retry int) { func (l *Linkpearl) tryJoin(ctx context.Context, roomID id.RoomID, retry int) {
if retry >= l.maxretries { if retry >= l.maxretries {
return return
} }
_, err := l.api.JoinRoom(roomID.String(), "", nil) _, err := l.api.JoinRoom(ctx, roomID.String(), "", nil)
err = UnwrapError(err) err = UnwrapError(err)
if err != nil { if err != nil {
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot join room") l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot join room")
@@ -84,31 +87,31 @@ func (l *Linkpearl) tryJoin(roomID id.RoomID, retry int) {
} }
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to join again") l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to join again")
l.tryJoin(roomID, retry+1) l.tryJoin(ctx, roomID, retry+1)
} }
} }
func (l *Linkpearl) tryLeave(roomID id.RoomID, retry int) { func (l *Linkpearl) tryLeave(ctx context.Context, roomID id.RoomID, retry int) {
if retry >= l.maxretries { if retry >= l.maxretries {
return return
} }
_, err := l.api.LeaveRoom(roomID) _, err := l.api.LeaveRoom(ctx, roomID)
err = UnwrapError(err) err = UnwrapError(err)
if err != nil { if err != nil {
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot leave room") l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot leave room")
time.Sleep(5 * time.Second) time.Sleep(5 * time.Second)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to leave again") l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to leave again")
l.tryLeave(roomID, retry+1) l.tryLeave(ctx, roomID, retry+1)
} }
} }
func (l *Linkpearl) onEmpty(evt *event.Event) { func (l *Linkpearl) onEmpty(ctx context.Context, evt *event.Event) {
if !l.autoleave { if !l.autoleave {
return return
} }
members, err := l.api.StateStore.GetRoomJoinedOrInvitedMembers(evt.RoomID) members, err := l.api.StateStore.GetRoomJoinedOrInvitedMembers(ctx, evt.RoomID)
err = UnwrapError(err) err = UnwrapError(err)
if err != nil { if err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members") l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members")
@@ -119,9 +122,11 @@ func (l *Linkpearl) onEmpty(evt *event.Event) {
return return
} }
l.tryLeave(evt.RoomID, 0) l.tryLeave(ctx, evt.RoomID, 0)
} }
func (l *Linkpearl) onEncryption(_ mautrix.EventSource, evt *event.Event) { func (l *Linkpearl) onEncryption(ctx context.Context, evt *event.Event) {
l.api.StateStore.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption()) if err := l.api.StateStore.SetEncryptionEvent(ctx, evt.RoomID, evt.Content.AsEncryption()); err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot set encryption event")
}
} }

View File

@@ -9,6 +9,9 @@ package dbutil
import ( import (
"context" "context"
"database/sql" "database/sql"
"errors"
"strconv"
"strings"
"time" "time"
) )
@@ -19,18 +22,61 @@ type LoggingExecable struct {
db *Database db *Database
} }
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) { type pqError interface {
Get(k byte) string
}
type PQErrorWithLine struct {
Underlying error
Line string
}
func (pqe *PQErrorWithLine) Error() string {
return pqe.Underlying.Error()
}
func (pqe *PQErrorWithLine) Unwrap() error {
return pqe.Underlying
}
func addErrorLine(query string, err error) error {
if err == nil {
return err
}
var pqe pqError
if !errors.As(err, &pqe) {
return err
}
pos, _ := strconv.Atoi(pqe.Get('P'))
pos--
if pos <= 0 {
return err
}
lines := strings.Split(query, "\n")
for _, line := range lines {
lineRunes := []rune(line)
if pos < len(lineRunes)+1 {
return &PQErrorWithLine{Underlying: err, Line: line}
}
pos -= len(lineRunes) + 1
}
return err
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
start := time.Now() start := time.Now()
query = le.db.mutateQuery(query) query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...) res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
err = addErrorLine(query, err)
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err) le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
return res, err return res, err
} }
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) { func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) {
start := time.Now() start := time.Now()
query = le.db.mutateQuery(query) query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...) rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
err = addErrorLine(query, err)
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err) le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
return &LoggingRows{ return &LoggingRows{
ctx: ctx, ctx: ctx,
@@ -42,7 +88,7 @@ func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args
}, err }, err
} }
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row { func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
start := time.Now() start := time.Now()
query = le.db.mutateQuery(query) query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...) row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
@@ -50,18 +96,6 @@ func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, ar
return row return row
} }
func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result, error) {
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
func (le *LoggingExecable) QueryRow(query string, args ...interface{}) *sql.Row {
return le.QueryRowContext(context.Background(), query, args...)
}
// loggingDB is a wrapper for LoggingExecable that allows access to BeginTx. // loggingDB is a wrapper for LoggingExecable that allows access to BeginTx.
// //
// While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since // While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since
@@ -89,10 +123,6 @@ func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Logging
}, nil }, nil
} }
func (ld *loggingDB) Begin() (*LoggingTxn, error) {
return ld.BeginTx(context.Background(), nil)
}
type LoggingTxn struct { type LoggingTxn struct {
LoggingExecable LoggingExecable
UnderlyingTx *sql.Tx UnderlyingTx *sql.Tx
@@ -129,7 +159,7 @@ type LoggingRows struct {
ctx context.Context ctx context.Context
db *Database db *Database
query string query string
args []interface{} args []any
rs Rows rs Rows
start time.Time start time.Time
nrows int nrows int

View File

@@ -58,7 +58,7 @@ type Rows interface {
} }
type Scannable interface { type Scannable interface {
Scan(...interface{}) error Scan(...any) error
} }
// Expected implementations of Scannable // Expected implementations of Scannable
@@ -67,30 +67,16 @@ var (
_ Scannable = (Rows)(nil) _ Scannable = (Rows)(nil)
) )
type UnderlyingContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type ContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type UnderlyingExecable interface { type UnderlyingExecable interface {
UnderlyingContextExecable ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Exec(query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
Query(query string, args ...interface{}) (*sql.Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
QueryRow(query string, args ...interface{}) *sql.Row
} }
type Execable interface { type Execable interface {
ContextExecable ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
Exec(query string, args ...interface{}) (sql.Result, error) QueryContext(ctx context.Context, query string, args ...any) (Rows, error)
Query(query string, args ...interface{}) (Rows, error) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
QueryRow(query string, args ...interface{}) *sql.Row
} }
type Transaction interface { type Transaction interface {
@@ -103,13 +89,13 @@ type Transaction interface {
var ( var (
_ UnderlyingExecable = (*sql.Tx)(nil) _ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil) _ UnderlyingExecable = (*sql.DB)(nil)
_ UnderlyingExecable = (*sql.Conn)(nil)
_ Execable = (*LoggingExecable)(nil) _ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil) _ Transaction = (*LoggingTxn)(nil)
_ UnderlyingContextExecable = (*sql.Conn)(nil)
) )
type Database struct { type Database struct {
loggingDB LoggingDB loggingDB
RawDB *sql.DB RawDB *sql.DB
ReadOnlyDB *sql.DB ReadOnlyDB *sql.DB
Owner string Owner string
@@ -139,7 +125,7 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
} }
return &Database{ return &Database{
RawDB: db.RawDB, RawDB: db.RawDB,
loggingDB: db.loggingDB, LoggingDB: db.LoggingDB,
Owner: "", Owner: "",
VersionTable: versionTable, VersionTable: versionTable,
UpgradeTable: upgradeTable, UpgradeTable: upgradeTable,
@@ -164,8 +150,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
IgnoreForeignTables: true, IgnoreForeignTables: true,
VersionTable: "version", VersionTable: "version",
} }
wrappedDB.loggingDB.UnderlyingExecable = db wrappedDB.LoggingDB.UnderlyingExecable = db
wrappedDB.loggingDB.db = wrappedDB wrappedDB.LoggingDB.db = wrappedDB
return wrappedDB, nil return wrappedDB, nil
} }
@@ -259,7 +245,7 @@ func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database,
if roUri == "" { if roUri == "" {
uriParts := strings.Split(cfg.URI, "?") uriParts := strings.Split(cfg.URI, "?")
var qs url.Values qs := url.Values{}
if len(uriParts) == 2 { if len(uriParts) == 2 {
var err error var err error
qs, err = url.ParseQuery(uriParts[1]) qs, err = url.ParseQuery(uriParts[1])

72
vendor/go.mau.fi/util/dbutil/iter.go vendored Normal file
View File

@@ -0,0 +1,72 @@
// Copyright (c) 2023 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
// RowIter is a wrapper for [Rows] that allows conveniently iterating over rows
// with a predefined scanner function.
type RowIter[T any] interface {
// Iter iterates over the rows and calls the given function for each row.
//
// If the function returns false, the iteration is stopped.
// If the function returns an error, the iteration is stopped and the error is
// returned.
Iter(func(T) (bool, error)) error
// AsList collects all rows into a slice.
AsList() ([]T, error)
}
type rowIterImpl[T any] struct {
Rows
ConvertRow func(Scannable) (T, error)
}
// NewRowIter creates a new RowIter from the given Rows and scanner function.
func NewRowIter[T any](rows Rows, convertFn func(Scannable) (T, error)) RowIter[T] {
return &rowIterImpl[T]{Rows: rows, ConvertRow: convertFn}
}
func ScanSingleColumn[T any](rows Scannable) (val T, err error) {
err = rows.Scan(&val)
return
}
type NewableDataStruct[T any] interface {
DataStruct[T]
New() T
}
func ScanDataStruct[T NewableDataStruct[T]](rows Scannable) (T, error) {
var val T
return val.New().Scan(rows)
}
func (i *rowIterImpl[T]) Iter(fn func(T) (bool, error)) error {
if i == nil || i.Rows == nil {
return nil
}
defer i.Rows.Close()
for i.Rows.Next() {
if item, err := i.ConvertRow(i.Rows); err != nil {
return err
} else if cont, err := fn(item); err != nil {
return err
} else if !cont {
break
}
}
return i.Rows.Err()
}
func (i *rowIterImpl[T]) AsList() (list []T, err error) {
err = i.Iter(func(item T) (bool, error) {
list = append(list, item)
return true, nil
})
return
}

View File

@@ -10,12 +10,12 @@ import (
) )
type DatabaseLogger interface { type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error)
WarnUnsupportedVersion(current, compat, latest int) WarnUnsupportedVersion(current, compat, latest int)
PrepareUpgrade(current, compat, latest int) PrepareUpgrade(current, compat, latest int)
DoUpgrade(from, to int, message string, txn bool) DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead // Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{}) Warn(msg string, args ...any)
} }
type noopLogger struct{} type noopLogger struct{}
@@ -25,9 +25,9 @@ var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {} func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _, _ int) {} func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {} func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {} func (n noopLogger) Warn(msg string, args ...any) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) { func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []any, _ int, _ time.Duration, _ error) {
} }
type zeroLogger struct { type zeroLogger struct {
@@ -92,7 +92,7 @@ func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
var whitespaceRegex = regexp.MustCompile(`\s+`) var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) { func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error) {
log := zerolog.Ctx(ctx) log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger { if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
log = z.l log = z.l
@@ -124,6 +124,6 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
} }
} }
func (z zeroLogger) Warn(msg string, args ...interface{}) { func (z zeroLogger) Warn(msg string, args ...any) {
z.l.Warn().Msgf(msg, args...) z.l.Warn().Msgf(msg, args...) // zerolog-allow-msgf
} }

View File

@@ -0,0 +1,105 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"errors"
"golang.org/x/exp/constraints"
)
// DataStruct is an interface for structs that represent a single database row.
type DataStruct[T any] interface {
Scan(row Scannable) (T, error)
}
// QueryHelper is a generic helper struct for SQL query execution boilerplate.
//
// After implementing the Scan and Init methods in a data struct, the query
// helper allows writing query functions in a single line.
type QueryHelper[T DataStruct[T]] struct {
db *Database
newFunc func(qh *QueryHelper[T]) T
}
func MakeQueryHelper[T DataStruct[T]](db *Database, new func(qh *QueryHelper[T]) T) *QueryHelper[T] {
return &QueryHelper[T]{db: db, newFunc: new}
}
// ValueOrErr is a helper function that returns the value if err is nil, or
// returns nil and the error if err is not nil. It can be used to avoid
// `if err != nil { return nil, err }` boilerplate in certain cases like
// DataStruct.Scan implementations.
func ValueOrErr[T any](val *T, err error) (*T, error) {
if err != nil {
return nil, err
}
return val, nil
}
// StrPtr returns a pointer to the given string, or nil if the string is empty.
func StrPtr[T ~string](val T) *string {
if val == "" {
return nil
}
strVal := string(val)
return &strVal
}
// NumPtr returns a pointer to the given number, or nil if the number is zero.
func NumPtr[T constraints.Integer | constraints.Float](val T) *T {
if val == 0 {
return nil
}
return &val
}
func (qh *QueryHelper[T]) GetDB() *Database {
return qh.db
}
func (qh *QueryHelper[T]) New() T {
return qh.newFunc(qh)
}
// Exec executes a query with ExecContext and returns the error.
//
// It omits the sql.Result return value, as it is rarely used. When the result
// is wanted, use `qh.GetDB().Exec(...)` instead, which is
// otherwise equivalent.
func (qh *QueryHelper[T]) Exec(ctx context.Context, query string, args ...any) error {
_, err := qh.db.Exec(ctx, query, args...)
return err
}
func (qh *QueryHelper[T]) scanNew(row Scannable) (T, error) {
return qh.New().Scan(row)
}
// QueryOne executes a query with QueryRowContext, uses the associated DataStruct
// to scan it, and returns the value. If the query returns no rows, it returns nil
// and no error.
func (qh *QueryHelper[T]) QueryOne(ctx context.Context, query string, args ...any) (val T, err error) {
val, err = qh.scanNew(qh.db.QueryRow(ctx, query, args...))
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return val, err
}
// QueryMany executes a query with QueryContext, uses the associated DataStruct
// to scan each row, and returns the values. If the query returns no rows, it
// returns a non-nil zero-length slice and no error.
func (qh *QueryHelper[T]) QueryMany(ctx context.Context, query string, args ...any) ([]T, error) {
rows, err := qh.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
return NewRowIter(rows, qh.scanNew).AsList()
}

View File

@@ -32,6 +32,22 @@ const (
ContextKeyDoTxnCallerSkip ContextKeyDoTxnCallerSkip
) )
func (db *Database) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return db.Conn(ctx).ExecContext(ctx, query, args...)
}
func (db *Database) Query(ctx context.Context, query string, args ...any) (Rows, error) {
return db.Conn(ctx).QueryContext(ctx, query, args...)
}
func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return db.Conn(ctx).QueryRowContext(ctx, query, args...)
}
func (db *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
return db.LoggingDB.BeginTx(ctx, opts)
}
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error { func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
if ctx.Value(ContextKeyDatabaseTransaction) != nil { if ctx.Value(ContextKeyDatabaseTransaction) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one") zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
@@ -82,13 +98,13 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
return nil return nil
} }
func (db *Database) Conn(ctx context.Context) ContextExecable { func (db *Database) Conn(ctx context.Context) Execable {
if ctx == nil { if ctx == nil {
return db return &db.LoggingDB
} }
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction) txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
if ok { if ok {
return txn return txn
} }
return db return &db.LoggingDB
} }

View File

@@ -7,12 +7,13 @@
package dbutil package dbutil
import ( import (
"context"
"database/sql" "database/sql"
"errors" "errors"
"fmt" "fmt"
) )
type upgradeFunc func(Execable, *Database) error type upgradeFunc func(context.Context, *Database) error
type upgrade struct { type upgrade struct {
message string message string
@@ -28,19 +29,19 @@ var ErrForeignTables = errors.New("the database contains foreign tables")
var ErrNotOwned = errors.New("the database is owned by") var ErrNotOwned = errors.New("the database is owned by")
var ErrUnsupportedDialect = errors.New("unsupported database dialect") var ErrUnsupportedDialect = errors.New("unsupported database dialect")
func (db *Database) upgradeVersionTable() error { func (db *Database) upgradeVersionTable(ctx context.Context) error {
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil { if compatColumnExists, err := db.ColumnExists(ctx, db.VersionTable, "compat"); err != nil {
return fmt.Errorf("failed to check if version table is up to date: %w", err) return fmt.Errorf("failed to check if version table is up to date: %w", err)
} else if !compatColumnExists { } else if !compatColumnExists {
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil { if tableExists, err := db.TableExists(ctx, db.VersionTable); err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err) return fmt.Errorf("failed to check if version table exists: %w", err)
} else if !tableExists { } else if !tableExists {
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable)) _, err = db.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
if err != nil { if err != nil {
return fmt.Errorf("failed to create version table: %w", err) return fmt.Errorf("failed to create version table: %w", err)
} }
} else { } else {
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable)) _, err = db.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
if err != nil { if err != nil {
return fmt.Errorf("failed to add compat column to version table: %w", err) return fmt.Errorf("failed to add compat column to version table: %w", err)
} }
@@ -49,13 +50,13 @@ func (db *Database) upgradeVersionTable() error {
return nil return nil
} }
func (db *Database) getVersion() (version, compat int, err error) { func (db *Database) getVersion(ctx context.Context) (version, compat int, err error) {
if err = db.upgradeVersionTable(); err != nil { if err = db.upgradeVersionTable(ctx); err != nil {
return return
} }
var compatNull sql.NullInt32 var compatNull sql.NullInt32
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull) err = db.QueryRow(ctx, fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
if errors.Is(err, sql.ErrNoRows) { if errors.Is(err, sql.ErrNoRows) {
err = nil err = nil
} }
@@ -72,15 +73,12 @@ const (
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)" tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
) )
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) { func (db *Database) TableExists(ctx context.Context, table string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect { switch db.Dialect {
case SQLite: case SQLite:
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists) err = db.QueryRow(ctx, tableExistsSQLite, table).Scan(&exists)
case Postgres: case Postgres:
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists) err = db.QueryRow(ctx, tableExistsPostgres, table).Scan(&exists)
default: default:
err = ErrUnsupportedDialect err = ErrUnsupportedDialect
} }
@@ -92,23 +90,20 @@ const (
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)" columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
) )
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) { func (db *Database) ColumnExists(ctx context.Context, table, column string) (exists bool, err error) {
if tx == nil {
tx = db
}
switch db.Dialect { switch db.Dialect {
case SQLite: case SQLite:
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists) err = db.QueryRow(ctx, columnExistsSQLite, table, column).Scan(&exists)
case Postgres: case Postgres:
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists) err = db.QueryRow(ctx, columnExistsPostgres, table, column).Scan(&exists)
default: default:
err = ErrUnsupportedDialect err = ErrUnsupportedDialect
} }
return return
} }
func (db *Database) tableExistsNoError(table string) bool { func (db *Database) tableExistsNoError(ctx context.Context, table string) bool {
exists, err := db.TableExists(nil, table) exists, err := db.TableExists(ctx, table)
if err != nil { if err != nil {
panic(fmt.Errorf("failed to check if table exists: %w", err)) panic(fmt.Errorf("failed to check if table exists: %w", err))
} }
@@ -122,22 +117,22 @@ CREATE TABLE IF NOT EXISTS database_owner (
) )
` `
func (db *Database) checkDatabaseOwner() error { func (db *Database) checkDatabaseOwner(ctx context.Context) error {
var owner string var owner string
if !db.IgnoreForeignTables { if !db.IgnoreForeignTables {
if db.tableExistsNoError("state_groups_state") { if db.tableExistsNoError(ctx, "state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables) return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if db.tableExistsNoError("roomserver_rooms") { } else if db.tableExistsNoError(ctx, "roomserver_rooms") {
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables) return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
} }
} }
if db.Owner == "" { if db.Owner == "" {
return nil return nil
} }
if _, err := db.Exec(createOwnerTable); err != nil { if _, err := db.Exec(ctx, createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err) return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) { } else if err = db.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner) _, err = db.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
if err != nil { if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err) return fmt.Errorf("failed to insert database owner: %w", err)
} }
@@ -149,22 +144,22 @@ func (db *Database) checkDatabaseOwner() error {
return nil return nil
} }
func (db *Database) setVersion(tx Execable, version, compat int) error { func (db *Database) setVersion(ctx context.Context, version, compat int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable)) _, err := db.Exec(ctx, fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil { if err != nil {
return err return err
} }
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat) _, err = db.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
return err return err
} }
func (db *Database) Upgrade() error { func (db *Database) Upgrade(ctx context.Context) error {
err := db.checkDatabaseOwner() err := db.checkDatabaseOwner(ctx)
if err != nil { if err != nil {
return err return err
} }
version, compat, err := db.getVersion() version, compat, err := db.getVersion(ctx)
if err != nil { if err != nil {
return err return err
} }
@@ -185,34 +180,28 @@ func (db *Database) Upgrade() error {
version++ version++
continue continue
} }
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction) doUpgrade := func(ctx context.Context) error {
var tx Transaction err = upgradeItem.fn(ctx, db)
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil { if err != nil {
return err return fmt.Errorf("failed to run upgrade #%d: %w", version, err)
}
upgradeConn = tx
} else {
upgradeConn = db
}
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
} }
version = upgradeItem.upgradesTo version = upgradeItem.upgradesTo
logVersion = version logVersion = version
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion) err = db.setVersion(ctx, version, upgradeItem.compatVersion)
if err != nil { if err != nil {
return err return err
} }
if tx != nil { return nil
err = tx.Commit() }
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
if upgradeItem.transaction {
err = db.DoTxn(ctx, nil, doUpgrade)
} else {
err = doUpgrade(ctx)
}
if err != nil { if err != nil {
return err return err
} }
} }
}
return nil return nil
} }

View File

@@ -8,6 +8,7 @@ package dbutil
import ( import (
"bytes" "bytes"
"context"
"errors" "errors"
"fmt" "fmt"
"io/fs" "io/fs"
@@ -189,27 +190,27 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
} }
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc { func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx Execable, db *Database) error { return func(ctx context.Context, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine { if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
return nil return nil
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil { } else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err)) panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
} else { } else {
_, err = tx.Exec(upgradeSQL) _, err = db.Exec(ctx, upgradeSQL)
return err return err
} }
} }
} }
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc { func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
return func(tx Execable, database *Database) (err error) { return func(ctx context.Context, db *Database) (err error) {
switch database.Dialect { switch db.Dialect {
case SQLite: case SQLite:
_, err = tx.Exec(sqliteData) _, err = db.Exec(ctx, sqliteData)
case Postgres: case Postgres:
_, err = tx.Exec(postgresData) _, err = db.Exec(ctx, postgresData)
default: default:
err = fmt.Errorf("unknown dialect %s", database.Dialect) err = fmt.Errorf("unknown dialect %s", db.Dialect)
} }
return return
} }

23
vendor/go.mau.fi/util/exerrors/must.go vendored Normal file
View File

@@ -0,0 +1,23 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package exerrors
func Must[T any](val T, err error) T {
PanicIfNotNil(err)
return val
}
func Must2[T any, T2 any](val T, val2 T2, err error) (T, T2) {
PanicIfNotNil(err)
return val, val2
}
func PanicIfNotNil(err error) {
if err != nil {
panic(err)
}
}

View File

@@ -7,10 +7,16 @@
package jsontime package jsontime
import ( import (
"database/sql"
"database/sql/driver"
"encoding/json" "encoding/json"
"errors"
"fmt"
"time" "time"
) )
var ErrNotInteger = errors.New("value is not an integer")
func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) error { func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) error {
var val int64 var val int64
err := json.Unmarshal(data, &val) err := json.Unmarshal(data, &val)
@@ -25,6 +31,28 @@ func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) err
return nil return nil
} }
func anyIntegerToTime(src any, unixConv func(int64) time.Time, into *time.Time) error {
switch v := src.(type) {
case int:
*into = unixConv(int64(v))
case int8:
*into = unixConv(int64(v))
case int16:
*into = unixConv(int64(v))
case int32:
*into = unixConv(int64(v))
case int64:
*into = unixConv(int64(v))
default:
return fmt.Errorf("%w: %T", ErrNotInteger, src)
}
return nil
}
var _ sql.Scanner = &UnixMilli{}
var _ driver.Valuer = UnixMilli{}
type UnixMilli struct { type UnixMilli struct {
time.Time time.Time
} }
@@ -40,6 +68,17 @@ func (um *UnixMilli) UnmarshalJSON(data []byte) error {
return parseTime(data, time.UnixMilli, &um.Time) return parseTime(data, time.UnixMilli, &um.Time)
} }
func (um UnixMilli) Value() (driver.Value, error) {
return um.UnixMilli(), nil
}
func (um *UnixMilli) Scan(src any) error {
return anyIntegerToTime(src, time.UnixMilli, &um.Time)
}
var _ sql.Scanner = &UnixMicro{}
var _ driver.Valuer = UnixMicro{}
type UnixMicro struct { type UnixMicro struct {
time.Time time.Time
} }
@@ -55,6 +94,17 @@ func (um *UnixMicro) UnmarshalJSON(data []byte) error {
return parseTime(data, time.UnixMicro, &um.Time) return parseTime(data, time.UnixMicro, &um.Time)
} }
func (um UnixMicro) Value() (driver.Value, error) {
return um.UnixMicro(), nil
}
func (um *UnixMicro) Scan(src any) error {
return anyIntegerToTime(src, time.UnixMicro, &um.Time)
}
var _ sql.Scanner = &UnixNano{}
var _ driver.Valuer = UnixNano{}
type UnixNano struct { type UnixNano struct {
time.Time time.Time
} }
@@ -72,6 +122,16 @@ func (un *UnixNano) UnmarshalJSON(data []byte) error {
}, &un.Time) }, &un.Time)
} }
func (un UnixNano) Value() (driver.Value, error) {
return un.UnixNano(), nil
}
func (un *UnixNano) Scan(src any) error {
return anyIntegerToTime(src, func(i int64) time.Time {
return time.Unix(0, i)
}, &un.Time)
}
type Unix struct { type Unix struct {
time.Time time.Time
} }
@@ -83,8 +143,21 @@ func (u Unix) MarshalJSON() ([]byte, error) {
return json.Marshal(u.Unix()) return json.Marshal(u.Unix())
} }
var _ sql.Scanner = &Unix{}
var _ driver.Valuer = Unix{}
func (u *Unix) UnmarshalJSON(data []byte) error { func (u *Unix) UnmarshalJSON(data []byte) error {
return parseTime(data, func(i int64) time.Time { return parseTime(data, func(i int64) time.Time {
return time.Unix(i, 0) return time.Unix(i, 0)
}, &u.Time) }, &u.Time)
} }
func (u Unix) Value() (driver.Value, error) {
return u.Unix(), nil
}
func (u *Unix) Scan(src any) error {
return anyIntegerToTime(src, func(i int64) time.Time {
return time.Unix(i, 0)
}, &u.Time)
}

View File

@@ -199,8 +199,8 @@ TEXT ·mixBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX MOVQ out+0(FP), DX
MOVQ a+8(FP), AX MOVQ a+8(FP), AX
MOVQ b+16(FP), BX MOVQ b+16(FP), BX
MOVQ a+24(FP), CX MOVQ c+24(FP), CX
MOVQ $128, BP MOVQ $128, DI
loop: loop:
MOVOU 0(AX), X0 MOVOU 0(AX), X0
@@ -213,7 +213,7 @@ loop:
ADDQ $16, BX ADDQ $16, BX
ADDQ $16, CX ADDQ $16, CX
ADDQ $16, DX ADDQ $16, DX
SUBQ $2, BP SUBQ $2, DI
JA loop JA loop
RET RET
@@ -222,8 +222,8 @@ TEXT ·xorBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX MOVQ out+0(FP), DX
MOVQ a+8(FP), AX MOVQ a+8(FP), AX
MOVQ b+16(FP), BX MOVQ b+16(FP), BX
MOVQ a+24(FP), CX MOVQ c+24(FP), CX
MOVQ $128, BP MOVQ $128, DI
loop: loop:
MOVOU 0(AX), X0 MOVOU 0(AX), X0
@@ -238,6 +238,6 @@ loop:
ADDQ $16, BX ADDQ $16, BX
ADDQ $16, CX ADDQ $16, CX
ADDQ $16, DX ADDQ $16, DX
SUBQ $2, BP SUBQ $2, DI
JA loop JA loop
RET RET

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.7 && amd64 && gc && !purego //go:build amd64 && gc && !purego
package blake2b package blake2b

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.7 && amd64 && gc && !purego //go:build amd64 && gc && !purego
#include "textflag.h" #include "textflag.h"

View File

@@ -1,24 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.7 && amd64 && gc && !purego
package blake2b
import "golang.org/x/sys/cpu"
func init() {
useSSE4 = cpu.X86.HasSSE41
}
//go:noescape
func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
func hashBlocks(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) {
if useSSE4 {
hashBlocksSSE4(h, c, flag, blocks)
} else {
hashBlocksGeneric(h, c, flag, blocks)
}
}

View File

@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build go1.9
package blake2b package blake2b
import ( import (

View File

@@ -1,39 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.13
package poly1305
// Generic fallbacks for the math/bits intrinsics, copied from
// src/math/bits/bits.go. They were added in Go 1.12, but Add64 and Sum64 had
// variable time fallbacks until Go 1.13.
func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
sum = x + y + carry
carryOut = ((x & y) | ((x | y) &^ sum)) >> 63
return
}
func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
diff = x - y - borrow
borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 63
return
}
func bitsMul64(x, y uint64) (hi, lo uint64) {
const mask32 = 1<<32 - 1
x0 := x & mask32
x1 := x >> 32
y0 := y & mask32
y1 := y >> 32
w0 := x0 * y0
t := x1*y0 + w0>>32
w1 := t & mask32
w2 := t >> 32
w1 += x0 * y1
hi = x1*y1 + w2 + w1>>32
lo = x * y
return
}

View File

@@ -1,21 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.13
package poly1305
import "math/bits"
func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
return bits.Add64(x, y, carry)
}
func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
return bits.Sub64(x, y, borrow)
}
func bitsMul64(x, y uint64) (hi, lo uint64) {
return bits.Mul64(x, y)
}

View File

@@ -7,7 +7,10 @@
package poly1305 package poly1305
import "encoding/binary" import (
"encoding/binary"
"math/bits"
)
// Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag // Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
// for a 64 bytes message is approximately // for a 64 bytes message is approximately
@@ -114,13 +117,13 @@ type uint128 struct {
} }
func mul64(a, b uint64) uint128 { func mul64(a, b uint64) uint128 {
hi, lo := bitsMul64(a, b) hi, lo := bits.Mul64(a, b)
return uint128{lo, hi} return uint128{lo, hi}
} }
func add128(a, b uint128) uint128 { func add128(a, b uint128) uint128 {
lo, c := bitsAdd64(a.lo, b.lo, 0) lo, c := bits.Add64(a.lo, b.lo, 0)
hi, c := bitsAdd64(a.hi, b.hi, c) hi, c := bits.Add64(a.hi, b.hi, c)
if c != 0 { if c != 0 {
panic("poly1305: unexpected overflow") panic("poly1305: unexpected overflow")
} }
@@ -155,8 +158,8 @@ func updateGeneric(state *macState, msg []byte) {
// hide leading zeroes. For full chunks, that's 1 << 128, so we can just // hide leading zeroes. For full chunks, that's 1 << 128, so we can just
// add 1 to the most significant (2¹²⁸) limb, h2. // add 1 to the most significant (2¹²⁸) limb, h2.
if len(msg) >= TagSize { if len(msg) >= TagSize {
h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0) h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(msg[8:16]), c) h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
h2 += c + 1 h2 += c + 1
msg = msg[TagSize:] msg = msg[TagSize:]
@@ -165,8 +168,8 @@ func updateGeneric(state *macState, msg []byte) {
copy(buf[:], msg) copy(buf[:], msg)
buf[len(msg)] = 1 buf[len(msg)] = 1
h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0) h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(buf[8:16]), c) h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
h2 += c h2 += c
msg = nil msg = nil
@@ -219,9 +222,9 @@ func updateGeneric(state *macState, msg []byte) {
m3 := h2r1 m3 := h2r1
t0 := m0.lo t0 := m0.lo
t1, c := bitsAdd64(m1.lo, m0.hi, 0) t1, c := bits.Add64(m1.lo, m0.hi, 0)
t2, c := bitsAdd64(m2.lo, m1.hi, c) t2, c := bits.Add64(m2.lo, m1.hi, c)
t3, _ := bitsAdd64(m3.lo, m2.hi, c) t3, _ := bits.Add64(m3.lo, m2.hi, c)
// Now we have the result as 4 64-bit limbs, and we need to reduce it // Now we have the result as 4 64-bit limbs, and we need to reduce it
// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do // modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
@@ -243,14 +246,14 @@ func updateGeneric(state *macState, msg []byte) {
// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c. // To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
h0, c = bitsAdd64(h0, cc.lo, 0) h0, c = bits.Add64(h0, cc.lo, 0)
h1, c = bitsAdd64(h1, cc.hi, c) h1, c = bits.Add64(h1, cc.hi, c)
h2 += c h2 += c
cc = shiftRightBy2(cc) cc = shiftRightBy2(cc)
h0, c = bitsAdd64(h0, cc.lo, 0) h0, c = bits.Add64(h0, cc.lo, 0)
h1, c = bitsAdd64(h1, cc.hi, c) h1, c = bits.Add64(h1, cc.hi, c)
h2 += c h2 += c
// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most // h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
@@ -287,9 +290,9 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the // in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
// result if the subtraction underflows, and t otherwise. // result if the subtraction underflows, and t otherwise.
hMinusP0, b := bitsSub64(h0, p0, 0) hMinusP0, b := bits.Sub64(h0, p0, 0)
hMinusP1, b := bitsSub64(h1, p1, b) hMinusP1, b := bits.Sub64(h1, p1, b)
_, b = bitsSub64(h2, p2, b) _, b = bits.Sub64(h2, p2, b)
// h = h if h < p else h - p // h = h if h < p else h - p
h0 = select64(b, h0, hMinusP0) h0 = select64(b, h0, hMinusP0)
@@ -301,8 +304,8 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
// //
// by just doing a wide addition with the 128 low bits of h and discarding // by just doing a wide addition with the 128 low bits of h and discarding
// the overflow. // the overflow.
h0, c := bitsAdd64(h0, s[0], 0) h0, c := bits.Add64(h0, s[0], 0)
h1, _ = bitsAdd64(h1, s[1], c) h1, _ = bits.Add64(h1, s[1], c)
binary.LittleEndian.PutUint64(out[0:8], h0) binary.LittleEndian.PutUint64(out[0:8], h0)
binary.LittleEndian.PutUint64(out[8:16], h1) binary.LittleEndian.PutUint64(out[8:16], h1)

View File

@@ -187,9 +187,11 @@ type channel struct {
pending *buffer pending *buffer
extPending *buffer extPending *buffer
// windowMu protects myWindow, the flow-control window. // windowMu protects myWindow, the flow-control window, and myConsumed,
// the number of bytes consumed since we last increased myWindow
windowMu sync.Mutex windowMu sync.Mutex
myWindow uint32 myWindow uint32
myConsumed uint32
// writeMu serializes calls to mux.conn.writePacket() and // writeMu serializes calls to mux.conn.writePacket() and
// protects sentClose and packetPool. This mutex must be // protects sentClose and packetPool. This mutex must be
@@ -332,14 +334,24 @@ func (ch *channel) handleData(packet []byte) error {
return nil return nil
} }
func (c *channel) adjustWindow(n uint32) error { func (c *channel) adjustWindow(adj uint32) error {
c.windowMu.Lock() c.windowMu.Lock()
// Since myWindow is managed on our side, and can never exceed // Since myConsumed and myWindow are managed on our side, and can never
// the initial window setting, we don't worry about overflow. // exceed the initial window setting, we don't worry about overflow.
c.myWindow += uint32(n) c.myConsumed += adj
var sendAdj uint32
if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
(c.myWindow < channelWindowSize/2) {
sendAdj = c.myConsumed
c.myConsumed = 0
c.myWindow += sendAdj
}
c.windowMu.Unlock() c.windowMu.Unlock()
if sendAdj == 0 {
return nil
}
return c.sendMessage(windowAdjustMsg{ return c.sendMessage(windowAdjustMsg{
AdditionalBytes: uint32(n), AdditionalBytes: sendAdj,
}) })
} }

View File

@@ -82,7 +82,7 @@ func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan
if err := conn.clientHandshake(addr, &fullConf); err != nil { if err := conn.clientHandshake(addr, &fullConf); err != nil {
c.Close() c.Close()
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err) return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
} }
conn.mux = newMux(conn.transport) conn.mux = newMux(conn.transport)
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil

View File

@@ -307,7 +307,10 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
} }
var methods []string var methods []string
var errSigAlgo error var errSigAlgo error
for _, signer := range signers {
origSignersLen := len(signers)
for idx := 0; idx < len(signers); idx++ {
signer := signers[idx]
pub := signer.PublicKey() pub := signer.PublicKey()
as, algo, err := pickSignatureAlgorithm(signer, extensions) as, algo, err := pickSignatureAlgorithm(signer, extensions)
if err != nil && errSigAlgo == nil { if err != nil && errSigAlgo == nil {
@@ -321,6 +324,21 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
if err != nil { if err != nil {
return authFailure, nil, err return authFailure, nil, err
} }
// OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512
// in the "server-sig-algs" extension but doesn't support these
// algorithms for certificate authentication, so if the server rejects
// the key try to use the obtained algorithm as if "server-sig-algs" had
// not been implemented if supported from the algorithm signer.
if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
if contains(as.Algorithms(), KeyAlgoRSA) {
// We retry using the compat algorithm after all signers have
// been tried normally.
signers = append(signers, &multiAlgorithmSigner{
AlgorithmSigner: as,
supportedAlgorithms: []string{KeyAlgoRSA},
})
}
}
if !ok { if !ok {
continue continue
} }

View File

@@ -127,6 +127,14 @@ func isRSA(algo string) bool {
return contains(algos, underlyingAlgo(algo)) return contains(algos, underlyingAlgo(algo))
} }
func isRSACert(algo string) bool {
_, ok := certKeyAlgoNames[algo]
if !ok {
return false
}
return isRSA(algo)
}
// supportedPubKeyAuthAlgos specifies the supported client public key // supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate types // authentication algorithms. Note that this doesn't include certificate types
// since those use the underlying algorithm. This list is sent to the client if // since those use the underlying algorithm. This list is sent to the client if

View File

@@ -35,6 +35,16 @@ type keyingTransport interface {
// direction will be effected if a msgNewKeys message is sent // direction will be effected if a msgNewKeys message is sent
// or received. // or received.
prepareKeyChange(*algorithms, *kexResult) error prepareKeyChange(*algorithms, *kexResult) error
// setStrictMode sets the strict KEX mode, notably triggering
// sequence number resets on sending or receiving msgNewKeys.
// If the sequence number is already > 1 when setStrictMode
// is called, an error is returned.
setStrictMode() error
// setInitialKEXDone indicates to the transport that the initial key exchange
// was completed
setInitialKEXDone()
} }
// handshakeTransport implements rekeying on top of a keyingTransport // handshakeTransport implements rekeying on top of a keyingTransport
@@ -100,6 +110,10 @@ type handshakeTransport struct {
// The session ID or nil if first kex did not complete yet. // The session ID or nil if first kex did not complete yet.
sessionID []byte sessionID []byte
// strictMode indicates if the other side of the handshake indicated
// that we should be following the strict KEX protocol restrictions.
strictMode bool
} }
type pendingKex struct { type pendingKex struct {
@@ -209,7 +223,10 @@ func (t *handshakeTransport) readLoop() {
close(t.incoming) close(t.incoming)
break break
} }
if p[0] == msgIgnore || p[0] == msgDebug { // If this is the first kex, and strict KEX mode is enabled,
// we don't ignore any messages, as they may be used to manipulate
// the packet sequence numbers.
if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
continue continue
} }
t.incoming <- p t.incoming <- p
@@ -441,6 +458,11 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
return successPacket, nil return successPacket, nil
} }
const (
kexStrictClient = "kex-strict-c-v00@openssh.com"
kexStrictServer = "kex-strict-s-v00@openssh.com"
)
// sendKexInit sends a key change message. // sendKexInit sends a key change message.
func (t *handshakeTransport) sendKexInit() error { func (t *handshakeTransport) sendKexInit() error {
t.mu.Lock() t.mu.Lock()
@@ -454,7 +476,6 @@ func (t *handshakeTransport) sendKexInit() error {
} }
msg := &kexInitMsg{ msg := &kexInitMsg{
KexAlgos: t.config.KeyExchanges,
CiphersClientServer: t.config.Ciphers, CiphersClientServer: t.config.Ciphers,
CiphersServerClient: t.config.Ciphers, CiphersServerClient: t.config.Ciphers,
MACsClientServer: t.config.MACs, MACsClientServer: t.config.MACs,
@@ -464,6 +485,13 @@ func (t *handshakeTransport) sendKexInit() error {
} }
io.ReadFull(rand.Reader, msg.Cookie[:]) io.ReadFull(rand.Reader, msg.Cookie[:])
// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
// and possibly to add the ext-info extension algorithm. Since the slice may be the
// user owned KeyExchanges, we create our own slice in order to avoid using user
// owned memory by mistake.
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
isServer := len(t.hostKeys) > 0 isServer := len(t.hostKeys) > 0
if isServer { if isServer {
for _, k := range t.hostKeys { for _, k := range t.hostKeys {
@@ -488,17 +516,24 @@ func (t *handshakeTransport) sendKexInit() error {
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat) msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
} }
} }
if t.sessionID == nil {
msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
}
} else { } else {
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what // As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
// algorithms the server supports for public key authentication. See RFC // algorithms the server supports for public key authentication. See RFC
// 8308, Section 2.1. // 8308, Section 2.1.
//
// We also send the strict KEX mode extension algorithm, in order to opt
// into the strict KEX mode.
if firstKeyExchange := t.sessionID == nil; firstKeyExchange { if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1)
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
msg.KexAlgos = append(msg.KexAlgos, "ext-info-c") msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
} }
} }
packet := Marshal(msg) packet := Marshal(msg)
@@ -604,6 +639,13 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return err return err
} }
if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
t.strictMode = true
if err := t.conn.setStrictMode(); err != nil {
return err
}
}
// We don't send FirstKexFollows, but we handle receiving it. // We don't send FirstKexFollows, but we handle receiving it.
// //
// RFC 4253 section 7 defines the kex and the agreement method for // RFC 4253 section 7 defines the kex and the agreement method for
@@ -679,6 +721,12 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return unexpectedMessageError(msgNewKeys, packet[0]) return unexpectedMessageError(msgNewKeys, packet[0])
} }
if firstKeyExchange {
// Indicates to the transport that the first key exchange is completed
// after receiving SSH_MSG_NEWKEYS.
t.conn.setInitialKEXDone()
}
return nil return nil
} }

View File

@@ -213,6 +213,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
} else { } else {
for _, algo := range fullConf.PublicKeyAuthAlgorithms { for _, algo := range fullConf.PublicKeyAuthAlgorithms {
if !contains(supportedPubKeyAuthAlgos, algo) { if !contains(supportedPubKeyAuthAlgos, algo) {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo) return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
} }
} }
@@ -220,6 +221,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
// Check if the config contains any unsupported key exchanges // Check if the config contains any unsupported key exchanges
for _, kex := range fullConf.KeyExchanges { for _, kex := range fullConf.KeyExchanges {
if _, ok := serverForbiddenKexAlgos[kex]; ok { if _, ok := serverForbiddenKexAlgos[kex]; ok {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex) return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
} }
} }
@@ -337,7 +339,7 @@ func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr) return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
} }
func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *connection, func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection,
sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) { sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) {
gssAPIServer := gssapiConfig.Server gssAPIServer := gssapiConfig.Server
defer gssAPIServer.DeleteSecContext() defer gssAPIServer.DeleteSecContext()
@@ -347,7 +349,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
outToken []byte outToken []byte
needContinue bool needContinue bool
) )
outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(firstToken) outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token)
if err != nil { if err != nil {
return err, nil, nil return err, nil, nil
} }
@@ -369,6 +371,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil { if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
return nil, nil, err return nil, nil, err
} }
token = userAuthGSSAPITokenReq.Token
} }
packet, err := s.transport.readPacket() packet, err := s.transport.readPacket()
if err != nil { if err != nil {

View File

@@ -5,6 +5,7 @@
package ssh package ssh
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"io" "io"
@@ -332,6 +333,40 @@ func (l *tcpListener) Addr() net.Addr {
return l.laddr return l.laddr
} }
// DialContext initiates a connection to the addr from the remote host.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
//
// See func Dial for additional information.
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
type connErr struct {
conn net.Conn
err error
}
ch := make(chan connErr)
go func() {
conn, err := c.Dial(n, addr)
select {
case ch <- connErr{conn, err}:
case <-ctx.Done():
if conn != nil {
conn.Close()
}
}
}()
select {
case res := <-ch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Dial initiates a connection to the addr from the remote host. // Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr(). // The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) { func (c *Client) Dial(n, addr string) (net.Conn, error) {

View File

@@ -49,6 +49,9 @@ type transport struct {
rand io.Reader rand io.Reader
isClient bool isClient bool
io.Closer io.Closer
strictMode bool
initialKEXDone bool
} }
// packetCipher represents a combination of SSH encryption/MAC // packetCipher represents a combination of SSH encryption/MAC
@@ -74,6 +77,18 @@ type connectionState struct {
pendingKeyChange chan packetCipher pendingKeyChange chan packetCipher
} }
func (t *transport) setStrictMode() error {
if t.reader.seqNum != 1 {
return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
}
t.strictMode = true
return nil
}
func (t *transport) setInitialKEXDone() {
t.initialKEXDone = true
}
// prepareKeyChange sets up key material for a keychange. The key changes in // prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet // both directions are triggered by reading and writing a msgNewKey packet
// respectively. // respectively.
@@ -112,11 +127,12 @@ func (t *transport) printPacket(p []byte, write bool) {
// Read and decrypt next packet. // Read and decrypt next packet.
func (t *transport) readPacket() (p []byte, err error) { func (t *transport) readPacket() (p []byte, err error) {
for { for {
p, err = t.reader.readPacket(t.bufReader) p, err = t.reader.readPacket(t.bufReader, t.strictMode)
if err != nil { if err != nil {
break break
} }
if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) { // in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
break break
} }
} }
@@ -127,7 +143,7 @@ func (t *transport) readPacket() (p []byte, err error) {
return p, err return p, err
} }
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) { func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
packet, err := s.packetCipher.readCipherPacket(s.seqNum, r) packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
s.seqNum++ s.seqNum++
if err == nil && len(packet) == 0 { if err == nil && len(packet) == 0 {
@@ -140,6 +156,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
select { select {
case cipher := <-s.pendingKeyChange: case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default: default:
return nil, errors.New("ssh: got bogus newkeys message") return nil, errors.New("ssh: got bogus newkeys message")
} }
@@ -170,10 +189,10 @@ func (t *transport) writePacket(packet []byte) error {
if debugTransport { if debugTransport {
t.printPacket(packet, true) t.printPacket(packet, true)
} }
return t.writer.writePacket(t.bufWriter, t.rand, packet) return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
} }
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error { func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet) err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
@@ -188,6 +207,9 @@ func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []
select { select {
case cipher := <-s.pendingKeyChange: case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default: default:
panic("ssh: no key material for msgNewKeys") panic("ssh: no key material for msgNewKeys")
} }

View File

@@ -209,25 +209,37 @@ func Insert[S ~[]E, E any](s S, i int, v ...E) S {
return s return s
} }
// Delete removes the elements s[i:j] from s, returning the modified slice. // clearSlice sets all elements up to the length of s to the zero value of E.
// Delete panics if s[i:j] is not a valid slice of s. // We may use the builtin clear func instead, and remove clearSlice, when upgrading
// Delete is O(len(s)-j), so if many items must be deleted, it is better to // to Go 1.21+.
// make a single call deleting them all together than to delete one at a time. func clearSlice[S ~[]E, E any](s S) {
// Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those var zero E
// elements contain pointers you might consider zeroing those elements so that for i := range s {
// objects they reference can be garbage collected. s[i] = zero
func Delete[S ~[]E, E any](s S, i, j int) S { }
_ = s[i:j] // bounds check }
return append(s[:i], s[j:]...) // Delete removes the elements s[i:j] from s, returning the modified slice.
// Delete panics if j > len(s) or s[i:j] is not a valid slice of s.
// Delete is O(len(s)-i), so if many items must be deleted, it is better to
// make a single call deleting them all together than to delete one at a time.
// Delete zeroes the elements s[len(s)-(j-i):len(s)].
func Delete[S ~[]E, E any](s S, i, j int) S {
_ = s[i:j:len(s)] // bounds check
if i == j {
return s
}
oldlen := len(s)
s = append(s[:i], s[j:]...)
clearSlice(s[len(s):oldlen]) // zero/nil out the obsolete elements, for GC
return s
} }
// DeleteFunc removes any elements from s for which del returns true, // DeleteFunc removes any elements from s for which del returns true,
// returning the modified slice. // returning the modified slice.
// When DeleteFunc removes m elements, it might not modify the elements // DeleteFunc zeroes the elements between the new length and the original length.
// s[len(s)-m:len(s)]. If those elements contain pointers you might consider
// zeroing those elements so that objects they reference can be garbage
// collected.
func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S { func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
i := IndexFunc(s, del) i := IndexFunc(s, del)
if i == -1 { if i == -1 {
@@ -240,11 +252,13 @@ func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
i++ i++
} }
} }
clearSlice(s[i:]) // zero/nil out the obsolete elements, for GC
return s[:i] return s[:i]
} }
// Replace replaces the elements s[i:j] by the given v, and returns the // Replace replaces the elements s[i:j] by the given v, and returns the
// modified slice. Replace panics if s[i:j] is not a valid slice of s. // modified slice. Replace panics if s[i:j] is not a valid slice of s.
// When len(v) < (j-i), Replace zeroes the elements between the new length and the original length.
func Replace[S ~[]E, E any](s S, i, j int, v ...E) S { func Replace[S ~[]E, E any](s S, i, j int, v ...E) S {
_ = s[i:j] // verify that i:j is a valid subslice _ = s[i:j] // verify that i:j is a valid subslice
@@ -272,6 +286,7 @@ func Replace[S ~[]E, E any](s S, i, j int, v ...E) S {
if i+len(v) != j { if i+len(v) != j {
copy(r[i+len(v):], s[j:]) copy(r[i+len(v):], s[j:])
} }
clearSlice(s[tot:]) // zero/nil out the obsolete elements, for GC
return r return r
} }
@@ -345,9 +360,7 @@ func Clone[S ~[]E, E any](s S) S {
// This is like the uniq command found on Unix. // This is like the uniq command found on Unix.
// Compact modifies the contents of the slice s and returns the modified slice, // Compact modifies the contents of the slice s and returns the modified slice,
// which may have a smaller length. // which may have a smaller length.
// When Compact discards m elements in total, it might not modify the elements // Compact zeroes the elements between the new length and the original length.
// s[len(s)-m:len(s)]. If those elements contain pointers you might consider
// zeroing those elements so that objects they reference can be garbage collected.
func Compact[S ~[]E, E comparable](s S) S { func Compact[S ~[]E, E comparable](s S) S {
if len(s) < 2 { if len(s) < 2 {
return s return s
@@ -361,11 +374,13 @@ func Compact[S ~[]E, E comparable](s S) S {
i++ i++
} }
} }
clearSlice(s[i:]) // zero/nil out the obsolete elements, for GC
return s[:i] return s[:i]
} }
// CompactFunc is like [Compact] but uses an equality function to compare elements. // CompactFunc is like [Compact] but uses an equality function to compare elements.
// For runs of elements that compare equal, CompactFunc keeps the first one. // For runs of elements that compare equal, CompactFunc keeps the first one.
// CompactFunc zeroes the elements between the new length and the original length.
func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S { func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S {
if len(s) < 2 { if len(s) < 2 {
return s return s
@@ -379,6 +394,7 @@ func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S {
i++ i++
} }
} }
clearSlice(s[i:]) // zero/nil out the obsolete elements, for GC
return s[:i] return s[:i]
} }

View File

@@ -910,9 +910,6 @@ func (z *Tokenizer) readTagAttrKey() {
return return
} }
switch c { switch c {
case ' ', '\n', '\r', '\t', '\f', '/':
z.pendingAttr[0].end = z.raw.end - 1
return
case '=': case '=':
if z.pendingAttr[0].start+1 == z.raw.end { if z.pendingAttr[0].start+1 == z.raw.end {
// WHATWG 13.2.5.32, if we see an equals sign before the attribute name // WHATWG 13.2.5.32, if we see an equals sign before the attribute name
@@ -920,7 +917,9 @@ func (z *Tokenizer) readTagAttrKey() {
continue continue
} }
fallthrough fallthrough
case '>': case ' ', '\n', '\r', '\t', '\f', '/', '>':
// WHATWG 13.2.5.33 Attribute name state
// We need to reconsume the char in the after attribute name state to support the / character
z.raw.end-- z.raw.end--
z.pendingAttr[0].end = z.raw.end z.pendingAttr[0].end = z.raw.end
return return
@@ -939,6 +938,11 @@ func (z *Tokenizer) readTagAttrVal() {
if z.err != nil { if z.err != nil {
return return
} }
if c == '/' {
// WHATWG 13.2.5.34 After attribute name state
// U+002F SOLIDUS (/) - Switch to the self-closing start tag state.
return
}
if c != '=' { if c != '=' {
z.raw.end-- z.raw.end--
return return

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style // Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file. // license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || openbsd //go:build dragonfly || freebsd || linux || netbsd
package unix package unix

View File

@@ -231,3 +231,8 @@ func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) {
func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error { func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error {
return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value)) return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value))
} }
// IoctlLoopConfigure configures all loop device parameters in a single step
func IoctlLoopConfigure(fd int, value *LoopConfig) error {
return ioctlPtr(fd, LOOP_CONFIGURE, unsafe.Pointer(value))
}

View File

@@ -248,6 +248,7 @@ struct ltchars {
#include <linux/module.h> #include <linux/module.h>
#include <linux/mount.h> #include <linux/mount.h>
#include <linux/netfilter/nfnetlink.h> #include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nf_tables.h>
#include <linux/netlink.h> #include <linux/netlink.h>
#include <linux/net_namespace.h> #include <linux/net_namespace.h>
#include <linux/nfc.h> #include <linux/nfc.h>
@@ -283,10 +284,6 @@ struct ltchars {
#include <asm/termbits.h> #include <asm/termbits.h>
#endif #endif
#ifndef MSG_FASTOPEN
#define MSG_FASTOPEN 0x20000000
#endif
#ifndef PTRACE_GETREGS #ifndef PTRACE_GETREGS
#define PTRACE_GETREGS 0xc #define PTRACE_GETREGS 0xc
#endif #endif
@@ -295,14 +292,6 @@ struct ltchars {
#define PTRACE_SETREGS 0xd #define PTRACE_SETREGS 0xd
#endif #endif
#ifndef SOL_NETLINK
#define SOL_NETLINK 270
#endif
#ifndef SOL_SMC
#define SOL_SMC 286
#endif
#ifdef SOL_BLUETOOTH #ifdef SOL_BLUETOOTH
// SPARC includes this in /usr/include/sparc64-linux-gnu/bits/socket.h // SPARC includes this in /usr/include/sparc64-linux-gnu/bits/socket.h
// but it is already in bluetooth_linux.go // but it is already in bluetooth_linux.go
@@ -319,10 +308,23 @@ struct ltchars {
#undef TIPC_WAIT_FOREVER #undef TIPC_WAIT_FOREVER
#define TIPC_WAIT_FOREVER 0xffffffff #define TIPC_WAIT_FOREVER 0xffffffff
// Copied from linux/l2tp.h // Copied from linux/netfilter/nf_nat.h
// Including linux/l2tp.h here causes conflicts between linux/in.h // Including linux/netfilter/nf_nat.h here causes conflicts between linux/in.h
// and netinet/in.h included via net/route.h above. // and netinet/in.h.
#define IPPROTO_L2TP 115 #define NF_NAT_RANGE_MAP_IPS (1 << 0)
#define NF_NAT_RANGE_PROTO_SPECIFIED (1 << 1)
#define NF_NAT_RANGE_PROTO_RANDOM (1 << 2)
#define NF_NAT_RANGE_PERSISTENT (1 << 3)
#define NF_NAT_RANGE_PROTO_RANDOM_FULLY (1 << 4)
#define NF_NAT_RANGE_PROTO_OFFSET (1 << 5)
#define NF_NAT_RANGE_NETMAP (1 << 6)
#define NF_NAT_RANGE_PROTO_RANDOM_ALL \
(NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
#define NF_NAT_RANGE_MASK \
(NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED | \
NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PERSISTENT | \
NF_NAT_RANGE_PROTO_RANDOM_FULLY | NF_NAT_RANGE_PROTO_OFFSET | \
NF_NAT_RANGE_NETMAP)
// Copied from linux/hid.h. // Copied from linux/hid.h.
// Keep in sync with the size of the referenced fields. // Keep in sync with the size of the referenced fields.
@@ -519,6 +521,7 @@ ccflags="$@"
$2 ~ /^LOCK_(SH|EX|NB|UN)$/ || $2 ~ /^LOCK_(SH|EX|NB|UN)$/ ||
$2 ~ /^LO_(KEY|NAME)_SIZE$/ || $2 ~ /^LO_(KEY|NAME)_SIZE$/ ||
$2 ~ /^LOOP_(CLR|CTL|GET|SET)_/ || $2 ~ /^LOOP_(CLR|CTL|GET|SET)_/ ||
$2 == "LOOP_CONFIGURE" ||
$2 ~ /^(AF|SOCK|SO|SOL|IPPROTO|IP|IPV6|TCP|MCAST|EVFILT|NOTE|SHUT|PROT|MAP|MREMAP|MFD|T?PACKET|MSG|SCM|MCL|DT|MADV|PR|LOCAL|TCPOPT|UDP)_/ || $2 ~ /^(AF|SOCK|SO|SOL|IPPROTO|IP|IPV6|TCP|MCAST|EVFILT|NOTE|SHUT|PROT|MAP|MREMAP|MFD|T?PACKET|MSG|SCM|MCL|DT|MADV|PR|LOCAL|TCPOPT|UDP)_/ ||
$2 ~ /^NFC_(GENL|PROTO|COMM|RF|SE|DIRECTION|LLCP|SOCKPROTO)_/ || $2 ~ /^NFC_(GENL|PROTO|COMM|RF|SE|DIRECTION|LLCP|SOCKPROTO)_/ ||
$2 ~ /^NFC_.*_(MAX)?SIZE$/ || $2 ~ /^NFC_.*_(MAX)?SIZE$/ ||
@@ -560,7 +563,7 @@ ccflags="$@"
$2 ~ /^RLIMIT_(AS|CORE|CPU|DATA|FSIZE|LOCKS|MEMLOCK|MSGQUEUE|NICE|NOFILE|NPROC|RSS|RTPRIO|RTTIME|SIGPENDING|STACK)|RLIM_INFINITY/ || $2 ~ /^RLIMIT_(AS|CORE|CPU|DATA|FSIZE|LOCKS|MEMLOCK|MSGQUEUE|NICE|NOFILE|NPROC|RSS|RTPRIO|RTTIME|SIGPENDING|STACK)|RLIM_INFINITY/ ||
$2 ~ /^PRIO_(PROCESS|PGRP|USER)/ || $2 ~ /^PRIO_(PROCESS|PGRP|USER)/ ||
$2 ~ /^CLONE_[A-Z_]+/ || $2 ~ /^CLONE_[A-Z_]+/ ||
$2 !~ /^(BPF_TIMEVAL|BPF_FIB_LOOKUP_[A-Z]+)$/ && $2 !~ /^(BPF_TIMEVAL|BPF_FIB_LOOKUP_[A-Z]+|BPF_F_LINK)$/ &&
$2 ~ /^(BPF|DLT)_/ || $2 ~ /^(BPF|DLT)_/ ||
$2 ~ /^AUDIT_/ || $2 ~ /^AUDIT_/ ||
$2 ~ /^(CLOCK|TIMER)_/ || $2 ~ /^(CLOCK|TIMER)_/ ||
@@ -581,7 +584,7 @@ ccflags="$@"
$2 ~ /^KEY_(SPEC|REQKEY_DEFL)_/ || $2 ~ /^KEY_(SPEC|REQKEY_DEFL)_/ ||
$2 ~ /^KEYCTL_/ || $2 ~ /^KEYCTL_/ ||
$2 ~ /^PERF_/ || $2 ~ /^PERF_/ ||
$2 ~ /^SECCOMP_MODE_/ || $2 ~ /^SECCOMP_/ ||
$2 ~ /^SEEK_/ || $2 ~ /^SEEK_/ ||
$2 ~ /^SCHED_/ || $2 ~ /^SCHED_/ ||
$2 ~ /^SPLICE_/ || $2 ~ /^SPLICE_/ ||
@@ -602,6 +605,9 @@ ccflags="$@"
$2 ~ /^FSOPT_/ || $2 ~ /^FSOPT_/ ||
$2 ~ /^WDIO[CFS]_/ || $2 ~ /^WDIO[CFS]_/ ||
$2 ~ /^NFN/ || $2 ~ /^NFN/ ||
$2 !~ /^NFT_META_IIFTYPE/ &&
$2 ~ /^NFT_/ ||
$2 ~ /^NF_NAT_/ ||
$2 ~ /^XDP_/ || $2 ~ /^XDP_/ ||
$2 ~ /^RWF_/ || $2 ~ /^RWF_/ ||
$2 ~ /^(HDIO|WIN|SMART)_/ || $2 ~ /^(HDIO|WIN|SMART)_/ ||

View File

@@ -316,7 +316,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return string(buf[:vallen-1]), nil return ByteSliceToString(buf[:vallen]), nil
} }
//sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error) //sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error)

View File

@@ -61,16 +61,24 @@ func FanotifyMark(fd int, flags uint, mask uint64, dirFd int, pathname string) (
} }
//sys fchmodat(dirfd int, path string, mode uint32) (err error) //sys fchmodat(dirfd int, path string, mode uint32) (err error)
//sys fchmodat2(dirfd int, path string, mode uint32, flags int) (err error)
func Fchmodat(dirfd int, path string, mode uint32, flags int) (err error) { func Fchmodat(dirfd int, path string, mode uint32, flags int) error {
// Linux fchmodat doesn't support the flags parameter. Mimick glibc's behavior // Linux fchmodat doesn't support the flags parameter, but fchmodat2 does.
// and check the flags. Otherwise the mode would be applied to the symlink // Try fchmodat2 if flags are specified.
// destination which is not what the user expects. if flags != 0 {
if flags&^AT_SYMLINK_NOFOLLOW != 0 { err := fchmodat2(dirfd, path, mode, flags)
if err == ENOSYS {
// fchmodat2 isn't available. If the flags are known to be valid,
// return EOPNOTSUPP to indicate that fchmodat doesn't support them.
if flags&^(AT_SYMLINK_NOFOLLOW|AT_EMPTY_PATH) != 0 {
return EINVAL return EINVAL
} else if flags&AT_SYMLINK_NOFOLLOW != 0 { } else if flags&(AT_SYMLINK_NOFOLLOW|AT_EMPTY_PATH) != 0 {
return EOPNOTSUPP return EOPNOTSUPP
} }
}
return err
}
return fchmodat(dirfd, path, mode) return fchmodat(dirfd, path, mode)
} }
@@ -1302,7 +1310,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
return "", err return "", err
} }
} }
return string(buf[:vallen-1]), nil return ByteSliceToString(buf[:vallen]), nil
} }
func GetsockoptTpacketStats(fd, level, opt int) (*TpacketStats, error) { func GetsockoptTpacketStats(fd, level, opt int) (*TpacketStats, error) {

View File

@@ -166,6 +166,20 @@ func Getresgid() (rgid, egid, sgid int) {
//sys sysctl(mib []_C_int, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) = SYS___SYSCTL //sys sysctl(mib []_C_int, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) = SYS___SYSCTL
//sys fcntl(fd int, cmd int, arg int) (n int, err error)
//sys fcntlPtr(fd int, cmd int, arg unsafe.Pointer) (n int, err error) = SYS_FCNTL
// FcntlInt performs a fcntl syscall on fd with the provided command and argument.
func FcntlInt(fd uintptr, cmd, arg int) (int, error) {
return fcntl(int(fd), cmd, arg)
}
// FcntlFlock performs a fcntl syscall for the F_GETLK, F_SETLK or F_SETLKW command.
func FcntlFlock(fd uintptr, cmd int, lk *Flock_t) error {
_, err := fcntlPtr(int(fd), cmd, unsafe.Pointer(lk))
return err
}
//sys ppoll(fds *PollFd, nfds int, timeout *Timespec, sigmask *Sigset_t) (n int, err error) //sys ppoll(fds *PollFd, nfds int, timeout *Timespec, sigmask *Sigset_t) (n int, err error)
func Ppoll(fds []PollFd, timeout *Timespec, sigmask *Sigset_t) (n int, err error) { func Ppoll(fds []PollFd, timeout *Timespec, sigmask *Sigset_t) (n int, err error) {

View File

@@ -158,7 +158,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
if err != nil { if err != nil {
return "", err return "", err
} }
return string(buf[:vallen-1]), nil return ByteSliceToString(buf[:vallen]), nil
} }
const ImplementsGetwd = true const ImplementsGetwd = true

View File

@@ -1104,7 +1104,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
return "", err return "", err
} }
return string(buf[:vallen-1]), nil return ByteSliceToString(buf[:vallen]), nil
} }
func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) { func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {

View File

@@ -486,7 +486,6 @@ const (
BPF_F_ANY_ALIGNMENT = 0x2 BPF_F_ANY_ALIGNMENT = 0x2
BPF_F_BEFORE = 0x8 BPF_F_BEFORE = 0x8
BPF_F_ID = 0x20 BPF_F_ID = 0x20
BPF_F_LINK = 0x2000
BPF_F_NETFILTER_IP_DEFRAG = 0x1 BPF_F_NETFILTER_IP_DEFRAG = 0x1
BPF_F_QUERY_EFFECTIVE = 0x1 BPF_F_QUERY_EFFECTIVE = 0x1
BPF_F_REPLACE = 0x4 BPF_F_REPLACE = 0x4
@@ -1786,6 +1785,8 @@ const (
LANDLOCK_ACCESS_FS_REMOVE_FILE = 0x20 LANDLOCK_ACCESS_FS_REMOVE_FILE = 0x20
LANDLOCK_ACCESS_FS_TRUNCATE = 0x4000 LANDLOCK_ACCESS_FS_TRUNCATE = 0x4000
LANDLOCK_ACCESS_FS_WRITE_FILE = 0x2 LANDLOCK_ACCESS_FS_WRITE_FILE = 0x2
LANDLOCK_ACCESS_NET_BIND_TCP = 0x1
LANDLOCK_ACCESS_NET_CONNECT_TCP = 0x2
LANDLOCK_CREATE_RULESET_VERSION = 0x1 LANDLOCK_CREATE_RULESET_VERSION = 0x1
LINUX_REBOOT_CMD_CAD_OFF = 0x0 LINUX_REBOOT_CMD_CAD_OFF = 0x0
LINUX_REBOOT_CMD_CAD_ON = 0x89abcdef LINUX_REBOOT_CMD_CAD_ON = 0x89abcdef
@@ -1802,6 +1803,7 @@ const (
LOCK_SH = 0x1 LOCK_SH = 0x1
LOCK_UN = 0x8 LOCK_UN = 0x8
LOOP_CLR_FD = 0x4c01 LOOP_CLR_FD = 0x4c01
LOOP_CONFIGURE = 0x4c0a
LOOP_CTL_ADD = 0x4c80 LOOP_CTL_ADD = 0x4c80
LOOP_CTL_GET_FREE = 0x4c82 LOOP_CTL_GET_FREE = 0x4c82
LOOP_CTL_REMOVE = 0x4c81 LOOP_CTL_REMOVE = 0x4c81
@@ -2127,6 +2129,60 @@ const (
NFNL_SUBSYS_QUEUE = 0x3 NFNL_SUBSYS_QUEUE = 0x3
NFNL_SUBSYS_ULOG = 0x4 NFNL_SUBSYS_ULOG = 0x4
NFS_SUPER_MAGIC = 0x6969 NFS_SUPER_MAGIC = 0x6969
NFT_CHAIN_FLAGS = 0x7
NFT_CHAIN_MAXNAMELEN = 0x100
NFT_CT_MAX = 0x17
NFT_DATA_RESERVED_MASK = 0xffffff00
NFT_DATA_VALUE_MAXLEN = 0x40
NFT_EXTHDR_OP_MAX = 0x4
NFT_FIB_RESULT_MAX = 0x3
NFT_INNER_MASK = 0xf
NFT_LOGLEVEL_MAX = 0x8
NFT_NAME_MAXLEN = 0x100
NFT_NG_MAX = 0x1
NFT_OBJECT_CONNLIMIT = 0x5
NFT_OBJECT_COUNTER = 0x1
NFT_OBJECT_CT_EXPECT = 0x9
NFT_OBJECT_CT_HELPER = 0x3
NFT_OBJECT_CT_TIMEOUT = 0x7
NFT_OBJECT_LIMIT = 0x4
NFT_OBJECT_MAX = 0xa
NFT_OBJECT_QUOTA = 0x2
NFT_OBJECT_SECMARK = 0x8
NFT_OBJECT_SYNPROXY = 0xa
NFT_OBJECT_TUNNEL = 0x6
NFT_OBJECT_UNSPEC = 0x0
NFT_OBJ_MAXNAMELEN = 0x100
NFT_OSF_MAXGENRELEN = 0x10
NFT_QUEUE_FLAG_BYPASS = 0x1
NFT_QUEUE_FLAG_CPU_FANOUT = 0x2
NFT_QUEUE_FLAG_MASK = 0x3
NFT_REG32_COUNT = 0x10
NFT_REG32_SIZE = 0x4
NFT_REG_MAX = 0x4
NFT_REG_SIZE = 0x10
NFT_REJECT_ICMPX_MAX = 0x3
NFT_RT_MAX = 0x4
NFT_SECMARK_CTX_MAXLEN = 0x100
NFT_SET_MAXNAMELEN = 0x100
NFT_SOCKET_MAX = 0x3
NFT_TABLE_F_MASK = 0x3
NFT_TABLE_MAXNAMELEN = 0x100
NFT_TRACETYPE_MAX = 0x3
NFT_TUNNEL_F_MASK = 0x7
NFT_TUNNEL_MAX = 0x1
NFT_TUNNEL_MODE_MAX = 0x2
NFT_USERDATA_MAXLEN = 0x100
NFT_XFRM_KEY_MAX = 0x6
NF_NAT_RANGE_MAP_IPS = 0x1
NF_NAT_RANGE_MASK = 0x7f
NF_NAT_RANGE_NETMAP = 0x40
NF_NAT_RANGE_PERSISTENT = 0x8
NF_NAT_RANGE_PROTO_OFFSET = 0x20
NF_NAT_RANGE_PROTO_RANDOM = 0x4
NF_NAT_RANGE_PROTO_RANDOM_ALL = 0x14
NF_NAT_RANGE_PROTO_RANDOM_FULLY = 0x10
NF_NAT_RANGE_PROTO_SPECIFIED = 0x2
NILFS_SUPER_MAGIC = 0x3434 NILFS_SUPER_MAGIC = 0x3434
NL0 = 0x0 NL0 = 0x0
NL1 = 0x100 NL1 = 0x100
@@ -2411,6 +2467,7 @@ const (
PR_MCE_KILL_GET = 0x22 PR_MCE_KILL_GET = 0x22
PR_MCE_KILL_LATE = 0x0 PR_MCE_KILL_LATE = 0x0
PR_MCE_KILL_SET = 0x1 PR_MCE_KILL_SET = 0x1
PR_MDWE_NO_INHERIT = 0x2
PR_MDWE_REFUSE_EXEC_GAIN = 0x1 PR_MDWE_REFUSE_EXEC_GAIN = 0x1
PR_MPX_DISABLE_MANAGEMENT = 0x2c PR_MPX_DISABLE_MANAGEMENT = 0x2c
PR_MPX_ENABLE_MANAGEMENT = 0x2b PR_MPX_ENABLE_MANAGEMENT = 0x2b
@@ -2615,8 +2672,9 @@ const (
RTAX_FEATURES = 0xc RTAX_FEATURES = 0xc
RTAX_FEATURE_ALLFRAG = 0x8 RTAX_FEATURE_ALLFRAG = 0x8
RTAX_FEATURE_ECN = 0x1 RTAX_FEATURE_ECN = 0x1
RTAX_FEATURE_MASK = 0xf RTAX_FEATURE_MASK = 0x1f
RTAX_FEATURE_SACK = 0x2 RTAX_FEATURE_SACK = 0x2
RTAX_FEATURE_TCP_USEC_TS = 0x10
RTAX_FEATURE_TIMESTAMP = 0x4 RTAX_FEATURE_TIMESTAMP = 0x4
RTAX_HOPLIMIT = 0xa RTAX_HOPLIMIT = 0xa
RTAX_INITCWND = 0xb RTAX_INITCWND = 0xb
@@ -2859,9 +2917,38 @@ const (
SCM_RIGHTS = 0x1 SCM_RIGHTS = 0x1
SCM_TIMESTAMP = 0x1d SCM_TIMESTAMP = 0x1d
SC_LOG_FLUSH = 0x100000 SC_LOG_FLUSH = 0x100000
SECCOMP_ADDFD_FLAG_SEND = 0x2
SECCOMP_ADDFD_FLAG_SETFD = 0x1
SECCOMP_FILTER_FLAG_LOG = 0x2
SECCOMP_FILTER_FLAG_NEW_LISTENER = 0x8
SECCOMP_FILTER_FLAG_SPEC_ALLOW = 0x4
SECCOMP_FILTER_FLAG_TSYNC = 0x1
SECCOMP_FILTER_FLAG_TSYNC_ESRCH = 0x10
SECCOMP_FILTER_FLAG_WAIT_KILLABLE_RECV = 0x20
SECCOMP_GET_ACTION_AVAIL = 0x2
SECCOMP_GET_NOTIF_SIZES = 0x3
SECCOMP_IOCTL_NOTIF_RECV = 0xc0502100
SECCOMP_IOCTL_NOTIF_SEND = 0xc0182101
SECCOMP_IOC_MAGIC = '!'
SECCOMP_MODE_DISABLED = 0x0 SECCOMP_MODE_DISABLED = 0x0
SECCOMP_MODE_FILTER = 0x2 SECCOMP_MODE_FILTER = 0x2
SECCOMP_MODE_STRICT = 0x1 SECCOMP_MODE_STRICT = 0x1
SECCOMP_RET_ACTION = 0x7fff0000
SECCOMP_RET_ACTION_FULL = 0xffff0000
SECCOMP_RET_ALLOW = 0x7fff0000
SECCOMP_RET_DATA = 0xffff
SECCOMP_RET_ERRNO = 0x50000
SECCOMP_RET_KILL = 0x0
SECCOMP_RET_KILL_PROCESS = 0x80000000
SECCOMP_RET_KILL_THREAD = 0x0
SECCOMP_RET_LOG = 0x7ffc0000
SECCOMP_RET_TRACE = 0x7ff00000
SECCOMP_RET_TRAP = 0x30000
SECCOMP_RET_USER_NOTIF = 0x7fc00000
SECCOMP_SET_MODE_FILTER = 0x1
SECCOMP_SET_MODE_STRICT = 0x0
SECCOMP_USER_NOTIF_FD_SYNC_WAKE_UP = 0x1
SECCOMP_USER_NOTIF_FLAG_CONTINUE = 0x1
SECRETMEM_MAGIC = 0x5345434d SECRETMEM_MAGIC = 0x5345434d
SECURITYFS_MAGIC = 0x73636673 SECURITYFS_MAGIC = 0x73636673
SEEK_CUR = 0x1 SEEK_CUR = 0x1
@@ -3021,6 +3108,7 @@ const (
SOL_TIPC = 0x10f SOL_TIPC = 0x10f
SOL_TLS = 0x11a SOL_TLS = 0x11a
SOL_UDP = 0x11 SOL_UDP = 0x11
SOL_VSOCK = 0x11f
SOL_X25 = 0x106 SOL_X25 = 0x106
SOL_XDP = 0x11b SOL_XDP = 0x11b
SOMAXCONN = 0x1000 SOMAXCONN = 0x1000

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800 SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905 SIOCATMARK = 0x8905

View File

@@ -282,6 +282,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800 SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905 SIOCATMARK = 0x8905

View File

@@ -288,6 +288,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800 SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905 SIOCATMARK = 0x8905

View File

@@ -278,6 +278,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800 SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905 SIOCATMARK = 0x8905

View File

@@ -275,6 +275,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800 SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905 SIOCATMARK = 0x8905

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80 SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307 SIOCATMARK = 0x40047307

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80 SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307 SIOCATMARK = 0x40047307

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80 SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307 SIOCATMARK = 0x40047307

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80 SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307 SIOCATMARK = 0x40047307

View File

@@ -336,6 +336,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800 SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905 SIOCATMARK = 0x8905

View File

@@ -340,6 +340,9 @@ const (
SCM_TIMESTAMPNS = 0x23 SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29 SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000 SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800 SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905 SIOCATMARK = 0x8905

Some files were not shown because too many files have changed in this diff Show More