56 lines
1.9 KiB
Go
56 lines
1.9 KiB
Go
package ldap
|
|
|
|
import (
|
|
"strings"
|
|
|
|
"beryju.io/ldap"
|
|
goldap "github.com/go-ldap/ldap/v3"
|
|
"goauthentik.io/internal/outpost/ldap/constants"
|
|
"goauthentik.io/internal/outpost/ldap/search"
|
|
)
|
|
|
|
func (ls *LDAPServer) providerForRequest(req *search.Request) *ProviderInstance {
|
|
parsedBaseDN, err := goldap.ParseDN(strings.ToLower(req.BaseDN))
|
|
if err != nil {
|
|
req.Log().WithError(err).Info("failed to parse base DN")
|
|
return nil
|
|
}
|
|
parsedBindDN, err := goldap.ParseDN(strings.ToLower(req.BindDN))
|
|
if err != nil {
|
|
req.Log().WithError(err).Info("failed to parse bind DN")
|
|
return nil
|
|
}
|
|
var selectedProvider *ProviderInstance
|
|
longestMatch := 0
|
|
for _, provider := range ls.providers {
|
|
providerBase, _ := goldap.ParseDN(strings.ToLower(provider.BaseDN))
|
|
// Try to match the provider primarily based on the search request's base DN
|
|
baseDNMatches := providerBase.AncestorOf(parsedBaseDN) || providerBase.Equal(parsedBaseDN)
|
|
// But also try to match the provider based on the bind DN
|
|
bindDNMatches := providerBase.AncestorOf(parsedBindDN) || providerBase.Equal(parsedBindDN)
|
|
if baseDNMatches || bindDNMatches {
|
|
// Only select the provider if it's a more precise match than previously
|
|
if len(provider.BaseDN) > longestMatch {
|
|
req.Log().WithField("provider", provider.BaseDN).Trace("selecting provider for search request")
|
|
selectedProvider = provider
|
|
longestMatch = len(provider.BaseDN)
|
|
}
|
|
}
|
|
}
|
|
return selectedProvider
|
|
}
|
|
|
|
func (ls *LDAPServer) searchRoute(req *search.Request, pi *ProviderInstance) (ldap.ServerSearchResult, error) {
|
|
// Route based on the base DN
|
|
if len(req.BaseDN) == 0 {
|
|
req.Log().Trace("routing to base")
|
|
return pi.searcher.SearchBase(req)
|
|
}
|
|
if strings.EqualFold(req.BaseDN, "cn=subschema") || req.FilterObjectClass == constants.OCSubSchema {
|
|
req.Log().Trace("routing to subschema")
|
|
return pi.searcher.SearchSubschema(req)
|
|
}
|
|
req.Log().Trace("routing to default")
|
|
return pi.searcher.Search(req)
|
|
}
|