stages/identification: move user validation to serializer
This commit is contained in:
parent
8787dc23d0
commit
33f67140f2
|
@ -1,5 +1,6 @@
|
||||||
"""Challenge helpers"""
|
"""Challenge helpers"""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from django.db.models.base import Model
|
from django.db.models.base import Model
|
||||||
from django.http import JsonResponse
|
from django.http import JsonResponse
|
||||||
|
@ -8,6 +9,9 @@ from rest_framework.serializers import CharField, Serializer
|
||||||
|
|
||||||
from authentik.flows.transfer.common import DataclassEncoder
|
from authentik.flows.transfer.common import DataclassEncoder
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from authentik.flows.stage import StageView
|
||||||
|
|
||||||
|
|
||||||
class ChallengeTypes(Enum):
|
class ChallengeTypes(Enum):
|
||||||
"""Currently defined challenge types"""
|
"""Currently defined challenge types"""
|
||||||
|
@ -36,6 +40,12 @@ class Challenge(Serializer):
|
||||||
class ChallengeResponse(Serializer):
|
class ChallengeResponse(Serializer):
|
||||||
"""Base class for all challenge responses"""
|
"""Base class for all challenge responses"""
|
||||||
|
|
||||||
|
stage: Optional["StageView"]
|
||||||
|
|
||||||
|
def __init__(self, instance, data, **kwargs):
|
||||||
|
self.stage = kwargs.pop("stage", None)
|
||||||
|
super().__init__(instance=instance, data=data, **kwargs)
|
||||||
|
|
||||||
def create(self, validated_data: dict) -> Model:
|
def create(self, validated_data: dict) -> Model:
|
||||||
return Model()
|
return Model()
|
||||||
|
|
||||||
|
|
|
@ -3,6 +3,7 @@ from collections import namedtuple
|
||||||
from typing import Any, Type
|
from typing import Any, Type
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
from django.http.request import QueryDict
|
||||||
from django.http.response import HttpResponse, JsonResponse
|
from django.http.response import HttpResponse, JsonResponse
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
from django.views.generic import TemplateView
|
from django.views.generic import TemplateView
|
||||||
|
@ -56,9 +57,9 @@ class ChallengeStageView(StageView):
|
||||||
|
|
||||||
response_class = ChallengeResponse
|
response_class = ChallengeResponse
|
||||||
|
|
||||||
def get_response_class(self) -> Type[ChallengeResponse]:
|
def get_response_instance(self, data: QueryDict) -> ChallengeResponse:
|
||||||
"""Return the response class type"""
|
"""Return the response class type"""
|
||||||
return self.response_class
|
return self.response_class(None, data=data, stage=self)
|
||||||
|
|
||||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
challenge = self.get_challenge()
|
challenge = self.get_challenge()
|
||||||
|
@ -69,7 +70,7 @@ class ChallengeStageView(StageView):
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
"""Handle challenge response"""
|
"""Handle challenge response"""
|
||||||
challenge: ChallengeResponse = self.get_response_class()(data=request.POST)
|
challenge: ChallengeResponse = self.get_response_instance(data=request.POST)
|
||||||
if not challenge.is_valid():
|
if not challenge.is_valid():
|
||||||
return self.challenge_invalid(challenge)
|
return self.challenge_invalid(challenge)
|
||||||
return self.challenge_valid(challenge)
|
return self.challenge_valid(challenge)
|
||||||
|
|
|
@ -7,6 +7,7 @@ from django.http import HttpResponse
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
from django.utils.translation import gettext as _
|
from django.utils.translation import gettext as _
|
||||||
from rest_framework.fields import CharField
|
from rest_framework.fields import CharField
|
||||||
|
from rest_framework.serializers import ValidationError
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
|
||||||
from authentik.core.models import Source, User
|
from authentik.core.models import Source, User
|
||||||
|
@ -27,8 +28,16 @@ class IdentificationChallengeResponse(ChallengeResponse):
|
||||||
"""Identification challenge"""
|
"""Identification challenge"""
|
||||||
|
|
||||||
uid_field = CharField()
|
uid_field = CharField()
|
||||||
|
pre_user: Optional[User] = None
|
||||||
|
|
||||||
# TODO: Validate here instead of challenge_valid()
|
def validate_uid_field(self, value: str) -> str:
|
||||||
|
"""Validate that user exists"""
|
||||||
|
pre_user = self.stage.get_user(value)
|
||||||
|
if not pre_user:
|
||||||
|
LOGGER.debug("invalid_login", identifier=value)
|
||||||
|
raise ValidationError("Failed to authenticate.")
|
||||||
|
self.pre_user = pre_user
|
||||||
|
return value
|
||||||
|
|
||||||
|
|
||||||
class IdentificationStageView(ChallengeStageView):
|
class IdentificationStageView(ChallengeStageView):
|
||||||
|
@ -96,18 +105,10 @@ class IdentificationStageView(ChallengeStageView):
|
||||||
def challenge_valid(
|
def challenge_valid(
|
||||||
self, challenge: IdentificationChallengeResponse
|
self, challenge: IdentificationChallengeResponse
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
user_identifier = challenge.data.get("uid_field")
|
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = challenge.pre_user
|
||||||
pre_user = self.get_user(user_identifier)
|
|
||||||
if not pre_user:
|
|
||||||
LOGGER.debug("invalid_login", identifier=user_identifier)
|
|
||||||
messages.error(self.request, _("Failed to authenticate."))
|
|
||||||
return self.challenge_invalid(challenge)
|
|
||||||
self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] = pre_user
|
|
||||||
|
|
||||||
current_stage: IdentificationStage = self.executor.current_stage
|
current_stage: IdentificationStage = self.executor.current_stage
|
||||||
if not current_stage.show_matched_user:
|
if not current_stage.show_matched_user:
|
||||||
self.executor.plan.context[
|
self.executor.plan.context[
|
||||||
PLAN_CONTEXT_PENDING_USER_IDENTIFIER
|
PLAN_CONTEXT_PENDING_USER_IDENTIFIER
|
||||||
] = user_identifier
|
] = challenge.validated_data.get("uid_field")
|
||||||
|
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|
|
@ -87,9 +87,7 @@ class TestFlowsEnroll(SeleniumTestCase):
|
||||||
FlowStageBinding.objects.create(target=flow, stage=user_login, order=3)
|
FlowStageBinding.objects.create(target=flow, stage=user_login, order=3)
|
||||||
|
|
||||||
self.driver.get(self.live_server_url)
|
self.driver.get(self.live_server_url)
|
||||||
self.wait.until(
|
self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll")))
|
||||||
ec.presence_of_element_located((By.CSS_SELECTOR, "#enroll"))
|
|
||||||
)
|
|
||||||
self.driver.find_element(By.CSS_SELECTOR, "#enroll").click()
|
self.driver.find_element(By.CSS_SELECTOR, "#enroll").click()
|
||||||
|
|
||||||
self.wait.until(ec.presence_of_element_located((By.ID, "id_username")))
|
self.wait.until(ec.presence_of_element_located((By.ID, "id_username")))
|
||||||
|
|
Reference in a new issue