diff --git a/internal/outpost/ldap/search/direct/direct.go b/internal/outpost/ldap/search/direct/direct.go index b864ec3f9..ed67a6a33 100644 --- a/internal/outpost/ldap/search/direct/direct.go +++ b/internal/outpost/ldap/search/direct/direct.go @@ -18,6 +18,7 @@ import ( "goauthentik.io/internal/outpost/ldap/search" "goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/utils" + "goauthentik.io/internal/outpost/ldap/utils/paginator" ) type DirectSearcher struct { @@ -124,15 +125,10 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, return nil } - u, _, err := searchReq.Execute() + u := paginator.FetchUsers(searchReq) uapisp.Finish() - if err != nil { - req.Log().WithError(err).Warning("failed to get users") - return err - } - - users = &u.Results + users = &u } else { if flags.UserInfo == nil { uapisp := sentry.StartSpan(errCtx, "authentik.providers.ldap.search.api_user") @@ -170,29 +166,24 @@ func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, searchReq = searchReq.MembersByPk([]int32{flags.UserPk}) } - g, _, err := searchReq.Execute() + g := paginator.FetchGroups(searchReq) gapisp.Finish() - if err != nil { - req.Log().WithError(err).Warning("failed to get groups") - return err - } - req.Log().WithField("count", len(g.Results)).Trace("Got results from API") + req.Log().WithField("count", len(g)).Trace("Got results from API") if !flags.CanSearch { - for i, results := range g.Results { + for i, results := range g { // If they can't search, remove any users from the group results except the one we're looking for. - g.Results[i].Users = []int32{flags.UserPk} + g[i].Users = []int32{flags.UserPk} for _, u := range results.UsersObj { if u.Pk == flags.UserPk { - g.Results[i].UsersObj = []api.GroupMember{u} + g[i].UsersObj = []api.GroupMember{u} break } } } } - groups = &g.Results - + groups = &g return nil }) } diff --git a/internal/outpost/ldap/search/memory/fetch.go b/internal/outpost/ldap/search/memory/fetch.go deleted file mode 100644 index a18dfbcb7..000000000 --- a/internal/outpost/ldap/search/memory/fetch.go +++ /dev/null @@ -1,63 +0,0 @@ -package memory - -import ( - "context" - - "goauthentik.io/api/v3" -) - -const pageSize = 100 - -func (ms *MemorySearcher) FetchUsers() []api.User { - fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) { - users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute() - if err != nil { - ms.log.WithError(err).Warning("failed to update users") - return nil, err - } - ms.log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users") - return users, nil - } - page := 1 - users := make([]api.User, 0) - for { - apiUsers, err := fetchUsersOffset(page) - if err != nil { - return users - } - users = append(users, apiUsers.Results...) - if apiUsers.Pagination.Next > 0 { - page += 1 - } else { - break - } - } - return users -} - -func (ms *MemorySearcher) FetchGroups() []api.Group { - fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) { - groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute() - if err != nil { - ms.log.WithError(err).Warning("failed to update groups") - return nil, err - } - ms.log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups") - return groups, nil - } - page := 1 - groups := make([]api.Group, 0) - for { - apiGroups, err := fetchGroupsOffset(page) - if err != nil { - return groups - } - groups = append(groups, apiGroups.Results...) - if apiGroups.Pagination.Next > 0 { - page += 1 - } else { - break - } - } - return groups -} diff --git a/internal/outpost/ldap/search/memory/memory.go b/internal/outpost/ldap/search/memory/memory.go index 02885204c..164ea1925 100644 --- a/internal/outpost/ldap/search/memory/memory.go +++ b/internal/outpost/ldap/search/memory/memory.go @@ -1,6 +1,7 @@ package memory import ( + "context" "errors" "fmt" "strings" @@ -16,6 +17,7 @@ import ( "goauthentik.io/internal/outpost/ldap/search" "goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/utils" + "goauthentik.io/internal/outpost/ldap/utils/paginator" ) type MemorySearcher struct { @@ -32,8 +34,8 @@ func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher { log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"), } ms.log.Debug("initialised memory searcher") - ms.users = ms.FetchUsers() - ms.groups = ms.FetchGroups() + ms.users = paginator.FetchUsers(ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO())) + ms.groups = paginator.FetchGroups(ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO())) return ms } diff --git a/internal/outpost/ldap/utils/paginator/paginator.go b/internal/outpost/ldap/utils/paginator/paginator.go new file mode 100644 index 000000000..f6793b7e5 --- /dev/null +++ b/internal/outpost/ldap/utils/paginator/paginator.go @@ -0,0 +1,64 @@ +package paginator + +import ( + log "github.com/sirupsen/logrus" + "goauthentik.io/api/v3" +) + +const PageSize = 100 + +func FetchUsers(req api.ApiCoreUsersListRequest) []api.User { + fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) { + users, _, err := req.Page(int32(page)).PageSize(PageSize).Execute() + if err != nil { + log.WithError(err).Warning("failed to update users") + return nil, err + } + log.WithField("page", page).WithField("count", len(users.Results)).Debug("fetched users") + return users, nil + } + page := 1 + users := make([]api.User, 0) + for { + apiUsers, err := fetchUsersOffset(page) + if err != nil { + log.WithError(err).WithField("page", page).Warn("Failed to fetch user page") + continue + } + users = append(users, apiUsers.Results...) + if apiUsers.Pagination.Next > 0 { + page += 1 + } else { + break + } + } + return users +} + +func FetchGroups(req api.ApiCoreGroupsListRequest) []api.Group { + fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) { + groups, _, err := req.Page(int32(page)).PageSize(PageSize).Execute() + if err != nil { + log.WithError(err).Warning("failed to update groups") + return nil, err + } + log.WithField("page", page).WithField("count", len(groups.Results)).Debug("fetched groups") + return groups, nil + } + page := 1 + groups := make([]api.Group, 0) + for { + apiGroups, err := fetchGroupsOffset(page) + if err != nil { + log.WithError(err).WithField("page", page).Warn("Failed to fetch group page") + continue + } + groups = append(groups, apiGroups.Results...) + if apiGroups.Pagination.Next > 0 { + page += 1 + } else { + break + } + } + return groups +}