diff --git a/pkg/api/authn.go b/pkg/api/authn.go index 15f6f1573..6da0cedcc 100644 --- a/pkg/api/authn.go +++ b/pkg/api/authn.go @@ -262,7 +262,8 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun // ldap and htpasswd based authN if ctlr.Config.IsLdapAuthEnabled() { ldapConfig := ctlr.Config.HTTP.Auth.LDAP - amw.ldapClient = &LDAPClient{ + + ctlr.LDAPClient = &LDAPClient{ Host: ldapConfig.Address, Port: ldapConfig.Port, UseSSL: !ldapConfig.Insecure, @@ -278,6 +279,8 @@ func (amw *AuthnMiddleware) tryAuthnHandlers(ctlr *Controller) mux.MiddlewareFun SubtreeSearch: ldapConfig.SubtreeSearch, } + amw.ldapClient = ctlr.LDAPClient + if ctlr.Config.HTTP.Auth.LDAP.CACert != "" { caCert, err := os.ReadFile(ctlr.Config.HTTP.Auth.LDAP.CACert) if err != nil { diff --git a/pkg/api/controller.go b/pkg/api/controller.go index 582411b6e..6f7ff763c 100644 --- a/pkg/api/controller.go +++ b/pkg/api/controller.go @@ -48,6 +48,7 @@ type Controller struct { SyncOnDemand SyncOnDemand RelyingParties map[string]rp.RelyingParty CookieStore *CookieStore + LDAPClient *LDAPClient taskScheduler *scheduler.Scheduler // runtime params chosenPort int // kernel-chosen port @@ -313,6 +314,17 @@ func (c *Controller) LoadNewConfig(newConfig *config.Config) { // reload access control config c.Config.HTTP.AccessControl = newConfig.HTTP.AccessControl + if c.Config.HTTP.Auth != nil { + c.Config.HTTP.Auth.LDAP = newConfig.HTTP.Auth.LDAP + + if c.LDAPClient != nil { + c.LDAPClient.lock.Lock() + c.LDAPClient.BindDN = newConfig.HTTP.Auth.LDAP.BindDN() + c.LDAPClient.BindPassword = newConfig.HTTP.Auth.LDAP.BindPassword() + c.LDAPClient.lock.Unlock() + } + } + // reload periodical gc config c.Config.Storage.GC = newConfig.Storage.GC c.Config.Storage.Dedupe = newConfig.Storage.Dedupe diff --git a/pkg/api/controller_test.go b/pkg/api/controller_test.go index bb2457663..a38cc89fb 100644 --- a/pkg/api/controller_test.go +++ b/pkg/api/controller_test.go @@ -2331,6 +2331,183 @@ func TestBasicAuthWithLDAP(t *testing.T) { }) } +func TestBasicAuthWithReloadedCredentials(t *testing.T) { + Convey("Start server with bad credentials", t, func() { + l := newTestLDAPServer() + ldapPort, err := strconv.Atoi(test.GetFreePort()) + So(err, ShouldBeNil) + l.Start(ldapPort) + defer l.Stop() + + tempDir := t.TempDir() + ldapConfigContent := fmt.Sprintf(`{"BindDN": "%s", "BindPassword": "%s"}`, LDAPBindDN, LDAPBindPassword) + ldapConfigPath := filepath.Join(tempDir, "ldap.json") + + err = os.WriteFile(ldapConfigPath, []byte(ldapConfigContent), 0o600) + So(err, ShouldBeNil) + + port := test.GetFreePort() + baseURL := test.GetBaseURL(port) + + configTemplate := ` + { + "Storage": { + "rootDirectory": "%s" + }, + "HTTP": { + "Address": "%s", + "Port": "%s", + "Auth": { + "LDAP": { + "CredentialsFile": "%s", + "UserAttribute": "uid", + "BaseDN": "LDAPBaseDN", + "UserGroupAttribute": "memberOf", + "Insecure": true, + "Address": "%v", + "Port": %v + } + } + } + } + ` + + configStr := fmt.Sprintf(configTemplate, + tempDir, "127.0.0.1", port, ldapConfigPath, LDAPAddress, ldapPort) + + configPath := filepath.Join(tempDir, "config.json") + err = os.WriteFile(configPath, []byte(configStr), 0o600) + So(err, ShouldBeNil) + + conf := config.New() + err = server.LoadConfiguration(conf, configPath) + So(err, ShouldBeNil) + + ctlr := api.NewController(conf) + ctlrManager := test.NewControllerManager(ctlr) + + hotReloader, err := server.NewHotReloader(ctlr, configPath, ldapConfigPath) + So(err, ShouldBeNil) + + hotReloader.Start() + + ctlrManager.StartAndWait(port) + defer ctlrManager.StopServer() + time.Sleep(time.Second * 2) + + // test if the credentials work + resp, _ := resty.R().SetBasicAuth(username, password).Get(baseURL + "/v2/") + So(resp, ShouldNotBeNil) + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + newLdapConfigContent := `{"BindDN": "NewBindDB", "BindPassword": "NewBindPassword"}` + + err = os.WriteFile(ldapConfigPath, []byte(newLdapConfigContent), 0o600) + So(err, ShouldBeNil) + + for i := 0; i < 10; i++ { + // test if the credentials don't work + resp, _ = resty.R().SetBasicAuth(username, password).Get(baseURL + "/v2/") + So(resp, ShouldNotBeNil) + + if resp.StatusCode() != http.StatusOK { + break + } + + time.Sleep(100 * time.Millisecond) + } + + So(resp.StatusCode(), ShouldNotEqual, http.StatusOK) + + err = os.WriteFile(ldapConfigPath, []byte(ldapConfigContent), 0o600) + So(err, ShouldBeNil) + + for i := 0; i < 10; i++ { + // test if the credentials don't work + resp, _ = resty.R().SetBasicAuth(username, password).Get(baseURL + "/v2/") + So(resp, ShouldNotBeNil) + if resp.StatusCode() == http.StatusOK { + break + } + + time.Sleep(100 * time.Millisecond) + } + + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // change the file + changedDir := t.TempDir() + changedLdapConfigPath := filepath.Join(changedDir, "ldap.json") + + err = os.WriteFile(changedLdapConfigPath, []byte(newLdapConfigContent), 0o600) + So(err, ShouldBeNil) + + configStr = fmt.Sprintf(configTemplate, + tempDir, "127.0.0.1", port, changedLdapConfigPath, LDAPAddress, ldapPort) + + err = os.WriteFile(configPath, []byte(configStr), 0o600) + So(err, ShouldBeNil) + + for i := 0; i < 10; i++ { + // test if the credentials don't work + resp, _ = resty.R().SetBasicAuth(username, password).Get(baseURL + "/v2/") + So(resp, ShouldNotBeNil) + + if resp.StatusCode() != http.StatusOK { + break + } + + time.Sleep(100 * time.Millisecond) + } + + So(resp.StatusCode(), ShouldNotEqual, http.StatusOK) + + err = os.WriteFile(changedLdapConfigPath, []byte(ldapConfigContent), 0o600) + So(err, ShouldBeNil) + + for i := 0; i < 10; i++ { + // test if the credentials don't work + resp, _ = resty.R().SetBasicAuth(username, password).Get(baseURL + "/v2/") + So(resp, ShouldNotBeNil) + if resp.StatusCode() == http.StatusOK { + break + } + + time.Sleep(100 * time.Millisecond) + } + + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + + // make it panic + badLDAPFilePath := filepath.Join(changedDir, "ldap222.json") + + err = os.WriteFile(badLDAPFilePath, []byte(newLdapConfigContent), 0o600) + So(err, ShouldBeNil) + + configStr = fmt.Sprintf(configTemplate, + tempDir, "127.0.0.1", port, changedLdapConfigPath, LDAPAddress, ldapPort) + + err = os.WriteFile(configPath, []byte(configStr), 0o600) + So(err, ShouldBeNil) + + // Loading the config should fail because the file doesn't exist so the old credentials + // are still up and working fine. + + for i := 0; i < 10; i++ { + // test if the credentials don't work + resp, _ = resty.R().SetBasicAuth(username, password).Get(baseURL + "/v2/") + So(resp, ShouldNotBeNil) + if resp.StatusCode() == http.StatusOK { + break + } + + time.Sleep(100 * time.Millisecond) + } + + So(resp.StatusCode(), ShouldEqual, http.StatusOK) + }) +} + func TestLDAPWithoutCreds(t *testing.T) { Convey("Make a new LDAP server", t, func() { l := newTestLDAPServer() diff --git a/pkg/cli/server/config_reloader.go b/pkg/cli/server/config_reloader.go index bf57ffa77..9409d67c4 100644 --- a/pkg/cli/server/config_reloader.go +++ b/pkg/cli/server/config_reloader.go @@ -1,6 +1,7 @@ package server import ( + "errors" "os" "os/signal" "syscall" @@ -13,12 +14,13 @@ import ( ) type HotReloader struct { - watcher *fsnotify.Watcher - filePath string - ctlr *api.Controller + watcher *fsnotify.Watcher + configPath string + ldapCredentialsPath string + ctlr *api.Controller } -func NewHotReloader(ctlr *api.Controller, filePath string) (*HotReloader, error) { +func NewHotReloader(ctlr *api.Controller, filePath, ldapCredentialsPath string) (*HotReloader, error) { // creates a new file watcher watcher, err := fsnotify.NewWatcher() if err != nil { @@ -26,9 +28,10 @@ func NewHotReloader(ctlr *api.Controller, filePath string) (*HotReloader, error) } hotReloader := &HotReloader{ - watcher: watcher, - filePath: filePath, - ctlr: ctlr, + watcher: watcher, + configPath: filePath, + ldapCredentialsPath: ldapCredentialsPath, + ctlr: ctlr, } return hotReloader, nil @@ -73,13 +76,27 @@ func (hr *HotReloader) Start() { newConfig := config.New() - err := LoadConfiguration(newConfig, hr.filePath) + err := LoadConfiguration(newConfig, hr.configPath) if err != nil { log.Error().Err(err).Msg("failed to reload config, retry writing it.") continue } + if hr.ctlr.Config.HTTP.Auth != nil && hr.ctlr.Config.HTTP.Auth.LDAP != nil && + hr.ctlr.Config.HTTP.Auth.LDAP.CredentialsFile != newConfig.HTTP.Auth.LDAP.CredentialsFile { + err = hr.watcher.Remove(hr.ctlr.Config.HTTP.Auth.LDAP.CredentialsFile) + if err != nil && !errors.Is(err, fsnotify.ErrNonExistentWatch) { + log.Error().Err(err).Msg("failed to remove old watch for the credentials file") + } + + err = hr.watcher.Add(newConfig.HTTP.Auth.LDAP.CredentialsFile) + if err != nil { + log.Panic().Err(err).Str("ldap-credentials-file", newConfig.HTTP.Auth.LDAP.CredentialsFile). + Msg("failed to watch ldap credentials file") + } + } + // stop background tasks gracefully hr.ctlr.StopBackgroundTasks() @@ -91,13 +108,20 @@ func (hr *HotReloader) Start() { } // watch for errors case err := <-hr.watcher.Errors: - log.Panic().Err(err).Str("config", hr.filePath).Msg("fsnotfy error while watching config") + log.Panic().Err(err).Str("config", hr.configPath).Msg("fsnotfy error while watching config") } } }() - if err := hr.watcher.Add(hr.filePath); err != nil { - log.Panic().Err(err).Str("config", hr.filePath).Msg("failed to add config file to fsnotity watcher") + if err := hr.watcher.Add(hr.configPath); err != nil { + log.Panic().Err(err).Str("config", hr.configPath).Msg("failed to add config file to fsnotity watcher") + } + + if hr.ldapCredentialsPath != "" { + if err := hr.watcher.Add(hr.ldapCredentialsPath); err != nil { + log.Panic().Err(err).Str("ldap-credentials", hr.ldapCredentialsPath). + Msg("failed to add ldap-credentials to fsnotity watcher") + } } <-done diff --git a/pkg/cli/server/root.go b/pkg/cli/server/root.go index 4060d58d5..0af847499 100644 --- a/pkg/cli/server/root.go +++ b/pkg/cli/server/root.go @@ -55,8 +55,13 @@ func newServeCmd(conf *config.Config) *cobra.Command { ctlr := api.NewController(conf) + ldapCredentials := "" + + if conf.HTTP.Auth != nil && conf.HTTP.Auth.LDAP != nil { + ldapCredentials = conf.HTTP.Auth.LDAP.CredentialsFile + } // config reloader - hotReloader, err := NewHotReloader(ctlr, args[0]) + hotReloader, err := NewHotReloader(ctlr, args[0], ldapCredentials) if err != nil { ctlr.Log.Error().Err(err).Msg("failed to create a new hot reloader") diff --git a/pkg/extensions/sync/sync_test.go b/pkg/extensions/sync/sync_test.go index 344ad7395..06ee1fcb8 100644 --- a/pkg/extensions/sync/sync_test.go +++ b/pkg/extensions/sync/sync_test.go @@ -1898,7 +1898,7 @@ func TestConfigReloader(t *testing.T) { _, err = cfgfile.WriteString(content) So(err, ShouldBeNil) - hotReloader, err := cli.NewHotReloader(dctlr, cfgfile.Name()) + hotReloader, err := cli.NewHotReloader(dctlr, cfgfile.Name(), "") So(err, ShouldBeNil) hotReloader.Start() @@ -2048,7 +2048,7 @@ func TestConfigReloader(t *testing.T) { _, err = cfgfile.WriteString(content) So(err, ShouldBeNil) - hotReloader, err := cli.NewHotReloader(dctlr, cfgfile.Name()) + hotReloader, err := cli.NewHotReloader(dctlr, cfgfile.Name(), "") So(err, ShouldBeNil) hotReloader.Start()