flows: remove need for post() wrapper by using dispatch (#6765)

Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
Jens L 2023-09-05 22:15:03 +02:00 committed by GitHub
parent 7cbce1bb3d
commit e373bae189
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
10 changed files with 15 additions and 48 deletions

View file

@ -48,7 +48,7 @@ class Action(Enum):
class MessageStage(StageView):
"""Show a pre-configured message after the flow is done"""
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Show a pre-configured message after the flow is done"""
message = getattr(self.executor.current_stage, "message", "")
level = getattr(self.executor.current_stage, "level", messages.SUCCESS)
@ -59,10 +59,6 @@ class MessageStage(StageView):
)
return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
class SourceFlowManager:
"""Help sources decide what they should do after authorization. Based on source settings and

View file

@ -13,7 +13,7 @@ class PostUserEnrollmentStage(StageView):
"""Dynamically injected stage which saves the Connection after
the user has been enrolled."""
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Stage used after the user has been enrolled"""
connection: UserSourceConnection = self.executor.plan.context[
PLAN_CONTEXT_SOURCES_CONNECTION
@ -27,7 +27,3 @@ class PostUserEnrollmentStage(StageView):
source=connection.source,
).from_http(self.request)
return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View file

@ -21,8 +21,9 @@ def view_tester_factory(view_class: type[StageView]) -> Callable:
def tester(self: TestViews):
model_class = view_class(self.exec)
self.assertIsNotNone(model_class.post)
self.assertIsNotNone(model_class.get)
if not hasattr(model_class, "dispatch"):
self.assertIsNotNone(model_class.post)
self.assertIsNotNone(model_class.get)
return tester

View file

@ -295,7 +295,7 @@ class FlowExecutorView(APIView):
span.set_data("Method", "GET")
span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.get(request, *args, **kwargs)
stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response)
except Exception as exc: # pylint: disable=broad-except
return self.handle_exception(exc)
@ -339,7 +339,7 @@ class FlowExecutorView(APIView):
span.set_data("Method", "POST")
span.set_data("authentik Stage", self.current_stage_view)
span.set_data("authentik Flow", self.flow.slug)
stage_response = self.current_stage_view.post(request, *args, **kwargs)
stage_response = self.current_stage_view.dispatch(request)
return to_stage_response(request, stage_response)
except Exception as exc: # pylint: disable=broad-except
return self.handle_exception(exc)

View file

@ -7,10 +7,6 @@ from authentik.flows.stage import StageView
class DenyStageView(StageView):
"""Cancels the current flow"""
def get(self, request: HttpRequest) -> HttpResponse:
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Cancels the current flow"""
return self.executor.stage_invalid()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View file

@ -21,10 +21,6 @@ INVITATION = "invitation"
class InvitationStageView(StageView):
"""Finalise Authentication flow by logging the user in"""
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def get_token(self) -> Optional[str]:
"""Get token from saved get-arguments or prompt_data"""
# Check for ?token= and ?itoken=
@ -55,7 +51,7 @@ class InvitationStageView(StageView):
return None
return invite
def get(self, request: HttpRequest) -> HttpResponse:
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Apply data to the current flow based on a URL"""
stage: InvitationStage = self.executor.current_stage

View file

@ -11,11 +11,7 @@ from authentik.flows.stage import StageView
class UserDeleteStageView(StageView):
"""Finalise unenrollment flow by deleting the user object."""
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def get(self, request: HttpRequest) -> HttpResponse:
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Delete currently pending user"""
user = self.get_pending_user()
if not user.is_authenticated:

View file

@ -41,17 +41,11 @@ class UserLoginStageView(ChallengeStageView):
}
)
def post(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Wrapper for post requests"""
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Check for remember_me, and do login"""
stage: UserLoginStage = self.executor.current_stage
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
return super().post(request, *args, **kwargs)
return self.do_login(request)
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
stage: UserLoginStage = self.executor.current_stage
if timedelta_from_string(stage.remember_me_offset).total_seconds() > 0:
return super().get(request, *args, **kwargs)
return super().dispatch(request)
return self.do_login(request)
def challenge_valid(self, response: UserLoginChallengeResponse) -> HttpResponse:

View file

@ -8,7 +8,7 @@ from authentik.flows.stage import StageView
class UserLogoutStageView(StageView):
"""Finalise Authentication flow by logging the user in"""
def get(self, request: HttpRequest) -> HttpResponse:
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Remove the user from the current session"""
self.logger.debug(
"Logged out",
@ -17,7 +17,3 @@ class UserLogoutStageView(StageView):
)
logout(self.request)
return self.executor.stage_ok()
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)

View file

@ -51,10 +51,6 @@ class UserWriteStageView(StageView):
attrs = attrs.get(comp)
attrs[parts[-1]] = value
def post(self, request: HttpRequest) -> HttpResponse:
"""Wrapper for post requests"""
return self.get(request)
def ensure_user(self) -> tuple[Optional[User], bool]:
"""Ensure a user exists"""
user_created = False
@ -127,7 +123,7 @@ class UserWriteStageView(StageView):
if connection.source.name not in user.attributes[USER_ATTRIBUTE_SOURCES]:
user.attributes[USER_ATTRIBUTE_SOURCES].append(connection.source.name)
def get(self, request: HttpRequest) -> HttpResponse:
def dispatch(self, request: HttpRequest) -> HttpResponse:
"""Save data in the current flow to the currently pending user. If no user is pending,
a new user is created."""
if PLAN_CONTEXT_PROMPT not in self.executor.plan.context: