providers/saml: some more cleanup, fix get_time_string when called without argument
This commit is contained in:
parent
e36d7928e4
commit
571373866e
|
@ -194,10 +194,6 @@ class Processor:
|
|||
self._logger.info(msg)
|
||||
raise CannotHandleAssertion(msg)
|
||||
|
||||
def _validate_user(self):
|
||||
"""Validates the User. Sub-classes should override this and
|
||||
throw an CannotHandleAssertion Exception if the validation does not succeed."""
|
||||
|
||||
def can_handle(self, request: HttpRequest) -> bool:
|
||||
"""Returns true if this processor can handle this request."""
|
||||
self._http_request = request
|
||||
|
@ -224,7 +220,6 @@ class Processor:
|
|||
if not self.is_idp_initiated:
|
||||
self.can_handle(self._http_request)
|
||||
|
||||
self._validate_user()
|
||||
self._build_assertion()
|
||||
self._format_assertion()
|
||||
self._build_response()
|
||||
|
|
|
@ -40,6 +40,8 @@ def timedelta_from_string(expr: str) -> datetime.timedelta:
|
|||
|
||||
def get_time_string(delta: datetime.timedelta = None) -> str:
|
||||
"""Get Data formatted in SAML format"""
|
||||
if delta is None:
|
||||
delta = datetime.timedelta()
|
||||
now = datetime.datetime.now()
|
||||
final = now + delta
|
||||
return final.strftime("%Y-%m-%dT%H:%M:%SZ")
|
||||
|
|
|
@ -1,9 +1,11 @@
|
|||
"""passbook SAML IDP Views"""
|
||||
from typing import Optional
|
||||
|
||||
from django.contrib.auth import logout
|
||||
from django.contrib.auth.mixins import AccessMixin
|
||||
from django.core.exceptions import ValidationError
|
||||
from django.core.validators import URLValidator
|
||||
from django.http import HttpResponse, HttpResponseBadRequest
|
||||
from django.http import HttpResponse, HttpResponseBadRequest, HttpRequest
|
||||
from django.shortcuts import get_object_or_404, redirect, render, reverse
|
||||
from django.utils.datastructures import MultiValueDictKeyError
|
||||
from django.utils.decorators import method_decorator
|
||||
|
@ -25,7 +27,7 @@ LOGGER = get_logger()
|
|||
URL_VALIDATOR = URLValidator(schemes=("http", "https"))
|
||||
|
||||
|
||||
def _generate_response(request, provider: SAMLProvider):
|
||||
def _generate_response(request: HttpRequest, provider: SAMLProvider):
|
||||
"""Generate a SAML response using processor_instance and return it in the proper Django
|
||||
response."""
|
||||
try:
|
||||
|
@ -42,7 +44,7 @@ def _generate_response(request, provider: SAMLProvider):
|
|||
class AccessRequiredView(AccessMixin, View):
|
||||
"""Mixin class for Views using a provider instance"""
|
||||
|
||||
_provider: SAMLProvider
|
||||
_provider: Optional[SAMLProvider] = None
|
||||
|
||||
@property
|
||||
def provider(self) -> SAMLProvider:
|
||||
|
@ -54,7 +56,7 @@ class AccessRequiredView(AccessMixin, View):
|
|||
self._provider = get_object_or_404(SAMLProvider, pk=application.provider_id)
|
||||
return self._provider
|
||||
|
||||
def _has_access(self):
|
||||
def _has_access(self) -> bool:
|
||||
"""Check if user has access to application"""
|
||||
policy_engine = PolicyEngine(
|
||||
self.provider.application.policies.all(), self.request.user, self.request
|
||||
|
@ -87,8 +89,8 @@ class LoginBeginView(AccessRequiredView):
|
|||
source = request.POST
|
||||
else:
|
||||
source = request.GET
|
||||
# Store these values now, because Django's login cycle won't preserve them.
|
||||
|
||||
# Store these values now, because Django's login cycle won't preserve them.
|
||||
try:
|
||||
request.session["SAMLRequest"] = source["SAMLRequest"]
|
||||
except (KeyError, MultiValueDictKeyError):
|
||||
|
@ -123,10 +125,9 @@ class LoginProcessView(AccessRequiredView):
|
|||
Presents a SAML 2.0 Assertion for POSTing back to the Service Provider."""
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def get(self, request, application):
|
||||
def get(self, request: HttpRequest, application: str) -> HttpResponse:
|
||||
"""Handle get request, i.e. render form"""
|
||||
LOGGER.debug("SAMLLoginProcessView", request=request, method="get")
|
||||
# Check if user has access
|
||||
# User access gets checked in dispatch
|
||||
if self.provider.application.skip_authorization:
|
||||
ctx = self.provider.processor.generate_response()
|
||||
# Log Application Authorization
|
||||
|
@ -148,10 +149,9 @@ class LoginProcessView(AccessRequiredView):
|
|||
return HttpResponseBadRequest()
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
def post(self, request, application):
|
||||
def post(self, request, application: str) -> HttpResponse:
|
||||
"""Handle post request, return back to ACS"""
|
||||
LOGGER.debug("SAMLLoginProcessView", request=request, method="post")
|
||||
# Check if user has access
|
||||
# User access gets checked in dispatch
|
||||
if request.POST.get("ACSUrl", None):
|
||||
# User accepted request
|
||||
Event.new(
|
||||
|
@ -166,10 +166,10 @@ class LoginProcessView(AccessRequiredView):
|
|||
relay_state=request.POST.get("RelayState"),
|
||||
)
|
||||
try:
|
||||
full_res = _generate_response(request, self.provider)
|
||||
return full_res
|
||||
return _generate_response(request, self.provider)
|
||||
except exceptions.CannotHandleAssertion as exc:
|
||||
LOGGER.debug(exc)
|
||||
return HttpResponseBadRequest()
|
||||
|
||||
|
||||
class LogoutView(CSRFExemptMixin, AccessRequiredView):
|
||||
|
|
Reference in a new issue