providers/ldap: correctly use pagination in search results in both modes (#5492)
closes #4292 Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
7f0ccc61dd
commit
b7b62ba089
|
@ -18,6 +18,7 @@ import (
|
||||||
"goauthentik.io/internal/outpost/ldap/search"
|
"goauthentik.io/internal/outpost/ldap/search"
|
||||||
"goauthentik.io/internal/outpost/ldap/server"
|
"goauthentik.io/internal/outpost/ldap/server"
|
||||||
"goauthentik.io/internal/outpost/ldap/utils"
|
"goauthentik.io/internal/outpost/ldap/utils"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/utils/paginator"
|
||||||
)
|
)
|
||||||
|
|
||||||
type DirectSearcher struct {
|
type DirectSearcher struct {
|
||||||
|
@ -124,15 +125,10 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
u, _, err := searchReq.Execute()
|
u := paginator.FetchUsers(searchReq)
|
||||||
uapisp.Finish()
|
uapisp.Finish()
|
||||||
|
|
||||||
if err != nil {
|
users = &u
|
||||||
req.Log().WithError(err).Warning("failed to get users")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
|
|
||||||
users = &u.Results
|
|
||||||
} else {
|
} else {
|
||||||
if flags.UserInfo == nil {
|
if flags.UserInfo == nil {
|
||||||
uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
|
uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user")
|
||||||
|
@ -170,29 +166,24 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult,
|
||||||
searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
|
searchReq = searchReq.MembersByPk([]int32{flags.UserPk})
|
||||||
}
|
}
|
||||||
|
|
||||||
g, _, err := searchReq.Execute()
|
g := paginator.FetchGroups(searchReq)
|
||||||
gapisp.Finish()
|
gapisp.Finish()
|
||||||
if err != nil {
|
req.Log().WithField("count", len(g)).Trace("Got results from API")
|
||||||
req.Log().WithError(err).Warning("failed to get groups")
|
|
||||||
return err
|
|
||||||
}
|
|
||||||
req.Log().WithField("count", len(g.Results)).Trace("Got results from API")
|
|
||||||
|
|
||||||
if !flags.CanSearch {
|
if !flags.CanSearch {
|
||||||
for i, results := range g.Results {
|
for i, results := range g {
|
||||||
// If they can't search, remove any users from the group results except the one we're looking for.
|
// If they can't search, remove any users from the group results except the one we're looking for.
|
||||||
g.Results[i].Users = []int32{flags.UserPk}
|
g[i].Users = []int32{flags.UserPk}
|
||||||
for _, u := range results.UsersObj {
|
for _, u := range results.UsersObj {
|
||||||
if u.Pk == flags.UserPk {
|
if u.Pk == flags.UserPk {
|
||||||
g.Results[i].UsersObj = []api.GroupMember{u}
|
g[i].UsersObj = []api.GroupMember{u}
|
||||||
break
|
break
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
groups = &g.Results
|
groups = &g
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
})
|
})
|
||||||
}
|
}
|
||||||
|
|
|
@ -1,63 +0,0 @@
|
||||||
package memory
|
|
||||||
|
|
||||||
import (
|
|
||||||
"context"
|
|
||||||
|
|
||||||
"goauthentik.io/api/v3"
|
|
||||||
)
|
|
||||||
|
|
||||||
const pageSize = 100
|
|
||||||
|
|
||||||
func (ms *MemorySearcher) FetchUsers() []api.User {
|
|
||||||
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
|
|
||||||
users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
|
|
||||||
if err != nil {
|
|
||||||
ms.log.WithError(err).Warning("failed to update users")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ms.log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
|
|
||||||
return users, nil
|
|
||||||
}
|
|
||||||
page := 1
|
|
||||||
users := make([]api.User, 0)
|
|
||||||
for {
|
|
||||||
apiUsers, err := fetchUsersOffset(page)
|
|
||||||
if err != nil {
|
|
||||||
return users
|
|
||||||
}
|
|
||||||
users = append(users, apiUsers.Results...)
|
|
||||||
if apiUsers.Pagination.Next > 0 {
|
|
||||||
page += 1
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return users
|
|
||||||
}
|
|
||||||
|
|
||||||
func (ms *MemorySearcher) FetchGroups() []api.Group {
|
|
||||||
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
|
|
||||||
groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
|
|
||||||
if err != nil {
|
|
||||||
ms.log.WithError(err).Warning("failed to update groups")
|
|
||||||
return nil, err
|
|
||||||
}
|
|
||||||
ms.log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
|
|
||||||
return groups, nil
|
|
||||||
}
|
|
||||||
page := 1
|
|
||||||
groups := make([]api.Group, 0)
|
|
||||||
for {
|
|
||||||
apiGroups, err := fetchGroupsOffset(page)
|
|
||||||
if err != nil {
|
|
||||||
return groups
|
|
||||||
}
|
|
||||||
groups = append(groups, apiGroups.Results...)
|
|
||||||
if apiGroups.Pagination.Next > 0 {
|
|
||||||
page += 1
|
|
||||||
} else {
|
|
||||||
break
|
|
||||||
}
|
|
||||||
}
|
|
||||||
return groups
|
|
||||||
}
|
|
|
@ -1,6 +1,7 @@
|
||||||
package memory
|
package memory
|
||||||
|
|
||||||
import (
|
import (
|
||||||
|
"context"
|
||||||
"errors"
|
"errors"
|
||||||
"fmt"
|
"fmt"
|
||||||
"strings"
|
"strings"
|
||||||
|
@ -16,6 +17,7 @@ import (
|
||||||
"goauthentik.io/internal/outpost/ldap/search"
|
"goauthentik.io/internal/outpost/ldap/search"
|
||||||
"goauthentik.io/internal/outpost/ldap/server"
|
"goauthentik.io/internal/outpost/ldap/server"
|
||||||
"goauthentik.io/internal/outpost/ldap/utils"
|
"goauthentik.io/internal/outpost/ldap/utils"
|
||||||
|
"goauthentik.io/internal/outpost/ldap/utils/paginator"
|
||||||
)
|
)
|
||||||
|
|
||||||
type MemorySearcher struct {
|
type MemorySearcher struct {
|
||||||
|
@ -32,8 +34,8 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
|
||||||
log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"),
|
log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"),
|
||||||
}
|
}
|
||||||
ms.log.Debug("initialised memory searcher")
|
ms.log.Debug("initialised memory searcher")
|
||||||
ms.users = ms.FetchUsers()
|
ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()))
|
||||||
ms.groups = ms.FetchGroups()
|
ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()))
|
||||||
return ms
|
return ms
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -0,0 +1,64 @@
|
||||||
|
package paginator
|
||||||
|
|
||||||
|
import (
|
||||||
|
log "github.com/sirupsen/logrus"
|
||||||
|
"goauthentik.io/api/v3"
|
||||||
|
)
|
||||||
|
|
||||||
|
const PageSize = 100
|
||||||
|
|
||||||
|
func FetchUsers(req api.ApiCoreUsersListRequest) []api.User {
|
||||||
|
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
|
||||||
|
users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warning("failed to update users")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users")
|
||||||
|
return users, nil
|
||||||
|
}
|
||||||
|
page := 1
|
||||||
|
users := make([]api.User, 0)
|
||||||
|
for {
|
||||||
|
apiUsers, err := fetchUsersOffset(page)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).WithField("page", page).Warn("Failed to fetch user page")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
users = append(users, apiUsers.Results...)
|
||||||
|
if apiUsers.Pagination.Next > 0 {
|
||||||
|
page += 1
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return users
|
||||||
|
}
|
||||||
|
|
||||||
|
func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group {
|
||||||
|
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
|
||||||
|
groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute()
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).Warning("failed to update groups")
|
||||||
|
return nil, err
|
||||||
|
}
|
||||||
|
log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups")
|
||||||
|
return groups, nil
|
||||||
|
}
|
||||||
|
page := 1
|
||||||
|
groups := make([]api.Group, 0)
|
||||||
|
for {
|
||||||
|
apiGroups, err := fetchGroupsOffset(page)
|
||||||
|
if err != nil {
|
||||||
|
log.WithError(err).WithField("page", page).Warn("Failed to fetch group page")
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
groups = append(groups, apiGroups.Results...)
|
||||||
|
if apiGroups.Pagination.Next > 0 {
|
||||||
|
page += 1
|
||||||
|
} else {
|
||||||
|
break
|
||||||
|
}
|
||||||
|
}
|
||||||
|
return groups
|
||||||
|
}
|
Reference in New Issue