internal: route traffic to proxy providers based on cookie domain when multiple domain-level providers exist

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>

#2079
This commit is contained in:
Jens Langhammer 2022-01-18 23:19:43 +01:00
parent 525976a81b
commit 14c7d8c4f4
4 changed files with 50 additions and 10 deletions

View file

@ -187,6 +187,10 @@ func (a *Application) Mode() api.ProxyMode {
return *a.proxyConfig.Mode
}
func (a *Application) ProxyConfig() api.ProxyOutpostConfig {
return a.proxyConfig
}
func (a *Application) ServeHTTP(rw http.ResponseWriter, r *http.Request) {
a.mux.ServeHTTP(rw, r)
}

View file

@ -8,6 +8,7 @@ import (
"time"
"github.com/prometheus/client_golang/prometheus"
"goauthentik.io/internal/outpost/proxyv2/application"
"goauthentik.io/internal/outpost/proxyv2/metrics"
"goauthentik.io/internal/utils/web"
staticWeb "goauthentik.io/web"
@ -43,6 +44,42 @@ func (ps *ProxyServer) HandleStatic(rw http.ResponseWriter, r *http.Request) {
}).Observe(float64(after))
}
func (ps *ProxyServer) lookupApp(r *http.Request) (*application.Application, string) {
host := web.GetHost(r)
// Try to find application by directly looking up host first (proxy, forward_auth_single)
a, ok := ps.apps[host]
if ok {
ps.log.WithField("host", host).WithField("app", a).Debug("Found app based direct host match")
return a, host
}
// For forward_auth_domain, we don't have a direct app to domain relationship
// Check through all apps, and check how much of their cookie domain matches the host
// Return the application that has the longest match
var longestMatch *application.Application
longestMatchLength := 0
for _, app := range ps.apps {
// Check if the cookie domain has a leading period for a wildcard
// This will decrease the weight of a wildcard domain, but a request to example.com
// with the cookie domain set to example.com will still be routed correctly.
cd := strings.TrimPrefix(*app.ProxyConfig().CookieDomain, ".")
if !strings.HasSuffix(host, cd) {
continue
}
if len(cd) < longestMatchLength {
continue
}
longestMatch = app
longestMatchLength = len(cd)
}
// Check if our longes match is 0, in which case we didn't match, so we
// manually return no app
if longestMatchLength == 0 {
return nil, host
}
ps.log.WithField("host", host).WithField("app", longestMatch).Debug("Found app based on cookie domain")
return longestMatch, host
}
func (ps *ProxyServer) Handle(rw http.ResponseWriter, r *http.Request) {
if strings.HasPrefix(r.URL.Path, "/akprox/static") {
ps.HandleStatic(rw, r)
@ -52,9 +89,8 @@ func (ps *ProxyServer) Handle(rw http.ResponseWriter, r *http.Request) {
ps.HandlePing(rw, r)
return
}
host := web.GetHost(r)
a, ok := ps.apps[host]
if !ok {
a, host := ps.lookupApp(r)
if a == nil {
// If we only have one handler, host name switching doesn't matter
if len(ps.apps) == 1 {
ps.log.WithField("host", host).Trace("passing to single app mux")

View file

@ -70,11 +70,12 @@ func NewProxyServer(ac *ak.APIController, portOffset int) *ProxyServer {
return s
}
func (ps *ProxyServer) HandleHost(host string, rw http.ResponseWriter, r *http.Request) bool {
if app, ok := ps.apps[host]; ok {
if app.Mode() == api.PROXYMODE_PROXY {
func (ps *ProxyServer) HandleHost(rw http.ResponseWriter, r *http.Request) bool {
a, host := ps.lookupApp(r)
if a != nil {
if a.Mode() == api.PROXYMODE_PROXY {
ps.log.WithField("host", host).Trace("routing to proxy outpost")
app.ServeHTTP(rw, r)
a.ServeHTTP(rw, r)
return true
}
}

View file

@ -47,10 +47,9 @@ func (ws *WebServer) configureProxy() {
ws.proxyErrorHandler(rw, r, fmt.Errorf("authentik core not running yet"))
return
}
host := web.GetHost(r)
before := time.Now()
if ws.ProxyServer != nil {
if ws.ProxyServer.HandleHost(host, rw, r) {
if ws.ProxyServer.HandleHost(rw, r) {
Requests.With(prometheus.Labels{
"dest": "embedded_outpost",
}).Observe(float64(time.Since(before)))
@ -60,7 +59,7 @@ func (ws *WebServer) configureProxy() {
Requests.With(prometheus.Labels{
"dest": "py",
}).Observe(float64(time.Since(before)))
ws.log.WithField("host", host).Trace("routing to application server")
ws.log.WithField("host", web.GetHost(r)).Trace("routing to application server")
rp.ServeHTTP(rw, r)
})
}