diff --git a/Makefile b/Makefile index 331d73d59..df4992b79 100644 --- a/Makefile +++ b/Makefile @@ -60,18 +60,19 @@ gen-web: \cp -rfv web-api/* web/node_modules/@goauthentik/api gen-outpost: + wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O config.yaml + mkdir -p templates + wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O templates/README.mustache + wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O templates/go.mod.mustache docker run \ --rm -v ${PWD}:/local \ --user ${UID}:${GID} \ openapitools/openapi-generator-cli generate \ - --git-host goauthentik.io \ - --git-repo-id outpost \ - --git-user-id api \ -i /local/schema.yml \ -g go \ -o /local/api \ - --additional-properties=packageName=api,enumClassPrefix=true,useOneOfDiscriminatorLookup=true,disallowAdditionalPropertiesIfNotPresent=false - rm -f api/go.mod api/go.sum + -c /local/config.yaml + go mod edit -replace goauthentik.io/api=./api gen: gen-build gen-clean gen-web diff --git a/authentik/providers/ldap/api.py b/authentik/providers/ldap/api.py index b4c7a861f..2532fac56 100644 --- a/authentik/providers/ldap/api.py +++ b/authentik/providers/ldap/api.py @@ -24,6 +24,7 @@ class LDAPProviderSerializer(ProviderSerializer): "uid_start_number", "gid_start_number", "outpost_set", + "search_mode", ] @@ -68,6 +69,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer): "tls_server_name", "uid_start_number", "gid_start_number", + "search_mode", ] diff --git a/authentik/providers/ldap/models.py b/authentik/providers/ldap/models.py index d3be0883d..350fd17c3 100644 --- a/authentik/providers/ldap/models.py +++ b/authentik/providers/ldap/models.py @@ -9,6 +9,12 @@ from authentik.core.models import Group, Provider from authentik.crypto.models import CertificateKeyPair from authentik.outposts.models import OutpostModel +class SearchModes(models.TextChoices): + """Search modes""" + + DIRECT = "direct" + CACHED = "cached" + class LDAPProvider(OutpostModel, Provider): """Allow applications to authenticate against authentik's users using LDAP.""" @@ -59,6 +65,8 @@ class LDAPProvider(OutpostModel, Provider): ), ) + search_mode = models.TextField(default=SearchModes.DIRECT, choices=SearchModes.choices) + @property def launch_url(self) -> Optional[str]: """LDAP never has a launch URL""" diff --git a/internal/outpost/flow.go b/internal/outpost/flow.go index bc8e8e904..0c353c3ff 100644 --- a/internal/outpost/flow.go +++ b/internal/outpost/flow.go @@ -4,7 +4,6 @@ import ( "context" "errors" "fmt" - "net" "net/http" "net/http/cookiejar" "net/url" @@ -18,7 +17,6 @@ import ( "goauthentik.io/api" "goauthentik.io/internal/constants" "goauthentik.io/internal/outpost/ak" - "goauthentik.io/internal/utils" ) type StageComponent string @@ -103,8 +101,8 @@ type ChallengeInt interface { GetResponseErrors() map[string][]api.ErrorDetail } -func (fe *FlowExecutor) DelegateClientIP(a net.Addr) { - fe.cip = utils.GetIP(a) +func (fe *FlowExecutor) DelegateClientIP(a string) { + fe.cip = a fe.api.GetConfig().AddDefaultHeader(HeaderAuthentikRemoteIP, fe.cip) } diff --git a/internal/outpost/ldap/api_tls.go b/internal/outpost/ldap/api_tls.go deleted file mode 100644 index 0bcbdf7d2..000000000 --- a/internal/outpost/ldap/api_tls.go +++ /dev/null @@ -1,23 +0,0 @@ -package ldap - -import "crypto/tls" - -func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) { - if len(ls.providers) == 1 { - if ls.providers[0].cert != nil { - ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert") - return ls.providers[0].cert, nil - } - } - for _, provider := range ls.providers { - if provider.tlsServerName == &info.ServerName { - if provider.cert == nil { - ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate") - return ls.defaultCert, nil - } - return provider.cert, nil - } - } - ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert") - return ls.defaultCert, nil -} diff --git a/internal/outpost/ldap/bind.go b/internal/outpost/ldap/bind.go index 3588f2b33..fedab284b 100644 --- a/internal/outpost/ldap/bind.go +++ b/internal/outpost/ldap/bind.go @@ -1,44 +1,18 @@ package ldap import ( - "context" "net" - "strings" - "github.com/getsentry/sentry-go" - "github.com/google/uuid" "github.com/nmcclain/ldap" "github.com/prometheus/client_golang/prometheus" - log "github.com/sirupsen/logrus" + "goauthentik.io/internal/outpost/ldap/bind" "goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/utils" ) -type BindRequest struct { - BindDN string - BindPW string - id string - conn net.Conn - log *log.Entry - ctx context.Context -} - func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LDAPResultCode, error) { - span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind", - sentry.TransactionName("authentik.providers.ldap.bind")) - rid := uuid.New().String() - span.SetTag("request_uid", rid) - span.SetTag("user.username", bindDN) + req, span := bind.NewRequest(bindDN, bindPW, conn) - bindDN = strings.ToLower(bindDN) - req := BindRequest{ - BindDN: bindDN, - BindPW: bindPW, - conn: conn, - log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())), - id: rid, - ctx: span.Context(), - } defer func() { span.Finish() metrics.Requests.With(prometheus.Labels{ @@ -46,19 +20,19 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD "type": "bind", "filter": "", "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), + "client": req.RemoteAddr(), }).Observe(float64(span.EndTime.Sub(span.StartTime))) - req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request") + req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request") }() for _, instance := range ls.providers { - username, err := instance.getUsername(bindDN) + username, err := instance.binder.GetUsername(bindDN) if err == nil { - return instance.Bind(username, req) + return instance.binder.Bind(username, req) } else { - req.log.WithError(err).Debug("Username not for instance") + req.Log().WithError(err).Debug("Username not for instance") } } - req.log.WithField("request", "bind").Warning("No provider found for request") + req.Log().WithField("request", "bind").Warning("No provider found for request") metrics.RequestsRejected.With(prometheus.Labels{ "outpost_name": ls.ac.Outpost.Name, "type": "bind", @@ -68,10 +42,3 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD }).Inc() return ldap.LDAPResultOperationsError, nil } - -func (ls *LDAPServer) TimerFlowCacheExpiry() { - for _, p := range ls.providers { - ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache") - p.TimerFlowCacheExpiry() - } -} diff --git a/internal/outpost/ldap/bind/binder.go b/internal/outpost/ldap/bind/binder.go new file mode 100644 index 000000000..d0aa4a618 --- /dev/null +++ b/internal/outpost/ldap/bind/binder.go @@ -0,0 +1,9 @@ +package bind + +import "github.com/nmcclain/ldap" + +type Binder interface { + GetUsername(string) (string, error) + Bind(username string, req *Request) (ldap.LDAPResultCode, error) + TimerFlowCacheExpiry() +} diff --git a/internal/outpost/ldap/instance_bind.go b/internal/outpost/ldap/bind/direct/direct.go similarity index 53% rename from internal/outpost/ldap/instance_bind.go rename to internal/outpost/ldap/bind/direct/direct.go index f1bce89c0..ba5e91a3f 100644 --- a/internal/outpost/ldap/instance_bind.go +++ b/internal/outpost/ldap/bind/direct/direct.go @@ -1,4 +1,4 @@ -package ldap +package direct import ( "context" @@ -12,14 +12,30 @@ import ( log "github.com/sirupsen/logrus" "goauthentik.io/api" "goauthentik.io/internal/outpost" + "goauthentik.io/internal/outpost/ldap/bind" + "goauthentik.io/internal/outpost/ldap/flags" "goauthentik.io/internal/outpost/ldap/metrics" - "goauthentik.io/internal/utils" + "goauthentik.io/internal/outpost/ldap/server" ) const ContextUserKey = "ak_user" -func (pi *ProviderInstance) getUsername(dn string) (string, error) { - if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(pi.BaseDN)) { +type DirectBinder struct { + si server.LDAPServerInstance + log *log.Entry +} + +func NewDirectBinder(si server.LDAPServerInstance) *DirectBinder { + db := &DirectBinder{ + si: si, + log: log.WithField("logger", "authentik.outpost.ldap.binder.direct"), + } + db.log.Info("initialised direct binder") + return db +} + +func (db *DirectBinder) GetUsername(dn string) (string, error) { + if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(db.si.GetBaseDN())) { return "", errors.New("invalid base DN") } dns, err := goldap.ParseDN(dn) @@ -36,13 +52,13 @@ func (pi *ProviderInstance) getUsername(dn string) (string, error) { return "", errors.New("failed to find cn") } -func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPResultCode, error) { - fe := outpost.NewFlowExecutor(req.ctx, pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{ +func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) { + fe := outpost.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{ "bindDN": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), - "requestId": req.id, + "client": req.RemoteAddr(), + "requestId": req.ID(), }) - fe.DelegateClientIP(req.conn.RemoteAddr()) + fe.DelegateClientIP(req.RemoteAddr()) fe.Params.Add("goauthentik.io/outpost/ldap", "true") fe.Answers[outpost.StageIdentification] = username @@ -51,83 +67,82 @@ func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPRes passed, err := fe.Execute() if !passed { metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, + "outpost_name": db.si.GetOutpostName(), "type": "bind", "reason": "invalid_credentials", "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), + "client": req.RemoteAddr(), }).Inc() return ldap.LDAPResultInvalidCredentials, nil } if err != nil { metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, + "outpost_name": db.si.GetOutpostName(), "type": "bind", "reason": "flow_error", "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), + "client": req.RemoteAddr(), }).Inc() - req.log.WithError(err).Warning("failed to execute flow") + req.Log().WithError(err).Warning("failed to execute flow") return ldap.LDAPResultOperationsError, nil } - access, err := fe.CheckApplicationAccess(pi.appSlug) + access, err := fe.CheckApplicationAccess(db.si.GetAppSlug()) if !access { - req.log.Info("Access denied for user") + req.Log().Info("Access denied for user") metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, + "outpost_name": db.si.GetOutpostName(), "type": "bind", "reason": "access_denied", "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), + "client": req.RemoteAddr(), }).Inc() return ldap.LDAPResultInsufficientAccessRights, nil } if err != nil { metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, + "outpost_name": db.si.GetOutpostName(), "type": "bind", "reason": "access_check_fail", "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), + "client": req.RemoteAddr(), }).Inc() - req.log.WithError(err).Warning("failed to check access") + req.Log().WithError(err).Warning("failed to check access") return ldap.LDAPResultOperationsError, nil } - req.log.Info("User has access") - uisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.bind.user_info") + req.Log().Info("User has access") + uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info") // Get user info to store in context userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute() if err != nil { metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, + "outpost_name": db.si.GetOutpostName(), "type": "bind", "reason": "user_info_fail", "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), + "client": req.RemoteAddr(), }).Inc() - req.log.WithError(err).Warning("failed to get user info") + req.Log().WithError(err).Warning("failed to get user info") return ldap.LDAPResultOperationsError, nil } - pi.boundUsersMutex.Lock() - cs := pi.SearchAccessCheck(userInfo.User) - pi.boundUsers[req.BindDN] = UserFlags{ + cs := db.SearchAccessCheck(userInfo.User) + flags := flags.UserFlags{ UserPk: userInfo.User.Pk, CanSearch: cs != nil, } - if pi.boundUsers[req.BindDN].CanSearch { - req.log.WithField("group", cs).Info("Allowed access to search") + db.si.SetFlags(req.BindDN, flags) + if flags.CanSearch { + req.Log().WithField("group", cs).Info("Allowed access to search") } uisp.Finish() - defer pi.boundUsersMutex.Unlock() return ldap.LDAPResultSuccess, nil } // SearchAccessCheck Check if the current user is allowed to search -func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string { +func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string { for _, group := range user.Groups { - for _, allowedGroup := range pi.searchAllowedGroups { - pi.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access") + for _, allowedGroup := range db.si.GetSearchAllowedGroups() { + db.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access") if group.Pk == allowedGroup.String() { return &group.Name } @@ -136,13 +151,13 @@ func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string { return nil } -func (pi *ProviderInstance) TimerFlowCacheExpiry() { - fe := outpost.NewFlowExecutor(context.Background(), pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{}) +func (db *DirectBinder) TimerFlowCacheExpiry() { + fe := outpost.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{}) fe.Params.Add("goauthentik.io/outpost/ldap", "true") fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true") err := fe.WarmUp() if err != nil { - pi.log.WithError(err).Warning("failed to warm up flow cache") + db.log.WithError(err).Warning("failed to warm up flow cache") } } diff --git a/internal/outpost/ldap/bind/request.go b/internal/outpost/ldap/bind/request.go new file mode 100644 index 000000000..43d379282 --- /dev/null +++ b/internal/outpost/ldap/bind/request.go @@ -0,0 +1,55 @@ +package bind + +import ( + "context" + "net" + "strings" + + "github.com/getsentry/sentry-go" + "github.com/google/uuid" + log "github.com/sirupsen/logrus" + "goauthentik.io/internal/utils" +) + +type Request struct { + BindDN string + BindPW string + id string + conn net.Conn + log *log.Entry + ctx context.Context +} + +func NewRequest(bindDN string, bindPW string, conn net.Conn) (*Request, *sentry.Span) { + span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind", + sentry.TransactionName("authentik.providers.ldap.bind")) + rid := uuid.New().String() + span.SetTag("request_uid", rid) + span.SetTag("user.username", bindDN) + + bindDN = strings.ToLower(bindDN) + return &Request{ + BindDN: bindDN, + BindPW: bindPW, + conn: conn, + log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())), + id: rid, + ctx: span.Context(), + }, span +} + +func (r *Request) Context() context.Context { + return r.ctx +} + +func (r *Request) Log() *log.Entry { + return r.log +} + +func (r *Request) RemoteAddr() string { + return utils.GetIP(r.conn.RemoteAddr()) +} + +func (r *Request) ID() string { + return r.id +} diff --git a/internal/outpost/ldap/constants/constants.go b/internal/outpost/ldap/constants/constants.go new file mode 100644 index 000000000..d791544d9 --- /dev/null +++ b/internal/outpost/ldap/constants/constants.go @@ -0,0 +1,21 @@ +package constants + +const ( + OCGroup = "group" + OCGroupOfUniqueNames = "groupOfUniqueNames" + OCAKGroup = "goauthentik.io/ldap/group" + OCAKVirtualGroup = "goauthentik.io/ldap/virtual-group" +) + +const ( + OCUser = "user" + OCOrgPerson = "organizationalPerson" + OCInetOrgPerson = "inetOrgPerson" + OCAKUser = "goauthentik.io/ldap/user" +) + +const ( + OUUsers = "users" + OUGroups = "groups" + OUVirtualGroups = "virtual-groups" +) diff --git a/internal/outpost/ldap/entries.go b/internal/outpost/ldap/entries.go new file mode 100644 index 000000000..3cd93f809 --- /dev/null +++ b/internal/outpost/ldap/entries.go @@ -0,0 +1,39 @@ +package ldap + +import ( + "github.com/nmcclain/ldap" + "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/constants" + "goauthentik.io/internal/outpost/ldap/group" + "goauthentik.io/internal/outpost/ldap/utils" +) + +func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry { + dn := pi.GetUserDN(u.Username) + attrs := utils.AKAttrsToLDAP(u.Attributes) + + attrs = utils.EnsureAttributes(attrs, map[string][]string{ + "memberOf": pi.GroupsForUser(u), + // Old fields for backwards compatibility + "accountStatus": {utils.BoolToString(*u.IsActive)}, + "superuser": {utils.BoolToString(u.IsSuperuser)}, + // End old fields + "goauthentik.io/ldap/active": {utils.BoolToString(*u.IsActive)}, + "goauthentik.io/ldap/superuser": {utils.BoolToString(u.IsSuperuser)}, + "cn": {u.Username}, + "sAMAccountName": {u.Username}, + "uid": {u.Uid}, + "name": {u.Name}, + "displayName": {u.Name}, + "mail": {*u.Email}, + "objectClass": {constants.OCUser, constants.OCOrgPerson, constants.OCInetOrgPerson, constants.OCAKUser}, + "uidNumber": {pi.GetUidNumber(u)}, + "gidNumber": {pi.GetUidNumber(u)}, + }) + return &ldap.Entry{DN: dn, Attributes: attrs} +} + +func (pi *ProviderInstance) GroupEntry(g group.LDAPGroup) *ldap.Entry { + // TODO: Remove + return g.Entry() +} diff --git a/internal/outpost/ldap/flags/flags.go b/internal/outpost/ldap/flags/flags.go new file mode 100644 index 000000000..8a774d19f --- /dev/null +++ b/internal/outpost/ldap/flags/flags.go @@ -0,0 +1,9 @@ +package flags + +import "goauthentik.io/api" + +type UserFlags struct { + UserInfo *api.User + UserPk int32 + CanSearch bool +} diff --git a/internal/outpost/ldap/group/group.go b/internal/outpost/ldap/group/group.go new file mode 100644 index 000000000..b7f02a079 --- /dev/null +++ b/internal/outpost/ldap/group/group.go @@ -0,0 +1,66 @@ +package group + +import ( + "github.com/nmcclain/ldap" + "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/constants" + "goauthentik.io/internal/outpost/ldap/server" + "goauthentik.io/internal/outpost/ldap/utils" +) + +type LDAPGroup struct { + DN string + CN string + Uid string + GidNumber string + Member []string + IsSuperuser bool + IsVirtualGroup bool + AKAttributes interface{} +} + +func (lg *LDAPGroup) Entry() *ldap.Entry { + attrs := utils.AKAttrsToLDAP(lg.AKAttributes) + + objectClass := []string{constants.OCGroup, constants.OCGroupOfUniqueNames, constants.OCAKGroup} + if lg.IsVirtualGroup { + objectClass = append(objectClass, constants.OCAKVirtualGroup) + } + + attrs = utils.EnsureAttributes(attrs, map[string][]string{ + "objectClass": objectClass, + "member": lg.Member, + "goauthentik.io/ldap/superuser": {utils.BoolToString(lg.IsSuperuser)}, + "cn": {lg.CN}, + "uid": {lg.Uid}, + "sAMAccountName": {lg.CN}, + "gidNumber": {lg.GidNumber}, + }) + return &ldap.Entry{DN: lg.DN, Attributes: attrs} +} + +func FromAPIGroup(g api.Group, si server.LDAPServerInstance) *LDAPGroup { + return &LDAPGroup{ + DN: si.GetGroupDN(g.Name), + CN: g.Name, + Uid: string(g.Pk), + GidNumber: si.GetGidNumber(g), + Member: si.UsersForGroup(g), + IsVirtualGroup: false, + IsSuperuser: *g.IsSuperuser, + AKAttributes: g.Attributes, + } +} + +func FromAPIUser(u api.User, si server.LDAPServerInstance) *LDAPGroup { + return &LDAPGroup{ + DN: si.GetVirtualGroupDN(u.Username), + CN: u.Username, + Uid: u.Uid, + GidNumber: si.GetUidNumber(u), + Member: []string{si.GetUserDN(u.Username)}, + IsVirtualGroup: true, + IsSuperuser: false, + AKAttributes: nil, + } +} diff --git a/internal/outpost/ldap/handler/handler.go b/internal/outpost/ldap/handler/handler.go new file mode 100644 index 000000000..6d1c380cf --- /dev/null +++ b/internal/outpost/ldap/handler/handler.go @@ -0,0 +1,4 @@ +package handler + +type Handler interface { +} diff --git a/internal/outpost/ldap/instance.go b/internal/outpost/ldap/instance.go new file mode 100644 index 000000000..0bacf428b --- /dev/null +++ b/internal/outpost/ldap/instance.go @@ -0,0 +1,83 @@ +package ldap + +import ( + "crypto/tls" + "sync" + + "github.com/go-openapi/strfmt" + log "github.com/sirupsen/logrus" + "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/bind" + "goauthentik.io/internal/outpost/ldap/flags" + "goauthentik.io/internal/outpost/ldap/search" +) + +type ProviderInstance struct { + BaseDN string + UserDN string + VirtualGroupDN string + GroupDN string + + searcher search.Searcher + binder bind.Binder + + appSlug string + flowSlug string + s *LDAPServer + log *log.Entry + + tlsServerName *string + cert *tls.Certificate + outpostName string + searchAllowedGroups []*strfmt.UUID + boundUsersMutex sync.RWMutex + boundUsers map[string]flags.UserFlags + + uidStartNumber int32 + gidStartNumber int32 +} + +func (pi *ProviderInstance) GetAPIClient() *api.APIClient { + return pi.s.ac.Client +} + +func (pi *ProviderInstance) GetBaseDN() string { + return pi.BaseDN +} + +func (pi *ProviderInstance) GetBaseGroupDN() string { + return pi.GroupDN +} + +func (pi *ProviderInstance) GetBaseUserDN() string { + return pi.UserDN +} + +func (pi *ProviderInstance) GetOutpostName() string { + return pi.outpostName +} + +func (pi *ProviderInstance) GetFlags(dn string) (flags.UserFlags, bool) { + pi.boundUsersMutex.RLock() + flags, ok := pi.boundUsers[dn] + pi.boundUsersMutex.RUnlock() + return flags, ok +} + +func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) { + pi.boundUsersMutex.Lock() + pi.boundUsers[dn] = flag + pi.boundUsersMutex.Unlock() +} + +func (pi *ProviderInstance) GetAppSlug() string { + return pi.appSlug +} + +func (pi *ProviderInstance) GetFlowSlug() string { + return pi.flowSlug +} + +func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID { + return pi.searchAllowedGroups +} diff --git a/internal/outpost/ldap/instance_search.go b/internal/outpost/ldap/instance_search.go deleted file mode 100644 index 7047f76b4..000000000 --- a/internal/outpost/ldap/instance_search.go +++ /dev/null @@ -1,244 +0,0 @@ -package ldap - -import ( - "errors" - "fmt" - "strings" - "sync" - - "github.com/getsentry/sentry-go" - "github.com/nmcclain/ldap" - "github.com/prometheus/client_golang/prometheus" - "goauthentik.io/api" - "goauthentik.io/internal/outpost/ldap/metrics" - "goauthentik.io/internal/utils" -) - -func (pi *ProviderInstance) SearchMe(req SearchRequest, f UserFlags) (ldap.ServerSearchResult, error) { - if f.UserInfo == nil { - u, _, err := pi.s.ac.Client.CoreApi.CoreUsersRetrieve(req.ctx, f.UserPk).Execute() - if err != nil { - req.log.WithError(err).Warning("Failed to get user info") - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo") - } - f.UserInfo = &u - } - entries := make([]*ldap.Entry, 1) - entries[0] = pi.UserEntry(*f.UserInfo) - return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil -} - -func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, error) { - accsp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.check_access") - baseDN := strings.ToLower("," + pi.BaseDN) - - entries := []*ldap.Entry{} - filterEntity, err := ldap.GetFilterObjectClass(req.Filter) - if err != nil { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, - "type": "search", - "reason": "filter_parse_fail", - "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), - }).Inc() - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter) - } - if len(req.BindDN) < 1 { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, - "type": "search", - "reason": "empty_bind_dn", - "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), - }).Inc() - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN) - } - if !strings.HasSuffix(req.BindDN, baseDN) { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, - "type": "search", - "reason": "invalid_bind_dn", - "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), - }).Inc() - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, pi.BaseDN) - } - - pi.boundUsersMutex.RLock() - flags, ok := pi.boundUsers[req.BindDN] - pi.boundUsersMutex.RUnlock() - if !ok { - pi.log.Debug("User info not cached") - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, - "type": "search", - "reason": "user_info_not_cached", - "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), - }).Inc() - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied") - } - - if req.SearchRequest.Scope == ldap.ScopeBaseObject { - pi.log.Debug("base scope, showing domain info") - return pi.SearchBase(req, flags.CanSearch) - } - if !flags.CanSearch { - pi.log.Debug("User can't search, showing info about user") - return pi.SearchMe(req, flags) - } - accsp.Finish() - - parsedFilter, err := ldap.CompileFilter(req.Filter) - if err != nil { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, - "type": "search", - "reason": "filter_parse_fail", - "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), - }).Inc() - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter) - } - - // Create a custom client to set additional headers - c := api.NewAPIClient(pi.s.ac.Client.GetConfig()) - c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter) - - switch filterEntity { - default: - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": pi.outpostName, - "type": "search", - "reason": "unhandled_filter_type", - "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), - }).Inc() - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter) - case "groupOfUniqueNames": - fallthrough - case "goauthentik.io/ldap/group": - fallthrough - case "goauthentik.io/ldap/virtual-group": - fallthrough - case GroupObjectClass: - wg := sync.WaitGroup{} - wg.Add(2) - - gEntries := make([]*ldap.Entry, 0) - uEntries := make([]*ldap.Entry, 0) - - go func() { - defer wg.Done() - gapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_group") - searchReq, skip := parseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false) - if skip { - pi.log.Trace("Skip backend request") - return - } - groups, _, err := searchReq.Execute() - gapisp.Finish() - if err != nil { - req.log.WithError(err).Warning("failed to get groups") - return - } - pi.log.WithField("count", len(groups.Results)).Trace("Got results from API") - - for _, g := range groups.Results { - gEntries = append(gEntries, pi.GroupEntry(pi.APIGroupToLDAPGroup(g))) - } - }() - - go func() { - defer wg.Done() - uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user") - searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false) - if skip { - pi.log.Trace("Skip backend request") - return - } - users, _, err := searchReq.Execute() - uapisp.Finish() - if err != nil { - req.log.WithError(err).Warning("failed to get users") - return - } - - for _, u := range users.Results { - uEntries = append(uEntries, pi.GroupEntry(pi.APIUserToLDAPGroup(u))) - } - }() - wg.Wait() - entries = append(gEntries, uEntries...) - case "": - fallthrough - case "organizationalPerson": - fallthrough - case "inetOrgPerson": - fallthrough - case "goauthentik.io/ldap/user": - fallthrough - case UserObjectClass: - uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user") - searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false) - if skip { - pi.log.Trace("Skip backend request") - return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil - } - users, _, err := searchReq.Execute() - uapisp.Finish() - - if err != nil { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err) - } - for _, u := range users.Results { - entries = append(entries, pi.UserEntry(u)) - } - } - return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil -} - -func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry { - dn := pi.GetUserDN(u.Username) - attrs := AKAttrsToLDAP(u.Attributes) - - attrs = pi.ensureAttributes(attrs, map[string][]string{ - "memberOf": pi.GroupsForUser(u), - // Old fields for backwards compatibility - "accountStatus": {BoolToString(*u.IsActive)}, - "superuser": {BoolToString(u.IsSuperuser)}, - "goauthentik.io/ldap/active": {BoolToString(*u.IsActive)}, - "goauthentik.io/ldap/superuser": {BoolToString(u.IsSuperuser)}, - "cn": {u.Username}, - "sAMAccountName": {u.Username}, - "uid": {u.Uid}, - "name": {u.Name}, - "displayName": {u.Name}, - "mail": {*u.Email}, - "objectClass": {UserObjectClass, "organizationalPerson", "inetOrgPerson", "goauthentik.io/ldap/user"}, - "uidNumber": {pi.GetUidNumber(u)}, - "gidNumber": {pi.GetUidNumber(u)}, - }) - return &ldap.Entry{DN: dn, Attributes: attrs} -} - -func (pi *ProviderInstance) GroupEntry(g LDAPGroup) *ldap.Entry { - attrs := AKAttrsToLDAP(g.akAttributes) - - objectClass := []string{GroupObjectClass, "groupOfUniqueNames", "goauthentik.io/ldap/group"} - if g.isVirtualGroup { - objectClass = append(objectClass, "goauthentik.io/ldap/virtual-group") - } - - attrs = pi.ensureAttributes(attrs, map[string][]string{ - "objectClass": objectClass, - "member": g.member, - "goauthentik.io/ldap/superuser": {BoolToString(g.isSuperuser)}, - "cn": {g.cn}, - "uid": {g.uid}, - "sAMAccountName": {g.cn}, - "gidNumber": {g.gidNumber}, - }) - return &ldap.Entry{DN: g.dn, Attributes: attrs} -} diff --git a/internal/outpost/ldap/ldap.go b/internal/outpost/ldap/ldap.go index 15a044972..96d2df84d 100644 --- a/internal/outpost/ldap/ldap.go +++ b/internal/outpost/ldap/ldap.go @@ -2,50 +2,18 @@ package ldap import ( "crypto/tls" + "net" "sync" - "github.com/go-openapi/strfmt" + "github.com/pires/go-proxyproto" log "github.com/sirupsen/logrus" - "goauthentik.io/api" "goauthentik.io/internal/crypto" "goauthentik.io/internal/outpost/ak" + "goauthentik.io/internal/outpost/ldap/metrics" "github.com/nmcclain/ldap" ) -const GroupObjectClass = "group" -const UserObjectClass = "user" - -type ProviderInstance struct { - BaseDN string - - UserDN string - - VirtualGroupDN string - GroupDN string - - appSlug string - flowSlug string - s *LDAPServer - log *log.Entry - - tlsServerName *string - cert *tls.Certificate - outpostName string - searchAllowedGroups []*strfmt.UUID - boundUsersMutex sync.RWMutex - boundUsers map[string]UserFlags - - uidStartNumber int32 - gidStartNumber int32 -} - -type UserFlags struct { - UserInfo *api.User - UserPk int32 - CanSearch bool -} - type LDAPServer struct { s *ldap.Server log *log.Entry @@ -55,17 +23,6 @@ type LDAPServer struct { providers []*ProviderInstance } -type LDAPGroup struct { - dn string - cn string - uid string - gidNumber string - member []string - isSuperuser bool - isVirtualGroup bool - akAttributes interface{} -} - func NewServer(ac *ak.APIController) *LDAPServer { s := ldap.NewServer() s.EnforceLDAP = true @@ -90,3 +47,54 @@ func NewServer(ac *ak.APIController) *LDAPServer { func (ls *LDAPServer) Type() string { return "ldap" } + +func (ls *LDAPServer) StartLDAPServer() error { + listen := "0.0.0.0:3389" + + ln, err := net.Listen("tcp", listen) + if err != nil { + ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed") + } + proxyListener := &proxyproto.Listener{Listener: ln} + defer proxyListener.Close() + + ls.log.WithField("listen", listen).Info("Starting ldap server") + err = ls.s.Serve(proxyListener) + if err != nil { + return err + } + ls.log.Printf("closing %s", ln.Addr()) + return ls.s.ListenAndServe(listen) +} + +func (ls *LDAPServer) Start() error { + wg := sync.WaitGroup{} + wg.Add(3) + go func() { + defer wg.Done() + metrics.RunServer() + }() + go func() { + defer wg.Done() + err := ls.StartLDAPServer() + if err != nil { + panic(err) + } + }() + go func() { + defer wg.Done() + err := ls.StartLDAPTLSServer() + if err != nil { + panic(err) + } + }() + wg.Wait() + return nil +} + +func (ls *LDAPServer) TimerFlowCacheExpiry() { + for _, p := range ls.providers { + ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache") + p.binder.TimerFlowCacheExpiry() + } +} diff --git a/internal/outpost/ldap/ldap_tls.go b/internal/outpost/ldap/ldap_tls.go new file mode 100644 index 000000000..f42d76a67 --- /dev/null +++ b/internal/outpost/ldap/ldap_tls.go @@ -0,0 +1,55 @@ +package ldap + +import ( + "crypto/tls" + "net" + + "github.com/pires/go-proxyproto" +) + +func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) { + if len(ls.providers) == 1 { + if ls.providers[0].cert != nil { + ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert") + return ls.providers[0].cert, nil + } + } + for _, provider := range ls.providers { + if provider.tlsServerName == &info.ServerName { + if provider.cert == nil { + ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate") + return ls.defaultCert, nil + } + return provider.cert, nil + } + } + ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert") + return ls.defaultCert, nil +} + +func (ls *LDAPServer) StartLDAPTLSServer() error { + listen := "0.0.0.0:6636" + tlsConfig := &tls.Config{ + MinVersion: tls.VersionTLS12, + MaxVersion: tls.VersionTLS12, + GetCertificate: ls.getCertificates, + } + + ln, err := net.Listen("tcp", listen) + if err != nil { + ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed") + } + + proxyListener := &proxyproto.Listener{Listener: ln} + defer proxyListener.Close() + + tln := tls.NewListener(proxyListener, tlsConfig) + + ls.log.WithField("listen", listen).Info("Starting ldap tls server") + err = ls.s.Serve(tln) + if err != nil { + return err + } + ls.log.Printf("closing %s", ln.Addr()) + return ls.s.ListenAndServe(listen) +} diff --git a/internal/outpost/ldap/refresh.go b/internal/outpost/ldap/refresh.go index 5aee26985..1ec37ead9 100644 --- a/internal/outpost/ldap/refresh.go +++ b/internal/outpost/ldap/refresh.go @@ -2,23 +2,19 @@ package ldap import ( "context" - "crypto/tls" "errors" "fmt" - "net" "strings" "sync" "github.com/go-openapi/strfmt" - "github.com/pires/go-proxyproto" log "github.com/sirupsen/logrus" - "goauthentik.io/internal/outpost/ldap/metrics" -) - -const ( - UsersOU = "users" - GroupsOU = "groups" - VirtualGroupsOU = "virtual-groups" + "goauthentik.io/api" + directbind "goauthentik.io/internal/outpost/ldap/bind/direct" + "goauthentik.io/internal/outpost/ldap/constants" + "goauthentik.io/internal/outpost/ldap/flags" + directsearch "goauthentik.io/internal/outpost/ldap/search/direct" + memorysearch "goauthentik.io/internal/outpost/ldap/search/memory" ) func (ls *LDAPServer) Refresh() error { @@ -31,9 +27,9 @@ func (ls *LDAPServer) Refresh() error { } providers := make([]*ProviderInstance, len(outposts.Results)) for idx, provider := range outposts.Results { - userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", UsersOU, *provider.BaseDn)) - groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", GroupsOU, *provider.BaseDn)) - virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", VirtualGroupsOU, *provider.BaseDn)) + userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn)) + groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) + virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn)) logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name) providers[idx] = &ProviderInstance{ BaseDN: *provider.BaseDn, @@ -44,7 +40,7 @@ func (ls *LDAPServer) Refresh() error { flowSlug: provider.BindFlowSlug, searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, boundUsersMutex: sync.RWMutex{}, - boundUsers: make(map[string]UserFlags), + boundUsers: make(map[string]flags.UserFlags), s: ls, log: logger, tlsServerName: provider.TlsServerName, @@ -60,79 +56,14 @@ func (ls *LDAPServer) Refresh() error { } providers[idx].cert = ls.cs.Get(*kp) } + if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_CACHED { + providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx]) + } else if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_DIRECT { + providers[idx].searcher = directsearch.NewDirectSearcher(providers[idx]) + } + providers[idx].binder = directbind.NewDirectBinder(providers[idx]) } ls.providers = providers ls.log.Info("Update providers") return nil } - -func (ls *LDAPServer) StartLDAPServer() error { - listen := "0.0.0.0:3389" - - ln, err := net.Listen("tcp", listen) - if err != nil { - ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err) - } - proxyListener := &proxyproto.Listener{Listener: ln} - defer proxyListener.Close() - - ls.log.WithField("listen", listen).Info("Starting ldap server") - err = ls.s.Serve(proxyListener) - if err != nil { - return err - } - ls.log.Printf("closing %s", ln.Addr()) - return ls.s.ListenAndServe(listen) -} - -func (ls *LDAPServer) StartLDAPTLSServer() error { - listen := "0.0.0.0:6636" - tlsConfig := &tls.Config{ - MinVersion: tls.VersionTLS12, - MaxVersion: tls.VersionTLS12, - GetCertificate: ls.getCertificates, - } - - ln, err := net.Listen("tcp", listen) - if err != nil { - ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err) - } - - proxyListener := &proxyproto.Listener{Listener: ln} - defer proxyListener.Close() - - tln := tls.NewListener(proxyListener, tlsConfig) - - ls.log.WithField("listen", listen).Info("Starting ldap tls server") - err = ls.s.Serve(tln) - if err != nil { - return err - } - ls.log.Printf("closing %s", ln.Addr()) - return ls.s.ListenAndServe(listen) -} - -func (ls *LDAPServer) Start() error { - wg := sync.WaitGroup{} - wg.Add(3) - go func() { - defer wg.Done() - metrics.RunServer() - }() - go func() { - defer wg.Done() - err := ls.StartLDAPServer() - if err != nil { - panic(err) - } - }() - go func() { - defer wg.Done() - err := ls.StartLDAPTLSServer() - if err != nil { - panic(err) - } - }() - wg.Wait() - return nil -} diff --git a/internal/outpost/ldap/search.go b/internal/outpost/ldap/search.go index 8e31512b6..99ad7b142 100644 --- a/internal/outpost/ldap/search.go +++ b/internal/outpost/ldap/search.go @@ -1,47 +1,21 @@ package ldap import ( - "context" "errors" "net" - "strings" "github.com/getsentry/sentry-go" goldap "github.com/go-ldap/ldap/v3" - "github.com/google/uuid" "github.com/nmcclain/ldap" "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/ldap/metrics" + "goauthentik.io/internal/outpost/ldap/search" "goauthentik.io/internal/utils" ) -type SearchRequest struct { - ldap.SearchRequest - BindDN string - id string - conn net.Conn - log *log.Entry - ctx context.Context -} - func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) { - span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search")) - rid := uuid.New().String() - span.SetTag("request_uid", rid) - span.SetTag("user.username", bindDN) - span.SetTag("ak_filter", searchReq.Filter) - span.SetTag("ak_base_dn", searchReq.BaseDN) - - bindDN = strings.ToLower(bindDN) - req := SearchRequest{ - SearchRequest: searchReq, - BindDN: bindDN, - conn: conn, - log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN), - id: rid, - ctx: span.Context(), - } + req, span := search.NewRequest(bindDN, searchReq, conn) defer func() { span.Finish() @@ -50,9 +24,9 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n "type": "search", "filter": req.Filter, "dn": req.BindDN, - "client": utils.GetIP(req.conn.RemoteAddr()), + "client": utils.GetIP(conn.RemoteAddr()), }).Observe(float64(span.EndTime.Sub(span.StartTime))) - req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request") + req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request") }() defer func() { @@ -69,13 +43,13 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n } bd, err := goldap.ParseDN(searchReq.BaseDN) if err != nil { - req.log.WithError(err).Info("failed to parse basedn") + req.Log().WithError(err).Info("failed to parse basedn") return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("invalid DN") } for _, provider := range ls.providers { providerBase, _ := goldap.ParseDN(provider.BaseDN) if providerBase.AncestorOf(bd) || providerBase.Equal(bd) { - return provider.Search(req) + return provider.searcher.Search(req) } } return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("no provider could handle request") diff --git a/internal/outpost/ldap/instance_search_base.go b/internal/outpost/ldap/search/direct/base.go similarity index 79% rename from internal/outpost/ldap/instance_search_base.go rename to internal/outpost/ldap/search/direct/base.go index 9411a314c..5688a69d8 100644 --- a/internal/outpost/ldap/instance_search_base.go +++ b/internal/outpost/ldap/search/direct/base.go @@ -1,13 +1,14 @@ -package ldap +package direct import ( "fmt" "github.com/nmcclain/ldap" "goauthentik.io/internal/constants" + "goauthentik.io/internal/outpost/ldap/search" ) -func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.ServerSearchResult, error) { +func (ds *DirectSearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) { dn := "" if authz { dn = req.SearchRequest.BaseDN @@ -19,7 +20,7 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv Attributes: []*ldap.EntryAttribute{ { Name: "distinguishedName", - Values: []string{pi.BaseDN}, + Values: []string{ds.si.GetBaseDN()}, }, { Name: "objectClass", @@ -32,9 +33,9 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv { Name: "namingContexts", Values: []string{ - pi.BaseDN, - pi.GroupDN, - pi.UserDN, + ds.si.GetBaseDN(), + ds.si.GetBaseUserDN(), + ds.si.GetBaseGroupDN(), }, }, { diff --git a/internal/outpost/ldap/search/direct/direct.go b/internal/outpost/ldap/search/direct/direct.go new file mode 100644 index 000000000..eda11b2bb --- /dev/null +++ b/internal/outpost/ldap/search/direct/direct.go @@ -0,0 +1,219 @@ +package direct + +import ( + "errors" + "fmt" + "strings" + "sync" + + log "github.com/sirupsen/logrus" + + "github.com/getsentry/sentry-go" + "github.com/nmcclain/ldap" + "github.com/prometheus/client_golang/prometheus" + "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/constants" + "goauthentik.io/internal/outpost/ldap/flags" + "goauthentik.io/internal/outpost/ldap/group" + "goauthentik.io/internal/outpost/ldap/metrics" + "goauthentik.io/internal/outpost/ldap/search" + "goauthentik.io/internal/outpost/ldap/server" + "goauthentik.io/internal/outpost/ldap/utils" +) + +type DirectSearcher struct { + si server.LDAPServerInstance + log *log.Entry +} + +func NewDirectSearcher(si server.LDAPServerInstance) *DirectSearcher { + ds := &DirectSearcher{ + si: si, + log: log.WithField("logger", "authentik.outpost.ldap.searcher.direct"), + } + ds.log.Info("initialised direct searcher") + return ds +} + +func (ds *DirectSearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) { + if f.UserInfo == nil { + u, _, err := ds.si.GetAPIClient().CoreApi.CoreUsersRetrieve(req.Context(), f.UserPk).Execute() + if err != nil { + req.Log().WithError(err).Warning("Failed to get user info") + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo") + } + f.UserInfo = &u + } + entries := make([]*ldap.Entry, 1) + entries[0] = ds.si.UserEntry(*f.UserInfo) + return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil +} + +func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) { + accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access") + baseDN := strings.ToLower("," + ds.si.GetBaseDN()) + + entries := []*ldap.Entry{} + filterEntity, err := ldap.GetFilterObjectClass(req.Filter) + if err != nil { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ds.si.GetOutpostName(), + "type": "search", + "reason": "filter_parse_fail", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter) + } + if len(req.BindDN) < 1 { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ds.si.GetOutpostName(), + "type": "search", + "reason": "empty_bind_dn", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN) + } + if !strings.HasSuffix(req.BindDN, baseDN) { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ds.si.GetOutpostName(), + "type": "search", + "reason": "invalid_bind_dn", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ds.si.GetBaseDN()) + } + + flags, ok := ds.si.GetFlags(req.BindDN) + if !ok { + req.Log().Debug("User info not cached") + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ds.si.GetOutpostName(), + "type": "search", + "reason": "user_info_not_cached", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied") + } + + if req.Scope == ldap.ScopeBaseObject { + req.Log().Debug("base scope, showing domain info") + return ds.SearchBase(req, flags.CanSearch) + } + if !flags.CanSearch { + req.Log().Debug("User can't search, showing info about user") + return ds.SearchMe(req, flags) + } + accsp.Finish() + + parsedFilter, err := ldap.CompileFilter(req.Filter) + if err != nil { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ds.si.GetOutpostName(), + "type": "search", + "reason": "filter_parse_fail", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter) + } + + // Create a custom client to set additional headers + c := api.NewAPIClient(ds.si.GetAPIClient().GetConfig()) + c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter) + + switch filterEntity { + default: + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ds.si.GetOutpostName(), + "type": "search", + "reason": "unhandled_filter_type", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter) + case constants.OCGroupOfUniqueNames: + fallthrough + case constants.OCAKGroup: + fallthrough + case constants.OCAKVirtualGroup: + fallthrough + case constants.OCGroup: + wg := sync.WaitGroup{} + wg.Add(2) + + gEntries := make([]*ldap.Entry, 0) + uEntries := make([]*ldap.Entry, 0) + + go func() { + defer wg.Done() + gapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_group") + searchReq, skip := utils.ParseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false) + if skip { + req.Log().Trace("Skip backend request") + return + } + groups, _, err := searchReq.Execute() + gapisp.Finish() + if err != nil { + req.Log().WithError(err).Warning("failed to get groups") + return + } + req.Log().WithField("count", len(groups.Results)).Trace("Got results from API") + + for _, g := range groups.Results { + gEntries = append(gEntries, group.FromAPIGroup(g, ds.si).Entry()) + } + }() + + go func() { + defer wg.Done() + uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user") + searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false) + if skip { + req.Log().Trace("Skip backend request") + return + } + users, _, err := searchReq.Execute() + uapisp.Finish() + if err != nil { + req.Log().WithError(err).Warning("failed to get users") + return + } + + for _, u := range users.Results { + uEntries = append(uEntries, group.FromAPIUser(u, ds.si).Entry()) + } + }() + wg.Wait() + entries = append(gEntries, uEntries...) + case "": + fallthrough + case constants.OCOrgPerson: + fallthrough + case constants.OCInetOrgPerson: + fallthrough + case constants.OCAKUser: + fallthrough + case constants.OCUser: + uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user") + searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false) + if skip { + req.Log().Trace("Skip backend request") + return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil + } + users, _, err := searchReq.Execute() + uapisp.Finish() + + if err != nil { + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err) + } + for _, u := range users.Results { + entries = append(entries, ds.si.UserEntry(u)) + } + } + return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil +} diff --git a/internal/outpost/ldap/search/memory/base.go b/internal/outpost/ldap/search/memory/base.go new file mode 100644 index 000000000..123d4a7ad --- /dev/null +++ b/internal/outpost/ldap/search/memory/base.go @@ -0,0 +1,54 @@ +package memory + +import ( + "fmt" + + "github.com/nmcclain/ldap" + "goauthentik.io/internal/constants" + "goauthentik.io/internal/outpost/ldap/search" +) + +func (ms *MemorySearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) { + dn := "" + if authz { + dn = req.SearchRequest.BaseDN + } + return ldap.ServerSearchResult{ + Entries: []*ldap.Entry{ + { + DN: dn, + Attributes: []*ldap.EntryAttribute{ + { + Name: "distinguishedName", + Values: []string{ms.si.GetBaseDN()}, + }, + { + Name: "objectClass", + Values: []string{"top", "domain"}, + }, + { + Name: "supportedLDAPVersion", + Values: []string{"3"}, + }, + { + Name: "namingContexts", + Values: []string{ + ms.si.GetBaseDN(), + ms.si.GetBaseUserDN(), + ms.si.GetBaseGroupDN(), + }, + }, + { + Name: "vendorName", + Values: []string{"goauthentik.io"}, + }, + { + Name: "vendorVersion", + Values: []string{fmt.Sprintf("authentik LDAP Outpost Version %s (build %s)", constants.VERSION, constants.BUILD())}, + }, + }, + }, + }, + Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess, + }, nil +} diff --git a/internal/outpost/ldap/search/memory/fetch.go b/internal/outpost/ldap/search/memory/fetch.go new file mode 100644 index 000000000..c908848a6 --- /dev/null +++ b/internal/outpost/ldap/search/memory/fetch.go @@ -0,0 +1,63 @@ +package memory + +import ( + "context" + + "goauthentik.io/api" +) + +const pageSize = 100 + +func (ms *MemorySearcher) FetchUsers() []api.User { + fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) { + users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute() + if err != nil { + ms.log.WithError(err).Warning("failed to update users") + return nil, err + } + ms.log.WithField("page", page).Debug("fetched users") + return &users, nil + } + page := 1 + users := make([]api.User, 0) + for { + apiUsers, err := fetchUsersOffset(page) + if err != nil { + return users + } + if apiUsers.Pagination.Next > 0 { + page += 1 + } else { + break + } + users = append(users, apiUsers.Results...) + } + return users +} + +func (ms *MemorySearcher) FetchGroups() []api.Group { + fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) { + groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute() + if err != nil { + ms.log.WithError(err).Warning("failed to update groups") + return nil, err + } + ms.log.WithField("page", page).Debug("fetched groups") + return &groups, nil + } + page := 1 + groups := make([]api.Group, 0) + for { + apiGroups, err := fetchGroupsOffset(page) + if err != nil { + return groups + } + if apiGroups.Pagination.Next > 0 { + page += 1 + } else { + break + } + groups = append(groups, apiGroups.Results...) + } + return groups +} diff --git a/internal/outpost/ldap/search/memory/memory.go b/internal/outpost/ldap/search/memory/memory.go new file mode 100644 index 000000000..687e2c207 --- /dev/null +++ b/internal/outpost/ldap/search/memory/memory.go @@ -0,0 +1,182 @@ +package memory + +import ( + "errors" + "fmt" + "strings" + "sync" + + "github.com/getsentry/sentry-go" + "github.com/nmcclain/ldap" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/constants" + "goauthentik.io/internal/outpost/ldap/flags" + "goauthentik.io/internal/outpost/ldap/group" + "goauthentik.io/internal/outpost/ldap/metrics" + "goauthentik.io/internal/outpost/ldap/search" + "goauthentik.io/internal/outpost/ldap/server" +) + +type MemorySearcher struct { + si server.LDAPServerInstance + log *log.Entry + + users []api.User + groups []api.Group +} + +func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher { + ms := &MemorySearcher{ + si: si, + log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"), + } + ms.log.Info("initialised memory searcher") + ms.users = ms.FetchUsers() + ms.groups = ms.FetchGroups() + return ms +} + +func (ms *MemorySearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) { + if f.UserInfo == nil { + for _, u := range ms.users { + if u.Pk == f.UserPk { + f.UserInfo = &u + } + } + if f.UserInfo == nil { + req.Log().WithField("pk", f.UserPk).Warning("User with pk is not in local cache") + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo") + } + } + entries := make([]*ldap.Entry, 1) + entries[0] = ms.si.UserEntry(*f.UserInfo) + return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil +} + +func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) { + accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access") + baseDN := strings.ToLower("," + ms.si.GetBaseDN()) + + entries := []*ldap.Entry{} + filterEntity, err := ldap.GetFilterObjectClass(req.Filter) + if err != nil { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ms.si.GetOutpostName(), + "type": "search", + "reason": "filter_parse_fail", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter) + } + if len(req.BindDN) < 1 { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ms.si.GetOutpostName(), + "type": "search", + "reason": "empty_bind_dn", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN) + } + if !strings.HasSuffix(req.BindDN, baseDN) { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ms.si.GetOutpostName(), + "type": "search", + "reason": "invalid_bind_dn", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN()) + } + + flags, ok := ms.si.GetFlags(req.BindDN) + if !ok { + req.Log().Debug("User info not cached") + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ms.si.GetOutpostName(), + "type": "search", + "reason": "user_info_not_cached", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied") + } + + if req.Scope == ldap.ScopeBaseObject { + req.Log().Debug("base scope, showing domain info") + return ms.SearchBase(req, flags.CanSearch) + } + if !flags.CanSearch { + req.Log().Debug("User can't search, showing info about user") + return ms.SearchMe(req, flags) + } + accsp.Finish() + + // parsedFilter, err := ldap.CompileFilter(req.Filter) + // if err != nil { + // metrics.RequestsRejected.With(prometheus.Labels{ + // "outpost_name": ms.si.GetOutpostName(), + // "type": "search", + // "reason": "filter_parse_fail", + // "dn": req.BindDN, + // "client": req.RemoteAddr(), + // }).Inc() + // return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter) + // } + + switch filterEntity { + default: + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ms.si.GetOutpostName(), + "type": "search", + "reason": "unhandled_filter_type", + "dn": req.BindDN, + "client": req.RemoteAddr(), + }).Inc() + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter) + case constants.OCGroupOfUniqueNames: + fallthrough + case constants.OCAKGroup: + fallthrough + case constants.OCAKVirtualGroup: + fallthrough + case constants.OCGroup: + wg := sync.WaitGroup{} + wg.Add(2) + + gEntries := make([]*ldap.Entry, 0) + uEntries := make([]*ldap.Entry, 0) + + go func() { + defer wg.Done() + for _, g := range ms.groups { + gEntries = append(gEntries, group.FromAPIGroup(g, ms.si).Entry()) + } + }() + + go func() { + defer wg.Done() + for _, u := range ms.users { + uEntries = append(uEntries, group.FromAPIUser(u, ms.si).Entry()) + } + }() + wg.Wait() + entries = append(gEntries, uEntries...) + case "": + fallthrough + case constants.OCOrgPerson: + fallthrough + case constants.OCInetOrgPerson: + fallthrough + case constants.OCAKUser: + fallthrough + case constants.OCUser: + for _, u := range ms.users { + entries = append(entries, ms.si.UserEntry(u)) + } + } + return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil +} diff --git a/internal/outpost/ldap/search/request.go b/internal/outpost/ldap/search/request.go new file mode 100644 index 000000000..4ba67282c --- /dev/null +++ b/internal/outpost/ldap/search/request.go @@ -0,0 +1,53 @@ +package search + +import ( + "context" + "net" + "strings" + + "github.com/getsentry/sentry-go" + "github.com/google/uuid" + "github.com/nmcclain/ldap" + log "github.com/sirupsen/logrus" + "goauthentik.io/internal/utils" +) + +type Request struct { + ldap.SearchRequest + BindDN string + log *log.Entry + + id string + conn net.Conn + ctx context.Context +} + +func NewRequest(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (*Request, *sentry.Span) { + rid := uuid.New().String() + bindDN = strings.ToLower(bindDN) + span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search")) + span.SetTag("request_uid", rid) + span.SetTag("user.username", bindDN) + span.SetTag("ak_filter", searchReq.Filter) + span.SetTag("ak_base_dn", searchReq.BaseDN) + return &Request{ + SearchRequest: searchReq, + BindDN: bindDN, + conn: conn, + log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN), + id: rid, + ctx: span.Context(), + }, span +} + +func (r *Request) Context() context.Context { + return r.ctx +} + +func (r *Request) Log() *log.Entry { + return r.log +} + +func (r *Request) RemoteAddr() string { + return utils.GetIP(r.conn.RemoteAddr()) +} diff --git a/internal/outpost/ldap/search/searcher.go b/internal/outpost/ldap/search/searcher.go new file mode 100644 index 000000000..5adb6d2f5 --- /dev/null +++ b/internal/outpost/ldap/search/searcher.go @@ -0,0 +1,7 @@ +package search + +import "github.com/nmcclain/ldap" + +type Searcher interface { + Search(req *Request) (ldap.ServerSearchResult, error) +} diff --git a/internal/outpost/ldap/server/base.go b/internal/outpost/ldap/server/base.go new file mode 100644 index 000000000..623796441 --- /dev/null +++ b/internal/outpost/ldap/server/base.go @@ -0,0 +1,35 @@ +package server + +import ( + "github.com/go-openapi/strfmt" + "github.com/nmcclain/ldap" + "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/flags" +) + +type LDAPServerInstance interface { + GetAPIClient() *api.APIClient + GetOutpostName() string + + GetFlowSlug() string + GetAppSlug() string + GetSearchAllowedGroups() []*strfmt.UUID + + UserEntry(u api.User) *ldap.Entry + + GetBaseDN() string + GetBaseGroupDN() string + GetBaseUserDN() string + + GetUserDN(string) string + GetGroupDN(string) string + GetVirtualGroupDN(string) string + + GetUidNumber(api.User) string + GetGidNumber(api.Group) string + + UsersForGroup(api.Group) []string + + GetFlags(string) (flags.UserFlags, bool) + SetFlags(string, flags.UserFlags) +} diff --git a/internal/outpost/ldap/utils.go b/internal/outpost/ldap/utils.go index b39c8e140..e66a0182e 100644 --- a/internal/outpost/ldap/utils.go +++ b/internal/outpost/ldap/utils.go @@ -3,70 +3,12 @@ package ldap import ( "fmt" "math/big" - "reflect" "strconv" "strings" - "github.com/nmcclain/ldap" - log "github.com/sirupsen/logrus" "goauthentik.io/api" ) -func BoolToString(in bool) string { - if in { - return "true" - } - return "false" -} - -func ldapResolveTypeSingle(in interface{}) *string { - switch t := in.(type) { - case string: - return &t - case *string: - return t - case bool: - s := BoolToString(t) - return &s - case *bool: - s := BoolToString(*t) - return &s - default: - log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet") - return nil - } -} - -func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute { - attrList := []*ldap.EntryAttribute{} - if attrs == nil { - return attrList - } - a := attrs.(*map[string]interface{}) - for attrKey, attrValue := range *a { - entry := &ldap.EntryAttribute{Name: attrKey} - switch t := attrValue.(type) { - case []string: - entry.Values = t - case *[]string: - entry.Values = *t - case []interface{}: - entry.Values = make([]string, len(t)) - for idx, v := range t { - v := ldapResolveTypeSingle(v) - entry.Values[idx] = *v - } - default: - v := ldapResolveTypeSingle(t) - if v != nil { - entry.Values = []string{*v} - } - } - attrList = append(attrList, entry) - } - return attrList -} - func (pi *ProviderInstance) GroupsForUser(user api.User) []string { groups := make([]string, len(user.Groups)) for i, group := range user.GroupsObj { @@ -83,32 +25,6 @@ func (pi *ProviderInstance) UsersForGroup(group api.Group) []string { return users } -func (pi *ProviderInstance) APIGroupToLDAPGroup(g api.Group) LDAPGroup { - return LDAPGroup{ - dn: pi.GetGroupDN(g.Name), - cn: g.Name, - uid: string(g.Pk), - gidNumber: pi.GetGidNumber(g), - member: pi.UsersForGroup(g), - isVirtualGroup: false, - isSuperuser: *g.IsSuperuser, - akAttributes: g.Attributes, - } -} - -func (pi *ProviderInstance) APIUserToLDAPGroup(u api.User) LDAPGroup { - return LDAPGroup{ - dn: pi.GetVirtualGroupDN(u.Username), - cn: u.Username, - uid: u.Uid, - gidNumber: pi.GetUidNumber(u), - member: []string{pi.GetUserDN(u.Username)}, - isVirtualGroup: true, - isSuperuser: false, - akAttributes: nil, - } -} - func (pi *ProviderInstance) GetUserDN(user string) string { return fmt.Sprintf("cn=%s,%s", user, pi.UserDN) } @@ -155,26 +71,3 @@ func (pi *ProviderInstance) GetRIDForGroup(uid string) int32 { return int32(gid) } - -func (pi *ProviderInstance) ensureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute { - for name, values := range shouldHave { - attrs = pi.mustHaveAttribute(attrs, name, values) - } - return attrs -} - -func (pi *ProviderInstance) mustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute { - shouldSet := true - for _, attr := range attrs { - if attr.Name == name { - shouldSet = false - } - } - if shouldSet { - return append(attrs, &ldap.EntryAttribute{ - Name: name, - Values: value, - }) - } - return attrs -} diff --git a/internal/outpost/ldap/utils/utils.go b/internal/outpost/ldap/utils/utils.go new file mode 100644 index 000000000..ad725b42f --- /dev/null +++ b/internal/outpost/ldap/utils/utils.go @@ -0,0 +1,86 @@ +package utils + +import ( + "reflect" + + "github.com/nmcclain/ldap" + log "github.com/sirupsen/logrus" +) + +func BoolToString(in bool) string { + if in { + return "true" + } + return "false" +} + +func ldapResolveTypeSingle(in interface{}) *string { + switch t := in.(type) { + case string: + return &t + case *string: + return t + case bool: + s := BoolToString(t) + return &s + case *bool: + s := BoolToString(*t) + return &s + default: + log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet") + return nil + } +} + +func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute { + attrList := []*ldap.EntryAttribute{} + if attrs == nil { + return attrList + } + a := attrs.(*map[string]interface{}) + for attrKey, attrValue := range *a { + entry := &ldap.EntryAttribute{Name: attrKey} + switch t := attrValue.(type) { + case []string: + entry.Values = t + case *[]string: + entry.Values = *t + case []interface{}: + entry.Values = make([]string, len(t)) + for idx, v := range t { + v := ldapResolveTypeSingle(v) + entry.Values[idx] = *v + } + default: + v := ldapResolveTypeSingle(t) + if v != nil { + entry.Values = []string{*v} + } + } + attrList = append(attrList, entry) + } + return attrList +} + +func EnsureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute { + for name, values := range shouldHave { + attrs = MustHaveAttribute(attrs, name, values) + } + return attrs +} + +func MustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute { + shouldSet := true + for _, attr := range attrs { + if attr.Name == name { + shouldSet = false + } + } + if shouldSet { + return append(attrs, &ldap.EntryAttribute{ + Name: name, + Values: value, + }) + } + return attrs +} diff --git a/internal/outpost/ldap/instance_search_group.go b/internal/outpost/ldap/utils/utils_group.go similarity index 83% rename from internal/outpost/ldap/instance_search_group.go rename to internal/outpost/ldap/utils/utils_group.go index 7bbf14da5..2599c9f76 100644 --- a/internal/outpost/ldap/instance_search_group.go +++ b/internal/outpost/ldap/utils/utils_group.go @@ -1,19 +1,20 @@ -package ldap +package utils import ( goldap "github.com/go-ldap/ldap/v3" ber "github.com/nmcclain/asn1-ber" "github.com/nmcclain/ldap" "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/constants" ) -func parseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) { +func ParseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) { switch f.Tag { case ldap.FilterEqualityMatch: return parseFilterForGroupSingle(req, f) case ldap.FilterAnd: for _, child := range f.Children { - r, s := parseFilterForGroup(req, child, skip) + r, s := ParseFilterForGroup(req, child, skip) skip = skip || s req = r } @@ -53,7 +54,7 @@ func parseFilterForGroupSingle(req api.ApiCoreGroupsListRequest, f *ber.Packet) username := userDN.RDNs[0].Attributes[0].Value // If the DN's first ou is virtual-groups, ignore this filter if len(userDN.RDNs) > 1 { - if userDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU || userDN.RDNs[1].Attributes[0].Value == GroupsOU { + if userDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups || userDN.RDNs[1].Attributes[0].Value == constants.OUGroups { // Since we know we're not filtering anything, skip this request return req, true } diff --git a/internal/outpost/ldap/instance_search_user.go b/internal/outpost/ldap/utils/utils_user.go similarity index 84% rename from internal/outpost/ldap/instance_search_user.go rename to internal/outpost/ldap/utils/utils_user.go index 842151a32..b51c4fe0a 100644 --- a/internal/outpost/ldap/instance_search_user.go +++ b/internal/outpost/ldap/utils/utils_user.go @@ -1,19 +1,20 @@ -package ldap +package utils import ( goldap "github.com/go-ldap/ldap/v3" ber "github.com/nmcclain/asn1-ber" "github.com/nmcclain/ldap" "goauthentik.io/api" + "goauthentik.io/internal/outpost/ldap/constants" ) -func parseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) { +func ParseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) { switch f.Tag { case ldap.FilterEqualityMatch: return parseFilterForUserSingle(req, f) case ldap.FilterAnd: for _, child := range f.Children { - r, s := parseFilterForUser(req, child, skip) + r, s := ParseFilterForUser(req, child, skip) skip = skip || s req = r } @@ -58,7 +59,7 @@ func parseFilterForUserSingle(req api.ApiCoreUsersListRequest, f *ber.Packet) (a name := groupDN.RDNs[0].Attributes[0].Value // If the DN's first ou is virtual-groups, ignore this filter if len(groupDN.RDNs) > 1 { - if groupDN.RDNs[1].Attributes[0].Value == UsersOU || groupDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU { + if groupDN.RDNs[1].Attributes[0].Value == constants.OUUsers || groupDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups { // Since we know we're not filtering anything, skip this request return req, true } diff --git a/schema.yml b/schema.yml index 0b58de6cc..cf57b1451 100644 --- a/schema.yml +++ b/schema.yml @@ -22035,6 +22035,8 @@ components: generated from the group.Pk to make sure that the numbers aren't too low for POSIX groups. Default is 4000 to ensure that we don't collide with local groups or users primary groups gidNumber + search_mode: + $ref: '#/components/schemas/SearchModeEnum' required: - application_slug - bind_flow_slug @@ -22173,6 +22175,8 @@ components: items: type: string readOnly: true + search_mode: + $ref: '#/components/schemas/SearchModeEnum' required: - assigned_application_name - assigned_application_slug @@ -22228,6 +22232,8 @@ components: generated from the group.Pk to make sure that the numbers aren't too low for POSIX groups. Default is 4000 to ensure that we don't collide with local groups or users primary groups gidNumber + search_mode: + $ref: '#/components/schemas/SearchModeEnum' required: - authorization_flow - name @@ -26849,6 +26855,8 @@ components: generated from the group.Pk to make sure that the numbers aren't too low for POSIX groups. Default is 4000 to ensure that we don't collide with local groups or users primary groups gidNumber + search_mode: + $ref: '#/components/schemas/SearchModeEnum' PatchedLDAPSourceRequest: type: object description: LDAP Source Serializer @@ -29410,6 +29418,11 @@ components: - expression - name - scope_name + SearchModeEnum: + enum: + - direct + - cached + type: string ServiceConnection: type: object description: ServiceConnection Serializer diff --git a/web/src/locales/en.po b/web/src/locales/en.po index a86c256e7..c9cd5e641 100644 --- a/web/src/locales/en.po +++ b/web/src/locales/en.po @@ -625,6 +625,10 @@ msgstr "Cached flows" msgid "Cached policies" msgstr "Cached policies" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes." +msgstr "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes." + #: src/pages/sources/oauth/OAuthSourceViewPage.ts msgid "Callback URL" msgstr "Callback URL" @@ -913,6 +917,10 @@ msgstr "Configure how the flow executor should handle an invalid response to a c msgid "Configure how the issuer field of the ID Token should be filled." msgstr "Configure how the issuer field of the ID Token should be filled." +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Configure how the outpost queries the core authentik server's users." +msgstr "Configure how the outpost queries the core authentik server's users." + #: #: #~ msgid "Configure settings relevant to your user profile." @@ -1416,6 +1424,10 @@ msgstr "Digest algorithm" msgid "Digits" msgstr "Digits" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Direct querying, always returns the latest data, but slower than cached querying." +msgstr "Direct querying, always returns the latest data, but slower than cached querying." + #: #: #~ msgid "Disable" @@ -3802,6 +3814,10 @@ msgstr "Score" msgid "Search group" msgstr "Search group" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Search mode" +msgstr "Search mode" + #: src/elements/table/TableSearch.ts #: src/user/LibraryPage.ts msgid "Search..." diff --git a/web/src/locales/fr_FR.po b/web/src/locales/fr_FR.po index 4e3cd71f4..a1129094d 100644 --- a/web/src/locales/fr_FR.po +++ b/web/src/locales/fr_FR.po @@ -627,6 +627,10 @@ msgstr "Flux mis en cache" msgid "Cached policies" msgstr "Politiques mises en cache" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes." +msgstr "" + #: src/pages/sources/oauth/OAuthSourceViewPage.ts msgid "Callback URL" msgstr "URL de rappel" @@ -913,6 +917,10 @@ msgstr "Configure comment l'exécuteur de flux gère une réponse invalide à un msgid "Configure how the issuer field of the ID Token should be filled." msgstr "Configure comment le champ émetteur du jeton ID sera rempli." +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Configure how the outpost queries the core authentik server's users." +msgstr "" + #~ msgid "Configure settings relevant to your user profile." #~ msgstr "Configure les paramètre applicable à votre profil." @@ -1406,6 +1414,10 @@ msgstr "Algorithme d'empreinte" msgid "Digits" msgstr "Chiffres" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Direct querying, always returns the latest data, but slower than cached querying." +msgstr "" + #~ msgid "Disable" #~ msgstr "Désactiver" @@ -3770,6 +3782,10 @@ msgstr "Note" msgid "Search group" msgstr "Rechercher un groupe" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Search mode" +msgstr "" + #: src/elements/table/TableSearch.ts #: src/user/LibraryPage.ts msgid "Search..." diff --git a/web/src/locales/pseudo-LOCALE.po b/web/src/locales/pseudo-LOCALE.po index 3b5bd9682..02b6ac128 100644 --- a/web/src/locales/pseudo-LOCALE.po +++ b/web/src/locales/pseudo-LOCALE.po @@ -621,6 +621,10 @@ msgstr "" msgid "Cached policies" msgstr "" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes." +msgstr "" + #: src/pages/sources/oauth/OAuthSourceViewPage.ts msgid "Callback URL" msgstr "" @@ -907,6 +911,10 @@ msgstr "" msgid "Configure how the issuer field of the ID Token should be filled." msgstr "" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Configure how the outpost queries the core authentik server's users." +msgstr "" + #: #: #~ msgid "Configure settings relevant to your user profile." @@ -1408,6 +1416,10 @@ msgstr "" msgid "Digits" msgstr "" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Direct querying, always returns the latest data, but slower than cached querying." +msgstr "" + #: #: #~ msgid "Disable" @@ -3794,6 +3806,10 @@ msgstr "" msgid "Search group" msgstr "" +#: src/pages/providers/ldap/LDAPProviderForm.ts +msgid "Search mode" +msgstr "" + #: src/elements/table/TableSearch.ts #: src/user/LibraryPage.ts msgid "Search..." diff --git a/web/src/pages/providers/ldap/LDAPProviderForm.ts b/web/src/pages/providers/ldap/LDAPProviderForm.ts index 2c18fa905..1587b8425 100644 --- a/web/src/pages/providers/ldap/LDAPProviderForm.ts +++ b/web/src/pages/providers/ldap/LDAPProviderForm.ts @@ -12,6 +12,7 @@ import { FlowsInstancesListDesignationEnum, LDAPProvider, ProvidersApi, + SearchModeEnum, } from "@goauthentik/api"; import { DEFAULT_CONFIG, tenant } from "../../../api/Config"; @@ -118,6 +119,25 @@ export class LDAPProviderFormPage extends ModelForm { ${t`Users in the selected group can do search queries. If no group is selected, no LDAP Searches are allowed.`}

+ + +

+ ${t`Configure how the outpost queries the core authentik server's users.`} +

+
${t`Protocol settings`}