automatically ignore known forwarded addresses, fixes #64

This commit is contained in:
Aine
2023-09-18 12:35:37 +03:00
parent e90925eceb
commit 60b4386dd8
187 changed files with 4070 additions and 2667 deletions

View File

@@ -14,11 +14,11 @@ import (
"net/url"
"os"
"strconv"
"strings"
"sync/atomic"
"time"
"github.com/rs/zerolog"
"go.mau.fi/util/retryafter"
"maunium.net/go/maulogger/v2/maulogadapt"
"maunium.net/go/mautrix/event"
@@ -258,45 +258,52 @@ const (
LogRequestIDContextKey
)
func (cli *Client) LogRequest(req *http.Request) {
func (cli *Client) RequestStart(req *http.Request) {
if cli.RequestHook != nil {
cli.RequestHook(req)
}
evt := zerolog.Ctx(req.Context()).Debug().
Str("method", req.Method).
Str("url", req.URL.String())
body := req.Context().Value(LogBodyContextKey)
if body != nil {
evt.Interface("body", body)
}
evt.Msg("Sending request")
}
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, handlerErr error, contentLength int, duration time.Duration) {
if cli.ResponseHook != nil {
cli.ResponseHook(req, resp, duration)
func (cli *Client) LogRequestDone(req *http.Request, resp *http.Response, err error, handlerErr error, contentLength int, duration time.Duration) {
var evt *zerolog.Event
if err != nil {
evt = zerolog.Ctx(req.Context()).Err(err)
} else if handlerErr != nil {
evt = zerolog.Ctx(req.Context()).Warn().
AnErr("body_parse_err", handlerErr)
} else {
evt = zerolog.Ctx(req.Context()).Debug()
}
mime := resp.Header.Get("Content-Type")
length := resp.ContentLength
if length == -1 && contentLength > 0 {
length = int64(contentLength)
}
path := strings.TrimPrefix(req.URL.Path, cli.HomeserverURL.Path)
path = strings.TrimPrefix(path, "/_matrix/client")
evt := zerolog.Ctx(req.Context()).Debug().
evt = evt.
Str("method", req.Method).
Str("path", path).
Int("status_code", resp.StatusCode).
Int64("response_length", length).
Str("response_mime", mime).
Str("url", req.URL.String()).
Dur("duration", duration)
if handlerErr != nil {
evt.AnErr("body_parse_err", handlerErr)
if resp != nil {
if cli.ResponseHook != nil {
cli.ResponseHook(req, resp, duration)
}
mime := resp.Header.Get("Content-Type")
length := resp.ContentLength
if length == -1 && contentLength > 0 {
length = int64(contentLength)
}
evt = evt.Int("status_code", resp.StatusCode).
Int64("response_length", length).
Str("response_mime", mime)
if serverRequestID := resp.Header.Get("X-Beeper-Request-ID"); serverRequestID != "" {
evt.Str("beeper_request_id", serverRequestID)
}
}
if serverRequestID := resp.Header.Get("X-Beeper-Request-ID"); serverRequestID != "" {
evt.Str("beeper_request_id", serverRequestID)
if body := req.Context().Value(LogBodyContextKey); body != nil {
evt.Interface("req_body", body)
}
if err != nil {
evt.Msg("Request failed")
} else if handlerErr != nil {
evt.Msg("Request parsing failed")
} else {
evt.Msg("Request completed")
}
evt.Msg("Request completed")
}
func (cli *Client) MakeRequest(method string, httpURL string, reqBody interface{}, resBody interface{}) ([]byte, error) {
@@ -520,38 +527,8 @@ func ParseErrorResponse(req *http.Request, res *http.Response) ([]byte, error) {
}
}
// parseBackoffFromResponse extracts the backoff time specified in the Retry-After header if present. See
// https://developer.mozilla.org/en-US/docs/Web/HTTP/Headers/Retry-After.
func parseBackoffFromResponse(req *http.Request, res *http.Response, now time.Time, fallback time.Duration) time.Duration {
retryAfterHeaderValue := res.Header.Get("Retry-After")
if retryAfterHeaderValue == "" {
return fallback
}
if t, err := time.Parse(http.TimeFormat, retryAfterHeaderValue); err == nil {
return t.Sub(now)
}
if seconds, err := strconv.Atoi(retryAfterHeaderValue); err == nil {
return time.Duration(seconds) * time.Second
}
zerolog.Ctx(req.Context()).Warn().
Str("retry_after", retryAfterHeaderValue).
Msg("Failed to parse Retry-After header value")
return fallback
}
func (cli *Client) shouldRetry(res *http.Response) bool {
return res.StatusCode == http.StatusBadGateway ||
res.StatusCode == http.StatusServiceUnavailable ||
res.StatusCode == http.StatusGatewayTimeout ||
(res.StatusCode == http.StatusTooManyRequests && !cli.IgnoreRateLimit)
}
func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backoff time.Duration, responseJSON interface{}, handler ClientResponseHandler) ([]byte, error) {
cli.LogRequest(req)
cli.RequestStart(req)
startTime := time.Now()
res, err := cli.Client.Do(req)
duration := time.Now().Sub(startTime)
@@ -562,29 +539,29 @@ func (cli *Client) executeCompiledRequest(req *http.Request, retries int, backof
if retries > 0 {
return cli.doRetry(req, err, retries, backoff, responseJSON, handler)
}
return nil, HTTPError{
err = HTTPError{
Request: req,
Response: res,
Message: "request error",
WrappedError: err,
}
cli.LogRequestDone(req, res, err, nil, 0, duration)
return nil, err
}
if retries > 0 && cli.shouldRetry(res) {
if res.StatusCode == http.StatusTooManyRequests {
backoff = parseBackoffFromResponse(req, res, time.Now(), backoff)
}
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
return cli.doRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff, responseJSON, handler)
}
var body []byte
if res.StatusCode < 200 || res.StatusCode >= 300 {
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, nil, len(body), duration)
cli.LogRequestDone(req, res, nil, nil, len(body), duration)
} else {
body, err = handler(req, res, responseJSON)
cli.LogRequestDone(req, res, err, len(body), duration)
cli.LogRequestDone(req, res, nil, err, len(body), duration)
}
return body, err
}
@@ -1371,26 +1348,80 @@ func (cli *Client) Download(mxcURL id.ContentURI) (io.ReadCloser, error) {
}
func (cli *Client) DownloadContext(ctx context.Context, mxcURL id.ContentURI) (io.ReadCloser, error) {
_, resp, err := cli.downloadContext(ctx, mxcURL)
return resp.Body, err
resp, err := cli.downloadContext(ctx, mxcURL)
if err != nil {
return nil, err
}
return resp.Body, nil
}
func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Request, *http.Response, error) {
func (cli *Client) doMediaRetry(req *http.Request, cause error, retries int, backoff time.Duration) (*http.Response, error) {
log := zerolog.Ctx(req.Context())
if req.Body != nil {
if req.GetBody == nil {
log.Warn().Msg("Failed to get new body to retry request: GetBody is nil")
return nil, cause
}
var err error
req.Body, err = req.GetBody()
if err != nil {
log.Warn().Err(err).Msg("Failed to get new body to retry request")
return nil, cause
}
}
log.Warn().Err(cause).
Int("retry_in_seconds", int(backoff.Seconds())).
Msg("Request failed, retrying")
time.Sleep(backoff)
return cli.doMediaRequest(req, retries-1, backoff*2)
}
func (cli *Client) doMediaRequest(req *http.Request, retries int, backoff time.Duration) (*http.Response, error) {
cli.RequestStart(req)
startTime := time.Now()
res, err := cli.Client.Do(req)
duration := time.Now().Sub(startTime)
if err != nil {
if retries > 0 {
return cli.doMediaRetry(req, err, retries, backoff)
}
err = HTTPError{
Request: req,
Response: res,
Message: "request error",
WrappedError: err,
}
cli.LogRequestDone(req, res, err, nil, 0, duration)
return nil, err
}
if retries > 0 && retryafter.Should(res.StatusCode, !cli.IgnoreRateLimit) {
backoff = retryafter.Parse(res.Header.Get("Retry-After"), backoff)
return cli.doMediaRetry(req, fmt.Errorf("HTTP %d", res.StatusCode), retries, backoff)
}
if res.StatusCode < 200 || res.StatusCode >= 300 {
var body []byte
body, err = ParseErrorResponse(req, res)
cli.LogRequestDone(req, res, err, nil, len(body), duration)
} else {
cli.LogRequestDone(req, res, nil, nil, -1, duration)
}
return res, err
}
func (cli *Client) downloadContext(ctx context.Context, mxcURL id.ContentURI) (*http.Response, error) {
ctxLog := zerolog.Ctx(ctx)
if ctxLog.GetLevel() == zerolog.Disabled || ctxLog == zerolog.DefaultContextLogger {
ctx = cli.Log.WithContext(ctx)
}
req, err := http.NewRequestWithContext(ctx, http.MethodGet, cli.GetDownloadURL(mxcURL), nil)
if err != nil {
return req, nil, err
return nil, err
}
req.Header.Set("User-Agent", cli.UserAgent+" (media downloader)")
cli.LogRequest(req)
if resp, err := cli.Client.Do(req); err != nil {
return req, nil, err
} else {
return req, resp, nil
}
return cli.doMediaRequest(req, cli.DefaultHTTPRetries, 4*time.Second)
}
func (cli *Client) DownloadBytes(mxcURL id.ContentURI) ([]byte, error) {
@@ -1398,18 +1429,11 @@ func (cli *Client) DownloadBytes(mxcURL id.ContentURI) ([]byte, error) {
}
func (cli *Client) DownloadBytesContext(ctx context.Context, mxcURL id.ContentURI) ([]byte, error) {
req, resp, err := cli.downloadContext(ctx, mxcURL)
resp, err := cli.downloadContext(ctx, mxcURL)
if err != nil {
return nil, err
}
defer resp.Body.Close()
if resp.StatusCode >= 300 || resp.StatusCode < 200 {
respErr := &RespError{}
if _ = json.NewDecoder(resp.Body).Decode(respErr); respErr.ErrCode == "" {
respErr = nil
}
return nil, HTTPError{Request: req, Response: resp, RespError: respErr}
}
return io.ReadAll(resp.Body)
}
@@ -1980,7 +2004,7 @@ func (cli *Client) PutPushRule(scope string, kind pushrules.PushRuleType, ruleID
// BatchSend sends a batch of historical events into a room. This is only available for appservices.
//
// See https://github.com/matrix-org/matrix-doc/pull/2716 for more info.
// Deprecated: MSC2716 has been abandoned, so this is now Beeper-specific. BeeperBatchSend should be used instead.
func (cli *Client) BatchSend(roomID id.RoomID, req *ReqBatchSend) (resp *RespBatchSend, err error) {
path := ClientURLPath{"unstable", "org.matrix.msc2716", "rooms", roomID, "batch_send"}
query := map[string]string{
@@ -2011,6 +2035,12 @@ func (cli *Client) AppservicePing(id, txnID string) (resp *RespAppservicePing, e
return
}
func (cli *Client) BeeperBatchSend(roomID id.RoomID, req *ReqBeeperBatchSend) (resp *RespBeeperBatchSend, err error) {
u := cli.BuildClientURL("unstable", "com.beeper.backfill", "rooms", roomID, "batch_send")
_, err = cli.MakeRequest(http.MethodPost, u, req, &resp)
return
}
func (cli *Client) BeeperMergeRooms(req *ReqBeeperMergeRoom) (resp *RespBeeperMergeRoom, err error) {
urlPath := cli.BuildClientURL("unstable", "com.beeper.chatmerging", "merge")
_, err = cli.MakeRequest(http.MethodPost, urlPath, req, &resp)