providers/ldap: memory Query (#1681)

* outposts/ldap: modularise ldap outpost, to allow different searchers and binders

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* outposts/ldap: add basic in-memory searcher

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* providers/ldap: add search mode field

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

* outpost: add search mode field

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens L 2021-11-05 10:37:30 +01:00 committed by GitHub
parent 8de13d3f67
commit 5a8c66d325
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
37 changed files with 1293 additions and 639 deletions

View File

@ -60,18 +60,19 @@ gen-web:
\cp -rfv web-api/* web/node_modules/@goauthentik/api \cp -rfv web-api/* web/node_modules/@goauthentik/api
gen-outpost: gen-outpost:
wget https://raw.githubusercontent.com/goauthentik/client-go/main/config.yaml -O config.yaml
mkdir -p templates
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/README.mustache -O templates/README.mustache
wget https://raw.githubusercontent.com/goauthentik/client-go/main/templates/go.mod.mustache -O templates/go.mod.mustache
docker run \ docker run \
--rm -v ${PWD}:/local \ --rm -v ${PWD}:/local \
--user ${UID}:${GID} \ --user ${UID}:${GID} \
openapitools/openapi-generator-cli generate \ openapitools/openapi-generator-cli generate \
--git-host goauthentik.io \
--git-repo-id outpost \
--git-user-id api \
-i /local/schema.yml \ -i /local/schema.yml \
-g go \ -g go \
-o /local/api \ -o /local/api \
--additional-properties=packageName=api,enumClassPrefix=true,useOneOfDiscriminatorLookup=true,disallowAdditionalPropertiesIfNotPresent=false -c /local/config.yaml
rm -f api/go.mod api/go.sum go mod edit -replace goauthentik.io/api=./api
gen: gen-build gen-clean gen-web gen: gen-build gen-clean gen-web

View File

@ -24,6 +24,7 @@ class LDAPProviderSerializer(ProviderSerializer):
"uid_start_number", "uid_start_number",
"gid_start_number", "gid_start_number",
"outpost_set", "outpost_set",
"search_mode",
] ]
@ -68,6 +69,7 @@ class LDAPOutpostConfigSerializer(ModelSerializer):
"tls_server_name", "tls_server_name",
"uid_start_number", "uid_start_number",
"gid_start_number", "gid_start_number",
"search_mode",
] ]

View File

@ -9,6 +9,12 @@ from authentik.core.models import Group, Provider
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
from authentik.outposts.models import OutpostModel from authentik.outposts.models import OutpostModel
class SearchModes(models.TextChoices):
"""Search modes"""
DIRECT = "direct"
CACHED = "cached"
class LDAPProvider(OutpostModel, Provider): class LDAPProvider(OutpostModel, Provider):
"""Allow applications to authenticate against authentik's users using LDAP.""" """Allow applications to authenticate against authentik's users using LDAP."""
@ -59,6 +65,8 @@ class LDAPProvider(OutpostModel, Provider):
), ),
) )
search_mode = models.TextField(default=SearchModes.DIRECT, choices=SearchModes.choices)
@property @property
def launch_url(self) -> Optional[str]: def launch_url(self) -> Optional[str]:
"""LDAP never has a launch URL""" """LDAP never has a launch URL"""

View File

@ -4,7 +4,6 @@ import (
"context" "context"
"errors" "errors"
"fmt" "fmt"
"net"
"net/http" "net/http"
"net/http/cookiejar" "net/http/cookiejar"
"net/url" "net/url"
@ -18,7 +17,6 @@ import (
"goauthentik.io/api" "goauthentik.io/api"
"goauthentik.io/internal/constants" "goauthentik.io/internal/constants"
"goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/utils"
) )
type StageComponent string type StageComponent string
@ -103,8 +101,8 @@ type ChallengeInt interface {
GetResponseErrors() map[string][]api.ErrorDetail GetResponseErrors() map[string][]api.ErrorDetail
} }
func (fe *FlowExecutor) DelegateClientIP(a net.Addr) { func (fe *FlowExecutor) DelegateClientIP(a string) {
fe.cip = utils.GetIP(a) fe.cip = a
fe.api.GetConfig().AddDefaultHeader(HeaderAuthentikRemoteIP, fe.cip) fe.api.GetConfig().AddDefaultHeader(HeaderAuthentikRemoteIP, fe.cip)
} }

View File

@ -1,23 +0,0 @@
package ldap
import "crypto/tls"
func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if len(ls.providers) == 1 {
if ls.providers[0].cert != nil {
ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert")
return ls.providers[0].cert, nil
}
}
for _, provider := range ls.providers {
if provider.tlsServerName == &info.ServerName {
if provider.cert == nil {
ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
return ls.defaultCert, nil
}
return provider.cert, nil
}
}
ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert")
return ls.defaultCert, nil
}

View File

@ -1,44 +1,18 @@
package ldap package ldap
import ( import (
"context"
"net" "net"
"strings"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
"github.com/nmcclain/ldap" "github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" "goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/utils" "goauthentik.io/internal/utils"
) )
type BindRequest struct {
BindDN string
BindPW string
id string
conn net.Conn
log *log.Entry
ctx context.Context
}
func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LDAPResultCode, error) { func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LDAPResultCode, error) {
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind", req, span := bind.NewRequest(bindDN, bindPW, conn)
sentry.TransactionName("authentik.providers.ldap.bind"))
rid := uuid.New().String()
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
bindDN = strings.ToLower(bindDN)
req := BindRequest{
BindDN: bindDN,
BindPW: bindPW,
conn: conn,
log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())),
id: rid,
ctx: span.Context(),
}
defer func() { defer func() {
span.Finish() span.Finish()
metrics.Requests.With(prometheus.Labels{ metrics.Requests.With(prometheus.Labels{
@ -46,19 +20,19 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
"type": "bind", "type": "bind",
"filter": "", "filter": "",
"dn": req.BindDN, "dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": req.RemoteAddr(),
}).Observe(float64(span.EndTime.Sub(span.StartTime))) }).Observe(float64(span.EndTime.Sub(span.StartTime)))
req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request") req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Bind request")
}() }()
for _, instance := range ls.providers { for _, instance := range ls.providers {
username, err := instance.getUsername(bindDN) username, err := instance.binder.GetUsername(bindDN)
if err == nil { if err == nil {
return instance.Bind(username, req) return instance.binder.Bind(username, req)
} else { } else {
req.log.WithError(err).Debug("Username not for instance") req.Log().WithError(err).Debug("Username not for instance")
} }
} }
req.log.WithField("request", "bind").Warning("No provider found for request") req.Log().WithField("request", "bind").Warning("No provider found for request")
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ls.ac.Outpost.Name, "outpost_name": ls.ac.Outpost.Name,
"type": "bind", "type": "bind",
@ -68,10 +42,3 @@ func (ls *LDAPServer) Bind(bindDN string, bindPW string, conn net.Conn) (ldap.LD
}).Inc() }).Inc()
return ldap.LDAPResultOperationsError, nil return ldap.LDAPResultOperationsError, nil
} }
func (ls *LDAPServer) TimerFlowCacheExpiry() {
for _, p := range ls.providers {
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
p.TimerFlowCacheExpiry()
}
}

View File

@ -0,0 +1,9 @@
package bind
import "github.com/nmcclain/ldap"
type Binder interface {
GetUsername(string) (string, error)
Bind(username string, req *Request) (ldap.LDAPResultCode, error)
TimerFlowCacheExpiry()
}

View File

@ -1,4 +1,4 @@
package ldap package direct
import ( import (
"context" "context"
@ -12,14 +12,30 @@ import (
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/api" "goauthentik.io/api"
"goauthentik.io/internal/outpost" "goauthentik.io/internal/outpost"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/utils" "goauthentik.io/internal/outpost/ldap/server"
) )
const ContextUserKey = "ak_user" const ContextUserKey = "ak_user"
func (pi *ProviderInstance) getUsername(dn string) (string, error) { type DirectBinder struct {
if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(pi.BaseDN)) { si server.LDAPServerInstance
log *log.Entry
}
func NewDirectBinder(si server.LDAPServerInstance) *DirectBinder {
db := &DirectBinder{
si: si,
log: log.WithField("logger", "authentik.outpost.ldap.binder.direct"),
}
db.log.Info("initialised direct binder")
return db
}
func (db *DirectBinder) GetUsername(dn string) (string, error) {
if !strings.HasSuffix(strings.ToLower(dn), strings.ToLower(db.si.GetBaseDN())) {
return "", errors.New("invalid base DN") return "", errors.New("invalid base DN")
} }
dns, err := goldap.ParseDN(dn) dns, err := goldap.ParseDN(dn)
@ -36,13 +52,13 @@ func (pi *ProviderInstance) getUsername(dn string) (string, error) {
return "", errors.New("failed to find cn") return "", errors.New("failed to find cn")
} }
func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPResultCode, error) { func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
fe := outpost.NewFlowExecutor(req.ctx, pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{ fe := outpost.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"bindDN": req.BindDN, "bindDN": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": req.RemoteAddr(),
"requestId": req.id, "requestId": req.ID(),
}) })
fe.DelegateClientIP(req.conn.RemoteAddr()) fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true") fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Answers[outpost.StageIdentification] = username fe.Answers[outpost.StageIdentification] = username
@ -51,83 +67,82 @@ func (pi *ProviderInstance) Bind(username string, req BindRequest) (ldap.LDAPRes
passed, err := fe.Execute() passed, err := fe.Execute()
if !passed { if !passed {
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName, "outpost_name": db.si.GetOutpostName(),
"type": "bind", "type": "bind",
"reason": "invalid_credentials", "reason": "invalid_credentials",
"dn": req.BindDN, "dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": req.RemoteAddr(),
}).Inc() }).Inc()
return ldap.LDAPResultInvalidCredentials, nil return ldap.LDAPResultInvalidCredentials, nil
} }
if err != nil { if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName, "outpost_name": db.si.GetOutpostName(),
"type": "bind", "type": "bind",
"reason": "flow_error", "reason": "flow_error",
"dn": req.BindDN, "dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": req.RemoteAddr(),
}).Inc() }).Inc()
req.log.WithError(err).Warning("failed to execute flow") req.Log().WithError(err).Warning("failed to execute flow")
return ldap.LDAPResultOperationsError, nil return ldap.LDAPResultOperationsError, nil
} }
access, err := fe.CheckApplicationAccess(pi.appSlug) access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
if !access { if !access {
req.log.Info("Access denied for user") req.Log().Info("Access denied for user")
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName, "outpost_name": db.si.GetOutpostName(),
"type": "bind", "type": "bind",
"reason": "access_denied", "reason": "access_denied",
"dn": req.BindDN, "dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": req.RemoteAddr(),
}).Inc() }).Inc()
return ldap.LDAPResultInsufficientAccessRights, nil return ldap.LDAPResultInsufficientAccessRights, nil
} }
if err != nil { if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName, "outpost_name": db.si.GetOutpostName(),
"type": "bind", "type": "bind",
"reason": "access_check_fail", "reason": "access_check_fail",
"dn": req.BindDN, "dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": req.RemoteAddr(),
}).Inc() }).Inc()
req.log.WithError(err).Warning("failed to check access") req.Log().WithError(err).Warning("failed to check access")
return ldap.LDAPResultOperationsError, nil return ldap.LDAPResultOperationsError, nil
} }
req.log.Info("User has access") req.Log().Info("User has access")
uisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.bind.user_info") uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
// Get user info to store in context // Get user info to store in context
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute() userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
if err != nil { if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{ metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName, "outpost_name": db.si.GetOutpostName(),
"type": "bind", "type": "bind",
"reason": "user_info_fail", "reason": "user_info_fail",
"dn": req.BindDN, "dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": req.RemoteAddr(),
}).Inc() }).Inc()
req.log.WithError(err).Warning("failed to get user info") req.Log().WithError(err).Warning("failed to get user info")
return ldap.LDAPResultOperationsError, nil return ldap.LDAPResultOperationsError, nil
} }
pi.boundUsersMutex.Lock() cs := db.SearchAccessCheck(userInfo.User)
cs := pi.SearchAccessCheck(userInfo.User) flags := flags.UserFlags{
pi.boundUsers[req.BindDN] = UserFlags{
UserPk: userInfo.User.Pk, UserPk: userInfo.User.Pk,
CanSearch: cs != nil, CanSearch: cs != nil,
} }
if pi.boundUsers[req.BindDN].CanSearch { db.si.SetFlags(req.BindDN, flags)
req.log.WithField("group", cs).Info("Allowed access to search") if flags.CanSearch {
req.Log().WithField("group", cs).Info("Allowed access to search")
} }
uisp.Finish() uisp.Finish()
defer pi.boundUsersMutex.Unlock()
return ldap.LDAPResultSuccess, nil return ldap.LDAPResultSuccess, nil
} }
// SearchAccessCheck Check if the current user is allowed to search // SearchAccessCheck Check if the current user is allowed to search
func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string { func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
for _, group := range user.Groups { for _, group := range user.Groups {
for _, allowedGroup := range pi.searchAllowedGroups { for _, allowedGroup := range db.si.GetSearchAllowedGroups() {
pi.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access") db.log.WithField("userGroup", group.Pk).WithField("allowedGroup", allowedGroup).Trace("Checking search access")
if group.Pk == allowedGroup.String() { if group.Pk == allowedGroup.String() {
return &group.Name return &group.Name
} }
@ -136,13 +151,13 @@ func (pi *ProviderInstance) SearchAccessCheck(user api.UserSelf) *string {
return nil return nil
} }
func (pi *ProviderInstance) TimerFlowCacheExpiry() { func (db *DirectBinder) TimerFlowCacheExpiry() {
fe := outpost.NewFlowExecutor(context.Background(), pi.flowSlug, pi.s.ac.Client.GetConfig(), log.Fields{}) fe := outpost.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
fe.Params.Add("goauthentik.io/outpost/ldap", "true") fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true") fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true")
err := fe.WarmUp() err := fe.WarmUp()
if err != nil { if err != nil {
pi.log.WithError(err).Warning("failed to warm up flow cache") db.log.WithError(err).Warning("failed to warm up flow cache")
} }
} }

View File

@ -0,0 +1,55 @@
package bind
import (
"context"
"net"
"strings"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/utils"
)
type Request struct {
BindDN string
BindPW string
id string
conn net.Conn
log *log.Entry
ctx context.Context
}
func NewRequest(bindDN string, bindPW string, conn net.Conn) (*Request, *sentry.Span) {
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.bind",
sentry.TransactionName("authentik.providers.ldap.bind"))
rid := uuid.New().String()
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
bindDN = strings.ToLower(bindDN)
return &Request{
BindDN: bindDN,
BindPW: bindPW,
conn: conn,
log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("client", utils.GetIP(conn.RemoteAddr())),
id: rid,
ctx: span.Context(),
}, span
}
func (r *Request) Context() context.Context {
return r.ctx
}
func (r *Request) Log() *log.Entry {
return r.log
}
func (r *Request) RemoteAddr() string {
return utils.GetIP(r.conn.RemoteAddr())
}
func (r *Request) ID() string {
return r.id
}

View File

@ -0,0 +1,21 @@
package constants
const (
OCGroup = "group"
OCGroupOfUniqueNames = "groupOfUniqueNames"
OCAKGroup = "goauthentik.io/ldap/group"
OCAKVirtualGroup = "goauthentik.io/ldap/virtual-group"
)
const (
OCUser = "user"
OCOrgPerson = "organizationalPerson"
OCInetOrgPerson = "inetOrgPerson"
OCAKUser = "goauthentik.io/ldap/user"
)
const (
OUUsers = "users"
OUGroups = "groups"
OUVirtualGroups = "virtual-groups"
)

View File

@ -0,0 +1,39 @@
package ldap
import (
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/utils"
)
func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
dn := pi.GetUserDN(u.Username)
attrs := utils.AKAttrsToLDAP(u.Attributes)
attrs = utils.EnsureAttributes(attrs, map[string][]string{
"memberOf": pi.GroupsForUser(u),
// Old fields for backwards compatibility
"accountStatus": {utils.BoolToString(*u.IsActive)},
"superuser": {utils.BoolToString(u.IsSuperuser)},
// End old fields
"goauthentik.io/ldap/active": {utils.BoolToString(*u.IsActive)},
"goauthentik.io/ldap/superuser": {utils.BoolToString(u.IsSuperuser)},
"cn": {u.Username},
"sAMAccountName": {u.Username},
"uid": {u.Uid},
"name": {u.Name},
"displayName": {u.Name},
"mail": {*u.Email},
"objectClass": {constants.OCUser, constants.OCOrgPerson, constants.OCInetOrgPerson, constants.OCAKUser},
"uidNumber": {pi.GetUidNumber(u)},
"gidNumber": {pi.GetUidNumber(u)},
})
return &ldap.Entry{DN: dn, Attributes: attrs}
}
func (pi *ProviderInstance) GroupEntry(g group.LDAPGroup) *ldap.Entry {
// TODO: Remove
return g.Entry()
}

View File

@ -0,0 +1,9 @@
package flags
import "goauthentik.io/api"
type UserFlags struct {
UserInfo *api.User
UserPk int32
CanSearch bool
}

View File

@ -0,0 +1,66 @@
package group
import (
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils"
)
type LDAPGroup struct {
DN string
CN string
Uid string
GidNumber string
Member []string
IsSuperuser bool
IsVirtualGroup bool
AKAttributes interface{}
}
func (lg *LDAPGroup) Entry() *ldap.Entry {
attrs := utils.AKAttrsToLDAP(lg.AKAttributes)
objectClass := []string{constants.OCGroup, constants.OCGroupOfUniqueNames, constants.OCAKGroup}
if lg.IsVirtualGroup {
objectClass = append(objectClass, constants.OCAKVirtualGroup)
}
attrs = utils.EnsureAttributes(attrs, map[string][]string{
"objectClass": objectClass,
"member": lg.Member,
"goauthentik.io/ldap/superuser": {utils.BoolToString(lg.IsSuperuser)},
"cn": {lg.CN},
"uid": {lg.Uid},
"sAMAccountName": {lg.CN},
"gidNumber": {lg.GidNumber},
})
return &ldap.Entry{DN: lg.DN, Attributes: attrs}
}
func FromAPIGroup(g api.Group, si server.LDAPServerInstance) *LDAPGroup {
return &LDAPGroup{
DN: si.GetGroupDN(g.Name),
CN: g.Name,
Uid: string(g.Pk),
GidNumber: si.GetGidNumber(g),
Member: si.UsersForGroup(g),
IsVirtualGroup: false,
IsSuperuser: *g.IsSuperuser,
AKAttributes: g.Attributes,
}
}
func FromAPIUser(u api.User, si server.LDAPServerInstance) *LDAPGroup {
return &LDAPGroup{
DN: si.GetVirtualGroupDN(u.Username),
CN: u.Username,
Uid: u.Uid,
GidNumber: si.GetUidNumber(u),
Member: []string{si.GetUserDN(u.Username)},
IsVirtualGroup: true,
IsSuperuser: false,
AKAttributes: nil,
}
}

View File

@ -0,0 +1,4 @@
package handler
type Handler interface {
}

View File

@ -0,0 +1,83 @@
package ldap
import (
"crypto/tls"
"sync"
"github.com/go-openapi/strfmt"
log "github.com/sirupsen/logrus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/search"
)
type ProviderInstance struct {
BaseDN string
UserDN string
VirtualGroupDN string
GroupDN string
searcher search.Searcher
binder bind.Binder
appSlug string
flowSlug string
s *LDAPServer
log *log.Entry
tlsServerName *string
cert *tls.Certificate
outpostName string
searchAllowedGroups []*strfmt.UUID
boundUsersMutex sync.RWMutex
boundUsers map[string]flags.UserFlags
uidStartNumber int32
gidStartNumber int32
}
func (pi *ProviderInstance) GetAPIClient() *api.APIClient {
return pi.s.ac.Client
}
func (pi *ProviderInstance) GetBaseDN() string {
return pi.BaseDN
}
func (pi *ProviderInstance) GetBaseGroupDN() string {
return pi.GroupDN
}
func (pi *ProviderInstance) GetBaseUserDN() string {
return pi.UserDN
}
func (pi *ProviderInstance) GetOutpostName() string {
return pi.outpostName
}
func (pi *ProviderInstance) GetFlags(dn string) (flags.UserFlags, bool) {
pi.boundUsersMutex.RLock()
flags, ok := pi.boundUsers[dn]
pi.boundUsersMutex.RUnlock()
return flags, ok
}
func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) {
pi.boundUsersMutex.Lock()
pi.boundUsers[dn] = flag
pi.boundUsersMutex.Unlock()
}
func (pi *ProviderInstance) GetAppSlug() string {
return pi.appSlug
}
func (pi *ProviderInstance) GetFlowSlug() string {
return pi.flowSlug
}
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {
return pi.searchAllowedGroups
}

View File

@ -1,244 +0,0 @@
package ldap
import (
"errors"
"fmt"
"strings"
"sync"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/utils"
)
func (pi *ProviderInstance) SearchMe(req SearchRequest, f UserFlags) (ldap.ServerSearchResult, error) {
if f.UserInfo == nil {
u, _, err := pi.s.ac.Client.CoreApi.CoreUsersRetrieve(req.ctx, f.UserPk).Execute()
if err != nil {
req.log.WithError(err).Warning("Failed to get user info")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
}
f.UserInfo = &u
}
entries := make([]*ldap.Entry, 1)
entries[0] = pi.UserEntry(*f.UserInfo)
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (pi *ProviderInstance) Search(req SearchRequest) (ldap.ServerSearchResult, error) {
accsp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.check_access")
baseDN := strings.ToLower("," + pi.BaseDN)
entries := []*ldap.Entry{}
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
if len(req.BindDN) < 1 {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "empty_bind_dn",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
}
if !strings.HasSuffix(req.BindDN, baseDN) {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "invalid_bind_dn",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, pi.BaseDN)
}
pi.boundUsersMutex.RLock()
flags, ok := pi.boundUsers[req.BindDN]
pi.boundUsersMutex.RUnlock()
if !ok {
pi.log.Debug("User info not cached")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "user_info_not_cached",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
}
if req.SearchRequest.Scope == ldap.ScopeBaseObject {
pi.log.Debug("base scope, showing domain info")
return pi.SearchBase(req, flags.CanSearch)
}
if !flags.CanSearch {
pi.log.Debug("User can't search, showing info about user")
return pi.SearchMe(req, flags)
}
accsp.Finish()
parsedFilter, err := ldap.CompileFilter(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
// Create a custom client to set additional headers
c := api.NewAPIClient(pi.s.ac.Client.GetConfig())
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
switch filterEntity {
default:
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": pi.outpostName,
"type": "search",
"reason": "unhandled_filter_type",
"dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
case "groupOfUniqueNames":
fallthrough
case "goauthentik.io/ldap/group":
fallthrough
case "goauthentik.io/ldap/virtual-group":
fallthrough
case GroupObjectClass:
wg := sync.WaitGroup{}
wg.Add(2)
gEntries := make([]*ldap.Entry, 0)
uEntries := make([]*ldap.Entry, 0)
go func() {
defer wg.Done()
gapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_group")
searchReq, skip := parseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false)
if skip {
pi.log.Trace("Skip backend request")
return
}
groups, _, err := searchReq.Execute()
gapisp.Finish()
if err != nil {
req.log.WithError(err).Warning("failed to get groups")
return
}
pi.log.WithField("count", len(groups.Results)).Trace("Got results from API")
for _, g := range groups.Results {
gEntries = append(gEntries, pi.GroupEntry(pi.APIGroupToLDAPGroup(g)))
}
}()
go func() {
defer wg.Done()
uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user")
searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
pi.log.Trace("Skip backend request")
return
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
req.log.WithError(err).Warning("failed to get users")
return
}
for _, u := range users.Results {
uEntries = append(uEntries, pi.GroupEntry(pi.APIUserToLDAPGroup(u)))
}
}()
wg.Wait()
entries = append(gEntries, uEntries...)
case "":
fallthrough
case "organizationalPerson":
fallthrough
case "inetOrgPerson":
fallthrough
case "goauthentik.io/ldap/user":
fallthrough
case UserObjectClass:
uapisp := sentry.StartSpan(req.ctx, "authentik.providers.ldap.search.api_user")
searchReq, skip := parseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
pi.log.Trace("Skip backend request")
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err)
}
for _, u := range users.Results {
entries = append(entries, pi.UserEntry(u))
}
}
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
dn := pi.GetUserDN(u.Username)
attrs := AKAttrsToLDAP(u.Attributes)
attrs = pi.ensureAttributes(attrs, map[string][]string{
"memberOf": pi.GroupsForUser(u),
// Old fields for backwards compatibility
"accountStatus": {BoolToString(*u.IsActive)},
"superuser": {BoolToString(u.IsSuperuser)},
"goauthentik.io/ldap/active": {BoolToString(*u.IsActive)},
"goauthentik.io/ldap/superuser": {BoolToString(u.IsSuperuser)},
"cn": {u.Username},
"sAMAccountName": {u.Username},
"uid": {u.Uid},
"name": {u.Name},
"displayName": {u.Name},
"mail": {*u.Email},
"objectClass": {UserObjectClass, "organizationalPerson", "inetOrgPerson", "goauthentik.io/ldap/user"},
"uidNumber": {pi.GetUidNumber(u)},
"gidNumber": {pi.GetUidNumber(u)},
})
return &ldap.Entry{DN: dn, Attributes: attrs}
}
func (pi *ProviderInstance) GroupEntry(g LDAPGroup) *ldap.Entry {
attrs := AKAttrsToLDAP(g.akAttributes)
objectClass := []string{GroupObjectClass, "groupOfUniqueNames", "goauthentik.io/ldap/group"}
if g.isVirtualGroup {
objectClass = append(objectClass, "goauthentik.io/ldap/virtual-group")
}
attrs = pi.ensureAttributes(attrs, map[string][]string{
"objectClass": objectClass,
"member": g.member,
"goauthentik.io/ldap/superuser": {BoolToString(g.isSuperuser)},
"cn": {g.cn},
"uid": {g.uid},
"sAMAccountName": {g.cn},
"gidNumber": {g.gidNumber},
})
return &ldap.Entry{DN: g.dn, Attributes: attrs}
}

View File

@ -2,50 +2,18 @@ package ldap
import ( import (
"crypto/tls" "crypto/tls"
"net"
"sync" "sync"
"github.com/go-openapi/strfmt" "github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/api"
"goauthentik.io/internal/crypto" "goauthentik.io/internal/crypto"
"goauthentik.io/internal/outpost/ak" "goauthentik.io/internal/outpost/ak"
"goauthentik.io/internal/outpost/ldap/metrics"
"github.com/nmcclain/ldap" "github.com/nmcclain/ldap"
) )
const GroupObjectClass = "group"
const UserObjectClass = "user"
type ProviderInstance struct {
BaseDN string
UserDN string
VirtualGroupDN string
GroupDN string
appSlug string
flowSlug string
s *LDAPServer
log *log.Entry
tlsServerName *string
cert *tls.Certificate
outpostName string
searchAllowedGroups []*strfmt.UUID
boundUsersMutex sync.RWMutex
boundUsers map[string]UserFlags
uidStartNumber int32
gidStartNumber int32
}
type UserFlags struct {
UserInfo *api.User
UserPk int32
CanSearch bool
}
type LDAPServer struct { type LDAPServer struct {
s *ldap.Server s *ldap.Server
log *log.Entry log *log.Entry
@ -55,17 +23,6 @@ type LDAPServer struct {
providers []*ProviderInstance providers []*ProviderInstance
} }
type LDAPGroup struct {
dn string
cn string
uid string
gidNumber string
member []string
isSuperuser bool
isVirtualGroup bool
akAttributes interface{}
}
func NewServer(ac *ak.APIController) *LDAPServer { func NewServer(ac *ak.APIController) *LDAPServer {
s := ldap.NewServer() s := ldap.NewServer()
s.EnforceLDAP = true s.EnforceLDAP = true
@ -90,3 +47,54 @@ func NewServer(ac *ak.APIController) *LDAPServer {
func (ls *LDAPServer) Type() string { func (ls *LDAPServer) Type() string {
return "ldap" return "ldap"
} }
func (ls *LDAPServer) StartLDAPServer() error {
listen := "0.0.0.0:3389"
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
ls.log.WithField("listen", listen).Info("Starting ldap server")
err = ls.s.Serve(proxyListener)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}
func (ls *LDAPServer) Start() error {
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
defer wg.Done()
metrics.RunServer()
}()
go func() {
defer wg.Done()
err := ls.StartLDAPServer()
if err != nil {
panic(err)
}
}()
go func() {
defer wg.Done()
err := ls.StartLDAPTLSServer()
if err != nil {
panic(err)
}
}()
wg.Wait()
return nil
}
func (ls *LDAPServer) TimerFlowCacheExpiry() {
for _, p := range ls.providers {
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
p.binder.TimerFlowCacheExpiry()
}
}

View File

@ -0,0 +1,55 @@
package ldap
import (
"crypto/tls"
"net"
"github.com/pires/go-proxyproto"
)
func (ls *LDAPServer) getCertificates(info *tls.ClientHelloInfo) (*tls.Certificate, error) {
if len(ls.providers) == 1 {
if ls.providers[0].cert != nil {
ls.log.WithField("server-name", info.ServerName).Debug("We only have a single provider, using their cert")
return ls.providers[0].cert, nil
}
}
for _, provider := range ls.providers {
if provider.tlsServerName == &info.ServerName {
if provider.cert == nil {
ls.log.WithField("server-name", info.ServerName).Debug("Handler does not have a certificate")
return ls.defaultCert, nil
}
return provider.cert, nil
}
}
ls.log.WithField("server-name", info.ServerName).Debug("Fallback to default cert")
return ls.defaultCert, nil
}
func (ls *LDAPServer) StartLDAPTLSServer() error {
listen := "0.0.0.0:6636"
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
GetCertificate: ls.getCertificates,
}
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.WithField("listen", listen).WithError(err).Fatalf("FATAL: listen failed")
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
tln := tls.NewListener(proxyListener, tlsConfig)
ls.log.WithField("listen", listen).Info("Starting ldap tls server")
err = ls.s.Serve(tln)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}

View File

@ -2,23 +2,19 @@ package ldap
import ( import (
"context" "context"
"crypto/tls"
"errors" "errors"
"fmt" "fmt"
"net"
"strings" "strings"
"sync" "sync"
"github.com/go-openapi/strfmt" "github.com/go-openapi/strfmt"
"github.com/pires/go-proxyproto"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/api"
) directbind "goauthentik.io/internal/outpost/ldap/bind/direct"
"goauthentik.io/internal/outpost/ldap/constants"
const ( "goauthentik.io/internal/outpost/ldap/flags"
UsersOU = "users" directsearch "goauthentik.io/internal/outpost/ldap/search/direct"
GroupsOU = "groups" memorysearch "goauthentik.io/internal/outpost/ldap/search/memory"
VirtualGroupsOU = "virtual-groups"
) )
func (ls *LDAPServer) Refresh() error { func (ls *LDAPServer) Refresh() error {
@ -31,9 +27,9 @@ func (ls *LDAPServer) Refresh() error {
} }
providers := make([]*ProviderInstance, len(outposts.Results)) providers := make([]*ProviderInstance, len(outposts.Results))
for idx, provider := range outposts.Results { for idx, provider := range outposts.Results {
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", UsersOU, *provider.BaseDn)) userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", GroupsOU, *provider.BaseDn)) groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", VirtualGroupsOU, *provider.BaseDn)) virtualGroupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUVirtualGroups, *provider.BaseDn))
logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name) logger := log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name)
providers[idx] = &ProviderInstance{ providers[idx] = &ProviderInstance{
BaseDN: *provider.BaseDn, BaseDN: *provider.BaseDn,
@ -44,7 +40,7 @@ func (ls *LDAPServer) Refresh() error {
flowSlug: provider.BindFlowSlug, flowSlug: provider.BindFlowSlug,
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
boundUsersMutex: sync.RWMutex{}, boundUsersMutex: sync.RWMutex{},
boundUsers: make(map[string]UserFlags), boundUsers: make(map[string]flags.UserFlags),
s: ls, s: ls,
log: logger, log: logger,
tlsServerName: provider.TlsServerName, tlsServerName: provider.TlsServerName,
@ -60,79 +56,14 @@ func (ls *LDAPServer) Refresh() error {
} }
providers[idx].cert = ls.cs.Get(*kp) providers[idx].cert = ls.cs.Get(*kp)
} }
if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_CACHED {
providers[idx].searcher = memorysearch.NewMemorySearcher(providers[idx])
} else if *provider.SearchMode.Ptr() == api.SEARCHMODEENUM_DIRECT {
providers[idx].searcher = directsearch.NewDirectSearcher(providers[idx])
}
providers[idx].binder = directbind.NewDirectBinder(providers[idx])
} }
ls.providers = providers ls.providers = providers
ls.log.Info("Update providers") ls.log.Info("Update providers")
return nil return nil
} }
func (ls *LDAPServer) StartLDAPServer() error {
listen := "0.0.0.0:3389"
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err)
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
ls.log.WithField("listen", listen).Info("Starting ldap server")
err = ls.s.Serve(proxyListener)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}
func (ls *LDAPServer) StartLDAPTLSServer() error {
listen := "0.0.0.0:6636"
tlsConfig := &tls.Config{
MinVersion: tls.VersionTLS12,
MaxVersion: tls.VersionTLS12,
GetCertificate: ls.getCertificates,
}
ln, err := net.Listen("tcp", listen)
if err != nil {
ls.log.Fatalf("FATAL: listen (%s) failed - %s", listen, err)
}
proxyListener := &proxyproto.Listener{Listener: ln}
defer proxyListener.Close()
tln := tls.NewListener(proxyListener, tlsConfig)
ls.log.WithField("listen", listen).Info("Starting ldap tls server")
err = ls.s.Serve(tln)
if err != nil {
return err
}
ls.log.Printf("closing %s", ln.Addr())
return ls.s.ListenAndServe(listen)
}
func (ls *LDAPServer) Start() error {
wg := sync.WaitGroup{}
wg.Add(3)
go func() {
defer wg.Done()
metrics.RunServer()
}()
go func() {
defer wg.Done()
err := ls.StartLDAPServer()
if err != nil {
panic(err)
}
}()
go func() {
defer wg.Done()
err := ls.StartLDAPTLSServer()
if err != nil {
panic(err)
}
}()
wg.Wait()
return nil
}

View File

@ -1,47 +1,21 @@
package ldap package ldap
import ( import (
"context"
"errors" "errors"
"net" "net"
"strings"
"github.com/getsentry/sentry-go" "github.com/getsentry/sentry-go"
goldap "github.com/go-ldap/ldap/v3" goldap "github.com/go-ldap/ldap/v3"
"github.com/google/uuid"
"github.com/nmcclain/ldap" "github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus" "github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus" log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/utils" "goauthentik.io/internal/utils"
) )
type SearchRequest struct {
ldap.SearchRequest
BindDN string
id string
conn net.Conn
log *log.Entry
ctx context.Context
}
func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) { func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (ldap.ServerSearchResult, error) {
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search")) req, span := search.NewRequest(bindDN, searchReq, conn)
rid := uuid.New().String()
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
span.SetTag("ak_filter", searchReq.Filter)
span.SetTag("ak_base_dn", searchReq.BaseDN)
bindDN = strings.ToLower(bindDN)
req := SearchRequest{
SearchRequest: searchReq,
BindDN: bindDN,
conn: conn,
log: ls.log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN),
id: rid,
ctx: span.Context(),
}
defer func() { defer func() {
span.Finish() span.Finish()
@ -50,9 +24,9 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n
"type": "search", "type": "search",
"filter": req.Filter, "filter": req.Filter,
"dn": req.BindDN, "dn": req.BindDN,
"client": utils.GetIP(req.conn.RemoteAddr()), "client": utils.GetIP(conn.RemoteAddr()),
}).Observe(float64(span.EndTime.Sub(span.StartTime))) }).Observe(float64(span.EndTime.Sub(span.StartTime)))
req.log.WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request") req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Search request")
}() }()
defer func() { defer func() {
@ -69,13 +43,13 @@ func (ls *LDAPServer) Search(bindDN string, searchReq ldap.SearchRequest, conn n
} }
bd, err := goldap.ParseDN(searchReq.BaseDN) bd, err := goldap.ParseDN(searchReq.BaseDN)
if err != nil { if err != nil {
req.log.WithError(err).Info("failed to parse basedn") req.Log().WithError(err).Info("failed to parse basedn")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("invalid DN") return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("invalid DN")
} }
for _, provider := range ls.providers { for _, provider := range ls.providers {
providerBase, _ := goldap.ParseDN(provider.BaseDN) providerBase, _ := goldap.ParseDN(provider.BaseDN)
if providerBase.AncestorOf(bd) || providerBase.Equal(bd) { if providerBase.AncestorOf(bd) || providerBase.Equal(bd) {
return provider.Search(req) return provider.searcher.Search(req)
} }
} }
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("no provider could handle request") return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, errors.New("no provider could handle request")

View File

@ -1,13 +1,14 @@
package ldap package direct
import ( import (
"fmt" "fmt"
"github.com/nmcclain/ldap" "github.com/nmcclain/ldap"
"goauthentik.io/internal/constants" "goauthentik.io/internal/constants"
"goauthentik.io/internal/outpost/ldap/search"
) )
func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.ServerSearchResult, error) { func (ds *DirectSearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) {
dn := "" dn := ""
if authz { if authz {
dn = req.SearchRequest.BaseDN dn = req.SearchRequest.BaseDN
@ -19,7 +20,7 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv
Attributes: []*ldap.EntryAttribute{ Attributes: []*ldap.EntryAttribute{
{ {
Name: "distinguishedName", Name: "distinguishedName",
Values: []string{pi.BaseDN}, Values: []string{ds.si.GetBaseDN()},
}, },
{ {
Name: "objectClass", Name: "objectClass",
@ -32,9 +33,9 @@ func (pi *ProviderInstance) SearchBase(req SearchRequest, authz bool) (ldap.Serv
{ {
Name: "namingContexts", Name: "namingContexts",
Values: []string{ Values: []string{
pi.BaseDN, ds.si.GetBaseDN(),
pi.GroupDN, ds.si.GetBaseUserDN(),
pi.UserDN, ds.si.GetBaseGroupDN(),
}, },
}, },
{ {

View File

@ -0,0 +1,219 @@
package direct
import (
"errors"
"fmt"
"strings"
"sync"
log "github.com/sirupsen/logrus"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils"
)
type DirectSearcher struct {
si server.LDAPServerInstance
log *log.Entry
}
func NewDirectSearcher(si server.LDAPServerInstance) *DirectSearcher {
ds := &DirectSearcher{
si: si,
log: log.WithField("logger", "authentik.outpost.ldap.searcher.direct"),
}
ds.log.Info("initialised direct searcher")
return ds
}
func (ds *DirectSearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
if f.UserInfo == nil {
u, _, err := ds.si.GetAPIClient().CoreApi.CoreUsersRetrieve(req.Context(), f.UserPk).Execute()
if err != nil {
req.Log().WithError(err).Warning("Failed to get user info")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
}
f.UserInfo = &u
}
entries := make([]*ldap.Entry, 1)
entries[0] = ds.si.UserEntry(*f.UserInfo)
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (ds *DirectSearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
baseDN := strings.ToLower("," + ds.si.GetBaseDN())
entries := []*ldap.Entry{}
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
if len(req.BindDN) < 1 {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "empty_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
}
if !strings.HasSuffix(req.BindDN, baseDN) {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "invalid_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ds.si.GetBaseDN())
}
flags, ok := ds.si.GetFlags(req.BindDN)
if !ok {
req.Log().Debug("User info not cached")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "user_info_not_cached",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
}
if req.Scope == ldap.ScopeBaseObject {
req.Log().Debug("base scope, showing domain info")
return ds.SearchBase(req, flags.CanSearch)
}
if !flags.CanSearch {
req.Log().Debug("User can't search, showing info about user")
return ds.SearchMe(req, flags)
}
accsp.Finish()
parsedFilter, err := ldap.CompileFilter(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
// Create a custom client to set additional headers
c := api.NewAPIClient(ds.si.GetAPIClient().GetConfig())
c.GetConfig().AddDefaultHeader("X-authentik-outpost-ldap-query", req.Filter)
switch filterEntity {
default:
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ds.si.GetOutpostName(),
"type": "search",
"reason": "unhandled_filter_type",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
case constants.OCGroupOfUniqueNames:
fallthrough
case constants.OCAKGroup:
fallthrough
case constants.OCAKVirtualGroup:
fallthrough
case constants.OCGroup:
wg := sync.WaitGroup{}
wg.Add(2)
gEntries := make([]*ldap.Entry, 0)
uEntries := make([]*ldap.Entry, 0)
go func() {
defer wg.Done()
gapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_group")
searchReq, skip := utils.ParseFilterForGroup(c.CoreApi.CoreGroupsList(gapisp.Context()), parsedFilter, false)
if skip {
req.Log().Trace("Skip backend request")
return
}
groups, _, err := searchReq.Execute()
gapisp.Finish()
if err != nil {
req.Log().WithError(err).Warning("failed to get groups")
return
}
req.Log().WithField("count", len(groups.Results)).Trace("Got results from API")
for _, g := range groups.Results {
gEntries = append(gEntries, group.FromAPIGroup(g, ds.si).Entry())
}
}()
go func() {
defer wg.Done()
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
req.Log().Trace("Skip backend request")
return
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
req.Log().WithError(err).Warning("failed to get users")
return
}
for _, u := range users.Results {
uEntries = append(uEntries, group.FromAPIUser(u, ds.si).Entry())
}
}()
wg.Wait()
entries = append(gEntries, uEntries...)
case "":
fallthrough
case constants.OCOrgPerson:
fallthrough
case constants.OCInetOrgPerson:
fallthrough
case constants.OCAKUser:
fallthrough
case constants.OCUser:
uapisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.api_user")
searchReq, skip := utils.ParseFilterForUser(c.CoreApi.CoreUsersList(uapisp.Context()), parsedFilter, false)
if skip {
req.Log().Trace("Skip backend request")
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
users, _, err := searchReq.Execute()
uapisp.Finish()
if err != nil {
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("API Error: %s", err)
}
for _, u := range users.Results {
entries = append(entries, ds.si.UserEntry(u))
}
}
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}

View File

@ -0,0 +1,54 @@
package memory
import (
"fmt"
"github.com/nmcclain/ldap"
"goauthentik.io/internal/constants"
"goauthentik.io/internal/outpost/ldap/search"
)
func (ms *MemorySearcher) SearchBase(req *search.Request, authz bool) (ldap.ServerSearchResult, error) {
dn := ""
if authz {
dn = req.SearchRequest.BaseDN
}
return ldap.ServerSearchResult{
Entries: []*ldap.Entry{
{
DN: dn,
Attributes: []*ldap.EntryAttribute{
{
Name: "distinguishedName",
Values: []string{ms.si.GetBaseDN()},
},
{
Name: "objectClass",
Values: []string{"top", "domain"},
},
{
Name: "supportedLDAPVersion",
Values: []string{"3"},
},
{
Name: "namingContexts",
Values: []string{
ms.si.GetBaseDN(),
ms.si.GetBaseUserDN(),
ms.si.GetBaseGroupDN(),
},
},
{
Name: "vendorName",
Values: []string{"goauthentik.io"},
},
{
Name: "vendorVersion",
Values: []string{fmt.Sprintf("authentik LDAP Outpost Version %s (build %s)", constants.VERSION, constants.BUILD())},
},
},
},
},
Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess,
}, nil
}

View File

@ -0,0 +1,63 @@
package memory
import (
"context"
"goauthentik.io/api"
)
const pageSize = 100
func (ms *MemorySearcher) FetchUsers() []api.User {
fetchUsersOffset := func(page int) (*api.PaginatedUserList, error) {
users, _, err := ms.si.GetAPIClient().CoreApi.CoreUsersList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
if err != nil {
ms.log.WithError(err).Warning("failed to update users")
return nil, err
}
ms.log.WithField("page", page).Debug("fetched users")
return &users, nil
}
page := 1
users := make([]api.User, 0)
for {
apiUsers, err := fetchUsersOffset(page)
if err != nil {
return users
}
if apiUsers.Pagination.Next > 0 {
page += 1
} else {
break
}
users = append(users, apiUsers.Results...)
}
return users
}
func (ms *MemorySearcher) FetchGroups() []api.Group {
fetchGroupsOffset := func(page int) (*api.PaginatedGroupList, error) {
groups, _, err := ms.si.GetAPIClient().CoreApi.CoreGroupsList(context.TODO()).Page(int32(page)).PageSize(pageSize).Execute()
if err != nil {
ms.log.WithError(err).Warning("failed to update groups")
return nil, err
}
ms.log.WithField("page", page).Debug("fetched groups")
return &groups, nil
}
page := 1
groups := make([]api.Group, 0)
for {
apiGroups, err := fetchGroupsOffset(page)
if err != nil {
return groups
}
if apiGroups.Pagination.Next > 0 {
page += 1
} else {
break
}
groups = append(groups, apiGroups.Results...)
}
return groups
}

View File

@ -0,0 +1,182 @@
package memory
import (
"errors"
"fmt"
"strings"
"sync"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/group"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/search"
"goauthentik.io/internal/outpost/ldap/server"
)
type MemorySearcher struct {
si server.LDAPServerInstance
log *log.Entry
users []api.User
groups []api.Group
}
func NewMemorySearcher(si server.LDAPServerInstance) *MemorySearcher {
ms := &MemorySearcher{
si: si,
log: log.WithField("logger", "authentik.outpost.ldap.searcher.memory"),
}
ms.log.Info("initialised memory searcher")
ms.users = ms.FetchUsers()
ms.groups = ms.FetchGroups()
return ms
}
func (ms *MemorySearcher) SearchMe(req *search.Request, f flags.UserFlags) (ldap.ServerSearchResult, error) {
if f.UserInfo == nil {
for _, u := range ms.users {
if u.Pk == f.UserPk {
f.UserInfo = &u
}
}
if f.UserInfo == nil {
req.Log().WithField("pk", f.UserPk).Warning("User with pk is not in local cache")
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("failed to get userinfo")
}
}
entries := make([]*ldap.Entry, 1)
entries[0] = ms.si.UserEntry(*f.UserInfo)
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}
func (ms *MemorySearcher) Search(req *search.Request) (ldap.ServerSearchResult, error) {
accsp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.search.check_access")
baseDN := strings.ToLower("," + ms.si.GetBaseDN())
entries := []*ldap.Entry{}
filterEntity, err := ldap.GetFilterObjectClass(req.Filter)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "filter_parse_fail",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
}
if len(req.BindDN) < 1 {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "empty_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: Anonymous BindDN not allowed %s", req.BindDN)
}
if !strings.HasSuffix(req.BindDN, baseDN) {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "invalid_bind_dn",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, fmt.Errorf("Search Error: BindDN %s not in our BaseDN %s", req.BindDN, ms.si.GetBaseDN())
}
flags, ok := ms.si.GetFlags(req.BindDN)
if !ok {
req.Log().Debug("User info not cached")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "user_info_not_cached",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied")
}
if req.Scope == ldap.ScopeBaseObject {
req.Log().Debug("base scope, showing domain info")
return ms.SearchBase(req, flags.CanSearch)
}
if !flags.CanSearch {
req.Log().Debug("User can't search, showing info about user")
return ms.SearchMe(req, flags)
}
accsp.Finish()
// parsedFilter, err := ldap.CompileFilter(req.Filter)
// if err != nil {
// metrics.RequestsRejected.With(prometheus.Labels{
// "outpost_name": ms.si.GetOutpostName(),
// "type": "search",
// "reason": "filter_parse_fail",
// "dn": req.BindDN,
// "client": req.RemoteAddr(),
// }).Inc()
// return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: error parsing filter: %s", req.Filter)
// }
switch filterEntity {
default:
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ms.si.GetOutpostName(),
"type": "search",
"reason": "unhandled_filter_type",
"dn": req.BindDN,
"client": req.RemoteAddr(),
}).Inc()
return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultOperationsError}, fmt.Errorf("Search Error: unhandled filter type: %s [%s]", filterEntity, req.Filter)
case constants.OCGroupOfUniqueNames:
fallthrough
case constants.OCAKGroup:
fallthrough
case constants.OCAKVirtualGroup:
fallthrough
case constants.OCGroup:
wg := sync.WaitGroup{}
wg.Add(2)
gEntries := make([]*ldap.Entry, 0)
uEntries := make([]*ldap.Entry, 0)
go func() {
defer wg.Done()
for _, g := range ms.groups {
gEntries = append(gEntries, group.FromAPIGroup(g, ms.si).Entry())
}
}()
go func() {
defer wg.Done()
for _, u := range ms.users {
uEntries = append(uEntries, group.FromAPIUser(u, ms.si).Entry())
}
}()
wg.Wait()
entries = append(gEntries, uEntries...)
case "":
fallthrough
case constants.OCOrgPerson:
fallthrough
case constants.OCInetOrgPerson:
fallthrough
case constants.OCAKUser:
fallthrough
case constants.OCUser:
for _, u := range ms.users {
entries = append(entries, ms.si.UserEntry(u))
}
}
return ldap.ServerSearchResult{Entries: entries, Referrals: []string{}, Controls: []ldap.Control{}, ResultCode: ldap.LDAPResultSuccess}, nil
}

View File

@ -0,0 +1,53 @@
package search
import (
"context"
"net"
"strings"
"github.com/getsentry/sentry-go"
"github.com/google/uuid"
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/utils"
)
type Request struct {
ldap.SearchRequest
BindDN string
log *log.Entry
id string
conn net.Conn
ctx context.Context
}
func NewRequest(bindDN string, searchReq ldap.SearchRequest, conn net.Conn) (*Request, *sentry.Span) {
rid := uuid.New().String()
bindDN = strings.ToLower(bindDN)
span := sentry.StartSpan(context.TODO(), "authentik.providers.ldap.search", sentry.TransactionName("authentik.providers.ldap.search"))
span.SetTag("request_uid", rid)
span.SetTag("user.username", bindDN)
span.SetTag("ak_filter", searchReq.Filter)
span.SetTag("ak_base_dn", searchReq.BaseDN)
return &Request{
SearchRequest: searchReq,
BindDN: bindDN,
conn: conn,
log: log.WithField("bindDN", bindDN).WithField("requestId", rid).WithField("scope", ldap.ScopeMap[searchReq.Scope]).WithField("client", utils.GetIP(conn.RemoteAddr())).WithField("filter", searchReq.Filter).WithField("baseDN", searchReq.BaseDN),
id: rid,
ctx: span.Context(),
}, span
}
func (r *Request) Context() context.Context {
return r.ctx
}
func (r *Request) Log() *log.Entry {
return r.log
}
func (r *Request) RemoteAddr() string {
return utils.GetIP(r.conn.RemoteAddr())
}

View File

@ -0,0 +1,7 @@
package search
import "github.com/nmcclain/ldap"
type Searcher interface {
Search(req *Request) (ldap.ServerSearchResult, error)
}

View File

@ -0,0 +1,35 @@
package server
import (
"github.com/go-openapi/strfmt"
"github.com/nmcclain/ldap"
"goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/flags"
)
type LDAPServerInstance interface {
GetAPIClient() *api.APIClient
GetOutpostName() string
GetFlowSlug() string
GetAppSlug() string
GetSearchAllowedGroups() []*strfmt.UUID
UserEntry(u api.User) *ldap.Entry
GetBaseDN() string
GetBaseGroupDN() string
GetBaseUserDN() string
GetUserDN(string) string
GetGroupDN(string) string
GetVirtualGroupDN(string) string
GetUidNumber(api.User) string
GetGidNumber(api.Group) string
UsersForGroup(api.Group) []string
GetFlags(string) (flags.UserFlags, bool)
SetFlags(string, flags.UserFlags)
}

View File

@ -3,70 +3,12 @@ package ldap
import ( import (
"fmt" "fmt"
"math/big" "math/big"
"reflect"
"strconv" "strconv"
"strings" "strings"
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
"goauthentik.io/api" "goauthentik.io/api"
) )
func BoolToString(in bool) string {
if in {
return "true"
}
return "false"
}
func ldapResolveTypeSingle(in interface{}) *string {
switch t := in.(type) {
case string:
return &t
case *string:
return t
case bool:
s := BoolToString(t)
return &s
case *bool:
s := BoolToString(*t)
return &s
default:
log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet")
return nil
}
}
func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute {
attrList := []*ldap.EntryAttribute{}
if attrs == nil {
return attrList
}
a := attrs.(*map[string]interface{})
for attrKey, attrValue := range *a {
entry := &ldap.EntryAttribute{Name: attrKey}
switch t := attrValue.(type) {
case []string:
entry.Values = t
case *[]string:
entry.Values = *t
case []interface{}:
entry.Values = make([]string, len(t))
for idx, v := range t {
v := ldapResolveTypeSingle(v)
entry.Values[idx] = *v
}
default:
v := ldapResolveTypeSingle(t)
if v != nil {
entry.Values = []string{*v}
}
}
attrList = append(attrList, entry)
}
return attrList
}
func (pi *ProviderInstance) GroupsForUser(user api.User) []string { func (pi *ProviderInstance) GroupsForUser(user api.User) []string {
groups := make([]string, len(user.Groups)) groups := make([]string, len(user.Groups))
for i, group := range user.GroupsObj { for i, group := range user.GroupsObj {
@ -83,32 +25,6 @@ func (pi *ProviderInstance) UsersForGroup(group api.Group) []string {
return users return users
} }
func (pi *ProviderInstance) APIGroupToLDAPGroup(g api.Group) LDAPGroup {
return LDAPGroup{
dn: pi.GetGroupDN(g.Name),
cn: g.Name,
uid: string(g.Pk),
gidNumber: pi.GetGidNumber(g),
member: pi.UsersForGroup(g),
isVirtualGroup: false,
isSuperuser: *g.IsSuperuser,
akAttributes: g.Attributes,
}
}
func (pi *ProviderInstance) APIUserToLDAPGroup(u api.User) LDAPGroup {
return LDAPGroup{
dn: pi.GetVirtualGroupDN(u.Username),
cn: u.Username,
uid: u.Uid,
gidNumber: pi.GetUidNumber(u),
member: []string{pi.GetUserDN(u.Username)},
isVirtualGroup: true,
isSuperuser: false,
akAttributes: nil,
}
}
func (pi *ProviderInstance) GetUserDN(user string) string { func (pi *ProviderInstance) GetUserDN(user string) string {
return fmt.Sprintf("cn=%s,%s", user, pi.UserDN) return fmt.Sprintf("cn=%s,%s", user, pi.UserDN)
} }
@ -155,26 +71,3 @@ func (pi *ProviderInstance) GetRIDForGroup(uid string) int32 {
return int32(gid) return int32(gid)
} }
func (pi *ProviderInstance) ensureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute {
for name, values := range shouldHave {
attrs = pi.mustHaveAttribute(attrs, name, values)
}
return attrs
}
func (pi *ProviderInstance) mustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute {
shouldSet := true
for _, attr := range attrs {
if attr.Name == name {
shouldSet = false
}
}
if shouldSet {
return append(attrs, &ldap.EntryAttribute{
Name: name,
Values: value,
})
}
return attrs
}

View File

@ -0,0 +1,86 @@
package utils
import (
"reflect"
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
)
func BoolToString(in bool) string {
if in {
return "true"
}
return "false"
}
func ldapResolveTypeSingle(in interface{}) *string {
switch t := in.(type) {
case string:
return &t
case *string:
return t
case bool:
s := BoolToString(t)
return &s
case *bool:
s := BoolToString(*t)
return &s
default:
log.WithField("type", reflect.TypeOf(in).String()).Warning("Type can't be mapped to LDAP yet")
return nil
}
}
func AKAttrsToLDAP(attrs interface{}) []*ldap.EntryAttribute {
attrList := []*ldap.EntryAttribute{}
if attrs == nil {
return attrList
}
a := attrs.(*map[string]interface{})
for attrKey, attrValue := range *a {
entry := &ldap.EntryAttribute{Name: attrKey}
switch t := attrValue.(type) {
case []string:
entry.Values = t
case *[]string:
entry.Values = *t
case []interface{}:
entry.Values = make([]string, len(t))
for idx, v := range t {
v := ldapResolveTypeSingle(v)
entry.Values[idx] = *v
}
default:
v := ldapResolveTypeSingle(t)
if v != nil {
entry.Values = []string{*v}
}
}
attrList = append(attrList, entry)
}
return attrList
}
func EnsureAttributes(attrs []*ldap.EntryAttribute, shouldHave map[string][]string) []*ldap.EntryAttribute {
for name, values := range shouldHave {
attrs = MustHaveAttribute(attrs, name, values)
}
return attrs
}
func MustHaveAttribute(attrs []*ldap.EntryAttribute, name string, value []string) []*ldap.EntryAttribute {
shouldSet := true
for _, attr := range attrs {
if attr.Name == name {
shouldSet = false
}
}
if shouldSet {
return append(attrs, &ldap.EntryAttribute{
Name: name,
Values: value,
})
}
return attrs
}

View File

@ -1,19 +1,20 @@
package ldap package utils
import ( import (
goldap "github.com/go-ldap/ldap/v3" goldap "github.com/go-ldap/ldap/v3"
ber "github.com/nmcclain/asn1-ber" ber "github.com/nmcclain/asn1-ber"
"github.com/nmcclain/ldap" "github.com/nmcclain/ldap"
"goauthentik.io/api" "goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
) )
func parseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) { func ParseFilterForGroup(req api.ApiCoreGroupsListRequest, f *ber.Packet, skip bool) (api.ApiCoreGroupsListRequest, bool) {
switch f.Tag { switch f.Tag {
case ldap.FilterEqualityMatch: case ldap.FilterEqualityMatch:
return parseFilterForGroupSingle(req, f) return parseFilterForGroupSingle(req, f)
case ldap.FilterAnd: case ldap.FilterAnd:
for _, child := range f.Children { for _, child := range f.Children {
r, s := parseFilterForGroup(req, child, skip) r, s := ParseFilterForGroup(req, child, skip)
skip = skip || s skip = skip || s
req = r req = r
} }
@ -53,7 +54,7 @@ func parseFilterForGroupSingle(req api.ApiCoreGroupsListRequest, f *ber.Packet)
username := userDN.RDNs[0].Attributes[0].Value username := userDN.RDNs[0].Attributes[0].Value
// If the DN's first ou is virtual-groups, ignore this filter // If the DN's first ou is virtual-groups, ignore this filter
if len(userDN.RDNs) > 1 { if len(userDN.RDNs) > 1 {
if userDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU || userDN.RDNs[1].Attributes[0].Value == GroupsOU { if userDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups || userDN.RDNs[1].Attributes[0].Value == constants.OUGroups {
// Since we know we're not filtering anything, skip this request // Since we know we're not filtering anything, skip this request
return req, true return req, true
} }

View File

@ -1,19 +1,20 @@
package ldap package utils
import ( import (
goldap "github.com/go-ldap/ldap/v3" goldap "github.com/go-ldap/ldap/v3"
ber "github.com/nmcclain/asn1-ber" ber "github.com/nmcclain/asn1-ber"
"github.com/nmcclain/ldap" "github.com/nmcclain/ldap"
"goauthentik.io/api" "goauthentik.io/api"
"goauthentik.io/internal/outpost/ldap/constants"
) )
func parseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) { func ParseFilterForUser(req api.ApiCoreUsersListRequest, f *ber.Packet, skip bool) (api.ApiCoreUsersListRequest, bool) {
switch f.Tag { switch f.Tag {
case ldap.FilterEqualityMatch: case ldap.FilterEqualityMatch:
return parseFilterForUserSingle(req, f) return parseFilterForUserSingle(req, f)
case ldap.FilterAnd: case ldap.FilterAnd:
for _, child := range f.Children { for _, child := range f.Children {
r, s := parseFilterForUser(req, child, skip) r, s := ParseFilterForUser(req, child, skip)
skip = skip || s skip = skip || s
req = r req = r
} }
@ -58,7 +59,7 @@ func parseFilterForUserSingle(req api.ApiCoreUsersListRequest, f *ber.Packet) (a
name := groupDN.RDNs[0].Attributes[0].Value name := groupDN.RDNs[0].Attributes[0].Value
// If the DN's first ou is virtual-groups, ignore this filter // If the DN's first ou is virtual-groups, ignore this filter
if len(groupDN.RDNs) > 1 { if len(groupDN.RDNs) > 1 {
if groupDN.RDNs[1].Attributes[0].Value == UsersOU || groupDN.RDNs[1].Attributes[0].Value == VirtualGroupsOU { if groupDN.RDNs[1].Attributes[0].Value == constants.OUUsers || groupDN.RDNs[1].Attributes[0].Value == constants.OUVirtualGroups {
// Since we know we're not filtering anything, skip this request // Since we know we're not filtering anything, skip this request
return req, true return req, true
} }

View File

@ -22035,6 +22035,8 @@ components:
generated from the group.Pk to make sure that the numbers aren't too low generated from the group.Pk to make sure that the numbers aren't too low
for POSIX groups. Default is 4000 to ensure that we don't collide with for POSIX groups. Default is 4000 to ensure that we don't collide with
local groups or users primary groups gidNumber local groups or users primary groups gidNumber
search_mode:
$ref: '#/components/schemas/SearchModeEnum'
required: required:
- application_slug - application_slug
- bind_flow_slug - bind_flow_slug
@ -22173,6 +22175,8 @@ components:
items: items:
type: string type: string
readOnly: true readOnly: true
search_mode:
$ref: '#/components/schemas/SearchModeEnum'
required: required:
- assigned_application_name - assigned_application_name
- assigned_application_slug - assigned_application_slug
@ -22228,6 +22232,8 @@ components:
generated from the group.Pk to make sure that the numbers aren't too low generated from the group.Pk to make sure that the numbers aren't too low
for POSIX groups. Default is 4000 to ensure that we don't collide with for POSIX groups. Default is 4000 to ensure that we don't collide with
local groups or users primary groups gidNumber local groups or users primary groups gidNumber
search_mode:
$ref: '#/components/schemas/SearchModeEnum'
required: required:
- authorization_flow - authorization_flow
- name - name
@ -26849,6 +26855,8 @@ components:
generated from the group.Pk to make sure that the numbers aren't too low generated from the group.Pk to make sure that the numbers aren't too low
for POSIX groups. Default is 4000 to ensure that we don't collide with for POSIX groups. Default is 4000 to ensure that we don't collide with
local groups or users primary groups gidNumber local groups or users primary groups gidNumber
search_mode:
$ref: '#/components/schemas/SearchModeEnum'
PatchedLDAPSourceRequest: PatchedLDAPSourceRequest:
type: object type: object
description: LDAP Source Serializer description: LDAP Source Serializer
@ -29410,6 +29418,11 @@ components:
- expression - expression
- name - name
- scope_name - scope_name
SearchModeEnum:
enum:
- direct
- cached
type: string
ServiceConnection: ServiceConnection:
type: object type: object
description: ServiceConnection Serializer description: ServiceConnection Serializer

View File

@ -625,6 +625,10 @@ msgstr "Cached flows"
msgid "Cached policies" msgid "Cached policies"
msgstr "Cached policies" msgstr "Cached policies"
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes."
msgstr "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes."
#: src/pages/sources/oauth/OAuthSourceViewPage.ts #: src/pages/sources/oauth/OAuthSourceViewPage.ts
msgid "Callback URL" msgid "Callback URL"
msgstr "Callback URL" msgstr "Callback URL"
@ -913,6 +917,10 @@ msgstr "Configure how the flow executor should handle an invalid response to a c
msgid "Configure how the issuer field of the ID Token should be filled." msgid "Configure how the issuer field of the ID Token should be filled."
msgstr "Configure how the issuer field of the ID Token should be filled." msgstr "Configure how the issuer field of the ID Token should be filled."
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Configure how the outpost queries the core authentik server's users."
msgstr "Configure how the outpost queries the core authentik server's users."
#: #:
#: #:
#~ msgid "Configure settings relevant to your user profile." #~ msgid "Configure settings relevant to your user profile."
@ -1416,6 +1424,10 @@ msgstr "Digest algorithm"
msgid "Digits" msgid "Digits"
msgstr "Digits" msgstr "Digits"
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Direct querying, always returns the latest data, but slower than cached querying."
msgstr "Direct querying, always returns the latest data, but slower than cached querying."
#: #:
#: #:
#~ msgid "Disable" #~ msgid "Disable"
@ -3802,6 +3814,10 @@ msgstr "Score"
msgid "Search group" msgid "Search group"
msgstr "Search group" msgstr "Search group"
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Search mode"
msgstr "Search mode"
#: src/elements/table/TableSearch.ts #: src/elements/table/TableSearch.ts
#: src/user/LibraryPage.ts #: src/user/LibraryPage.ts
msgid "Search..." msgid "Search..."

View File

@ -627,6 +627,10 @@ msgstr "Flux mis en cache"
msgid "Cached policies" msgid "Cached policies"
msgstr "Politiques mises en cache" msgstr "Politiques mises en cache"
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes."
msgstr ""
#: src/pages/sources/oauth/OAuthSourceViewPage.ts #: src/pages/sources/oauth/OAuthSourceViewPage.ts
msgid "Callback URL" msgid "Callback URL"
msgstr "URL de rappel" msgstr "URL de rappel"
@ -913,6 +917,10 @@ msgstr "Configure comment l'exécuteur de flux gère une réponse invalide à un
msgid "Configure how the issuer field of the ID Token should be filled." msgid "Configure how the issuer field of the ID Token should be filled."
msgstr "Configure comment le champ émetteur du jeton ID sera rempli." msgstr "Configure comment le champ émetteur du jeton ID sera rempli."
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Configure how the outpost queries the core authentik server's users."
msgstr ""
#~ msgid "Configure settings relevant to your user profile." #~ msgid "Configure settings relevant to your user profile."
#~ msgstr "Configure les paramètre applicable à votre profil." #~ msgstr "Configure les paramètre applicable à votre profil."
@ -1406,6 +1414,10 @@ msgstr "Algorithme d'empreinte"
msgid "Digits" msgid "Digits"
msgstr "Chiffres" msgstr "Chiffres"
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Direct querying, always returns the latest data, but slower than cached querying."
msgstr ""
#~ msgid "Disable" #~ msgid "Disable"
#~ msgstr "Désactiver" #~ msgstr "Désactiver"
@ -3770,6 +3782,10 @@ msgstr "Note"
msgid "Search group" msgid "Search group"
msgstr "Rechercher un groupe" msgstr "Rechercher un groupe"
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Search mode"
msgstr ""
#: src/elements/table/TableSearch.ts #: src/elements/table/TableSearch.ts
#: src/user/LibraryPage.ts #: src/user/LibraryPage.ts
msgid "Search..." msgid "Search..."

View File

@ -621,6 +621,10 @@ msgstr ""
msgid "Cached policies" msgid "Cached policies"
msgstr "" msgstr ""
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes."
msgstr ""
#: src/pages/sources/oauth/OAuthSourceViewPage.ts #: src/pages/sources/oauth/OAuthSourceViewPage.ts
msgid "Callback URL" msgid "Callback URL"
msgstr "" msgstr ""
@ -907,6 +911,10 @@ msgstr ""
msgid "Configure how the issuer field of the ID Token should be filled." msgid "Configure how the issuer field of the ID Token should be filled."
msgstr "" msgstr ""
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Configure how the outpost queries the core authentik server's users."
msgstr ""
#: #:
#: #:
#~ msgid "Configure settings relevant to your user profile." #~ msgid "Configure settings relevant to your user profile."
@ -1408,6 +1416,10 @@ msgstr ""
msgid "Digits" msgid "Digits"
msgstr "" msgstr ""
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Direct querying, always returns the latest data, but slower than cached querying."
msgstr ""
#: #:
#: #:
#~ msgid "Disable" #~ msgid "Disable"
@ -3794,6 +3806,10 @@ msgstr ""
msgid "Search group" msgid "Search group"
msgstr "" msgstr ""
#: src/pages/providers/ldap/LDAPProviderForm.ts
msgid "Search mode"
msgstr ""
#: src/elements/table/TableSearch.ts #: src/elements/table/TableSearch.ts
#: src/user/LibraryPage.ts #: src/user/LibraryPage.ts
msgid "Search..." msgid "Search..."

View File

@ -12,6 +12,7 @@ import {
FlowsInstancesListDesignationEnum, FlowsInstancesListDesignationEnum,
LDAPProvider, LDAPProvider,
ProvidersApi, ProvidersApi,
SearchModeEnum,
} from "@goauthentik/api"; } from "@goauthentik/api";
import { DEFAULT_CONFIG, tenant } from "../../../api/Config"; import { DEFAULT_CONFIG, tenant } from "../../../api/Config";
@ -118,6 +119,25 @@ export class LDAPProviderFormPage extends ModelForm<LDAPProvider, number> {
${t`Users in the selected group can do search queries. If no group is selected, no LDAP Searches are allowed.`} ${t`Users in the selected group can do search queries. If no group is selected, no LDAP Searches are allowed.`}
</p> </p>
</ak-form-element-horizontal> </ak-form-element-horizontal>
<ak-form-element-horizontal label=${t`Search mode`} name="searchMode">
<select class="pf-c-form-control">
<option
value=""
?selected=${this.instance?.searchMode === SearchModeEnum.Cached}
>
${t`Cached querying, the outpost holds all users and groups in-memory and will refresh every 5 Minutes.`}
</option>
<option
value=""
?selected=${this.instance?.searchMode === SearchModeEnum.Direct}
>
${t`Direct querying, always returns the latest data, but slower than cached querying.`}
</option>
</select>
<p class="pf-c-form__helper-text">
${t`Configure how the outpost queries the core authentik server's users.`}
</p>
</ak-form-element-horizontal>
<ak-form-group .expanded=${true}> <ak-form-group .expanded=${true}>
<span slot="header"> ${t`Protocol settings`} </span> <span slot="header"> ${t`Protocol settings`} </span>