providers/oauth2: fix TokenView not having CORS headers set even with proper Origin

and added tests. closes #771

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-22 23:48:28 +02:00
parent 392d9bb10b
commit 3282b34431
2 changed files with 68 additions and 2 deletions

View File

@ -153,10 +153,61 @@ class TestViewsToken(TestCase):
"redirect_uri": "http://local.invalid", "redirect_uri": "http://local.invalid",
}, },
HTTP_AUTHORIZATION=f"Basic {header}", HTTP_AUTHORIZATION=f"Basic {header}",
HTTP_ORIGIN="http://local.invalid",
) )
new_token: RefreshToken = ( new_token: RefreshToken = (
RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first() RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
) )
self.assertEqual(response["Access-Control-Allow-Credentials"], "true")
self.assertEqual(
response["Access-Control-Allow-Origin"], "http://local.invalid"
)
self.assertJSONEqual(
force_str(response.content),
{
"access_token": new_token.access_token,
"refresh_token": new_token.refresh_token,
"token_type": "bearer",
"expires_in": 600,
"id_token": provider.encode(
new_token.id_token.to_dict(),
),
},
)
def test_refresh_token_view_invalid_origin(self):
"""test request param"""
provider = OAuth2Provider.objects.create(
name="test",
client_id=generate_client_id(),
client_secret=generate_client_secret(),
authorization_flow=Flow.objects.first(),
redirect_uris="http://local.invalid",
)
header = b64encode(
f"{provider.client_id}:{provider.client_secret}".encode()
).decode()
user = User.objects.get(username="akadmin")
token: RefreshToken = RefreshToken.objects.create(
provider=provider,
user=user,
refresh_token=generate_client_id(),
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_REFRESH_TOKEN,
"refresh_token": token.refresh_token,
"redirect_uri": "http://local.invalid",
},
HTTP_AUTHORIZATION=f"Basic {header}",
HTTP_ORIGIN="http://another.invalid",
)
new_token: RefreshToken = (
RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
)
self.assertNotIn("Access-Control-Allow-Credentials", response)
self.assertNotIn("Access-Control-Allow-Origin", response)
self.assertJSONEqual( self.assertJSONEqual(
force_str(response.content), force_str(response.content),
{ {

View File

@ -19,7 +19,11 @@ from authentik.providers.oauth2.models import (
OAuth2Provider, OAuth2Provider,
RefreshToken, RefreshToken,
) )
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth from authentik.providers.oauth2.utils import (
TokenResponse,
cors_allow,
extract_client_auth,
)
LOGGER = get_logger() LOGGER = get_logger()
@ -154,7 +158,18 @@ class TokenParams:
class TokenView(View): class TokenView(View):
"""Generate tokens for clients""" """Generate tokens for clients"""
params: TokenParams params: Optional[TokenParams] = None
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
response = super().dispatch(request, *args, **kwargs)
allowed_origins = []
if self.params:
allowed_origins = self.params.provider.redirect_uris.split("\n")
cors_allow(self.request, response, *allowed_origins)
return response
def options(self, request: HttpRequest) -> HttpResponse:
return TokenResponse({})
def post(self, request: HttpRequest) -> HttpResponse: def post(self, request: HttpRequest) -> HttpResponse:
"""Generate tokens for clients""" """Generate tokens for clients"""