refactor to mautrix 0.17.x; update deps

This commit is contained in:
Aine
2024-02-11 20:47:04 +02:00
parent 0a9701f4c9
commit dd0ad4c245
237 changed files with 9091 additions and 3317 deletions

View File

@@ -22,19 +22,19 @@ func parseMXIDpatterns(patterns []string, defaultPattern string) ([]*regexp.Rege
return mxidwc.ParsePatterns(patterns)
}
func (b *Bot) allowUsers(actorID id.UserID, targetRoomID id.RoomID) bool {
func (b *Bot) allowUsers(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
// first, check if it's an allowed user
if mxidwc.Match(actorID.String(), b.allowedUsers) {
return true
}
// second, check if it's an admin (admin may not fit the allowed users pattern)
if b.allowAdmin(actorID, targetRoomID) {
if b.allowAdmin(ctx, actorID, targetRoomID) {
return true
}
// then, check if it's the owner (same as above)
cfg, err := b.cfg.GetRoom(targetRoomID)
cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err == nil && cfg.Owner() == actorID.String() {
return true
}
@@ -42,15 +42,15 @@ func (b *Bot) allowUsers(actorID id.UserID, targetRoomID id.RoomID) bool {
return false
}
func (b *Bot) allowAnyone(_ id.UserID, _ id.RoomID) bool {
func (b *Bot) allowAnyone(_ context.Context, _ id.UserID, _ id.RoomID) bool {
return true
}
func (b *Bot) allowOwner(actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(actorID, targetRoomID) {
func (b *Bot) allowOwner(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(ctx, actorID, targetRoomID) {
return false
}
cfg, err := b.cfg.GetRoom(targetRoomID)
cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err != nil {
b.Error(context.Background(), "failed to retrieve settings: %v", err)
return false
@@ -61,19 +61,19 @@ func (b *Bot) allowOwner(actorID id.UserID, targetRoomID id.RoomID) bool {
return true
}
return owner == actorID.String() || b.allowAdmin(actorID, targetRoomID)
return owner == actorID.String() || b.allowAdmin(ctx, actorID, targetRoomID)
}
func (b *Bot) allowAdmin(actorID id.UserID, _ id.RoomID) bool {
func (b *Bot) allowAdmin(_ context.Context, actorID id.UserID, _ id.RoomID) bool {
return mxidwc.Match(actorID.String(), b.allowedAdmins)
}
func (b *Bot) allowSend(actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(actorID, targetRoomID) {
func (b *Bot) allowSend(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(ctx, actorID, targetRoomID) {
return false
}
cfg, err := b.cfg.GetRoom(targetRoomID)
cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err != nil {
b.Error(context.Background(), "failed to retrieve settings: %v", err)
return false
@@ -82,14 +82,14 @@ func (b *Bot) allowSend(actorID id.UserID, targetRoomID id.RoomID) bool {
return !cfg.NoSend()
}
func (b *Bot) allowReply(actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(actorID, targetRoomID) {
func (b *Bot) allowReply(ctx context.Context, actorID id.UserID, targetRoomID id.RoomID) bool {
if !b.allowUsers(ctx, actorID, targetRoomID) {
return false
}
cfg, err := b.cfg.GetRoom(targetRoomID)
cfg, err := b.cfg.GetRoom(ctx, targetRoomID)
if err != nil {
b.Error(context.Background(), "failed to retrieve settings: %v", err)
b.Error(ctx, "failed to retrieve settings: %v", err)
return false
}
@@ -106,30 +106,30 @@ func (b *Bot) isReserved(mailbox string) bool {
}
// IsGreylisted checks if host is in greylist
func (b *Bot) IsGreylisted(addr net.Addr) bool {
if b.cfg.GetBot().Greylist() == 0 {
func (b *Bot) IsGreylisted(ctx context.Context, addr net.Addr) bool {
if b.cfg.GetBot(ctx).Greylist() == 0 {
return false
}
greylist := b.cfg.GetGreylist()
greylist := b.cfg.GetGreylist(ctx)
greylistedAt, ok := greylist.Get(addr)
if !ok {
b.log.Debug().Str("addr", addr.String()).Msg("greylisting")
greylist.Add(addr)
err := b.cfg.SetGreylist(greylist)
err := b.cfg.SetGreylist(ctx, greylist)
if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update greylist")
}
return true
}
duration := time.Duration(b.cfg.GetBot().Greylist()) * time.Minute
duration := time.Duration(b.cfg.GetBot(ctx).Greylist()) * time.Minute
return greylistedAt.Add(duration).After(time.Now().UTC())
}
// IsBanned checks if address is banned
func (b *Bot) IsBanned(addr net.Addr) bool {
return b.cfg.GetBanlist().Has(addr)
func (b *Bot) IsBanned(ctx context.Context, addr net.Addr) bool {
return b.cfg.GetBanlist(ctx).Has(addr)
}
// IsTrusted checks if address is a trusted (proxy)
@@ -146,12 +146,12 @@ func (b *Bot) IsTrusted(addr net.Addr) bool {
}
// Ban an address automatically
func (b *Bot) BanAuto(addr net.Addr) {
if !b.cfg.GetBot().BanlistEnabled() {
func (b *Bot) BanAuto(ctx context.Context, addr net.Addr) {
if !b.cfg.GetBot(ctx).BanlistEnabled() {
return
}
if !b.cfg.GetBot().BanlistAuto() {
if !b.cfg.GetBot(ctx).BanlistAuto() {
return
}
@@ -159,21 +159,21 @@ func (b *Bot) BanAuto(addr net.Addr) {
return
}
b.log.Debug().Str("addr", addr.String()).Msg("attempting to automatically ban")
banlist := b.cfg.GetBanlist()
banlist := b.cfg.GetBanlist(ctx)
banlist.Add(addr)
err := b.cfg.SetBanlist(banlist)
err := b.cfg.SetBanlist(ctx, banlist)
if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist")
}
}
// Ban an address for incorrect auth automatically
func (b *Bot) BanAuth(addr net.Addr) {
if !b.cfg.GetBot().BanlistEnabled() {
func (b *Bot) BanAuth(ctx context.Context, addr net.Addr) {
if !b.cfg.GetBot(ctx).BanlistEnabled() {
return
}
if !b.cfg.GetBot().BanlistAuth() {
if !b.cfg.GetBot(ctx).BanlistAuth() {
return
}
@@ -181,33 +181,33 @@ func (b *Bot) BanAuth(addr net.Addr) {
return
}
b.log.Debug().Str("addr", addr.String()).Msg("attempting to automatically ban")
banlist := b.cfg.GetBanlist()
banlist := b.cfg.GetBanlist(ctx)
banlist.Add(addr)
err := b.cfg.SetBanlist(banlist)
err := b.cfg.SetBanlist(ctx, banlist)
if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist")
}
}
// Ban an address manually
func (b *Bot) BanManually(addr net.Addr) {
if !b.cfg.GetBot().BanlistEnabled() {
func (b *Bot) BanManually(ctx context.Context, addr net.Addr) {
if !b.cfg.GetBot(ctx).BanlistEnabled() {
return
}
if b.IsTrusted(addr) {
return
}
b.log.Debug().Str("addr", addr.String()).Msg("attempting to manually ban")
banlist := b.cfg.GetBanlist()
banlist := b.cfg.GetBanlist(ctx)
banlist.Add(addr)
err := b.cfg.SetBanlist(banlist)
err := b.cfg.SetBanlist(ctx, banlist)
if err != nil {
b.log.Error().Err(err).Str("addr", addr.String()).Msg("cannot update banlist")
}
}
// AllowAuth check if SMTP login (email) and password are valid
func (b *Bot) AllowAuth(email, password string) (id.RoomID, bool) {
func (b *Bot) AllowAuth(ctx context.Context, email, password string) (id.RoomID, bool) {
var suffix bool
for _, domain := range b.domains {
if strings.HasSuffix(email, "@"+domain) {
@@ -223,7 +223,7 @@ func (b *Bot) AllowAuth(email, password string) (id.RoomID, bool) {
if !ok {
return "", false
}
cfg, err := b.cfg.GetRoom(roomID)
cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil {
b.log.Error().Err(err).Msg("failed to retrieve settings")
return "", false

View File

@@ -1,13 +1,14 @@
package bot
import (
"context"
"fmt"
"maunium.net/go/mautrix/format"
"maunium.net/go/mautrix/id"
)
type activationFlow func(id.UserID, id.RoomID, string) bool
type activationFlow func(context.Context, id.UserID, id.RoomID, string) bool
func (b *Bot) getActivationFlow() activationFlow {
switch b.mbxc.Activation {
@@ -21,19 +22,19 @@ func (b *Bot) getActivationFlow() activationFlow {
}
// ActivateMailbox using the configured flow
func (b *Bot) ActivateMailbox(ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
func (b *Bot) ActivateMailbox(ctx context.Context, ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
flow := b.getActivationFlow()
return flow(ownerID, roomID, mailbox)
return flow(ctx, ownerID, roomID, mailbox)
}
func (b *Bot) activateNone(ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
func (b *Bot) activateNone(_ context.Context, ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
b.log.Debug().Str("mailbox", mailbox).Str("roomID", roomID.String()).Str("ownerID", ownerID.String()).Msg("activating mailbox through the flow 'none'")
b.rooms.Store(mailbox, roomID)
return true
}
func (b *Bot) activateNotify(ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
func (b *Bot) activateNotify(ctx context.Context, ownerID id.UserID, roomID id.RoomID, mailbox string) bool {
b.log.Debug().Str("mailbox", mailbox).Str("roomID", roomID.String()).Str("ownerID", ownerID.String()).Msg("activating mailbox through the flow 'notify'")
b.rooms.Store(mailbox, roomID)
if len(b.adminRooms) == 0 {
@@ -43,7 +44,7 @@ func (b *Bot) activateNotify(ownerID id.UserID, roomID id.RoomID, mailbox string
msg := fmt.Sprintf("Mailbox %q has been registered by %q for the room %q", mailbox, ownerID, roomID)
for _, adminRoom := range b.adminRooms {
content := format.RenderMarkdown(msg, true, true)
_, err := b.lp.Send(adminRoom, &content)
_, err := b.lp.Send(ctx, adminRoom, &content)
if err != nil {
b.log.Info().Str("adminRoom", adminRoom.String()).Msg("cannot send mailbox activation notification to the admin room")
continue

View File

@@ -36,6 +36,7 @@ type Bot struct {
rooms sync.Map
proxies []string
sendmail func(string, string, string) error
psd *utils.PSD
cfg *config.Manager
log *zerolog.Logger
lp *linkpearl.Linkpearl
@@ -50,6 +51,7 @@ func New(
lp *linkpearl.Linkpearl,
log *zerolog.Logger,
cfg *config.Manager,
psd *utils.PSD,
proxies []string,
prefix string,
domains []string,
@@ -63,13 +65,14 @@ func New(
adminRooms: []id.RoomID{},
proxies: proxies,
mbxc: mbxc,
psd: psd,
cfg: cfg,
log: log,
lp: lp,
mu: utils.NewMutex(),
q: q,
}
users, err := b.initBotUsers()
users, err := b.initBotUsers(context.Background())
if err != nil {
return nil, err
}
@@ -105,7 +108,7 @@ func (b *Bot) Error(ctx context.Context, message string, args ...any) {
}
var noThreads bool
cfg, cerr := b.cfg.GetRoom(evt.RoomID)
cfg, cerr := b.cfg.GetRoom(ctx, evt.RoomID)
if cerr == nil {
noThreads = cfg.NoThreads()
}
@@ -115,16 +118,17 @@ func (b *Bot) Error(ctx context.Context, message string, args ...any) {
relatesTo = linkpearl.RelatesTo(threadID, noThreads)
}
b.lp.SendNotice(evt.RoomID, "ERROR: "+err.Error(), relatesTo)
b.lp.SendNotice(ctx, evt.RoomID, "ERROR: "+err.Error(), relatesTo)
}
// Start performs matrix /sync
func (b *Bot) Start(statusMsg string) error {
if err := b.migrateMautrix015(); err != nil {
ctx := context.Background()
if err := b.migrateMautrix015(ctx); err != nil {
return err
}
if err := b.syncRooms(); err != nil {
if err := b.syncRooms(ctx); err != nil {
return err
}
@@ -135,7 +139,8 @@ func (b *Bot) Start(statusMsg string) error {
// Stop the bot
func (b *Bot) Stop() {
err := b.lp.GetClient().SetPresence(event.PresenceOffline)
ctx := context.Background()
err := b.lp.GetClient().SetPresence(ctx, event.PresenceOffline)
if err != nil {
b.log.Error().Err(err).Msg("cannot set presence = offline")
}

View File

@@ -46,7 +46,7 @@ type (
key string
description string
sanitizer func(string) string
allowed func(id.UserID, id.RoomID) bool
allowed func(context.Context, id.UserID, id.RoomID) bool
}
commandList []command
)
@@ -351,7 +351,7 @@ func (b *Bot) initCommands() commandList {
func (b *Bot) handle(ctx context.Context) {
evt := eventFromContext(ctx)
err := b.lp.GetClient().MarkRead(evt.RoomID, evt.ID)
err := b.lp.GetClient().MarkRead(ctx, evt.RoomID, evt.ID)
if err != nil {
b.log.Error().Err(err).Msg("cannot send read receipt")
}
@@ -378,14 +378,14 @@ func (b *Bot) handle(ctx context.Context) {
if cmd == nil {
return
}
_, err = b.lp.GetClient().UserTyping(evt.RoomID, true, 30*time.Second)
_, err = b.lp.GetClient().UserTyping(ctx, evt.RoomID, true, 30*time.Second)
if err != nil {
b.log.Error().Err(err).Msg("cannot send typing notification")
}
defer b.lp.GetClient().UserTyping(evt.RoomID, false, 30*time.Second) //nolint:errcheck // we don't care
defer b.lp.GetClient().UserTyping(ctx, evt.RoomID, false, 30*time.Second) //nolint:errcheck // we don't care
if !cmd.allowed(evt.Sender, evt.RoomID) {
b.lp.SendNotice(evt.RoomID, "not allowed to do that, kupo")
if !cmd.allowed(ctx, evt.Sender, evt.RoomID) {
b.lp.SendNotice(ctx, evt.RoomID, "not allowed to do that, kupo")
return
}
@@ -452,7 +452,7 @@ func (b *Bot) parseCommand(message string, toLower bool) []string {
return strings.Split(strings.TrimSpace(message), " ")
}
func (b *Bot) sendIntroduction(roomID id.RoomID) {
func (b *Bot) sendIntroduction(ctx context.Context, roomID id.RoomID) {
var msg strings.Builder
msg.WriteString("Hello, kupo!\n\n")
@@ -468,7 +468,7 @@ func (b *Bot) sendIntroduction(roomID id.RoomID) {
msg.WriteString(utils.EmailsList("SOME_INBOX", ""))
msg.WriteString("` and have them appear in this room.")
b.lp.SendNotice(roomID, msg.String())
b.lp.SendNotice(ctx, roomID, msg.String())
}
func (b *Bot) getHelpValue(cfg config.Room, cmd command) string {
@@ -497,7 +497,7 @@ func (b *Bot) getHelpValue(cfg config.Room, cmd command) string {
func (b *Bot) sendHelp(ctx context.Context) {
evt := eventFromContext(ctx)
cfg, serr := b.cfg.GetRoom(evt.RoomID)
cfg, serr := b.cfg.GetRoom(ctx, evt.RoomID)
if serr != nil {
b.log.Error().Err(serr).Msg("cannot retrieve settings")
}
@@ -505,7 +505,7 @@ func (b *Bot) sendHelp(ctx context.Context) {
var msg strings.Builder
msg.WriteString("The following commands are supported and accessible to you:\n\n")
for _, cmd := range b.commands {
if !cmd.allowed(evt.Sender, evt.RoomID) {
if !cmd.allowed(ctx, evt.Sender, evt.RoomID) {
continue
}
if cmd.key == "" {
@@ -528,7 +528,7 @@ func (b *Bot) sendHelp(ctx context.Context) {
msg.WriteString("\n")
}
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
func (b *Bot) runSend(ctx context.Context) {
@@ -538,7 +538,7 @@ func (b *Bot) runSend(ctx context.Context) {
return
}
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "failed to retrieve room settings: %v", err)
return
@@ -555,11 +555,11 @@ func (b *Bot) runSend(ctx context.Context) {
func (b *Bot) getSendDetails(ctx context.Context) (to, subject, body string, ok bool) {
evt := eventFromContext(ctx)
if !b.allowSend(evt.Sender, evt.RoomID) {
if !b.allowSend(ctx, evt.Sender, evt.RoomID) {
return "", "", "", false
}
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "failed to retrieve room settings: %v", err)
return "", "", "", false
@@ -568,7 +568,7 @@ func (b *Bot) getSendDetails(ctx context.Context) (to, subject, body string, ok
commandSlice := b.parseCommand(evt.Content.AsMessage().Body, false)
to, subject, body, err = utils.ParseSend(commandSlice)
if errors.Is(err, utils.ErrInvalidArgs) {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf(
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf(
"Usage:\n"+
"```\n"+
"%s send someone@example.com\n"+
@@ -585,7 +585,7 @@ func (b *Bot) getSendDetails(ctx context.Context) (to, subject, body string, ok
mailbox := cfg.Mailbox()
if mailbox == "" {
b.lp.SendNotice(evt.RoomID, "mailbox is not configured, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "mailbox is not configured, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return "", "", "", false
}
@@ -608,8 +608,8 @@ func (b *Bot) runSendCommand(ctx context.Context, cfg config.Room, tos []string,
}
}
b.lock(evt.RoomID, evt.ID)
defer b.unlock(evt.RoomID, evt.ID)
b.lock(ctx, evt.RoomID, evt.ID)
defer b.unlock(ctx, evt.RoomID, evt.ID)
domain := utils.SanitizeDomain(cfg.Domain())
from := cfg.Mailbox() + "@" + domain
@@ -617,12 +617,12 @@ func (b *Bot) runSendCommand(ctx context.Context, cfg config.Room, tos []string,
for _, to := range tos {
recipients := []string{to}
eml := email.New(ID, "", " "+ID, subject, from, to, to, "", body, htmlBody, nil, nil)
data := eml.Compose(b.cfg.GetBot().DKIMPrivateKey())
data := eml.Compose(b.cfg.GetBot(ctx).DKIMPrivateKey())
if data == "" {
b.lp.SendNotice(evt.RoomID, "email body is empty", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "email body is empty", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return
}
queued, err := b.Sendmail(evt.ID, from, to, data)
queued, err := b.Sendmail(ctx, evt.ID, from, to, data)
if queued {
b.log.Warn().Err(err).Msg("email has been queued")
b.saveSentMetadata(ctx, queued, evt.ID, recipients, eml, cfg)
@@ -635,6 +635,6 @@ func (b *Bot) runSendCommand(ctx context.Context, cfg config.Room, tos []string,
b.saveSentMetadata(ctx, false, evt.ID, recipients, eml, cfg)
}
if len(tos) > 1 {
b.lp.SendNotice(evt.RoomID, "All emails were sent.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "All emails were sent.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
}

View File

@@ -37,7 +37,7 @@ func (b *Bot) sendMailboxes(ctx context.Context) {
if !ok {
return true
}
cfg, err := b.cfg.GetRoom(roomID)
cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil {
b.log.Error().Err(err).Msg("cannot retrieve settings")
}
@@ -49,7 +49,7 @@ func (b *Bot) sendMailboxes(ctx context.Context) {
sort.Strings(slice)
if len(slice) == 0 {
b.lp.SendNotice(evt.RoomID, "No mailboxes are managed by the bot so far, kupo!", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "No mailboxes are managed by the bot so far, kupo!", linkpearl.RelatesTo(evt.ID))
return
}
@@ -64,20 +64,20 @@ func (b *Bot) sendMailboxes(ctx context.Context) {
msg.WriteString("\n")
}
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runDelete(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx)
if len(commandSlice) < 2 {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Usage: `%s delete MAILBOX`", b.prefix), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Usage: `%s delete MAILBOX`", b.prefix), linkpearl.RelatesTo(evt.ID))
return
}
mailbox := utils.Mailbox(commandSlice[1])
v, ok := b.rooms.Load(mailbox)
if v == nil || !ok {
b.lp.SendNotice(evt.RoomID, "mailbox does not exists, kupo", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "mailbox does not exists, kupo", linkpearl.RelatesTo(evt.ID))
return
}
roomID, ok := v.(id.RoomID)
@@ -86,18 +86,18 @@ func (b *Bot) runDelete(ctx context.Context, commandSlice []string) {
}
b.rooms.Delete(mailbox)
err := b.cfg.SetRoom(roomID, config.Room{})
err := b.cfg.SetRoom(ctx, roomID, config.Room{})
if err != nil {
b.Error(ctx, "cannot update settings: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, "mailbox has been deleted", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "mailbox has been deleted", linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runUsers(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx)
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 {
var msg strings.Builder
users := cfg.Users()
@@ -112,35 +112,35 @@ func (b *Bot) runUsers(ctx context.Context, commandSlice []string) {
msg.WriteString("where each pattern is like `@someone:example.com`, ")
msg.WriteString("`@bot.*:example.com`, `@*:another.com`, or `@*:*`\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return
}
_, homeserver, err := b.lp.GetClient().UserID.Parse()
if err != nil {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("invalid userID: %v", err), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("invalid userID: %v", err), linkpearl.RelatesTo(evt.ID))
}
patterns := commandSlice[1:]
allowedUsers, err := parseMXIDpatterns(patterns, "@*:"+homeserver)
if err != nil {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("invalid patterns: %v", err), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("invalid patterns: %v", err), linkpearl.RelatesTo(evt.ID))
return
}
cfg.Set(config.BotUsers, strings.Join(patterns, " "))
err = b.cfg.SetBot(cfg)
err = b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot set bot config: %v", err)
}
b.allowedUsers = allowedUsers
b.lp.SendNotice(evt.RoomID, "allowed users updated", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "allowed users updated", linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runDKIM(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx)
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
if len(commandSlice) > 1 && commandSlice[1] == "reset" {
cfg.Set(config.BotDKIMPrivateKey, "")
cfg.Set(config.BotDKIMSignature, "")
@@ -157,14 +157,14 @@ func (b *Bot) runDKIM(ctx context.Context, commandSlice []string) {
}
cfg.Set(config.BotDKIMSignature, signature)
cfg.Set(config.BotDKIMPrivateKey, private)
err := b.cfg.SetBot(cfg)
err := b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot save bot options: %v", err)
return
}
}
b.lp.SendNotice(evt.RoomID, fmt.Sprintf(
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf(
"DKIM signature is: `%s`.\n"+
"You need to add it to DNS records of all domains added to postmoogle (if not already):\n"+
"Add new DNS record with type = `TXT`, key (subdomain/from): `postmoogle._domainkey` and value (to):\n ```\n%s\n```\n"+
@@ -177,7 +177,7 @@ func (b *Bot) runDKIM(ctx context.Context, commandSlice []string) {
func (b *Bot) runCatchAll(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx)
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 {
var msg strings.Builder
msg.WriteString("Currently: `")
@@ -195,30 +195,30 @@ func (b *Bot) runCatchAll(ctx context.Context, commandSlice []string) {
msg.WriteString(" catch-all MAILBOX`")
msg.WriteString("where mailbox is valid and existing mailbox name\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return
}
mailbox := utils.Mailbox(commandSlice[1])
_, ok := b.GetMapping(mailbox)
_, ok := b.GetMapping(ctx, mailbox)
if !ok {
b.lp.SendNotice(evt.RoomID, "mailbox does not exist, kupo.", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "mailbox does not exist, kupo.", linkpearl.RelatesTo(evt.ID))
return
}
cfg.Set(config.BotCatchAll, mailbox)
err := b.cfg.SetBot(cfg)
err := b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot save bot options: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Catch-all is set to: `%s` (%s).", mailbox, utils.EmailsList(mailbox, "")), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Catch-all is set to: `%s` (%s).", mailbox, utils.EmailsList(mailbox, "")), linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runAdminRoom(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx)
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 {
var msg strings.Builder
msg.WriteString("Currently: `")
@@ -233,13 +233,13 @@ func (b *Bot) runAdminRoom(ctx context.Context, commandSlice []string) {
msg.WriteString(" adminroom ROOM_ID`")
msg.WriteString("where ROOM_ID is valid and existing matrix room id\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return
}
roomID := b.parseCommand(evt.Content.AsMessage().Body, false)[1] // get original value, without forced lower case
cfg.Set(config.BotAdminRoom, roomID)
err := b.cfg.SetBot(cfg)
err := b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot save bot options: %v", err)
return
@@ -247,12 +247,12 @@ func (b *Bot) runAdminRoom(ctx context.Context, commandSlice []string) {
b.adminRooms = append([]id.RoomID{id.RoomID(roomID)}, b.adminRooms...) // make it the first room in list on the fly
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Admin Room is set to: `%s`.", roomID), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Admin Room is set to: `%s`.", roomID), linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) printGreylist(ctx context.Context, roomID id.RoomID) {
cfg := b.cfg.GetBot()
greylist := b.cfg.GetGreylist()
cfg := b.cfg.GetBot(ctx)
greylist := b.cfg.GetGreylist(ctx)
var msg strings.Builder
size := len(greylist)
duration := cfg.Greylist()
@@ -278,7 +278,7 @@ func (b *Bot) printGreylist(ctx context.Context, roomID id.RoomID) {
msg.WriteString("where `MIN` is duration in minutes for automatic greylisting\n")
}
b.lp.SendNotice(roomID, msg.String(), linkpearl.RelatesTo(eventFromContext(ctx).ID))
b.lp.SendNotice(ctx, roomID, msg.String(), linkpearl.RelatesTo(eventFromContext(ctx).ID))
}
func (b *Bot) runGreylist(ctx context.Context, commandSlice []string) {
@@ -287,21 +287,21 @@ func (b *Bot) runGreylist(ctx context.Context, commandSlice []string) {
b.printGreylist(ctx, evt.RoomID)
return
}
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
value := utils.SanitizeIntString(commandSlice[1])
cfg.Set(config.BotGreylist, value)
err := b.cfg.SetBot(cfg)
err := b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot set bot config: %v", err)
}
b.lp.SendNotice(evt.RoomID, "greylist duration has been updated", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "greylist duration has been updated", linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runBanlist(ctx context.Context, commandSlice []string) {
evt := eventFromContext(ctx)
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 {
banlist := b.cfg.GetBanlist()
banlist := b.cfg.GetBanlist(ctx)
var msg strings.Builder
size := len(banlist)
if size > 0 {
@@ -322,26 +322,26 @@ func (b *Bot) runBanlist(ctx context.Context, commandSlice []string) {
msg.WriteString("where each ip is IPv4 or IPv6\n\n")
msg.WriteString("You can find current banlist values below:\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.addBanlistTimeline(ctx, false)
return
}
value := utils.SanitizeBoolString(commandSlice[1])
cfg.Set(config.BotBanlistEnabled, value)
err := b.cfg.SetBot(cfg)
err := b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot set bot config: %v", err)
}
b.lp.SendNotice(evt.RoomID, "banlist has been updated", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "banlist has been updated", linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runBanlistTotals(ctx context.Context) {
evt := eventFromContext(ctx)
banlist := b.cfg.GetBanlist()
banlist := b.cfg.GetBanlist(ctx)
var msg strings.Builder
size := len(banlist)
if size == 0 {
b.lp.SendNotice(evt.RoomID, "banlist is empty, kupo.", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "banlist is empty, kupo.", linkpearl.RelatesTo(evt.ID))
return
}
@@ -349,13 +349,13 @@ func (b *Bot) runBanlistTotals(ctx context.Context) {
msg.WriteString(strconv.Itoa(size))
msg.WriteString(" hosts banned\n\n")
msg.WriteString("You can find daily totals below:\n")
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.addBanlistTimeline(ctx, true)
}
func (b *Bot) runBanlistAuth(ctx context.Context, commandSlice []string) { //nolint:dupl // not in that case
evt := eventFromContext(ctx)
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 {
var msg strings.Builder
msg.WriteString("Currently: `")
@@ -368,21 +368,21 @@ func (b *Bot) runBanlistAuth(ctx context.Context, commandSlice []string) { //nol
msg.WriteString(" banlist:auth true` (banlist itself must be enabled!)\n\n")
}
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return
}
value := utils.SanitizeBoolString(commandSlice[1])
cfg.Set(config.BotBanlistAuth, value)
err := b.cfg.SetBot(cfg)
err := b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot set bot config: %v", err)
}
b.lp.SendNotice(evt.RoomID, "auth banning has been updated", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "auth banning has been updated", linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runBanlistAuto(ctx context.Context, commandSlice []string) { //nolint:dupl // not in that case
evt := eventFromContext(ctx)
cfg := b.cfg.GetBot()
cfg := b.cfg.GetBot(ctx)
if len(commandSlice) < 2 {
var msg strings.Builder
msg.WriteString("Currently: `")
@@ -395,16 +395,16 @@ func (b *Bot) runBanlistAuto(ctx context.Context, commandSlice []string) { //nol
msg.WriteString(" banlist:auto true` (banlist itself must be enabled!)\n\n")
}
b.lp.SendNotice(evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, msg.String(), linkpearl.RelatesTo(evt.ID))
return
}
value := utils.SanitizeBoolString(commandSlice[1])
cfg.Set(config.BotBanlistAuto, value)
err := b.cfg.SetBot(cfg)
err := b.cfg.SetBot(ctx, cfg)
if err != nil {
b.Error(ctx, "cannot set bot config: %v", err)
}
b.lp.SendNotice(evt.RoomID, "auto banning has been updated", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "auto banning has been updated", linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) runBanlistChange(ctx context.Context, mode string, commandSlice []string) {
@@ -413,11 +413,11 @@ func (b *Bot) runBanlistChange(ctx context.Context, mode string, commandSlice []
b.runBanlist(ctx, commandSlice)
return
}
if !b.cfg.GetBot().BanlistEnabled() {
b.lp.SendNotice(evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID))
if !b.cfg.GetBot(ctx).BanlistEnabled() {
b.lp.SendNotice(ctx, evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID))
return
}
banlist := b.cfg.GetBanlist()
banlist := b.cfg.GetBanlist(ctx)
var action func(net.Addr)
if mode == "remove" {
@@ -436,18 +436,18 @@ func (b *Bot) runBanlistChange(ctx context.Context, mode string, commandSlice []
action(addr)
}
err := b.cfg.SetBanlist(banlist)
err := b.cfg.SetBanlist(ctx, banlist)
if err != nil {
b.Error(ctx, "cannot set banlist: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, "banlist has been updated, kupo", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "banlist has been updated, kupo", linkpearl.RelatesTo(evt.ID))
}
func (b *Bot) addBanlistTimeline(ctx context.Context, onlyTotals bool) {
evt := eventFromContext(ctx)
banlist := b.cfg.GetBanlist()
banlist := b.cfg.GetBanlist(ctx)
timeline := map[string][]string{}
for ip, ts := range banlist {
key := "???"
@@ -479,22 +479,22 @@ func (b *Bot) addBanlistTimeline(ctx context.Context, onlyTotals bool) {
txt.WriteString(strings.Join(data, "`, `"))
txt.WriteString("`\n")
}
b.lp.SendNotice(evt.RoomID, txt.String(), linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, txt.String(), linkpearl.RelatesTo(evt.ID))
}
}
func (b *Bot) runBanlistReset(ctx context.Context) {
evt := eventFromContext(ctx)
if !b.cfg.GetBot().BanlistEnabled() {
b.lp.SendNotice(evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID))
if !b.cfg.GetBot(ctx).BanlistEnabled() {
b.lp.SendNotice(ctx, evt.RoomID, "banlist is disabled, you have to enable it first, kupo", linkpearl.RelatesTo(evt.ID))
return
}
err := b.cfg.SetBanlist(config.List{})
err := b.cfg.SetBanlist(ctx, config.List{})
if err != nil {
b.Error(ctx, "cannot set banlist: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, "banlist has been reset, kupo", linkpearl.RelatesTo(evt.ID))
b.lp.SendNotice(ctx, evt.RoomID, "banlist has been reset, kupo", linkpearl.RelatesTo(evt.ID))
}

View File

@@ -16,7 +16,7 @@ import (
func (b *Bot) runStop(ctx context.Context) {
evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err)
return
@@ -24,19 +24,19 @@ func (b *Bot) runStop(ctx context.Context) {
mailbox := cfg.Get(config.RoomMailbox)
if mailbox == "" {
b.lp.SendNotice(evt.RoomID, "that room is not configured yet", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "that room is not configured yet", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return
}
b.rooms.Delete(mailbox)
err = b.cfg.SetRoom(evt.RoomID, config.Room{})
err = b.cfg.SetRoom(ctx, evt.RoomID, config.Room{})
if err != nil {
b.Error(ctx, "cannot update settings: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, "mailbox has been disabled", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "mailbox has been disabled", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
func (b *Bot) handleOption(ctx context.Context, cmd []string) {
@@ -58,7 +58,7 @@ func (b *Bot) handleOption(ctx context.Context, cmd []string) {
func (b *Bot) getOption(ctx context.Context, name string) {
evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err)
return
@@ -73,7 +73,7 @@ func (b *Bot) getOption(ctx context.Context, name string) {
msg := fmt.Sprintf("`%s` is not set, kupo.\n"+
"To set it, send a `%s %s VALUE` command.",
name, b.prefix, name)
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return
}
@@ -91,18 +91,18 @@ func (b *Bot) getOption(ctx context.Context, name string) {
"or just set a new one with `%s %s NEW_PASSWORD`.",
b.prefix, name)
}
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
func (b *Bot) setMailbox(ctx context.Context, value string) {
evt := eventFromContext(ctx)
existingID, ok := b.getMapping(value)
if (ok && existingID != "" && existingID != evt.RoomID) || b.isReserved(value) {
b.lp.SendNotice(evt.RoomID, fmt.Sprintf("Mailbox `%s` (%s) already taken, kupo", value, utils.EmailsList(value, "")))
b.lp.SendNotice(ctx, evt.RoomID, fmt.Sprintf("Mailbox `%s` (%s) already taken, kupo", value, utils.EmailsList(value, "")))
return
}
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err)
return
@@ -113,23 +113,23 @@ func (b *Bot) setMailbox(ctx context.Context, value string) {
if old != "" {
b.rooms.Delete(old)
}
active := b.ActivateMailbox(evt.Sender, evt.RoomID, value)
active := b.ActivateMailbox(ctx, evt.Sender, evt.RoomID, value)
cfg.Set(config.RoomActive, strconv.FormatBool(active))
value = fmt.Sprintf("%s@%s", value, utils.SanitizeDomain(cfg.Domain()))
err = b.cfg.SetRoom(evt.RoomID, cfg)
err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil {
b.Error(ctx, "cannot update settings: %v", err)
return
}
msg := fmt.Sprintf("mailbox of this room set to `%s`", value)
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
func (b *Bot) setPassword(ctx context.Context) {
evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err)
return
@@ -143,13 +143,13 @@ func (b *Bot) setPassword(ctx context.Context) {
}
cfg.Set(config.RoomPassword, value)
err = b.cfg.SetRoom(evt.RoomID, cfg)
err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil {
b.Error(ctx, "cannot update settings: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, "SMTP password has been set", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "SMTP password has been set", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
func (b *Bot) setOption(ctx context.Context, name, value string) {
@@ -159,7 +159,7 @@ func (b *Bot) setOption(ctx context.Context, name, value string) {
}
evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "failed to retrieve settings: %v", err)
return
@@ -176,19 +176,19 @@ func (b *Bot) setOption(ctx context.Context, name, value string) {
old := cfg.Get(name)
if old == value {
b.lp.SendNotice(evt.RoomID, "nothing changed, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "nothing changed, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return
}
cfg.Set(name, value)
err = b.cfg.SetRoom(evt.RoomID, cfg)
err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil {
b.Error(ctx, "cannot update settings: %v", err)
return
}
msg := fmt.Sprintf("`%s` of this room set to:\n```\n%s\n```", name, value)
b.lp.SendNotice(evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, msg, linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
@@ -197,7 +197,7 @@ func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
b.getOption(ctx, config.RoomSpamlist)
return
}
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "cannot get room settings: %v", err)
return
@@ -212,7 +212,7 @@ func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
}
cfg.Set(config.RoomSpamlist, utils.SliceString(spamlist))
err = b.cfg.SetRoom(evt.RoomID, cfg)
err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil {
b.Error(ctx, "cannot store room settings: %v", err)
return
@@ -223,7 +223,7 @@ func (b *Bot) runSpamlistAdd(ctx context.Context, commandSlice []string) {
threadID = evt.ID
}
b.lp.SendNotice(evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(threadID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(threadID, cfg.NoThreads()))
}
func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
@@ -232,7 +232,7 @@ func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
b.getOption(ctx, config.RoomSpamlist)
return
}
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "cannot get room settings: %v", err)
return
@@ -248,7 +248,7 @@ func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
toRemove[idx] = struct{}{}
}
if len(toRemove) == 0 {
b.lp.SendNotice(evt.RoomID, "nothing new, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "nothing new, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return
}
@@ -261,34 +261,34 @@ func (b *Bot) runSpamlistRemove(ctx context.Context, commandSlice []string) {
}
cfg.Set(config.RoomSpamlist, utils.SliceString(updatedSpamlist))
err = b.cfg.SetRoom(evt.RoomID, cfg)
err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil {
b.Error(ctx, "cannot store room settings: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "spamlist has been updated, kupo", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}
func (b *Bot) runSpamlistReset(ctx context.Context) {
evt := eventFromContext(ctx)
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "cannot get room settings: %v", err)
return
}
spamlist := utils.StringSlice(cfg[config.RoomSpamlist])
if len(spamlist) == 0 {
b.lp.SendNotice(evt.RoomID, "spamlist is empty, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "spamlist is empty, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
return
}
cfg.Set(config.RoomSpamlist, "")
err = b.cfg.SetRoom(evt.RoomID, cfg)
err = b.cfg.SetRoom(ctx, evt.RoomID, cfg)
if err != nil {
b.Error(ctx, "cannot store room settings: %v", err)
return
}
b.lp.SendNotice(evt.RoomID, "spamlist has been reset, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "spamlist has been reset, kupo.", linkpearl.RelatesTo(evt.ID, cfg.NoThreads()))
}

View File

@@ -1,6 +1,8 @@
package config
import (
"context"
"github.com/rs/zerolog"
"gitlab.com/etke.cc/linkpearl"
"maunium.net/go/mautrix/id"
@@ -27,10 +29,10 @@ func New(lp *linkpearl.Linkpearl, log *zerolog.Logger) *Manager {
}
// GetBot config
func (m *Manager) GetBot() Bot {
func (m *Manager) GetBot(ctx context.Context) Bot {
var err error
var config Bot
config, err = m.lp.GetAccountData(acBotKey)
config, err = m.lp.GetAccountData(ctx, acBotKey)
if err != nil {
m.log.Error().Err(err).Msg("cannot get bot settings")
}
@@ -43,13 +45,13 @@ func (m *Manager) GetBot() Bot {
}
// SetBot config
func (m *Manager) SetBot(cfg Bot) error {
return m.lp.SetAccountData(acBotKey, cfg)
func (m *Manager) SetBot(ctx context.Context, cfg Bot) error {
return m.lp.SetAccountData(ctx, acBotKey, cfg)
}
// GetRoom config
func (m *Manager) GetRoom(roomID id.RoomID) (Room, error) {
config, err := m.lp.GetRoomAccountData(roomID, acRoomKey)
func (m *Manager) GetRoom(ctx context.Context, roomID id.RoomID) (Room, error) {
config, err := m.lp.GetRoomAccountData(ctx, roomID, acRoomKey)
if err != nil {
m.log.Warn().Err(err).Str("room_id", roomID.String()).Msg("cannot get room settings")
}
@@ -61,19 +63,19 @@ func (m *Manager) GetRoom(roomID id.RoomID) (Room, error) {
}
// SetRoom config
func (m *Manager) SetRoom(roomID id.RoomID, cfg Room) error {
return m.lp.SetRoomAccountData(roomID, acRoomKey, cfg)
func (m *Manager) SetRoom(ctx context.Context, roomID id.RoomID, cfg Room) error {
return m.lp.SetRoomAccountData(ctx, roomID, acRoomKey, cfg)
}
// GetBanlist config
func (m *Manager) GetBanlist() List {
if !m.GetBot().BanlistEnabled() {
func (m *Manager) GetBanlist(ctx context.Context) List {
if !m.GetBot(ctx).BanlistEnabled() {
return make(List, 0)
}
m.mu.Lock("banlist")
defer m.mu.Unlock("banlist")
config, err := m.lp.GetAccountData(acBanlistKey)
config, err := m.lp.GetAccountData(ctx, acBanlistKey)
if err != nil {
m.log.Error().Err(err).Msg("cannot get banlist")
}
@@ -85,8 +87,8 @@ func (m *Manager) GetBanlist() List {
}
// SetBanlist config
func (m *Manager) SetBanlist(cfg List) error {
if !m.GetBot().BanlistEnabled() {
func (m *Manager) SetBanlist(ctx context.Context, cfg List) error {
if !m.GetBot(ctx).BanlistEnabled() {
return nil
}
@@ -96,12 +98,12 @@ func (m *Manager) SetBanlist(cfg List) error {
cfg = make(List, 0)
}
return m.lp.SetAccountData(acBanlistKey, cfg)
return m.lp.SetAccountData(ctx, acBanlistKey, cfg)
}
// GetGreylist config
func (m *Manager) GetGreylist() List {
config, err := m.lp.GetAccountData(acGreylistKey)
func (m *Manager) GetGreylist(ctx context.Context) List {
config, err := m.lp.GetAccountData(ctx, acGreylistKey)
if err != nil {
m.log.Error().Err(err).Msg("cannot get banlist")
}
@@ -114,6 +116,6 @@ func (m *Manager) GetGreylist() List {
}
// SetGreylist config
func (m *Manager) SetGreylist(cfg List) error {
return m.lp.SetAccountData(acGreylistKey, cfg)
func (m *Manager) SetGreylist(ctx context.Context, cfg List) error {
return m.lp.SetAccountData(ctx, acGreylistKey, cfg)
}

View File

@@ -15,8 +15,10 @@ const (
ctxThreadID ctxkey = iota
)
func newContext(evt *event.Event) context.Context {
ctx := context.Background()
func newContext(ctx context.Context, evt *event.Event) context.Context {
if ctx == nil {
ctx = context.Background()
}
hub := sentry.CurrentHub().Clone()
ctx = sentry.SetHubOnContext(ctx, hub)
ctx = eventToContext(ctx, evt)

View File

@@ -1,6 +1,7 @@
package bot
import (
"context"
"strconv"
"time"
@@ -9,21 +10,21 @@ import (
"gitlab.com/etke.cc/postmoogle/bot/config"
)
func (b *Bot) syncRooms() error {
func (b *Bot) syncRooms(ctx context.Context) error {
adminRooms := []id.RoomID{}
adminRoom := b.cfg.GetBot().AdminRoom()
adminRoom := b.cfg.GetBot(ctx).AdminRoom()
if adminRoom != "" {
adminRooms = append(adminRooms, adminRoom)
}
resp, err := b.lp.GetClient().JoinedRooms()
resp, err := b.lp.GetClient().JoinedRooms(ctx)
if err != nil {
return err
}
for _, roomID := range resp.JoinedRooms {
b.migrateRoomSettings(roomID)
cfg, serr := b.cfg.GetRoom(roomID)
b.migrateRoomSettings(ctx, roomID)
cfg, serr := b.cfg.GetRoom(ctx, roomID)
if serr != nil {
continue
}
@@ -33,7 +34,7 @@ func (b *Bot) syncRooms() error {
b.rooms.Store(mailbox, roomID)
}
if cfg.Owner() != "" && b.allowAdmin(id.UserID(cfg.Owner()), "") {
if cfg.Owner() != "" && b.allowAdmin(ctx, id.UserID(cfg.Owner()), "") {
adminRooms = append(adminRooms, roomID)
}
}
@@ -42,8 +43,8 @@ func (b *Bot) syncRooms() error {
return nil
}
func (b *Bot) migrateRoomSettings(roomID id.RoomID) {
cfg, err := b.cfg.GetRoom(roomID)
func (b *Bot) migrateRoomSettings(ctx context.Context, roomID id.RoomID) {
cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil {
b.log.Error().Err(err).Msg("cannot retrieve room settings")
return
@@ -56,7 +57,7 @@ func (b *Bot) migrateRoomSettings(roomID id.RoomID) {
return
}
cfg.MigrateSpamlistSettings()
err = b.cfg.SetRoom(roomID, cfg)
err = b.cfg.SetRoom(ctx, roomID, cfg)
if err != nil {
b.log.Error().Err(err).Msg("cannot migrate room settings")
}
@@ -68,8 +69,8 @@ func (b *Bot) migrateRoomSettings(roomID id.RoomID) {
// alongside with other database configs to simplify maintenance,
// but with that simplification there is no proper way to migrate
// existing sync token and session info. No data loss, tho.
func (b *Bot) migrateMautrix015() error {
cfg := b.cfg.GetBot()
func (b *Bot) migrateMautrix015(ctx context.Context) error {
cfg := b.cfg.GetBot(ctx)
ts := cfg.Mautrix015Migration()
// already migrated
if ts > 0 {
@@ -82,11 +83,11 @@ func (b *Bot) migrateMautrix015() error {
tss := strconv.FormatInt(ts, 10)
cfg.Set(config.BotMautrix015Migration, tss)
return b.cfg.SetBot(cfg)
return b.cfg.SetBot(ctx, cfg)
}
func (b *Bot) initBotUsers() ([]string, error) {
cfg := b.cfg.GetBot()
func (b *Bot) initBotUsers(ctx context.Context) ([]string, error) {
cfg := b.cfg.GetBot(ctx)
cfgUsers := cfg.Users()
if len(cfgUsers) > 0 {
return cfgUsers, nil
@@ -97,10 +98,10 @@ func (b *Bot) initBotUsers() ([]string, error) {
return nil, err
}
cfg.Set(config.BotUsers, "@*:"+homeserver)
return cfg.Users(), b.cfg.SetBot(cfg)
return cfg.Users(), b.cfg.SetBot(ctx, cfg)
}
// SyncRooms and mailboxes
func (b *Bot) SyncRooms() {
b.syncRooms() //nolint:errcheck // nothing can be done here
b.syncRooms(context.Background()) //nolint:errcheck // nothing can be done here
}

View File

@@ -60,14 +60,14 @@ func (b *Bot) shouldQueue(msg string) bool {
// Sendmail tries to send email immediately, but if it gets 4xx error (greylisting),
// the email will be added to the queue and retried several times after that
func (b *Bot) Sendmail(eventID id.EventID, from, to, data string) (bool, error) {
func (b *Bot) Sendmail(ctx context.Context, eventID id.EventID, from, to, data string) (bool, error) {
log := b.log.With().Str("from", from).Str("to", to).Str("eventID", eventID.String()).Logger()
log.Info().Msg("attempting to deliver email")
err := b.sendmail(from, to, data)
if err != nil {
if b.shouldQueue(err.Error()) {
log.Info().Err(err).Msg("email has been added to the queue")
return true, b.q.Add(eventID.String(), from, to, data)
return true, b.q.Add(ctx, eventID.String(), from, to, data)
}
log.Warn().Err(err).Msg("email delivery failed")
return false, err
@@ -78,8 +78,8 @@ func (b *Bot) Sendmail(eventID id.EventID, from, to, data string) (bool, error)
}
// GetDKIMprivkey returns DKIM private key
func (b *Bot) GetDKIMprivkey() string {
return b.cfg.GetBot().DKIMPrivateKey()
func (b *Bot) GetDKIMprivkey(ctx context.Context) string {
return b.cfg.GetBot(ctx).DKIMPrivateKey()
}
func (b *Bot) getMapping(mailbox string) (id.RoomID, bool) {
@@ -97,10 +97,10 @@ func (b *Bot) getMapping(mailbox string) (id.RoomID, bool) {
}
// GetMapping returns mapping of mailbox = room
func (b *Bot) GetMapping(mailbox string) (id.RoomID, bool) {
func (b *Bot) GetMapping(ctx context.Context, mailbox string) (id.RoomID, bool) {
roomID, ok := b.getMapping(mailbox)
if !ok {
catchAll := b.cfg.GetBot().CatchAll()
catchAll := b.cfg.GetBot(ctx).CatchAll()
if catchAll == "" {
return roomID, ok
}
@@ -111,8 +111,8 @@ func (b *Bot) GetMapping(mailbox string) (id.RoomID, bool) {
}
// GetIFOptions returns incoming email filtering options (room settings)
func (b *Bot) GetIFOptions(roomID id.RoomID) email.IncomingFilteringOptions {
cfg, err := b.cfg.GetRoom(roomID)
func (b *Bot) GetIFOptions(ctx context.Context, roomID id.RoomID) email.IncomingFilteringOptions {
cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil {
b.log.Error().Err(err).Msg("cannot retrieve room settings")
}
@@ -124,11 +124,11 @@ func (b *Bot) GetIFOptions(roomID id.RoomID) email.IncomingFilteringOptions {
//
//nolint:gocognit // TODO
func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
roomID, ok := b.GetMapping(eml.Mailbox(true))
roomID, ok := b.GetMapping(ctx, eml.Mailbox(true))
if !ok {
return ErrNoRoom
}
cfg, err := b.cfg.GetRoom(roomID)
cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil {
b.Error(ctx, "cannot get settings: %v", err)
}
@@ -139,15 +139,15 @@ func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
var threadID id.EventID
newThread := true
if eml.InReplyTo != "" || eml.References != "" {
threadID = b.getThreadID(roomID, eml.InReplyTo, eml.References)
threadID = b.getThreadID(ctx, roomID, eml.InReplyTo, eml.References)
if threadID != "" {
newThread = false
ctx = threadIDToContext(ctx, threadID)
b.setThreadID(roomID, eml.MessageID, threadID)
b.setThreadID(ctx, roomID, eml.MessageID, threadID)
}
}
content := eml.Content(threadID, cfg.ContentOptions())
eventID, serr := b.lp.Send(roomID, content)
content := eml.Content(threadID, cfg.ContentOptions(), b.psd)
eventID, serr := b.lp.Send(ctx, roomID, content)
if serr != nil {
if !strings.Contains(serr.Error(), "M_UNKNOWN") { // if it's not an unknown event error
return serr
@@ -160,11 +160,11 @@ func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
ctx = threadIDToContext(ctx, threadID)
}
b.setThreadID(roomID, eml.MessageID, threadID)
b.setLastEventID(roomID, threadID, eventID)
b.setThreadID(ctx, roomID, eml.MessageID, threadID)
b.setLastEventID(ctx, roomID, threadID, eventID)
if newThread && cfg.Threadify() {
_, berr := b.lp.Send(roomID, eml.ContentBody(threadID, cfg.ContentOptions()))
_, berr := b.lp.Send(ctx, roomID, eml.ContentBody(threadID, cfg.ContentOptions()))
if berr != nil {
return berr
}
@@ -179,15 +179,15 @@ func (b *Bot) IncomingEmail(ctx context.Context, eml *email.Email) error {
}
if newThread && cfg.Autoreply() != "" {
b.sendAutoreply(roomID, threadID)
b.sendAutoreply(ctx, roomID, threadID)
}
return nil
}
//nolint:gocognit // TODO
func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
cfg, err := b.cfg.GetRoom(roomID)
func (b *Bot) sendAutoreply(ctx context.Context, roomID id.RoomID, threadID id.EventID) {
cfg, err := b.cfg.GetRoom(ctx, roomID)
if err != nil {
return
}
@@ -197,7 +197,7 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
return
}
threadEvt, err := b.lp.GetClient().GetEvent(roomID, threadID)
threadEvt, err := b.lp.GetClient().GetEvent(ctx, roomID, threadID)
if err != nil {
b.log.Error().Err(err).Msg("cannot get thread event for autoreply")
return
@@ -216,7 +216,7 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
},
}
meta := b.getParentEmail(evt, cfg.Mailbox())
meta := b.getParentEmail(ctx, evt, cfg.Mailbox())
if meta.To == "" {
return
@@ -246,16 +246,16 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
meta.References = meta.References + " " + meta.MessageID
b.log.Info().Any("meta", meta).Msg("sending automatic reply")
eml := email.New(meta.MessageID, meta.InReplyTo, meta.References, meta.Subject, meta.From, meta.To, meta.RcptTo, meta.CC, body, htmlBody, nil, nil)
data := eml.Compose(b.cfg.GetBot().DKIMPrivateKey())
data := eml.Compose(b.cfg.GetBot(ctx).DKIMPrivateKey())
if data == "" {
return
}
var queued bool
ctx := newContext(threadEvt)
ctx = newContext(ctx, threadEvt)
recipients := meta.Recipients
for _, to := range recipients {
queued, err = b.Sendmail(evt.ID, meta.From, to, data)
queued, err = b.Sendmail(ctx, evt.ID, meta.From, to, data)
if queued {
b.log.Info().Err(err).Str("from", meta.From).Str("to", to).Msg("email has been queued")
b.saveSentMetadata(ctx, queued, meta.ThreadID, recipients, eml, cfg, "Autoreply has been sent to "+to+" (queued)")
@@ -273,7 +273,7 @@ func (b *Bot) sendAutoreply(roomID id.RoomID, threadID id.EventID) {
func (b *Bot) canReply(ctx context.Context) bool {
evt := eventFromContext(ctx)
return b.allowSend(evt.Sender, evt.RoomID) && b.allowReply(evt.Sender, evt.RoomID)
return b.allowSend(ctx, evt.Sender, evt.RoomID) && b.allowReply(ctx, evt.Sender, evt.RoomID)
}
// SendEmailReply sends replies from matrix thread to email thread
@@ -284,7 +284,7 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
if !b.canReply(ctx) {
return
}
cfg, err := b.cfg.GetRoom(evt.RoomID)
cfg, err := b.cfg.GetRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "cannot retrieve room settings: %v", err)
return
@@ -295,10 +295,10 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
return
}
b.lock(evt.RoomID, evt.ID)
defer b.unlock(evt.RoomID, evt.ID)
b.lock(ctx, evt.RoomID, evt.ID)
defer b.unlock(ctx, evt.RoomID, evt.ID)
meta := b.getParentEmail(evt, mailbox)
meta := b.getParentEmail(ctx, evt, mailbox)
if meta.To == "" {
b.Error(ctx, "cannot find parent email and continue the thread. Please, start a new email thread")
@@ -306,7 +306,7 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
}
if meta.ThreadID == "" {
meta.ThreadID = b.getThreadID(evt.RoomID, meta.InReplyTo, meta.References)
meta.ThreadID = b.getThreadID(ctx, evt.RoomID, meta.InReplyTo, meta.References)
ctx = threadIDToContext(ctx, meta.ThreadID)
}
content := evt.Content.AsMessage()
@@ -330,16 +330,16 @@ func (b *Bot) SendEmailReply(ctx context.Context) {
meta.References = meta.References + " " + meta.MessageID
b.log.Info().Any("meta", meta).Msg("sending email reply")
eml := email.New(meta.MessageID, meta.InReplyTo, meta.References, meta.Subject, meta.From, meta.To, meta.RcptTo, meta.CC, body, htmlBody, nil, nil)
data := eml.Compose(b.cfg.GetBot().DKIMPrivateKey())
data := eml.Compose(b.cfg.GetBot(ctx).DKIMPrivateKey())
if data == "" {
b.lp.SendNotice(evt.RoomID, "email body is empty", linkpearl.RelatesTo(meta.ThreadID, cfg.NoThreads()))
b.lp.SendNotice(ctx, evt.RoomID, "email body is empty", linkpearl.RelatesTo(meta.ThreadID, cfg.NoThreads()))
return
}
var queued bool
recipients := meta.Recipients
for _, to := range recipients {
queued, err = b.Sendmail(evt.ID, meta.From, to, data)
queued, err = b.Sendmail(ctx, evt.ID, meta.From, to, data)
if queued {
b.log.Info().Err(err).Str("from", meta.From).Str("to", to).Msg("email has been queued")
b.saveSentMetadata(ctx, queued, meta.ThreadID, recipients, eml, cfg)
@@ -444,7 +444,7 @@ func (e *parentEmail) calculateRecipients(from string, forwardedFrom []string) {
e.Recipients = rcpts
}
func (b *Bot) getParentEvent(evt *event.Event) (id.EventID, *event.Event) {
func (b *Bot) getParentEvent(ctx context.Context, evt *event.Event) (id.EventID, *event.Event) {
content := evt.Content.AsMessage()
threadID := linkpearl.EventParent(evt.ID, content)
b.log.Debug().Str("eventID", evt.ID.String()).Str("threadID", threadID.String()).Msg("looking up for the parent event within thread")
@@ -452,23 +452,23 @@ func (b *Bot) getParentEvent(evt *event.Event) (id.EventID, *event.Event) {
b.log.Debug().Str("eventID", evt.ID.String()).Msg("event is the thread itself")
return threadID, evt
}
lastEventID := b.getLastEventID(evt.RoomID, threadID)
lastEventID := b.getLastEventID(ctx, evt.RoomID, threadID)
b.log.Debug().Str("eventID", evt.ID.String()).Str("threadID", threadID.String()).Str("lastEventID", lastEventID.String()).Msg("the last event of the thread (and parent of the event) has been found")
if lastEventID == evt.ID {
return threadID, evt
}
parentEvt, err := b.lp.GetClient().GetEvent(evt.RoomID, lastEventID)
parentEvt, err := b.lp.GetClient().GetEvent(ctx, evt.RoomID, lastEventID)
if err != nil {
b.log.Error().Err(err).Msg("cannot get parent event")
return threadID, nil
}
linkpearl.ParseContent(parentEvt, b.log)
if !b.lp.GetMachine().StateStore.IsEncrypted(evt.RoomID) {
if ok, _ := b.lp.GetMachine().StateStore.IsEncrypted(ctx, evt.RoomID); !ok { //nolint:errcheck // that's fine
return threadID, parentEvt
}
decrypted, err := b.lp.GetClient().Crypto.Decrypt(parentEvt)
decrypted, err := b.lp.GetClient().Crypto.Decrypt(ctx, parentEvt)
if err != nil {
b.log.Error().Err(err).Msg("cannot decrypt parent event")
return threadID, nil
@@ -477,9 +477,9 @@ func (b *Bot) getParentEvent(evt *event.Event) (id.EventID, *event.Event) {
return threadID, decrypted
}
func (b *Bot) getParentEmail(evt *event.Event, newFromMailbox string) *parentEmail {
func (b *Bot) getParentEmail(ctx context.Context, evt *event.Event, newFromMailbox string) *parentEmail {
parent := &parentEmail{}
threadID, parentEvt := b.getParentEvent(evt)
threadID, parentEvt := b.getParentEvent(ctx, evt)
parent.ThreadID = threadID
if parentEvt == nil {
return parent
@@ -527,7 +527,7 @@ func (b *Bot) saveSentMetadata(ctx context.Context, queued bool, threadID id.Eve
}
evt := eventFromContext(ctx)
content := eml.Content(threadID, cfg.ContentOptions())
content := eml.Content(threadID, cfg.ContentOptions(), b.psd)
notice := format.RenderMarkdown(text, true, true)
msgContent, ok := content.Parsed.(*event.MessageEventContent)
if !ok {
@@ -539,28 +539,28 @@ func (b *Bot) saveSentMetadata(ctx context.Context, queued bool, threadID id.Eve
msgContent.FormattedBody = notice.FormattedBody
msgContent.RelatesTo = linkpearl.RelatesTo(threadID, cfg.NoThreads())
content.Parsed = msgContent
msgID, err := b.lp.Send(evt.RoomID, content)
msgID, err := b.lp.Send(ctx, evt.RoomID, content)
if err != nil {
b.Error(ctx, "cannot send notice: %v", err)
return
}
domain := utils.SanitizeDomain(cfg.Domain())
b.setThreadID(evt.RoomID, email.MessageID(evt.ID, domain), threadID)
b.setThreadID(evt.RoomID, email.MessageID(msgID, domain), threadID)
b.setLastEventID(evt.RoomID, threadID, msgID)
b.setThreadID(ctx, evt.RoomID, email.MessageID(evt.ID, domain), threadID)
b.setThreadID(ctx, evt.RoomID, email.MessageID(msgID, domain), threadID)
b.setLastEventID(ctx, evt.RoomID, threadID, msgID)
}
func (b *Bot) sendFiles(ctx context.Context, roomID id.RoomID, files []*utils.File, noThreads bool, parentID id.EventID) {
for _, file := range files {
req := file.Convert()
err := b.lp.SendFile(roomID, req, file.MsgType, linkpearl.RelatesTo(parentID, noThreads))
err := b.lp.SendFile(ctx, roomID, req, file.MsgType, linkpearl.RelatesTo(parentID, noThreads))
if err != nil {
b.Error(ctx, "cannot upload file %s: %v", req.FileName, err)
}
}
}
func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.EventID {
func (b *Bot) getThreadID(ctx context.Context, roomID id.RoomID, messageID, references string) id.EventID {
refs := []string{messageID}
if references != "" {
refs = append(refs, strings.Split(references, " ")...)
@@ -568,7 +568,7 @@ func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.Eve
for _, refID := range refs {
key := acMessagePrefix + "." + refID
data, err := b.lp.GetRoomAccountData(roomID, key)
data, err := b.lp.GetRoomAccountData(ctx, roomID, key)
if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot retrieve thread ID")
continue
@@ -576,7 +576,7 @@ func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.Eve
if data["eventID"] == "" {
continue
}
resp, err := b.lp.GetClient().GetEvent(roomID, id.EventID(data["eventID"]))
resp, err := b.lp.GetClient().GetEvent(ctx, roomID, id.EventID(data["eventID"]))
if err != nil {
b.log.Warn().Err(err).Str("roomID", roomID.String()).Str("eventID", data["eventID"]).Msg("cannot get event by id (may be removed)")
continue
@@ -587,17 +587,17 @@ func (b *Bot) getThreadID(roomID id.RoomID, messageID, references string) id.Eve
return ""
}
func (b *Bot) setThreadID(roomID id.RoomID, messageID string, eventID id.EventID) {
func (b *Bot) setThreadID(ctx context.Context, roomID id.RoomID, messageID string, eventID id.EventID) {
key := acMessagePrefix + "." + messageID
err := b.lp.SetRoomAccountData(roomID, key, map[string]string{"eventID": eventID.String()})
err := b.lp.SetRoomAccountData(ctx, roomID, key, map[string]string{"eventID": eventID.String()})
if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot save thread ID")
}
}
func (b *Bot) getLastEventID(roomID id.RoomID, threadID id.EventID) id.EventID {
func (b *Bot) getLastEventID(ctx context.Context, roomID id.RoomID, threadID id.EventID) id.EventID {
key := acLastEventPrefix + "." + threadID.String()
data, err := b.lp.GetRoomAccountData(roomID, key)
data, err := b.lp.GetRoomAccountData(ctx, roomID, key)
if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot retrieve last event ID")
return threadID
@@ -609,9 +609,9 @@ func (b *Bot) getLastEventID(roomID id.RoomID, threadID id.EventID) id.EventID {
return threadID
}
func (b *Bot) setLastEventID(roomID id.RoomID, threadID, eventID id.EventID) {
func (b *Bot) setLastEventID(ctx context.Context, roomID id.RoomID, threadID, eventID id.EventID) {
key := acLastEventPrefix + "." + threadID.String()
err := b.lp.SetRoomAccountData(roomID, key, map[string]string{"eventID": eventID.String()})
err := b.lp.SetRoomAccountData(ctx, roomID, key, map[string]string{"eventID": eventID.String()})
if err != nil {
b.log.Error().Err(err).Str("key", key).Msg("cannot save thread ID")
}

View File

@@ -1,29 +1,31 @@
package bot
import (
"context"
"maunium.net/go/mautrix/id"
)
func (b *Bot) lock(roomID id.RoomID, optionalEventID ...id.EventID) {
func (b *Bot) lock(ctx context.Context, roomID id.RoomID, optionalEventID ...id.EventID) {
b.mu.Lock(roomID.String())
if len(optionalEventID) == 0 {
return
}
evtID := optionalEventID[0]
if _, err := b.lp.GetClient().SendReaction(roomID, evtID, "📨"); err != nil {
if _, err := b.lp.GetClient().SendReaction(ctx, roomID, evtID, "📨"); err != nil {
b.log.Error().Err(err).Str("roomID", roomID.String()).Str("eventID", evtID.String()).Msg("cannot send reaction on lock")
}
}
func (b *Bot) unlock(roomID id.RoomID, optionalEventID ...id.EventID) {
func (b *Bot) unlock(ctx context.Context, roomID id.RoomID, optionalEventID ...id.EventID) {
b.mu.Unlock(roomID.String())
if len(optionalEventID) == 0 {
return
}
evtID := optionalEventID[0]
if _, err := b.lp.GetClient().SendReaction(roomID, evtID, "✅"); err != nil {
if _, err := b.lp.GetClient().SendReaction(ctx, roomID, evtID, "✅"); err != nil {
b.log.Error().Err(err).Str("roomID", roomID.String()).Str("eventID", evtID.String()).Msg("cannot send reaction on unlock")
}
}

View File

@@ -1,6 +1,8 @@
package queue
import (
"context"
"github.com/rs/zerolog"
"gitlab.com/etke.cc/linkpearl"
@@ -41,7 +43,8 @@ func (q *Queue) SetSendmail(function func(string, string, string) error) {
// Process queue
func (q *Queue) Process() {
q.log.Debug().Msg("staring queue processing...")
cfg := q.cfg.GetBot()
ctx := context.Background()
cfg := q.cfg.GetBot(ctx)
batchSize := cfg.QueueBatch()
if batchSize == 0 {
@@ -55,7 +58,7 @@ func (q *Queue) Process() {
q.mu.Lock(acQueueKey)
defer q.mu.Unlock(acQueueKey)
index, err := q.lp.GetAccountData(acQueueKey)
index, err := q.lp.GetAccountData(ctx, acQueueKey)
if err != nil {
q.log.Error().Err(err).Msg("cannot get queue index")
}
@@ -66,9 +69,9 @@ func (q *Queue) Process() {
q.log.Debug().Msg("finished re-deliveries from queue")
return
}
if dequeue := q.try(itemkey, maxRetries); dequeue {
if dequeue := q.try(ctx, itemkey, maxRetries); dequeue {
q.log.Info().Str("id", id).Msg("email has been delivered")
err = q.Remove(id)
err = q.Remove(ctx, id)
if err != nil {
q.log.Error().Err(err).Str("id", id).Msg("cannot dequeue email")
}

View File

@@ -1,11 +1,12 @@
package queue
import (
"context"
"strconv"
)
// Add to queue
func (q *Queue) Add(id, from, to, data string) error {
func (q *Queue) Add(ctx context.Context, id, from, to, data string) error {
itemkey := acQueueKey + "." + id
item := map[string]string{
"attempts": "0",
@@ -17,7 +18,7 @@ func (q *Queue) Add(id, from, to, data string) error {
q.mu.Lock(itemkey)
defer q.mu.Unlock(itemkey)
err := q.lp.SetAccountData(itemkey, item)
err := q.lp.SetAccountData(ctx, itemkey, item)
if err != nil {
q.log.Error().Err(err).Str("id", id).Msg("cannot enqueue email")
return err
@@ -25,13 +26,13 @@ func (q *Queue) Add(id, from, to, data string) error {
q.mu.Lock(acQueueKey)
defer q.mu.Unlock(acQueueKey)
queueIndex, err := q.lp.GetAccountData(acQueueKey)
queueIndex, err := q.lp.GetAccountData(ctx, acQueueKey)
if err != nil {
q.log.Error().Err(err).Msg("cannot get queue index")
return err
}
queueIndex[id] = itemkey
err = q.lp.SetAccountData(acQueueKey, queueIndex)
err = q.lp.SetAccountData(ctx, acQueueKey, queueIndex)
if err != nil {
q.log.Error().Err(err).Msg("cannot save queue index")
return err
@@ -41,8 +42,8 @@ func (q *Queue) Add(id, from, to, data string) error {
}
// Remove from queue
func (q *Queue) Remove(id string) error {
index, err := q.lp.GetAccountData(acQueueKey)
func (q *Queue) Remove(ctx context.Context, id string) error {
index, err := q.lp.GetAccountData(ctx, acQueueKey)
if err != nil {
q.log.Error().Err(err).Msg("cannot get queue index")
return err
@@ -52,7 +53,7 @@ func (q *Queue) Remove(id string) error {
itemkey = acQueueKey + "." + id
}
delete(index, id)
err = q.lp.SetAccountData(acQueueKey, index)
err = q.lp.SetAccountData(ctx, acQueueKey, index)
if err != nil {
q.log.Error().Err(err).Msg("cannot update queue index")
return err
@@ -60,15 +61,15 @@ func (q *Queue) Remove(id string) error {
q.mu.Lock(itemkey)
defer q.mu.Unlock(itemkey)
return q.lp.SetAccountData(itemkey, map[string]string{})
return q.lp.SetAccountData(ctx, itemkey, map[string]string{})
}
// try to send email
func (q *Queue) try(itemkey string, maxRetries int) bool {
func (q *Queue) try(ctx context.Context, itemkey string, maxRetries int) bool {
q.mu.Lock(itemkey)
defer q.mu.Unlock(itemkey)
item, err := q.lp.GetAccountData(itemkey)
item, err := q.lp.GetAccountData(ctx, itemkey)
if err != nil {
q.log.Error().Err(err).Str("id", itemkey).Msg("cannot retrieve a queue item")
return false
@@ -92,7 +93,7 @@ func (q *Queue) try(itemkey string, maxRetries int) bool {
q.log.Info().Str("id", itemkey).Str("from", item["from"]).Str("to", item["to"]).Err(err).Msg("attempted to deliver email, but it's not ready yet")
attempts++
item["attempts"] = strconv.Itoa(attempts)
err = q.lp.SetAccountData(itemkey, item)
err = q.lp.SetAccountData(ctx, itemkey, item)
if err != nil {
q.log.Error().Err(err).Str("id", itemkey).Msg("cannot update attempt count on email")
}

View File

@@ -22,14 +22,14 @@ func (b *Bot) handleReaction(ctx context.Context) {
}
srcID := content.GetRelatesTo().EventID
srcEvt, err := b.lp.GetClient().GetEvent(evt.RoomID, srcID)
srcEvt, err := b.lp.GetClient().GetEvent(ctx, evt.RoomID, srcID)
if err != nil {
b.Error(ctx, "cannot find event %s: %v", srcID, err)
return
}
linkpearl.ParseContent(srcEvt, b.log)
if b.lp.GetMachine().StateStore.IsEncrypted(evt.RoomID) {
decrypted, derr := b.lp.GetClient().Crypto.Decrypt(srcEvt)
if ok, _ := b.lp.GetMachine().StateStore.IsEncrypted(ctx, evt.RoomID); ok { //nolint:errcheck // that's ok
decrypted, derr := b.lp.GetClient().Crypto.Decrypt(ctx, srcEvt)
if derr == nil {
srcEvt = decrypted
}

View File

@@ -3,7 +3,6 @@ package bot
import (
"context"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
)
@@ -12,27 +11,27 @@ func (b *Bot) initSync() {
b.lp.OnEventType(
event.StateMember,
func(_ mautrix.EventSource, evt *event.Event) {
go b.onMembership(evt)
func(ctx context.Context, evt *event.Event) {
go b.onMembership(ctx, evt)
},
)
b.lp.OnEventType(
event.EventMessage,
func(_ mautrix.EventSource, evt *event.Event) {
go b.onMessage(evt)
func(ctx context.Context, evt *event.Event) {
go b.onMessage(ctx, evt)
},
)
b.lp.OnEventType(
event.EventReaction,
func(_ mautrix.EventSource, evt *event.Event) {
go b.onReaction(evt)
func(ctx context.Context, evt *event.Event) {
go b.onReaction(ctx, evt)
},
)
}
// joinPermit is called by linkpearl when processing "invite" events and deciding if rooms should be auto-joined or not
func (b *Bot) joinPermit(evt *event.Event) bool {
if !b.allowUsers(evt.Sender, evt.RoomID) {
func (b *Bot) joinPermit(ctx context.Context, evt *event.Event) bool {
if !b.allowUsers(ctx, evt.Sender, evt.RoomID) {
b.log.Debug().Str("userID", evt.Sender.String()).Msg("Rejecting room invitation from unallowed user")
return false
}
@@ -40,13 +39,13 @@ func (b *Bot) joinPermit(evt *event.Event) bool {
return true
}
func (b *Bot) onMembership(evt *event.Event) {
func (b *Bot) onMembership(ctx context.Context, evt *event.Event) {
// mautrix 0.15.x migration
if b.ignoreBefore >= evt.Timestamp {
return
}
ctx := newContext(evt)
ctx = newContext(ctx, evt)
evtType := evt.Content.AsMember().Membership
if evtType == event.MembershipJoin && evt.Sender == b.lp.GetClient().UserID {
@@ -61,7 +60,7 @@ func (b *Bot) onMembership(evt *event.Event) {
// Potentially handle other membership events in the future
}
func (b *Bot) onMessage(evt *event.Event) {
func (b *Bot) onMessage(ctx context.Context, evt *event.Event) {
// ignore own messages
if evt.Sender == b.lp.GetClient().UserID {
return
@@ -71,11 +70,11 @@ func (b *Bot) onMessage(evt *event.Event) {
return
}
ctx := newContext(evt)
ctx = newContext(ctx, evt)
b.handle(ctx)
}
func (b *Bot) onReaction(evt *event.Event) {
func (b *Bot) onReaction(ctx context.Context, evt *event.Event) {
// ignore own messages
if evt.Sender == b.lp.GetClient().UserID {
return
@@ -85,7 +84,7 @@ func (b *Bot) onReaction(evt *event.Event) {
return
}
ctx := newContext(evt)
ctx = newContext(ctx, evt)
b.handleReaction(ctx)
}
@@ -100,7 +99,7 @@ func (b *Bot) onBotJoin(ctx context.Context) {
return
}
b.sendIntroduction(evt.RoomID)
b.sendIntroduction(ctx, evt.RoomID)
b.sendHelp(ctx)
}
@@ -111,7 +110,7 @@ func (b *Bot) onLeave(ctx context.Context) {
b.log.Info().Str("eventID", evt.ID.String()).Msg("Suppressing already handled event")
return
}
members, err := b.lp.GetClient().StateStore.GetRoomJoinedOrInvitedMembers(evt.RoomID)
members, err := b.lp.GetClient().StateStore.GetRoomJoinedOrInvitedMembers(ctx, evt.RoomID)
if err != nil {
b.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members")
return
@@ -121,7 +120,7 @@ func (b *Bot) onLeave(ctx context.Context) {
if count == 1 && members[0] == b.lp.GetClient().UserID {
b.log.Info().Str("roomID", evt.RoomID.String()).Msg("no more users left in the room")
b.runStop(ctx)
_, err := b.lp.GetClient().LeaveRoom(evt.RoomID)
_, err := b.lp.GetClient().LeaveRoom(ctx, evt.RoomID)
if err != nil {
b.Error(ctx, "cannot leave empty room: %v", err)
}

View File

@@ -114,9 +114,10 @@ func initMatrix(cfg *config.Config) {
log.Fatal().Err(err).Msg("cannot initialize matrix bot")
}
psd := utils.NewPSD(cfg.PSD.URL, cfg.PSD.Login, cfg.PSD.Password, &log)
mxc = mxconfig.New(lp, &log)
q = queue.New(lp, mxc, &log)
mxb, err = bot.New(q, lp, &log, mxc, cfg.Proxies, cfg.Prefix, cfg.Domains, cfg.Admins, bot.MBXConfig(cfg.Mailboxes))
mxb, err = bot.New(q, lp, &log, mxc, psd, cfg.Proxies, cfg.Prefix, cfg.Domains, cfg.Admins, bot.MBXConfig(cfg.Mailboxes))
if err != nil {
log.Panic().Err(err).Msg("cannot start matrix bot")
}

View File

@@ -48,6 +48,11 @@ func New() *Config {
DSN: env.String("db.dsn", defaultConfig.DB.DSN),
Dialect: env.String("db.dialect", defaultConfig.DB.Dialect),
},
PSD: PSD{
URL: env.String("psd.url"),
Login: env.String("psd.login"),
Password: env.String("psd.password"),
},
Relay: Relay{
Host: env.String("relay.host", defaultConfig.Relay.Host),
Port: env.String("relay.port", defaultConfig.Relay.Port),

View File

@@ -38,6 +38,9 @@ type Config struct {
// DB config
DB DB
// PSD config
PSD PSD
// TLS config
TLS TLS
@@ -78,6 +81,12 @@ type Mailboxes struct {
Activation string
}
type PSD struct {
URL string
Login string
Password string
}
// Relay config
type Relay struct {
Host string

View File

@@ -108,8 +108,9 @@ func (e *Email) Mailbox(incoming bool) string {
return utils.Mailbox(e.From)
}
func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, options *ContentOptions) {
func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, options *ContentOptions, psd *utils.PSD) {
if options.Sender {
text.WriteString(psd.Status(e.From))
text.WriteString(e.From)
}
if options.Recipient {
@@ -125,8 +126,12 @@ func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, option
}
}
if options.CC && len(e.CC) > 0 {
ccs := make([]string, 0, len(e.CC))
for _, addr := range e.CC {
ccs = append(ccs, psd.Status(addr)+addr)
}
text.WriteString("\ncc: ")
text.WriteString(strings.Join(e.CC, ", "))
text.WriteString(strings.Join(ccs, ", "))
}
if options.Sender || options.Recipient || options.CC {
text.WriteString("\n\n")
@@ -146,10 +151,10 @@ func (e *Email) contentHeader(threadID id.EventID, text *strings.Builder, option
}
// Content converts the email object to a Matrix event content
func (e *Email) Content(threadID id.EventID, options *ContentOptions) *event.Content {
func (e *Email) Content(threadID id.EventID, options *ContentOptions, psd *utils.PSD) *event.Content {
var text strings.Builder
e.contentHeader(threadID, &text, options)
e.contentHeader(threadID, &text, options, psd)
if threadID != "" || (threadID == "" && !options.Threadify) {
if e.HTML != "" && options.HTML {

20
go.mod
View File

@@ -18,16 +18,16 @@ require (
github.com/mcnijman/go-emailaddress v1.1.0
github.com/mileusna/crontab v1.2.0
github.com/raja/argon2pw v1.0.2-0.20210910183755-a391af63bd39
github.com/rs/zerolog v1.31.0
gitlab.com/etke.cc/go/env v1.0.0
github.com/rs/zerolog v1.32.0
gitlab.com/etke.cc/go/env v1.1.0
gitlab.com/etke.cc/go/fswatcher v1.0.0
gitlab.com/etke.cc/go/healthchecks v1.0.1
gitlab.com/etke.cc/go/mxidwc v1.0.0
gitlab.com/etke.cc/go/secgen v1.1.1
gitlab.com/etke.cc/go/validator v1.0.6
gitlab.com/etke.cc/linkpearl v0.0.0-20231121221431-72443f33d266
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa
maunium.net/go/mautrix v0.16.2
gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a
golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3
maunium.net/go/mautrix v0.17.0
)
require (
@@ -50,12 +50,12 @@ require (
github.com/tidwall/match v1.1.1 // indirect
github.com/tidwall/pretty v1.2.1 // indirect
github.com/tidwall/sjson v1.2.5 // indirect
github.com/yuin/goldmark v1.6.0 // indirect
github.com/yuin/goldmark v1.7.0 // indirect
gitlab.com/etke.cc/go/trysmtp v1.1.3 // indirect
go.mau.fi/util v0.2.1 // indirect
golang.org/x/crypto v0.15.0 // indirect
golang.org/x/net v0.18.0 // indirect
golang.org/x/sys v0.14.0 // indirect
go.mau.fi/util v0.3.0 // indirect
golang.org/x/crypto v0.19.0 // indirect
golang.org/x/net v0.21.0 // indirect
golang.org/x/sys v0.17.0 // indirect
golang.org/x/text v0.14.0 // indirect
maunium.net/go/maulogger/v2 v2.4.1 // indirect
)

44
go.sum
View File

@@ -1,6 +1,6 @@
blitiri.com.ar/go/spf v1.5.1 h1:CWUEasc44OrANJD8CzceRnRn1Jv0LttY68cYym2/pbE=
blitiri.com.ar/go/spf v1.5.1/go.mod h1:E71N92TfL4+Yyd5lpKuE9CAF2pd4JrUq1xQfkTxoNdk=
github.com/DATA-DOG/go-sqlmock v1.5.0 h1:Shsta01QNfFxHCfpW6YH2STWB0MudeXXEWMr20OEh60=
github.com/DATA-DOG/go-sqlmock v1.5.1 h1:FK6RCIUSfmbnI/imIICmboyQBkOckutaa6R5YYlLZyo=
github.com/archdx/zerolog-sentry v1.2.0 h1:FDFqlo5XvL/jpDAPoAWI15EjJQVFvixn70v3IH//eTM=
github.com/archdx/zerolog-sentry v1.2.0/go.mod h1:3H8gClGFafB90fKMsvfP017bdmkG5MD6UiA+6iPEwGw=
github.com/buger/jsonparser v1.0.0 h1:etJTGF5ESxjI0Ic2UaLQs2LQQpa8G9ykQScukbh4L8A=
@@ -55,8 +55,6 @@ github.com/mattn/go-isatty v0.0.20/go.mod h1:W+V8PltTTMOvKvAeJH7IuucS94S2C6jfK/D
github.com/mattn/go-runewidth v0.0.9/go.mod h1:H031xJmbD/WCDINGzjvQ9THkh0rPKHF+m2gUSrubnMI=
github.com/mattn/go-runewidth v0.0.12 h1:Y41i/hVW3Pgwr8gV+J23B9YEY0zxjptBuCWEaxmAOow=
github.com/mattn/go-runewidth v0.0.12/go.mod h1:RAqKPSqVFrSLVXbA8x7dzmKdmGzieGRCM46jaSJTDAk=
github.com/mattn/go-sqlite3 v1.14.18 h1:JL0eqdCOq6DJVNPSvArO/bIV9/P7fbGrV00LZHc+5aI=
github.com/mattn/go-sqlite3 v1.14.18/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mattn/go-sqlite3 v1.14.19 h1:fhGleo2h1p8tVChob4I9HpmVFIAkKGpiukdrgQbWfGI=
github.com/mattn/go-sqlite3 v1.14.19/go.mod h1:2eHXhiwb8IkHr+BDWZGa96P6+rkvnG63S2DGjv9HUNg=
github.com/mcnijman/go-emailaddress v1.1.0 h1:7/Uxgn9pXwXmvXsFSgORo6XoRTrttj7AGmmB2yFArAg=
@@ -78,8 +76,8 @@ github.com/rivo/uniseg v0.1.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJ
github.com/rivo/uniseg v0.2.0 h1:S1pD9weZBuJdFmowNwbpi7BJ8TNftyUImj/0WQi72jY=
github.com/rivo/uniseg v0.2.0/go.mod h1:J6wj4VEh+S6ZtnVlnTBMWIodfgj8LQOQFoIToxlJtxc=
github.com/rs/xid v1.5.0/go.mod h1:trrq9SKmegXys3aeAKXMUTdJsYXVwGY3RLcfgqegfbg=
github.com/rs/zerolog v1.31.0 h1:FcTR3NnLWW+NnTwwhFWiJSZr4ECLpqCm6QsEnyvbV4A=
github.com/rs/zerolog v1.31.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/rs/zerolog v1.32.0 h1:keLypqrlIjaFsbmJOBdB/qvyF8KEtCWHwobLp5l/mQ0=
github.com/rs/zerolog v1.32.0/go.mod h1:/7mN4D5sKwJLZQ2b/znpjC3/GQWY/xaDXUM0kKWRHss=
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf h1:pvbZ0lM0XWPBqUKqFU8cmavspvIl9nulOYwdy6IFRRo=
github.com/ssor/bom v0.0.0-20170718123548-6386211fdfcf/go.mod h1:RJID2RhlZKId02nZ62WenDCkgHFerpIOmW0iT7GKmXM=
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
@@ -96,10 +94,12 @@ github.com/tidwall/pretty v1.2.1 h1:qjsOFOWWQl+N3RsoF5/ssm1pHmJJwhjlSbZ51I6wMl4=
github.com/tidwall/pretty v1.2.1/go.mod h1:ITEVvHYasfjBbM0u2Pg8T2nJnzm8xPwvNhhsoaGGjNU=
github.com/tidwall/sjson v1.2.5 h1:kLy8mja+1c9jlljvWTlSazM7cKDRfJuR/bOJhcY5NcY=
github.com/tidwall/sjson v1.2.5/go.mod h1:Fvgq9kS/6ociJEDnK0Fk1cpYF4FIW6ZF7LAe+6jwd28=
github.com/yuin/goldmark v1.6.0 h1:boZcn2GTjpsynOsC0iJHnBWa4Bi0qzfJjthwauItG68=
github.com/yuin/goldmark v1.6.0/go.mod h1:6yULJ656Px+3vBD8DxQVa3kxgyrAnzto9xy5taEt/CY=
github.com/yuin/goldmark v1.7.0 h1:EfOIvIMZIzHdB/R/zVrikYLPPwJlfMcNczJFMs1m6sA=
github.com/yuin/goldmark v1.7.0/go.mod h1:uzxRWxtg69N339t3louHJ7+O03ezfj6PlliRlaOzY1E=
gitlab.com/etke.cc/go/env v1.0.0 h1:J98BwzOuELnjsVPFvz5wa79L7IoRV9CmrS41xLYXtSw=
gitlab.com/etke.cc/go/env v1.0.0/go.mod h1:e1l4RM5MA1sc0R1w/RBDAESWRwgo5cOG9gx8BKUn2C4=
gitlab.com/etke.cc/go/env v1.1.0 h1:nbMhZkMu6C8lysRlb5siIiylWuyVkGAgEvwWEqz/82o=
gitlab.com/etke.cc/go/env v1.1.0/go.mod h1:e1l4RM5MA1sc0R1w/RBDAESWRwgo5cOG9gx8BKUn2C4=
gitlab.com/etke.cc/go/fswatcher v1.0.0 h1:uyiVn+1NVCjOLZrXSZouIDBDZBMwVipS4oYuvAFpPzo=
gitlab.com/etke.cc/go/fswatcher v1.0.0/go.mod h1:MqTOxyhXfvaVZQUL9/Ksbl2ow1PTBVu3eqIldvMq0RE=
gitlab.com/etke.cc/go/healthchecks v1.0.1 h1:IxPB+r4KtEM6wf4K7MeQoH1XnuBITMGUqFaaRIgxeUY=
@@ -112,21 +112,21 @@ gitlab.com/etke.cc/go/trysmtp v1.1.3 h1:e2EHond77onMaecqCg6mWumffTSEf+ycgj88nbee
gitlab.com/etke.cc/go/trysmtp v1.1.3/go.mod h1:lOO7tTdAE0a3ETV3wN3GJ7I1Tqewu7YTpPWaOmTteV0=
gitlab.com/etke.cc/go/validator v1.0.6 h1:w0Muxf9Pqw7xvF7NaaswE6d7r9U3nB2t2l5PnFMrecQ=
gitlab.com/etke.cc/go/validator v1.0.6/go.mod h1:Id0SxRj0J3IPhiKlj0w1plxVLZfHlkwipn7HfRZsDts=
gitlab.com/etke.cc/linkpearl v0.0.0-20231121221431-72443f33d266 h1:mGbLQkdE35WeyinqP38HC0dqUOJ7FItEAumVIOz7Gg8=
gitlab.com/etke.cc/linkpearl v0.0.0-20231121221431-72443f33d266/go.mod h1:wFEvngglb6ZTlE58/2a9gwYYs6V3FTYclYn5Pf5EGyQ=
go.mau.fi/util v0.2.1 h1:eazulhFE/UmjOFtPrGg6zkF5YfAyiDzQb8ihLMbsPWw=
go.mau.fi/util v0.2.1/go.mod h1:MjlzCQEMzJ+G8RsPawHzpLB8rwTo3aPIjG5FzBvQT/c=
gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a h1:30WtX+uepGqyFnU7jIockJWxQUeYdljhhk63DCOXLZs=
gitlab.com/etke.cc/linkpearl v0.0.0-20240211143445-bddf907d137a/go.mod h1:3lqQGDDtk52Jm8PD3mZ3qhmIp4JXuq95waWH5vmEacc=
go.mau.fi/util v0.3.0 h1:Lt3lbRXP6ZBqTINK0EieRWor3zEwwwrDT14Z5N8RUCs=
go.mau.fi/util v0.3.0/go.mod h1:9dGsBCCbZJstx16YgnVMVi3O2bOizELoKpugLD4FoGs=
golang.org/x/crypto v0.0.0-20220518034528-6f7dac969898/go.mod h1:IxCIyHEi3zRg3s0A5j5BB6A9Jmi73HwBIUl50j+osU4=
golang.org/x/crypto v0.15.0 h1:frVn1TEaCEaZcn3Tmd7Y2b5KKPaZ+I32Q2OA3kYp5TA=
golang.org/x/crypto v0.15.0/go.mod h1:4ChreQoLWfG3xLDer1WdlH5NdlQ3+mwnQq1YTKY+72g=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa h1:FRnLl4eNAQl8hwxVVC17teOw8kdjVDVAiFMtgUdTSRQ=
golang.org/x/exp v0.0.0-20231110203233-9a3e6036ecaa/go.mod h1:zk2irFbV9DP96SEBUUAy67IdHUaZuSnrz1n472HUCLE=
golang.org/x/crypto v0.19.0 h1:ENy+Az/9Y1vSrlrvBSyna3PITt4tiZLf7sgCjZBX7Wo=
golang.org/x/crypto v0.19.0/go.mod h1:Iy9bg/ha4yyC70EfRS8jz+B6ybOBKMaSxLj6P6oBDfU=
golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3 h1:/RIbNt/Zr7rVhIkQhooTxCxFcdWLGIKnZA4IXNFSrvo=
golang.org/x/exp v0.0.0-20240205201215-2c58cdc269a3/go.mod h1:idGWGoKP1toJGkd5/ig9ZLuPcZBC3ewk7SzmH0uou08=
golang.org/x/net v0.0.0-20180911220305-26e67e76b6c3/go.mod h1:mL1N/T3taQHkDXs73rZJwtUhF3w3ftmwwsq0BUmARs4=
golang.org/x/net v0.0.0-20210501142056-aec3718b3fa0/go.mod h1:OJAsFXCWl8Ukc7SiCT/9KSuxbyM7479/AVlXFRxuMCk=
golang.org/x/net v0.0.0-20211112202133-69e39bad7dc2/go.mod h1:9nx3DQGgdP8bBQD5qxJ1jj9UTztislL4KSBs9R2vV5Y=
golang.org/x/net v0.0.0-20220624214902-1bab6f366d9e/go.mod h1:XRhObCWvk6IyKnWLug+ECip1KBveYUHfp+8e9klMJ9c=
golang.org/x/net v0.18.0 h1:mIYleuAkSbHh0tCv7RvjL3F6ZVbLjq4+R7zbOn3Kokg=
golang.org/x/net v0.18.0/go.mod h1:/czyP5RqHAH4odGYxBJ1qz0+CE5WZ+2j1YgoEo8F2jQ=
golang.org/x/net v0.21.0 h1:AQyQV4dYCvJ7vGmJyKki9+PBdyvhkSd8EIx/qb0AYv4=
golang.org/x/net v0.21.0/go.mod h1:bIjVDfnllIU7BJ2DNgfnXvpSvtn8VRwhlsaeUTyUS44=
golang.org/x/sys v0.0.0-20201119102817-f84b799fce68/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210423082822-04245dca01da/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
golang.org/x/sys v0.0.0-20210615035016-665e8c7367d1/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
@@ -135,11 +135,11 @@ golang.org/x/sys v0.0.0-20220811171246-fbc7d0a398ab/go.mod h1:oPkhp1MJrh7nUepCBc
golang.org/x/sys v0.0.0-20220908164124-27713097b956/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.6.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.12.0/go.mod h1:oPkhp1MJrh7nUepCBck5+mAzfO9JrbApNNgaTdGDITg=
golang.org/x/sys v0.14.0 h1:Vz7Qs629MkJkGyHxUlRHizWJRG2j8fbQKjELVSNhy7Q=
golang.org/x/sys v0.14.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/sys v0.17.0 h1:25cE3gD+tdBA7lp7QfhuV+rJiE9YXTcS3VG1SqssI/Y=
golang.org/x/sys v0.17.0/go.mod h1:/VUhepiaJMQUp4+oa/7Zr1D23ma6VTLIYjOOTFZPUcA=
golang.org/x/term v0.0.0-20201126162022-7de9c90e9dd1/go.mod h1:bj7SfCRtBDWHUb9snDiAeCFNEtKQo2Wmx5Cou7ajbmo=
golang.org/x/term v0.0.0-20210927222741-03fcf44c2211/go.mod h1:jbD1KX2456YbFQfuXm/mYQcufACuNUgVhRMnK/tPxf8=
golang.org/x/term v0.14.0 h1:LGK9IlZ8T9jvdy6cTdfKUCltatMFOehAQo9SRC46UQ8=
golang.org/x/term v0.17.0 h1:mkTF7LCd6WGJNL3K1Ad7kwxNfYAW6a8a8QqtMblp/4U=
golang.org/x/text v0.3.2/go.mod h1:bEr9sfX3Q8Zfm5fL9x+3itogRgK3+ptLWKqgva+5dAk=
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
@@ -152,5 +152,5 @@ gopkg.in/yaml.v3 v3.0.1 h1:fxVm/GzAzEWqLHuvctI91KS9hhNmmWOoWu0XTYJS7CA=
gopkg.in/yaml.v3 v3.0.1/go.mod h1:K4uyk7z7BCEPqu6E+C64Yfv1cQ7kz7rIZviUmN+EgEM=
maunium.net/go/maulogger/v2 v2.4.1 h1:N7zSdd0mZkB2m2JtFUsiGTQQAdP0YeFWT7YMc80yAL8=
maunium.net/go/maulogger/v2 v2.4.1/go.mod h1:omPuYwYBILeVQobz8uO3XC8DIRuEb5rXYlQSuqrbCho=
maunium.net/go/mautrix v0.16.2 h1:a6GUJXNWsTEOO8VE4dROBfCIfPp50mqaqzv7KPzChvg=
maunium.net/go/mautrix v0.16.2/go.mod h1:YL4l4rZB46/vj/ifRMEjcibbvHjgxHftOF1SgmruLu4=
maunium.net/go/mautrix v0.17.0 h1:scc1qlUbzPn+wc+3eAPquyD+3gZwwy/hBANBm+iGKK8=
maunium.net/go/mautrix v0.17.0/go.mod h1:j+puTEQCEydlVxhJ/dQP5chfa26TdvBO7X6F3Ataav8=

View File

@@ -1,6 +1,7 @@
package smtp
import (
"context"
"crypto/tls"
"net"
"sync"
@@ -15,10 +16,10 @@ type Listener struct {
tls *tls.Config
tlsMu sync.Mutex
listener net.Listener
isBanned func(net.Addr) bool
isBanned func(context.Context, net.Addr) bool
}
func NewListener(port string, tlsConfig *tls.Config, isBanned func(net.Addr) bool, log *zerolog.Logger) (*Listener, error) {
func NewListener(port string, tlsConfig *tls.Config, isBanned func(context.Context, net.Addr) bool, log *zerolog.Logger) (*Listener, error) {
actual, err := net.Listen("tcp", ":"+port)
if err != nil {
return nil, err
@@ -52,7 +53,7 @@ func (l *Listener) Accept() (net.Conn, error) {
continue
}
}
if l.isBanned(conn.RemoteAddr()) {
if l.isBanned(context.Background(), conn.RemoteAddr()) {
conn.Close()
l.log.Info().Str("addr", conn.RemoteAddr().String()).Msg("rejected connection (already banned)")
continue

View File

@@ -60,16 +60,16 @@ type Manager struct {
}
type matrixbot interface {
AllowAuth(string, string) (id.RoomID, bool)
IsGreylisted(net.Addr) bool
IsBanned(net.Addr) bool
AllowAuth(context.Context, string, string) (id.RoomID, bool)
IsGreylisted(context.Context, net.Addr) bool
IsBanned(context.Context, net.Addr) bool
IsTrusted(net.Addr) bool
BanAuto(net.Addr)
BanAuth(net.Addr)
GetMapping(string) (id.RoomID, bool)
GetIFOptions(id.RoomID) email.IncomingFilteringOptions
BanAuto(context.Context, net.Addr)
BanAuth(context.Context, net.Addr)
GetMapping(context.Context, string) (id.RoomID, bool)
GetIFOptions(context.Context, id.RoomID) email.IncomingFilteringOptions
IncomingEmail(context.Context, *email.Email) error
GetDKIMprivkey() string
GetDKIMprivkey(context.Context) string
}
// Caller is Sendmail caller

View File

@@ -46,27 +46,28 @@ type mailServer struct {
// Login used for outgoing mail submissions only (when you use postmoogle as smtp server in your scripts)
func (m *mailServer) Login(state *smtp.ConnectionState, username, password string) (smtp.Session, error) {
m.log.Debug().Str("username", username).Any("state", state).Msg("Login")
if m.bot.IsBanned(state.RemoteAddr) {
ctx := context.Background()
if m.bot.IsBanned(ctx, state.RemoteAddr) {
return nil, ErrBanned
}
if !email.AddressValid(username) {
m.log.Debug().Str("address", username).Msg("address is invalid")
m.bot.BanAuth(state.RemoteAddr)
m.bot.BanAuth(ctx, state.RemoteAddr)
return nil, ErrBanned
}
roomID, allow := m.bot.AllowAuth(username, password)
roomID, allow := m.bot.AllowAuth(ctx, username, password)
if !allow {
m.log.Debug().Str("username", username).Msg("username or password is invalid")
m.bot.BanAuth(state.RemoteAddr)
m.bot.BanAuth(ctx, state.RemoteAddr)
return nil, ErrBanned
}
return &outgoingSession{
ctx: sentry.SetHubOnContext(context.Background(), sentry.CurrentHub().Clone()),
sendmail: m.sender.Send,
privkey: m.bot.GetDKIMprivkey(),
privkey: m.bot.GetDKIMprivkey(ctx),
from: username,
log: m.log,
domains: m.domains,
@@ -79,7 +80,8 @@ func (m *mailServer) Login(state *smtp.ConnectionState, username, password strin
// AnonymousLogin used for incoming mail submissions only
func (m *mailServer) AnonymousLogin(state *smtp.ConnectionState) (smtp.Session, error) {
m.log.Debug().Any("state", state).Msg("AnonymousLogin")
if m.bot.IsBanned(state.RemoteAddr) {
ctx := context.Background()
if m.bot.IsBanned(ctx, state.RemoteAddr) {
return nil, ErrBanned
}

View File

@@ -33,12 +33,12 @@ var (
// incomingSession represents an SMTP-submission session receiving emails from remote servers
type incomingSession struct {
log *zerolog.Logger
getRoomID func(string) (id.RoomID, bool)
getFilters func(id.RoomID) email.IncomingFilteringOptions
getRoomID func(context.Context, string) (id.RoomID, bool)
getFilters func(context.Context, id.RoomID) email.IncomingFilteringOptions
receiveEmail func(context.Context, *email.Email) error
greylisted func(net.Addr) bool
greylisted func(context.Context, net.Addr) bool
trusted func(net.Addr) bool
ban func(net.Addr)
ban func(context.Context, net.Addr)
domains []string
roomID id.RoomID
@@ -52,7 +52,7 @@ func (s *incomingSession) Mail(from string, opts smtp.MailOptions) error {
sentry.GetHubFromContext(s.ctx).Scope().SetTag("from", from)
if !email.AddressValid(from) {
s.log.Debug().Str("from", from).Msg("address is invalid")
s.ban(s.addr)
s.ban(s.ctx, s.addr)
return ErrBanned
}
s.from = email.Address(from)
@@ -77,7 +77,7 @@ func (s *incomingSession) Rcpt(to string) error {
}
var ok bool
s.roomID, ok = s.getRoomID(utils.Mailbox(to))
s.roomID, ok = s.getRoomID(s.ctx, utils.Mailbox(to))
if !ok {
s.log.Debug().Str("to", to).Msg("mapping not found")
return ErrNoUser
@@ -126,12 +126,12 @@ func (s *incomingSession) Data(r io.Reader) error {
}
addr := s.getAddr(envelope)
reader.Seek(0, io.SeekStart) //nolint:errcheck // becase we're sure that's ok
validations := s.getFilters(s.roomID)
validations := s.getFilters(s.ctx, s.roomID)
if !validateIncoming(s.from, s.tos[0], addr, s.log, validations) {
s.ban(addr)
s.ban(s.ctx, addr)
return ErrBanned
}
if s.greylisted(addr) {
if s.greylisted(s.ctx, addr) {
return &smtp.SMTPError{
Code: GraylistCode,
EnhancedCode: GraylistEnhancedCode,
@@ -172,7 +172,7 @@ type outgoingSession struct {
sendmail func(string, string, string) error
privkey string
domains []string
getRoomID func(string) (id.RoomID, bool)
getRoomID func(context.Context, string) (id.RoomID, bool)
ctx context.Context //nolint:containedctx // that's session
tos []string
@@ -198,7 +198,7 @@ func (s *outgoingSession) Mail(from string, _ smtp.MailOptions) error {
return ErrNoUser
}
roomID, ok := s.getRoomID(utils.Mailbox(from))
roomID, ok := s.getRoomID(s.ctx, utils.Mailbox(from))
if !ok {
s.log.Debug().Str("from", from).Msg("mapping not found")
return ErrNoUser

112
utils/psd.go Normal file
View File

@@ -0,0 +1,112 @@
package utils
import (
"context"
"encoding/json"
"fmt"
"io"
"net/http"
"net/url"
"sync"
"time"
"github.com/rs/zerolog"
)
//nolint:gocritic // sync.Mutex is intended
type PSD struct {
sync.Mutex
cachedAt time.Time
cache map[string]bool
log *zerolog.Logger
url *url.URL
login string
password string
}
type PSDTarget struct {
Targets []string `json:"targets"`
Labels map[string]string `json:"labels"`
}
func NewPSD(baseURL, login, password string, log *zerolog.Logger) *PSD {
uri, err := url.Parse(baseURL)
if err != nil || login == "" || password == "" {
return &PSD{}
}
return &PSD{url: uri, login: login, password: password, log: log}
}
func (p *PSD) Contains(email string) (bool, error) {
if p.cachedAt.IsZero() || time.Since(p.cachedAt) > 10*time.Minute {
err := p.updateCache()
if err != nil {
return false, err
}
}
p.Lock()
defer p.Unlock()
return p.cache[email], nil
}
func (p *PSD) Status(email string) string {
ok, err := p.Contains(email)
if !ok || err != nil {
return ""
}
return "👤"
}
func (p *PSD) updateCache() error {
p.Lock()
defer p.Unlock()
defer func() {
p.cachedAt = time.Now()
}()
if p.url == nil {
return nil
}
cloned := *p.url
uri := cloned.JoinPath("/emails")
ctx, cancel := context.WithTimeout(context.Background(), 10*time.Second)
defer cancel()
req, err := http.NewRequestWithContext(ctx, http.MethodGet, uri.String(), http.NoBody)
if err != nil {
p.log.Error().Err(err).Msg("failed to create request")
return err
}
req.SetBasicAuth(p.login, p.password)
resp, err := http.DefaultClient.Do(req)
if err != nil {
return err
}
defer resp.Body.Close()
if resp.StatusCode != http.StatusOK {
err = fmt.Errorf("%s", resp.Status) //nolint:goerr113 // no need to wrap
p.log.Error().Err(err).Msg("failed to fetch PSD")
return err
}
datab, err := io.ReadAll(resp.Body)
if err != nil {
p.log.Error().Err(err).Msg("failed to read response")
return err
}
var psd []*PSDTarget
err = json.Unmarshal(datab, &psd)
if err != nil {
p.log.Error().Err(err).Msg("failed to unmarshal response")
return err
}
p.cache = make(map[string]bool)
for _, t := range psd {
for _, email := range t.Targets {
p.cache[email] = true
}
}
return nil
}

View File

@@ -547,7 +547,7 @@ and facilitates the unification of logging and tracing in some systems:
type TracingHook struct{}
func (h TracingHook) Run(e *zerolog.Event, level zerolog.Level, msg string) {
ctx := e.Ctx()
ctx := e.GetCtx()
spanId := getSpanIdFromContext(ctx) // as per your tracing framework
e.Str("span-id", spanId)
}

View File

@@ -76,6 +76,8 @@ type ConsoleWriter struct {
FormatErrFieldValue Formatter
FormatExtra func(map[string]interface{}, *bytes.Buffer) error
FormatPrepare func(map[string]interface{}) error
}
// NewConsoleWriter creates and initializes a new ConsoleWriter.
@@ -124,6 +126,13 @@ func (w ConsoleWriter) Write(p []byte) (n int, err error) {
return n, fmt.Errorf("cannot decode event: %s", err)
}
if w.FormatPrepare != nil {
err = w.FormatPrepare(evt)
if err != nil {
return n, err
}
}
for _, p := range w.PartsOrder {
w.writePart(buf, evt, p)
}
@@ -146,6 +155,15 @@ func (w ConsoleWriter) Write(p []byte) (n int, err error) {
return len(p), err
}
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (w ConsoleWriter) Close() error {
if closer, ok := w.Out.(io.Closer); ok {
return closer.Close()
}
return nil
}
// writeFields appends formatted key-value pairs to buf.
func (w ConsoleWriter) writeFields(evt map[string]interface{}, buf *bytes.Buffer) {
var fields = make([]string, 0, len(evt))
@@ -272,7 +290,7 @@ func (w ConsoleWriter) writePart(buf *bytes.Buffer, evt map[string]interface{},
}
case MessageFieldName:
if w.FormatMessage == nil {
f = consoleDefaultFormatMessage
f = consoleDefaultFormatMessage(w.NoColor, evt[LevelFieldName])
} else {
f = w.FormatMessage
}
@@ -310,10 +328,10 @@ func needsQuote(s string) bool {
return false
}
// colorize returns the string s wrapped in ANSI code c, unless disabled is true.
// colorize returns the string s wrapped in ANSI code c, unless disabled is true or c is 0.
func colorize(s interface{}, c int, disabled bool) string {
e := os.Getenv("NO_COLOR")
if e != "" {
if e != "" || c == 0 {
disabled = true
}
@@ -378,27 +396,16 @@ func consoleDefaultFormatLevel(noColor bool) Formatter {
return func(i interface{}) string {
var l string
if ll, ok := i.(string); ok {
switch ll {
case LevelTraceValue:
l = colorize("TRC", colorMagenta, noColor)
case LevelDebugValue:
l = colorize("DBG", colorYellow, noColor)
case LevelInfoValue:
l = colorize("INF", colorGreen, noColor)
case LevelWarnValue:
l = colorize("WRN", colorRed, noColor)
case LevelErrorValue:
l = colorize(colorize("ERR", colorRed, noColor), colorBold, noColor)
case LevelFatalValue:
l = colorize(colorize("FTL", colorRed, noColor), colorBold, noColor)
case LevelPanicValue:
l = colorize(colorize("PNC", colorRed, noColor), colorBold, noColor)
default:
l = colorize(ll, colorBold, noColor)
level, _ := ParseLevel(ll)
fl, ok := FormattedLevels[level]
if ok {
l = colorize(fl, LevelColors[level], noColor)
} else {
l = strings.ToUpper(ll)[0:3]
}
} else {
if i == nil {
l = colorize("???", colorBold, noColor)
l = "???"
} else {
l = strings.ToUpper(fmt.Sprintf("%s", i))[0:3]
}
@@ -425,11 +432,18 @@ func consoleDefaultFormatCaller(noColor bool) Formatter {
}
}
func consoleDefaultFormatMessage(i interface{}) string {
if i == nil {
return ""
func consoleDefaultFormatMessage(noColor bool, level interface{}) Formatter {
return func(i interface{}) string {
if i == nil || i == "" {
return ""
}
switch level {
case LevelInfoValue, LevelWarnValue, LevelErrorValue, LevelFatalValue, LevelPanicValue:
return colorize(fmt.Sprintf("%s", i), colorBold, noColor)
default:
return fmt.Sprintf("%s", i)
}
}
return fmt.Sprintf("%s", i)
}
func consoleDefaultFormatFieldName(noColor bool) Formatter {
@@ -450,6 +464,6 @@ func consoleDefaultFormatErrFieldName(noColor bool) Formatter {
func consoleDefaultFormatErrFieldValue(noColor bool) Formatter {
return func(i interface{}) string {
return colorize(fmt.Sprintf("%s", i), colorRed, noColor)
return colorize(colorize(fmt.Sprintf("%s", i), colorBold, noColor), colorRed, noColor)
}
}

View File

@@ -3,7 +3,7 @@ package zerolog
import (
"context"
"fmt"
"io/ioutil"
"io"
"math"
"net"
"time"
@@ -23,7 +23,7 @@ func (c Context) Logger() Logger {
// Only map[string]interface{} and []interface{} are accepted. []interface{} must
// alternate string keys and arbitrary values, and extraneous ones are ignored.
func (c Context) Fields(fields interface{}) Context {
c.l.context = appendFields(c.l.context, fields)
c.l.context = appendFields(c.l.context, fields, c.l.stack)
return c
}
@@ -57,7 +57,7 @@ func (c Context) Array(key string, arr LogArrayMarshaler) Context {
// Object marshals an object that implement the LogObjectMarshaler interface.
func (c Context) Object(key string, obj LogObjectMarshaler) Context {
e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0)
e := newEvent(LevelWriterAdapter{io.Discard}, 0)
e.Object(key, obj)
c.l.context = enc.AppendObjectData(c.l.context, e.buf)
putEvent(e)
@@ -66,7 +66,7 @@ func (c Context) Object(key string, obj LogObjectMarshaler) Context {
// EmbedObject marshals and Embeds an object that implement the LogObjectMarshaler interface.
func (c Context) EmbedObject(obj LogObjectMarshaler) Context {
e := newEvent(LevelWriterAdapter{ioutil.Discard}, 0)
e := newEvent(LevelWriterAdapter{io.Discard}, 0)
e.EmbedObject(obj)
c.l.context = enc.AppendObjectData(c.l.context, e.buf)
putEvent(e)
@@ -163,6 +163,22 @@ func (c Context) Errs(key string, errs []error) Context {
// Err adds the field "error" with serialized err to the logger context.
func (c Context) Err(err error) Context {
if c.l.stack && ErrorStackMarshaler != nil {
switch m := ErrorStackMarshaler(err).(type) {
case nil:
case LogObjectMarshaler:
c = c.Object(ErrorStackFieldName, m)
case error:
if m != nil && !isNilValue(m) {
c = c.Str(ErrorStackFieldName, m.Error())
}
case string:
c = c.Str(ErrorStackFieldName, m)
default:
c = c.Interface(ErrorStackFieldName, m)
}
}
return c.AnErr(ErrorFieldName, err)
}
@@ -375,10 +391,19 @@ func (c Context) Durs(key string, d []time.Duration) Context {
// Interface adds the field key with obj marshaled using reflection.
func (c Context) Interface(key string, i interface{}) Context {
if obj, ok := i.(LogObjectMarshaler); ok {
return c.Object(key, obj)
}
c.l.context = enc.AppendInterface(enc.AppendKey(c.l.context, key), i)
return c
}
// Type adds the field key with val's type using reflection.
func (c Context) Type(key string, val interface{}) Context {
c.l.context = enc.AppendType(enc.AppendKey(c.l.context, key), val)
return c
}
// Any is a wrapper around Context.Interface.
func (c Context) Any(key string, i interface{}) Context {
return c.Interface(key, i)

View File

@@ -164,7 +164,7 @@ func (e *Event) Fields(fields interface{}) *Event {
if e == nil {
return e
}
e.buf = appendFields(e.buf, fields)
e.buf = appendFields(e.buf, fields, e.stack)
return e
}

7
vendor/github.com/rs/zerolog/example.jsonl generated vendored Normal file
View File

@@ -0,0 +1,7 @@
{"time":"5:41PM","level":"info","message":"Starting listener","listen":":8080","pid":37556}
{"time":"5:41PM","level":"debug","message":"Access","database":"myapp","host":"localhost:4962","pid":37556}
{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":23}
{"time":"5:41PM","level":"info","message":"Access","method":"POST","path":"/posts","pid":37556,"resp_time":532}
{"time":"5:41PM","level":"warn","message":"Slow request","method":"POST","path":"/posts","pid":37556,"resp_time":532}
{"time":"5:41PM","level":"info","message":"Access","method":"GET","path":"/users","pid":37556,"resp_time":10}
{"time":"5:41PM","level":"error","message":"Database connection lost","database":"myapp","pid":37556,"error":"connection reset by peer"}

View File

@@ -12,13 +12,13 @@ func isNilValue(i interface{}) bool {
return (*[2]uintptr)(unsafe.Pointer(&i))[1] == 0
}
func appendFields(dst []byte, fields interface{}) []byte {
func appendFields(dst []byte, fields interface{}, stack bool) []byte {
switch fields := fields.(type) {
case []interface{}:
if n := len(fields); n&0x1 == 1 { // odd number
fields = fields[:n-1]
}
dst = appendFieldList(dst, fields)
dst = appendFieldList(dst, fields, stack)
case map[string]interface{}:
keys := make([]string, 0, len(fields))
for key := range fields {
@@ -28,13 +28,13 @@ func appendFields(dst []byte, fields interface{}) []byte {
kv := make([]interface{}, 2)
for _, key := range keys {
kv[0], kv[1] = key, fields[key]
dst = appendFieldList(dst, kv)
dst = appendFieldList(dst, kv, stack)
}
}
return dst
}
func appendFieldList(dst []byte, kvList []interface{}) []byte {
func appendFieldList(dst []byte, kvList []interface{}, stack bool) []byte {
for i, n := 0, len(kvList); i < n; i += 2 {
key, val := kvList[i], kvList[i+1]
if key, ok := key.(string); ok {
@@ -74,6 +74,21 @@ func appendFieldList(dst []byte, kvList []interface{}) []byte {
default:
dst = enc.AppendInterface(dst, m)
}
if stack && ErrorStackMarshaler != nil {
dst = enc.AppendKey(dst, ErrorStackFieldName)
switch m := ErrorStackMarshaler(val).(type) {
case nil:
case error:
if m != nil && !isNilValue(m) {
dst = enc.AppendString(dst, m.Error())
}
case string:
dst = enc.AppendString(dst, m)
default:
dst = enc.AppendInterface(dst, m)
}
}
case []error:
dst = enc.AppendArrayStart(dst)
for i, err := range val {

View File

@@ -108,6 +108,34 @@ var (
// DefaultContextLogger is returned from Ctx() if there is no logger associated
// with the context.
DefaultContextLogger *Logger
// LevelColors are used by ConsoleWriter's consoleDefaultFormatLevel to color
// log levels.
LevelColors = map[Level]int{
TraceLevel: colorBlue,
DebugLevel: 0,
InfoLevel: colorGreen,
WarnLevel: colorYellow,
ErrorLevel: colorRed,
FatalLevel: colorRed,
PanicLevel: colorRed,
}
// FormattedLevels are used by ConsoleWriter's consoleDefaultFormatLevel
// for a short level name.
FormattedLevels = map[Level]string{
TraceLevel: "TRC",
DebugLevel: "DBG",
InfoLevel: "INF",
WarnLevel: "WRN",
ErrorLevel: "ERR",
FatalLevel: "FTL",
PanicLevel: "PNC",
}
// TriggerLevelWriterBufferReuseLimit is a limit in bytes that a buffer is dropped
// from the TriggerLevelWriter buffer pool if the buffer grows above the limit.
TriggerLevelWriterBufferReuseLimit = 64 * 1024
)
var (

32
vendor/github.com/rs/zerolog/log.go generated vendored
View File

@@ -118,7 +118,6 @@ import (
"errors"
"fmt"
"io"
"io/ioutil"
"os"
"strconv"
"strings"
@@ -246,7 +245,7 @@ type Logger struct {
// you may consider using sync wrapper.
func New(w io.Writer) Logger {
if w == nil {
w = ioutil.Discard
w = io.Discard
}
lw, ok := w.(LevelWriter)
if !ok {
@@ -326,10 +325,13 @@ func (l Logger) Sample(s Sampler) Logger {
}
// Hook returns a logger with the h Hook.
func (l Logger) Hook(h Hook) Logger {
newHooks := make([]Hook, len(l.hooks), len(l.hooks)+1)
func (l Logger) Hook(hooks ...Hook) Logger {
if len(hooks) == 0 {
return l
}
newHooks := make([]Hook, len(l.hooks), len(l.hooks)+len(hooks))
copy(newHooks, l.hooks)
l.hooks = append(newHooks, h)
l.hooks = append(newHooks, hooks...)
return l
}
@@ -385,7 +387,14 @@ func (l *Logger) Err(err error) *Event {
//
// You must call Msg on the returned event in order to send the event.
func (l *Logger) Fatal() *Event {
return l.newEvent(FatalLevel, func(msg string) { os.Exit(1) })
return l.newEvent(FatalLevel, func(msg string) {
if closer, ok := l.w.(io.Closer); ok {
// Close the writer to flush any buffered message. Otherwise the message
// will be lost as os.Exit() terminates the program immediately.
closer.Close()
}
os.Exit(1)
})
}
// Panic starts a new message with panic level. The panic() function
@@ -450,6 +459,14 @@ func (l *Logger) Printf(format string, v ...interface{}) {
}
}
// Println sends a log event using debug level and no extra field.
// Arguments are handled in the manner of fmt.Println.
func (l *Logger) Println(v ...interface{}) {
if e := l.Debug(); e.Enabled() {
e.CallerSkipFrame(1).Msg(fmt.Sprintln(v...))
}
}
// Write implements the io.Writer interface. This is useful to set as a writer
// for the standard library log.
func (l Logger) Write(p []byte) (n int, err error) {
@@ -488,6 +505,9 @@ func (l *Logger) newEvent(level Level, done func(string)) *Event {
// should returns true if the log event should be logged.
func (l *Logger) should(lvl Level) bool {
if l.w == nil {
return false
}
if lvl < l.level || lvl < GlobalLevel() {
return false
}

Binary file not shown.

Before

Width:  |  Height:  |  Size: 82 KiB

After

Width:  |  Height:  |  Size: 116 KiB

View File

@@ -78,3 +78,12 @@ func (sw syslogWriter) WriteLevel(level Level, p []byte) (n int, err error) {
n = len(p)
return
}
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (sw syslogWriter) Close() error {
if c, ok := sw.w.(io.Closer); ok {
return c.Close()
}
return nil
}

View File

@@ -27,6 +27,15 @@ func (lw LevelWriterAdapter) WriteLevel(l Level, p []byte) (n int, err error) {
return lw.Write(p)
}
// Call the underlying writer's Close method if it is an io.Closer. Otherwise
// does nothing.
func (lw LevelWriterAdapter) Close() error {
if closer, ok := lw.Writer.(io.Closer); ok {
return closer.Close()
}
return nil
}
type syncWriter struct {
mu sync.Mutex
lw LevelWriter
@@ -57,6 +66,15 @@ func (s *syncWriter) WriteLevel(l Level, p []byte) (n int, err error) {
return s.lw.WriteLevel(l, p)
}
func (s *syncWriter) Close() error {
s.mu.Lock()
defer s.mu.Unlock()
if closer, ok := s.lw.(io.Closer); ok {
return closer.Close()
}
return nil
}
type multiLevelWriter struct {
writers []LevelWriter
}
@@ -89,6 +107,20 @@ func (t multiLevelWriter) WriteLevel(l Level, p []byte) (n int, err error) {
return n, err
}
// Calls close on all the underlying writers that are io.Closers. If any of the
// Close methods return an error, the remainder of the closers are not closed
// and the error is returned.
func (t multiLevelWriter) Close() error {
for _, w := range t.writers {
if closer, ok := w.(io.Closer); ok {
if err := closer.Close(); err != nil {
return err
}
}
}
return nil
}
// MultiLevelWriter creates a writer that duplicates its writes to all the
// provided writers, similar to the Unix tee(1) command. If some writers
// implement LevelWriter, their WriteLevel method will be used instead of Write.
@@ -180,3 +212,135 @@ func (w *FilteredLevelWriter) WriteLevel(level Level, p []byte) (int, error) {
}
return len(p), nil
}
var triggerWriterPool = &sync.Pool{
New: func() interface{} {
return bytes.NewBuffer(make([]byte, 0, 1024))
},
}
// TriggerLevelWriter buffers log lines at the ConditionalLevel or below
// until a trigger level (or higher) line is emitted. Log lines with level
// higher than ConditionalLevel are always written out to the destination
// writer. If trigger never happens, buffered log lines are never written out.
//
// It can be used to configure "log level per request".
type TriggerLevelWriter struct {
// Destination writer. If LevelWriter is provided (usually), its WriteLevel is used
// instead of Write.
io.Writer
// ConditionalLevel is the level (and below) at which lines are buffered until
// a trigger level (or higher) line is emitted. Usually this is set to DebugLevel.
ConditionalLevel Level
// TriggerLevel is the lowest level that triggers the sending of the conditional
// level lines. Usually this is set to ErrorLevel.
TriggerLevel Level
buf *bytes.Buffer
triggered bool
mu sync.Mutex
}
func (w *TriggerLevelWriter) WriteLevel(l Level, p []byte) (n int, err error) {
w.mu.Lock()
defer w.mu.Unlock()
// At first trigger level or above log line, we flush the buffer and change the
// trigger state to triggered.
if !w.triggered && l >= w.TriggerLevel {
err := w.trigger()
if err != nil {
return 0, err
}
}
// Unless triggered, we buffer everything at and below ConditionalLevel.
if !w.triggered && l <= w.ConditionalLevel {
if w.buf == nil {
w.buf = triggerWriterPool.Get().(*bytes.Buffer)
}
// We prefix each log line with a byte with the level.
// Hopefully we will never have a level value which equals a newline
// (which could interfere with reconstruction of log lines in the trigger method).
w.buf.WriteByte(byte(l))
w.buf.Write(p)
return len(p), nil
}
// Anything above ConditionalLevel is always passed through.
// Once triggered, everything is passed through.
if lw, ok := w.Writer.(LevelWriter); ok {
return lw.WriteLevel(l, p)
}
return w.Write(p)
}
// trigger expects lock to be held.
func (w *TriggerLevelWriter) trigger() error {
if w.triggered {
return nil
}
w.triggered = true
if w.buf == nil {
return nil
}
p := w.buf.Bytes()
for len(p) > 0 {
// We do not use bufio.Scanner here because we already have full buffer
// in the memory and we do not want extra copying from the buffer to
// scanner's token slice, nor we want to hit scanner's token size limit,
// and we also want to preserve newlines.
i := bytes.IndexByte(p, '\n')
line := p[0 : i+1]
p = p[i+1:]
// We prefixed each log line with a byte with the level.
level := Level(line[0])
line = line[1:]
var err error
if lw, ok := w.Writer.(LevelWriter); ok {
_, err = lw.WriteLevel(level, line)
} else {
_, err = w.Write(line)
}
if err != nil {
return err
}
}
return nil
}
// Trigger forces flushing the buffer and change the trigger state to
// triggered, if the writer has not already been triggered before.
func (w *TriggerLevelWriter) Trigger() error {
w.mu.Lock()
defer w.mu.Unlock()
return w.trigger()
}
// Close closes the writer and returns the buffer to the pool.
func (w *TriggerLevelWriter) Close() error {
w.mu.Lock()
defer w.mu.Unlock()
if w.buf == nil {
return nil
}
// We return the buffer only if it has not grown above the limit.
// This prevents accumulation of large buffers in the pool just
// because occasionally a large buffer might be needed.
if w.buf.Cap() <= TriggerLevelWriterBufferReuseLimit {
w.buf.Reset()
triggerWriterPool.Put(w.buf)
}
w.buf = nil
return nil
}

View File

@@ -8,7 +8,7 @@ goldmark
> A Markdown parser written in Go. Easy to extend, standards-compliant, well-structured.
goldmark is compliant with CommonMark 0.30.
goldmark is compliant with CommonMark 0.31.2.
Motivation
----------------------
@@ -260,7 +260,7 @@ You can override autolinking patterns via options.
| Functional option | Type | Description |
| ----------------- | ---- | ----------- |
| `extension.WithLinkifyAllowedProtocols` | `[][]byte` | List of allowed protocols such as `[][]byte{ []byte("http:") }` |
| `extension.WithLinkifyAllowedProtocols` | `[][]byte \| []string` | List of allowed protocols such as `[]string{ "http:" }` |
| `extension.WithLinkifyURLRegexp` | `*regexp.Regexp` | Regexp that defines URLs, including protocols |
| `extension.WithLinkifyWWWRegexp` | `*regexp.Regexp` | Regexp that defines URL starting with `www.`. This pattern corresponds to [the extended www autolink](https://github.github.com/gfm/#extended-www-autolink) |
| `extension.WithLinkifyEmailRegexp` | `*regexp.Regexp` | Regexp that defines email addresses` |
@@ -277,9 +277,9 @@ markdown := goldmark.New(
),
goldmark.WithExtensions(
extension.NewLinkify(
extension.WithLinkifyAllowedProtocols([][]byte{
[]byte("http:"),
[]byte("https:"),
extension.WithLinkifyAllowedProtocols([]string{
"http:",
"https:",
}),
extension.WithLinkifyURLRegexp(
xurls.Strict,
@@ -297,13 +297,13 @@ This extension has some options:
| Functional option | Type | Description |
| ----------------- | ---- | ----------- |
| `extension.WithFootnoteIDPrefix` | `[]byte` | a prefix for the id attributes.|
| `extension.WithFootnoteIDPrefix` | `[]byte \| string` | a prefix for the id attributes.|
| `extension.WithFootnoteIDPrefixFunction` | `func(gast.Node) []byte` | a function that determines the id attribute for given Node.|
| `extension.WithFootnoteLinkTitle` | `[]byte` | an optional title attribute for footnote links.|
| `extension.WithFootnoteBacklinkTitle` | `[]byte` | an optional title attribute for footnote backlinks. |
| `extension.WithFootnoteLinkClass` | `[]byte` | a class for footnote links. This defaults to `footnote-ref`. |
| `extension.WithFootnoteBacklinkClass` | `[]byte` | a class for footnote backlinks. This defaults to `footnote-backref`. |
| `extension.WithFootnoteBacklinkHTML` | `[]byte` | a class for footnote backlinks. This defaults to `&#x21a9;&#xfe0e;`. |
| `extension.WithFootnoteLinkTitle` | `[]byte \| string` | an optional title attribute for footnote links.|
| `extension.WithFootnoteBacklinkTitle` | `[]byte \| string` | an optional title attribute for footnote backlinks. |
| `extension.WithFootnoteLinkClass` | `[]byte \| string` | a class for footnote links. This defaults to `footnote-ref`. |
| `extension.WithFootnoteBacklinkClass` | `[]byte \| string` | a class for footnote backlinks. This defaults to `footnote-backref`. |
| `extension.WithFootnoteBacklinkHTML` | `[]byte \| string` | a class for footnote backlinks. This defaults to `&#x21a9;&#xfe0e;`. |
Some options can have special substitutions. Occurrences of “^^” in the string will be replaced by the corresponding footnote number in the HTML output. Occurrences of “%%” will be replaced by a number for the reference (footnotes can have multiple references).
@@ -319,7 +319,7 @@ for _, path := range files {
markdown := goldmark.New(
goldmark.WithExtensions(
NewFootnote(
WithFootnoteIDPrefix([]byte(path)),
WithFootnoteIDPrefix(path),
),
),
)
@@ -379,7 +379,7 @@ This extension provides additional options for CJK users.
| Functional option | Type | Description |
| ----------------- | ---- | ----------- |
| `extension.WithEastAsianLineBreaks` | `...extension.EastAsianLineBreaksStyle` | Soft line breaks are rendered as a newline. Some asian users will see it as an unnecessary space. With this option, soft line breaks between east asian wide characters will be ignored. |
| `extension.WithEastAsianLineBreaks` | `...extension.EastAsianLineBreaksStyle` | Soft line breaks are rendered as a newline. Some asian users will see it as an unnecessary space. With this option, soft line breaks between east asian wide characters will be ignored. This defaults to `EastAsianLineBreaksStyleSimple`. |
| `extension.WithEscapedSpace` | `-` | Without spaces around an emphasis started with east asian punctuations, it is not interpreted as an emphasis(as defined in CommonMark spec). With this option, you can avoid this inconvenient behavior by putting 'not rendered' spaces around an emphasis like `太郎は\ **「こんにちわ」**\ といった`. |
#### Styles of Line Breaking
@@ -467,6 +467,7 @@ As you can see, goldmark's performance is on par with cmark's.
Extensions
--------------------
### List of extensions
- [goldmark-meta](https://github.com/yuin/goldmark-meta): A YAML metadata
extension for the goldmark Markdown parser.
@@ -490,6 +491,13 @@ Extensions
- [goldmark-d2](https://github.com/FurqanSoftware/goldmark-d2): Adds support for [D2](https://d2lang.com/) diagrams.
- [goldmark-katex](https://github.com/FurqanSoftware/goldmark-katex): Adds support for [KaTeX](https://katex.org/) math and equations.
- [goldmark-img64](https://github.com/tenkoh/goldmark-img64): Adds support for embedding images into the document as DataURL (base64 encoded).
- [goldmark-enclave](https://github.com/quail-ink/goldmark-enclave): Adds support for embedding youtube/bilibili video, X's [oembed tweet](https://publish.twitter.com/), [tradingview](https://www.tradingview.com/widget/)'s chart, [quail](https://quail.ink)'s widget into the document.
- [goldmark-wiki-table](https://github.com/movsb/goldmark-wiki-table): Adds support for embedding Wiki Tables.
### Loading extensions at runtime
[goldmark-dynamic](https://github.com/yuin/goldmark-dynamic) allows you to write a goldmark extension in Lua and load it at runtime without re-compilation.
Please refer to [goldmark-dynamic](https://github.com/yuin/goldmark-dynamic) for details.
goldmark internal(for extension developers)

View File

@@ -382,8 +382,8 @@ func (o *withFootnoteIDPrefix) SetFootnoteOption(c *FootnoteConfig) {
}
// WithFootnoteIDPrefix is a functional option that is a prefix for the id attributes generated by footnotes.
func WithFootnoteIDPrefix(a []byte) FootnoteOption {
return &withFootnoteIDPrefix{a}
func WithFootnoteIDPrefix[T []byte | string](a T) FootnoteOption {
return &withFootnoteIDPrefix{[]byte(a)}
}
const optFootnoteIDPrefixFunction renderer.OptionName = "FootnoteIDPrefixFunction"
@@ -420,8 +420,8 @@ func (o *withFootnoteLinkTitle) SetFootnoteOption(c *FootnoteConfig) {
}
// WithFootnoteLinkTitle is a functional option that is an optional title attribute for footnote links.
func WithFootnoteLinkTitle(a []byte) FootnoteOption {
return &withFootnoteLinkTitle{a}
func WithFootnoteLinkTitle[T []byte | string](a T) FootnoteOption {
return &withFootnoteLinkTitle{[]byte(a)}
}
const optFootnoteBacklinkTitle renderer.OptionName = "FootnoteBacklinkTitle"
@@ -439,8 +439,8 @@ func (o *withFootnoteBacklinkTitle) SetFootnoteOption(c *FootnoteConfig) {
}
// WithFootnoteBacklinkTitle is a functional option that is an optional title attribute for footnote backlinks.
func WithFootnoteBacklinkTitle(a []byte) FootnoteOption {
return &withFootnoteBacklinkTitle{a}
func WithFootnoteBacklinkTitle[T []byte | string](a T) FootnoteOption {
return &withFootnoteBacklinkTitle{[]byte(a)}
}
const optFootnoteLinkClass renderer.OptionName = "FootnoteLinkClass"
@@ -458,8 +458,8 @@ func (o *withFootnoteLinkClass) SetFootnoteOption(c *FootnoteConfig) {
}
// WithFootnoteLinkClass is a functional option that is a class for footnote links.
func WithFootnoteLinkClass(a []byte) FootnoteOption {
return &withFootnoteLinkClass{a}
func WithFootnoteLinkClass[T []byte | string](a T) FootnoteOption {
return &withFootnoteLinkClass{[]byte(a)}
}
const optFootnoteBacklinkClass renderer.OptionName = "FootnoteBacklinkClass"
@@ -477,8 +477,8 @@ func (o *withFootnoteBacklinkClass) SetFootnoteOption(c *FootnoteConfig) {
}
// WithFootnoteBacklinkClass is a functional option that is a class for footnote backlinks.
func WithFootnoteBacklinkClass(a []byte) FootnoteOption {
return &withFootnoteBacklinkClass{a}
func WithFootnoteBacklinkClass[T []byte | string](a T) FootnoteOption {
return &withFootnoteBacklinkClass{[]byte(a)}
}
const optFootnoteBacklinkHTML renderer.OptionName = "FootnoteBacklinkHTML"
@@ -496,8 +496,8 @@ func (o *withFootnoteBacklinkHTML) SetFootnoteOption(c *FootnoteConfig) {
}
// WithFootnoteBacklinkHTML is an HTML content for footnote backlinks.
func WithFootnoteBacklinkHTML(a []byte) FootnoteOption {
return &withFootnoteBacklinkHTML{a}
func WithFootnoteBacklinkHTML[T []byte | string](a T) FootnoteOption {
return &withFootnoteBacklinkHTML{[]byte(a)}
}
// FootnoteHTMLRenderer is a renderer.NodeRenderer implementation that

View File

@@ -66,10 +66,12 @@ func (o *withLinkifyAllowedProtocols) SetLinkifyOption(p *LinkifyConfig) {
// WithLinkifyAllowedProtocols is a functional option that specify allowed
// protocols in autolinks. Each protocol must end with ':' like
// 'http:' .
func WithLinkifyAllowedProtocols(value [][]byte) LinkifyOption {
return &withLinkifyAllowedProtocols{
value: value,
func WithLinkifyAllowedProtocols[T []byte | string](value []T) LinkifyOption {
opt := &withLinkifyAllowedProtocols{}
for _, v := range value {
opt.value = append(opt.value, []byte(v))
}
return opt
}
type withLinkifyURLRegexp struct {

View File

@@ -115,10 +115,10 @@ func (o *withTypographicSubstitutions) SetTypographerOption(p *TypographerConfig
// WithTypographicSubstitutions is a functional otpion that specify replacement text
// for punctuations.
func WithTypographicSubstitutions(values map[TypographicPunctuation][]byte) TypographerOption {
func WithTypographicSubstitutions[T []byte | string](values map[TypographicPunctuation]T) TypographerOption {
replacements := newDefaultSubstitutions()
for k, v := range values {
replacements[k] = v
replacements[k] = []byte(v)
}
return &withTypographicSubstitutions{replacements}

View File

@@ -61,8 +61,8 @@ var allowedBlockTags = map[string]bool{
"option": true,
"p": true,
"param": true,
"search": true,
"section": true,
"source": true,
"summary": true,
"table": true,
"tbody": true,

View File

@@ -58,47 +58,38 @@ var closeProcessingInstruction = []byte("?>")
var openCDATA = []byte("<![CDATA[")
var closeCDATA = []byte("]]>")
var closeDecl = []byte(">")
var emptyComment = []byte("<!---->")
var invalidComment1 = []byte("<!-->")
var invalidComment2 = []byte("<!--->")
var emptyComment1 = []byte("<!-->")
var emptyComment2 = []byte("<!--->")
var openComment = []byte("<!--")
var closeComment = []byte("-->")
var doubleHyphen = []byte("--")
func (s *rawHTMLParser) parseComment(block text.Reader, pc Context) ast.Node {
savedLine, savedSegment := block.Position()
node := ast.NewRawHTML()
line, segment := block.PeekLine()
if bytes.HasPrefix(line, emptyComment) {
node.Segments.Append(segment.WithStop(segment.Start + len(emptyComment)))
block.Advance(len(emptyComment))
if bytes.HasPrefix(line, emptyComment1) {
node.Segments.Append(segment.WithStop(segment.Start + len(emptyComment1)))
block.Advance(len(emptyComment1))
return node
}
if bytes.HasPrefix(line, invalidComment1) || bytes.HasPrefix(line, invalidComment2) {
return nil
if bytes.HasPrefix(line, emptyComment2) {
node.Segments.Append(segment.WithStop(segment.Start + len(emptyComment2)))
block.Advance(len(emptyComment2))
return node
}
offset := len(openComment)
line = line[offset:]
for {
hindex := bytes.Index(line, doubleHyphen)
if hindex > -1 {
hindex += offset
}
index := bytes.Index(line, closeComment) + offset
if index > -1 && hindex == index {
if index == 0 || len(line) < 2 || line[index-offset-1] != '-' {
node.Segments.Append(segment.WithStop(segment.Start + index + len(closeComment)))
block.Advance(index + len(closeComment))
return node
}
}
if hindex > 0 {
break
index := bytes.Index(line, closeComment)
if index > -1 {
node.Segments.Append(segment.WithStop(segment.Start + offset + index + len(closeComment)))
block.Advance(offset + index + len(closeComment))
return node
}
offset = 0
node.Segments.Append(segment)
block.AdvanceLine()
line, segment = block.PeekLine()
offset = 0
if line == nil {
break
}

View File

@@ -808,7 +808,7 @@ func IsPunct(c byte) bool {
// IsPunctRune returns true if the given rune is a punctuation, otherwise false.
func IsPunctRune(r rune) bool {
return int32(r) <= 256 && IsPunct(byte(r)) || unicode.IsPunct(r)
return unicode.IsSymbol(r) || unicode.IsPunct(r)
}
// IsSpace returns true if the given character is a space, otherwise false.

View File

@@ -1,661 +1,166 @@
GNU AFFERO GENERAL PUBLIC LICENSE
Version 3, 19 November 2007
GNU LESSER GENERAL PUBLIC LICENSE
Version 3, 29 June 2007
Copyright (C) 2007 Free Software Foundation, Inc. <https://fsf.org/>
Everyone is permitted to copy and distribute verbatim copies
of this license document, but changing it is not allowed.
Preamble
The GNU Affero General Public License is a free, copyleft license for
software and other kinds of works, specifically designed to ensure
cooperation with the community in the case of network server software.
The licenses for most software and other practical works are designed
to take away your freedom to share and change the works. By contrast,
our General Public Licenses are intended to guarantee your freedom to
share and change all versions of a program--to make sure it remains free
software for all its users.
When we speak of free software, we are referring to freedom, not
price. Our General Public Licenses are designed to make sure that you
have the freedom to distribute copies of free software (and charge for
them if you wish), that you receive source code or can get it if you
want it, that you can change the software or use pieces of it in new
free programs, and that you know you can do these things.
Developers that use our General Public Licenses protect your rights
with two steps: (1) assert copyright on the software, and (2) offer
you this License which gives you legal permission to copy, distribute
and/or modify the software.
A secondary benefit of defending all users' freedom is that
improvements made in alternate versions of the program, if they
receive widespread use, become available for other developers to
incorporate. Many developers of free software are heartened and
encouraged by the resulting cooperation. However, in the case of
software used on network servers, this result may fail to come about.
The GNU General Public License permits making a modified version and
letting the public access it on a server without ever releasing its
source code to the public.
The GNU Affero General Public License is designed specifically to
ensure that, in such cases, the modified source code becomes available
to the community. It requires the operator of a network server to
provide the source code of the modified version running there to the
users of that server. Therefore, public use of a modified version, on
a publicly accessible server, gives the public access to the source
code of the modified version.
An older license, called the Affero General Public License and
published by Affero, was designed to accomplish similar goals. This is
a different license, not a version of the Affero GPL, but Affero has
released a new version of the Affero GPL which permits relicensing under
this license.
The precise terms and conditions for copying, distribution and
modification follow.
TERMS AND CONDITIONS
0. Definitions.
"This License" refers to version 3 of the GNU Affero General Public License.
"Copyright" also means copyright-like laws that apply to other kinds of
works, such as semiconductor masks.
"The Program" refers to any copyrightable work licensed under this
License. Each licensee is addressed as "you". "Licensees" and
"recipients" may be individuals or organizations.
To "modify" a work means to copy from or adapt all or part of the work
in a fashion requiring copyright permission, other than the making of an
exact copy. The resulting work is called a "modified version" of the
earlier work or a work "based on" the earlier work.
A "covered work" means either the unmodified Program or a work based
on the Program.
To "propagate" a work means to do anything with it that, without
permission, would make you directly or secondarily liable for
infringement under applicable copyright law, except executing it on a
computer or modifying a private copy. Propagation includes copying,
distribution (with or without modification), making available to the
public, and in some countries other activities as well.
To "convey" a work means any kind of propagation that enables other
parties to make or receive copies. Mere interaction with a user through
a computer network, with no transfer of a copy, is not conveying.
An interactive user interface displays "Appropriate Legal Notices"
to the extent that it includes a convenient and prominently visible
feature that (1) displays an appropriate copyright notice, and (2)
tells the user that there is no warranty for the work (except to the
extent that warranties are provided), that licensees may convey the
work under this License, and how to view a copy of this License. If
the interface presents a list of user commands or options, such as a
menu, a prominent item in the list meets this criterion.
1. Source Code.
The "source code" for a work means the preferred form of the work
for making modifications to it. "Object code" means any non-source
form of a work.
A "Standard Interface" means an interface that either is an official
standard defined by a recognized standards body, or, in the case of
interfaces specified for a particular programming language, one that
is widely used among developers working in that language.
The "System Libraries" of an executable work include anything, other
than the work as a whole, that (a) is included in the normal form of
packaging a Major Component, but which is not part of that Major
Component, and (b) serves only to enable use of the work with that
Major Component, or to implement a Standard Interface for which an
implementation is available to the public in source code form. A
"Major Component", in this context, means a major essential component
(kernel, window system, and so on) of the specific operating system
(if any) on which the executable work runs, or a compiler used to
produce the work, or an object code interpreter used to run it.
The "Corresponding Source" for a work in object code form means all
the source code needed to generate, install, and (for an executable
work) run the object code and to modify the work, including scripts to
control those activities. However, it does not include the work's
System Libraries, or general-purpose tools or generally available free
programs which are used unmodified in performing those activities but
which are not part of the work. For example, Corresponding Source
includes interface definition files associated with source files for
the work, and the source code for shared libraries and dynamically
linked subprograms that the work is specifically designed to require,
such as by intimate data communication or control flow between those
subprograms and other parts of the work.
The Corresponding Source need not include anything that users
can regenerate automatically from other parts of the Corresponding
Source.
The Corresponding Source for a work in source code form is that
same work.
2. Basic Permissions.
All rights granted under this License are granted for the term of
copyright on the Program, and are irrevocable provided the stated
conditions are met. This License explicitly affirms your unlimited
permission to run the unmodified Program. The output from running a
covered work is covered by this License only if the output, given its
content, constitutes a covered work. This License acknowledges your
rights of fair use or other equivalent, as provided by copyright law.
You may make, run and propagate covered works that you do not
convey, without conditions so long as your license otherwise remains
in force. You may convey covered works to others for the sole purpose
of having them make modifications exclusively for you, or provide you
with facilities for running those works, provided that you comply with
the terms of this License in conveying all material for which you do
not control copyright. Those thus making or running the covered works
for you must do so exclusively on your behalf, under your direction
and control, on terms that prohibit them from making any copies of
your copyrighted material outside their relationship with you.
Conveying under any other circumstances is permitted solely under
the conditions stated below. Sublicensing is not allowed; section 10
makes it unnecessary.
3. Protecting Users' Legal Rights From Anti-Circumvention Law.
No covered work shall be deemed part of an effective technological
measure under any applicable law fulfilling obligations under article
11 of the WIPO copyright treaty adopted on 20 December 1996, or
similar laws prohibiting or restricting circumvention of such
measures.
When you convey a covered work, you waive any legal power to forbid
circumvention of technological measures to the extent such circumvention
is effected by exercising rights under this License with respect to
the covered work, and you disclaim any intention to limit operation or
modification of the work as a means of enforcing, against the work's
users, your or third parties' legal rights to forbid circumvention of
technological measures.
4. Conveying Verbatim Copies.
You may convey verbatim copies of the Program's source code as you
receive it, in any medium, provided that you conspicuously and
appropriately publish on each copy an appropriate copyright notice;
keep intact all notices stating that this License and any
non-permissive terms added in accord with section 7 apply to the code;
keep intact all notices of the absence of any warranty; and give all
recipients a copy of this License along with the Program.
You may charge any price or no price for each copy that you convey,
and you may offer support or warranty protection for a fee.
5. Conveying Modified Source Versions.
You may convey a work based on the Program, or the modifications to
produce it from the Program, in the form of source code under the
terms of section 4, provided that you also meet all of these conditions:
a) The work must carry prominent notices stating that you modified
it, and giving a relevant date.
b) The work must carry prominent notices stating that it is
released under this License and any conditions added under section
7. This requirement modifies the requirement in section 4 to
"keep intact all notices".
c) You must license the entire work, as a whole, under this
License to anyone who comes into possession of a copy. This
License will therefore apply, along with any applicable section 7
additional terms, to the whole of the work, and all its parts,
regardless of how they are packaged. This License gives no
permission to license the work in any other way, but it does not
invalidate such permission if you have separately received it.
d) If the work has interactive user interfaces, each must display
Appropriate Legal Notices; however, if the Program has interactive
interfaces that do not display Appropriate Legal Notices, your
work need not make them do so.
A compilation of a covered work with other separate and independent
works, which are not by their nature extensions of the covered work,
and which are not combined with it such as to form a larger program,
in or on a volume of a storage or distribution medium, is called an
"aggregate" if the compilation and its resulting copyright are not
used to limit the access or legal rights of the compilation's users
beyond what the individual works permit. Inclusion of a covered work
in an aggregate does not cause this License to apply to the other
parts of the aggregate.
6. Conveying Non-Source Forms.
You may convey a covered work in object code form under the terms
of sections 4 and 5, provided that you also convey the
machine-readable Corresponding Source under the terms of this License,
in one of these ways:
a) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by the
Corresponding Source fixed on a durable physical medium
customarily used for software interchange.
b) Convey the object code in, or embodied in, a physical product
(including a physical distribution medium), accompanied by a
written offer, valid for at least three years and valid for as
long as you offer spare parts or customer support for that product
model, to give anyone who possesses the object code either (1) a
copy of the Corresponding Source for all the software in the
product that is covered by this License, on a durable physical
medium customarily used for software interchange, for a price no
more than your reasonable cost of physically performing this
conveying of source, or (2) access to copy the
Corresponding Source from a network server at no charge.
c) Convey individual copies of the object code with a copy of the
written offer to provide the Corresponding Source. This
alternative is allowed only occasionally and noncommercially, and
only if you received the object code with such an offer, in accord
with subsection 6b.
d) Convey the object code by offering access from a designated
place (gratis or for a charge), and offer equivalent access to the
Corresponding Source in the same way through the same place at no
further charge. You need not require recipients to copy the
Corresponding Source along with the object code. If the place to
copy the object code is a network server, the Corresponding Source
may be on a different server (operated by you or a third party)
that supports equivalent copying facilities, provided you maintain
clear directions next to the object code saying where to find the
Corresponding Source. Regardless of what server hosts the
Corresponding Source, you remain obligated to ensure that it is
available for as long as needed to satisfy these requirements.
e) Convey the object code using peer-to-peer transmission, provided
you inform other peers where the object code and Corresponding
Source of the work are being offered to the general public at no
charge under subsection 6d.
A separable portion of the object code, whose source code is excluded
from the Corresponding Source as a System Library, need not be
included in conveying the object code work.
A "User Product" is either (1) a "consumer product", which means any
tangible personal property which is normally used for personal, family,
or household purposes, or (2) anything designed or sold for incorporation
into a dwelling. In determining whether a product is a consumer product,
doubtful cases shall be resolved in favor of coverage. For a particular
product received by a particular user, "normally used" refers to a
typical or common use of that class of product, regardless of the status
of the particular user or of the way in which the particular user
actually uses, or expects or is expected to use, the product. A product
is a consumer product regardless of whether the product has substantial
commercial, industrial or non-consumer uses, unless such uses represent
the only significant mode of use of the product.
"Installation Information" for a User Product means any methods,
procedures, authorization keys, or other information required to install
and execute modified versions of a covered work in that User Product from
a modified version of its Corresponding Source. The information must
suffice to ensure that the continued functioning of the modified object
code is in no case prevented or interfered with solely because
modification has been made.
If you convey an object code work under this section in, or with, or
specifically for use in, a User Product, and the conveying occurs as
part of a transaction in which the right of possession and use of the
User Product is transferred to the recipient in perpetuity or for a
fixed term (regardless of how the transaction is characterized), the
Corresponding Source conveyed under this section must be accompanied
by the Installation Information. But this requirement does not apply
if neither you nor any third party retains the ability to install
modified object code on the User Product (for example, the work has
been installed in ROM).
The requirement to provide Installation Information does not include a
requirement to continue to provide support service, warranty, or updates
for a work that has been modified or installed by the recipient, or for
the User Product in which it has been modified or installed. Access to a
network may be denied when the modification itself materially and
adversely affects the operation of the network or violates the rules and
protocols for communication across the network.
Corresponding Source conveyed, and Installation Information provided,
in accord with this section must be in a format that is publicly
documented (and with an implementation available to the public in
source code form), and must require no special password or key for
unpacking, reading or copying.
7. Additional Terms.
"Additional permissions" are terms that supplement the terms of this
License by making exceptions from one or more of its conditions.
Additional permissions that are applicable to the entire Program shall
be treated as though they were included in this License, to the extent
that they are valid under applicable law. If additional permissions
apply only to part of the Program, that part may be used separately
under those permissions, but the entire Program remains governed by
this License without regard to the additional permissions.
When you convey a copy of a covered work, you may at your option
remove any additional permissions from that copy, or from any part of
it. (Additional permissions may be written to require their own
removal in certain cases when you modify the work.) You may place
additional permissions on material, added by you to a covered work,
for which you have or can give appropriate copyright permission.
Notwithstanding any other provision of this License, for material you
add to a covered work, you may (if authorized by the copyright holders of
that material) supplement the terms of this License with terms:
a) Disclaiming warranty or limiting liability differently from the
terms of sections 15 and 16 of this License; or
b) Requiring preservation of specified reasonable legal notices or
author attributions in that material or in the Appropriate Legal
Notices displayed by works containing it; or
c) Prohibiting misrepresentation of the origin of that material, or
requiring that modified versions of such material be marked in
reasonable ways as different from the original version; or
d) Limiting the use for publicity purposes of names of licensors or
authors of the material; or
e) Declining to grant rights under trademark law for use of some
trade names, trademarks, or service marks; or
f) Requiring indemnification of licensors and authors of that
material by anyone who conveys the material (or modified versions of
it) with contractual assumptions of liability to the recipient, for
any liability that these contractual assumptions directly impose on
those licensors and authors.
All other non-permissive additional terms are considered "further
restrictions" within the meaning of section 10. If the Program as you
received it, or any part of it, contains a notice stating that it is
governed by this License along with a term that is a further
restriction, you may remove that term. If a license document contains
a further restriction but permits relicensing or conveying under this
License, you may add to a covered work material governed by the terms
of that license document, provided that the further restriction does
not survive such relicensing or conveying.
If you add terms to a covered work in accord with this section, you
must place, in the relevant source files, a statement of the
additional terms that apply to those files, or a notice indicating
where to find the applicable terms.
Additional terms, permissive or non-permissive, may be stated in the
form of a separately written license, or stated as exceptions;
the above requirements apply either way.
8. Termination.
You may not propagate or modify a covered work except as expressly
provided under this License. Any attempt otherwise to propagate or
modify it is void, and will automatically terminate your rights under
this License (including any patent licenses granted under the third
paragraph of section 11).
However, if you cease all violation of this License, then your
license from a particular copyright holder is reinstated (a)
provisionally, unless and until the copyright holder explicitly and
finally terminates your license, and (b) permanently, if the copyright
holder fails to notify you of the violation by some reasonable means
prior to 60 days after the cessation.
Moreover, your license from a particular copyright holder is
reinstated permanently if the copyright holder notifies you of the
violation by some reasonable means, this is the first time you have
received notice of violation of this License (for any work) from that
copyright holder, and you cure the violation prior to 30 days after
your receipt of the notice.
Termination of your rights under this section does not terminate the
licenses of parties who have received copies or rights from you under
this License. If your rights have been terminated and not permanently
reinstated, you do not qualify to receive new licenses for the same
material under section 10.
9. Acceptance Not Required for Having Copies.
You are not required to accept this License in order to receive or
run a copy of the Program. Ancillary propagation of a covered work
occurring solely as a consequence of using peer-to-peer transmission
to receive a copy likewise does not require acceptance. However,
nothing other than this License grants you permission to propagate or
modify any covered work. These actions infringe copyright if you do
not accept this License. Therefore, by modifying or propagating a
covered work, you indicate your acceptance of this License to do so.
10. Automatic Licensing of Downstream Recipients.
Each time you convey a covered work, the recipient automatically
receives a license from the original licensors, to run, modify and
propagate that work, subject to this License. You are not responsible
for enforcing compliance by third parties with this License.
An "entity transaction" is a transaction transferring control of an
organization, or substantially all assets of one, or subdividing an
organization, or merging organizations. If propagation of a covered
work results from an entity transaction, each party to that
transaction who receives a copy of the work also receives whatever
licenses to the work the party's predecessor in interest had or could
give under the previous paragraph, plus a right to possession of the
Corresponding Source of the work from the predecessor in interest, if
the predecessor has it or can get it with reasonable efforts.
You may not impose any further restrictions on the exercise of the
rights granted or affirmed under this License. For example, you may
not impose a license fee, royalty, or other charge for exercise of
rights granted under this License, and you may not initiate litigation
(including a cross-claim or counterclaim in a lawsuit) alleging that
any patent claim is infringed by making, using, selling, offering for
sale, or importing the Program or any portion of it.
11. Patents.
A "contributor" is a copyright holder who authorizes use under this
License of the Program or a work on which the Program is based. The
work thus licensed is called the contributor's "contributor version".
A contributor's "essential patent claims" are all patent claims
owned or controlled by the contributor, whether already acquired or
hereafter acquired, that would be infringed by some manner, permitted
by this License, of making, using, or selling its contributor version,
but do not include claims that would be infringed only as a
consequence of further modification of the contributor version. For
purposes of this definition, "control" includes the right to grant
patent sublicenses in a manner consistent with the requirements of
this License.
Each contributor grants you a non-exclusive, worldwide, royalty-free
patent license under the contributor's essential patent claims, to
make, use, sell, offer for sale, import and otherwise run, modify and
propagate the contents of its contributor version.
In the following three paragraphs, a "patent license" is any express
agreement or commitment, however denominated, not to enforce a patent
(such as an express permission to practice a patent or covenant not to
sue for patent infringement). To "grant" such a patent license to a
party means to make such an agreement or commitment not to enforce a
patent against the party.
If you convey a covered work, knowingly relying on a patent license,
and the Corresponding Source of the work is not available for anyone
to copy, free of charge and under the terms of this License, through a
publicly available network server or other readily accessible means,
then you must either (1) cause the Corresponding Source to be so
available, or (2) arrange to deprive yourself of the benefit of the
patent license for this particular work, or (3) arrange, in a manner
consistent with the requirements of this License, to extend the patent
license to downstream recipients. "Knowingly relying" means you have
actual knowledge that, but for the patent license, your conveying the
covered work in a country, or your recipient's use of the covered work
in a country, would infringe one or more identifiable patents in that
country that you have reason to believe are valid.
If, pursuant to or in connection with a single transaction or
arrangement, you convey, or propagate by procuring conveyance of, a
covered work, and grant a patent license to some of the parties
receiving the covered work authorizing them to use, propagate, modify
or convey a specific copy of the covered work, then the patent license
you grant is automatically extended to all recipients of the covered
work and works based on it.
A patent license is "discriminatory" if it does not include within
the scope of its coverage, prohibits the exercise of, or is
conditioned on the non-exercise of one or more of the rights that are
specifically granted under this License. You may not convey a covered
work if you are a party to an arrangement with a third party that is
in the business of distributing software, under which you make payment
to the third party based on the extent of your activity of conveying
the work, and under which the third party grants, to any of the
parties who would receive the covered work from you, a discriminatory
patent license (a) in connection with copies of the covered work
conveyed by you (or copies made from those copies), or (b) primarily
for and in connection with specific products or compilations that
contain the covered work, unless you entered into that arrangement,
or that patent license was granted, prior to 28 March 2007.
Nothing in this License shall be construed as excluding or limiting
any implied license or other defenses to infringement that may
otherwise be available to you under applicable patent law.
12. No Surrender of Others' Freedom.
If conditions are imposed on you (whether by court order, agreement or
otherwise) that contradict the conditions of this License, they do not
excuse you from the conditions of this License. If you cannot convey a
covered work so as to satisfy simultaneously your obligations under this
License and any other pertinent obligations, then as a consequence you may
not convey it at all. For example, if you agree to terms that obligate you
to collect a royalty for further conveying from those to whom you convey
the Program, the only way you could satisfy both those terms and this
License would be to refrain entirely from conveying the Program.
13. Remote Network Interaction; Use with the GNU General Public License.
Notwithstanding any other provision of this License, if you modify the
Program, your modified version must prominently offer all users
interacting with it remotely through a computer network (if your version
supports such interaction) an opportunity to receive the Corresponding
Source of your version by providing access to the Corresponding Source
from a network server at no charge, through some standard or customary
means of facilitating copying of software. This Corresponding Source
shall include the Corresponding Source for any work covered by version 3
of the GNU General Public License that is incorporated pursuant to the
following paragraph.
Notwithstanding any other provision of this License, you have
permission to link or combine any covered work with a work licensed
under version 3 of the GNU General Public License into a single
combined work, and to convey the resulting work. The terms of this
License will continue to apply to the part which is the covered work,
but the work with which it is combined will remain governed by version
3 of the GNU General Public License.
14. Revised Versions of this License.
The Free Software Foundation may publish revised and/or new versions of
the GNU Affero General Public License from time to time. Such new versions
will be similar in spirit to the present version, but may differ in detail to
address new problems or concerns.
Each version is given a distinguishing version number. If the
Program specifies that a certain numbered version of the GNU Affero General
Public License "or any later version" applies to it, you have the
option of following the terms and conditions either of that numbered
version or of any later version published by the Free Software
Foundation. If the Program does not specify a version number of the
GNU Affero General Public License, you may choose any version ever published
by the Free Software Foundation.
If the Program specifies that a proxy can decide which future
versions of the GNU Affero General Public License can be used, that proxy's
public statement of acceptance of a version permanently authorizes you
to choose that version for the Program.
Later license versions may give you additional or different
permissions. However, no additional obligations are imposed on any
author or copyright holder as a result of your choosing to follow a
later version.
15. Disclaimer of Warranty.
THERE IS NO WARRANTY FOR THE PROGRAM, TO THE EXTENT PERMITTED BY
APPLICABLE LAW. EXCEPT WHEN OTHERWISE STATED IN WRITING THE COPYRIGHT
HOLDERS AND/OR OTHER PARTIES PROVIDE THE PROGRAM "AS IS" WITHOUT WARRANTY
OF ANY KIND, EITHER EXPRESSED OR IMPLIED, INCLUDING, BUT NOT LIMITED TO,
THE IMPLIED WARRANTIES OF MERCHANTABILITY AND FITNESS FOR A PARTICULAR
PURPOSE. THE ENTIRE RISK AS TO THE QUALITY AND PERFORMANCE OF THE PROGRAM
IS WITH YOU. SHOULD THE PROGRAM PROVE DEFECTIVE, YOU ASSUME THE COST OF
ALL NECESSARY SERVICING, REPAIR OR CORRECTION.
16. Limitation of Liability.
IN NO EVENT UNLESS REQUIRED BY APPLICABLE LAW OR AGREED TO IN WRITING
WILL ANY COPYRIGHT HOLDER, OR ANY OTHER PARTY WHO MODIFIES AND/OR CONVEYS
THE PROGRAM AS PERMITTED ABOVE, BE LIABLE TO YOU FOR DAMAGES, INCLUDING ANY
GENERAL, SPECIAL, INCIDENTAL OR CONSEQUENTIAL DAMAGES ARISING OUT OF THE
USE OR INABILITY TO USE THE PROGRAM (INCLUDING BUT NOT LIMITED TO LOSS OF
DATA OR DATA BEING RENDERED INACCURATE OR LOSSES SUSTAINED BY YOU OR THIRD
PARTIES OR A FAILURE OF THE PROGRAM TO OPERATE WITH ANY OTHER PROGRAMS),
EVEN IF SUCH HOLDER OR OTHER PARTY HAS BEEN ADVISED OF THE POSSIBILITY OF
SUCH DAMAGES.
17. Interpretation of Sections 15 and 16.
If the disclaimer of warranty and limitation of liability provided
above cannot be given local legal effect according to their terms,
reviewing courts shall apply local law that most closely approximates
an absolute waiver of all civil liability in connection with the
Program, unless a warranty or assumption of liability accompanies a
copy of the Program in return for a fee.
END OF TERMS AND CONDITIONS
How to Apply These Terms to Your New Programs
If you develop a new program, and you want it to be of the greatest
possible use to the public, the best way to achieve this is to make it
free software which everyone can redistribute and change under these terms.
To do so, attach the following notices to the program. It is safest
to attach them to the start of each source file to most effectively
state the exclusion of warranty; and each file should have at least
the "copyright" line and a pointer to where the full notice is found.
env
Copyright (C) 2022 etke.cc / Go
This program is free software: you can redistribute it and/or modify
it under the terms of the GNU Affero General Public License as published
by the Free Software Foundation, either version 3 of the License, or
(at your option) any later version.
This program is distributed in the hope that it will be useful,
but WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
GNU Affero General Public License for more details.
You should have received a copy of the GNU Affero General Public License
along with this program. If not, see <https://www.gnu.org/licenses/>.
Also add information on how to contact you by electronic and paper mail.
If your software can interact with users remotely through a computer
network, you should also make sure that it provides a way for users to
get its source. For example, if your program is a web application, its
interface could display a "Source" link that leads users to an archive
of the code. There are many ways you could offer source, and different
solutions will be better for different programs; see section 13 for the
specific requirements.
You should also get your employer (if you work as a programmer) or school,
if any, to sign a "copyright disclaimer" for the program, if necessary.
For more information on this, and how to apply and follow the GNU AGPL, see
<https://www.gnu.org/licenses/>.
This version of the GNU Lesser General Public License incorporates
the terms and conditions of version 3 of the GNU General Public
License, supplemented by the additional permissions listed below.
0. Additional Definitions.
As used herein, "this License" refers to version 3 of the GNU Lesser
General Public License, and the "GNU GPL" refers to version 3 of the GNU
General Public License.
"The Library" refers to a covered work governed by this License,
other than an Application or a Combined Work as defined below.
An "Application" is any work that makes use of an interface provided
by the Library, but which is not otherwise based on the Library.
Defining a subclass of a class defined by the Library is deemed a mode
of using an interface provided by the Library.
A "Combined Work" is a work produced by combining or linking an
Application with the Library. The particular version of the Library
with which the Combined Work was made is also called the "Linked
Version".
The "Minimal Corresponding Source" for a Combined Work means the
Corresponding Source for the Combined Work, excluding any source code
for portions of the Combined Work that, considered in isolation, are
based on the Application, and not on the Linked Version.
The "Corresponding Application Code" for a Combined Work means the
object code and/or source code for the Application, including any data
and utility programs needed for reproducing the Combined Work from the
Application, but excluding the System Libraries of the Combined Work.
1. Exception to Section 3 of the GNU GPL.
You may convey a covered work under sections 3 and 4 of this License
without being bound by section 3 of the GNU GPL.
2. Conveying Modified Versions.
If you modify a copy of the Library, and, in your modifications, a
facility refers to a function or data to be supplied by an Application
that uses the facility (other than as an argument passed when the
facility is invoked), then you may convey a copy of the modified
version:
a) under this License, provided that you make a good faith effort to
ensure that, in the event an Application does not supply the
function or data, the facility still operates, and performs
whatever part of its purpose remains meaningful, or
b) under the GNU GPL, with none of the additional permissions of
this License applicable to that copy.
3. Object Code Incorporating Material from Library Header Files.
The object code form of an Application may incorporate material from
a header file that is part of the Library. You may convey such object
code under terms of your choice, provided that, if the incorporated
material is not limited to numerical parameters, data structure
layouts and accessors, or small macros, inline functions and templates
(ten or fewer lines in length), you do both of the following:
a) Give prominent notice with each copy of the object code that the
Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the object code with a copy of the GNU GPL and this license
document.
4. Combined Works.
You may convey a Combined Work under terms of your choice that,
taken together, effectively do not restrict modification of the
portions of the Library contained in the Combined Work and reverse
engineering for debugging such modifications, if you also do each of
the following:
a) Give prominent notice with each copy of the Combined Work that
the Library is used in it and that the Library and its use are
covered by this License.
b) Accompany the Combined Work with a copy of the GNU GPL and this license
document.
c) For a Combined Work that displays copyright notices during
execution, include the copyright notice for the Library among
these notices, as well as a reference directing the user to the
copies of the GNU GPL and this license document.
d) Do one of the following:
0) Convey the Minimal Corresponding Source under the terms of this
License, and the Corresponding Application Code in a form
suitable for, and under terms that permit, the user to
recombine or relink the Application with a modified version of
the Linked Version to produce a modified Combined Work, in the
manner specified by section 6 of the GNU GPL for conveying
Corresponding Source.
1) Use a suitable shared library mechanism for linking with the
Library. A suitable mechanism is one that (a) uses at run time
a copy of the Library already present on the user's computer
system, and (b) will operate properly with a modified version
of the Library that is interface-compatible with the Linked
Version.
e) Provide Installation Information, but only if you would otherwise
be required to provide such information under section 6 of the
GNU GPL, and only to the extent that such information is
necessary to install and execute a modified version of the
Combined Work produced by recombining or relinking the
Application with a modified version of the Linked Version. (If
you use option 4d0, the Installation Information must accompany
the Minimal Corresponding Source and Corresponding Application
Code. If you use option 4d1, you must provide the Installation
Information in the manner specified by section 6 of the GNU GPL
for conveying Corresponding Source.)
5. Combined Libraries.
You may place library facilities that are a work based on the
Library side by side in a single library together with other library
facilities that are not Applications and are not covered by this
License, and convey such a combined library under terms of your
choice, if you do both of the following:
a) Accompany the combined library with a copy of the same work based
on the Library, uncombined with any other library facilities,
conveyed under the terms of this License.
b) Give prominent notice with the combined library that part of it
is a work based on the Library, and explaining where to find the
accompanying uncombined form of the same work.
6. Revised Versions of the GNU Lesser General Public License.
The Free Software Foundation may publish revised and/or new versions
of the GNU Lesser General Public License from time to time. Such new
versions will be similar in spirit to the present version, but may
differ in detail to address new problems or concerns.
Each version is given a distinguishing version number. If the
Library as you received it specifies that a certain numbered version
of the GNU Lesser General Public License "or any later version"
applies to it, you have the option of following the terms and
conditions either of that published version or of any later version
published by the Free Software Foundation. If the Library as you
received it does not specify a version number of the GNU Lesser
General Public License, you may choose any version of the GNU Lesser
General Public License ever published by the Free Software Foundation.
If the Library as you received it specifies that a proxy can decide
whether future versions of the GNU Lesser General Public License shall
apply, that proxy's public statement of acceptance of any version is
permanent authorization for you to choose that version for the
Library.

View File

@@ -14,26 +14,36 @@ func SetPrefix(prefix string) {
}
// String returns string vars
func String(shortkey string, defaultValue string) string {
func String(shortkey string, defaultValue ...string) string {
var dv string
if len(defaultValue) > 0 {
dv = defaultValue[0]
}
key := strings.ToUpper(envprefix + "_" + strings.ReplaceAll(shortkey, ".", "_"))
value := strings.TrimSpace(os.Getenv(key))
if value == "" {
return defaultValue
return dv
}
return value
}
// Int returns int vars
func Int(shortkey string, defaultValue int) int {
str := String(shortkey, "")
func Int(shortkey string, defaultValue ...int) int {
var dv int
if len(defaultValue) > 0 {
dv = defaultValue[0]
}
str := String(shortkey)
if str == "" {
return defaultValue
return dv
}
val, err := strconv.Atoi(str)
if err != nil {
return defaultValue
return dv
}
return val
@@ -41,7 +51,7 @@ func Int(shortkey string, defaultValue int) int {
// Bool returns boolean vars (1, true, yes)
func Bool(shortkey string) bool {
str := strings.ToLower(String(shortkey, ""))
str := strings.ToLower(String(shortkey))
if str == "" {
return false
}
@@ -50,7 +60,7 @@ func Bool(shortkey string) bool {
// Slice returns slice from space-separated strings, eg: export VAR="one two three" => []string{"one", "two", "three"}
func Slice(shortkey string) []string {
str := String(shortkey, "")
str := String(shortkey)
if str == "" {
return nil
}

View File

@@ -1,13 +1,14 @@
package linkpearl
import (
"context"
"strings"
"maunium.net/go/mautrix/id"
)
// GetAccountData of the user (from cache and API, with encryption support)
func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
func (l *Linkpearl) GetAccountData(ctx context.Context, name string) (map[string]string, error) {
cached, ok := l.acc.Get(name)
if ok {
if cached == nil {
@@ -17,7 +18,7 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
}
var data map[string]string
err := l.GetClient().GetAccountData(name, &data)
err := l.GetClient().GetAccountData(ctx, name, &data)
if err != nil {
data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") {
@@ -33,15 +34,15 @@ func (l *Linkpearl) GetAccountData(name string) (map[string]string, error) {
}
// SetAccountData of the user (to cache and API, with encryption support)
func (l *Linkpearl) SetAccountData(name string, data map[string]string) error {
func (l *Linkpearl) SetAccountData(ctx context.Context, name string, data map[string]string) error {
l.acc.Add(name, data)
data = l.encryptAccountData(data)
return UnwrapError(l.GetClient().SetAccountData(name, data))
return UnwrapError(l.GetClient().SetAccountData(ctx, name, data))
}
// GetRoomAccountData of the room (from cache and API, with encryption support)
func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[string]string, error) {
func (l *Linkpearl) GetRoomAccountData(ctx context.Context, roomID id.RoomID, name string) (map[string]string, error) {
key := roomID.String() + name
cached, ok := l.acc.Get(key)
if ok {
@@ -52,7 +53,7 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
}
var data map[string]string
err := l.GetClient().GetRoomAccountData(roomID, name, &data)
err := l.GetClient().GetRoomAccountData(ctx, roomID, name, &data)
if err != nil {
data = map[string]string{}
if strings.Contains(err.Error(), "M_NOT_FOUND") {
@@ -68,12 +69,12 @@ func (l *Linkpearl) GetRoomAccountData(roomID id.RoomID, name string) (map[strin
}
// SetRoomAccountData of the room (to cache and API, with encryption support)
func (l *Linkpearl) SetRoomAccountData(roomID id.RoomID, name string, data map[string]string) error {
func (l *Linkpearl) SetRoomAccountData(ctx context.Context, roomID id.RoomID, name string, data map[string]string) error {
key := roomID.String() + name
l.acc.Add(key, data)
data = l.encryptAccountData(data)
return UnwrapError(l.GetClient().SetRoomAccountData(roomID, name, data))
return UnwrapError(l.GetClient().SetRoomAccountData(ctx, roomID, name, data))
}
func (l *Linkpearl) encryptAccountData(data map[string]string) map[string]string {

View File

@@ -1,6 +1,7 @@
package linkpearl
import (
"context"
"crypto/hmac"
"crypto/sha512"
"database/sql"
@@ -25,7 +26,7 @@ type Config struct {
// JoinPermit is a callback function that tells
// if linkpearl should respond to the given "invite" event
// and join the room
JoinPermit func(*event.Event) bool
JoinPermit func(context.Context, *event.Event) bool
// AutoLeave if true, linkpearl will automatically leave empty rooms
AutoLeave bool

View File

@@ -1,6 +1,7 @@
package linkpearl
import (
"context"
"strconv"
"maunium.net/go/mautrix"
@@ -15,7 +16,7 @@ type RespThreads struct {
}
// Threads endpoint, ref: https://spec.matrix.org/v1.8/client-server-api/#get_matrixclientv1roomsroomidthreads
func (l *Linkpearl) Threads(roomID id.RoomID, fromToken ...string) (*RespThreads, error) {
func (l *Linkpearl) Threads(ctx context.Context, roomID id.RoomID, fromToken ...string) (*RespThreads, error) {
var from string
if len(fromToken) > 0 {
from = fromToken[0]
@@ -28,18 +29,18 @@ func (l *Linkpearl) Threads(roomID id.RoomID, fromToken ...string) (*RespThreads
var resp *RespThreads
urlPath := l.GetClient().BuildURLWithQuery(mautrix.ClientURLPath{"v1", "rooms", roomID, "threads"}, query)
_, err := l.GetClient().MakeRequest("GET", urlPath, nil, &resp)
_, err := l.GetClient().MakeRequest(ctx, "GET", urlPath, nil, &resp)
return resp, UnwrapError(err)
}
// FindThreadBy tries to find thread message event by field and value
func (l *Linkpearl) FindThreadBy(roomID id.RoomID, field, value string, fromToken ...string) *event.Event {
func (l *Linkpearl) FindThreadBy(ctx context.Context, roomID id.RoomID, field, value string, fromToken ...string) *event.Event {
var from string
if len(fromToken) > 0 {
from = fromToken[0]
}
resp, err := l.Threads(roomID, from)
resp, err := l.Threads(ctx, roomID, from)
err = UnwrapError(err)
if err != nil {
l.log.Warn().Err(err).Str("roomID", roomID.String()).Msg("cannot get room threads")
@@ -47,7 +48,7 @@ func (l *Linkpearl) FindThreadBy(roomID id.RoomID, field, value string, fromToke
}
for _, msg := range resp.Chunk {
evt, contains := l.eventContains(msg, field, value)
evt, contains := l.eventContains(ctx, msg, field, value)
if contains {
return evt
}
@@ -57,17 +58,17 @@ func (l *Linkpearl) FindThreadBy(roomID id.RoomID, field, value string, fromToke
return nil
}
return l.FindThreadBy(roomID, field, value, resp.NextBatch)
return l.FindThreadBy(ctx, roomID, field, value, resp.NextBatch)
}
// FindEventBy tries to find message event by field and value
func (l *Linkpearl) FindEventBy(roomID id.RoomID, field, value string, fromToken ...string) *event.Event {
func (l *Linkpearl) FindEventBy(ctx context.Context, roomID id.RoomID, field, value string, fromToken ...string) *event.Event {
var from string
if len(fromToken) > 0 {
from = fromToken[0]
}
resp, err := l.GetClient().Messages(roomID, from, "", mautrix.DirectionBackward, nil, l.eventsLimit)
resp, err := l.GetClient().Messages(ctx, roomID, from, "", mautrix.DirectionBackward, nil, l.eventsLimit)
err = UnwrapError(err)
if err != nil {
l.log.Warn().Err(err).Str("roomID", roomID.String()).Msg("cannot get room events")
@@ -75,7 +76,7 @@ func (l *Linkpearl) FindEventBy(roomID id.RoomID, field, value string, fromToken
}
for _, msg := range resp.Chunk {
evt, contains := l.eventContains(msg, field, value)
evt, contains := l.eventContains(ctx, msg, field, value)
if contains {
return evt
}
@@ -85,13 +86,13 @@ func (l *Linkpearl) FindEventBy(roomID id.RoomID, field, value string, fromToken
return nil
}
return l.FindEventBy(roomID, field, value, resp.End)
return l.FindEventBy(ctx, roomID, field, value, resp.End)
}
func (l *Linkpearl) eventContains(evt *event.Event, field, value string) (*event.Event, bool) {
func (l *Linkpearl) eventContains(ctx context.Context, evt *event.Event, field, value string) (*event.Event, bool) {
if evt.Type == event.EventEncrypted {
ParseContent(evt, &l.log)
decrypted, err := l.GetClient().Crypto.Decrypt(evt)
decrypted, err := l.GetClient().Crypto.Decrypt(ctx, evt)
if err == nil {
evt = decrypted
}

View File

@@ -2,6 +2,7 @@
package linkpearl
import (
"context"
"database/sql"
lru "github.com/hashicorp/golang-lru/v2"
@@ -31,7 +32,7 @@ type Linkpearl struct {
log zerolog.Logger
api *mautrix.Client
joinPermit func(*event.Event) bool
joinPermit func(ctx context.Context, evt *event.Event) bool
autoleave bool
maxretries int
eventsLimit int
@@ -54,7 +55,7 @@ func setDefaults(cfg *Config) {
}
if cfg.JoinPermit == nil {
// By default, we approve all join requests
cfg.JoinPermit = func(*event.Event) bool { return true }
cfg.JoinPermit = func(_ context.Context, _ *event.Event) bool { return true }
}
}
@@ -103,7 +104,7 @@ func New(cfg *Config) (*Linkpearl, error) {
return nil, err
}
lp.ch.LoginAs = cfg.LoginAs()
if err = lp.ch.Init(); err != nil {
if err = lp.ch.Init(context.Background()); err != nil {
return nil, err
}
lp.api.Crypto = lp.ch
@@ -131,16 +132,16 @@ func (l *Linkpearl) GetAccountDataCrypter() *Crypter {
}
// SetPresence (own). See https://spec.matrix.org/v1.3/client-server-api/#put_matrixclientv3presenceuseridstatus
func (l *Linkpearl) SetPresence(presence event.Presence, message string) error {
func (l *Linkpearl) SetPresence(ctx context.Context, presence event.Presence, message string) error {
req := ReqPresence{Presence: presence, StatusMsg: message}
u := l.GetClient().BuildClientURL("v3", "presence", l.GetClient().UserID, "status")
_, err := l.GetClient().MakeRequest("PUT", u, req, nil)
_, err := l.GetClient().MakeRequest(ctx, "PUT", u, req, nil)
return err
}
// SetJoinPermit sets the the join permit callback function
func (l *Linkpearl) SetJoinPermit(value func(*event.Event) bool) {
func (l *Linkpearl) SetJoinPermit(value func(context.Context, *event.Event) bool) {
l.joinPermit = value
}
@@ -152,7 +153,7 @@ func (l *Linkpearl) Start(optionalStatusMsg ...string) error {
statusMsg = optionalStatusMsg[0]
}
err := l.SetPresence(event.PresenceOnline, statusMsg)
err := l.SetPresence(context.Background(), event.PresenceOnline, statusMsg)
if err != nil {
l.log.Error().Err(err).Msg("cannot set presence")
}
@@ -165,7 +166,7 @@ func (l *Linkpearl) Start(optionalStatusMsg ...string) error {
// Stop the client
func (l *Linkpearl) Stop() {
l.log.Debug().Msg("stopping the client")
if err := l.api.SetPresence(event.PresenceOffline); err != nil {
if err := l.api.SetPresence(context.Background(), event.PresenceOffline); err != nil {
l.log.Error().Err(err).Msg("cannot set presence")
}
l.api.StopSync()

View File

@@ -1,6 +1,8 @@
package linkpearl
import (
"context"
"maunium.net/go/mautrix"
"maunium.net/go/mautrix/event"
"maunium.net/go/mautrix/format"
@@ -10,9 +12,9 @@ import (
// Send a message to the roomID and automatically try to encrypt it, if the destination room is encrypted
//
//nolint:unparam // it's public interface
func (l *Linkpearl) Send(roomID id.RoomID, content interface{}) (id.EventID, error) {
func (l *Linkpearl) Send(ctx context.Context, roomID id.RoomID, content interface{}) (id.EventID, error) {
l.log.Debug().Str("roomID", roomID.String()).Any("content", content).Msg("sending event")
resp, err := l.api.SendMessageEvent(roomID, event.EventMessage, content)
resp, err := l.api.SendMessageEvent(ctx, roomID, event.EventMessage, content)
if err != nil {
return "", UnwrapError(err)
}
@@ -20,7 +22,7 @@ func (l *Linkpearl) Send(roomID id.RoomID, content interface{}) (id.EventID, err
}
// SendNotice to a room with optional relations, markdown supported
func (l *Linkpearl) SendNotice(roomID id.RoomID, message string, relates ...*event.RelatesTo) {
func (l *Linkpearl) SendNotice(ctx context.Context, roomID id.RoomID, message string, relates ...*event.RelatesTo) {
var withRelatesTo bool
content := format.RenderMarkdown(message, true, true)
content.MsgType = event.MsgNotice
@@ -29,12 +31,12 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, message string, relates ...*eve
content.RelatesTo = relates[0]
}
_, err := l.Send(roomID, &content)
_, err := l.Send(ctx, roomID, &content)
if err != nil {
l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "1/2").Msg("cannot send a notice into the room")
if withRelatesTo {
content.RelatesTo = nil
_, err = l.Send(roomID, &content)
_, err = l.Send(ctx, roomID, &content)
if err != nil {
l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "2/2").Msg("cannot send a notice into the room even without relations")
}
@@ -43,13 +45,13 @@ func (l *Linkpearl) SendNotice(roomID id.RoomID, message string, relates ...*eve
}
// SendFile to a matrix room
func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgtype event.MessageType, relates ...*event.RelatesTo) error {
func (l *Linkpearl) SendFile(ctx context.Context, roomID id.RoomID, req *mautrix.ReqUploadMedia, msgtype event.MessageType, relates ...*event.RelatesTo) error {
var relation *event.RelatesTo
if len(relates) > 0 {
relation = relates[0]
}
resp, err := l.GetClient().UploadMedia(*req)
resp, err := l.GetClient().UploadMedia(ctx, *req)
if err != nil {
err = UnwrapError(err)
l.log.Error().Err(err).Str("file", req.FileName).Msg("cannot upload file")
@@ -62,13 +64,13 @@ func (l *Linkpearl) SendFile(roomID id.RoomID, req *mautrix.ReqUploadMedia, msgt
RelatesTo: relation,
}
_, err = l.Send(roomID, content)
_, err = l.Send(ctx, roomID, content)
err = UnwrapError(err)
if err != nil {
l.log.Error().Err(err).Str("roomID", roomID.String()).Str("retries", "1/2").Msg("cannot send file into the room")
if relation != nil {
content.RelatesTo = nil
_, err = l.Send(roomID, &content)
_, err = l.Send(ctx, roomID, &content)
err = UnwrapError(err)
if err != nil {
l.log.Error().Err(UnwrapError(err)).Str("roomID", roomID.String()).Str("retries", "2/2").Msg("cannot send file into the room even without relations")

View File

@@ -1,6 +1,7 @@
package linkpearl
import (
"context"
"strings"
"time"
@@ -28,54 +29,56 @@ func (l *Linkpearl) OnEvent(callback mautrix.EventHandler) {
func (l *Linkpearl) initSync() {
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateEncryption,
func(source mautrix.EventSource, evt *event.Event) {
go l.onEncryption(source, evt)
func(ctx context.Context, evt *event.Event) {
go l.onEncryption(ctx, evt)
},
)
l.api.Syncer.(mautrix.ExtensibleSyncer).OnEventType(
event.StateMember,
func(source mautrix.EventSource, evt *event.Event) {
go l.onMembership(source, evt)
func(ctx context.Context, evt *event.Event) {
go l.onMembership(ctx, evt)
},
)
}
func (l *Linkpearl) onMembership(src mautrix.EventSource, evt *event.Event) {
l.ch.Machine().HandleMemberEvent(src, evt)
l.api.StateStore.SetMembership(evt.RoomID, id.UserID(evt.GetStateKey()), evt.Content.AsMember().Membership)
func (l *Linkpearl) onMembership(ctx context.Context, evt *event.Event) {
l.ch.Machine().HandleMemberEvent(ctx, evt)
if err := l.api.StateStore.SetMembership(ctx, evt.RoomID, id.UserID(evt.GetStateKey()), evt.Content.AsMember().Membership); err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Str("userID", evt.GetStateKey()).Msg("cannot set membership")
}
// potentially autoaccept invites
l.onInvite(evt)
l.onInvite(ctx, evt)
// autoleave empty rooms
l.onEmpty(evt)
l.onEmpty(ctx, evt)
}
func (l *Linkpearl) onInvite(evt *event.Event) {
func (l *Linkpearl) onInvite(ctx context.Context, evt *event.Event) {
userID := l.api.UserID.String()
invite := evt.Content.AsMember().Membership == event.MembershipInvite
if !invite || evt.GetStateKey() != userID {
return
}
if l.joinPermit(evt) {
l.tryJoin(evt.RoomID, 0)
if l.joinPermit(ctx, evt) {
l.tryJoin(ctx, evt.RoomID, 0)
return
}
l.tryLeave(evt.RoomID, 0)
l.tryLeave(ctx, evt.RoomID, 0)
}
// TODO: https://spec.matrix.org/v1.8/client-server-api/#post_matrixclientv3joinroomidoralias
// endpoint supports server_name param and tells "The servers to attempt to join the room through. One of the servers must be participating in the room.",
// meaning you can specify more than 1 server. It is not clear, what format should be used "example.com,example.org", or "example.com example.org", or whatever else.
// Moreover, it is not clear if the following values can be used together with that field: l.api.UserID.Homeserver() and evt.Sender.Homeserver()
func (l *Linkpearl) tryJoin(roomID id.RoomID, retry int) {
func (l *Linkpearl) tryJoin(ctx context.Context, roomID id.RoomID, retry int) {
if retry >= l.maxretries {
return
}
_, err := l.api.JoinRoom(roomID.String(), "", nil)
_, err := l.api.JoinRoom(ctx, roomID.String(), "", nil)
err = UnwrapError(err)
if err != nil {
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot join room")
@@ -84,31 +87,31 @@ func (l *Linkpearl) tryJoin(roomID id.RoomID, retry int) {
}
time.Sleep(5 * time.Second)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to join again")
l.tryJoin(roomID, retry+1)
l.tryJoin(ctx, roomID, retry+1)
}
}
func (l *Linkpearl) tryLeave(roomID id.RoomID, retry int) {
func (l *Linkpearl) tryLeave(ctx context.Context, roomID id.RoomID, retry int) {
if retry >= l.maxretries {
return
}
_, err := l.api.LeaveRoom(roomID)
_, err := l.api.LeaveRoom(ctx, roomID)
err = UnwrapError(err)
if err != nil {
l.log.Error().Err(err).Str("roomID", roomID.String()).Msg("cannot leave room")
time.Sleep(5 * time.Second)
l.log.Error().Err(err).Str("roomID", roomID.String()).Int("retry", retry+1).Msg("trying to leave again")
l.tryLeave(roomID, retry+1)
l.tryLeave(ctx, roomID, retry+1)
}
}
func (l *Linkpearl) onEmpty(evt *event.Event) {
func (l *Linkpearl) onEmpty(ctx context.Context, evt *event.Event) {
if !l.autoleave {
return
}
members, err := l.api.StateStore.GetRoomJoinedOrInvitedMembers(evt.RoomID)
members, err := l.api.StateStore.GetRoomJoinedOrInvitedMembers(ctx, evt.RoomID)
err = UnwrapError(err)
if err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot get joined or invited members")
@@ -119,9 +122,11 @@ func (l *Linkpearl) onEmpty(evt *event.Event) {
return
}
l.tryLeave(evt.RoomID, 0)
l.tryLeave(ctx, evt.RoomID, 0)
}
func (l *Linkpearl) onEncryption(_ mautrix.EventSource, evt *event.Event) {
l.api.StateStore.SetEncryptionEvent(evt.RoomID, evt.Content.AsEncryption())
func (l *Linkpearl) onEncryption(ctx context.Context, evt *event.Event) {
if err := l.api.StateStore.SetEncryptionEvent(ctx, evt.RoomID, evt.Content.AsEncryption()); err != nil {
l.log.Error().Err(err).Str("roomID", evt.RoomID.String()).Msg("cannot set encryption event")
}
}

View File

@@ -9,6 +9,9 @@ package dbutil
import (
"context"
"database/sql"
"errors"
"strconv"
"strings"
"time"
)
@@ -19,18 +22,61 @@ type LoggingExecable struct {
db *Database
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error) {
type pqError interface {
Get(k byte) string
}
type PQErrorWithLine struct {
Underlying error
Line string
}
func (pqe *PQErrorWithLine) Error() string {
return pqe.Underlying.Error()
}
func (pqe *PQErrorWithLine) Unwrap() error {
return pqe.Underlying
}
func addErrorLine(query string, err error) error {
if err == nil {
return err
}
var pqe pqError
if !errors.As(err, &pqe) {
return err
}
pos, _ := strconv.Atoi(pqe.Get('P'))
pos--
if pos <= 0 {
return err
}
lines := strings.Split(query, "\n")
for _, line := range lines {
lineRunes := []rune(line)
if pos < len(lineRunes)+1 {
return &PQErrorWithLine{Underlying: err, Line: line}
}
pos -= len(lineRunes) + 1
}
return err
}
func (le *LoggingExecable) ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error) {
start := time.Now()
query = le.db.mutateQuery(query)
res, err := le.UnderlyingExecable.ExecContext(ctx, query, args...)
err = addErrorLine(query, err)
le.db.Log.QueryTiming(ctx, "Exec", query, args, -1, time.Since(start), err)
return res, err
}
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error) {
func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args ...any) (Rows, error) {
start := time.Now()
query = le.db.mutateQuery(query)
rows, err := le.UnderlyingExecable.QueryContext(ctx, query, args...)
err = addErrorLine(query, err)
le.db.Log.QueryTiming(ctx, "Query", query, args, -1, time.Since(start), err)
return &LoggingRows{
ctx: ctx,
@@ -42,7 +88,7 @@ func (le *LoggingExecable) QueryContext(ctx context.Context, query string, args
}, err
}
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row {
func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row {
start := time.Now()
query = le.db.mutateQuery(query)
row := le.UnderlyingExecable.QueryRowContext(ctx, query, args...)
@@ -50,18 +96,6 @@ func (le *LoggingExecable) QueryRowContext(ctx context.Context, query string, ar
return row
}
func (le *LoggingExecable) Exec(query string, args ...interface{}) (sql.Result, error) {
return le.ExecContext(context.Background(), query, args...)
}
func (le *LoggingExecable) Query(query string, args ...interface{}) (Rows, error) {
return le.QueryContext(context.Background(), query, args...)
}
func (le *LoggingExecable) QueryRow(query string, args ...interface{}) *sql.Row {
return le.QueryRowContext(context.Background(), query, args...)
}
// loggingDB is a wrapper for LoggingExecable that allows access to BeginTx.
//
// While LoggingExecable has a pointer to the database and could use BeginTx, it's not technically safe since
@@ -89,10 +123,6 @@ func (ld *loggingDB) BeginTx(ctx context.Context, opts *sql.TxOptions) (*Logging
}, nil
}
func (ld *loggingDB) Begin() (*LoggingTxn, error) {
return ld.BeginTx(context.Background(), nil)
}
type LoggingTxn struct {
LoggingExecable
UnderlyingTx *sql.Tx
@@ -129,7 +159,7 @@ type LoggingRows struct {
ctx context.Context
db *Database
query string
args []interface{}
args []any
rs Rows
start time.Time
nrows int

View File

@@ -58,7 +58,7 @@ type Rows interface {
}
type Scannable interface {
Scan(...interface{}) error
Scan(...any) error
}
// Expected implementations of Scannable
@@ -67,30 +67,16 @@ var (
_ Scannable = (Rows)(nil)
)
type UnderlyingContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type ContextExecable interface {
ExecContext(ctx context.Context, query string, args ...interface{}) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...interface{}) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...interface{}) *sql.Row
}
type UnderlyingExecable interface {
UnderlyingContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (*sql.Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (*sql.Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type Execable interface {
ContextExecable
Exec(query string, args ...interface{}) (sql.Result, error)
Query(query string, args ...interface{}) (Rows, error)
QueryRow(query string, args ...interface{}) *sql.Row
ExecContext(ctx context.Context, query string, args ...any) (sql.Result, error)
QueryContext(ctx context.Context, query string, args ...any) (Rows, error)
QueryRowContext(ctx context.Context, query string, args ...any) *sql.Row
}
type Transaction interface {
@@ -101,15 +87,15 @@ type Transaction interface {
// Expected implementations of Execable
var (
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
_ UnderlyingContextExecable = (*sql.Conn)(nil)
_ UnderlyingExecable = (*sql.Tx)(nil)
_ UnderlyingExecable = (*sql.DB)(nil)
_ UnderlyingExecable = (*sql.Conn)(nil)
_ Execable = (*LoggingExecable)(nil)
_ Transaction = (*LoggingTxn)(nil)
)
type Database struct {
loggingDB
LoggingDB loggingDB
RawDB *sql.DB
ReadOnlyDB *sql.DB
Owner string
@@ -139,7 +125,7 @@ func (db *Database) Child(versionTable string, upgradeTable UpgradeTable, log Da
}
return &Database{
RawDB: db.RawDB,
loggingDB: db.loggingDB,
LoggingDB: db.LoggingDB,
Owner: "",
VersionTable: versionTable,
UpgradeTable: upgradeTable,
@@ -164,8 +150,8 @@ func NewWithDB(db *sql.DB, rawDialect string) (*Database, error) {
IgnoreForeignTables: true,
VersionTable: "version",
}
wrappedDB.loggingDB.UnderlyingExecable = db
wrappedDB.loggingDB.db = wrappedDB
wrappedDB.LoggingDB.UnderlyingExecable = db
wrappedDB.LoggingDB.db = wrappedDB
return wrappedDB, nil
}
@@ -259,7 +245,7 @@ func NewFromConfig(owner string, cfg Config, logger DatabaseLogger) (*Database,
if roUri == "" {
uriParts := strings.Split(cfg.URI, "?")
var qs url.Values
qs := url.Values{}
if len(uriParts) == 2 {
var err error
qs, err = url.ParseQuery(uriParts[1])

72
vendor/go.mau.fi/util/dbutil/iter.go vendored Normal file
View File

@@ -0,0 +1,72 @@
// Copyright (c) 2023 Sumner Evans
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
// RowIter is a wrapper for [Rows] that allows conveniently iterating over rows
// with a predefined scanner function.
type RowIter[T any] interface {
// Iter iterates over the rows and calls the given function for each row.
//
// If the function returns false, the iteration is stopped.
// If the function returns an error, the iteration is stopped and the error is
// returned.
Iter(func(T) (bool, error)) error
// AsList collects all rows into a slice.
AsList() ([]T, error)
}
type rowIterImpl[T any] struct {
Rows
ConvertRow func(Scannable) (T, error)
}
// NewRowIter creates a new RowIter from the given Rows and scanner function.
func NewRowIter[T any](rows Rows, convertFn func(Scannable) (T, error)) RowIter[T] {
return &rowIterImpl[T]{Rows: rows, ConvertRow: convertFn}
}
func ScanSingleColumn[T any](rows Scannable) (val T, err error) {
err = rows.Scan(&val)
return
}
type NewableDataStruct[T any] interface {
DataStruct[T]
New() T
}
func ScanDataStruct[T NewableDataStruct[T]](rows Scannable) (T, error) {
var val T
return val.New().Scan(rows)
}
func (i *rowIterImpl[T]) Iter(fn func(T) (bool, error)) error {
if i == nil || i.Rows == nil {
return nil
}
defer i.Rows.Close()
for i.Rows.Next() {
if item, err := i.ConvertRow(i.Rows); err != nil {
return err
} else if cont, err := fn(item); err != nil {
return err
} else if !cont {
break
}
}
return i.Rows.Err()
}
func (i *rowIterImpl[T]) AsList() (list []T, err error) {
err = i.Iter(func(item T) (bool, error) {
list = append(list, item)
return true, nil
})
return
}

View File

@@ -10,12 +10,12 @@ import (
)
type DatabaseLogger interface {
QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error)
QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error)
WarnUnsupportedVersion(current, compat, latest int)
PrepareUpgrade(current, compat, latest int)
DoUpgrade(from, to int, message string, txn bool)
// Deprecated: legacy warning method, return errors instead
Warn(msg string, args ...interface{})
Warn(msg string, args ...any)
}
type noopLogger struct{}
@@ -25,9 +25,9 @@ var NoopLogger DatabaseLogger = &noopLogger{}
func (n noopLogger) WarnUnsupportedVersion(_, _, _ int) {}
func (n noopLogger) PrepareUpgrade(_, _, _ int) {}
func (n noopLogger) DoUpgrade(_, _ int, _ string, _ bool) {}
func (n noopLogger) Warn(msg string, args ...interface{}) {}
func (n noopLogger) Warn(msg string, args ...any) {}
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []interface{}, _ int, _ time.Duration, _ error) {
func (n noopLogger) QueryTiming(_ context.Context, _, _ string, _ []any, _ int, _ time.Duration, _ error) {
}
type zeroLogger struct {
@@ -92,7 +92,7 @@ func (z zeroLogger) DoUpgrade(from, to int, message string, txn bool) {
var whitespaceRegex = regexp.MustCompile(`\s+`)
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []interface{}, nrows int, duration time.Duration, err error) {
func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args []any, nrows int, duration time.Duration, err error) {
log := zerolog.Ctx(ctx)
if log.GetLevel() == zerolog.Disabled || log == zerolog.DefaultContextLogger {
log = z.l
@@ -124,6 +124,6 @@ func (z zeroLogger) QueryTiming(ctx context.Context, method, query string, args
}
}
func (z zeroLogger) Warn(msg string, args ...interface{}) {
z.l.Warn().Msgf(msg, args...)
func (z zeroLogger) Warn(msg string, args ...any) {
z.l.Warn().Msgf(msg, args...) // zerolog-allow-msgf
}

View File

@@ -0,0 +1,105 @@
// Copyright (c) 2023 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package dbutil
import (
"context"
"database/sql"
"errors"
"golang.org/x/exp/constraints"
)
// DataStruct is an interface for structs that represent a single database row.
type DataStruct[T any] interface {
Scan(row Scannable) (T, error)
}
// QueryHelper is a generic helper struct for SQL query execution boilerplate.
//
// After implementing the Scan and Init methods in a data struct, the query
// helper allows writing query functions in a single line.
type QueryHelper[T DataStruct[T]] struct {
db *Database
newFunc func(qh *QueryHelper[T]) T
}
func MakeQueryHelper[T DataStruct[T]](db *Database, new func(qh *QueryHelper[T]) T) *QueryHelper[T] {
return &QueryHelper[T]{db: db, newFunc: new}
}
// ValueOrErr is a helper function that returns the value if err is nil, or
// returns nil and the error if err is not nil. It can be used to avoid
// `if err != nil { return nil, err }` boilerplate in certain cases like
// DataStruct.Scan implementations.
func ValueOrErr[T any](val *T, err error) (*T, error) {
if err != nil {
return nil, err
}
return val, nil
}
// StrPtr returns a pointer to the given string, or nil if the string is empty.
func StrPtr[T ~string](val T) *string {
if val == "" {
return nil
}
strVal := string(val)
return &strVal
}
// NumPtr returns a pointer to the given number, or nil if the number is zero.
func NumPtr[T constraints.Integer | constraints.Float](val T) *T {
if val == 0 {
return nil
}
return &val
}
func (qh *QueryHelper[T]) GetDB() *Database {
return qh.db
}
func (qh *QueryHelper[T]) New() T {
return qh.newFunc(qh)
}
// Exec executes a query with ExecContext and returns the error.
//
// It omits the sql.Result return value, as it is rarely used. When the result
// is wanted, use `qh.GetDB().Exec(...)` instead, which is
// otherwise equivalent.
func (qh *QueryHelper[T]) Exec(ctx context.Context, query string, args ...any) error {
_, err := qh.db.Exec(ctx, query, args...)
return err
}
func (qh *QueryHelper[T]) scanNew(row Scannable) (T, error) {
return qh.New().Scan(row)
}
// QueryOne executes a query with QueryRowContext, uses the associated DataStruct
// to scan it, and returns the value. If the query returns no rows, it returns nil
// and no error.
func (qh *QueryHelper[T]) QueryOne(ctx context.Context, query string, args ...any) (val T, err error) {
val, err = qh.scanNew(qh.db.QueryRow(ctx, query, args...))
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
return val, err
}
// QueryMany executes a query with QueryContext, uses the associated DataStruct
// to scan each row, and returns the values. If the query returns no rows, it
// returns a non-nil zero-length slice and no error.
func (qh *QueryHelper[T]) QueryMany(ctx context.Context, query string, args ...any) ([]T, error) {
rows, err := qh.db.Query(ctx, query, args...)
if err != nil {
return nil, err
}
return NewRowIter(rows, qh.scanNew).AsList()
}

View File

@@ -32,6 +32,22 @@ const (
ContextKeyDoTxnCallerSkip
)
func (db *Database) Exec(ctx context.Context, query string, args ...any) (sql.Result, error) {
return db.Conn(ctx).ExecContext(ctx, query, args...)
}
func (db *Database) Query(ctx context.Context, query string, args ...any) (Rows, error) {
return db.Conn(ctx).QueryContext(ctx, query, args...)
}
func (db *Database) QueryRow(ctx context.Context, query string, args ...any) *sql.Row {
return db.Conn(ctx).QueryRowContext(ctx, query, args...)
}
func (db *Database) BeginTx(ctx context.Context, opts *sql.TxOptions) (*LoggingTxn, error) {
return db.LoggingDB.BeginTx(ctx, opts)
}
func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx context.Context) error) error {
if ctx.Value(ContextKeyDatabaseTransaction) != nil {
zerolog.Ctx(ctx).Trace().Msg("Already in a transaction, not creating a new one")
@@ -82,13 +98,13 @@ func (db *Database) DoTxn(ctx context.Context, opts *sql.TxOptions, fn func(ctx
return nil
}
func (db *Database) Conn(ctx context.Context) ContextExecable {
func (db *Database) Conn(ctx context.Context) Execable {
if ctx == nil {
return db
return &db.LoggingDB
}
txn, ok := ctx.Value(ContextKeyDatabaseTransaction).(Transaction)
if ok {
return txn
}
return db
return &db.LoggingDB
}

View File

@@ -7,12 +7,13 @@
package dbutil
import (
"context"
"database/sql"
"errors"
"fmt"
)
type upgradeFunc func(Execable, *Database) error
type upgradeFunc func(context.Context, *Database) error
type upgrade struct {
message string
@@ -28,19 +29,19 @@ var ErrForeignTables = errors.New("the database contains foreign tables")
var ErrNotOwned = errors.New("the database is owned by")
var ErrUnsupportedDialect = errors.New("unsupported database dialect")
func (db *Database) upgradeVersionTable() error {
if compatColumnExists, err := db.ColumnExists(nil, db.VersionTable, "compat"); err != nil {
func (db *Database) upgradeVersionTable(ctx context.Context) error {
if compatColumnExists, err := db.ColumnExists(ctx, db.VersionTable, "compat"); err != nil {
return fmt.Errorf("failed to check if version table is up to date: %w", err)
} else if !compatColumnExists {
if tableExists, err := db.TableExists(nil, db.VersionTable); err != nil {
if tableExists, err := db.TableExists(ctx, db.VersionTable); err != nil {
return fmt.Errorf("failed to check if version table exists: %w", err)
} else if !tableExists {
_, err = db.Exec(fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
_, err = db.Exec(ctx, fmt.Sprintf("CREATE TABLE %s (version INTEGER, compat INTEGER)", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to create version table: %w", err)
}
} else {
_, err = db.Exec(fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
_, err = db.Exec(ctx, fmt.Sprintf("ALTER TABLE %s ADD COLUMN compat INTEGER", db.VersionTable))
if err != nil {
return fmt.Errorf("failed to add compat column to version table: %w", err)
}
@@ -49,13 +50,13 @@ func (db *Database) upgradeVersionTable() error {
return nil
}
func (db *Database) getVersion() (version, compat int, err error) {
if err = db.upgradeVersionTable(); err != nil {
func (db *Database) getVersion(ctx context.Context) (version, compat int, err error) {
if err = db.upgradeVersionTable(ctx); err != nil {
return
}
var compatNull sql.NullInt32
err = db.QueryRow(fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
err = db.QueryRow(ctx, fmt.Sprintf("SELECT version, compat FROM %s LIMIT 1", db.VersionTable)).Scan(&version, &compatNull)
if errors.Is(err, sql.ErrNoRows) {
err = nil
}
@@ -72,15 +73,12 @@ const (
tableExistsSQLite = "SELECT EXISTS(SELECT 1 FROM sqlite_master WHERE type='table' AND tbl_name=?1)"
)
func (db *Database) TableExists(tx Execable, table string) (exists bool, err error) {
if tx == nil {
tx = db
}
func (db *Database) TableExists(ctx context.Context, table string) (exists bool, err error) {
switch db.Dialect {
case SQLite:
err = db.QueryRow(tableExistsSQLite, table).Scan(&exists)
err = db.QueryRow(ctx, tableExistsSQLite, table).Scan(&exists)
case Postgres:
err = db.QueryRow(tableExistsPostgres, table).Scan(&exists)
err = db.QueryRow(ctx, tableExistsPostgres, table).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
@@ -92,23 +90,20 @@ const (
columnExistsSQLite = "SELECT EXISTS(SELECT 1 FROM pragma_table_info(?1) WHERE name=?2)"
)
func (db *Database) ColumnExists(tx Execable, table, column string) (exists bool, err error) {
if tx == nil {
tx = db
}
func (db *Database) ColumnExists(ctx context.Context, table, column string) (exists bool, err error) {
switch db.Dialect {
case SQLite:
err = db.QueryRow(columnExistsSQLite, table, column).Scan(&exists)
err = db.QueryRow(ctx, columnExistsSQLite, table, column).Scan(&exists)
case Postgres:
err = db.QueryRow(columnExistsPostgres, table, column).Scan(&exists)
err = db.QueryRow(ctx, columnExistsPostgres, table, column).Scan(&exists)
default:
err = ErrUnsupportedDialect
}
return
}
func (db *Database) tableExistsNoError(table string) bool {
exists, err := db.TableExists(nil, table)
func (db *Database) tableExistsNoError(ctx context.Context, table string) bool {
exists, err := db.TableExists(ctx, table)
if err != nil {
panic(fmt.Errorf("failed to check if table exists: %w", err))
}
@@ -122,22 +117,22 @@ CREATE TABLE IF NOT EXISTS database_owner (
)
`
func (db *Database) checkDatabaseOwner() error {
func (db *Database) checkDatabaseOwner(ctx context.Context) error {
var owner string
if !db.IgnoreForeignTables {
if db.tableExistsNoError("state_groups_state") {
if db.tableExistsNoError(ctx, "state_groups_state") {
return fmt.Errorf("%w (found state_groups_state, likely belonging to Synapse)", ErrForeignTables)
} else if db.tableExistsNoError("roomserver_rooms") {
} else if db.tableExistsNoError(ctx, "roomserver_rooms") {
return fmt.Errorf("%w (found roomserver_rooms, likely belonging to Dendrite)", ErrForeignTables)
}
}
if db.Owner == "" {
return nil
}
if _, err := db.Exec(createOwnerTable); err != nil {
if _, err := db.Exec(ctx, createOwnerTable); err != nil {
return fmt.Errorf("failed to ensure database owner table exists: %w", err)
} else if err = db.QueryRow("SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec("INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
} else if err = db.QueryRow(ctx, "SELECT owner FROM database_owner WHERE key=0").Scan(&owner); errors.Is(err, sql.ErrNoRows) {
_, err = db.Exec(ctx, "INSERT INTO database_owner (key, owner) VALUES (0, $1)", db.Owner)
if err != nil {
return fmt.Errorf("failed to insert database owner: %w", err)
}
@@ -149,22 +144,22 @@ func (db *Database) checkDatabaseOwner() error {
return nil
}
func (db *Database) setVersion(tx Execable, version, compat int) error {
_, err := tx.Exec(fmt.Sprintf("DELETE FROM %s", db.VersionTable))
func (db *Database) setVersion(ctx context.Context, version, compat int) error {
_, err := db.Exec(ctx, fmt.Sprintf("DELETE FROM %s", db.VersionTable))
if err != nil {
return err
}
_, err = tx.Exec(fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
_, err = db.Exec(ctx, fmt.Sprintf("INSERT INTO %s (version, compat) VALUES ($1, $2)", db.VersionTable), version, compat)
return err
}
func (db *Database) Upgrade() error {
err := db.checkDatabaseOwner()
func (db *Database) Upgrade(ctx context.Context) error {
err := db.checkDatabaseOwner(ctx)
if err != nil {
return err
}
version, compat, err := db.getVersion()
version, compat, err := db.getVersion(ctx)
if err != nil {
return err
}
@@ -185,34 +180,28 @@ func (db *Database) Upgrade() error {
version++
continue
}
doUpgrade := func(ctx context.Context) error {
err = upgradeItem.fn(ctx, db)
if err != nil {
return fmt.Errorf("failed to run upgrade #%d: %w", version, err)
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(ctx, version, upgradeItem.compatVersion)
if err != nil {
return err
}
return nil
}
db.Log.DoUpgrade(logVersion, upgradeItem.upgradesTo, upgradeItem.message, upgradeItem.transaction)
var tx Transaction
var upgradeConn Execable
if upgradeItem.transaction {
tx, err = db.Begin()
if err != nil {
return err
}
upgradeConn = tx
err = db.DoTxn(ctx, nil, doUpgrade)
} else {
upgradeConn = db
err = doUpgrade(ctx)
}
err = upgradeItem.fn(upgradeConn, db)
if err != nil {
return err
}
version = upgradeItem.upgradesTo
logVersion = version
err = db.setVersion(upgradeConn, version, upgradeItem.compatVersion)
if err != nil {
return err
}
if tx != nil {
err = tx.Commit()
if err != nil {
return err
}
}
}
return nil
}

View File

@@ -8,6 +8,7 @@ package dbutil
import (
"bytes"
"context"
"errors"
"fmt"
"io/fs"
@@ -189,27 +190,27 @@ func (db *Database) filterSQLUpgrade(lines [][]byte) (string, error) {
}
func sqlUpgradeFunc(fileName string, lines [][]byte) upgradeFunc {
return func(tx Execable, db *Database) error {
return func(ctx context.Context, db *Database) error {
if skip, err := db.parseDialectFilter(lines[0]); err == nil && skip == skipNextLine {
return nil
} else if upgradeSQL, err := db.filterSQLUpgrade(lines); err != nil {
panic(fmt.Errorf("failed to parse upgrade %s: %w", fileName, err))
} else {
_, err = tx.Exec(upgradeSQL)
_, err = db.Exec(ctx, upgradeSQL)
return err
}
}
}
func splitSQLUpgradeFunc(sqliteData, postgresData string) upgradeFunc {
return func(tx Execable, database *Database) (err error) {
switch database.Dialect {
return func(ctx context.Context, db *Database) (err error) {
switch db.Dialect {
case SQLite:
_, err = tx.Exec(sqliteData)
_, err = db.Exec(ctx, sqliteData)
case Postgres:
_, err = tx.Exec(postgresData)
_, err = db.Exec(ctx, postgresData)
default:
err = fmt.Errorf("unknown dialect %s", database.Dialect)
err = fmt.Errorf("unknown dialect %s", db.Dialect)
}
return
}

23
vendor/go.mau.fi/util/exerrors/must.go vendored Normal file
View File

@@ -0,0 +1,23 @@
// Copyright (c) 2024 Tulir Asokan
//
// This Source Code Form is subject to the terms of the Mozilla Public
// License, v. 2.0. If a copy of the MPL was not distributed with this
// file, You can obtain one at http://mozilla.org/MPL/2.0/.
package exerrors
func Must[T any](val T, err error) T {
PanicIfNotNil(err)
return val
}
func Must2[T any, T2 any](val T, val2 T2, err error) (T, T2) {
PanicIfNotNil(err)
return val, val2
}
func PanicIfNotNil(err error) {
if err != nil {
panic(err)
}
}

View File

@@ -7,10 +7,16 @@
package jsontime
import (
"database/sql"
"database/sql/driver"
"encoding/json"
"errors"
"fmt"
"time"
)
var ErrNotInteger = errors.New("value is not an integer")
func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) error {
var val int64
err := json.Unmarshal(data, &val)
@@ -25,6 +31,28 @@ func parseTime(data []byte, unixConv func(int64) time.Time, into *time.Time) err
return nil
}
func anyIntegerToTime(src any, unixConv func(int64) time.Time, into *time.Time) error {
switch v := src.(type) {
case int:
*into = unixConv(int64(v))
case int8:
*into = unixConv(int64(v))
case int16:
*into = unixConv(int64(v))
case int32:
*into = unixConv(int64(v))
case int64:
*into = unixConv(int64(v))
default:
return fmt.Errorf("%w: %T", ErrNotInteger, src)
}
return nil
}
var _ sql.Scanner = &UnixMilli{}
var _ driver.Valuer = UnixMilli{}
type UnixMilli struct {
time.Time
}
@@ -40,6 +68,17 @@ func (um *UnixMilli) UnmarshalJSON(data []byte) error {
return parseTime(data, time.UnixMilli, &um.Time)
}
func (um UnixMilli) Value() (driver.Value, error) {
return um.UnixMilli(), nil
}
func (um *UnixMilli) Scan(src any) error {
return anyIntegerToTime(src, time.UnixMilli, &um.Time)
}
var _ sql.Scanner = &UnixMicro{}
var _ driver.Valuer = UnixMicro{}
type UnixMicro struct {
time.Time
}
@@ -55,6 +94,17 @@ func (um *UnixMicro) UnmarshalJSON(data []byte) error {
return parseTime(data, time.UnixMicro, &um.Time)
}
func (um UnixMicro) Value() (driver.Value, error) {
return um.UnixMicro(), nil
}
func (um *UnixMicro) Scan(src any) error {
return anyIntegerToTime(src, time.UnixMicro, &um.Time)
}
var _ sql.Scanner = &UnixNano{}
var _ driver.Valuer = UnixNano{}
type UnixNano struct {
time.Time
}
@@ -72,6 +122,16 @@ func (un *UnixNano) UnmarshalJSON(data []byte) error {
}, &un.Time)
}
func (un UnixNano) Value() (driver.Value, error) {
return un.UnixNano(), nil
}
func (un *UnixNano) Scan(src any) error {
return anyIntegerToTime(src, func(i int64) time.Time {
return time.Unix(0, i)
}, &un.Time)
}
type Unix struct {
time.Time
}
@@ -83,8 +143,21 @@ func (u Unix) MarshalJSON() ([]byte, error) {
return json.Marshal(u.Unix())
}
var _ sql.Scanner = &Unix{}
var _ driver.Valuer = Unix{}
func (u *Unix) UnmarshalJSON(data []byte) error {
return parseTime(data, func(i int64) time.Time {
return time.Unix(i, 0)
}, &u.Time)
}
func (u Unix) Value() (driver.Value, error) {
return u.Unix(), nil
}
func (u *Unix) Scan(src any) error {
return anyIntegerToTime(src, func(i int64) time.Time {
return time.Unix(i, 0)
}, &u.Time)
}

View File

@@ -199,8 +199,8 @@ TEXT ·mixBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX
MOVQ a+8(FP), AX
MOVQ b+16(FP), BX
MOVQ a+24(FP), CX
MOVQ $128, BP
MOVQ c+24(FP), CX
MOVQ $128, DI
loop:
MOVOU 0(AX), X0
@@ -213,7 +213,7 @@ loop:
ADDQ $16, BX
ADDQ $16, CX
ADDQ $16, DX
SUBQ $2, BP
SUBQ $2, DI
JA loop
RET
@@ -222,8 +222,8 @@ TEXT ·xorBlocksSSE2(SB), 4, $0-32
MOVQ out+0(FP), DX
MOVQ a+8(FP), AX
MOVQ b+16(FP), BX
MOVQ a+24(FP), CX
MOVQ $128, BP
MOVQ c+24(FP), CX
MOVQ $128, DI
loop:
MOVOU 0(AX), X0
@@ -238,6 +238,6 @@ loop:
ADDQ $16, BX
ADDQ $16, CX
ADDQ $16, DX
SUBQ $2, BP
SUBQ $2, DI
JA loop
RET

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.7 && amd64 && gc && !purego
//go:build amd64 && gc && !purego
package blake2b

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.7 && amd64 && gc && !purego
//go:build amd64 && gc && !purego
#include "textflag.h"

View File

@@ -1,24 +0,0 @@
// Copyright 2016 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.7 && amd64 && gc && !purego
package blake2b
import "golang.org/x/sys/cpu"
func init() {
useSSE4 = cpu.X86.HasSSE41
}
//go:noescape
func hashBlocksSSE4(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte)
func hashBlocks(h *[8]uint64, c *[2]uint64, flag uint64, blocks []byte) {
if useSSE4 {
hashBlocksSSE4(h, c, flag, blocks)
} else {
hashBlocksGeneric(h, c, flag, blocks)
}
}

View File

@@ -2,8 +2,6 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.9
package blake2b
import (

View File

@@ -1,39 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build !go1.13
package poly1305
// Generic fallbacks for the math/bits intrinsics, copied from
// src/math/bits/bits.go. They were added in Go 1.12, but Add64 and Sum64 had
// variable time fallbacks until Go 1.13.
func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
sum = x + y + carry
carryOut = ((x & y) | ((x | y) &^ sum)) >> 63
return
}
func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
diff = x - y - borrow
borrowOut = ((^x & y) | (^(x ^ y) & diff)) >> 63
return
}
func bitsMul64(x, y uint64) (hi, lo uint64) {
const mask32 = 1<<32 - 1
x0 := x & mask32
x1 := x >> 32
y0 := y & mask32
y1 := y >> 32
w0 := x0 * y0
t := x1*y0 + w0>>32
w1 := t & mask32
w2 := t >> 32
w1 += x0 * y1
hi = x1*y1 + w2 + w1>>32
lo = x * y
return
}

View File

@@ -1,21 +0,0 @@
// Copyright 2019 The Go Authors. All rights reserved.
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build go1.13
package poly1305
import "math/bits"
func bitsAdd64(x, y, carry uint64) (sum, carryOut uint64) {
return bits.Add64(x, y, carry)
}
func bitsSub64(x, y, borrow uint64) (diff, borrowOut uint64) {
return bits.Sub64(x, y, borrow)
}
func bitsMul64(x, y uint64) (hi, lo uint64) {
return bits.Mul64(x, y)
}

View File

@@ -7,7 +7,10 @@
package poly1305
import "encoding/binary"
import (
"encoding/binary"
"math/bits"
)
// Poly1305 [RFC 7539] is a relatively simple algorithm: the authentication tag
// for a 64 bytes message is approximately
@@ -114,13 +117,13 @@ type uint128 struct {
}
func mul64(a, b uint64) uint128 {
hi, lo := bitsMul64(a, b)
hi, lo := bits.Mul64(a, b)
return uint128{lo, hi}
}
func add128(a, b uint128) uint128 {
lo, c := bitsAdd64(a.lo, b.lo, 0)
hi, c := bitsAdd64(a.hi, b.hi, c)
lo, c := bits.Add64(a.lo, b.lo, 0)
hi, c := bits.Add64(a.hi, b.hi, c)
if c != 0 {
panic("poly1305: unexpected overflow")
}
@@ -155,8 +158,8 @@ func updateGeneric(state *macState, msg []byte) {
// hide leading zeroes. For full chunks, that's 1 << 128, so we can just
// add 1 to the most significant (2¹²⁸) limb, h2.
if len(msg) >= TagSize {
h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(msg[0:8]), 0)
h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(msg[8:16]), c)
h2 += c + 1
msg = msg[TagSize:]
@@ -165,8 +168,8 @@ func updateGeneric(state *macState, msg []byte) {
copy(buf[:], msg)
buf[len(msg)] = 1
h0, c = bitsAdd64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
h1, c = bitsAdd64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
h0, c = bits.Add64(h0, binary.LittleEndian.Uint64(buf[0:8]), 0)
h1, c = bits.Add64(h1, binary.LittleEndian.Uint64(buf[8:16]), c)
h2 += c
msg = nil
@@ -219,9 +222,9 @@ func updateGeneric(state *macState, msg []byte) {
m3 := h2r1
t0 := m0.lo
t1, c := bitsAdd64(m1.lo, m0.hi, 0)
t2, c := bitsAdd64(m2.lo, m1.hi, c)
t3, _ := bitsAdd64(m3.lo, m2.hi, c)
t1, c := bits.Add64(m1.lo, m0.hi, 0)
t2, c := bits.Add64(m2.lo, m1.hi, c)
t3, _ := bits.Add64(m3.lo, m2.hi, c)
// Now we have the result as 4 64-bit limbs, and we need to reduce it
// modulo 2¹³⁰ - 5. The special shape of this Crandall prime lets us do
@@ -243,14 +246,14 @@ func updateGeneric(state *macState, msg []byte) {
// To add c * 5 to h, we first add cc = c * 4, and then add (cc >> 2) = c.
h0, c = bitsAdd64(h0, cc.lo, 0)
h1, c = bitsAdd64(h1, cc.hi, c)
h0, c = bits.Add64(h0, cc.lo, 0)
h1, c = bits.Add64(h1, cc.hi, c)
h2 += c
cc = shiftRightBy2(cc)
h0, c = bitsAdd64(h0, cc.lo, 0)
h1, c = bitsAdd64(h1, cc.hi, c)
h0, c = bits.Add64(h0, cc.lo, 0)
h1, c = bits.Add64(h1, cc.hi, c)
h2 += c
// h2 is at most 3 + 1 + 1 = 5, making the whole of h at most
@@ -287,9 +290,9 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
// in constant time, we compute t = h - (2¹³⁰ - 5), and select h as the
// result if the subtraction underflows, and t otherwise.
hMinusP0, b := bitsSub64(h0, p0, 0)
hMinusP1, b := bitsSub64(h1, p1, b)
_, b = bitsSub64(h2, p2, b)
hMinusP0, b := bits.Sub64(h0, p0, 0)
hMinusP1, b := bits.Sub64(h1, p1, b)
_, b = bits.Sub64(h2, p2, b)
// h = h if h < p else h - p
h0 = select64(b, h0, hMinusP0)
@@ -301,8 +304,8 @@ func finalize(out *[TagSize]byte, h *[3]uint64, s *[2]uint64) {
//
// by just doing a wide addition with the 128 low bits of h and discarding
// the overflow.
h0, c := bitsAdd64(h0, s[0], 0)
h1, _ = bitsAdd64(h1, s[1], c)
h0, c := bits.Add64(h0, s[0], 0)
h1, _ = bits.Add64(h1, s[1], c)
binary.LittleEndian.PutUint64(out[0:8], h0)
binary.LittleEndian.PutUint64(out[8:16], h1)

View File

@@ -187,9 +187,11 @@ type channel struct {
pending *buffer
extPending *buffer
// windowMu protects myWindow, the flow-control window.
windowMu sync.Mutex
myWindow uint32
// windowMu protects myWindow, the flow-control window, and myConsumed,
// the number of bytes consumed since we last increased myWindow
windowMu sync.Mutex
myWindow uint32
myConsumed uint32
// writeMu serializes calls to mux.conn.writePacket() and
// protects sentClose and packetPool. This mutex must be
@@ -332,14 +334,24 @@ func (ch *channel) handleData(packet []byte) error {
return nil
}
func (c *channel) adjustWindow(n uint32) error {
func (c *channel) adjustWindow(adj uint32) error {
c.windowMu.Lock()
// Since myWindow is managed on our side, and can never exceed
// the initial window setting, we don't worry about overflow.
c.myWindow += uint32(n)
// Since myConsumed and myWindow are managed on our side, and can never
// exceed the initial window setting, we don't worry about overflow.
c.myConsumed += adj
var sendAdj uint32
if (channelWindowSize-c.myWindow > 3*c.maxIncomingPayload) ||
(c.myWindow < channelWindowSize/2) {
sendAdj = c.myConsumed
c.myConsumed = 0
c.myWindow += sendAdj
}
c.windowMu.Unlock()
if sendAdj == 0 {
return nil
}
return c.sendMessage(windowAdjustMsg{
AdditionalBytes: uint32(n),
AdditionalBytes: sendAdj,
})
}

View File

@@ -82,7 +82,7 @@ func NewClientConn(c net.Conn, addr string, config *ClientConfig) (Conn, <-chan
if err := conn.clientHandshake(addr, &fullConf); err != nil {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %v", err)
return nil, nil, nil, fmt.Errorf("ssh: handshake failed: %w", err)
}
conn.mux = newMux(conn.transport)
return conn, conn.mux.incomingChannels, conn.mux.incomingRequests, nil

View File

@@ -307,7 +307,10 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
}
var methods []string
var errSigAlgo error
for _, signer := range signers {
origSignersLen := len(signers)
for idx := 0; idx < len(signers); idx++ {
signer := signers[idx]
pub := signer.PublicKey()
as, algo, err := pickSignatureAlgorithm(signer, extensions)
if err != nil && errSigAlgo == nil {
@@ -321,6 +324,21 @@ func (cb publicKeyCallback) auth(session []byte, user string, c packetConn, rand
if err != nil {
return authFailure, nil, err
}
// OpenSSH 7.2-7.7 advertises support for rsa-sha2-256 and rsa-sha2-512
// in the "server-sig-algs" extension but doesn't support these
// algorithms for certificate authentication, so if the server rejects
// the key try to use the obtained algorithm as if "server-sig-algs" had
// not been implemented if supported from the algorithm signer.
if !ok && idx < origSignersLen && isRSACert(algo) && algo != CertAlgoRSAv01 {
if contains(as.Algorithms(), KeyAlgoRSA) {
// We retry using the compat algorithm after all signers have
// been tried normally.
signers = append(signers, &multiAlgorithmSigner{
AlgorithmSigner: as,
supportedAlgorithms: []string{KeyAlgoRSA},
})
}
}
if !ok {
continue
}

View File

@@ -127,6 +127,14 @@ func isRSA(algo string) bool {
return contains(algos, underlyingAlgo(algo))
}
func isRSACert(algo string) bool {
_, ok := certKeyAlgoNames[algo]
if !ok {
return false
}
return isRSA(algo)
}
// supportedPubKeyAuthAlgos specifies the supported client public key
// authentication algorithms. Note that this doesn't include certificate types
// since those use the underlying algorithm. This list is sent to the client if

View File

@@ -35,6 +35,16 @@ type keyingTransport interface {
// direction will be effected if a msgNewKeys message is sent
// or received.
prepareKeyChange(*algorithms, *kexResult) error
// setStrictMode sets the strict KEX mode, notably triggering
// sequence number resets on sending or receiving msgNewKeys.
// If the sequence number is already > 1 when setStrictMode
// is called, an error is returned.
setStrictMode() error
// setInitialKEXDone indicates to the transport that the initial key exchange
// was completed
setInitialKEXDone()
}
// handshakeTransport implements rekeying on top of a keyingTransport
@@ -100,6 +110,10 @@ type handshakeTransport struct {
// The session ID or nil if first kex did not complete yet.
sessionID []byte
// strictMode indicates if the other side of the handshake indicated
// that we should be following the strict KEX protocol restrictions.
strictMode bool
}
type pendingKex struct {
@@ -209,7 +223,10 @@ func (t *handshakeTransport) readLoop() {
close(t.incoming)
break
}
if p[0] == msgIgnore || p[0] == msgDebug {
// If this is the first kex, and strict KEX mode is enabled,
// we don't ignore any messages, as they may be used to manipulate
// the packet sequence numbers.
if !(t.sessionID == nil && t.strictMode) && (p[0] == msgIgnore || p[0] == msgDebug) {
continue
}
t.incoming <- p
@@ -441,6 +458,11 @@ func (t *handshakeTransport) readOnePacket(first bool) ([]byte, error) {
return successPacket, nil
}
const (
kexStrictClient = "kex-strict-c-v00@openssh.com"
kexStrictServer = "kex-strict-s-v00@openssh.com"
)
// sendKexInit sends a key change message.
func (t *handshakeTransport) sendKexInit() error {
t.mu.Lock()
@@ -454,7 +476,6 @@ func (t *handshakeTransport) sendKexInit() error {
}
msg := &kexInitMsg{
KexAlgos: t.config.KeyExchanges,
CiphersClientServer: t.config.Ciphers,
CiphersServerClient: t.config.Ciphers,
MACsClientServer: t.config.MACs,
@@ -464,6 +485,13 @@ func (t *handshakeTransport) sendKexInit() error {
}
io.ReadFull(rand.Reader, msg.Cookie[:])
// We mutate the KexAlgos slice, in order to add the kex-strict extension algorithm,
// and possibly to add the ext-info extension algorithm. Since the slice may be the
// user owned KeyExchanges, we create our own slice in order to avoid using user
// owned memory by mistake.
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+2) // room for kex-strict and ext-info
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
isServer := len(t.hostKeys) > 0
if isServer {
for _, k := range t.hostKeys {
@@ -488,17 +516,24 @@ func (t *handshakeTransport) sendKexInit() error {
msg.ServerHostKeyAlgos = append(msg.ServerHostKeyAlgos, keyFormat)
}
}
if t.sessionID == nil {
msg.KexAlgos = append(msg.KexAlgos, kexStrictServer)
}
} else {
msg.ServerHostKeyAlgos = t.hostKeyAlgorithms
// As a client we opt in to receiving SSH_MSG_EXT_INFO so we know what
// algorithms the server supports for public key authentication. See RFC
// 8308, Section 2.1.
//
// We also send the strict KEX mode extension algorithm, in order to opt
// into the strict KEX mode.
if firstKeyExchange := t.sessionID == nil; firstKeyExchange {
msg.KexAlgos = make([]string, 0, len(t.config.KeyExchanges)+1)
msg.KexAlgos = append(msg.KexAlgos, t.config.KeyExchanges...)
msg.KexAlgos = append(msg.KexAlgos, "ext-info-c")
msg.KexAlgos = append(msg.KexAlgos, kexStrictClient)
}
}
packet := Marshal(msg)
@@ -604,6 +639,13 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return err
}
if t.sessionID == nil && ((isClient && contains(serverInit.KexAlgos, kexStrictServer)) || (!isClient && contains(clientInit.KexAlgos, kexStrictClient))) {
t.strictMode = true
if err := t.conn.setStrictMode(); err != nil {
return err
}
}
// We don't send FirstKexFollows, but we handle receiving it.
//
// RFC 4253 section 7 defines the kex and the agreement method for
@@ -679,6 +721,12 @@ func (t *handshakeTransport) enterKeyExchange(otherInitPacket []byte) error {
return unexpectedMessageError(msgNewKeys, packet[0])
}
if firstKeyExchange {
// Indicates to the transport that the first key exchange is completed
// after receiving SSH_MSG_NEWKEYS.
t.conn.setInitialKEXDone()
}
return nil
}

View File

@@ -213,6 +213,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
} else {
for _, algo := range fullConf.PublicKeyAuthAlgorithms {
if !contains(supportedPubKeyAuthAlgos, algo) {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported public key authentication algorithm %s", algo)
}
}
@@ -220,6 +221,7 @@ func NewServerConn(c net.Conn, config *ServerConfig) (*ServerConn, <-chan NewCha
// Check if the config contains any unsupported key exchanges
for _, kex := range fullConf.KeyExchanges {
if _, ok := serverForbiddenKexAlgos[kex]; ok {
c.Close()
return nil, nil, nil, fmt.Errorf("ssh: unsupported key exchange %s for server", kex)
}
}
@@ -337,7 +339,7 @@ func checkSourceAddress(addr net.Addr, sourceAddrs string) error {
return fmt.Errorf("ssh: remote address %v is not allowed because of source-address restriction", addr)
}
func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *connection,
func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, token []byte, s *connection,
sessionID []byte, userAuthReq userAuthRequestMsg) (authErr error, perms *Permissions, err error) {
gssAPIServer := gssapiConfig.Server
defer gssAPIServer.DeleteSecContext()
@@ -347,7 +349,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
outToken []byte
needContinue bool
)
outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(firstToken)
outToken, srcName, needContinue, err = gssAPIServer.AcceptSecContext(token)
if err != nil {
return err, nil, nil
}
@@ -369,6 +371,7 @@ func gssExchangeToken(gssapiConfig *GSSAPIWithMICConfig, firstToken []byte, s *c
if err := Unmarshal(packet, userAuthGSSAPITokenReq); err != nil {
return nil, nil, err
}
token = userAuthGSSAPITokenReq.Token
}
packet, err := s.transport.readPacket()
if err != nil {

View File

@@ -5,6 +5,7 @@
package ssh
import (
"context"
"errors"
"fmt"
"io"
@@ -332,6 +333,40 @@ func (l *tcpListener) Addr() net.Addr {
return l.laddr
}
// DialContext initiates a connection to the addr from the remote host.
//
// The provided Context must be non-nil. If the context expires before the
// connection is complete, an error is returned. Once successfully connected,
// any expiration of the context will not affect the connection.
//
// See func Dial for additional information.
func (c *Client) DialContext(ctx context.Context, n, addr string) (net.Conn, error) {
if err := ctx.Err(); err != nil {
return nil, err
}
type connErr struct {
conn net.Conn
err error
}
ch := make(chan connErr)
go func() {
conn, err := c.Dial(n, addr)
select {
case ch <- connErr{conn, err}:
case <-ctx.Done():
if conn != nil {
conn.Close()
}
}
}()
select {
case res := <-ch:
return res.conn, res.err
case <-ctx.Done():
return nil, ctx.Err()
}
}
// Dial initiates a connection to the addr from the remote host.
// The resulting connection has a zero LocalAddr() and RemoteAddr().
func (c *Client) Dial(n, addr string) (net.Conn, error) {

View File

@@ -49,6 +49,9 @@ type transport struct {
rand io.Reader
isClient bool
io.Closer
strictMode bool
initialKEXDone bool
}
// packetCipher represents a combination of SSH encryption/MAC
@@ -74,6 +77,18 @@ type connectionState struct {
pendingKeyChange chan packetCipher
}
func (t *transport) setStrictMode() error {
if t.reader.seqNum != 1 {
return errors.New("ssh: sequence number != 1 when strict KEX mode requested")
}
t.strictMode = true
return nil
}
func (t *transport) setInitialKEXDone() {
t.initialKEXDone = true
}
// prepareKeyChange sets up key material for a keychange. The key changes in
// both directions are triggered by reading and writing a msgNewKey packet
// respectively.
@@ -112,11 +127,12 @@ func (t *transport) printPacket(p []byte, write bool) {
// Read and decrypt next packet.
func (t *transport) readPacket() (p []byte, err error) {
for {
p, err = t.reader.readPacket(t.bufReader)
p, err = t.reader.readPacket(t.bufReader, t.strictMode)
if err != nil {
break
}
if len(p) == 0 || (p[0] != msgIgnore && p[0] != msgDebug) {
// in strict mode we pass through DEBUG and IGNORE packets only during the initial KEX
if len(p) == 0 || (t.strictMode && !t.initialKEXDone) || (p[0] != msgIgnore && p[0] != msgDebug) {
break
}
}
@@ -127,7 +143,7 @@ func (t *transport) readPacket() (p []byte, err error) {
return p, err
}
func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
func (s *connectionState) readPacket(r *bufio.Reader, strictMode bool) ([]byte, error) {
packet, err := s.packetCipher.readCipherPacket(s.seqNum, r)
s.seqNum++
if err == nil && len(packet) == 0 {
@@ -140,6 +156,9 @@ func (s *connectionState) readPacket(r *bufio.Reader) ([]byte, error) {
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default:
return nil, errors.New("ssh: got bogus newkeys message")
}
@@ -170,10 +189,10 @@ func (t *transport) writePacket(packet []byte) error {
if debugTransport {
t.printPacket(packet, true)
}
return t.writer.writePacket(t.bufWriter, t.rand, packet)
return t.writer.writePacket(t.bufWriter, t.rand, packet, t.strictMode)
}
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte) error {
func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []byte, strictMode bool) error {
changeKeys := len(packet) > 0 && packet[0] == msgNewKeys
err := s.packetCipher.writeCipherPacket(s.seqNum, w, rand, packet)
@@ -188,6 +207,9 @@ func (s *connectionState) writePacket(w *bufio.Writer, rand io.Reader, packet []
select {
case cipher := <-s.pendingKeyChange:
s.packetCipher = cipher
if strictMode {
s.seqNum = 0
}
default:
panic("ssh: no key material for msgNewKeys")
}

View File

@@ -209,25 +209,37 @@ func Insert[S ~[]E, E any](s S, i int, v ...E) S {
return s
}
// Delete removes the elements s[i:j] from s, returning the modified slice.
// Delete panics if s[i:j] is not a valid slice of s.
// Delete is O(len(s)-j), so if many items must be deleted, it is better to
// make a single call deleting them all together than to delete one at a time.
// Delete might not modify the elements s[len(s)-(j-i):len(s)]. If those
// elements contain pointers you might consider zeroing those elements so that
// objects they reference can be garbage collected.
func Delete[S ~[]E, E any](s S, i, j int) S {
_ = s[i:j] // bounds check
// clearSlice sets all elements up to the length of s to the zero value of E.
// We may use the builtin clear func instead, and remove clearSlice, when upgrading
// to Go 1.21+.
func clearSlice[S ~[]E, E any](s S) {
var zero E
for i := range s {
s[i] = zero
}
}
return append(s[:i], s[j:]...)
// Delete removes the elements s[i:j] from s, returning the modified slice.
// Delete panics if j > len(s) or s[i:j] is not a valid slice of s.
// Delete is O(len(s)-i), so if many items must be deleted, it is better to
// make a single call deleting them all together than to delete one at a time.
// Delete zeroes the elements s[len(s)-(j-i):len(s)].
func Delete[S ~[]E, E any](s S, i, j int) S {
_ = s[i:j:len(s)] // bounds check
if i == j {
return s
}
oldlen := len(s)
s = append(s[:i], s[j:]...)
clearSlice(s[len(s):oldlen]) // zero/nil out the obsolete elements, for GC
return s
}
// DeleteFunc removes any elements from s for which del returns true,
// returning the modified slice.
// When DeleteFunc removes m elements, it might not modify the elements
// s[len(s)-m:len(s)]. If those elements contain pointers you might consider
// zeroing those elements so that objects they reference can be garbage
// collected.
// DeleteFunc zeroes the elements between the new length and the original length.
func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
i := IndexFunc(s, del)
if i == -1 {
@@ -240,11 +252,13 @@ func DeleteFunc[S ~[]E, E any](s S, del func(E) bool) S {
i++
}
}
clearSlice(s[i:]) // zero/nil out the obsolete elements, for GC
return s[:i]
}
// Replace replaces the elements s[i:j] by the given v, and returns the
// modified slice. Replace panics if s[i:j] is not a valid slice of s.
// When len(v) < (j-i), Replace zeroes the elements between the new length and the original length.
func Replace[S ~[]E, E any](s S, i, j int, v ...E) S {
_ = s[i:j] // verify that i:j is a valid subslice
@@ -272,6 +286,7 @@ func Replace[S ~[]E, E any](s S, i, j int, v ...E) S {
if i+len(v) != j {
copy(r[i+len(v):], s[j:])
}
clearSlice(s[tot:]) // zero/nil out the obsolete elements, for GC
return r
}
@@ -345,9 +360,7 @@ func Clone[S ~[]E, E any](s S) S {
// This is like the uniq command found on Unix.
// Compact modifies the contents of the slice s and returns the modified slice,
// which may have a smaller length.
// When Compact discards m elements in total, it might not modify the elements
// s[len(s)-m:len(s)]. If those elements contain pointers you might consider
// zeroing those elements so that objects they reference can be garbage collected.
// Compact zeroes the elements between the new length and the original length.
func Compact[S ~[]E, E comparable](s S) S {
if len(s) < 2 {
return s
@@ -361,11 +374,13 @@ func Compact[S ~[]E, E comparable](s S) S {
i++
}
}
clearSlice(s[i:]) // zero/nil out the obsolete elements, for GC
return s[:i]
}
// CompactFunc is like [Compact] but uses an equality function to compare elements.
// For runs of elements that compare equal, CompactFunc keeps the first one.
// CompactFunc zeroes the elements between the new length and the original length.
func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S {
if len(s) < 2 {
return s
@@ -379,6 +394,7 @@ func CompactFunc[S ~[]E, E any](s S, eq func(E, E) bool) S {
i++
}
}
clearSlice(s[i:]) // zero/nil out the obsolete elements, for GC
return s[:i]
}

View File

@@ -910,9 +910,6 @@ func (z *Tokenizer) readTagAttrKey() {
return
}
switch c {
case ' ', '\n', '\r', '\t', '\f', '/':
z.pendingAttr[0].end = z.raw.end - 1
return
case '=':
if z.pendingAttr[0].start+1 == z.raw.end {
// WHATWG 13.2.5.32, if we see an equals sign before the attribute name
@@ -920,7 +917,9 @@ func (z *Tokenizer) readTagAttrKey() {
continue
}
fallthrough
case '>':
case ' ', '\n', '\r', '\t', '\f', '/', '>':
// WHATWG 13.2.5.33 Attribute name state
// We need to reconsume the char in the after attribute name state to support the / character
z.raw.end--
z.pendingAttr[0].end = z.raw.end
return
@@ -939,6 +938,11 @@ func (z *Tokenizer) readTagAttrVal() {
if z.err != nil {
return
}
if c == '/' {
// WHATWG 13.2.5.34 After attribute name state
// U+002F SOLIDUS (/) - Switch to the self-closing start tag state.
return
}
if c != '=' {
z.raw.end--
return

View File

@@ -2,7 +2,7 @@
// Use of this source code is governed by a BSD-style
// license that can be found in the LICENSE file.
//go:build dragonfly || freebsd || linux || netbsd || openbsd
//go:build dragonfly || freebsd || linux || netbsd
package unix

View File

@@ -231,3 +231,8 @@ func IoctlLoopGetStatus64(fd int) (*LoopInfo64, error) {
func IoctlLoopSetStatus64(fd int, value *LoopInfo64) error {
return ioctlPtr(fd, LOOP_SET_STATUS64, unsafe.Pointer(value))
}
// IoctlLoopConfigure configures all loop device parameters in a single step
func IoctlLoopConfigure(fd int, value *LoopConfig) error {
return ioctlPtr(fd, LOOP_CONFIGURE, unsafe.Pointer(value))
}

View File

@@ -248,6 +248,7 @@ struct ltchars {
#include <linux/module.h>
#include <linux/mount.h>
#include <linux/netfilter/nfnetlink.h>
#include <linux/netfilter/nf_tables.h>
#include <linux/netlink.h>
#include <linux/net_namespace.h>
#include <linux/nfc.h>
@@ -283,10 +284,6 @@ struct ltchars {
#include <asm/termbits.h>
#endif
#ifndef MSG_FASTOPEN
#define MSG_FASTOPEN 0x20000000
#endif
#ifndef PTRACE_GETREGS
#define PTRACE_GETREGS 0xc
#endif
@@ -295,14 +292,6 @@ struct ltchars {
#define PTRACE_SETREGS 0xd
#endif
#ifndef SOL_NETLINK
#define SOL_NETLINK 270
#endif
#ifndef SOL_SMC
#define SOL_SMC 286
#endif
#ifdef SOL_BLUETOOTH
// SPARC includes this in /usr/include/sparc64-linux-gnu/bits/socket.h
// but it is already in bluetooth_linux.go
@@ -319,10 +308,23 @@ struct ltchars {
#undef TIPC_WAIT_FOREVER
#define TIPC_WAIT_FOREVER 0xffffffff
// Copied from linux/l2tp.h
// Including linux/l2tp.h here causes conflicts between linux/in.h
// and netinet/in.h included via net/route.h above.
#define IPPROTO_L2TP 115
// Copied from linux/netfilter/nf_nat.h
// Including linux/netfilter/nf_nat.h here causes conflicts between linux/in.h
// and netinet/in.h.
#define NF_NAT_RANGE_MAP_IPS (1 << 0)
#define NF_NAT_RANGE_PROTO_SPECIFIED (1 << 1)
#define NF_NAT_RANGE_PROTO_RANDOM (1 << 2)
#define NF_NAT_RANGE_PERSISTENT (1 << 3)
#define NF_NAT_RANGE_PROTO_RANDOM_FULLY (1 << 4)
#define NF_NAT_RANGE_PROTO_OFFSET (1 << 5)
#define NF_NAT_RANGE_NETMAP (1 << 6)
#define NF_NAT_RANGE_PROTO_RANDOM_ALL \
(NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PROTO_RANDOM_FULLY)
#define NF_NAT_RANGE_MASK \
(NF_NAT_RANGE_MAP_IPS | NF_NAT_RANGE_PROTO_SPECIFIED | \
NF_NAT_RANGE_PROTO_RANDOM | NF_NAT_RANGE_PERSISTENT | \
NF_NAT_RANGE_PROTO_RANDOM_FULLY | NF_NAT_RANGE_PROTO_OFFSET | \
NF_NAT_RANGE_NETMAP)
// Copied from linux/hid.h.
// Keep in sync with the size of the referenced fields.
@@ -519,6 +521,7 @@ ccflags="$@"
$2 ~ /^LOCK_(SH|EX|NB|UN)$/ ||
$2 ~ /^LO_(KEY|NAME)_SIZE$/ ||
$2 ~ /^LOOP_(CLR|CTL|GET|SET)_/ ||
$2 == "LOOP_CONFIGURE" ||
$2 ~ /^(AF|SOCK|SO|SOL|IPPROTO|IP|IPV6|TCP|MCAST|EVFILT|NOTE|SHUT|PROT|MAP|MREMAP|MFD|T?PACKET|MSG|SCM|MCL|DT|MADV|PR|LOCAL|TCPOPT|UDP)_/ ||
$2 ~ /^NFC_(GENL|PROTO|COMM|RF|SE|DIRECTION|LLCP|SOCKPROTO)_/ ||
$2 ~ /^NFC_.*_(MAX)?SIZE$/ ||
@@ -560,7 +563,7 @@ ccflags="$@"
$2 ~ /^RLIMIT_(AS|CORE|CPU|DATA|FSIZE|LOCKS|MEMLOCK|MSGQUEUE|NICE|NOFILE|NPROC|RSS|RTPRIO|RTTIME|SIGPENDING|STACK)|RLIM_INFINITY/ ||
$2 ~ /^PRIO_(PROCESS|PGRP|USER)/ ||
$2 ~ /^CLONE_[A-Z_]+/ ||
$2 !~ /^(BPF_TIMEVAL|BPF_FIB_LOOKUP_[A-Z]+)$/ &&
$2 !~ /^(BPF_TIMEVAL|BPF_FIB_LOOKUP_[A-Z]+|BPF_F_LINK)$/ &&
$2 ~ /^(BPF|DLT)_/ ||
$2 ~ /^AUDIT_/ ||
$2 ~ /^(CLOCK|TIMER)_/ ||
@@ -581,7 +584,7 @@ ccflags="$@"
$2 ~ /^KEY_(SPEC|REQKEY_DEFL)_/ ||
$2 ~ /^KEYCTL_/ ||
$2 ~ /^PERF_/ ||
$2 ~ /^SECCOMP_MODE_/ ||
$2 ~ /^SECCOMP_/ ||
$2 ~ /^SEEK_/ ||
$2 ~ /^SCHED_/ ||
$2 ~ /^SPLICE_/ ||
@@ -602,6 +605,9 @@ ccflags="$@"
$2 ~ /^FSOPT_/ ||
$2 ~ /^WDIO[CFS]_/ ||
$2 ~ /^NFN/ ||
$2 !~ /^NFT_META_IIFTYPE/ &&
$2 ~ /^NFT_/ ||
$2 ~ /^NF_NAT_/ ||
$2 ~ /^XDP_/ ||
$2 ~ /^RWF_/ ||
$2 ~ /^(HDIO|WIN|SMART)_/ ||

View File

@@ -316,7 +316,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
if err != nil {
return "", err
}
return string(buf[:vallen-1]), nil
return ByteSliceToString(buf[:vallen]), nil
}
//sys recvfrom(fd int, p []byte, flags int, from *RawSockaddrAny, fromlen *_Socklen) (n int, err error)

View File

@@ -61,15 +61,23 @@ func FanotifyMark(fd int, flags uint, mask uint64, dirFd int, pathname string) (
}
//sys fchmodat(dirfd int, path string, mode uint32) (err error)
//sys fchmodat2(dirfd int, path string, mode uint32, flags int) (err error)
func Fchmodat(dirfd int, path string, mode uint32, flags int) (err error) {
// Linux fchmodat doesn't support the flags parameter. Mimick glibc's behavior
// and check the flags. Otherwise the mode would be applied to the symlink
// destination which is not what the user expects.
if flags&^AT_SYMLINK_NOFOLLOW != 0 {
return EINVAL
} else if flags&AT_SYMLINK_NOFOLLOW != 0 {
return EOPNOTSUPP
func Fchmodat(dirfd int, path string, mode uint32, flags int) error {
// Linux fchmodat doesn't support the flags parameter, but fchmodat2 does.
// Try fchmodat2 if flags are specified.
if flags != 0 {
err := fchmodat2(dirfd, path, mode, flags)
if err == ENOSYS {
// fchmodat2 isn't available. If the flags are known to be valid,
// return EOPNOTSUPP to indicate that fchmodat doesn't support them.
if flags&^(AT_SYMLINK_NOFOLLOW|AT_EMPTY_PATH) != 0 {
return EINVAL
} else if flags&(AT_SYMLINK_NOFOLLOW|AT_EMPTY_PATH) != 0 {
return EOPNOTSUPP
}
}
return err
}
return fchmodat(dirfd, path, mode)
}
@@ -1302,7 +1310,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
return "", err
}
}
return string(buf[:vallen-1]), nil
return ByteSliceToString(buf[:vallen]), nil
}
func GetsockoptTpacketStats(fd, level, opt int) (*TpacketStats, error) {

View File

@@ -166,6 +166,20 @@ func Getresgid() (rgid, egid, sgid int) {
//sys sysctl(mib []_C_int, old *byte, oldlen *uintptr, new *byte, newlen uintptr) (err error) = SYS___SYSCTL
//sys fcntl(fd int, cmd int, arg int) (n int, err error)
//sys fcntlPtr(fd int, cmd int, arg unsafe.Pointer) (n int, err error) = SYS_FCNTL
// FcntlInt performs a fcntl syscall on fd with the provided command and argument.
func FcntlInt(fd uintptr, cmd, arg int) (int, error) {
return fcntl(int(fd), cmd, arg)
}
// FcntlFlock performs a fcntl syscall for the F_GETLK, F_SETLK or F_SETLKW command.
func FcntlFlock(fd uintptr, cmd int, lk *Flock_t) error {
_, err := fcntlPtr(int(fd), cmd, unsafe.Pointer(lk))
return err
}
//sys ppoll(fds *PollFd, nfds int, timeout *Timespec, sigmask *Sigset_t) (n int, err error)
func Ppoll(fds []PollFd, timeout *Timespec, sigmask *Sigset_t) (n int, err error) {

View File

@@ -158,7 +158,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
if err != nil {
return "", err
}
return string(buf[:vallen-1]), nil
return ByteSliceToString(buf[:vallen]), nil
}
const ImplementsGetwd = true

View File

@@ -1104,7 +1104,7 @@ func GetsockoptString(fd, level, opt int) (string, error) {
return "", err
}
return string(buf[:vallen-1]), nil
return ByteSliceToString(buf[:vallen]), nil
}
func Recvmsg(fd int, p, oob []byte, flags int) (n, oobn int, recvflags int, from Sockaddr, err error) {

View File

@@ -486,7 +486,6 @@ const (
BPF_F_ANY_ALIGNMENT = 0x2
BPF_F_BEFORE = 0x8
BPF_F_ID = 0x20
BPF_F_LINK = 0x2000
BPF_F_NETFILTER_IP_DEFRAG = 0x1
BPF_F_QUERY_EFFECTIVE = 0x1
BPF_F_REPLACE = 0x4
@@ -1786,6 +1785,8 @@ const (
LANDLOCK_ACCESS_FS_REMOVE_FILE = 0x20
LANDLOCK_ACCESS_FS_TRUNCATE = 0x4000
LANDLOCK_ACCESS_FS_WRITE_FILE = 0x2
LANDLOCK_ACCESS_NET_BIND_TCP = 0x1
LANDLOCK_ACCESS_NET_CONNECT_TCP = 0x2
LANDLOCK_CREATE_RULESET_VERSION = 0x1
LINUX_REBOOT_CMD_CAD_OFF = 0x0
LINUX_REBOOT_CMD_CAD_ON = 0x89abcdef
@@ -1802,6 +1803,7 @@ const (
LOCK_SH = 0x1
LOCK_UN = 0x8
LOOP_CLR_FD = 0x4c01
LOOP_CONFIGURE = 0x4c0a
LOOP_CTL_ADD = 0x4c80
LOOP_CTL_GET_FREE = 0x4c82
LOOP_CTL_REMOVE = 0x4c81
@@ -2127,6 +2129,60 @@ const (
NFNL_SUBSYS_QUEUE = 0x3
NFNL_SUBSYS_ULOG = 0x4
NFS_SUPER_MAGIC = 0x6969
NFT_CHAIN_FLAGS = 0x7
NFT_CHAIN_MAXNAMELEN = 0x100
NFT_CT_MAX = 0x17
NFT_DATA_RESERVED_MASK = 0xffffff00
NFT_DATA_VALUE_MAXLEN = 0x40
NFT_EXTHDR_OP_MAX = 0x4
NFT_FIB_RESULT_MAX = 0x3
NFT_INNER_MASK = 0xf
NFT_LOGLEVEL_MAX = 0x8
NFT_NAME_MAXLEN = 0x100
NFT_NG_MAX = 0x1
NFT_OBJECT_CONNLIMIT = 0x5
NFT_OBJECT_COUNTER = 0x1
NFT_OBJECT_CT_EXPECT = 0x9
NFT_OBJECT_CT_HELPER = 0x3
NFT_OBJECT_CT_TIMEOUT = 0x7
NFT_OBJECT_LIMIT = 0x4
NFT_OBJECT_MAX = 0xa
NFT_OBJECT_QUOTA = 0x2
NFT_OBJECT_SECMARK = 0x8
NFT_OBJECT_SYNPROXY = 0xa
NFT_OBJECT_TUNNEL = 0x6
NFT_OBJECT_UNSPEC = 0x0
NFT_OBJ_MAXNAMELEN = 0x100
NFT_OSF_MAXGENRELEN = 0x10
NFT_QUEUE_FLAG_BYPASS = 0x1
NFT_QUEUE_FLAG_CPU_FANOUT = 0x2
NFT_QUEUE_FLAG_MASK = 0x3
NFT_REG32_COUNT = 0x10
NFT_REG32_SIZE = 0x4
NFT_REG_MAX = 0x4
NFT_REG_SIZE = 0x10
NFT_REJECT_ICMPX_MAX = 0x3
NFT_RT_MAX = 0x4
NFT_SECMARK_CTX_MAXLEN = 0x100
NFT_SET_MAXNAMELEN = 0x100
NFT_SOCKET_MAX = 0x3
NFT_TABLE_F_MASK = 0x3
NFT_TABLE_MAXNAMELEN = 0x100
NFT_TRACETYPE_MAX = 0x3
NFT_TUNNEL_F_MASK = 0x7
NFT_TUNNEL_MAX = 0x1
NFT_TUNNEL_MODE_MAX = 0x2
NFT_USERDATA_MAXLEN = 0x100
NFT_XFRM_KEY_MAX = 0x6
NF_NAT_RANGE_MAP_IPS = 0x1
NF_NAT_RANGE_MASK = 0x7f
NF_NAT_RANGE_NETMAP = 0x40
NF_NAT_RANGE_PERSISTENT = 0x8
NF_NAT_RANGE_PROTO_OFFSET = 0x20
NF_NAT_RANGE_PROTO_RANDOM = 0x4
NF_NAT_RANGE_PROTO_RANDOM_ALL = 0x14
NF_NAT_RANGE_PROTO_RANDOM_FULLY = 0x10
NF_NAT_RANGE_PROTO_SPECIFIED = 0x2
NILFS_SUPER_MAGIC = 0x3434
NL0 = 0x0
NL1 = 0x100
@@ -2411,6 +2467,7 @@ const (
PR_MCE_KILL_GET = 0x22
PR_MCE_KILL_LATE = 0x0
PR_MCE_KILL_SET = 0x1
PR_MDWE_NO_INHERIT = 0x2
PR_MDWE_REFUSE_EXEC_GAIN = 0x1
PR_MPX_DISABLE_MANAGEMENT = 0x2c
PR_MPX_ENABLE_MANAGEMENT = 0x2b
@@ -2615,8 +2672,9 @@ const (
RTAX_FEATURES = 0xc
RTAX_FEATURE_ALLFRAG = 0x8
RTAX_FEATURE_ECN = 0x1
RTAX_FEATURE_MASK = 0xf
RTAX_FEATURE_MASK = 0x1f
RTAX_FEATURE_SACK = 0x2
RTAX_FEATURE_TCP_USEC_TS = 0x10
RTAX_FEATURE_TIMESTAMP = 0x4
RTAX_HOPLIMIT = 0xa
RTAX_INITCWND = 0xb
@@ -2859,9 +2917,38 @@ const (
SCM_RIGHTS = 0x1
SCM_TIMESTAMP = 0x1d
SC_LOG_FLUSH = 0x100000
SECCOMP_ADDFD_FLAG_SEND = 0x2
SECCOMP_ADDFD_FLAG_SETFD = 0x1
SECCOMP_FILTER_FLAG_LOG = 0x2
SECCOMP_FILTER_FLAG_NEW_LISTENER = 0x8
SECCOMP_FILTER_FLAG_SPEC_ALLOW = 0x4
SECCOMP_FILTER_FLAG_TSYNC = 0x1
SECCOMP_FILTER_FLAG_TSYNC_ESRCH = 0x10
SECCOMP_FILTER_FLAG_WAIT_KILLABLE_RECV = 0x20
SECCOMP_GET_ACTION_AVAIL = 0x2
SECCOMP_GET_NOTIF_SIZES = 0x3
SECCOMP_IOCTL_NOTIF_RECV = 0xc0502100
SECCOMP_IOCTL_NOTIF_SEND = 0xc0182101
SECCOMP_IOC_MAGIC = '!'
SECCOMP_MODE_DISABLED = 0x0
SECCOMP_MODE_FILTER = 0x2
SECCOMP_MODE_STRICT = 0x1
SECCOMP_RET_ACTION = 0x7fff0000
SECCOMP_RET_ACTION_FULL = 0xffff0000
SECCOMP_RET_ALLOW = 0x7fff0000
SECCOMP_RET_DATA = 0xffff
SECCOMP_RET_ERRNO = 0x50000
SECCOMP_RET_KILL = 0x0
SECCOMP_RET_KILL_PROCESS = 0x80000000
SECCOMP_RET_KILL_THREAD = 0x0
SECCOMP_RET_LOG = 0x7ffc0000
SECCOMP_RET_TRACE = 0x7ff00000
SECCOMP_RET_TRAP = 0x30000
SECCOMP_RET_USER_NOTIF = 0x7fc00000
SECCOMP_SET_MODE_FILTER = 0x1
SECCOMP_SET_MODE_STRICT = 0x0
SECCOMP_USER_NOTIF_FD_SYNC_WAKE_UP = 0x1
SECCOMP_USER_NOTIF_FLAG_CONTINUE = 0x1
SECRETMEM_MAGIC = 0x5345434d
SECURITYFS_MAGIC = 0x73636673
SEEK_CUR = 0x1
@@ -3021,6 +3108,7 @@ const (
SOL_TIPC = 0x10f
SOL_TLS = 0x11a
SOL_UDP = 0x11
SOL_VSOCK = 0x11f
SOL_X25 = 0x106
SOL_XDP = 0x11b
SOMAXCONN = 0x1000

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@@ -282,6 +282,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@@ -288,6 +288,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@@ -278,6 +278,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@@ -275,6 +275,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x40182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x40082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x40082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

View File

@@ -281,6 +281,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x80
SIOCATMARK = 0x40047307

View File

@@ -336,6 +336,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

View File

@@ -340,6 +340,9 @@ const (
SCM_TIMESTAMPNS = 0x23
SCM_TXTIME = 0x3d
SCM_WIFI_STATUS = 0x29
SECCOMP_IOCTL_NOTIF_ADDFD = 0x80182103
SECCOMP_IOCTL_NOTIF_ID_VALID = 0x80082102
SECCOMP_IOCTL_NOTIF_SET_FLAGS = 0x80082104
SFD_CLOEXEC = 0x80000
SFD_NONBLOCK = 0x800
SIOCATMARK = 0x8905

Some files were not shown because too many files have changed in this diff Show More