providers/ldap: add unbind flow execution (#4484)

add unbind flow execution

Signed-off-by: Jens Langhammer <jens@goauthentik.io>

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-01-23 20:36:30 +01:00 committed by GitHub
parent b2d272bf6f
commit a9b32e2f97
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
13 changed files with 267 additions and 132 deletions

View file

@ -1,8 +1,11 @@
package ak
import "context"
type Outpost interface {
Start() error
Stop() error
Refresh() error
TimerFlowCacheExpiry()
TimerFlowCacheExpiry(context.Context)
Type() string
}

View file

@ -1,15 +1,16 @@
package ak
import (
"context"
"time"
)
func (a *APIController) startPeriodicalTasks() {
go a.Server.TimerFlowCacheExpiry()
go func() {
for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) {
a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks")
a.Server.TimerFlowCacheExpiry()
}
}()
ctx, canc := context.WithCancel(context.Background())
defer canc()
go a.Server.TimerFlowCacheExpiry(ctx)
for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) {
a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks")
a.Server.TimerFlowCacheExpiry(ctx)
}
}

View file

@ -143,6 +143,10 @@ func (fe *FlowExecutor) GetSession() *http.Cookie {
return fe.session
}
func (fe *FlowExecutor) SetSession(s *http.Cookie) {
fe.session = s
}
// WarmUp Ensure authentik's flow cache is warmed up
func (fe *FlowExecutor) WarmUp() error {
gcsp := sentry.StartSpan(fe.Context, "authentik.outposts.flow_executor.get_challenge")

View file

@ -1,9 +1,14 @@
package bind
import "github.com/nmcclain/ldap"
import (
"context"
"github.com/nmcclain/ldap"
)
type Binder interface {
GetUsername(string) (string, error)
Bind(username string, req *Request) (ldap.LDAPResultCode, error)
TimerFlowCacheExpiry()
Unbind(username string, req *Request) (ldap.LDAPResultCode, error)
TimerFlowCacheExpiry(context.Context)
}

View file

@ -0,0 +1,98 @@
package direct
import (
"context"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/metrics"
)
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
fe := flow.NewFlowExecutor(req.Context(), db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"bindDN": req.BindDN,
"client": req.RemoteAddr(),
"requestId": req.ID(),
})
fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Answers[flow.StageIdentification] = username
fe.Answers[flow.StagePassword] = req.BindPW
passed, err := fe.Execute()
flags := flags.UserFlags{
Session: fe.GetSession(),
}
db.si.SetFlags(req.BindDN, &flags)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "flow_error",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to execute flow")
return ldap.LDAPResultInvalidCredentials, nil
}
if !passed {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "invalid_credentials",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().Info("Invalid credentials")
return ldap.LDAPResultInvalidCredentials, nil
}
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
if !access {
req.Log().Info("Access denied for user")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_denied",
"app": db.si.GetAppSlug(),
}).Inc()
return ldap.LDAPResultInsufficientAccessRights, nil
}
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_check_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to check access")
return ldap.LDAPResultOperationsError, nil
}
req.Log().Info("User has access")
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
// Get user info to store in context
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "user_info_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to get user info")
return ldap.LDAPResultOperationsError, nil
}
cs := db.SearchAccessCheck(userInfo.User)
flags.UserPk = userInfo.User.Pk
flags.CanSearch = cs != nil
db.si.SetFlags(req.BindDN, &flags)
if flags.CanSearch {
req.Log().WithField("group", cs).Info("Allowed access to search")
}
uisp.Finish()
return ldap.LDAPResultSuccess, nil
}

View file

@ -5,16 +5,10 @@ import (
"errors"
"strings"
"github.com/getsentry/sentry-go"
goldap "github.com/go-ldap/ldap/v3"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/api/v3"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/flags"
"goauthentik.io/internal/outpost/ldap/metrics"
"goauthentik.io/internal/outpost/ldap/server"
"goauthentik.io/internal/outpost/ldap/utils"
)
@ -53,90 +47,6 @@ func (db *DirectBinder) GetUsername(dn string) (string, error) {
return "", errors.New("failed to find cn")
}
func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
fe := flow.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"bindDN": req.BindDN,
"client": req.RemoteAddr(),
"requestId": req.ID(),
})
fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Answers[flow.StageIdentification] = username
fe.Answers[flow.StagePassword] = req.BindPW
passed, err := fe.Execute()
flags := flags.UserFlags{
Session: fe.GetSession(),
}
db.si.SetFlags(req.BindDN, flags)
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "flow_error",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to execute flow")
return ldap.LDAPResultInvalidCredentials, nil
}
if !passed {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "invalid_credentials",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().Info("Invalid credentials")
return ldap.LDAPResultInvalidCredentials, nil
}
access, err := fe.CheckApplicationAccess(db.si.GetAppSlug())
if !access {
req.Log().Info("Access denied for user")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_denied",
"app": db.si.GetAppSlug(),
}).Inc()
return ldap.LDAPResultInsufficientAccessRights, nil
}
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "access_check_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to check access")
return ldap.LDAPResultOperationsError, nil
}
req.Log().Info("User has access")
uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info")
// Get user info to store in context
userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute()
if err != nil {
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": db.si.GetOutpostName(),
"type": "bind",
"reason": "user_info_fail",
"app": db.si.GetAppSlug(),
}).Inc()
req.Log().WithError(err).Warning("failed to get user info")
return ldap.LDAPResultOperationsError, nil
}
cs := db.SearchAccessCheck(userInfo.User)
flags.UserPk = userInfo.User.Pk
flags.CanSearch = cs != nil
db.si.SetFlags(req.BindDN, flags)
if flags.CanSearch {
req.Log().WithField("group", cs).Info("Allowed access to search")
}
uisp.Finish()
return ldap.LDAPResultSuccess, nil
}
// SearchAccessCheck Check if the current user is allowed to search
func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
for _, group := range user.Groups {
@ -153,8 +63,8 @@ func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string {
return nil
}
func (db *DirectBinder) TimerFlowCacheExpiry() {
fe := flow.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
func (db *DirectBinder) TimerFlowCacheExpiry(ctx context.Context) {
fe := flow.NewFlowExecutor(ctx, db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{})
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true")

View file

@ -0,0 +1,29 @@
package direct
import (
"github.com/nmcclain/ldap"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/flow"
"goauthentik.io/internal/outpost/ldap/bind"
)
func (db *DirectBinder) Unbind(username string, req *bind.Request) (ldap.LDAPResultCode, error) {
flags := db.si.GetFlags(req.BindDN)
if flags == nil || flags.Session == nil {
return ldap.LDAPResultSuccess, nil
}
fe := flow.NewFlowExecutor(req.Context(), db.si.GetInvalidationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{
"boundDN": req.BindDN,
"client": req.RemoteAddr(),
"requestId": req.ID(),
})
fe.SetSession(flags.Session)
fe.DelegateClientIP(req.RemoteAddr())
fe.Params.Add("goauthentik.io/outpost/ldap", "true")
_, err := fe.Execute()
if err != nil {
db.log.WithError(err).Warning("failed to logout user")
}
db.si.SetFlags(req.BindDN, nil)
return ldap.LDAPResultSuccess, nil
}

View file

@ -27,10 +27,11 @@ type ProviderInstance struct {
searcher search.Searcher
binder bind.Binder
appSlug string
flowSlug string
s *LDAPServer
log *log.Entry
appSlug string
authenticationFlowSlug string
invalidationFlowSlug string
s *LDAPServer
log *log.Entry
tlsServerName *string
cert *tls.Certificate
@ -79,9 +80,13 @@ func (pi *ProviderInstance) GetFlags(dn string) *flags.UserFlags {
return flags
}
func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) {
func (pi *ProviderInstance) SetFlags(dn string, flag *flags.UserFlags) {
pi.boundUsersMutex.Lock()
pi.boundUsers[dn] = &flag
if flag == nil {
delete(pi.boundUsers, dn)
} else {
pi.boundUsers[dn] = flag
}
pi.boundUsersMutex.Unlock()
}
@ -89,8 +94,12 @@ func (pi *ProviderInstance) GetAppSlug() string {
return pi.appSlug
}
func (pi *ProviderInstance) GetFlowSlug() string {
return pi.flowSlug
func (pi *ProviderInstance) GetAuthenticationFlowSlug() string {
return pi.authenticationFlowSlug
}
func (pi *ProviderInstance) GetInvalidationFlowSlug() string {
return pi.invalidationFlowSlug
}
func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID {

View file

@ -1,6 +1,7 @@
package ldap
import (
"context"
"crypto/tls"
"net"
"sync"
@ -40,6 +41,7 @@ func NewServer(ac *ak.APIController) *LDAPServer {
}
ls.defaultCert = &defaultCert
s.BindFunc("", ls)
s.UnbindFunc("", ls)
s.SearchFunc("", ls)
return ls
}
@ -92,9 +94,13 @@ func (ls *LDAPServer) Start() error {
return nil
}
func (ls *LDAPServer) TimerFlowCacheExpiry() {
func (ls *LDAPServer) Stop() error {
return nil
}
func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) {
for _, p := range ls.providers {
ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache")
p.binder.TimerFlowCacheExpiry()
ls.log.WithField("flow", p.authenticationFlowSlug).Debug("Pre-heating flow cache")
p.binder.TimerFlowCacheExpiry(ctx)
}
}

View file

@ -28,6 +28,16 @@ func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance {
return nil
}
func (ls *LDAPServer) getInvalidationFlow() string {
req, _, err := ls.ac.Client.CoreApi.CoreTenantsCurrentRetrieve(context.Background()).Execute()
if err != nil {
ls.log.WithError(err).Warning("failed to fetch tenant config")
return ""
}
flow := req.GetFlowInvalidation()
return flow
}
func (ls *LDAPServer) Refresh() error {
outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute()
if err != nil {
@ -37,6 +47,7 @@ func (ls *LDAPServer) Refresh() error {
return errors.New("no ldap provider defined")
}
providers := make([]*ProviderInstance, len(outposts.Results))
invalidationFlow := ls.getInvalidationFlow()
for idx, provider := range outposts.Results {
userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn))
groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn))
@ -53,22 +64,23 @@ func (ls *LDAPServer) Refresh() error {
}
providers[idx] = &ProviderInstance{
BaseDN: *provider.BaseDn,
VirtualGroupDN: virtualGroupDN,
GroupDN: groupDN,
UserDN: userDN,
appSlug: provider.ApplicationSlug,
flowSlug: provider.BindFlowSlug,
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
boundUsersMutex: sync.RWMutex{},
boundUsers: users,
s: ls,
log: logger,
tlsServerName: provider.TlsServerName,
uidStartNumber: *provider.UidStartNumber,
gidStartNumber: *provider.GidStartNumber,
outpostName: ls.ac.Outpost.Name,
outpostPk: provider.Pk,
BaseDN: *provider.BaseDn,
VirtualGroupDN: virtualGroupDN,
GroupDN: groupDN,
UserDN: userDN,
appSlug: provider.ApplicationSlug,
authenticationFlowSlug: provider.BindFlowSlug,
invalidationFlowSlug: invalidationFlow,
searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())},
boundUsersMutex: sync.RWMutex{},
boundUsers: users,
s: ls,
log: logger,
tlsServerName: provider.TlsServerName,
uidStartNumber: *provider.UidStartNumber,
gidStartNumber: *provider.GidStartNumber,
outpostName: ls.ac.Outpost.Name,
outpostPk: provider.Pk,
}
if kp := provider.Certificate.Get(); kp != nil {
err := ls.cs.AddKeypair(*kp)

View file

@ -11,7 +11,8 @@ type LDAPServerInstance interface {
GetAPIClient() *api.APIClient
GetOutpostName() string
GetFlowSlug() string
GetAuthenticationFlowSlug() string
GetInvalidationFlowSlug() string
GetAppSlug() string
GetSearchAllowedGroups() []*strfmt.UUID
@ -32,7 +33,7 @@ type LDAPServerInstance interface {
UsersForGroup(api.Group) []string
GetFlags(dn string) *flags.UserFlags
SetFlags(dn string, flags flags.UserFlags)
SetFlags(dn string, flags *flags.UserFlags)
GetBaseEntry() *ldap.Entry
GetNeededObjects(int, string, string) (bool, bool)

View file

@ -0,0 +1,53 @@
package ldap
import (
"net"
"github.com/getsentry/sentry-go"
"github.com/nmcclain/ldap"
"github.com/prometheus/client_golang/prometheus"
log "github.com/sirupsen/logrus"
"goauthentik.io/internal/outpost/ldap/bind"
"goauthentik.io/internal/outpost/ldap/metrics"
)
func (ls *LDAPServer) Unbind(boundDN string, conn net.Conn) (ldap.LDAPResultCode, error) {
req, span := bind.NewRequest(boundDN, "", conn)
selectedApp := ""
defer func() {
span.Finish()
metrics.Requests.With(prometheus.Labels{
"outpost_name": ls.ac.Outpost.Name,
"type": "unbind",
"app": selectedApp,
}).Observe(float64(span.EndTime.Sub(span.StartTime)))
req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Unbind request")
}()
defer func() {
err := recover()
if err == nil {
return
}
log.WithError(err.(error)).Error("recover in bind request")
sentry.CaptureException(err.(error))
}()
for _, instance := range ls.providers {
username, err := instance.binder.GetUsername(boundDN)
if err == nil {
selectedApp = instance.GetAppSlug()
return instance.binder.Unbind(username, req)
} else {
req.Log().WithError(err).Debug("Username not for instance")
}
}
req.Log().WithField("request", "unbind").Warning("No provider found for request")
metrics.RequestsRejected.With(prometheus.Labels{
"outpost_name": ls.ac.Outpost.Name,
"type": "unbind",
"reason": "no_provider",
"app": "",
}).Inc()
return ldap.LDAPResultOperationsError, nil
}

View file

@ -81,7 +81,7 @@ func (ps *ProxyServer) Type() string {
return "proxy"
}
func (ps *ProxyServer) TimerFlowCacheExpiry() {}
func (ps *ProxyServer) TimerFlowCacheExpiry(context.Context) {}
func (ps *ProxyServer) GetCertificate(serverName string) *tls.Certificate {
app, ok := ps.apps[serverName]
@ -163,6 +163,10 @@ func (ps *ProxyServer) Start() error {
return nil
}
func (ps *ProxyServer) Stop() error {
return nil
}
func (ps *ProxyServer) serve(listener net.Listener) {
srv := &http.Server{Handler: ps.mux}