outposts/ldap: copy boundUsers map when running refresh instead of using blank map
closes #1651 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
4ce3c2341c
commit
f069cfb643
|
@ -29,6 +29,7 @@ type ProviderInstance struct {
|
|||
tlsServerName *string
|
||||
cert *tls.Certificate
|
||||
outpostName string
|
||||
outpostPk int32
|
||||
searchAllowedGroups []*strfmt.UUID
|
||||
boundUsersMutex sync.RWMutex
|
||||
boundUsers map[string]flags.UserFlags
|
||||
|
|
|
@ -17,6 +17,15 @@ import (
|
|||
memorysearch "goauthentik.io/internal/outpost/ldap/search/memory"
|
||||
)
|
||||
|
||||
func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
|
||||
for _, p := range ls.providers {
|
||||
if p.outpostPk == pk {
|
||||
return p
|
||||
}
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (ls *LDAPServer) Refresh() error {
|
||||
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
|
||||
if err != nil {
|
||||
|
@ -31,6 +40,15 @@ func (ls *LDAPServer) Refresh() error {
|
|||
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
|
||||
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
|
||||
logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name)
|
||||
|
||||
// Get existing instance so we can transfer boundUsers
|
||||
existing := ls.getCurrentProvider(provider.Pk)
|
||||
users := make(map[string]flags.UserFlags)
|
||||
if existing != nil {
|
||||
existing.boundUsersMutex.Unlock()
|
||||
users = existing.boundUsers
|
||||
}
|
||||
|
||||
providers[idx] = &ProviderInstance{
|
||||
BaseDN: *provider.BaseDn,
|
||||
VirtualGroupDN: virtualGroupDN,
|
||||
|
@ -40,13 +58,14 @@ func (ls *LDAPServer) Refresh() error {
|
|||
flowSlug: provider.BindFlowSlug,
|
||||
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
|
||||
boundUsersMutex: sync.RWMutex{},
|
||||
boundUsers: make(map[string]flags.UserFlags),
|
||||
boundUsers: users,
|
||||
s: ls,
|
||||
log: logger,
|
||||
tlsServerName: provider.TlsServerName,
|
||||
uidStartNumber: *provider.UidStartNumber,
|
||||
gidStartNumber: *provider.GidStartNumber,
|
||||
outpostName: ls.ac.Outpost.Name,
|
||||
outpostPk: provider.Pk,
|
||||
}
|
||||
if provider.Certificate.Get() != nil {
|
||||
kp := provider.Certificate.Get()
|
||||
|
|
Reference in a new issue