diff --git a/authentik/stages/authenticator_sms/models.py b/authentik/stages/authenticator_sms/models.py index 7d7fa08a3..222aecb90 100644 --- a/authentik/stages/authenticator_sms/models.py +++ b/authentik/stages/authenticator_sms/models.py @@ -76,13 +76,17 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage): return self.send_generic(token, device) raise ValueError(f"invalid provider {self.provider}") + def get_message(self, token: str) -> str: + """Get SMS message""" + return _("Use this code to authenticate in authentik: %(token)s" % {"token": token}) + def send_twilio(self, token: str, device: "SMSDevice"): """send sms via twilio provider""" client = Client(self.account_sid, self.auth) try: message = client.messages.create( - to=device.phone_number, from_=self.from_number, body=token + to=device.phone_number, from_=self.from_number, body=self.get_message(token) ) LOGGER.debug("Sent SMS", to=device, message=message.sid) except TwilioRestException as exc: @@ -95,6 +99,7 @@ class AuthenticatorSMSStage(ConfigurableStage, Stage): "From": self.from_number, "To": device.phone_number, "Body": token, + "Message": self.get_message(token), } if self.mapping: diff --git a/authentik/stages/authenticator_sms/stage.py b/authentik/stages/authenticator_sms/stage.py index 0192a3d8f..c4e0e16b7 100644 --- a/authentik/stages/authenticator_sms/stage.py +++ b/authentik/stages/authenticator_sms/stage.py @@ -12,6 +12,7 @@ from authentik.flows.challenge import ( Challenge, ChallengeResponse, ChallengeTypes, + ErrorDetailSerializer, WithUserInfoChallenge, ) from authentik.flows.stage import ChallengeStageView @@ -46,15 +47,9 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse): def validate(self, attrs: dict) -> dict: """Check""" - stage: AuthenticatorSMSStage = self.device.stage if "code" not in attrs: self.device.phone_number = attrs["phone_number"] - hashed_number = hash_phone_number(self.device.phone_number) - query = Q(phone_number=hashed_number) | Q(phone_number=self.device.phone_number) - if SMSDevice.objects.filter(query, stage=self.stage.executor.current_stage.pk).exists(): - raise ValidationError(_("Invalid phone number")) - # No code yet, but we have a phone number, so send a verification message - stage.send(self.device.token, self.device) + self.stage.validate_and_send(attrs["phone_number"]) return super().validate(attrs) if not self.device.verify_token(str(attrs["code"])): raise ValidationError(_("Code does not match")) @@ -67,6 +62,17 @@ class AuthenticatorSMSStageView(ChallengeStageView): response_class = AuthenticatorSMSChallengeResponse + def validate_and_send(self, phone_number: str): + """Validate phone number and send message""" + stage: AuthenticatorSMSStage = self.executor.current_stage + hashed_number = hash_phone_number(phone_number) + query = Q(phone_number=hashed_number) | Q(phone_number=phone_number) + if SMSDevice.objects.filter(query, stage=stage.pk).exists(): + raise ValidationError(_("Invalid phone number")) + # No code yet, but we have a phone number, so send a verification message + device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] + stage.send(device.token, device) + def _has_phone_number(self) -> Optional[str]: context = self.executor.plan.context if "phone" in context.get(PLAN_CONTEXT_PROMPT, {}): @@ -96,19 +102,21 @@ class AuthenticatorSMSStageView(ChallengeStageView): def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: user = self.get_pending_user() - # Currently, this stage only supports one device per user. If the user already - # has a device, just skip to the next stage - if SMSDevice.objects.filter(user=user).exists(): - return self.executor.stage_ok() - stage: AuthenticatorSMSStage = self.executor.current_stage if SESSION_KEY_SMS_DEVICE not in self.request.session: device = SMSDevice(user=user, confirmed=False, stage=stage, name="SMS Device") device.generate_token(commit=False) + self.request.session[SESSION_KEY_SMS_DEVICE] = device if phone_number := self._has_phone_number(): device.phone_number = phone_number - self.request.session[SESSION_KEY_SMS_DEVICE] = device + try: + self.validate_and_send(phone_number) + except ValidationError as exc: + response = AuthenticatorSMSChallengeResponse() + response._errors.setdefault("phone_number", []) + response._errors["phone_number"].append(ErrorDetailSerializer(exc.detail)) + return self.challenge_invalid(response) return super().get(request, *args, **kwargs) def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: diff --git a/authentik/stages/authenticator_sms/tests.py b/authentik/stages/authenticator_sms/tests.py index d985fea08..79bfc4bf6 100644 --- a/authentik/stages/authenticator_sms/tests.py +++ b/authentik/stages/authenticator_sms/tests.py @@ -80,6 +80,39 @@ class AuthenticatorSMSStageTests(FlowTestCase): phone_number_required=False, ) + def test_stage_context_data(self): + """test stage context data""" + self.client.get( + reverse("authentik_flows:configure", kwargs={"stage_uuid": self.stage.stage_uuid}), + ) + sms_send_mock = MagicMock() + with ( + patch( + ( + "authentik.stages.authenticator_sms.stage." + "AuthenticatorSMSStageView._has_phone_number" + ), + MagicMock( + return_value="1234", + ), + ), + patch( + "authentik.stages.authenticator_sms.models.AuthenticatorSMSStage.send", + sms_send_mock, + ), + ): + response = self.client.get( + reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug}), + ) + sms_send_mock.assert_called_once() + self.assertStageResponse( + response, + self.flow, + self.user, + component="ak-stage-authenticator-sms", + phone_number_required=False, + ) + def test_stage_submit_full(self): """test stage (submit)""" self.client.get(