diff --git a/authentik/sources/oauth/types/azure_ad.py b/authentik/sources/oauth/types/azure_ad.py index 7d5dc02fb..879664d33 100644 --- a/authentik/sources/oauth/types/azure_ad.py +++ b/authentik/sources/oauth/types/azure_ad.py @@ -2,13 +2,46 @@ from typing import Any, Optional from uuid import UUID +from requests.exceptions import RequestException +from structlog.stdlib import get_logger + +from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.views.callback import OAuthCallback +LOGGER = get_logger() + + +class AzureADClient(OAuth2Client): + """Azure AD Oauth client, azure ad doesn't like the ?access_token that is sent by default""" + + def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: + "Fetch user profile information." + profile_url = self.source.type.profile_url or "" + if self.source.type.urls_customizable and self.source.profile_url: + profile_url = self.source.profile_url + try: + response = self.session.request( + "get", + profile_url, + headers={ + "Authorization": f"{token['token_type']} {token['access_token']}" + }, + ) + LOGGER.debug(response.text) + response.raise_for_status() + except RequestException as exc: + LOGGER.warning("Unable to fetch user profile", exc=exc) + return None + else: + return response.json() + class AzureADOAuthCallback(OAuthCallback): """AzureAD OAuth2 Callback""" + client_class = AzureADClient + def get_user_id(self, info: dict[str, Any]) -> Optional[str]: try: return str(UUID(info.get("objectId")).int)