diff --git a/outpost/pkg/ldap/api.go b/outpost/pkg/ldap/api.go index 473958952..d6b6f1657 100644 --- a/outpost/pkg/ldap/api.go +++ b/outpost/pkg/ldap/api.go @@ -3,7 +3,9 @@ package ldap import ( "errors" "fmt" + "net/http" "strings" + "sync" "github.com/go-openapi/strfmt" log "github.com/sirupsen/logrus" @@ -29,6 +31,8 @@ func (ls *LDAPServer) Refresh() error { appSlug: *provider.ApplicationSlug, flowSlug: *provider.BindFlowSlug, searchAllowedGroups: []*strfmt.UUID{provider.SearchGroup}, + boundUsersMutex: sync.RWMutex{}, + boundUsers: make(map[string]UserFlags), s: ls, log: log.WithField("logger", "authentik.outpost.ldap").WithField("provider", provider.Name), } @@ -48,3 +52,17 @@ func (ls *LDAPServer) Start() error { } return nil } + +type transport struct { + headers map[string]string +} + +func (t *transport) RoundTrip(req *http.Request) (*http.Response, error) { + for key, value := range t.headers { + req.Header.Add(key, value) + } + return http.DefaultTransport.RoundTrip(req) +} +func newTransport(headers map[string]string) *transport { + return &transport{headers} +} diff --git a/outpost/pkg/ldap/instance_bind.go b/outpost/pkg/ldap/instance_bind.go index 905f3060c..e9cb527f0 100644 --- a/outpost/pkg/ldap/instance_bind.go +++ b/outpost/pkg/ldap/instance_bind.go @@ -7,6 +7,7 @@ import ( "net" "net/http" "net/http/cookiejar" + "net/url" "strings" "time" @@ -52,10 +53,21 @@ func (pi *ProviderInstance) Bind(username string, bindPW string, conn net.Conn) pi.log.WithError(err).Warning("Failed to create cookiejar") return ldap.LDAPResultOperationsError, nil } + host, _, err := net.SplitHostPort(conn.RemoteAddr().String()) + if err != nil { + pi.log.WithError(err).Warning("Failed to get remote IP") + return ldap.LDAPResultOperationsError, nil + } + // Create new http client that also sets the correct ip client := &http.Client{ Jar: jar, + Transport: newTransport(map[string]string{ + "X-authentik-remote-ip": host, + }), } - passed, err := pi.solveFlowChallenge(username, bindPW, client) + params := url.Values{} + params.Add("goauthentik.io/outpost/ldap", "true") + passed, err := pi.solveFlowChallenge(username, bindPW, client, params.Encode()) if err != nil { pi.log.WithField("boundDN", username).WithError(err).Warning("failed to solve challenge") return ldap.LDAPResultOperationsError, nil @@ -91,7 +103,7 @@ func (pi *ProviderInstance) Bind(username string, bindPW string, conn net.Conn) UserInfo: userInfo.Payload.User, CanSearch: pi.SearchAccessCheck(userInfo.Payload.User), } - pi.boundUsersMutex.Unlock() + defer pi.boundUsersMutex.Unlock() pi.delayDeleteUserInfo(username) return ldap.LDAPResultSuccess, nil } @@ -127,13 +139,13 @@ func (pi *ProviderInstance) delayDeleteUserInfo(dn string) { }() } -func (pi *ProviderInstance) solveFlowChallenge(bindDN string, password string, client *http.Client) (bool, error) { +func (pi *ProviderInstance) solveFlowChallenge(bindDN string, password string, client *http.Client, urlParams string) (bool, error) { challenge, err := pi.s.ac.Client.Flows.FlowsExecutorGet(&flows.FlowsExecutorGetParams{ FlowSlug: pi.flowSlug, - Query: "ldap=true", + Query: urlParams, Context: context.Background(), HTTPClient: client, - }, httptransport.PassThroughAuth) + }, pi.s.ac.Auth) if err != nil { pi.log.WithError(err).Warning("Failed to get challenge") return false, err @@ -141,7 +153,7 @@ func (pi *ProviderInstance) solveFlowChallenge(bindDN string, password string, c pi.log.WithField("component", challenge.Payload.Component).WithField("type", *challenge.Payload.Type).Debug("Got challenge") responseParams := &flows.FlowsExecutorSolveParams{ FlowSlug: pi.flowSlug, - Query: "ldap=true", + Query: urlParams, Context: context.Background(), HTTPClient: client, } @@ -155,7 +167,7 @@ func (pi *ProviderInstance) solveFlowChallenge(bindDN string, password string, c default: return false, fmt.Errorf("unsupported challenge type: %s", challenge.Payload.Component) } - response, err := pi.s.ac.Client.Flows.FlowsExecutorSolve(responseParams, httptransport.PassThroughAuth) + response, err := pi.s.ac.Client.Flows.FlowsExecutorSolve(responseParams, pi.s.ac.Auth) pi.log.WithField("component", response.Payload.Component).WithField("type", *response.Payload.Type).Debug("Got response") if *response.Payload.Type == "redirect" { return true, nil @@ -172,5 +184,5 @@ func (pi *ProviderInstance) solveFlowChallenge(bindDN string, password string, c } } } - return pi.solveFlowChallenge(bindDN, password, client) + return pi.solveFlowChallenge(bindDN, password, client, urlParams) } diff --git a/outpost/pkg/ldap/instance_search.go b/outpost/pkg/ldap/instance_search.go index 4e5665adc..b7102c347 100644 --- a/outpost/pkg/ldap/instance_search.go +++ b/outpost/pkg/ldap/instance_search.go @@ -30,10 +30,10 @@ func (pi *ProviderInstance) Search(bindDN string, searchReq ldap.SearchRequest, defer pi.boundUsersMutex.RUnlock() flags, ok := pi.boundUsers[bindDN] if !ok { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("Access denied") + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied") } if !flags.CanSearch { - return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("Access denied") + return ldap.ServerSearchResult{ResultCode: ldap.LDAPResultInsufficientAccessRights}, errors.New("access denied") } switch filterEntity {