diff --git a/internal/outpost/ldap/bind/direct/bind.go b/internal/outpost/ldap/bind/direct/bind.go index 0806bff88..67151033e 100644 --- a/internal/outpost/ldap/bind/direct/bind.go +++ b/internal/outpost/ldap/bind/direct/bind.go @@ -36,8 +36,15 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul passed, err := fe.Execute() flags := flags.UserFlags{ Session: fe.GetSession(), + UserPk: flags.InvalidUserPK, + } + // only set flags if we don't have flags for this DN yet + // as flags are only checked during the bind, we can remember whether a certain DN + // can search or not, as if they bind correctly first and then use incorrect credentials + // later, they won't get past this step anyways + if db.si.GetFlags(req.BindDN) == nil { + db.si.SetFlags(req.BindDN, &flags) } - db.si.SetFlags(req.BindDN, &flags) if err != nil { metrics.RequestsRejected.With(prometheus.Labels{ "outpost_name": db.si.GetOutpostName(), diff --git a/internal/outpost/ldap/flags/flags.go b/internal/outpost/ldap/flags/flags.go index f02fc1292..60538de2c 100644 --- a/internal/outpost/ldap/flags/flags.go +++ b/internal/outpost/ldap/flags/flags.go @@ -6,6 +6,8 @@ import ( "goauthentik.io/api/v3" ) +const InvalidUserPK = -1 + type UserFlags struct { UserInfo *api.User UserPk int32 diff --git a/internal/outpost/ldap/search/memory/memory.go b/internal/outpost/ldap/search/memory/memory.go index 1706493a1..04a39996a 100644 --- a/internal/outpost/ldap/search/memory/memory.go +++ b/internal/outpost/ldap/search/memory/memory.go @@ -12,6 +12,7 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" "goauthentik.io/internal/outpost/ldap/constants" + "goauthentik.io/internal/outpost/ldap/flags" "goauthentik.io/internal/outpost/ldap/group" "goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/outpost/ldap/search" @@ -73,8 +74,8 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN()) } - flags := ms.si.GetFlags(req.BindDN) - if flags == nil { + flag := ms.si.GetFlags(req.BindDN) + if flag == nil || (flag.UserInfo == nil && flag.UserPk == flags.InvalidUserPK) { req.Log().Debug("User info not cached") metrics.RequestsRejected.With(prometheus.Labels{ "outpost_name": ms.si.GetOutpostName(), @@ -108,23 +109,23 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, var err error if needUsers { - if flags.CanSearch { + if flag.CanSearch { users = &ms.users } else { u := make([]api.User, 1) - if flags.UserInfo == nil { + if flag.UserInfo == nil { for i, u := range ms.users { - if u.Pk == flags.UserPk { - flags.UserInfo = &ms.users[i] + if u.Pk == flag.UserPk { + flag.UserInfo = &ms.users[i] } } - - if flags.UserInfo == nil { - req.Log().WithField("pk", flags.UserPk).Warning("User with pk is not in local cache") + if flag.UserInfo == nil { + req.Log().WithField("pk", flag.UserPk).Warning("User with pk is not in local cache") err = fmt.Errorf("failed to get userinfo") } - } else { - u[0] = *flags.UserInfo + } + if flag.UserInfo != nil { + u[0] = *flag.UserInfo } users = &u } @@ -134,17 +135,17 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, groups = make([]*group.LDAPGroup, 0) for _, g := range ms.groups { - if flags.CanSearch { + if flag.CanSearch { groups = append(groups, group.FromAPIGroup(g, ms.si)) } else { // If the user cannot search, we're going to only return // the groups they're in _and_ only return themselves // as a member. for _, u := range g.UsersObj { - if flags.UserPk == u.Pk { + if flag.UserPk == u.Pk { //TODO: Is there a better way to clone this object? fg := api.NewGroup(g.Pk, g.NumPk, g.Name, g.ParentName, []api.GroupMember{u}) - fg.SetUsers([]int32{flags.UserPk}) + fg.SetUsers([]int32{flag.UserPk}) if g.Parent.IsSet() { fg.SetParent(*g.Parent.Get()) }