providers/oauth2: add c_hash field
This commit is contained in:
parent
ee2e737782
commit
67ca83c228
|
@ -9,6 +9,7 @@ from typing import Any, Dict, List, Optional, Type
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
from dacite import from_dict
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.db import models
|
from django.db import models
|
||||||
from django.forms import ModelForm
|
from django.forms import ModelForm
|
||||||
|
@ -386,6 +387,18 @@ class AuthorizationCode(ExpiringModel, BaseGrantModel):
|
||||||
max_length=255, null=True, verbose_name=_("Code Challenge Method")
|
max_length=255, null=True, verbose_name=_("Code Challenge Method")
|
||||||
)
|
)
|
||||||
|
|
||||||
|
@property
|
||||||
|
def c_hash(self):
|
||||||
|
"""https://openid.net/specs/openid-connect-core-1_0.html#IDToken"""
|
||||||
|
hashed_code = sha256(self.code.encode("ascii")).hexdigest().encode("ascii")
|
||||||
|
return (
|
||||||
|
base64.urlsafe_b64encode(
|
||||||
|
binascii.unhexlify(hashed_code[: len(hashed_code) // 2])
|
||||||
|
)
|
||||||
|
.rstrip(b"=")
|
||||||
|
.decode("ascii")
|
||||||
|
)
|
||||||
|
|
||||||
class Meta:
|
class Meta:
|
||||||
verbose_name = _("Authorization Code")
|
verbose_name = _("Authorization Code")
|
||||||
verbose_name_plural = _("Authorization Codes")
|
verbose_name_plural = _("Authorization Codes")
|
||||||
|
@ -413,19 +426,13 @@ class IDToken:
|
||||||
auth_time: Optional[int] = None
|
auth_time: Optional[int] = None
|
||||||
acr: Optional[str] = ACR_AUTHENTIK_DEFAULT
|
acr: Optional[str] = ACR_AUTHENTIK_DEFAULT
|
||||||
|
|
||||||
|
c_hash: Optional[str] = None
|
||||||
|
|
||||||
nonce: Optional[str] = None
|
nonce: Optional[str] = None
|
||||||
at_hash: Optional[str] = None
|
at_hash: Optional[str] = None
|
||||||
|
|
||||||
claims: Dict[str, Any] = field(default_factory=dict)
|
claims: Dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
@staticmethod
|
|
||||||
def from_dict(data: Dict[str, Any]) -> "IDToken":
|
|
||||||
"""Reconstruct ID Token from json dictionary"""
|
|
||||||
token = IDToken()
|
|
||||||
for key, value in data.items():
|
|
||||||
setattr(token, key, value)
|
|
||||||
return token
|
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> Dict[str, Any]:
|
||||||
"""Convert dataclass to dict, and update with keys from `claims`"""
|
"""Convert dataclass to dict, and update with keys from `claims`"""
|
||||||
dic = asdict(self)
|
dic = asdict(self)
|
||||||
|
@ -454,7 +461,7 @@ class RefreshToken(ExpiringModel, BaseGrantModel):
|
||||||
"""Load ID Token from json"""
|
"""Load ID Token from json"""
|
||||||
if self._id_token:
|
if self._id_token:
|
||||||
raw_token = json.loads(self._id_token)
|
raw_token = json.loads(self._id_token)
|
||||||
return IDToken.from_dict(raw_token)
|
return from_dict(IDToken, raw_token)
|
||||||
return IDToken()
|
return IDToken()
|
||||||
|
|
||||||
@id_token.setter
|
@id_token.setter
|
||||||
|
|
|
@ -316,6 +316,12 @@ class OAuthFulfillmentStage(StageView):
|
||||||
if "access_token" in query_fragment:
|
if "access_token" in query_fragment:
|
||||||
id_token.at_hash = token.at_hash
|
id_token.at_hash = token.at_hash
|
||||||
|
|
||||||
|
if self.params.response_type in [
|
||||||
|
ResponseTypes.CODE_ID_TOKEN,
|
||||||
|
ResponseTypes.CODE_ID_TOKEN_TOKEN,
|
||||||
|
]:
|
||||||
|
id_token.c_hash = code.c_hash
|
||||||
|
|
||||||
# Check if response_type must include id_token in the response.
|
# Check if response_type must include id_token in the response.
|
||||||
if self.params.response_type in [
|
if self.params.response_type in [
|
||||||
ResponseTypes.ID_TOKEN,
|
ResponseTypes.ID_TOKEN,
|
||||||
|
|
Reference in New Issue