flows: ensure all StageViews accept post, add tests

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-09-09 16:29:00 +02:00
parent 7158c9d2ea
commit d0898a3869
7 changed files with 55 additions and 0 deletions

View File

@ -28,3 +28,7 @@ class PostUserEnrollmentStage(StageView):
source=connection.source, source=connection.source,
).from_http(self.request) ).from_http(self.request)
return self.executor.stage_ok() return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View File

@ -0,0 +1,31 @@
"""stage view tests"""
from typing import Callable, Type
from django.test import RequestFactory, TestCase
from authentik.flows.stage import StageView
from authentik.flows.views import FlowExecutorView
from authentik.lib.utils.reflection import all_subclasses
class TestViews(TestCase):
"""Generic model properties tests"""
def setUp(self) -> None:
self.factory = RequestFactory()
self.exec = FlowExecutorView(self.factory.request("/"))
def view_tester_factory(view: Type[StageView]) -> Callable:
"""Test a form"""
def tester(self: TestViews):
model_class = view(self.exec)
self.assertIsNotNone(model_class.post)
self.assertIsNotNone(model_class.get)
return tester
for view in all_subclasses(StageView):
setattr(TestViews, f"test_view_{view.__name__}", view_tester_factory(view))

View File

@ -13,3 +13,7 @@ class DenyStageView(StageView):
def get(self, request: HttpRequest) -> HttpResponse: def get(self, request: HttpRequest) -> HttpResponse:
"""Cancells the current flow""" """Cancells the current flow"""
return self.executor.stage_invalid() return self.executor.stage_invalid()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View File

@ -23,6 +23,10 @@ INVITATION = "invitation"
class InvitationStageView(StageView): class InvitationStageView(StageView):
"""Finalise Authentication flow by logging the user in""" """Finalise Authentication flow by logging the user in"""
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def get_token(self) -> Optional[str]: def get_token(self) -> Optional[str]:
"""Get token from saved get-arguments or prompt_data""" """Get token from saved get-arguments or prompt_data"""
if INVITATION_TOKEN_KEY in self.request.session.get(SESSION_KEY_GET, {}): if INVITATION_TOKEN_KEY in self.request.session.get(SESSION_KEY_GET, {}):

View File

@ -14,6 +14,10 @@ LOGGER = get_logger()
class UserDeleteStageView(StageView): class UserDeleteStageView(StageView):
"""Finalise unenrollment flow by deleting the user object.""" """Finalise unenrollment flow by deleting the user object."""
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def get(self, request: HttpRequest) -> HttpResponse: def get(self, request: HttpRequest) -> HttpResponse:
"""Delete currently pending user""" """Delete currently pending user"""
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context: if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:

View File

@ -18,6 +18,10 @@ USER_LOGIN_AUTHENTICATED = "user_login_authenticated"
class UserLoginStageView(StageView): class UserLoginStageView(StageView):
"""Finalise Authentication flow by logging the user in""" """Finalise Authentication flow by logging the user in"""
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def get(self, request: HttpRequest) -> HttpResponse: def get(self, request: HttpRequest) -> HttpResponse:
"""Attach the currently pending user to the current session""" """Attach the currently pending user to the current session"""
if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context: if PLAN_CONTEXT_PENDING_USER not in self.executor.plan.context:

View File

@ -20,3 +20,7 @@ class UserLogoutStageView(StageView):
) )
logout(self.request) logout(self.request)
return self.executor.stage_ok() return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)