sources/oauth: cleanup clients, add type annotations

This commit is contained in:
Jens Langhammer 2020-09-26 00:34:57 +02:00
parent 6e4ce8dbaa
commit d9c2b32cba
8 changed files with 94 additions and 85 deletions

View File

@ -1,34 +1,41 @@
"""OAuth Clients""" """OAuth Clients"""
from typing import TYPE_CHECKING, Any, Dict, Optional from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlencode from urllib.parse import urlencode
from django.http import HttpRequest from django.http import HttpRequest
from requests import Session from requests import Session
from requests.exceptions import RequestException from requests.exceptions import RequestException
from requests.models import Response
from structlog import get_logger from structlog import get_logger
from passbook import __version__ from passbook import __version__
from passbook.sources.oauth.models import OAuthSource
LOGGER = get_logger() LOGGER = get_logger()
if TYPE_CHECKING:
from passbook.sources.oauth.models import OAuthSource
class BaseOAuthClient: class BaseOAuthClient:
"""Base OAuth Client""" """Base OAuth Client"""
session: Session session: Session
source: "OAuthSource" source: OAuthSource
def __init__(self, source: "OAuthSource", token=""): # nosec token: str
request: HttpRequest
callback: Optional[str]
def __init__(
self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None
):
self.source = source self.source = source
self.token = token self.token = ""
self.session = Session() self.session = Session()
self.session.headers.update({"User-Agent": "passbook %s" % __version__}) self.request = request
self.callback = callback
self.session.headers.update({"User-Agent": f"passbook {__version__}"})
def get_access_token( def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
self, request: HttpRequest, callback=None
) -> Optional[Dict[str, Any]]:
"Fetch access token from callback request." "Fetch access token from callback request."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
@ -48,24 +55,28 @@ class BaseOAuthClient:
else: else:
return response.json() return response.json()
def get_redirect_args(self, request, callback) -> Dict[str, str]: def get_redirect_args(self) -> Dict[str, str]:
"Get request parameters for redirect url." "Get request parameters for redirect url."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_redirect_url(self, request: HttpRequest, callback: str, parameters=None): def get_redirect_url(self, parameters=None):
"Build authentication redirect url." "Build authentication redirect url."
args = self.get_redirect_args(request, callback=callback) args = self.get_redirect_args()
additional = parameters or {} additional = parameters or {}
args.update(additional) args.update(additional)
params = urlencode(args) params = urlencode(args)
LOGGER.info("redirect args", **args) LOGGER.info("redirect args", **args)
return "{0}?{1}".format(self.source.authorization_url, params) return f"{self.source.authorization_url}?{params}"
def parse_raw_token(self, raw_token): def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def do_request(self, method: str, url: str, **kwargs) -> Response:
"""Wrapper around self.session.request, which can add special headers"""
return self.session.request(method, url, **kwargs)
@property @property
def session_key(self): def session_key(self) -> str:
"""Return Session Key""" """Return Session Key"""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover

View File

@ -1,15 +1,18 @@
"""OAuth Clients""" """OAuth Clients"""
from typing import Dict, Optional from typing import Any, Dict, Optional, Tuple
from urllib.parse import parse_qs from urllib.parse import parse_qs
from django.db.models.expressions import Value
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.encoding import force_str from django.utils.encoding import force_str
from requests.exceptions import RequestException from requests.exceptions import RequestException
from requests.models import Response
from requests_oauthlib import OAuth1 from requests_oauthlib import OAuth1
from structlog import get_logger from structlog import get_logger
from passbook import __version__ from passbook import __version__
from passbook.sources.oauth.clients.base import BaseOAuthClient from passbook.sources.oauth.clients.base import BaseOAuthClient
from passbook.sources.oauth.exceptions import OAuthSourceException
LOGGER = get_logger() LOGGER = get_logger()
@ -21,16 +24,14 @@ class OAuthClient(BaseOAuthClient):
"Accept": "application/json", "Accept": "application/json",
} }
def get_access_token( def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
self, request: HttpRequest, callback=None
) -> Optional[Dict[str, str]]:
"Fetch access token from callback request." "Fetch access token from callback request."
raw_token = request.session.get(self.session_key, None) raw_token = self.request.session.get(self.session_key, None)
verifier = request.GET.get("oauth_verifier", None) verifier = self.request.GET.get("oauth_verifier", None)
if raw_token is not None and verifier is not None: if raw_token is not None and verifier is not None:
data = { data = {
"oauth_verifier": verifier, "oauth_verifier": verifier,
"oauth_callback": callback, "oauth_callback": self.callback,
"token": raw_token, "token": raw_token,
} }
try: try:
@ -48,9 +49,9 @@ class OAuthClient(BaseOAuthClient):
return response.json() return response.json()
return None return None
def get_request_token(self, request: HttpRequest, callback): def get_request_token(self) -> str:
"Fetch the OAuth request token. Only required for OAuth 1.0." "Fetch the OAuth request token. Only required for OAuth 1.0."
callback = request.build_absolute_uri(callback) callback = self.request.build_absolute_uri(self.callback)
try: try:
response = self.session.request( response = self.session.request(
"post", "post",
@ -63,33 +64,29 @@ class OAuthClient(BaseOAuthClient):
) )
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
LOGGER.warning("Unable to fetch request token", exc=exc) raise OAuthSourceException from exc
return None
else: else:
return response.text return response.text
def get_redirect_args(self, request: HttpRequest, callback): def get_redirect_args(self) -> Dict[str, Any]:
"Get request parameters for redirect url." "Get request parameters for redirect url."
callback = force_str(request.build_absolute_uri(callback)) callback = self.request.build_absolute_uri(self.callback)
raw_token = self.get_request_token(request, callback) raw_token = self.get_request_token()
token, secret = self.parse_raw_token(raw_token) token, _ = self.parse_raw_token(raw_token)
if token is not None and secret is not None: self.request.session[self.session_key] = raw_token
request.session[self.session_key] = raw_token
return { return {
"oauth_token": token, "oauth_token": token,
"oauth_callback": callback, "oauth_callback": callback,
} }
def parse_raw_token(self, raw_token): def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
if raw_token is None:
return (None, None)
query_string = parse_qs(raw_token) query_string = parse_qs(raw_token)
token = query_string.get("oauth_token", [None])[0] token = query_string["oauth_token"][0]
secret = query_string.get("oauth_token_secret", [None])[0] secret = query_string["oauth_token_secret"][0]
return (token, secret) return (token, secret)
def request(self, method, url, **kwargs): def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth." "Build remote url request. Constructs necessary auth."
user_token = kwargs.pop("token", self.token) user_token = kwargs.pop("token", self.token)
token, secret = self.parse_raw_token(user_token) token, secret = self.parse_raw_token(user_token)
@ -104,8 +101,8 @@ class OAuthClient(BaseOAuthClient):
callback_uri=callback, callback_uri=callback,
) )
kwargs["auth"] = oauth kwargs["auth"] = oauth
return super(OAuthClient, self).session.request(method, url, **kwargs) return super().do_request(method, url, **kwargs)
@property @property
def session_key(self): def session_key(self) -> str:
return "oauth-client-{0}-request-token".format(self.source.name) return f"oauth-client-{self.source.name}-request-token"

View File

@ -1,11 +1,12 @@
"""OAuth Clients""" """OAuth Clients"""
import json import json
from typing import TYPE_CHECKING from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from urllib.parse import parse_qs from urllib.parse import parse_qs
from django.http import HttpRequest from django.http import HttpRequest
from django.utils.crypto import constant_time_compare, get_random_string from django.utils.crypto import constant_time_compare, get_random_string
from requests.exceptions import RequestException from requests.exceptions import RequestException
from requests.models import Response
from structlog import get_logger from structlog import get_logger
from passbook import __version__ from passbook import __version__
@ -23,11 +24,10 @@ class OAuth2Client(BaseOAuthClient):
"Accept": "application/json", "Accept": "application/json",
} }
# pylint: disable=unused-argument def check_application_state(self) -> bool:
def check_application_state(self, request: HttpRequest, callback: str):
"Check optional state parameter." "Check optional state parameter."
stored = request.session.get(self.session_key, None) stored = self.request.session.get(self.session_key, None)
returned = request.GET.get("state", None) returned = self.request.GET.get("state", None)
check = False check = False
if stored is not None: if stored is not None:
if returned is not None: if returned is not None:
@ -35,21 +35,25 @@ class OAuth2Client(BaseOAuthClient):
else: else:
LOGGER.warning("No state parameter returned by the source.") LOGGER.warning("No state parameter returned by the source.")
else: else:
LOGGER.warning("No state stored in the sesssion.") LOGGER.warning("No state stored in the session.")
return check return check
def get_access_token(self, request: HttpRequest, callback=None, **request_kwargs): def get_application_state(self) -> str:
"Generate state optional parameter."
return get_random_string(32)
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
"Fetch access token from callback request." "Fetch access token from callback request."
callback = request.build_absolute_uri(callback or request.path) callback = self.request.build_absolute_uri(self.callback or self.request.path)
if not self.check_application_state(request, callback): if not self.check_application_state():
LOGGER.warning("Application state check failed.") LOGGER.warning("Application state check failed.")
return None return None
if "code" in request.GET: if "code" in self.request.GET:
args = { args = {
"client_id": self.source.consumer_key, "client_id": self.source.consumer_key,
"redirect_uri": callback, "redirect_uri": callback,
"client_secret": self.source.consumer_secret, "client_secret": self.source.consumer_secret,
"code": request.GET["code"], "code": self.request.GET["code"],
"grant_type": "authorization_code", "grant_type": "authorization_code",
} }
else: else:
@ -61,7 +65,6 @@ class OAuth2Client(BaseOAuthClient):
self.source.access_token_url, self.source.access_token_url,
data=args, data=args,
headers=self._default_headers, headers=self._default_headers,
**request_kwargs,
) )
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
@ -70,39 +73,33 @@ class OAuth2Client(BaseOAuthClient):
else: else:
return response.json() return response.json()
# pylint: disable=unused-argument def get_redirect_args(self) -> Dict[str, str]:
def get_application_state(self, request: HttpRequest, callback):
"Generate state optional parameter."
return get_random_string(32)
def get_redirect_args(self, request, callback):
"Get request parameters for redirect url." "Get request parameters for redirect url."
callback = request.build_absolute_uri(callback) callback = self.request.build_absolute_uri(self.callback)
args = { client_id: str = self.source.consumer_key
"client_id": self.source.consumer_key, args: Dict[str, str] = {
"client_id": client_id,
"redirect_uri": callback, "redirect_uri": callback,
"response_type": "code", "response_type": "code",
} }
state = self.get_application_state(request, callback) state = self.get_application_state()
if state is not None: if state is not None:
args["state"] = state args["state"] = state
request.session[self.session_key] = state self.request.session[self.session_key] = state
return args return args
def parse_raw_token(self, raw_token): def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
if raw_token is None:
return (None, None)
# Load as json first then parse as query string # Load as json first then parse as query string
try: try:
token_data = json.loads(raw_token) token_data = json.loads(raw_token)
except ValueError: except ValueError:
token = parse_qs(raw_token).get("access_token", [None])[0] token = parse_qs(raw_token)["access_token"][0]
else: else:
token = token_data.get("access_token", None) token = token_data["access_token"]
return (token, None) return (token, None)
def request(self, method, url, **kwargs): def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth." "Build remote url request. Constructs necessary auth."
user_token = kwargs.pop("token", self.token) user_token = kwargs.pop("token", self.token)
token, _ = self.parse_raw_token(user_token) token, _ = self.parse_raw_token(user_token)
@ -110,7 +107,7 @@ class OAuth2Client(BaseOAuthClient):
params = kwargs.get("params", {}) params = kwargs.get("params", {})
params["access_token"] = token params["access_token"] = token
kwargs["params"] = params kwargs["params"] = params
return super(OAuth2Client, self).session.request(method, url, **kwargs) return super().do_request(method, url, **kwargs)
@property @property
def session_key(self): def session_key(self):

View File

@ -0,0 +1,5 @@
from passbook.lib.sentry import SentryIgnoredException
class OAuthSourceException(SentryIgnoredException):
"""General Error during OAuth Flow occurred"""

View File

@ -27,9 +27,7 @@ class RedditOAuth2Client(OAuth2Client):
def get_access_token(self, request, callback=None, **request_kwargs): def get_access_token(self, request, callback=None, **request_kwargs):
"Fetch access token from callback request." "Fetch access token from callback request."
auth = HTTPBasicAuth(self.source.consumer_key, self.source.consumer_secret) auth = HTTPBasicAuth(self.source.consumer_key, self.source.consumer_secret)
return super(RedditOAuth2Client, self).get_access_token( return super().get_access_token(auth=auth)
request, callback, auth=auth
)
@MANAGER.source(kind=RequestKind.callback, name="reddit") @MANAGER.source(kind=RequestKind.callback, name="reddit")

View File

@ -1,6 +1,8 @@
"""OAuth Base views""" """OAuth Base views"""
from typing import Optional, Type from typing import Optional, Type
from django.http.request import HttpRequest
from passbook.sources.oauth.clients.base import BaseOAuthClient from passbook.sources.oauth.clients.base import BaseOAuthClient
from passbook.sources.oauth.clients.oauth1 import OAuthClient from passbook.sources.oauth.clients.oauth1 import OAuthClient
from passbook.sources.oauth.clients.oauth2 import OAuth2Client from passbook.sources.oauth.clients.oauth2 import OAuth2Client
@ -11,13 +13,15 @@ from passbook.sources.oauth.models import OAuthSource
class OAuthClientMixin: class OAuthClientMixin:
"Mixin for getting OAuth client for a source." "Mixin for getting OAuth client for a source."
request: HttpRequest # Set by View class
client_class: Optional[Type[BaseOAuthClient]] = None client_class: Optional[Type[BaseOAuthClient]] = None
def get_client(self, source: OAuthSource) -> BaseOAuthClient: def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient:
"Get instance of the OAuth client for this source." "Get instance of the OAuth client for this source."
if self.client_class is not None: if self.client_class is not None:
# pylint: disable=not-callable # pylint: disable=not-callable
return self.client_class(source) return self.client_class(source, self.request, **kwargs)
if source.request_token_url: if source.request_token_url:
return OAuthClient(source) return OAuthClient(source, self.request, **kwargs)
return OAuth2Client(source) return OAuth2Client(source, self.request, **kwargs)

View File

@ -54,7 +54,7 @@ class OAuthCallback(OAuthClientMixin, View):
client = self.get_client(self.source) client = self.get_client(self.source)
callback = self.get_callback_url(self.source) callback = self.get_callback_url(self.source)
# Fetch access token # Fetch access token
token = client.get_access_token(self.request, callback=callback) token = client.get_access_token(callback=callback)
if token is None: if token is None:
return self.handle_login_failure(self.source, "Could not retrieve token.") return self.handle_login_failure(self.source, "Could not retrieve token.")
if "error" in token: if "error" in token:

View File

@ -40,9 +40,6 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
else: else:
if not source.enabled: if not source.enabled:
raise Http404(f"source {slug} is not enabled.") raise Http404(f"source {slug} is not enabled.")
client = self.get_client(source) client = self.get_client(source, callback=self.get_callback_url(source))
callback = self.get_callback_url(source)
params = self.get_additional_parameters(source) params = self.get_additional_parameters(source)
return client.get_redirect_url( return client.get_redirect_url(params)
self.request, callback=callback, parameters=params
)