From d9c2b32cbafb96265d4415d3c667d0c889744fdd Mon Sep 17 00:00:00 2001 From: Jens Langhammer Date: Sat, 26 Sep 2020 00:34:57 +0200 Subject: [PATCH] sources/oauth: cleanup clients, add type annotations --- passbook/sources/oauth/clients/base.py | 43 +++++++++++------- passbook/sources/oauth/clients/oauth1.py | 49 ++++++++++---------- passbook/sources/oauth/clients/oauth2.py | 57 +++++++++++------------- passbook/sources/oauth/exceptions.py | 5 +++ passbook/sources/oauth/types/reddit.py | 4 +- passbook/sources/oauth/views/base.py | 12 +++-- passbook/sources/oauth/views/callback.py | 2 +- passbook/sources/oauth/views/redirect.py | 7 +-- 8 files changed, 94 insertions(+), 85 deletions(-) create mode 100644 passbook/sources/oauth/exceptions.py diff --git a/passbook/sources/oauth/clients/base.py b/passbook/sources/oauth/clients/base.py index 285cdbd69..d97b2ab47 100644 --- a/passbook/sources/oauth/clients/base.py +++ b/passbook/sources/oauth/clients/base.py @@ -1,34 +1,41 @@ """OAuth Clients""" -from typing import TYPE_CHECKING, Any, Dict, Optional +from typing import Any, Dict, Optional, Tuple from urllib.parse import urlencode from django.http import HttpRequest from requests import Session from requests.exceptions import RequestException +from requests.models import Response from structlog import get_logger from passbook import __version__ +from passbook.sources.oauth.models import OAuthSource LOGGER = get_logger() -if TYPE_CHECKING: - from passbook.sources.oauth.models import OAuthSource class BaseOAuthClient: """Base OAuth Client""" session: Session - source: "OAuthSource" + source: OAuthSource - def __init__(self, source: "OAuthSource", token=""): # nosec + token: str + + request: HttpRequest + callback: Optional[str] + + def __init__( + self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None + ): self.source = source - self.token = token + self.token = "" self.session = Session() - self.session.headers.update({"User-Agent": "passbook %s" % __version__}) + self.request = request + self.callback = callback + self.session.headers.update({"User-Agent": f"passbook {__version__}"}) - def get_access_token( - self, request: HttpRequest, callback=None - ) -> Optional[Dict[str, Any]]: + def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]: "Fetch access token from callback request." raise NotImplementedError("Defined in a sub-class") # pragma: no cover @@ -48,24 +55,28 @@ class BaseOAuthClient: else: return response.json() - def get_redirect_args(self, request, callback) -> Dict[str, str]: + def get_redirect_args(self) -> Dict[str, str]: "Get request parameters for redirect url." raise NotImplementedError("Defined in a sub-class") # pragma: no cover - def get_redirect_url(self, request: HttpRequest, callback: str, parameters=None): + def get_redirect_url(self, parameters=None): "Build authentication redirect url." - args = self.get_redirect_args(request, callback=callback) + args = self.get_redirect_args() additional = parameters or {} args.update(additional) params = urlencode(args) LOGGER.info("redirect args", **args) - return "{0}?{1}".format(self.source.authorization_url, params) + return f"{self.source.authorization_url}?{params}" - def parse_raw_token(self, raw_token): + def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]: "Parse token and secret from raw token response." raise NotImplementedError("Defined in a sub-class") # pragma: no cover + def do_request(self, method: str, url: str, **kwargs) -> Response: + """Wrapper around self.session.request, which can add special headers""" + return self.session.request(method, url, **kwargs) + @property - def session_key(self): + def session_key(self) -> str: """Return Session Key""" raise NotImplementedError("Defined in a sub-class") # pragma: no cover diff --git a/passbook/sources/oauth/clients/oauth1.py b/passbook/sources/oauth/clients/oauth1.py index 338844d8f..a91327759 100644 --- a/passbook/sources/oauth/clients/oauth1.py +++ b/passbook/sources/oauth/clients/oauth1.py @@ -1,15 +1,18 @@ """OAuth Clients""" -from typing import Dict, Optional +from typing import Any, Dict, Optional, Tuple from urllib.parse import parse_qs +from django.db.models.expressions import Value from django.http import HttpRequest from django.utils.encoding import force_str from requests.exceptions import RequestException +from requests.models import Response from requests_oauthlib import OAuth1 from structlog import get_logger from passbook import __version__ from passbook.sources.oauth.clients.base import BaseOAuthClient +from passbook.sources.oauth.exceptions import OAuthSourceException LOGGER = get_logger() @@ -21,16 +24,14 @@ class OAuthClient(BaseOAuthClient): "Accept": "application/json", } - def get_access_token( - self, request: HttpRequest, callback=None - ) -> Optional[Dict[str, str]]: + def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]: "Fetch access token from callback request." - raw_token = request.session.get(self.session_key, None) - verifier = request.GET.get("oauth_verifier", None) + raw_token = self.request.session.get(self.session_key, None) + verifier = self.request.GET.get("oauth_verifier", None) if raw_token is not None and verifier is not None: data = { "oauth_verifier": verifier, - "oauth_callback": callback, + "oauth_callback": self.callback, "token": raw_token, } try: @@ -48,9 +49,9 @@ class OAuthClient(BaseOAuthClient): return response.json() return None - def get_request_token(self, request: HttpRequest, callback): + def get_request_token(self) -> str: "Fetch the OAuth request token. Only required for OAuth 1.0." - callback = request.build_absolute_uri(callback) + callback = self.request.build_absolute_uri(self.callback) try: response = self.session.request( "post", @@ -63,33 +64,29 @@ class OAuthClient(BaseOAuthClient): ) response.raise_for_status() except RequestException as exc: - LOGGER.warning("Unable to fetch request token", exc=exc) - return None + raise OAuthSourceException from exc else: return response.text - def get_redirect_args(self, request: HttpRequest, callback): + def get_redirect_args(self) -> Dict[str, Any]: "Get request parameters for redirect url." - callback = force_str(request.build_absolute_uri(callback)) - raw_token = self.get_request_token(request, callback) - token, secret = self.parse_raw_token(raw_token) - if token is not None and secret is not None: - request.session[self.session_key] = raw_token + callback = self.request.build_absolute_uri(self.callback) + raw_token = self.get_request_token() + token, _ = self.parse_raw_token(raw_token) + self.request.session[self.session_key] = raw_token return { "oauth_token": token, "oauth_callback": callback, } - def parse_raw_token(self, raw_token): + def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]: "Parse token and secret from raw token response." - if raw_token is None: - return (None, None) query_string = parse_qs(raw_token) - token = query_string.get("oauth_token", [None])[0] - secret = query_string.get("oauth_token_secret", [None])[0] + token = query_string["oauth_token"][0] + secret = query_string["oauth_token_secret"][0] return (token, secret) - def request(self, method, url, **kwargs): + def do_request(self, method: str, url: str, **kwargs) -> Response: "Build remote url request. Constructs necessary auth." user_token = kwargs.pop("token", self.token) token, secret = self.parse_raw_token(user_token) @@ -104,8 +101,8 @@ class OAuthClient(BaseOAuthClient): callback_uri=callback, ) kwargs["auth"] = oauth - return super(OAuthClient, self).session.request(method, url, **kwargs) + return super().do_request(method, url, **kwargs) @property - def session_key(self): - return "oauth-client-{0}-request-token".format(self.source.name) + def session_key(self) -> str: + return f"oauth-client-{self.source.name}-request-token" diff --git a/passbook/sources/oauth/clients/oauth2.py b/passbook/sources/oauth/clients/oauth2.py index dfafdbadc..9562c1cca 100644 --- a/passbook/sources/oauth/clients/oauth2.py +++ b/passbook/sources/oauth/clients/oauth2.py @@ -1,11 +1,12 @@ """OAuth Clients""" import json -from typing import TYPE_CHECKING +from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple from urllib.parse import parse_qs from django.http import HttpRequest from django.utils.crypto import constant_time_compare, get_random_string from requests.exceptions import RequestException +from requests.models import Response from structlog import get_logger from passbook import __version__ @@ -23,11 +24,10 @@ class OAuth2Client(BaseOAuthClient): "Accept": "application/json", } - # pylint: disable=unused-argument - def check_application_state(self, request: HttpRequest, callback: str): + def check_application_state(self) -> bool: "Check optional state parameter." - stored = request.session.get(self.session_key, None) - returned = request.GET.get("state", None) + stored = self.request.session.get(self.session_key, None) + returned = self.request.GET.get("state", None) check = False if stored is not None: if returned is not None: @@ -35,21 +35,25 @@ class OAuth2Client(BaseOAuthClient): else: LOGGER.warning("No state parameter returned by the source.") else: - LOGGER.warning("No state stored in the sesssion.") + LOGGER.warning("No state stored in the session.") return check - def get_access_token(self, request: HttpRequest, callback=None, **request_kwargs): + def get_application_state(self) -> str: + "Generate state optional parameter." + return get_random_string(32) + + def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]: "Fetch access token from callback request." - callback = request.build_absolute_uri(callback or request.path) - if not self.check_application_state(request, callback): + callback = self.request.build_absolute_uri(self.callback or self.request.path) + if not self.check_application_state(): LOGGER.warning("Application state check failed.") return None - if "code" in request.GET: + if "code" in self.request.GET: args = { "client_id": self.source.consumer_key, "redirect_uri": callback, "client_secret": self.source.consumer_secret, - "code": request.GET["code"], + "code": self.request.GET["code"], "grant_type": "authorization_code", } else: @@ -61,7 +65,6 @@ class OAuth2Client(BaseOAuthClient): self.source.access_token_url, data=args, headers=self._default_headers, - **request_kwargs, ) response.raise_for_status() except RequestException as exc: @@ -70,39 +73,33 @@ class OAuth2Client(BaseOAuthClient): else: return response.json() - # pylint: disable=unused-argument - def get_application_state(self, request: HttpRequest, callback): - "Generate state optional parameter." - return get_random_string(32) - - def get_redirect_args(self, request, callback): + def get_redirect_args(self) -> Dict[str, str]: "Get request parameters for redirect url." - callback = request.build_absolute_uri(callback) - args = { - "client_id": self.source.consumer_key, + callback = self.request.build_absolute_uri(self.callback) + client_id: str = self.source.consumer_key + args: Dict[str, str] = { + "client_id": client_id, "redirect_uri": callback, "response_type": "code", } - state = self.get_application_state(request, callback) + state = self.get_application_state() if state is not None: args["state"] = state - request.session[self.session_key] = state + self.request.session[self.session_key] = state return args - def parse_raw_token(self, raw_token): + def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]: "Parse token and secret from raw token response." - if raw_token is None: - return (None, None) # Load as json first then parse as query string try: token_data = json.loads(raw_token) except ValueError: - token = parse_qs(raw_token).get("access_token", [None])[0] + token = parse_qs(raw_token)["access_token"][0] else: - token = token_data.get("access_token", None) + token = token_data["access_token"] return (token, None) - def request(self, method, url, **kwargs): + def do_request(self, method: str, url: str, **kwargs) -> Response: "Build remote url request. Constructs necessary auth." user_token = kwargs.pop("token", self.token) token, _ = self.parse_raw_token(user_token) @@ -110,7 +107,7 @@ class OAuth2Client(BaseOAuthClient): params = kwargs.get("params", {}) params["access_token"] = token kwargs["params"] = params - return super(OAuth2Client, self).session.request(method, url, **kwargs) + return super().do_request(method, url, **kwargs) @property def session_key(self): diff --git a/passbook/sources/oauth/exceptions.py b/passbook/sources/oauth/exceptions.py new file mode 100644 index 000000000..d89c30eeb --- /dev/null +++ b/passbook/sources/oauth/exceptions.py @@ -0,0 +1,5 @@ +from passbook.lib.sentry import SentryIgnoredException + + +class OAuthSourceException(SentryIgnoredException): + """General Error during OAuth Flow occurred""" diff --git a/passbook/sources/oauth/types/reddit.py b/passbook/sources/oauth/types/reddit.py index d18cd29a3..2d877201e 100644 --- a/passbook/sources/oauth/types/reddit.py +++ b/passbook/sources/oauth/types/reddit.py @@ -27,9 +27,7 @@ class RedditOAuth2Client(OAuth2Client): def get_access_token(self, request, callback=None, **request_kwargs): "Fetch access token from callback request." auth = HTTPBasicAuth(self.source.consumer_key, self.source.consumer_secret) - return super(RedditOAuth2Client, self).get_access_token( - request, callback, auth=auth - ) + return super().get_access_token(auth=auth) @MANAGER.source(kind=RequestKind.callback, name="reddit") diff --git a/passbook/sources/oauth/views/base.py b/passbook/sources/oauth/views/base.py index 32342b24d..b7472624b 100644 --- a/passbook/sources/oauth/views/base.py +++ b/passbook/sources/oauth/views/base.py @@ -1,6 +1,8 @@ """OAuth Base views""" from typing import Optional, Type +from django.http.request import HttpRequest + from passbook.sources.oauth.clients.base import BaseOAuthClient from passbook.sources.oauth.clients.oauth1 import OAuthClient from passbook.sources.oauth.clients.oauth2 import OAuth2Client @@ -11,13 +13,15 @@ from passbook.sources.oauth.models import OAuthSource class OAuthClientMixin: "Mixin for getting OAuth client for a source." + request: HttpRequest # Set by View class + client_class: Optional[Type[BaseOAuthClient]] = None - def get_client(self, source: OAuthSource) -> BaseOAuthClient: + def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient: "Get instance of the OAuth client for this source." if self.client_class is not None: # pylint: disable=not-callable - return self.client_class(source) + return self.client_class(source, self.request, **kwargs) if source.request_token_url: - return OAuthClient(source) - return OAuth2Client(source) + return OAuthClient(source, self.request, **kwargs) + return OAuth2Client(source, self.request, **kwargs) diff --git a/passbook/sources/oauth/views/callback.py b/passbook/sources/oauth/views/callback.py index 911812daa..155f9abbb 100644 --- a/passbook/sources/oauth/views/callback.py +++ b/passbook/sources/oauth/views/callback.py @@ -54,7 +54,7 @@ class OAuthCallback(OAuthClientMixin, View): client = self.get_client(self.source) callback = self.get_callback_url(self.source) # Fetch access token - token = client.get_access_token(self.request, callback=callback) + token = client.get_access_token(callback=callback) if token is None: return self.handle_login_failure(self.source, "Could not retrieve token.") if "error" in token: diff --git a/passbook/sources/oauth/views/redirect.py b/passbook/sources/oauth/views/redirect.py index a4ff16d25..a3557e017 100644 --- a/passbook/sources/oauth/views/redirect.py +++ b/passbook/sources/oauth/views/redirect.py @@ -40,9 +40,6 @@ class OAuthRedirect(OAuthClientMixin, RedirectView): else: if not source.enabled: raise Http404(f"source {slug} is not enabled.") - client = self.get_client(source) - callback = self.get_callback_url(source) + client = self.get_client(source, callback=self.get_callback_url(source)) params = self.get_additional_parameters(source) - return client.get_redirect_url( - self.request, callback=callback, parameters=params - ) + return client.get_redirect_url(params)