From 64d7b009abe1a01f1a18cb9682e51f7902171f17 Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sun, 23 Feb 2020 19:42:57 +0100 Subject: [PATCH] sources/oauth: fix invalid headers, fix invalid function signature --- passbook/sources/oauth/clients.py | 69 +++++++++++++++++----------- passbook/sources/oauth/forms.py | 2 +- passbook/sources/oauth/views/core.py | 16 ++++--- 3 files changed, 52 insertions(+), 35 deletions(-) diff --git a/passbook/sources/oauth/clients.py b/passbook/sources/oauth/clients.py index 9eae30afb..f47380a9a 100644 --- a/passbook/sources/oauth/clients.py +++ b/passbook/sources/oauth/clients.py @@ -1,8 +1,9 @@ """OAuth Clients""" import json -from typing import Dict +from typing import Dict, Optional from urllib.parse import parse_qs, urlencode +from django.http import HttpRequest from django.utils.crypto import constant_time_compare, get_random_string from django.utils.encoding import force_text from requests import Session @@ -18,30 +19,26 @@ LOGGER = get_logger() class BaseOAuthClient: """Base OAuth Client""" - _session: Session = None + session: Session = None def __init__(self, source, token=""): # nosec self.source = source self.token = token - self._session = Session() - self._session.headers.update({"User-Agent": "passbook %s" % __version__}) + self.session = Session() + self.session.headers.update({"User-Agent": "passbook %s" % __version__}) def get_access_token(self, request, callback=None): "Fetch access token from callback request." raise NotImplementedError("Defined in a sub-class") # pragma: no cover - def get_profile_info(self, raw_token): + def get_profile_info(self, token: Dict[str, str]): "Fetch user profile information." try: - token = json.loads(raw_token) headers = { "Authorization": f"{token['token_type']} {token['access_token']}" } - response = self.request( - "get", - self.source.profile_url, - token=token["access_token"], - headers=headers, + response = self.session.request( + "get", self.source.profile_url, headers=headers, ) response.raise_for_status() except RequestException as exc: @@ -67,10 +64,6 @@ class BaseOAuthClient: "Parse token and secret from raw token response." raise NotImplementedError("Defined in a sub-class") # pragma: no cover - def request(self, method, url, **kwargs): - "Build remote url request." - return self._session.request(method, url, **kwargs) - @property def session_key(self): """Return Session Key""" @@ -80,36 +73,48 @@ class BaseOAuthClient: class OAuthClient(BaseOAuthClient): """OAuth1 Client""" - def get_access_token(self, request, callback=None): + _default_headers = { + "Accept": "application/json", + } + + def get_access_token( + self, request: HttpRequest, callback=None + ) -> Optional[Dict[str, str]]: "Fetch access token from callback request." raw_token = request.session.get(self.session_key, None) verifier = request.GET.get("oauth_verifier", None) if raw_token is not None and verifier is not None: - data = {"oauth_verifier": verifier} + data = { + "oauth_verifier": verifier, + "oauth_callback": callback, + "token": raw_token, + } callback = request.build_absolute_uri(callback or request.path) callback = force_text(callback) try: - response = self.request( + response = self.session.request( "post", self.source.access_token_url, - token=raw_token, data=data, - oauth_callback=callback, + headers=self._default_headers, ) response.raise_for_status() except RequestException as exc: LOGGER.warning("Unable to fetch access token", exc=exc) return None else: - return response.text + return response.json() return None def get_request_token(self, request, callback): "Fetch the OAuth request token. Only required for OAuth 1.0." callback = force_text(request.build_absolute_uri(callback)) try: - response = self.request( - "post", self.source.request_token_url, oauth_callback=callback + response = self.session.request( + "post", + self.source.request_token_url, + data={"oauth_callback": callback}, + headers=self._default_headers, ) response.raise_for_status() except RequestException as exc: @@ -154,7 +159,7 @@ class OAuthClient(BaseOAuthClient): callback_uri=callback, ) kwargs["auth"] = oauth - return super(OAuthClient, self).request(method, url, **kwargs) + return super(OAuthClient, self).session.request(method, url, **kwargs) @property def session_key(self): @@ -164,6 +169,10 @@ class OAuthClient(BaseOAuthClient): class OAuth2Client(BaseOAuthClient): """OAuth2 Client""" + _default_headers = { + "Accept": "application/json", + } + # pylint: disable=unused-argument def check_application_state(self, request, callback): "Check optional state parameter." @@ -197,15 +206,19 @@ class OAuth2Client(BaseOAuthClient): LOGGER.warning("No code returned by the source") return None try: - response = self.request( - "post", self.source.access_token_url, data=args, **request_kwargs + response = self.session.request( + "post", + self.source.access_token_url, + data=args, + headers=self._default_headers, + **request_kwargs, ) response.raise_for_status() except RequestException as exc: LOGGER.warning("Unable to fetch access token", exc=exc) return None else: - return response.text + return response.json() # pylint: disable=unused-argument def get_application_state(self, request, callback): @@ -247,7 +260,7 @@ class OAuth2Client(BaseOAuthClient): params = kwargs.get("params", {}) params["access_token"] = token kwargs["params"] = params - return super(OAuth2Client, self).request(method, url, **kwargs) + return super(OAuth2Client, self).session.request(method, url, **kwargs) @property def session_key(self): diff --git a/passbook/sources/oauth/forms.py b/passbook/sources/oauth/forms.py index dea11724b..e1e0de903 100644 --- a/passbook/sources/oauth/forms.py +++ b/passbook/sources/oauth/forms.py @@ -116,7 +116,7 @@ class AzureADOAuthSourceForm(OAuthSourceForm): class Meta(OAuthSourceForm.Meta): overrides = { - "provider_type": "azure_ad", + "provider_type": "azure-ad", "request_token_url": "", "authorization_url": "https://login.microsoftonline.com/common/oauth2/authorize", "access_token_url": "https://login.microsoftonline.com/common/oauth2/token", diff --git a/passbook/sources/oauth/views/core.py b/passbook/sources/oauth/views/core.py index 0ac48a28b..5c0508c0a 100644 --- a/passbook/sources/oauth/views/core.py +++ b/passbook/sources/oauth/views/core.py @@ -89,13 +89,15 @@ class OAuthCallback(OAuthClientMixin, View): client = self.get_client(self.source) callback = self.get_callback_url(self.source) # Fetch access token - raw_token = client.get_access_token(self.request, callback=callback) - if raw_token is None: + token = client.get_access_token(self.request, callback=callback) + if token is None: return self.handle_login_failure( self.source, "Could not retrieve token." ) + if "error" in token: + return self.handle_login_failure(self.source, token["error"]) # Fetch profile info - info = client.get_profile_info(raw_token) + info = client.get_profile_info(token) if info is None: return self.handle_login_failure( self.source, "Could not retrieve profile." @@ -105,7 +107,7 @@ class OAuthCallback(OAuthClientMixin, View): return self.handle_login_failure(self.source, "Could not determine id.") # Get or create access record defaults = { - "access_token": raw_token, + "access_token": token.get("access_token"), } existing = UserOAuthSourceConnection.objects.filter( source=self.source, identifier=identifier @@ -113,13 +115,15 @@ class OAuthCallback(OAuthClientMixin, View): if existing.exists(): connection = existing.first() - connection.access_token = raw_token + connection.access_token = token.get("access_token") UserOAuthSourceConnection.objects.filter(pk=connection.pk).update( **defaults ) else: connection = UserOAuthSourceConnection( - source=self.source, identifier=identifier, access_token=raw_token + source=self.source, + identifier=identifier, + access_token=token.get("access_token"), ) user = authenticate( source=self.source, identifier=identifier, request=request