stages/identification: move user validation to serializer

This commit is contained in:
Jens Langhammer 2021-02-20 20:16:20 +01:00
parent 8787dc23d0
commit 33f67140f2
4 changed files with 27 additions and 17 deletions

View file

@ -1,5 +1,6 @@
"""Challenge helpers""" """Challenge helpers"""
from enum import Enum from enum import Enum
from typing import TYPE_CHECKING, Optional
from django.db.models.base import Model from django.db.models.base import Model
from django.http import JsonResponse from django.http import JsonResponse
@ -8,6 +9,9 @@ from rest_framework.serializers import CharField, Serializer
from authentik.flows.transfer.common import DataclassEncoder from authentik.flows.transfer.common import DataclassEncoder
if TYPE_CHECKING:
from authentik.flows.stage import StageView
class ChallengeTypes(Enum): class ChallengeTypes(Enum):
"""Currently defined challenge types""" """Currently defined challenge types"""
@ -36,6 +40,12 @@ class Challenge(Serializer):
class ChallengeResponse(Serializer): class ChallengeResponse(Serializer):
"""Base class for all challenge responses""" """Base class for all challenge responses"""
stage: Optional["StageView"]
def __init__(self, instance, data, **kwargs):
self.stage = kwargs.pop("stage", None)
super().__init__(instance=instance, data=data, **kwargs)
def create(self, validated_data: dict) -> Model: def create(self, validated_data: dict) -> Model:
return Model() return Model()

View file

@ -3,6 +3,7 @@ from collections import namedtuple
from typing import Any, Type from typing import Any, Type
from django.http import HttpRequest from django.http import HttpRequest
from django.http.request import QueryDict
from django.http.response import HttpResponse, JsonResponse from django.http.response import HttpResponse, JsonResponse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
from django.views.generic import TemplateView from django.views.generic import TemplateView
@ -56,9 +57,9 @@ class ChallengeStageView(StageView):
response_class = ChallengeResponse response_class = ChallengeResponse
def get_response_class(self) -> Type[ChallengeResponse]: def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
"""Return the response class type""" """Return the response class type"""
return self.response_class return self.response_class(None, data=data, stage=self)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
challenge = self.get_challenge() challenge = self.get_challenge()
@ -69,7 +70,7 @@ class ChallengeStageView(StageView):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Handle challenge response""" """Handle challenge response"""
challenge: ChallengeResponse = self.get_response_class()(data=request.POST) challenge: ChallengeResponse = self.get_response_instance(data=request.POST)
if not challenge.is_valid(): if not challenge.is_valid():
return self.challenge_invalid(challenge) return self.challenge_invalid(challenge)
return self.challenge_valid(challenge) return self.challenge_valid(challenge)

View file

@ -7,6 +7,7 @@ from django.http import HttpResponse
from django.urls import reverse from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from rest_framework.fields import CharField from rest_framework.fields import CharField
from rest_framework.serializers import ValidationError
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.models import Source, User from authentik.core.models import Source, User
@ -27,8 +28,16 @@ class IdentificationChallengeResponse(ChallengeResponse):
"""Identification challenge""" """Identification challenge"""
uid_field = CharField() uid_field = CharField()
pre_user: Optional[User] = None
# TODO: Validate here instead of challenge_valid() def validate_uid_field(self, value: str) -> str:
"""Validate that user exists"""
pre_user = self.stage.get_user(value)
if not pre_user:
LOGGER.debug("invalid_login", identifier=value)
raise ValidationError("Failed to authenticate.")
self.pre_user = pre_user
return value
class IdentificationStageView(ChallengeStageView): class IdentificationStageView(ChallengeStageView):
@ -96,18 +105,10 @@ class IdentificationStageView(ChallengeStageView):
def challenge_valid( def challenge_valid(
self, challenge: IdentificationChallengeResponse self, challenge: IdentificationChallengeResponse
) -> HttpResponse: ) -> HttpResponse:
user_identifier = challenge.data.get("uid_field") self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = challenge.pre_user
pre_user = self.get_user(user_identifier)
if not pre_user:
LOGGER.debug("invalid_login", identifier=user_identifier)
messages.error(self.request, _("Failed to authenticate."))
return self.challenge_invalid(challenge)
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = pre_user
current_stage: IdentificationStage = self.executor.current_stage current_stage: IdentificationStage = self.executor.current_stage
if not current_stage.show_matched_user: if not current_stage.show_matched_user:
self.executor.plan.context[ self.executor.plan.context[
PLAN_CONTEXT_PENDING_USER_IDENTIFIER PLAN_CONTEXT_PENDING_USER_IDENTIFIER
] = user_identifier ] = challenge.validated_data.get("uid_field")
return self.executor.stage_ok() return self.executor.stage_ok()

View file

@ -87,9 +87,7 @@ class TestFlowsEnroll(SeleniumTestCase):
FlowStageBinding.objects.create(target=flow, stage=user_login, order=3) FlowStageBinding.objects.create(target=flow, stage=user_login, order=3)
self.driver.get(self.live_server_url) self.driver.get(self.live_server_url)
self.wait.until( self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll")))
ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll"))
)
self.driver.find_element(By.CSS_SELECTOR, "#enroll").click() self.driver.find_element(By.CSS_SELECTOR, "#enroll").click()
self.wait.until(ec.presence_of_element_located((By.ID, "id_username"))) self.wait.until(ec.presence_of_element_located((By.ID, "id_username")))