244 lines
8.0 KiB
Go
244 lines
8.0 KiB
Go
package proxy
|
|
|
|
import (
|
|
"context"
|
|
"errors"
|
|
"fmt"
|
|
"net/http"
|
|
"net/url"
|
|
"strings"
|
|
|
|
sessionsapi "github.com/oauth2-proxy/oauth2-proxy/pkg/apis/sessions"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/encryption"
|
|
"github.com/oauth2-proxy/oauth2-proxy/pkg/ip"
|
|
)
|
|
|
|
// GetRedirectURI returns the redirectURL that the upstream OAuth Provider will
|
|
// redirect clients to once authenticated
|
|
func (p *OAuthProxy) GetRedirectURI(host string) string {
|
|
// default to the request Host if not set
|
|
if p.redirectURL.Host != "" {
|
|
return p.redirectURL.String()
|
|
}
|
|
u := *p.redirectURL
|
|
if u.Scheme == "" {
|
|
if p.CookieSecure {
|
|
u.Scheme = httpsScheme
|
|
} else {
|
|
u.Scheme = httpScheme
|
|
}
|
|
}
|
|
u.Host = host
|
|
return u.String()
|
|
}
|
|
|
|
// HTTPClient is the context key to use with golang.org/x/net/context's
|
|
// WithValue function to associate an *http.Client value with a context.
|
|
var HTTPClient ContextKey
|
|
|
|
// ContextKey is just an empty struct. It exists so HTTPClient can be
|
|
// an immutable public variable with a unique type. It's immutable
|
|
// because nobody else can create a ContextKey, being unexported.
|
|
type ContextKey struct{}
|
|
|
|
func (p *OAuthProxy) redeemCode(ctx context.Context, host, code string) (s *sessionsapi.SessionState, err error) {
|
|
if code == "" {
|
|
return nil, errors.New("missing code")
|
|
}
|
|
redirectURI := p.GetRedirectURI(host)
|
|
redeemCtx := context.WithValue(ctx, HTTPClient, p.client)
|
|
s, err = p.provider.Redeem(redeemCtx, redirectURI, code)
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
if s.Email == "" {
|
|
s.Email, err = p.provider.GetEmailAddress(ctx, s)
|
|
}
|
|
|
|
if s.PreferredUsername == "" {
|
|
s.PreferredUsername, err = p.provider.GetPreferredUsername(ctx, s)
|
|
if err != nil && err.Error() == "not implemented" {
|
|
err = nil
|
|
}
|
|
}
|
|
|
|
if s.User == "" {
|
|
s.User, err = p.provider.GetUserName(ctx, s)
|
|
if err != nil && err.Error() == "not implemented" {
|
|
err = nil
|
|
}
|
|
}
|
|
return
|
|
}
|
|
|
|
// GetRedirect reads the query parameter to get the URL to redirect clients to
|
|
// once authenticated with the OAuthProxy
|
|
func (p *OAuthProxy) GetRedirect(req *http.Request) (redirect string, err error) {
|
|
err = req.ParseForm()
|
|
if err != nil {
|
|
return
|
|
}
|
|
|
|
redirect = req.Header.Get("X-Auth-Request-Redirect")
|
|
if req.Form.Get("rd") != "" {
|
|
redirect = req.Form.Get("rd")
|
|
}
|
|
if !p.IsValidRedirect(redirect) {
|
|
// Use RequestURI to preserve ?query
|
|
redirect = req.URL.RequestURI()
|
|
if strings.HasPrefix(redirect, p.ProxyPrefix) {
|
|
redirect = "/"
|
|
}
|
|
}
|
|
|
|
return
|
|
}
|
|
|
|
// IsValidRedirect checks whether the redirect URL is whitelisted
|
|
func (p *OAuthProxy) IsValidRedirect(redirect string) bool {
|
|
switch {
|
|
case redirect == "":
|
|
// The user didn't specify a redirect, should fallback to `/`
|
|
return false
|
|
case strings.HasPrefix(redirect, "/") && !strings.HasPrefix(redirect, "//") && !invalidRedirectRegex.MatchString(redirect):
|
|
return true
|
|
case strings.HasPrefix(redirect, "http://") || strings.HasPrefix(redirect, "https://"):
|
|
redirectURL, err := url.Parse(redirect)
|
|
if err != nil {
|
|
p.logger.Printf("Rejecting invalid redirect %q: scheme unsupported or missing", redirect)
|
|
return false
|
|
}
|
|
redirectHostname := redirectURL.Hostname()
|
|
|
|
for _, domain := range p.whitelistDomains {
|
|
domainHostname, domainPort := splitHostPort(strings.TrimLeft(domain, "."))
|
|
if domainHostname == "" {
|
|
continue
|
|
}
|
|
|
|
if (redirectHostname == domainHostname) || (strings.HasPrefix(domain, ".") && strings.HasSuffix(redirectHostname, domainHostname)) {
|
|
// the domain names match, now validate the ports
|
|
// if the whitelisted domain's port is '*', allow all ports
|
|
// if the whitelisted domain contains a specific port, only allow that port
|
|
// if the whitelisted domain doesn't contain a port at all, only allow empty redirect ports ie http and https
|
|
redirectPort := redirectURL.Port()
|
|
if (domainPort == "*") ||
|
|
(domainPort == redirectPort) ||
|
|
(domainPort == "" && redirectPort == "") {
|
|
return true
|
|
}
|
|
}
|
|
}
|
|
|
|
p.logger.Printf("Rejecting invalid redirect %q: domain / port not in whitelist", redirect)
|
|
return false
|
|
default:
|
|
p.logger.Printf("Rejecting invalid redirect %q: not an absolute or relative URL", redirect)
|
|
return false
|
|
}
|
|
}
|
|
|
|
// IsWhitelistedRequest is used to check if auth should be skipped for this request
|
|
func (p *OAuthProxy) IsWhitelistedRequest(req *http.Request) bool {
|
|
isPreflightRequestAllowed := p.skipAuthPreflight && req.Method == "OPTIONS"
|
|
return isPreflightRequestAllowed || p.IsWhitelistedPath(req.URL.Path)
|
|
}
|
|
|
|
// IsWhitelistedPath is used to check if the request path is allowed without auth
|
|
func (p *OAuthProxy) IsWhitelistedPath(path string) bool {
|
|
for _, u := range p.compiledRegex {
|
|
if u.MatchString(path) {
|
|
return true
|
|
}
|
|
}
|
|
return false
|
|
}
|
|
|
|
// OAuthStart starts the OAuth2 authentication flow
|
|
func (p *OAuthProxy) OAuthStart(rw http.ResponseWriter, req *http.Request) {
|
|
prepareNoCache(rw)
|
|
nonce, err := encryption.Nonce()
|
|
if err != nil {
|
|
p.logger.Errorf("Error obtaining nonce: %v", err)
|
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
|
return
|
|
}
|
|
p.SetCSRFCookie(rw, req, nonce)
|
|
redirect, err := p.GetRedirect(req)
|
|
if err != nil {
|
|
p.logger.Errorf("Error obtaining redirect: %v", err)
|
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
|
return
|
|
}
|
|
redirectURI := p.GetRedirectURI(req.Host)
|
|
http.Redirect(rw, req, p.provider.GetLoginURL(redirectURI, fmt.Sprintf("%v:%v", nonce, redirect)), http.StatusFound)
|
|
}
|
|
|
|
// OAuthCallback is the OAuth2 authentication flow callback that finishes the
|
|
// OAuth2 authentication flow
|
|
func (p *OAuthProxy) OAuthCallback(rw http.ResponseWriter, req *http.Request) {
|
|
remoteAddr := ip.GetClientString(p.realClientIPParser, req, true)
|
|
|
|
// finish the oauth cycle
|
|
err := req.ParseForm()
|
|
if err != nil {
|
|
p.logger.Errorf("Error while parsing OAuth2 callback: %v", err)
|
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
|
return
|
|
}
|
|
errorString := req.Form.Get("error")
|
|
if errorString != "" {
|
|
p.logger.Errorf("Error while parsing OAuth2 callback: %s", errorString)
|
|
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", errorString)
|
|
return
|
|
}
|
|
|
|
session, err := p.redeemCode(req.Context(), req.Host, req.Form.Get("code"))
|
|
if err != nil {
|
|
p.logger.Errorf("Error redeeming code during OAuth2 callback: %v", err)
|
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Internal Error")
|
|
return
|
|
}
|
|
|
|
s := strings.SplitN(req.Form.Get("state"), ":", 2)
|
|
if len(s) != 2 {
|
|
p.logger.Error("Error while parsing OAuth2 state: invalid length")
|
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", "Invalid State")
|
|
return
|
|
}
|
|
nonce := s[0]
|
|
redirect := s[1]
|
|
c, err := req.Cookie(p.CSRFCookieName)
|
|
if err != nil {
|
|
p.logger.WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: unable to obtain CSRF cookie")
|
|
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", err.Error())
|
|
return
|
|
}
|
|
p.ClearCSRFCookie(rw, req)
|
|
if c.Value != nonce {
|
|
p.logger.WithField("is", c.Value).WithField("should", nonce).WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: CSRF token mismatch, potential attack")
|
|
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "CSRF Failed")
|
|
return
|
|
}
|
|
|
|
if !p.IsValidRedirect(redirect) {
|
|
redirect = "/"
|
|
}
|
|
|
|
// set cookie, or deny
|
|
if p.provider.ValidateGroup(session.Email) {
|
|
p.logger.WithField("user", session.Email).WithField("status", "AuthFailure").Infof("Authenticated via OAuth2: %s", session)
|
|
err := p.SaveSession(rw, req, session)
|
|
if err != nil {
|
|
p.logger.Printf("Error saving session state for %s: %v", remoteAddr, err)
|
|
p.ErrorPage(rw, http.StatusInternalServerError, "Internal Server Error", err.Error())
|
|
return
|
|
}
|
|
http.Redirect(rw, req, redirect, http.StatusFound)
|
|
} else {
|
|
p.logger.WithField("user", session.Email).WithField("status", "AuthFailure").Info("Invalid authentication via OAuth2: unauthorized")
|
|
p.ErrorPage(rw, http.StatusForbidden, "Permission Denied", "Invalid Account")
|
|
}
|
|
}
|