providers/ldap: fix inconsistent saving of user flags on failed cached binds (#6096)

* feat: assign invalid pk and check

* fix: only set flags if they don't exist

* fix: userinfo not being set if data is available

* minor cleanup

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

---------

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
Co-authored-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Daniel 2023-06-29 16:57:11 +02:00 committed by Jens Langhammer
parent 04d0bd7fb7
commit ad81ee2740
No known key found for this signature in database
3 changed files with 25 additions and 15 deletions

View File

@ -36,8 +36,15 @@ func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResul
passed, err := fe.Execute() passed, err := fe.Execute()
flags := flags.UserFlags{ flags := flags.UserFlags{
Session: fe.GetSession(), Session: fe.GetSession(),
UserPk: flags.InvalidUserPK,
} }
// only set flags if we don't have flags for this DN yet
// as flags are only checked during the bind, we can remember whether a certain DN
// can search or not, as if they bind correctly first and then use incorrect credentials
// later, they won't get past this step anyways
if db.si.GetFlags(req.BindDN) == nil {
db.si.SetFlags(req.BindDN, &flags) db.si.SetFlags(req.BindDN, &flags)
}
if err != nil { if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(), "outpost_name": db.si.GetOutpostName(),

View File

@ -6,6 +6,8 @@ import (
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
) )
const InvalidUserPK = -1
type UserFlags struct { type UserFlags struct {
UserInfo *api.User UserInfo *api.User
UserPk int32 UserPk int32

View File

@ -12,6 +12,7 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3" "goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/ldap/constants" "goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/group" "goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search" "goauthentik.io/internal/outpost/ldap/search"
@ -73,8 +74,8 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult,
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN()) return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN())
} }
flags := ms.si.GetFlags(req.BindDN) flag := ms.si.GetFlags(req.BindDN)
if flags == nil { if flag == nil || (flag.UserInfo == nil && flag.UserPk == flags.InvalidUserPK) {
req.Log().Debug("User info not cached") req.Log().Debug("User info not cached")
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(), "outpost_name": ms.si.GetOutpostName(),
@ -108,23 +109,23 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult,
var err error var err error
if needUsers { if needUsers {
if flags.CanSearch { if flag.CanSearch {
users = &ms.users users = &ms.users
} else { } else {
u := make([]api.User, 1) u := make([]api.User, 1)
if flags.UserInfo == nil { if flag.UserInfo == nil {
for i, u := range ms.users { for i, u := range ms.users {
if u.Pk == flags.UserPk { if u.Pk == flag.UserPk {
flags.UserInfo = &ms.users[i] flag.UserInfo = &ms.users[i]
} }
} }
if flag.UserInfo == nil {
if flags.UserInfo == nil { req.Log().WithField("pk", flag.UserPk).Warning("User with pk is not in local cache")
req.Log().WithField("pk", flags.UserPk).Warning("User with pk is not in local cache")
err = fmt.Errorf("failed to get userinfo") err = fmt.Errorf("failed to get userinfo")
} }
} else { }
u[0] = *flags.UserInfo if flag.UserInfo != nil {
u[0] = *flag.UserInfo
} }
users = &u users = &u
} }
@ -134,17 +135,17 @@ func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult,
groups = make([]*group.LDAPGroup, 0) groups = make([]*group.LDAPGroup, 0)
for _, g := range ms.groups { for _, g := range ms.groups {
if flags.CanSearch { if flag.CanSearch {
groups = append(groups, group.FromAPIGroup(g, ms.si)) groups = append(groups, group.FromAPIGroup(g, ms.si))
} else { } else {
// If the user cannot search, we're going to only return // If the user cannot search, we're going to only return
// the groups they're in _and_ only return themselves // the groups they're in _and_ only return themselves
// as a member. // as a member.
for _, u := range g.UsersObj { for _, u := range g.UsersObj {
if flags.UserPk == u.Pk { if flag.UserPk == u.Pk {
//TODO: Is there a better way to clone this object? //TODO: Is there a better way to clone this object?
fg := api.NewGroup(g.Pk, g.NumPk, g.Name, g.ParentName, []api.GroupMember{u}) fg := api.NewGroup(g.Pk, g.NumPk, g.Name, g.ParentName, []api.GroupMember{u})
fg.SetUsers([]int32{flags.UserPk}) fg.SetUsers([]int32{flag.UserPk})
if g.Parent.IsSet() { if g.Parent.IsSet() {
fg.SetParent(*g.Parent.Get()) fg.SetParent(*g.Parent.Get())
} }