sources/oauth: use GitHub's dedicated email API when no public email address is configured
closes #3472 Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
parent
2868331976
commit
83eaac375d
|
@ -12,8 +12,6 @@ from authentik.events.models import Event, EventAction
|
||||||
from authentik.lib.utils.http import get_http_session
|
from authentik.lib.utils.http import get_http_session
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
|
|
||||||
LOGGER = get_logger()
|
|
||||||
|
|
||||||
|
|
||||||
class BaseOAuthClient:
|
class BaseOAuthClient:
|
||||||
"""Base OAuth Client"""
|
"""Base OAuth Client"""
|
||||||
|
@ -30,6 +28,7 @@ class BaseOAuthClient:
|
||||||
self.session = get_http_session()
|
self.session = get_http_session()
|
||||||
self.request = request
|
self.request = request
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
|
self.logger = get_logger().bind(source=source.slug)
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||||
"""Fetch access token from callback request."""
|
"""Fetch access token from callback request."""
|
||||||
|
@ -44,7 +43,7 @@ class BaseOAuthClient:
|
||||||
response = self.do_request("get", profile_url, token=token)
|
response = self.do_request("get", profile_url, token=token)
|
||||||
response.raise_for_status()
|
response.raise_for_status()
|
||||||
except RequestException as exc:
|
except RequestException as exc:
|
||||||
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)
|
self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text)
|
||||||
return None
|
return None
|
||||||
else:
|
else:
|
||||||
return response.json()
|
return response.json()
|
||||||
|
@ -73,7 +72,7 @@ class BaseOAuthClient:
|
||||||
# to make additional scopes easier
|
# to make additional scopes easier
|
||||||
args["scope"] = " ".join(sorted(set(args["scope"])))
|
args["scope"] = " ".join(sorted(set(args["scope"])))
|
||||||
params = urlencode(args, quote_via=quote, doseq=True)
|
params = urlencode(args, quote_via=quote, doseq=True)
|
||||||
LOGGER.info("redirect args", **args)
|
self.logger.info("redirect args", **args)
|
||||||
return urlunparse(parsed_url._replace(query=params))
|
return urlunparse(parsed_url._replace(query=params))
|
||||||
|
|
||||||
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||||
|
|
|
@ -1,6 +1,10 @@
|
||||||
"""GitHub Type tests"""
|
"""GitHub Type tests"""
|
||||||
from django.test import TestCase
|
from copy import copy
|
||||||
|
|
||||||
|
from django.test import RequestFactory, TestCase
|
||||||
|
from requests_mock import Mocker
|
||||||
|
|
||||||
|
from authentik.lib.generators import generate_id
|
||||||
from authentik.sources.oauth.models import OAuthSource
|
from authentik.sources.oauth.models import OAuthSource
|
||||||
from authentik.sources.oauth.types.github import GitHubOAuth2Callback
|
from authentik.sources.oauth.types.github import GitHubOAuth2Callback
|
||||||
|
|
||||||
|
@ -55,11 +59,9 @@ class TestTypeGitHub(TestCase):
|
||||||
self.source = OAuthSource.objects.create(
|
self.source = OAuthSource.objects.create(
|
||||||
name="test",
|
name="test",
|
||||||
slug="test",
|
slug="test",
|
||||||
provider_type="openidconnect",
|
provider_type="github",
|
||||||
authorization_url="",
|
|
||||||
profile_url="",
|
|
||||||
consumer_key="",
|
|
||||||
)
|
)
|
||||||
|
self.factory = RequestFactory()
|
||||||
|
|
||||||
def test_enroll_context(self):
|
def test_enroll_context(self):
|
||||||
"""Test GitHub Enrollment context"""
|
"""Test GitHub Enrollment context"""
|
||||||
|
@ -67,3 +69,30 @@ class TestTypeGitHub(TestCase):
|
||||||
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
|
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
|
||||||
self.assertEqual(ak_context["email"], GITHUB_USER["email"])
|
self.assertEqual(ak_context["email"], GITHUB_USER["email"])
|
||||||
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
|
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
|
||||||
|
|
||||||
|
def test_enroll_context_email(self):
|
||||||
|
"""Test GitHub Enrollment context"""
|
||||||
|
email = generate_id()
|
||||||
|
user = copy(GITHUB_USER)
|
||||||
|
del user["email"]
|
||||||
|
with Mocker() as mocker:
|
||||||
|
mocker.get(
|
||||||
|
"https://api.github.com/user/emails",
|
||||||
|
json=[
|
||||||
|
{
|
||||||
|
"primary": True,
|
||||||
|
"email": email,
|
||||||
|
}
|
||||||
|
],
|
||||||
|
)
|
||||||
|
ak_context = GitHubOAuth2Callback(
|
||||||
|
source=self.source,
|
||||||
|
request=self.factory.get("/"),
|
||||||
|
token={
|
||||||
|
"access_token": generate_id(),
|
||||||
|
"token_type": generate_id(),
|
||||||
|
},
|
||||||
|
).get_user_enroll_context(user)
|
||||||
|
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
|
||||||
|
self.assertEqual(ak_context["email"], email)
|
||||||
|
self.assertEqual(ak_context["name"], GITHUB_USER["name"])
|
||||||
|
|
|
@ -1,6 +1,9 @@
|
||||||
"""GitHub OAuth Views"""
|
"""GitHub OAuth Views"""
|
||||||
from typing import Any
|
from typing import Any, Optional
|
||||||
|
|
||||||
|
from requests.exceptions import RequestException
|
||||||
|
|
||||||
|
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
from authentik.sources.oauth.types.manager import MANAGER, SourceType
|
||||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
from authentik.sources.oauth.views.redirect import OAuthRedirect
|
||||||
|
@ -15,16 +18,47 @@ class GitHubOAuthRedirect(OAuthRedirect):
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
|
class GitHubOAuth2Client(OAuth2Client):
|
||||||
|
"""GitHub OAuth2 Client"""
|
||||||
|
|
||||||
|
def get_github_emails(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||||
|
"""Get Emails from the GitHub API"""
|
||||||
|
profile_url = self.source.type.profile_url or ""
|
||||||
|
if self.source.type.urls_customizable and self.source.profile_url:
|
||||||
|
profile_url = self.source.profile_url
|
||||||
|
profile_url += "/emails"
|
||||||
|
try:
|
||||||
|
response = self.do_request("get", profile_url, token=token)
|
||||||
|
response.raise_for_status()
|
||||||
|
except RequestException as exc:
|
||||||
|
self.logger.warning("Unable to fetch github emails", exc=exc)
|
||||||
|
return []
|
||||||
|
else:
|
||||||
|
return response.json()
|
||||||
|
|
||||||
|
|
||||||
class GitHubOAuth2Callback(OAuthCallback):
|
class GitHubOAuth2Callback(OAuthCallback):
|
||||||
"""GitHub OAuth2 Callback"""
|
"""GitHub OAuth2 Callback"""
|
||||||
|
|
||||||
|
client_class = GitHubOAuth2Client
|
||||||
|
|
||||||
def get_user_enroll_context(
|
def get_user_enroll_context(
|
||||||
self,
|
self,
|
||||||
info: dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
|
chosen_email = info.get("email")
|
||||||
|
if not chosen_email:
|
||||||
|
# The GitHub Userprofile API only returns an email address if the profile
|
||||||
|
# has a public email address set (despite us asking for user:email, this behaviour
|
||||||
|
# doesn't change.). So we fetch all the user's email addresses
|
||||||
|
client: GitHubOAuth2Client = self.get_client(self.source)
|
||||||
|
emails = client.get_github_emails(self.token)
|
||||||
|
for email in emails:
|
||||||
|
if email.get("primary", False):
|
||||||
|
chosen_email = email.get("email", None)
|
||||||
return {
|
return {
|
||||||
"username": info.get("login"),
|
"username": info.get("login"),
|
||||||
"email": info.get("email"),
|
"email": chosen_email,
|
||||||
"name": info.get("name"),
|
"name": info.get("name"),
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
@ -22,6 +22,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||||
"Base OAuth callback view."
|
"Base OAuth callback view."
|
||||||
|
|
||||||
source: OAuthSource
|
source: OAuthSource
|
||||||
|
token: Optional[dict] = None
|
||||||
|
|
||||||
# pylint: disable=too-many-return-statements
|
# pylint: disable=too-many-return-statements
|
||||||
def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
|
def dispatch(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
|
||||||
|
@ -36,14 +37,14 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||||
raise Http404(f"Source {slug} is not enabled.")
|
raise Http404(f"Source {slug} is not enabled.")
|
||||||
client = self.get_client(self.source, callback=self.get_callback_url(self.source))
|
client = self.get_client(self.source, callback=self.get_callback_url(self.source))
|
||||||
# Fetch access token
|
# Fetch access token
|
||||||
token = client.get_access_token()
|
self.token = client.get_access_token()
|
||||||
if token is None:
|
if self.token is None:
|
||||||
return self.handle_login_failure("Could not retrieve token.")
|
return self.handle_login_failure("Could not retrieve token.")
|
||||||
if "error" in token:
|
if "error" in self.token:
|
||||||
return self.handle_login_failure(token["error"])
|
return self.handle_login_failure(self.token["error"])
|
||||||
# Fetch profile info
|
# Fetch profile info
|
||||||
try:
|
try:
|
||||||
raw_info = client.get_profile_info(token)
|
raw_info = client.get_profile_info(self.token)
|
||||||
if raw_info is None:
|
if raw_info is None:
|
||||||
return self.handle_login_failure("Could not retrieve profile.")
|
return self.handle_login_failure("Could not retrieve profile.")
|
||||||
except JSONDecodeError as exc:
|
except JSONDecodeError as exc:
|
||||||
|
@ -66,7 +67,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||||
)
|
)
|
||||||
sfm.policy_context = {"oauth_userinfo": raw_info}
|
sfm.policy_context = {"oauth_userinfo": raw_info}
|
||||||
return sfm.get_flow(
|
return sfm.get_flow(
|
||||||
access_token=token.get("access_token"),
|
access_token=self.token.get("access_token"),
|
||||||
)
|
)
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
|
|
Reference in New Issue