providers/oauth2: exclude at_hash claim if not set instead of being null

closes #3739

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2022-10-07 10:10:53 +03:00
parent f60f4c6fc7
commit ce085a029d
5 changed files with 12 additions and 23 deletions

View file

@ -1,5 +1,4 @@
"""OAuth2Provider API Views""" """OAuth2Provider API Views"""
from dataclasses import asdict
from json import dumps from json import dumps
from django_filters.rest_framework import DjangoFilterBackend from django_filters.rest_framework import DjangoFilterBackend
@ -38,7 +37,7 @@ class RefreshTokenModelSerializer(ExpiringBaseGrantModelSerializer):
def get_id_token(self, instance: RefreshToken) -> str: def get_id_token(self, instance: RefreshToken) -> str:
"""Get the token's id_token as JSON String""" """Get the token's id_token as JSON String"""
return dumps(asdict(instance.id_token), indent=4) return dumps(instance.id_token.to_dict(), indent=4)
class Meta: class Meta:

View file

@ -399,10 +399,13 @@ class IDToken:
def to_dict(self) -> dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Convert dataclass to dict, and update with keys from `claims`""" """Convert dataclass to dict, and update with keys from `claims`"""
dic = asdict(self) id_dict = asdict(self)
dic.pop("claims") # at_hash should be omitted when not set instead of retuning a null claim
dic.update(self.claims) if not self.at_hash:
return dic id_dict.pop("at_hash")
id_dict.pop("claims")
id_dict.update(self.claims)
return id_dict
class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel): class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
@ -432,7 +435,7 @@ class RefreshToken(SerializerModel, ExpiringModel, BaseGrantModel):
@id_token.setter @id_token.setter
def id_token(self, value: IDToken): def id_token(self, value: IDToken):
self._id_token = json.dumps(asdict(value)) self._id_token = json.dumps(value.to_dict())
def __str__(self): def __str__(self):
return f"Refresh Token for {self.provider} for user {self.user}" return f"Refresh Token for {self.provider} for user {self.user}"

View file

@ -1,7 +1,6 @@
"""Test introspect view""" """Test introspect view"""
import json import json
from base64 import b64encode from base64 import b64encode
from dataclasses import asdict
from django.urls import reverse from django.urls import reverse
@ -37,9 +36,7 @@ class TesOAuth2Introspection(OAuthTestCase):
refresh_token=generate_id(), refresh_token=generate_id(),
_scope="openid user profile", _scope="openid user profile",
_id_token=json.dumps( _id_token=json.dumps(
asdict( IDToken("foo", "bar").to_dict(),
IDToken("foo", "bar"),
)
), ),
) )
self.auth = b64encode( self.auth = b64encode(

View file

@ -1,7 +1,6 @@
"""Test revoke view""" """Test revoke view"""
import json import json
from base64 import b64encode from base64 import b64encode
from dataclasses import asdict
from django.urls import reverse from django.urls import reverse
@ -36,11 +35,7 @@ class TesOAuth2Revoke(OAuthTestCase):
access_token=generate_id(), access_token=generate_id(),
refresh_token=generate_id(), refresh_token=generate_id(),
_scope="openid user profile", _scope="openid user profile",
_id_token=json.dumps( _id_token=json.dumps(IDToken("foo", "bar").to_dict()),
asdict(
IDToken("foo", "bar"),
)
),
) )
self.auth = b64encode( self.auth = b64encode(
f"{self.provider.client_id}:{self.provider.client_secret}".encode() f"{self.provider.client_id}:{self.provider.client_secret}".encode()

View file

@ -1,6 +1,5 @@
"""Test userinfo view""" """Test userinfo view"""
import json import json
from dataclasses import asdict
from django.urls import reverse from django.urls import reverse
@ -39,11 +38,7 @@ class TestUserinfo(OAuthTestCase):
access_token=generate_id(), access_token=generate_id(),
refresh_token=generate_id(), refresh_token=generate_id(),
_scope="openid user profile", _scope="openid user profile",
_id_token=json.dumps( _id_token=json.dumps(IDToken("foo", "bar").to_dict()),
asdict(
IDToken("foo", "bar"),
)
),
) )
def test_userinfo_normal(self): def test_userinfo_normal(self):