providers/oauth2: remember session_id from initial token (#7976)
* providers/oauth2: remember session_id original token was created with for future access/refresh tokens Signed-off-by: Jens Langhammer <jens@goauthentik.io> * providers/proxy: use hashed session as `sid` Signed-off-by: Jens Langhammer <jens@goauthentik.io> --------- Signed-off-by: Jens Langhammer <jens@goauthentik.io>
This commit is contained in:
parent
23c03d454e
commit
9a261c52d1
|
@ -0,0 +1,27 @@
|
|||
# Generated by Django 5.0 on 2023-12-22 23:20
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_providers_oauth2", "0016_alter_refreshtoken_token"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AddField(
|
||||
model_name="accesstoken",
|
||||
name="session_id",
|
||||
field=models.CharField(blank=True, default=""),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="authorizationcode",
|
||||
name="session_id",
|
||||
field=models.CharField(blank=True, default=""),
|
||||
),
|
||||
migrations.AddField(
|
||||
model_name="refreshtoken",
|
||||
name="session_id",
|
||||
field=models.CharField(blank=True, default=""),
|
||||
),
|
||||
]
|
|
@ -296,6 +296,7 @@ class BaseGrantModel(models.Model):
|
|||
revoked = models.BooleanField(default=False)
|
||||
_scope = models.TextField(default="", verbose_name=_("Scopes"))
|
||||
auth_time = models.DateTimeField(verbose_name="Authentication time")
|
||||
session_id = models.CharField(default="", blank=True)
|
||||
|
||||
@property
|
||||
def scope(self) -> list[str]:
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""authentik OAuth2 Authorization views"""
|
||||
from dataclasses import dataclass, field
|
||||
from datetime import timedelta
|
||||
from hashlib import sha256
|
||||
from json import dumps
|
||||
from re import error as RegexError
|
||||
from re import fullmatch
|
||||
|
@ -282,6 +283,7 @@ class OAuthAuthorizationParams:
|
|||
expires=now + timedelta_from_string(self.provider.access_code_validity),
|
||||
scope=self.scope,
|
||||
nonce=self.nonce,
|
||||
session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
|
||||
)
|
||||
|
||||
if self.code_challenge and self.code_challenge_method:
|
||||
|
@ -569,6 +571,7 @@ class OAuthFulfillmentStage(StageView):
|
|||
expires=access_token_expiry,
|
||||
provider=self.provider,
|
||||
auth_time=auth_event.created if auth_event else now,
|
||||
session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
|
||||
)
|
||||
|
||||
id_token = IDToken.new(self.provider, token, self.request)
|
||||
|
|
|
@ -487,6 +487,7 @@ class TokenView(View):
|
|||
# Keep same scopes as previous token
|
||||
scope=self.params.authorization_code.scope,
|
||||
auth_time=self.params.authorization_code.auth_time,
|
||||
session_id=self.params.authorization_code.session_id,
|
||||
)
|
||||
access_token.id_token = IDToken.new(
|
||||
self.provider,
|
||||
|
@ -502,6 +503,7 @@ class TokenView(View):
|
|||
expires=refresh_token_expiry,
|
||||
provider=self.provider,
|
||||
auth_time=self.params.authorization_code.auth_time,
|
||||
session_id=self.params.authorization_code.session_id,
|
||||
)
|
||||
id_token = IDToken.new(
|
||||
self.provider,
|
||||
|
@ -539,6 +541,7 @@ class TokenView(View):
|
|||
# Keep same scopes as previous token
|
||||
scope=self.params.refresh_token.scope,
|
||||
auth_time=self.params.refresh_token.auth_time,
|
||||
session_id=self.params.refresh_token.session_id,
|
||||
)
|
||||
access_token.id_token = IDToken.new(
|
||||
self.provider,
|
||||
|
@ -554,6 +557,7 @@ class TokenView(View):
|
|||
expires=refresh_token_expiry,
|
||||
provider=self.provider,
|
||||
auth_time=self.params.refresh_token.auth_time,
|
||||
session_id=self.params.refresh_token.session_id,
|
||||
)
|
||||
id_token = IDToken.new(
|
||||
self.provider,
|
||||
|
|
|
@ -1,4 +1,6 @@
|
|||
"""proxy provider tasks"""
|
||||
from hashlib import sha256
|
||||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.db import DatabaseError, InternalError, ProgrammingError
|
||||
|
@ -23,6 +25,7 @@ def proxy_set_defaults():
|
|||
def proxy_on_logout(session_id: str):
|
||||
"""Update outpost instances connected to a single outpost"""
|
||||
layer = get_channel_layer()
|
||||
hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()
|
||||
for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
|
||||
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
|
||||
async_to_sync(layer.group_send)(
|
||||
|
@ -30,6 +33,6 @@ def proxy_on_logout(session_id: str):
|
|||
{
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "logout",
|
||||
"session_id": session_id,
|
||||
"session_id": hashed_session_id,
|
||||
},
|
||||
)
|
||||
|
|
|
@ -15,7 +15,7 @@ entries:
|
|||
# This mapping is used by the authentik proxy. It passes extra user attributes,
|
||||
# which are used for example for the HTTP-Basic Authentication mapping.
|
||||
return {
|
||||
"sid": request.http_request.session.session_key,
|
||||
"sid": token.session_id,
|
||||
"ak_proxy": {
|
||||
"user_attributes": request.user.group_attributes(request),
|
||||
"is_superuser": request.user.is_superuser,
|
||||
|
|
|
@ -36,6 +36,7 @@ func (ps *ProxyServer) handleWSMessage(ctx context.Context, args map[string]inte
|
|||
switch msg.SubType {
|
||||
case WSProviderSubTypeLogout:
|
||||
for _, p := range ps.apps {
|
||||
ps.log.WithField("provider", p.Host).Debug("Logging out")
|
||||
err := p.Logout(ctx, func(c application.Claims) bool {
|
||||
return c.Sid == msg.SessionID
|
||||
})
|
||||
|
|
Reference in New Issue