outposts/ldap: copy boundUsers map when running refresh instead of using blank map

closes #1651

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-11-13 00:26:01 +01:00
parent 4ce3c2341c
commit f069cfb643
2 changed files with 21 additions and 1 deletions

View file

@ -29,6 +29,7 @@ type ProviderInstance struct {
tlsServerName *string
cert *tls.Certificate
outpostName string
outpostPk int32
searchAllowedGroups []*strfmt.UUID
boundUsersMutex sync.RWMutex
boundUsers map[string]flags.UserFlags

View file

@ -17,6 +17,15 @@ import (
memorysearch "goauthentik.io/internal/outpost/ldap/search/memory"
)
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
for _, p := range ls.providers {
if p.outpostPk == pk {
return p
}
}
return nil
}
func (ls *LDAPServer) Refresh() error {
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
if err != nil {
@ -31,6 +40,15 @@ func (ls *LDAPServer) Refresh() error {
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name)
// Get existing instance so we can transfer boundUsers
existing := ls.getCurrentProvider(provider.Pk)
users := make(map[string]flags.UserFlags)
if existing != nil {
existing.boundUsersMutex.Unlock()
users = existing.boundUsers
}
providers[idx] = &ProviderInstance{
BaseDN: *provider.BaseDn,
VirtualGroupDN: virtualGroupDN,
@ -40,13 +58,14 @@ func (ls *LDAPServer) Refresh() error {
flowSlug: provider.BindFlowSlug,
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
boundUsersMutex: sync.RWMutex{},
boundUsers: make(map[string]flags.UserFlags),
boundUsers: users,
s: ls,
log: logger,
tlsServerName: provider.TlsServerName,
uidStartNumber: *provider.UidStartNumber,
gidStartNumber: *provider.GidStartNumber,
outpostName: ls.ac.Outpost.Name,
outpostPk: provider.Pk,
}
if provider.Certificate.Get() != nil {
kp := provider.Certificate.Get()