providers/ldap: correctly use pagination in search results in both modes (#5492)

closes #4292

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-05-05 15:51:02 +03:00 committed by GitHub
parent 7f0ccc61dd
commit b7b62ba089
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 77 additions and 83 deletions

View File

@ -18,6 +18,7 @@ import (
"goauthentik.io/internal/outpost/ldap/search" "goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils" "goauthentik.io/internal/outpost/ldap/utils"
"goauthentik.io/internal/outpost/ldap/utils/paginator"
) )
type DirectSearcher struct { type DirectSearcher struct {
@ -124,15 +125,10 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
return nil return nil
} }
u, _, err := searchReq.Execute() u := paginator.FetchUsers(searchReq)
uapisp.Finish() uapisp.Finish()
if err != nil { users = &u
req.Log().WithError(err).Warning("failed to get users")
return err
}
users = &u.Results
} else { } else {
if flags.UserInfo == nil { if flags.UserInfo == nil {
uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user") uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
@ -170,29 +166,24 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
searchReq = searchReq.MembersByPk([]int32{flags.UserPk}) searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
} }
g, _, err := searchReq.Execute() g := paginator.FetchGroups(searchReq)
gapisp.Finish() gapisp.Finish()
if err != nil { req.Log().WithField("count", len(g)).Trace("Got results from API")
req.Log().WithError(err).Warning("failed to get groups")
return err
}
req.Log().WithField("count", len(g.Results)).Trace("Got results from API")
if !flags.CanSearch { if !flags.CanSearch {
for i, results := range g.Results { for i, results := range g {
// If they can't search, remove any users from the group results except the one we're looking for. // If they can't search, remove any users from the group results except the one we're looking for.
g.Results[i].Users = []int32{flags.UserPk} g[i].Users = []int32{flags.UserPk}
for _, u := range results.UsersObj { for _, u := range results.UsersObj {
if u.Pk == flags.UserPk { if u.Pk == flags.UserPk {
g.Results[i].UsersObj = []api.GroupMember{u} g[i].UsersObj = []api.GroupMember{u}
break break
} }
} }
} }
} }
groups = &g.Results groups = &g
return nil return nil
}) })
} }

View File

@ -1,63 +0,0 @@
package memory
import (
"context"
"goauthentik.io/api/v3"
)
const pageSize = 100
func (ms *MemorySearcher) FetchUsers() []api.User {
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
if err != nil {
ms.log.WithError(err).Warning("failed to update users")
return nil, err
}
ms.log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
return users, nil
}
page := 1
users := make([]api.User, 0)
for {
apiUsers, err := fetchUsersOffset(page)
if err != nil {
return users
}
users = append(users, apiUsers.Results...)
if apiUsers.Pagination.Next > 0 {
page += 1
} else {
break
}
}
return users
}
func (ms *MemorySearcher) FetchGroups() []api.Group {
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
if err != nil {
ms.log.WithError(err).Warning("failed to update groups")
return nil, err
}
ms.log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
return groups, nil
}
page := 1
groups := make([]api.Group, 0)
for {
apiGroups, err := fetchGroupsOffset(page)
if err != nil {
return groups
}
groups = append(groups, apiGroups.Results...)
if apiGroups.Pagination.Next > 0 {
page += 1
} else {
break
}
}
return groups
}

View File

@ -1,6 +1,7 @@
package memory package memory
import ( import (
"context"
"errors" "errors"
"fmt" "fmt"
"strings" "strings"
@ -16,6 +17,7 @@ import (
"goauthentik.io/internal/outpost/ldap/search" "goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils" "goauthentik.io/internal/outpost/ldap/utils"
"goauthentik.io/internal/outpost/ldap/utils/paginator"
) )
type MemorySearcher struct { type MemorySearcher struct {
@ -32,8 +34,8 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"), log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"),
} }
ms.log.Debug("initialised memory searcher") ms.log.Debug("initialised memory searcher")
ms.users = ms.FetchUsers() ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()))
ms.groups = ms.FetchGroups() ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()))
return ms return ms
} }

View File

@ -0,0 +1,64 @@
package paginator
import (
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
)
const PageSize = 100
func FetchUsers(req api.ApiCoreUsersListRequest) []api.User {
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
if err != nil {
log.WithError(err).Warning("failed to update users")
return nil, err
}
log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
return users, nil
}
page := 1
users := make([]api.User, 0)
for {
apiUsers, err := fetchUsersOffset(page)
if err != nil {
log.WithError(err).WithField("page", page).Warn("Failed to fetch user page")
continue
}
users = append(users, apiUsers.Results...)
if apiUsers.Pagination.Next > 0 {
page += 1
} else {
break
}
}
return users
}
func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group {
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
if err != nil {
log.WithError(err).Warning("failed to update groups")
return nil, err
}
log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
return groups, nil
}
page := 1
groups := make([]api.Group, 0)
for {
apiGroups, err := fetchGroupsOffset(page)
if err != nil {
log.WithError(err).WithField("page", page).Warn("Failed to fetch group page")
continue
}
groups = append(groups, apiGroups.Results...)
if apiGroups.Pagination.Next > 0 {
page += 1
} else {
break
}
}
return groups
}