providers/oauth2: add proper support for non-http schemes as redirect URIs

closes #772

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-04-23 16:34:52 +02:00
parent 5112ef9331
commit d616bdd5d6
3 changed files with 30 additions and 8 deletions

View File

@ -166,7 +166,7 @@ class TestViewsAuthorize(TestCase):
name="test", name="test",
client_id="test", client_id="test",
authorization_flow=flow, authorization_flow=flow,
redirect_uris="http://localhost", redirect_uris="foo://localhost",
) )
Application.objects.create(name="app", slug="app", provider=provider) Application.objects.create(name="app", slug="app", provider=provider)
state = generate_client_id() state = generate_client_id()
@ -179,7 +179,7 @@ class TestViewsAuthorize(TestCase):
"response_type": "code", "response_type": "code",
"client_id": "test", "client_id": "test",
"state": state, "state": state,
"redirect_uri": "http://localhost", "redirect_uri": "foo://localhost",
}, },
) )
response = self.client.get( response = self.client.get(
@ -190,7 +190,7 @@ class TestViewsAuthorize(TestCase):
force_str(response.content), force_str(response.content),
{ {
"type": ChallengeTypes.REDIRECT.value, "type": ChallengeTypes.REDIRECT.value,
"to": f"http://localhost?code={code.code}&state={state}", "to": f"foo://localhost?code={code.code}&state={state}",
}, },
) )

View File

@ -2,10 +2,11 @@
import re import re
from base64 import b64decode from base64 import b64decode
from binascii import Error from binascii import Error
from typing import Optional from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from django.http import HttpRequest, HttpResponse, JsonResponse from django.http import HttpRequest, HttpResponse, JsonResponse
from django.http.response import HttpResponseRedirect
from django.utils.cache import patch_vary_headers from django.utils.cache import patch_vary_headers
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -161,3 +162,18 @@ def protected_resource_view(scopes: list[str]):
return view_wrapper return view_wrapper
return wrapper return wrapper
class HttpResponseRedirectScheme(HttpResponseRedirect):
"""HTTP Response to redirect, can be to a non-http scheme"""
def __init__(
self,
redirect_to: str,
*args: Any,
allowed_schemes: Optional[list[str]] = None,
**kwargs: Any,
) -> None:
self.allowed_schemes = allowed_schemes or ["http", "https", "ftp"]
# pyright: reportGeneralTypeIssues=false
super().__init__(redirect_to, *args, **kwargs)

View File

@ -2,12 +2,12 @@
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from typing import Optional from typing import Optional
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit from urllib.parse import parse_qs, urlencode, urlparse, urlsplit, urlunsplit
from uuid import uuid4 from uuid import uuid4
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.http.response import Http404, HttpResponseBadRequest, HttpResponseRedirect from django.http.response import Http404, HttpResponseBadRequest, HttpResponseRedirect
from django.shortcuts import get_object_or_404, redirect from django.shortcuts import get_object_or_404
from django.utils import timezone from django.utils import timezone
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -46,6 +46,7 @@ from authentik.providers.oauth2.models import (
OAuth2Provider, OAuth2Provider,
ResponseTypes, ResponseTypes,
) )
from authentik.providers.oauth2.utils import HttpResponseRedirectScheme
from authentik.providers.oauth2.views.userinfo import UserInfoView from authentik.providers.oauth2.views.userinfo import UserInfoView
from authentik.stages.consent.models import ConsentMode, ConsentStage from authentik.stages.consent.models import ConsentMode, ConsentStage
from authentik.stages.consent.stage import ( from authentik.stages.consent.stage import (
@ -233,6 +234,11 @@ class OAuthFulfillmentStage(StageView):
params: OAuthAuthorizationParams params: OAuthAuthorizationParams
provider: OAuth2Provider provider: OAuth2Provider
def redirect(self, uri: str) -> HttpResponse:
"""Redirect using HttpResponseRedirectScheme, compatible with non-http schemes"""
parsed = urlparse(uri)
return HttpResponseRedirectScheme(uri, allowed_schemes=[parsed.scheme])
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""final Stage of an OAuth2 Flow""" """final Stage of an OAuth2 Flow"""
@ -261,7 +267,7 @@ class OAuthFulfillmentStage(StageView):
flow=self.executor.plan.flow_pk, flow=self.executor.plan.flow_pk,
scopes=", ".join(self.params.scope), scopes=", ".join(self.params.scope),
).from_http(self.request) ).from_http(self.request)
return redirect(self.create_response_uri()) return self.redirect(self.create_response_uri())
except (ClientIdError, RedirectUriError) as error: except (ClientIdError, RedirectUriError) as error:
error.to_event(application=application).from_http(request) error.to_event(application=application).from_http(request)
self.executor.stage_invalid() self.executor.stage_invalid()
@ -270,7 +276,7 @@ class OAuthFulfillmentStage(StageView):
except AuthorizeError as error: except AuthorizeError as error:
error.to_event(application=application).from_http(request) error.to_event(application=application).from_http(request)
self.executor.stage_invalid() self.executor.stage_invalid()
return redirect(error.create_uri()) return self.redirect(error.create_uri())
def create_response_uri(self) -> str: def create_response_uri(self) -> str:
"""Create a final Response URI the user is redirected to.""" """Create a final Response URI the user is redirected to."""