diff --git a/authentik/providers/oauth2/tests/test_views_token.py b/authentik/providers/oauth2/tests/test_views_token.py index 5e4dcae64..d021b1a5e 100644 --- a/authentik/providers/oauth2/tests/test_views_token.py +++ b/authentik/providers/oauth2/tests/test_views_token.py @@ -153,10 +153,61 @@ class TestViewsToken(TestCase): "redirect_uri": "http://local.invalid", }, HTTP_AUTHORIZATION=f"Basic {header}", + HTTP_ORIGIN="http://local.invalid", ) new_token: RefreshToken = ( RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first() ) + self.assertEqual(response["Access-Control-Allow-Credentials"], "true") + self.assertEqual( + response["Access-Control-Allow-Origin"], "http://local.invalid" + ) + self.assertJSONEqual( + force_str(response.content), + { + "access_token": new_token.access_token, + "refresh_token": new_token.refresh_token, + "token_type": "bearer", + "expires_in": 600, + "id_token": provider.encode( + new_token.id_token.to_dict(), + ), + }, + ) + + def test_refresh_token_view_invalid_origin(self): + """test request param""" + provider = OAuth2Provider.objects.create( + name="test", + client_id=generate_client_id(), + client_secret=generate_client_secret(), + authorization_flow=Flow.objects.first(), + redirect_uris="http://local.invalid", + ) + header = b64encode( + f"{provider.client_id}:{provider.client_secret}".encode() + ).decode() + user = User.objects.get(username="akadmin") + token: RefreshToken = RefreshToken.objects.create( + provider=provider, + user=user, + refresh_token=generate_client_id(), + ) + response = self.client.post( + reverse("authentik_providers_oauth2:token"), + data={ + "grant_type": GRANT_TYPE_REFRESH_TOKEN, + "refresh_token": token.refresh_token, + "redirect_uri": "http://local.invalid", + }, + HTTP_AUTHORIZATION=f"Basic {header}", + HTTP_ORIGIN="http://another.invalid", + ) + new_token: RefreshToken = ( + RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first() + ) + self.assertNotIn("Access-Control-Allow-Credentials", response) + self.assertNotIn("Access-Control-Allow-Origin", response) self.assertJSONEqual( force_str(response.content), { diff --git a/authentik/providers/oauth2/views/token.py b/authentik/providers/oauth2/views/token.py index 5360c0f9e..dc71beec9 100644 --- a/authentik/providers/oauth2/views/token.py +++ b/authentik/providers/oauth2/views/token.py @@ -19,7 +19,11 @@ from authentik.providers.oauth2.models import ( OAuth2Provider, RefreshToken, ) -from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth +from authentik.providers.oauth2.utils import ( + TokenResponse, + cors_allow, + extract_client_auth, +) LOGGER = get_logger() @@ -154,7 +158,18 @@ class TokenParams: class TokenView(View): """Generate tokens for clients""" - params: TokenParams + params: Optional[TokenParams] = None + + def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse: + response = super().dispatch(request, *args, **kwargs) + allowed_origins = [] + if self.params: + allowed_origins = self.params.provider.redirect_uris.split("\n") + cors_allow(self.request, response, *allowed_origins) + return response + + def options(self, request: HttpRequest) -> HttpResponse: + return TokenResponse({}) def post(self, request: HttpRequest) -> HttpResponse: """Generate tokens for clients"""