providers/oauth2: add missing kid header to JWT Tokens
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
a265dd54cc
commit
6600da7d98
|
@ -6,11 +6,10 @@ import time
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, Optional, Type, Union
|
from typing import Any, Optional, Type
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
|
|
||||||
from dacite import from_dict
|
from dacite import from_dict
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
@ -238,7 +237,7 @@ class OAuth2Provider(Provider):
|
||||||
token.access_token = token.create_access_token(user, request)
|
token.access_token = token.create_access_token(user, request)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def get_jwt_keys(self) -> Union[RSAPrivateKey, str]:
|
def get_jwt_key(self) -> str:
|
||||||
"""
|
"""
|
||||||
Takes a provider and returns the set of keys associated with it.
|
Takes a provider and returns the set of keys associated with it.
|
||||||
Returns a list of keys.
|
Returns a list of keys.
|
||||||
|
@ -255,7 +254,7 @@ class OAuth2Provider(Provider):
|
||||||
self.jwt_alg = JWTAlgorithms.HS256
|
self.jwt_alg = JWTAlgorithms.HS256
|
||||||
self.save()
|
self.save()
|
||||||
else:
|
else:
|
||||||
return self.rsa_key.private_key
|
return self.rsa_key.key_data
|
||||||
|
|
||||||
if self.jwt_alg == JWTAlgorithms.HS256:
|
if self.jwt_alg == JWTAlgorithms.HS256:
|
||||||
return self.client_secret
|
return self.client_secret
|
||||||
|
@ -299,11 +298,14 @@ class OAuth2Provider(Provider):
|
||||||
|
|
||||||
def encode(self, payload: dict[str, Any]) -> str:
|
def encode(self, payload: dict[str, Any]) -> str:
|
||||||
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
||||||
key = self.get_jwt_keys()
|
headers = {}
|
||||||
|
if self.rsa_key:
|
||||||
|
headers["kid"] = self.rsa_key.kid
|
||||||
|
key = self.get_jwt_key()
|
||||||
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
||||||
self.refresh_from_db()
|
self.refresh_from_db()
|
||||||
# pyright: reportGeneralTypeIssues=false
|
# pyright: reportGeneralTypeIssues=false
|
||||||
return encode(payload, key, algorithm=self.jwt_alg)
|
return encode(payload, key, algorithm=self.jwt_alg, headers=headers)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
|
|
||||||
|
|
|
@ -4,6 +4,7 @@ from django.urls import reverse
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
|
||||||
from authentik.core.models import Application, User
|
from authentik.core.models import Application, User
|
||||||
|
from authentik.crypto.models import CertificateKeyPair
|
||||||
from authentik.flows.challenge import ChallengeTypes
|
from authentik.flows.challenge import ChallengeTypes
|
||||||
from authentik.flows.models import Flow
|
from authentik.flows.models import Flow
|
||||||
from authentik.providers.oauth2.errors import (
|
from authentik.providers.oauth2.errors import (
|
||||||
|
@ -207,6 +208,7 @@ class TestAuthorize(OAuthTestCase):
|
||||||
client_secret=generate_client_secret(),
|
client_secret=generate_client_secret(),
|
||||||
authorization_flow=flow,
|
authorization_flow=flow,
|
||||||
redirect_uris="http://localhost",
|
redirect_uris="http://localhost",
|
||||||
|
rsa_key=CertificateKeyPair.objects.first(),
|
||||||
)
|
)
|
||||||
Application.objects.create(name="app", slug="app", provider=provider)
|
Application.objects.create(name="app", slug="app", provider=provider)
|
||||||
state = generate_client_id()
|
state = generate_client_id()
|
||||||
|
|
|
@ -2,7 +2,11 @@
|
||||||
from django.test import TestCase
|
from django.test import TestCase
|
||||||
from jwt import decode
|
from jwt import decode
|
||||||
|
|
||||||
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken
|
from authentik.providers.oauth2.models import (
|
||||||
|
JWTAlgorithms,
|
||||||
|
OAuth2Provider,
|
||||||
|
RefreshToken,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
class OAuthTestCase(TestCase):
|
class OAuthTestCase(TestCase):
|
||||||
|
@ -19,9 +23,12 @@ class OAuthTestCase(TestCase):
|
||||||
|
|
||||||
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
|
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
|
||||||
"""Validate that all required fields are set"""
|
"""Validate that all required fields are set"""
|
||||||
|
key = provider.client_secret
|
||||||
|
if provider.jwt_alg == JWTAlgorithms.RS256:
|
||||||
|
key = provider.rsa_key.public_key
|
||||||
jwt = decode(
|
jwt = decode(
|
||||||
token.access_token,
|
token.access_token,
|
||||||
provider.client_secret,
|
key,
|
||||||
algorithms=[provider.jwt_alg],
|
algorithms=[provider.jwt_alg],
|
||||||
audience=provider.client_id,
|
audience=provider.client_id,
|
||||||
)
|
)
|
||||||
|
|
Reference in New Issue