stages/identification: move user validation to serializer

This commit is contained in:
Jens Langhammer 2021-02-20 20:16:20 +01:00
parent 8787dc23d0
commit 33f67140f2
4 changed files with 27 additions and 17 deletions

View file

@ -1,5 +1,6 @@
"""Challenge helpers"""
from enum import Enum
from typing import TYPE_CHECKING, Optional
from django.db.models.base import Model
from django.http import JsonResponse
@ -8,6 +9,9 @@ from rest_framework.serializers import CharField, Serializer
from authentik.flows.transfer.common import DataclassEncoder
if TYPE_CHECKING:
from authentik.flows.stage import StageView
class ChallengeTypes(Enum):
"""Currently defined challenge types"""
@ -36,6 +40,12 @@ class Challenge(Serializer):
class ChallengeResponse(Serializer):
"""Base class for all challenge responses"""
stage: Optional["StageView"]
def __init__(self, instance, data, **kwargs):
self.stage = kwargs.pop("stage", None)
super().__init__(instance=instance, data=data, **kwargs)
def create(self, validated_data: dict) -> Model:
return Model()

View file

@ -3,6 +3,7 @@ from collections import namedtuple
from typing import Any, Type
from django.http import HttpRequest
from django.http.request import QueryDict
from django.http.response import HttpResponse, JsonResponse
from django.utils.translation import gettext_lazy as _
from django.views.generic import TemplateView
@ -56,9 +57,9 @@ class ChallengeStageView(StageView):
response_class = ChallengeResponse
def get_response_class(self) -> Type[ChallengeResponse]:
def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
"""Return the response class type"""
return self.response_class
return self.response_class(None, data=data, stage=self)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
challenge = self.get_challenge()
@ -69,7 +70,7 @@ class ChallengeStageView(StageView):
# pylint: disable=unused-argument
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Handle challenge response"""
challenge: ChallengeResponse = self.get_response_class()(data=request.POST)
challenge: ChallengeResponse = self.get_response_instance(data=request.POST)
if not challenge.is_valid():
return self.challenge_invalid(challenge)
return self.challenge_valid(challenge)

View file

@ -7,6 +7,7 @@ from django.http import HttpResponse
from django.urls import reverse
from django.utils.translation import gettext as _
from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError
from structlog.stdlib import get_logger
from authentik.core.models import Source, User
@ -27,8 +28,16 @@ class IdentificationChallengeResponse(ChallengeResponse):
"""Identification challenge"""
uid_field = CharField()
pre_user: Optional[User] = None
# TODO: Validate here instead of challenge_valid()
def validate_uid_field(self, value: str) -> str:
"""Validate that user exists"""
pre_user = self.stage.get_user(value)
if not pre_user:
LOGGER.debug("invalid_login", identifier=value)
raise ValidationError("Failed to authenticate.")
self.pre_user = pre_user
return value
class IdentificationStageView(ChallengeStageView):
@ -96,18 +105,10 @@ class IdentificationStageView(ChallengeStageView):
def challenge_valid(
self, challenge: IdentificationChallengeResponse
) -> HttpResponse:
user_identifier = challenge.data.get("uid_field")
pre_user = self.get_user(user_identifier)
if not pre_user:
LOGGER.debug("invalid_login", identifier=user_identifier)
messages.error(self.request, _("Failed to authenticate."))
return self.challenge_invalid(challenge)
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = pre_user
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = challenge.pre_user
current_stage: IdentificationStage = self.executor.current_stage
if not current_stage.show_matched_user:
self.executor.plan.context[
PLAN_CONTEXT_PENDING_USER_IDENTIFIER
] = user_identifier
] = challenge.validated_data.get("uid_field")
return self.executor.stage_ok()

View file

@ -87,9 +87,7 @@ class TestFlowsEnroll(SeleniumTestCase):
FlowStageBinding.objects.create(target=flow, stage=user_login, order=3)
self.driver.get(self.live_server_url)
self.wait.until(
ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll"))
)
self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll")))
self.driver.find_element(By.CSS_SELECTOR, "#enroll").click()
self.wait.until(ec.presence_of_element_located((By.ID, "id_username")))