From ced98e818e5966c3484c67ee4c844a8c85f0f7b5 Mon Sep 17 00:00:00 2001 From: Aine Date: Sat, 19 Nov 2022 18:20:57 +0200 Subject: [PATCH] correctly handle TCP connections without forging them for banned hosts --- smtp/listener.go | 34 ++++++++++++++++++++++------------ 1 file changed, 22 insertions(+), 12 deletions(-) diff --git a/smtp/listener.go b/smtp/listener.go index 394f9a7..7d3971e 100644 --- a/smtp/listener.go +++ b/smtp/listener.go @@ -9,6 +9,7 @@ import ( // Listener that rejects connections from banned hosts type Listener struct { log *logger.Logger + done chan struct{} listener net.Listener isBanned func(net.Addr) bool } @@ -16,6 +17,7 @@ type Listener struct { func NewListener(actual net.Listener, isBanned func(net.Addr) bool, log *logger.Logger) *Listener { return &Listener{ log: log, + done: make(chan struct{}, 1), listener: actual, isBanned: isBanned, } @@ -23,24 +25,32 @@ func NewListener(actual net.Listener, isBanned func(net.Addr) bool, log *logger. // Accept waits for and returns the next connection to the listener. func (l *Listener) Accept() (net.Conn, error) { - conn, err := l.listener.Accept() - if err != nil { - return conn, err - } - if l.isBanned(conn.RemoteAddr()) { - conn.Close() - l.log.Info("rejected connection from %q (already banned)", conn.RemoteAddr()) - // Due to go-smtp design, any error returned here will crash whole server, - // thus we have to forge a connection - return &net.TCPConn{}, nil - } + for { + conn, err := l.listener.Accept() + if err != nil { + select { + case <-l.done: + return conn, err + default: + l.log.Warn("cannot accept connection: %v", err) + continue + } + } + if l.isBanned(conn.RemoteAddr()) { + conn.Close() + l.log.Info("rejected connection from %q (already banned)", conn.RemoteAddr()) + continue + } - return conn, nil + l.log.Debug("accepted connection from %q", conn.RemoteAddr()) + return conn, nil + } } // Close closes the listener. // Any blocked Accept operations will be unblocked and return errors. func (l *Listener) Close() error { + close(l.done) return l.listener.Close() }