diff --git a/internal/outpost/ldap/instance.go b/internal/outpost/ldap/instance.go index 0bacf428b..21352518f 100644 --- a/internal/outpost/ldap/instance.go +++ b/internal/outpost/ldap/instance.go @@ -29,6 +29,7 @@ type ProviderInstance struct { tlsServerName *string cert *tls.Certificate outpostName string + outpostPk int32 searchAllowedGroups []*strfmt.UUID boundUsersMutex sync.RWMutex boundUsers map[string]flags.UserFlags diff --git a/internal/outpost/ldap/refresh.go b/internal/outpost/ldap/refresh.go index 1ec37ead9..a79417b40 100644 --- a/internal/outpost/ldap/refresh.go +++ b/internal/outpost/ldap/refresh.go @@ -17,6 +17,15 @@ import ( memorysearch "goauthentik.io/internal/outpost/ldap/search/memory" ) +func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance { + for _, p := range ls.providers { + if p.outpostPk == pk { + return p + } + } + return nil +} + func (ls *LDAPServer) Refresh() error { outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute() if err != nil { @@ -31,6 +40,15 @@ func (ls *LDAPServer) Refresh() error { groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn)) logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name) + + // Get existing instance so we can transfer boundUsers + existing := ls.getCurrentProvider(provider.Pk) + users := make(map[string]flags.UserFlags) + if existing != nil { + existing.boundUsersMutex.Unlock() + users = existing.boundUsers + } + providers[idx] = &ProviderInstance{ BaseDN: *provider.BaseDn, VirtualGroupDN: virtualGroupDN, @@ -40,13 +58,14 @@ func (ls *LDAPServer) Refresh() error { flowSlug: provider.BindFlowSlug, searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, boundUsersMutex: sync.RWMutex{}, - boundUsers: make(map[string]flags.UserFlags), + boundUsers: users, s: ls, log: logger, tlsServerName: provider.TlsServerName, uidStartNumber: *provider.UidStartNumber, gidStartNumber: *provider.GidStartNumber, outpostName: ls.ac.Outpost.Name, + outpostPk: provider.Pk, } if provider.Certificate.Get() != nil { kp := provider.Certificate.Get()