diff --git a/passbook/providers/saml/views.py b/passbook/providers/saml/views.py index ebd3041b5..61c7ca8fb 100644 --- a/passbook/providers/saml/views.py +++ b/passbook/providers/saml/views.py @@ -83,30 +83,29 @@ class LoginBeginView(AccessRequiredView): """Receives a SAML 2.0 AuthnRequest from a Service Provider and stores it in the session prior to enforcing login.""" - @method_decorator(csrf_exempt) - def dispatch(self, request: HttpRequest, application: str) -> HttpResponse: - if request.method == "POST": - source = request.POST - else: - source = request.GET - + def handler(self, source, application: str) -> HttpResponse: + """Handle SAML Request whether its a POST or a Redirect binding""" # Store these values now, because Django's login cycle won't preserve them. try: - request.session[SESSION_KEY_SAML_REQUEST] = source[SESSION_KEY_SAML_REQUEST] + self.request.session[SESSION_KEY_SAML_REQUEST] = source[ + SESSION_KEY_SAML_REQUEST + ] except (KeyError, MultiValueDictKeyError): - return bad_request_message(request, "The SAML request payload is missing.") + return bad_request_message( + self.request, "The SAML request payload is missing." + ) - request.session[SESSION_KEY_RELAY_STATE] = source.get( + self.request.session[SESSION_KEY_RELAY_STATE] = source.get( SESSION_KEY_RELAY_STATE, "" ) try: - self.provider.processor.can_handle(request) + self.provider.processor.can_handle(self.request) params = self.provider.processor.generate_response() - request.session[SESSION_KEY_PARAMS] = params + self.request.session[SESSION_KEY_PARAMS] = params except CannotHandleAssertion as exc: LOGGER.info(exc) - did_you_mean_link = request.build_absolute_uri( + did_you_mean_link = self.request.build_absolute_uri( reverse( "passbook_providers_saml:saml-login-initiate", kwargs={"application": application}, @@ -116,7 +115,7 @@ class LoginBeginView(AccessRequiredView): f" Did you mean to go here?" ) return bad_request_message( - request, mark_safe(str(exc) + did_you_mean_message) + self.request, mark_safe(str(exc) + did_you_mean_message) ) return redirect( @@ -126,6 +125,16 @@ class LoginBeginView(AccessRequiredView): ) ) + @method_decorator(csrf_exempt) + def get(self, request: HttpRequest, application: str) -> HttpResponse: + """Handle REDIRECT bindings""" + return self.handler(request.GET, application) + + @method_decorator(csrf_exempt) + def post(self, request: HttpRequest, application: str) -> HttpResponse: + """Handle POST Bindings""" + return self.handler(request.POST, application) + class InitiateLoginView(AccessRequiredView): """IdP-initiated Login"""