outpost: check for X-Forwarded-Host to switch context

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-23 14:07:44 +02:00
parent 5b3941a425
commit 7a49377caf
5 changed files with 21 additions and 8 deletions

View File

@ -15,9 +15,9 @@ import (
"goauthentik.io/outpost/pkg" "goauthentik.io/outpost/pkg"
) )
func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) { func (ac *APIController) initWS(akURL url.URL, outpostUUID strfmt.UUID) {
pathTemplate := "%s://%s/ws/outpost/%s/" pathTemplate := "%s://%s/ws/outpost/%s/"
scheme := strings.ReplaceAll(pbURL.Scheme, "http", "ws") scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
authHeader := fmt.Sprintf("Bearer %s", ac.token) authHeader := fmt.Sprintf("Bearer %s", ac.token)
@ -37,7 +37,7 @@ func (ac *APIController) initWS(pbURL url.URL, outpostUUID strfmt.UUID) {
InsecureSkipVerify: strings.ToLower(value) == "true", InsecureSkipVerify: strings.ToLower(value) == "true",
}, },
} }
ws.Dial(fmt.Sprintf(pathTemplate, scheme, pbURL.Host, outpostUUID.String()), header) ws.Dial(fmt.Sprintf(pathTemplate, scheme, akURL.Host, outpostUUID.String()), header)
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("connecting to authentik") ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID.String()).Debug("connecting to authentik")

View File

@ -107,7 +107,7 @@ func (h loggingHandler) ServeHTTP(w http.ResponseWriter, req *http.Request) {
duration := float64(time.Since(t)) / float64(time.Millisecond) duration := float64(time.Since(t)) / float64(time.Millisecond)
h.logger.WithFields(log.Fields{ h.logger.WithFields(log.Fields{
"host": req.RemoteAddr, "host": req.RemoteAddr,
"vhost": req.Host, "vhost": getHost(req),
"request_protocol": req.Proto, "request_protocol": req.Proto,
"runtime": fmt.Sprintf("%0.3f", duration), "runtime": fmt.Sprintf("%0.3f", duration),
"method": req.Method, "method": req.Method,

View File

@ -161,7 +161,7 @@ func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error()) p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
return return
} }
redirectURI := p.GetRedirectURI(req.Host) redirectURI := p.GetRedirectURI(getHost(req))
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound) http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
} }
@ -184,7 +184,7 @@ func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
return return
} }
session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code")) session, err := p.redeemCode(req.Context(), getHost(req), req.Form.Get("code"))
if err != nil { if err != nil {
p.logger.Errorf("Error redeeming code during OAuth2 callback: %v", err) p.logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error") p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")

View File

@ -42,7 +42,8 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(204) w.WriteHeader(204)
return return
} }
handler, ok := s.Handlers[r.Host] host := getHost(r)
handler, ok := s.Handlers[host]
if !ok { if !ok {
// If we only have one handler, host name switching doesn't matter // If we only have one handler, host name switching doesn't matter
if len(s.Handlers) == 1 { if len(s.Handlers) == 1 {
@ -56,7 +57,7 @@ func (s *Server) handler(w http.ResponseWriter, r *http.Request) {
for k := range s.Handlers { for k := range s.Handlers {
hostKeys = append(hostKeys, k) hostKeys = append(hostKeys, k)
} }
s.logger.WithField("host", r.Host).WithField("known-hosts", strings.Join(hostKeys, ", ")).Debug("Host header does not match any we know of") s.logger.WithField("host", host).WithField("known-hosts", strings.Join(hostKeys, ", ")).Debug("Host header does not match any we know of")
w.WriteHeader(404) w.WriteHeader(404)
return return
} }

View File

@ -0,0 +1,12 @@
package proxy
import "net/http"
var xForwardedHost = http.CanonicalHeaderKey("X-Forwarded-Host")
func getHost(req *http.Request) string {
if req.Header.Get(xForwardedHost) != "" {
return req.Header.Get(xForwardedHost)
}
return req.Host
}