outposts/proxy: add support for multiple states, when multiple requests are redirect at once
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
f93f7e635b
commit
410d1b97cd
|
@ -76,6 +76,7 @@ func (a *Application) getTraefikForwardUrl(r *http.Request) *url.URL {
|
||||||
a.log.WithError(err).Warning("Failed to parse URL from traefik")
|
a.log.WithError(err).Warning("Failed to parse URL from traefik")
|
||||||
return r.URL
|
return r.URL
|
||||||
}
|
}
|
||||||
|
a.log.WithField("url", u.String()).Trace("traefik forwarded url")
|
||||||
return u
|
return u
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -10,10 +10,14 @@ import (
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *Application) handleRedirect(rw http.ResponseWriter, r *http.Request) {
|
func (a *Application) handleRedirect(rw http.ResponseWriter, r *http.Request) {
|
||||||
state := base64.RawStdEncoding.EncodeToString(securecookie.GenerateRandomKey(32))
|
newState := base64.RawStdEncoding.EncodeToString(securecookie.GenerateRandomKey(32))
|
||||||
s, _ := a.sessions.Get(r, constants.SeesionName)
|
s, err := a.sessions.Get(r, constants.SeesionName)
|
||||||
s.Values[constants.SessionOAuthState] = state
|
if err != nil {
|
||||||
err := s.Save(r, rw)
|
s.Values[constants.SessionOAuthState] = []string{}
|
||||||
|
}
|
||||||
|
state := s.Values[constants.SessionOAuthState].([]string)
|
||||||
|
s.Values[constants.SessionOAuthState] = append(state, newState)
|
||||||
|
err = s.Save(r, rw)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.WithError(err).Warning("failed to save session")
|
a.log.WithError(err).Warning("failed to save session")
|
||||||
}
|
}
|
||||||
|
@ -24,7 +28,7 @@ func (a *Application) handleRedirect(rw http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(state), http.StatusFound)
|
http.Redirect(rw, r, a.oauthConfig.AuthCodeURL(newState), http.StatusFound)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (a *Application) handleCallback(rw http.ResponseWriter, r *http.Request) {
|
func (a *Application) handleCallback(rw http.ResponseWriter, r *http.Request) {
|
||||||
|
@ -35,7 +39,7 @@ func (a *Application) handleCallback(rw http.ResponseWriter, r *http.Request) {
|
||||||
http.Redirect(rw, r, a.proxyConfig.ExternalHost, http.StatusFound)
|
http.Redirect(rw, r, a.proxyConfig.ExternalHost, http.StatusFound)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
claims, err := a.redeemCallback(r, state.(string))
|
claims, err := a.redeemCallback(r, state.([]string))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
a.log.WithError(err).Warning("failed to redeem code")
|
a.log.WithError(err).Warning("failed to redeem code")
|
||||||
rw.WriteHeader(400)
|
rw.WriteHeader(400)
|
||||||
|
@ -61,6 +65,7 @@ func (a *Application) handleCallback(rw http.ResponseWriter, r *http.Request) {
|
||||||
redirect := a.proxyConfig.ExternalHost
|
redirect := a.proxyConfig.ExternalHost
|
||||||
redirectR, ok := s.Values[constants.SessionRedirect]
|
redirectR, ok := s.Values[constants.SessionRedirect]
|
||||||
if ok {
|
if ok {
|
||||||
|
a.log.WithField("redirect", redirectR).Trace("got final redirect from session")
|
||||||
redirect = redirectR.(string)
|
redirect = redirectR.(string)
|
||||||
}
|
}
|
||||||
http.Redirect(rw, r, redirect, http.StatusFound)
|
http.Redirect(rw, r, redirect, http.StatusFound)
|
||||||
|
|
|
@ -8,10 +8,19 @@ import (
|
||||||
"golang.org/x/oauth2"
|
"golang.org/x/oauth2"
|
||||||
)
|
)
|
||||||
|
|
||||||
func (a *Application) redeemCallback(r *http.Request, shouldState string) (*Claims, error) {
|
func (a *Application) redeemCallback(r *http.Request, states []string) (*Claims, error) {
|
||||||
state := r.URL.Query().Get("state")
|
state := r.URL.Query().Get("state")
|
||||||
if state == "" || state != shouldState {
|
if len(states) < 1 {
|
||||||
return nil, fmt.Errorf("blank/invalid state")
|
return nil, fmt.Errorf("no states")
|
||||||
|
}
|
||||||
|
found := false
|
||||||
|
for _, fstate := range states {
|
||||||
|
if fstate == state {
|
||||||
|
found = true
|
||||||
|
}
|
||||||
|
}
|
||||||
|
if !found {
|
||||||
|
return nil, fmt.Errorf("invalid state")
|
||||||
}
|
}
|
||||||
|
|
||||||
code := r.URL.Query().Get("code")
|
code := r.URL.Query().Get("code")
|
||||||
|
|
Reference in New Issue