providers/oauth2: add missing kid header to JWT Tokens

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-05-21 23:40:00 +02:00
parent a265dd54cc
commit 6600da7d98
3 changed files with 19 additions and 8 deletions

View File

@ -6,11 +6,10 @@ import time
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from hashlib import sha256 from hashlib import sha256
from typing import Any, Optional, Type, Union from typing import Any, Optional, Type
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from dacite import from_dict from dacite import from_dict
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
@ -238,7 +237,7 @@ class OAuth2Provider(Provider):
token.access_token = token.create_access_token(user, request) token.access_token = token.create_access_token(user, request)
return token return token
def get_jwt_keys(self) -> Union[RSAPrivateKey, str]: def get_jwt_key(self) -> str:
""" """
Takes a provider and returns the set of keys associated with it. Takes a provider and returns the set of keys associated with it.
Returns a list of keys. Returns a list of keys.
@ -255,7 +254,7 @@ class OAuth2Provider(Provider):
self.jwt_alg = JWTAlgorithms.HS256 self.jwt_alg = JWTAlgorithms.HS256
self.save() self.save()
else: else:
return self.rsa_key.private_key return self.rsa_key.key_data
if self.jwt_alg == JWTAlgorithms.HS256: if self.jwt_alg == JWTAlgorithms.HS256:
return self.client_secret return self.client_secret
@ -299,11 +298,14 @@ class OAuth2Provider(Provider):
def encode(self, payload: dict[str, Any]) -> str: def encode(self, payload: dict[str, Any]) -> str:
"""Represent the ID Token as a JSON Web Token (JWT).""" """Represent the ID Token as a JSON Web Token (JWT)."""
key = self.get_jwt_keys() headers = {}
if self.rsa_key:
headers["kid"] = self.rsa_key.kid
key = self.get_jwt_key()
# If the provider does not have an RSA Key assigned, it was switched to Symmetric # If the provider does not have an RSA Key assigned, it was switched to Symmetric
self.refresh_from_db() self.refresh_from_db()
# pyright: reportGeneralTypeIssues=false # pyright: reportGeneralTypeIssues=false
return encode(payload, key, algorithm=self.jwt_alg) return encode(payload, key, algorithm=self.jwt_alg, headers=headers)
class Meta: class Meta:

View File

@ -4,6 +4,7 @@ from django.urls import reverse
from django.utils.encoding import force_str from django.utils.encoding import force_str
from authentik.core.models import Application, User from authentik.core.models import Application, User
from authentik.crypto.models import CertificateKeyPair
from authentik.flows.challenge import ChallengeTypes from authentik.flows.challenge import ChallengeTypes
from authentik.flows.models import Flow from authentik.flows.models import Flow
from authentik.providers.oauth2.errors import ( from authentik.providers.oauth2.errors import (
@ -207,6 +208,7 @@ class TestAuthorize(OAuthTestCase):
client_secret=generate_client_secret(), client_secret=generate_client_secret(),
authorization_flow=flow, authorization_flow=flow,
redirect_uris="http://localhost", redirect_uris="http://localhost",
rsa_key=CertificateKeyPair.objects.first(),
) )
Application.objects.create(name="app", slug="app", provider=provider) Application.objects.create(name="app", slug="app", provider=provider)
state = generate_client_id() state = generate_client_id()

View File

@ -2,7 +2,11 @@
from django.test import TestCase from django.test import TestCase
from jwt import decode from jwt import decode
from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken from authentik.providers.oauth2.models import (
JWTAlgorithms,
OAuth2Provider,
RefreshToken,
)
class OAuthTestCase(TestCase): class OAuthTestCase(TestCase):
@ -19,9 +23,12 @@ class OAuthTestCase(TestCase):
def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider):
"""Validate that all required fields are set""" """Validate that all required fields are set"""
key = provider.client_secret
if provider.jwt_alg == JWTAlgorithms.RS256:
key = provider.rsa_key.public_key
jwt = decode( jwt = decode(
token.access_token, token.access_token,
provider.client_secret, key,
algorithms=[provider.jwt_alg], algorithms=[provider.jwt_alg],
audience=provider.client_id, audience=provider.client_id,
) )