sources/oauth: use GitHub's dedicated email API when no public email address is configured
closes #3472 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
2868331976
commit
83eaac375d
|
@ -12,8 +12,6 @@ from authentik.events.models import Event, EventAction
|
|||
from authentik.lib.utils.http import get_http_session
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class BaseOAuthClient:
|
||||
"""Base OAuth Client"""
|
||||
|
@ -30,6 +28,7 @@ class BaseOAuthClient:
|
|||
self.session = get_http_session()
|
||||
self.request = request
|
||||
self.callback = callback
|
||||
self.logger = get_logger().bind(source=source.slug)
|
||||
|
||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||
"""Fetch access token from callback request."""
|
||||
|
@ -44,7 +43,7 @@ class BaseOAuthClient:
|
|||
response = self.do_request("get", profile_url, token=token)
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)
|
||||
self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text)
|
||||
return None
|
||||
else:
|
||||
return response.json()
|
||||
|
@ -73,7 +72,7 @@ class BaseOAuthClient:
|
|||
# to make additional scopes easier
|
||||
args["scope"] = " ".join(sorted(set(args["scope"])))
|
||||
params = urlencode(args, quote_via=quote, doseq=True)
|
||||
LOGGER.info("redirect args", **args)
|
||||
self.logger.info("redirect args", **args)
|
||||
return urlunparse(parsed_url._replace(query=params))
|
||||
|
||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||
|
|
|
@ -1,6 +1,10 @@
|
|||
"""GitHub Type tests"""
|
||||
from django.test import TestCase
|
||||
from copy import copy
|
||||
|
||||
from django.test import RequestFactory, TestCase
|
||||
from requests_mock import Mocker
|
||||
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.github import GitHubOAuth2Callback
|
||||
|
||||
|
@ -55,11 +59,9 @@ class TestTypeGitHub(TestCase):
|
|||
self.source = OAuthSource.objects.create(
|
||||
name="test",
|
||||
slug="test",
|
||||
provider_type="openidconnect",
|
||||
authorization_url="",
|
||||
profile_url="",
|
||||
consumer_key="",
|
||||
provider_type="github",
|
||||
)
|
||||
self.factory = RequestFactory()
|
||||
|
||||
def test_enroll_context(self):
|
||||
"""Test GitHub Enrollment context"""
|
||||
|
@ -67,3 +69,30 @@ class TestTypeGitHub(TestCase):
|
|||
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
|
||||
self.assertEqual(ak_context["email"], GITHUB_USER["email"])
|
||||
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
|
||||
|
||||
def test_enroll_context_email(self):
|
||||
"""Test GitHub Enrollment context"""
|
||||
email = generate_id()
|
||||
user = copy(GITHUB_USER)
|
||||
del user["email"]
|
||||
with Mocker() as mocker:
|
||||
mocker.get(
|
||||
"https://api.github.com/user/emails",
|
||||
json=[
|
||||
{
|
||||
"primary": True,
|
||||
"email": email,
|
||||
}
|
||||
],
|
||||
)
|
||||
ak_context = GitHubOAuth2Callback(
|
||||
source=self.source,
|
||||
request=self.factory.get("/"),
|
||||
token={
|
||||
"access_token": generate_id(),
|
||||
"token_type": generate_id(),
|
||||
},
|
||||
).get_user_enroll_context(user)
|
||||
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
|
||||
self.assertEqual(ak_context["email"], email)
|
||||
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
|
||||
|
|
|
@ -1,6 +1,9 @@
|
|||
"""GitHub OAuth Views"""
|
||||
from typing import Any
|
||||
from typing import Any, Optional
|
||||
|
||||
from requests.exceptions import RequestException
|
||||
|
||||
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||
|
@ -15,16 +18,47 @@ class GitHubOAuthRedirect(OAuthRedirect):
|
|||
}
|
||||
|
||||
|
||||
class GitHubOAuth2Client(OAuth2Client):
|
||||
"""GitHub OAuth2 Client"""
|
||||
|
||||
def get_github_emails(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||
"""Get Emails from the GitHub API"""
|
||||
profile_url = self.source.type.profile_url or ""
|
||||
if self.source.type.urls_customizable and self.source.profile_url:
|
||||
profile_url = self.source.profile_url
|
||||
profile_url += "/emails"
|
||||
try:
|
||||
response = self.do_request("get", profile_url, token=token)
|
||||
response.raise_for_status()
|
||||
except RequestException as exc:
|
||||
self.logger.warning("Unable to fetch github emails", exc=exc)
|
||||
return []
|
||||
else:
|
||||
return response.json()
|
||||
|
||||
|
||||
class GitHubOAuth2Callback(OAuthCallback):
|
||||
"""GitHub OAuth2 Callback"""
|
||||
|
||||
client_class = GitHubOAuth2Client
|
||||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
chosen_email = info.get("email")
|
||||
if not chosen_email:
|
||||
# The GitHub Userprofile API only returns an email address if the profile
|
||||
# has a public email address set (despite us asking for user:email, this behaviour
|
||||
# doesn't change.). So we fetch all the user's email addresses
|
||||
client: GitHubOAuth2Client = self.get_client(self.source)
|
||||
emails = client.get_github_emails(self.token)
|
||||
for email in emails:
|
||||
if email.get("primary", False):
|
||||
chosen_email = email.get("email", None)
|
||||
return {
|
||||
"username": info.get("login"),
|
||||
"email": info.get("email"),
|
||||
"email": chosen_email,
|
||||
"name": info.get("name"),
|
||||
}
|
||||
|
||||
|
|
|
@ -22,6 +22,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
|||
"Base OAuth callback view."
|
||||
|
||||
source: OAuthSource
|
||||
token: Optional[dict] = None
|
||||
|
||||
# pylint: disable=too-many-return-statements
|
||||
def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
|
||||
|
@ -36,14 +37,14 @@ class OAuthCallback(OAuthClientMixin, View):
|
|||
raise Http404(f"Source {slug} is not enabled.")
|
||||
client = self.get_client(self.source, callback=self.get_callback_url(self.source))
|
||||
# Fetch access token
|
||||
token = client.get_access_token()
|
||||
if token is None:
|
||||
self.token = client.get_access_token()
|
||||
if self.token is None:
|
||||
return self.handle_login_failure("Could not retrieve token.")
|
||||
if "error" in token:
|
||||
return self.handle_login_failure(token["error"])
|
||||
if "error" in self.token:
|
||||
return self.handle_login_failure(self.token["error"])
|
||||
# Fetch profile info
|
||||
try:
|
||||
raw_info = client.get_profile_info(token)
|
||||
raw_info = client.get_profile_info(self.token)
|
||||
if raw_info is None:
|
||||
return self.handle_login_failure("Could not retrieve profile.")
|
||||
except JSONDecodeError as exc:
|
||||
|
@ -66,7 +67,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
|||
)
|
||||
sfm.policy_context = {"oauth_userinfo": raw_info}
|
||||
return sfm.get_flow(
|
||||
access_token=token.get("access_token"),
|
||||
access_token=self.token.get("access_token"),
|
||||
)
|
||||
|
||||
# pylint: disable=unused-argument
|
||||
|
|
Reference in New Issue