providers/oauth2: fix TokenView not having CORS headers set even with proper Origin
and added tests. closes #771 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
392d9bb10b
commit
3282b34431
|
@ -153,10 +153,61 @@ class TestViewsToken(TestCase):
|
||||||
"redirect_uri": "http://local.invalid",
|
"redirect_uri": "http://local.invalid",
|
||||||
},
|
},
|
||||||
HTTP_AUTHORIZATION=f"Basic {header}",
|
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||||
|
HTTP_ORIGIN="http://local.invalid",
|
||||||
)
|
)
|
||||||
new_token: RefreshToken = (
|
new_token: RefreshToken = (
|
||||||
RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
|
RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
|
||||||
)
|
)
|
||||||
|
self.assertEqual(response["Access-Control-Allow-Credentials"], "true")
|
||||||
|
self.assertEqual(
|
||||||
|
response["Access-Control-Allow-Origin"], "http://local.invalid"
|
||||||
|
)
|
||||||
|
self.assertJSONEqual(
|
||||||
|
force_str(response.content),
|
||||||
|
{
|
||||||
|
"access_token": new_token.access_token,
|
||||||
|
"refresh_token": new_token.refresh_token,
|
||||||
|
"token_type": "bearer",
|
||||||
|
"expires_in": 600,
|
||||||
|
"id_token": provider.encode(
|
||||||
|
new_token.id_token.to_dict(),
|
||||||
|
),
|
||||||
|
},
|
||||||
|
)
|
||||||
|
|
||||||
|
def test_refresh_token_view_invalid_origin(self):
|
||||||
|
"""test request param"""
|
||||||
|
provider = OAuth2Provider.objects.create(
|
||||||
|
name="test",
|
||||||
|
client_id=generate_client_id(),
|
||||||
|
client_secret=generate_client_secret(),
|
||||||
|
authorization_flow=Flow.objects.first(),
|
||||||
|
redirect_uris="http://local.invalid",
|
||||||
|
)
|
||||||
|
header = b64encode(
|
||||||
|
f"{provider.client_id}:{provider.client_secret}".encode()
|
||||||
|
).decode()
|
||||||
|
user = User.objects.get(username="akadmin")
|
||||||
|
token: RefreshToken = RefreshToken.objects.create(
|
||||||
|
provider=provider,
|
||||||
|
user=user,
|
||||||
|
refresh_token=generate_client_id(),
|
||||||
|
)
|
||||||
|
response = self.client.post(
|
||||||
|
reverse("authentik_providers_oauth2:token"),
|
||||||
|
data={
|
||||||
|
"grant_type": GRANT_TYPE_REFRESH_TOKEN,
|
||||||
|
"refresh_token": token.refresh_token,
|
||||||
|
"redirect_uri": "http://local.invalid",
|
||||||
|
},
|
||||||
|
HTTP_AUTHORIZATION=f"Basic {header}",
|
||||||
|
HTTP_ORIGIN="http://another.invalid",
|
||||||
|
)
|
||||||
|
new_token: RefreshToken = (
|
||||||
|
RefreshToken.objects.filter(user=user).exclude(pk=token.pk).first()
|
||||||
|
)
|
||||||
|
self.assertNotIn("Access-Control-Allow-Credentials", response)
|
||||||
|
self.assertNotIn("Access-Control-Allow-Origin", response)
|
||||||
self.assertJSONEqual(
|
self.assertJSONEqual(
|
||||||
force_str(response.content),
|
force_str(response.content),
|
||||||
{
|
{
|
||||||
|
|
|
@ -19,7 +19,11 @@ from authentik.providers.oauth2.models import (
|
||||||
OAuth2Provider,
|
OAuth2Provider,
|
||||||
RefreshToken,
|
RefreshToken,
|
||||||
)
|
)
|
||||||
from authentik.providers.oauth2.utils import TokenResponse, extract_client_auth
|
from authentik.providers.oauth2.utils import (
|
||||||
|
TokenResponse,
|
||||||
|
cors_allow,
|
||||||
|
extract_client_auth,
|
||||||
|
)
|
||||||
|
|
||||||
LOGGER = get_logger()
|
LOGGER = get_logger()
|
||||||
|
|
||||||
|
@ -154,7 +158,18 @@ class TokenParams:
|
||||||
class TokenView(View):
|
class TokenView(View):
|
||||||
"""Generate tokens for clients"""
|
"""Generate tokens for clients"""
|
||||||
|
|
||||||
params: TokenParams
|
params: Optional[TokenParams] = None
|
||||||
|
|
||||||
|
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||||
|
response = super().dispatch(request, *args, **kwargs)
|
||||||
|
allowed_origins = []
|
||||||
|
if self.params:
|
||||||
|
allowed_origins = self.params.provider.redirect_uris.split("\n")
|
||||||
|
cors_allow(self.request, response, *allowed_origins)
|
||||||
|
return response
|
||||||
|
|
||||||
|
def options(self, request: HttpRequest) -> HttpResponse:
|
||||||
|
return TokenResponse({})
|
||||||
|
|
||||||
def post(self, request: HttpRequest) -> HttpResponse:
|
def post(self, request: HttpRequest) -> HttpResponse:
|
||||||
"""Generate tokens for clients"""
|
"""Generate tokens for clients"""
|
||||||
|
|
Reference in New Issue