diff --git a/README.md b/README.md index 3a5ecf1..42f6c68 100644 --- a/README.md +++ b/README.md @@ -44,6 +44,7 @@ env vars * **POSTMOOGLE_DB_DSN** - database connection string * **POSTMOOGLE_DB_DIALECT** - database dialect (postgres, sqlite3) * **POSTMOOGLE_MAXSIZE** - max email size (including attachments) in megabytes +* **POSTMOOGLE_USERS** - a space-separated list of whitelisted users allowed to use the bridge. If not defined, everyone is allowed. Example rule: `@someone:example.com @another:example.com @bot.*:example.com @*:another.com` You can find default values in [config/defaults.go](config/defaults.go) diff --git a/bot/bot.go b/bot/bot.go index 83826be..db3a7e3 100644 --- a/bot/bot.go +++ b/bot/bot.go @@ -3,6 +3,7 @@ package bot import ( "context" "fmt" + "regexp" "sync" "github.com/getsentry/sentry-go" @@ -19,6 +20,7 @@ type Bot struct { federation bool prefix string domain string + allowedUsers []*regexp.Regexp rooms sync.Map log *logger.Logger lp *linkpearl.Linkpearl @@ -26,15 +28,16 @@ type Bot struct { } // New creates a new matrix bot -func New(lp *linkpearl.Linkpearl, log *logger.Logger, prefix, domain string, noowner, federation bool) *Bot { +func New(lp *linkpearl.Linkpearl, log *logger.Logger, prefix, domain string, noowner, federation bool, allowedUsers []*regexp.Regexp) *Bot { return &Bot{ - noowner: noowner, - federation: federation, - prefix: prefix, - domain: domain, - rooms: sync.Map{}, - log: log, - lp: lp, + noowner: noowner, + federation: federation, + prefix: prefix, + domain: domain, + allowedUsers: allowedUsers, + rooms: sync.Map{}, + log: log, + lp: lp, } } diff --git a/cmd/cmd.go b/cmd/cmd.go index dcd3ec7..f8d8cd1 100644 --- a/cmd/cmd.go +++ b/cmd/cmd.go @@ -26,7 +26,13 @@ var ( func main() { quit := make(chan struct{}) - cfg := config.New() + + cfg, err := config.New() + if err != nil { + log = logger.New("postmoogle.", "info") + log.Fatal("%s", err) + } + log = logger.New("postmoogle.", cfg.LogLevel) log.Info("#############################") @@ -81,7 +87,7 @@ func initBot(cfg *config.Config) { // nolint // Fatal = panic, not os.Exit() log.Fatal("cannot initialize matrix bot: %v", err) } - mxb = bot.New(lp, mxlog, cfg.Prefix, cfg.Domain, cfg.NoOwner, cfg.Federation) + mxb = bot.New(lp, mxlog, cfg.Prefix, cfg.Domain, cfg.NoOwner, cfg.Federation, cfg.Users) log.Debug("bot has been created") } diff --git a/config/config.go b/config/config.go index 4735b55..3223265 100644 --- a/config/config.go +++ b/config/config.go @@ -1,14 +1,28 @@ package config import ( + "fmt" + "gitlab.com/etke.cc/go/env" + "gitlab.com/etke.cc/postmoogle/utils" ) const prefix = "postmoogle" // New config -func New() *Config { +func New() (*Config, error) { env.SetPrefix(prefix) + + wildCardUserPatterns := env.Slice("users") + regexUserPatterns, err := utils.WildcardUserPatternsToRegexPatterns(wildCardUserPatterns) + if err != nil { + return nil, fmt.Errorf( + "failed to convert wildcard user patterns (`%s`) to regular expression: %s", + wildCardUserPatterns, + err, + ) + } + cfg := &Config{ Homeserver: env.String("homeserver", defaultConfig.Homeserver), Login: env.String("login", defaultConfig.Login), @@ -21,6 +35,7 @@ func New() *Config { Federation: env.Bool("federation"), MaxSize: env.Int("maxsize", defaultConfig.MaxSize), StatusMsg: env.String("statusmsg", defaultConfig.StatusMsg), + Users: *regexUserPatterns, Sentry: Sentry{ DSN: env.String("sentry.dsn", defaultConfig.Sentry.DSN), }, @@ -31,5 +46,5 @@ func New() *Config { }, } - return cfg + return cfg, nil } diff --git a/config/types.go b/config/types.go index 5dc8985..611c1ff 100644 --- a/config/types.go +++ b/config/types.go @@ -1,5 +1,7 @@ package config +import "regexp" + // Config of Postmoogle type Config struct { // Homeserver url @@ -26,6 +28,11 @@ type Config struct { MaxSize int // StatusMsg of the bot StatusMsg string + // Users holds regular expression patterns of users that are allowed to use the bridge. + // The regular expression patterns are compiled from wildcard patterns like: + // `@someone:example.com`, `@*:example.com`, `@bot.*:example.com`, `@someone:*`, `@someone:*.example.com` + // An empty list means that "everyone is allowed". + Users []*regexp.Regexp // DB config DB DB diff --git a/utils/user.go b/utils/user.go new file mode 100644 index 0000000..5f7f4e9 --- /dev/null +++ b/utils/user.go @@ -0,0 +1,98 @@ +package utils + +import ( + "fmt" + "regexp" + "strings" +) + +// WildcardUserPatternsToRegexPatterns converts a list of wildcard patterns to a list of regular expressions +func WildcardUserPatternsToRegexPatterns(wildCardPatterns []string) (*[]*regexp.Regexp, error) { + regexPatterns := make([]*regexp.Regexp, len(wildCardPatterns)) + + for idx, wildCardPattern := range wildCardPatterns { + regex, err := parseAllowedUserRule(wildCardPattern) + if err != nil { + return nil, fmt.Errorf("failed to parse allowed user rule `%s`: %s", wildCardPattern, err) + } + regexPatterns[idx] = regex + } + + return ®exPatterns, nil +} + +// MatchUserWithAllowedRegexes tells if the given user id is allowed to use the bot, according to the given whitelist +// An empty whitelist means "everyone is allowed" +func MatchUserWithAllowedRegexes(userID string, allowed []*regexp.Regexp) (bool, error) { + // No whitelisted users means everyone is whitelisted + if len(allowed) == 0 { + return true, nil + } + + for _, regex := range allowed { + if regex.MatchString(userID) { + return true, nil + } + } + + return false, nil +} + +// parseAllowedUserRule parses a user whitelisting rule and returns a regular expression which corresponds to it +// Example conversion: `@bot.*.something:*.example.com` -> `^bot\.([^:@]*)\.something:([^:@]*)\.example.com$` +// Example of recognized wildcard patterns: `@someone:example.com`, `@*:example.com`, `@bot.*:example.com`, `@someone:*`, `@someone:*.example.com` +func parseAllowedUserRule(wildCardRule string) (*regexp.Regexp, error) { + if !strings.HasPrefix(wildCardRule, "@") { + return nil, fmt.Errorf("rules need to be fully-qualified, starting with a @") + } + + remainingRule := wildCardRule[1:] + if strings.Contains(remainingRule, "@") { + return nil, fmt.Errorf("rules cannot contain more than one @") + } + + parts := strings.Split(remainingRule, ":") + if len(parts) != 2 { + return nil, fmt.Errorf("expected exactly 2 parts in the rule, separated by `:`") + } + + getRegexPatternForPart := func(part string) (string, error) { + if part == "" { + return "", fmt.Errorf("rejecting empty part") + } + + var pattern strings.Builder + for _, rune := range part { + if rune == '*' { + // We match everything except for `:` and `@`, because that would be an invalid MXID anyway + pattern.WriteString("([^:@]*)") + continue + } + + pattern.WriteString(regexp.QuoteMeta(string(rune))) + } + + return pattern.String(), nil + } + + localPart := parts[0] + localPartPattern, err := getRegexPatternForPart(localPart) + if err != nil { + return nil, fmt.Errorf("failed to convert local part `%s` to regex: %s", localPart, err) + } + + domainPart := parts[1] + domainPartPattern, err := getRegexPatternForPart(domainPart) + if err != nil { + return nil, fmt.Errorf("failed to convert domain part `%s` to regex: %s", domainPart, err) + } + + finalPattern := fmt.Sprintf("^@%s:%s$", localPartPattern, domainPartPattern) + + regex, err := regexp.Compile(finalPattern) + if err != nil { + return nil, fmt.Errorf("failed to compile regex `%s`: %s", finalPattern, err) + } + + return regex, nil +} diff --git a/utils/user_test.go b/utils/user_test.go new file mode 100644 index 0000000..e764d6d --- /dev/null +++ b/utils/user_test.go @@ -0,0 +1,194 @@ +package utils + +import "testing" + +func TestRuleToRegex(t *testing.T) { + type testDataDefinition struct { + name string + checkedValue string + expectedResult string + expectedError bool + } + + tests := []testDataDefinition{ + { + name: "simple pattern without wildcards succeeds", + checkedValue: "@someone:example.com", + expectedResult: `^@someone:example\.com$`, + expectedError: false, + }, + { + name: "pattern with wildcard as the whole local part succeeds", + checkedValue: "@*:example.com", + expectedResult: `^@([^:@]*):example\.com$`, + expectedError: false, + }, + { + name: "pattern with wildcard within the local part succeeds", + checkedValue: "@bot.*.something:example.com", + expectedResult: `^@bot\.([^:@]*)\.something:example\.com$`, + expectedError: false, + }, + { + name: "pattern with wildcard as the whole domain part succeeds", + checkedValue: "@someone:*", + expectedResult: `^@someone:([^:@]*)$`, + expectedError: false, + }, + { + name: "pattern with wildcard within the domain part succeeds", + checkedValue: "@someone:*.organization.com", + expectedResult: `^@someone:([^:@]*)\.organization\.com$`, + expectedError: false, + }, + { + name: "pattern with wildcard in both parts succeeds", + checkedValue: "@*:*", + expectedResult: `^@([^:@]*):([^:@]*)$`, + expectedError: false, + }, + { + name: "pattern that does not appear fully-qualified fails", + checkedValue: "someone:example.com", + expectedResult: ``, + expectedError: true, + }, + { + name: "pattern that does not appear fully-qualified fails", + checkedValue: "@someone", + expectedResult: ``, + expectedError: true, + }, + { + name: "pattern with empty domain part fails", + checkedValue: "@someone:", + expectedResult: ``, + expectedError: true, + }, + { + name: "pattern with empty local part fails", + checkedValue: "@:example.com", + expectedResult: ``, + expectedError: true, + }, + { + name: "pattern with multiple @ fails", + checkedValue: "@someone@someone:example.com", + expectedResult: ``, + expectedError: true, + }, + { + name: "pattern with multiple : fails", + checkedValue: "@someone:someone:example.com", + expectedResult: ``, + expectedError: true, + }, + } + + for _, testData := range tests { + func(testData testDataDefinition) { + t.Run(testData.name, func(t *testing.T) { + actualResult, err := parseAllowedUserRule(testData.checkedValue) + + if testData.expectedError { + if err != nil { + return + } + + t.Errorf("expected an error, but did not get one") + } + + if err != nil { + t.Errorf("did not expect an error, but got one: %s", err) + } + + if actualResult.String() == testData.expectedResult { + return + } + + t.Errorf( + "Expected `%s` to yield `%s`, not `%s`", + testData.checkedValue, + testData.expectedResult, + actualResult.String(), + ) + }) + }(testData) + } +} + +func TestMatch(t *testing.T) { + type testDataDefinition struct { + name string + checkedValue string + allowedUsers []string + expectedResult bool + } + + tests := []testDataDefinition{ + { + name: "Empty allowed users allows anyone", + checkedValue: "@someone:example.com", + allowedUsers: []string{}, + expectedResult: true, + }, + { + name: "Direct full mxid match is allowed", + checkedValue: "@someone:example.com", + allowedUsers: []string{"@someone:example.com"}, + expectedResult: true, + }, + { + name: "Direct full mxid match later on is allowed", + checkedValue: "@someone:example.com", + allowedUsers: []string{"@another:example.com", "@someone:example.com"}, + expectedResult: true, + }, + { + name: "No mxid match is not allowed", + checkedValue: "@someone:example.com", + allowedUsers: []string{"@another:example.com"}, + expectedResult: false, + }, + { + name: "mxid localpart wildcard match is allowed", + checkedValue: "@someone:example.com", + allowedUsers: []string{"@*:example.com"}, + expectedResult: true, + }, + { + name: "mxid localpart wildcard for another domain is not allowed", + checkedValue: "@someone:example.com", + allowedUsers: []string{"@*:another.com"}, + expectedResult: false, + }, + } + + for _, testData := range tests { + func(testData testDataDefinition) { + t.Run(testData.name, func(t *testing.T) { + allowedUserRegexes, err := WildcardUserPatternsToRegexPatterns(testData.allowedUsers) + if err != nil { + t.Error(err) + } + + actualResult, err := MatchUserWithAllowedRegexes(testData.checkedValue, *allowedUserRegexes) + if err != nil { + t.Error(err) + } + + if actualResult == testData.expectedResult { + return + } + + t.Errorf( + "Expected `%s` compared against `%v` to yield `%v`, not `%v`", + testData.checkedValue, + testData.allowedUsers, + testData.expectedResult, + actualResult, + ) + }) + }(testData) + } +}