From 6600da7d98163b2ce8a8d41ec4be8449a76b9161 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Fri, 21 May 2021 23:40:00 +0200 Subject: [PATCH] providers/oauth2: add missing kid header to JWT Tokens Signed-off-by: Jens Langhammer --- authentik/providers/oauth2/models.py | 14 ++++++++------ authentik/providers/oauth2/tests/test_authorize.py | 2 ++ authentik/providers/oauth2/tests/utils.py | 11 +++++++++-- 3 files changed, 19 insertions(+), 8 deletions(-) diff --git a/authentik/providers/oauth2/models.py b/authentik/providers/oauth2/models.py index e40c6f643..fc87a1fb8 100644 --- a/authentik/providers/oauth2/models.py +++ b/authentik/providers/oauth2/models.py @@ -6,11 +6,10 @@ import time from dataclasses import asdict, dataclass, field from datetime import datetime from hashlib import sha256 -from typing import Any, Optional, Type, Union +from typing import Any, Optional, Type from urllib.parse import urlparse from uuid import uuid4 -from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from dacite import from_dict from django.db import models from django.http import HttpRequest @@ -238,7 +237,7 @@ class OAuth2Provider(Provider): token.access_token = token.create_access_token(user, request) return token - def get_jwt_keys(self) -> Union[RSAPrivateKey, str]: + def get_jwt_key(self) -> str: """ Takes a provider and returns the set of keys associated with it. Returns a list of keys. @@ -255,7 +254,7 @@ class OAuth2Provider(Provider): self.jwt_alg = JWTAlgorithms.HS256 self.save() else: - return self.rsa_key.private_key + return self.rsa_key.key_data if self.jwt_alg == JWTAlgorithms.HS256: return self.client_secret @@ -299,11 +298,14 @@ class OAuth2Provider(Provider): def encode(self, payload: dict[str, Any]) -> str: """Represent the ID Token as a JSON Web Token (JWT).""" - key = self.get_jwt_keys() + headers = {} + if self.rsa_key: + headers["kid"] = self.rsa_key.kid + key = self.get_jwt_key() # If the provider does not have an RSA Key assigned, it was switched to Symmetric self.refresh_from_db() # pyright: reportGeneralTypeIssues=false - return encode(payload, key, algorithm=self.jwt_alg) + return encode(payload, key, algorithm=self.jwt_alg, headers=headers) class Meta: diff --git a/authentik/providers/oauth2/tests/test_authorize.py b/authentik/providers/oauth2/tests/test_authorize.py index da169ed4e..2c0d03955 100644 --- a/authentik/providers/oauth2/tests/test_authorize.py +++ b/authentik/providers/oauth2/tests/test_authorize.py @@ -4,6 +4,7 @@ from django.urls import reverse from django.utils.encoding import force_str from authentik.core.models import Application, User +from authentik.crypto.models import CertificateKeyPair from authentik.flows.challenge import ChallengeTypes from authentik.flows.models import Flow from authentik.providers.oauth2.errors import ( @@ -207,6 +208,7 @@ class TestAuthorize(OAuthTestCase): client_secret=generate_client_secret(), authorization_flow=flow, redirect_uris="http://localhost", + rsa_key=CertificateKeyPair.objects.first(), ) Application.objects.create(name="app", slug="app", provider=provider) state = generate_client_id() diff --git a/authentik/providers/oauth2/tests/utils.py b/authentik/providers/oauth2/tests/utils.py index 0f1264ebf..adf9c7e7b 100644 --- a/authentik/providers/oauth2/tests/utils.py +++ b/authentik/providers/oauth2/tests/utils.py @@ -2,7 +2,11 @@ from django.test import TestCase from jwt import decode -from authentik.providers.oauth2.models import OAuth2Provider, RefreshToken +from authentik.providers.oauth2.models import ( + JWTAlgorithms, + OAuth2Provider, + RefreshToken, +) class OAuthTestCase(TestCase): @@ -19,9 +23,12 @@ class OAuthTestCase(TestCase): def validate_jwt(self, token: RefreshToken, provider: OAuth2Provider): """Validate that all required fields are set""" + key = provider.client_secret + if provider.jwt_alg == JWTAlgorithms.RS256: + key = provider.rsa_key.public_key jwt = decode( token.access_token, - provider.client_secret, + key, algorithms=[provider.jwt_alg], audience=provider.client_id, )