diff --git a/outpost/pkg/ak/api_ws.go b/outpost/pkg/ak/api_ws.go index 851266aa9..235247ec1 100644 --- a/outpost/pkg/ak/api_ws.go +++ b/outpost/pkg/ak/api_ws.go @@ -15,9 +15,9 @@ import ( "goauthentik.io/outpost/pkg" ) -func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) { +func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) { pathTemplate := "%s://%s/ws/outpost/%s/" - scheme := strings.ReplaceAll(pbURL.Scheme, "http", "ws") + scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws") authHeader := fmt.Sprintf("Bearer %s", ac.token) @@ -37,7 +37,7 @@ func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) { InsecureSkipVerify: strings.ToLower(value) == "true", }, } - ws.Dial(fmt.Sprintf(pathTemplate, scheme, pbURL.Host, outpostUUID.String()), header) + ws.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID.String()), header) ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("connecting to authentik") diff --git a/outpost/pkg/proxy/middleware.go b/outpost/pkg/proxy/middleware.go index 3e9e924fc..7c938ea77 100644 --- a/outpost/pkg/proxy/middleware.go +++ b/outpost/pkg/proxy/middleware.go @@ -107,7 +107,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) { duration := float64(time.Since(t)) / float64(time.Millisecond) h.logger.WithFields(log.Fields{ "host": req.RemoteAddr, - "vhost": req.Host, + "vhost": getHost(req), "request_protocol": req.Proto, "runtime": fmt.Sprintf("%0.3f", duration), "method": req.Method, diff --git a/outpost/pkg/proxy/oauth.go b/outpost/pkg/proxy/oauth.go index 9ee0bcadd..9d1d894fa 100644 --- a/outpost/pkg/proxy/oauth.go +++ b/outpost/pkg/proxy/oauth.go @@ -161,7 +161,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) { p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) return } - redirectURI := p.GetRedirectURI(req.Host) + redirectURI := p.GetRedirectURI(getHost(req)) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) } @@ -184,7 +184,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) { return } - session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code")) + session, err := p.redeemCode(req.Context(), getHost(req), req.Form.Get("code")) if err != nil { p.logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") diff --git a/outpost/pkg/proxy/server.go b/outpost/pkg/proxy/server.go index e01dbee95..44a355b51 100644 --- a/outpost/pkg/proxy/server.go +++ b/outpost/pkg/proxy/server.go @@ -42,7 +42,8 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { w.WriteHeader(204) return } - handler, ok := s.Handlers[r.Host] + host := getHost(r) + handler, ok := s.Handlers[host] if !ok { // If we only have one handler, host name switching doesn't matter if len(s.Handlers) == 1 { @@ -56,7 +57,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) { for k := range s.Handlers { hostKeys = append(hostKeys, k) } - s.logger.WithField("host", r.Host).WithField("known-hosts", strings.Join(hostKeys, ", ")).Debug("Host header does not match any we know of") + s.logger.WithField("host", host).WithField("known-hosts", strings.Join(hostKeys, ", ")).Debug("Host header does not match any we know of") w.WriteHeader(404) return } diff --git a/outpost/pkg/proxy/utils.go b/outpost/pkg/proxy/utils.go new file mode 100644 index 000000000..d9e4602a9 --- /dev/null +++ b/outpost/pkg/proxy/utils.go @@ -0,0 +1,12 @@ +package proxy + +import "net/http" + +var xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host") + +func getHost(req *http.Request) string { + if req.Header.Get(xForwardedHost) != "" { + return req.Header.Get(xForwardedHost) + } + return req.Host +}