diff --git a/e2e/send-stress b/e2e/send-stress index e4b5da9..902a9df 100755 --- a/e2e/send-stress +++ b/e2e/send-stress @@ -1,6 +1,6 @@ #!/bin/bash -for i in {0..10..1}; do +for i in {0..100..1}; do echo "#${i}..." ssmtp test@localhost < $1 done diff --git a/smtp/listener.go b/smtp/listener.go new file mode 100644 index 0000000..394f9a7 --- /dev/null +++ b/smtp/listener.go @@ -0,0 +1,50 @@ +package smtp + +import ( + "net" + + "gitlab.com/etke.cc/go/logger" +) + +// Listener that rejects connections from banned hosts +type Listener struct { + log *logger.Logger + listener net.Listener + isBanned func(net.Addr) bool +} + +func NewListener(actual net.Listener, isBanned func(net.Addr) bool, log *logger.Logger) *Listener { + return &Listener{ + log: log, + listener: actual, + isBanned: isBanned, + } +} + +// Accept waits for and returns the next connection to the listener. +func (l *Listener) Accept() (net.Conn, error) { + conn, err := l.listener.Accept() + if err != nil { + return conn, err + } + if l.isBanned(conn.RemoteAddr()) { + conn.Close() + l.log.Info("rejected connection from %q (already banned)", conn.RemoteAddr()) + // Due to go-smtp design, any error returned here will crash whole server, + // thus we have to forge a connection + return &net.TCPConn{}, nil + } + + return conn, nil +} + +// Close closes the listener. +// Any blocked Accept operations will be unblocked and return errors. +func (l *Listener) Close() error { + return l.listener.Close() +} + +// Addr returns the listener's network address. +func (l *Listener) Addr() net.Addr { + return l.listener.Addr() +} diff --git a/smtp/manager.go b/smtp/manager.go index b1b69f3..7b3ca36 100644 --- a/smtp/manager.go +++ b/smtp/manager.go @@ -30,6 +30,7 @@ type Config struct { type Manager struct { log *logger.Logger + bot matrixbot smtp *smtp.Server errs chan error @@ -77,6 +78,7 @@ func NewManager(cfg *Config) *Manager { m := &Manager{ smtp: s, + bot: cfg.Bot, log: log, port: cfg.Port, tlsPort: cfg.TLSPort, @@ -118,10 +120,10 @@ func (m *Manager) listen(port string, tlsCfg *tls.Config) { m.errs <- err return } - + lwrapper := NewListener(l, m.bot.IsBanned, m.log) m.log.Info("Starting SMTP server on port %s", port) - err = m.smtp.Serve(l) + err = m.smtp.Serve(lwrapper) if err != nil { m.log.Error("cannot start SMTP server on %s: %v", port, err) m.errs <- err