diff --git a/internal/outpost/ak/outpost.go b/internal/outpost/ak/outpost.go index 88ae9e4fd..45d2a8daf 100644 --- a/internal/outpost/ak/outpost.go +++ b/internal/outpost/ak/outpost.go @@ -1,8 +1,11 @@ package ak +import "context" + type Outpost interface { Start() error + Stop() error Refresh() error - TimerFlowCacheExpiry() + TimerFlowCacheExpiry(context.Context) Type() string } diff --git a/internal/outpost/ak/periodical.go b/internal/outpost/ak/periodical.go index fb35f446a..83d38ddce 100644 --- a/internal/outpost/ak/periodical.go +++ b/internal/outpost/ak/periodical.go @@ -1,15 +1,16 @@ package ak import ( + "context" "time" ) func (a *APIController) startPeriodicalTasks() { - go a.Server.TimerFlowCacheExpiry() - go func() { - for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) { - a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks") - a.Server.TimerFlowCacheExpiry() - } - }() + ctx, canc := context.WithCancel(context.Background()) + defer canc() + go a.Server.TimerFlowCacheExpiry(ctx) + for range time.Tick(time.Duration(a.GlobalConfig.CacheTimeoutFlows) * time.Second) { + a.logger.WithField("timer", "cache-timeout").Debug("Running periodical tasks") + a.Server.TimerFlowCacheExpiry(ctx) + } } diff --git a/internal/outpost/flow/executor.go b/internal/outpost/flow/executor.go index 570e34a06..5bcafd687 100644 --- a/internal/outpost/flow/executor.go +++ b/internal/outpost/flow/executor.go @@ -143,6 +143,10 @@ func (fe *FlowExecutor) GetSession() *http.Cookie { return fe.session } +func (fe *FlowExecutor) SetSession(s *http.Cookie) { + fe.session = s +} + // WarmUp Ensure authentik's flow cache is warmed up func (fe *FlowExecutor) WarmUp() error { gcsp := sentry.StartSpan(fe.Context, "authentik.outposts.flow_executor.get_challenge") diff --git a/internal/outpost/ldap/bind/binder.go b/internal/outpost/ldap/bind/binder.go index d0aa4a618..ddf5414d6 100644 --- a/internal/outpost/ldap/bind/binder.go +++ b/internal/outpost/ldap/bind/binder.go @@ -1,9 +1,14 @@ package bind -import "github.com/nmcclain/ldap" +import ( + "context" + + "github.com/nmcclain/ldap" +) type Binder interface { GetUsername(string) (string, error) Bind(username string, req *Request) (ldap.LDAPResultCode, error) - TimerFlowCacheExpiry() + Unbind(username string, req *Request) (ldap.LDAPResultCode, error) + TimerFlowCacheExpiry(context.Context) } diff --git a/internal/outpost/ldap/bind/direct/bind.go b/internal/outpost/ldap/bind/direct/bind.go new file mode 100644 index 000000000..1b07c0a8f --- /dev/null +++ b/internal/outpost/ldap/bind/direct/bind.go @@ -0,0 +1,98 @@ +package direct + +import ( + "context" + + "github.com/getsentry/sentry-go" + "github.com/nmcclain/ldap" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "goauthentik.io/internal/outpost/flow" + "goauthentik.io/internal/outpost/ldap/bind" + "goauthentik.io/internal/outpost/ldap/flags" + "goauthentik.io/internal/outpost/ldap/metrics" +) + +func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) { + fe := flow.NewFlowExecutor(req.Context(), db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{ + "bindDN": req.BindDN, + "client": req.RemoteAddr(), + "requestId": req.ID(), + }) + fe.DelegateClientIP(req.RemoteAddr()) + fe.Params.Add("goauthentik.io/outpost/ldap", "true") + + fe.Answers[flow.StageIdentification] = username + fe.Answers[flow.StagePassword] = req.BindPW + + passed, err := fe.Execute() + flags := flags.UserFlags{ + Session: fe.GetSession(), + } + db.si.SetFlags(req.BindDN, &flags) + if err != nil { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": db.si.GetOutpostName(), + "type": "bind", + "reason": "flow_error", + "app": db.si.GetAppSlug(), + }).Inc() + req.Log().WithError(err).Warning("failed to execute flow") + return ldap.LDAPResultInvalidCredentials, nil + } + if !passed { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": db.si.GetOutpostName(), + "type": "bind", + "reason": "invalid_credentials", + "app": db.si.GetAppSlug(), + }).Inc() + req.Log().Info("Invalid credentials") + return ldap.LDAPResultInvalidCredentials, nil + } + + access, err := fe.CheckApplicationAccess(db.si.GetAppSlug()) + if !access { + req.Log().Info("Access denied for user") + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": db.si.GetOutpostName(), + "type": "bind", + "reason": "access_denied", + "app": db.si.GetAppSlug(), + }).Inc() + return ldap.LDAPResultInsufficientAccessRights, nil + } + if err != nil { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": db.si.GetOutpostName(), + "type": "bind", + "reason": "access_check_fail", + "app": db.si.GetAppSlug(), + }).Inc() + req.Log().WithError(err).Warning("failed to check access") + return ldap.LDAPResultOperationsError, nil + } + req.Log().Info("User has access") + uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info") + // Get user info to store in context + userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute() + if err != nil { + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": db.si.GetOutpostName(), + "type": "bind", + "reason": "user_info_fail", + "app": db.si.GetAppSlug(), + }).Inc() + req.Log().WithError(err).Warning("failed to get user info") + return ldap.LDAPResultOperationsError, nil + } + cs := db.SearchAccessCheck(userInfo.User) + flags.UserPk = userInfo.User.Pk + flags.CanSearch = cs != nil + db.si.SetFlags(req.BindDN, &flags) + if flags.CanSearch { + req.Log().WithField("group", cs).Info("Allowed access to search") + } + uisp.Finish() + return ldap.LDAPResultSuccess, nil +} diff --git a/internal/outpost/ldap/bind/direct/direct.go b/internal/outpost/ldap/bind/direct/direct.go index db8a7e7b0..cd4349850 100644 --- a/internal/outpost/ldap/bind/direct/direct.go +++ b/internal/outpost/ldap/bind/direct/direct.go @@ -5,16 +5,10 @@ import ( "errors" "strings" - "github.com/getsentry/sentry-go" goldap "github.com/go-ldap/ldap/v3" - "github.com/nmcclain/ldap" - "github.com/prometheus/client_golang/prometheus" log "github.com/sirupsen/logrus" "goauthentik.io/api/v3" "goauthentik.io/internal/outpost/flow" - "goauthentik.io/internal/outpost/ldap/bind" - "goauthentik.io/internal/outpost/ldap/flags" - "goauthentik.io/internal/outpost/ldap/metrics" "goauthentik.io/internal/outpost/ldap/server" "goauthentik.io/internal/outpost/ldap/utils" ) @@ -53,90 +47,6 @@ func (db *DirectBinder) GetUsername(dn string) (string, error) { return "", errors.New("failed to find cn") } -func (db *DirectBinder) Bind(username string, req *bind.Request) (ldap.LDAPResultCode, error) { - fe := flow.NewFlowExecutor(req.Context(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{ - "bindDN": req.BindDN, - "client": req.RemoteAddr(), - "requestId": req.ID(), - }) - fe.DelegateClientIP(req.RemoteAddr()) - fe.Params.Add("goauthentik.io/outpost/ldap", "true") - - fe.Answers[flow.StageIdentification] = username - fe.Answers[flow.StagePassword] = req.BindPW - - passed, err := fe.Execute() - flags := flags.UserFlags{ - Session: fe.GetSession(), - } - db.si.SetFlags(req.BindDN, flags) - if err != nil { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": db.si.GetOutpostName(), - "type": "bind", - "reason": "flow_error", - "app": db.si.GetAppSlug(), - }).Inc() - req.Log().WithError(err).Warning("failed to execute flow") - return ldap.LDAPResultInvalidCredentials, nil - } - if !passed { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": db.si.GetOutpostName(), - "type": "bind", - "reason": "invalid_credentials", - "app": db.si.GetAppSlug(), - }).Inc() - req.Log().Info("Invalid credentials") - return ldap.LDAPResultInvalidCredentials, nil - } - - access, err := fe.CheckApplicationAccess(db.si.GetAppSlug()) - if !access { - req.Log().Info("Access denied for user") - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": db.si.GetOutpostName(), - "type": "bind", - "reason": "access_denied", - "app": db.si.GetAppSlug(), - }).Inc() - return ldap.LDAPResultInsufficientAccessRights, nil - } - if err != nil { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": db.si.GetOutpostName(), - "type": "bind", - "reason": "access_check_fail", - "app": db.si.GetAppSlug(), - }).Inc() - req.Log().WithError(err).Warning("failed to check access") - return ldap.LDAPResultOperationsError, nil - } - req.Log().Info("User has access") - uisp := sentry.StartSpan(req.Context(), "authentik.providers.ldap.bind.user_info") - // Get user info to store in context - userInfo, _, err := fe.ApiClient().CoreApi.CoreUsersMeRetrieve(context.Background()).Execute() - if err != nil { - metrics.RequestsRejected.With(prometheus.Labels{ - "outpost_name": db.si.GetOutpostName(), - "type": "bind", - "reason": "user_info_fail", - "app": db.si.GetAppSlug(), - }).Inc() - req.Log().WithError(err).Warning("failed to get user info") - return ldap.LDAPResultOperationsError, nil - } - cs := db.SearchAccessCheck(userInfo.User) - flags.UserPk = userInfo.User.Pk - flags.CanSearch = cs != nil - db.si.SetFlags(req.BindDN, flags) - if flags.CanSearch { - req.Log().WithField("group", cs).Info("Allowed access to search") - } - uisp.Finish() - return ldap.LDAPResultSuccess, nil -} - // SearchAccessCheck Check if the current user is allowed to search func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string { for _, group := range user.Groups { @@ -153,8 +63,8 @@ func (db *DirectBinder) SearchAccessCheck(user api.UserSelf) *string { return nil } -func (db *DirectBinder) TimerFlowCacheExpiry() { - fe := flow.NewFlowExecutor(context.Background(), db.si.GetFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{}) +func (db *DirectBinder) TimerFlowCacheExpiry(ctx context.Context) { + fe := flow.NewFlowExecutor(ctx, db.si.GetAuthenticationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{}) fe.Params.Add("goauthentik.io/outpost/ldap", "true") fe.Params.Add("goauthentik.io/outpost/ldap-warmup", "true") diff --git a/internal/outpost/ldap/bind/direct/unbind.go b/internal/outpost/ldap/bind/direct/unbind.go new file mode 100644 index 000000000..39d64fb79 --- /dev/null +++ b/internal/outpost/ldap/bind/direct/unbind.go @@ -0,0 +1,29 @@ +package direct + +import ( + "github.com/nmcclain/ldap" + log "github.com/sirupsen/logrus" + "goauthentik.io/internal/outpost/flow" + "goauthentik.io/internal/outpost/ldap/bind" +) + +func (db *DirectBinder) Unbind(username string, req *bind.Request) (ldap.LDAPResultCode, error) { + flags := db.si.GetFlags(req.BindDN) + if flags == nil || flags.Session == nil { + return ldap.LDAPResultSuccess, nil + } + fe := flow.NewFlowExecutor(req.Context(), db.si.GetInvalidationFlowSlug(), db.si.GetAPIClient().GetConfig(), log.Fields{ + "boundDN": req.BindDN, + "client": req.RemoteAddr(), + "requestId": req.ID(), + }) + fe.SetSession(flags.Session) + fe.DelegateClientIP(req.RemoteAddr()) + fe.Params.Add("goauthentik.io/outpost/ldap", "true") + _, err := fe.Execute() + if err != nil { + db.log.WithError(err).Warning("failed to logout user") + } + db.si.SetFlags(req.BindDN, nil) + return ldap.LDAPResultSuccess, nil +} diff --git a/internal/outpost/ldap/instance.go b/internal/outpost/ldap/instance.go index 1d98ef1df..b39dc8c51 100644 --- a/internal/outpost/ldap/instance.go +++ b/internal/outpost/ldap/instance.go @@ -27,10 +27,11 @@ type ProviderInstance struct { searcher search.Searcher binder bind.Binder - appSlug string - flowSlug string - s *LDAPServer - log *log.Entry + appSlug string + authenticationFlowSlug string + invalidationFlowSlug string + s *LDAPServer + log *log.Entry tlsServerName *string cert *tls.Certificate @@ -79,9 +80,13 @@ func (pi *ProviderInstance) GetFlags(dn string) *flags.UserFlags { return flags } -func (pi *ProviderInstance) SetFlags(dn string, flag flags.UserFlags) { +func (pi *ProviderInstance) SetFlags(dn string, flag *flags.UserFlags) { pi.boundUsersMutex.Lock() - pi.boundUsers[dn] = &flag + if flag == nil { + delete(pi.boundUsers, dn) + } else { + pi.boundUsers[dn] = flag + } pi.boundUsersMutex.Unlock() } @@ -89,8 +94,12 @@ func (pi *ProviderInstance) GetAppSlug() string { return pi.appSlug } -func (pi *ProviderInstance) GetFlowSlug() string { - return pi.flowSlug +func (pi *ProviderInstance) GetAuthenticationFlowSlug() string { + return pi.authenticationFlowSlug +} + +func (pi *ProviderInstance) GetInvalidationFlowSlug() string { + return pi.invalidationFlowSlug } func (pi *ProviderInstance) GetSearchAllowedGroups() []*strfmt.UUID { diff --git a/internal/outpost/ldap/ldap.go b/internal/outpost/ldap/ldap.go index 2388b62b0..3ab8e4fd6 100644 --- a/internal/outpost/ldap/ldap.go +++ b/internal/outpost/ldap/ldap.go @@ -1,6 +1,7 @@ package ldap import ( + "context" "crypto/tls" "net" "sync" @@ -40,6 +41,7 @@ func NewServer(ac *ak.APIController) *LDAPServer { } ls.defaultCert = &defaultCert s.BindFunc("", ls) + s.UnbindFunc("", ls) s.SearchFunc("", ls) return ls } @@ -92,9 +94,13 @@ func (ls *LDAPServer) Start() error { return nil } -func (ls *LDAPServer) TimerFlowCacheExpiry() { +func (ls *LDAPServer) Stop() error { + return nil +} + +func (ls *LDAPServer) TimerFlowCacheExpiry(ctx context.Context) { for _, p := range ls.providers { - ls.log.WithField("flow", p.flowSlug).Debug("Pre-heating flow cache") - p.binder.TimerFlowCacheExpiry() + ls.log.WithField("flow", p.authenticationFlowSlug).Debug("Pre-heating flow cache") + p.binder.TimerFlowCacheExpiry(ctx) } } diff --git a/internal/outpost/ldap/refresh.go b/internal/outpost/ldap/refresh.go index 34067a876..40b757c92 100644 --- a/internal/outpost/ldap/refresh.go +++ b/internal/outpost/ldap/refresh.go @@ -28,6 +28,16 @@ func (ls *LDAPServer) getCurrentProvider(pk int32) *ProviderInstance { return nil } +func (ls *LDAPServer) getInvalidationFlow() string { + req, _, err := ls.ac.Client.CoreApi.CoreTenantsCurrentRetrieve(context.Background()).Execute() + if err != nil { + ls.log.WithError(err).Warning("failed to fetch tenant config") + return "" + } + flow := req.GetFlowInvalidation() + return flow +} + func (ls *LDAPServer) Refresh() error { outposts, _, err := ls.ac.Client.OutpostsApi.OutpostsLdapList(context.Background()).Execute() if err != nil { @@ -37,6 +47,7 @@ func (ls *LDAPServer) Refresh() error { return errors.New("no ldap provider defined") } providers := make([]*ProviderInstance, len(outposts.Results)) + invalidationFlow := ls.getInvalidationFlow() for idx, provider := range outposts.Results { userDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUUsers, *provider.BaseDn)) groupDN := strings.ToLower(fmt.Sprintf("ou=%s,%s", constants.OUGroups, *provider.BaseDn)) @@ -53,22 +64,23 @@ func (ls *LDAPServer) Refresh() error { } providers[idx] = &ProviderInstance{ - BaseDN: *provider.BaseDn, - VirtualGroupDN: virtualGroupDN, - GroupDN: groupDN, - UserDN: userDN, - appSlug: provider.ApplicationSlug, - flowSlug: provider.BindFlowSlug, - searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, - boundUsersMutex: sync.RWMutex{}, - boundUsers: users, - s: ls, - log: logger, - tlsServerName: provider.TlsServerName, - uidStartNumber: *provider.UidStartNumber, - gidStartNumber: *provider.GidStartNumber, - outpostName: ls.ac.Outpost.Name, - outpostPk: provider.Pk, + BaseDN: *provider.BaseDn, + VirtualGroupDN: virtualGroupDN, + GroupDN: groupDN, + UserDN: userDN, + appSlug: provider.ApplicationSlug, + authenticationFlowSlug: provider.BindFlowSlug, + invalidationFlowSlug: invalidationFlow, + searchAllowedGroups: []*strfmt.UUID{(*strfmt.UUID)(provider.SearchGroup.Get())}, + boundUsersMutex: sync.RWMutex{}, + boundUsers: users, + s: ls, + log: logger, + tlsServerName: provider.TlsServerName, + uidStartNumber: *provider.UidStartNumber, + gidStartNumber: *provider.GidStartNumber, + outpostName: ls.ac.Outpost.Name, + outpostPk: provider.Pk, } if kp := provider.Certificate.Get(); kp != nil { err := ls.cs.AddKeypair(*kp) diff --git a/internal/outpost/ldap/server/base.go b/internal/outpost/ldap/server/base.go index 982aa4fc2..cc5dee89e 100644 --- a/internal/outpost/ldap/server/base.go +++ b/internal/outpost/ldap/server/base.go @@ -11,7 +11,8 @@ type LDAPServerInstance interface { GetAPIClient() *api.APIClient GetOutpostName() string - GetFlowSlug() string + GetAuthenticationFlowSlug() string + GetInvalidationFlowSlug() string GetAppSlug() string GetSearchAllowedGroups() []*strfmt.UUID @@ -32,7 +33,7 @@ type LDAPServerInstance interface { UsersForGroup(api.Group) []string GetFlags(dn string) *flags.UserFlags - SetFlags(dn string, flags flags.UserFlags) + SetFlags(dn string, flags *flags.UserFlags) GetBaseEntry() *ldap.Entry GetNeededObjects(int, string, string) (bool, bool) diff --git a/internal/outpost/ldap/unbind.go b/internal/outpost/ldap/unbind.go new file mode 100644 index 000000000..52c64f71e --- /dev/null +++ b/internal/outpost/ldap/unbind.go @@ -0,0 +1,53 @@ +package ldap + +import ( + "net" + + "github.com/getsentry/sentry-go" + "github.com/nmcclain/ldap" + "github.com/prometheus/client_golang/prometheus" + log "github.com/sirupsen/logrus" + "goauthentik.io/internal/outpost/ldap/bind" + "goauthentik.io/internal/outpost/ldap/metrics" +) + +func (ls *LDAPServer) Unbind(boundDN string, conn net.Conn) (ldap.LDAPResultCode, error) { + req, span := bind.NewRequest(boundDN, "", conn) + selectedApp := "" + defer func() { + span.Finish() + metrics.Requests.With(prometheus.Labels{ + "outpost_name": ls.ac.Outpost.Name, + "type": "unbind", + "app": selectedApp, + }).Observe(float64(span.EndTime.Sub(span.StartTime))) + req.Log().WithField("took-ms", span.EndTime.Sub(span.StartTime).Milliseconds()).Info("Unbind request") + }() + + defer func() { + err := recover() + if err == nil { + return + } + log.WithError(err.(error)).Error("recover in bind request") + sentry.CaptureException(err.(error)) + }() + + for _, instance := range ls.providers { + username, err := instance.binder.GetUsername(boundDN) + if err == nil { + selectedApp = instance.GetAppSlug() + return instance.binder.Unbind(username, req) + } else { + req.Log().WithError(err).Debug("Username not for instance") + } + } + req.Log().WithField("request", "unbind").Warning("No provider found for request") + metrics.RequestsRejected.With(prometheus.Labels{ + "outpost_name": ls.ac.Outpost.Name, + "type": "unbind", + "reason": "no_provider", + "app": "", + }).Inc() + return ldap.LDAPResultOperationsError, nil +} diff --git a/internal/outpost/proxyv2/proxyv2.go b/internal/outpost/proxyv2/proxyv2.go index 784bc5860..1271c39bf 100644 --- a/internal/outpost/proxyv2/proxyv2.go +++ b/internal/outpost/proxyv2/proxyv2.go @@ -81,7 +81,7 @@ func (ps *ProxyServer) Type() string { return "proxy" } -func (ps *ProxyServer) TimerFlowCacheExpiry() {} +func (ps *ProxyServer) TimerFlowCacheExpiry(context.Context) {} func (ps *ProxyServer) GetCertificate(serverName string) *tls.Certificate { app, ok := ps.apps[serverName] @@ -163,6 +163,10 @@ func (ps *ProxyServer) Start() error { return nil } +func (ps *ProxyServer) Stop() error { + return nil +} + func (ps *ProxyServer) serve(listener net.Listener) { srv := &http.Server{Handler: ps.mux}