sources/oauth: cleanup clients, add type annotations

This commit is contained in:
Jens Langhammer 2020-09-26 00:34:57 +02:00
parent 6e4ce8dbaa
commit d9c2b32cba
8 changed files with 94 additions and 85 deletions

View File

@ -1,34 +1,41 @@
"""OAuth Clients"""
from typing import TYPE_CHECKING, Any, Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib.parse import urlencode
from django.http import HttpRequest
from requests import Session
from requests.exceptions import RequestException
from requests.models import Response
from structlog import get_logger
from passbook import __version__
from passbook.sources.oauth.models import OAuthSource
LOGGER = get_logger()
if TYPE_CHECKING:
from passbook.sources.oauth.models import OAuthSource
class BaseOAuthClient:
"""Base OAuth Client"""
session: Session
source: "OAuthSource"
source: OAuthSource
def __init__(self, source: "OAuthSource", token=""): # nosec
token: str
request: HttpRequest
callback: Optional[str]
def __init__(
self, source: OAuthSource, request: HttpRequest, callback: Optional[str] = None
):
self.source = source
self.token = token
self.token = ""
self.session = Session()
self.session.headers.update({"User-Agent": "passbook %s" % __version__})
self.request = request
self.callback = callback
self.session.headers.update({"User-Agent": f"passbook {__version__}"})
def get_access_token(
self, request: HttpRequest, callback=None
) -> Optional[Dict[str, Any]]:
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
"Fetch access token from callback request."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
@ -48,24 +55,28 @@ class BaseOAuthClient:
else:
return response.json()
def get_redirect_args(self, request, callback) -> Dict[str, str]:
def get_redirect_args(self) -> Dict[str, str]:
"Get request parameters for redirect url."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_redirect_url(self, request: HttpRequest, callback: str, parameters=None):
def get_redirect_url(self, parameters=None):
"Build authentication redirect url."
args = self.get_redirect_args(request, callback=callback)
args = self.get_redirect_args()
additional = parameters or {}
args.update(additional)
params = urlencode(args)
LOGGER.info("redirect args", **args)
return "{0}?{1}".format(self.source.authorization_url, params)
return f"{self.source.authorization_url}?{params}"
def parse_raw_token(self, raw_token):
def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]:
"Parse token and secret from raw token response."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def do_request(self, method: str, url: str, **kwargs) -> Response:
"""Wrapper around self.session.request, which can add special headers"""
return self.session.request(method, url, **kwargs)
@property
def session_key(self):
def session_key(self) -> str:
"""Return Session Key"""
raise NotImplementedError("Defined in a sub-class") # pragma: no cover

View File

@ -1,15 +1,18 @@
"""OAuth Clients"""
from typing import Dict, Optional
from typing import Any, Dict, Optional, Tuple
from urllib.parse import parse_qs
from django.db.models.expressions import Value
from django.http import HttpRequest
from django.utils.encoding import force_str
from requests.exceptions import RequestException
from requests.models import Response
from requests_oauthlib import OAuth1
from structlog import get_logger
from passbook import __version__
from passbook.sources.oauth.clients.base import BaseOAuthClient
from passbook.sources.oauth.exceptions import OAuthSourceException
LOGGER = get_logger()
@ -21,16 +24,14 @@ class OAuthClient(BaseOAuthClient):
"Accept": "application/json",
}
def get_access_token(
self, request: HttpRequest, callback=None
) -> Optional[Dict[str, str]]:
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
"Fetch access token from callback request."
raw_token = request.session.get(self.session_key, None)
verifier = request.GET.get("oauth_verifier", None)
raw_token = self.request.session.get(self.session_key, None)
verifier = self.request.GET.get("oauth_verifier", None)
if raw_token is not None and verifier is not None:
data = {
"oauth_verifier": verifier,
"oauth_callback": callback,
"oauth_callback": self.callback,
"token": raw_token,
}
try:
@ -48,9 +49,9 @@ class OAuthClient(BaseOAuthClient):
return response.json()
return None
def get_request_token(self, request: HttpRequest, callback):
def get_request_token(self) -> str:
"Fetch the OAuth request token. Only required for OAuth 1.0."
callback = request.build_absolute_uri(callback)
callback = self.request.build_absolute_uri(self.callback)
try:
response = self.session.request(
"post",
@ -63,33 +64,29 @@ class OAuthClient(BaseOAuthClient):
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch request token", exc=exc)
return None
raise OAuthSourceException from exc
else:
return response.text
def get_redirect_args(self, request: HttpRequest, callback):
def get_redirect_args(self) -> Dict[str, Any]:
"Get request parameters for redirect url."
callback = force_str(request.build_absolute_uri(callback))
raw_token = self.get_request_token(request, callback)
token, secret = self.parse_raw_token(raw_token)
if token is not None and secret is not None:
request.session[self.session_key] = raw_token
callback = self.request.build_absolute_uri(self.callback)
raw_token = self.get_request_token()
token, _ = self.parse_raw_token(raw_token)
self.request.session[self.session_key] = raw_token
return {
"oauth_token": token,
"oauth_callback": callback,
}
def parse_raw_token(self, raw_token):
def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]:
"Parse token and secret from raw token response."
if raw_token is None:
return (None, None)
query_string = parse_qs(raw_token)
token = query_string.get("oauth_token", [None])[0]
secret = query_string.get("oauth_token_secret", [None])[0]
token = query_string["oauth_token"][0]
secret = query_string["oauth_token_secret"][0]
return (token, secret)
def request(self, method, url, **kwargs):
def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth."
user_token = kwargs.pop("token", self.token)
token, secret = self.parse_raw_token(user_token)
@ -104,8 +101,8 @@ class OAuthClient(BaseOAuthClient):
callback_uri=callback,
)
kwargs["auth"] = oauth
return super(OAuthClient, self).session.request(method, url, **kwargs)
return super().do_request(method, url, **kwargs)
@property
def session_key(self):
return "oauth-client-{0}-request-token".format(self.source.name)
def session_key(self) -> str:
return f"oauth-client-{self.source.name}-request-token"

View File

@ -1,11 +1,12 @@
"""OAuth Clients"""
import json
from typing import TYPE_CHECKING
from typing import TYPE_CHECKING, Any, Dict, Optional, Tuple
from urllib.parse import parse_qs
from django.http import HttpRequest
from django.utils.crypto import constant_time_compare, get_random_string
from requests.exceptions import RequestException
from requests.models import Response
from structlog import get_logger
from passbook import __version__
@ -23,11 +24,10 @@ class OAuth2Client(BaseOAuthClient):
"Accept": "application/json",
}
# pylint: disable=unused-argument
def check_application_state(self, request: HttpRequest, callback: str):
def check_application_state(self) -> bool:
"Check optional state parameter."
stored = request.session.get(self.session_key, None)
returned = request.GET.get("state", None)
stored = self.request.session.get(self.session_key, None)
returned = self.request.GET.get("state", None)
check = False
if stored is not None:
if returned is not None:
@ -35,21 +35,25 @@ class OAuth2Client(BaseOAuthClient):
else:
LOGGER.warning("No state parameter returned by the source.")
else:
LOGGER.warning("No state stored in the sesssion.")
LOGGER.warning("No state stored in the session.")
return check
def get_access_token(self, request: HttpRequest, callback=None, **request_kwargs):
def get_application_state(self) -> str:
"Generate state optional parameter."
return get_random_string(32)
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
"Fetch access token from callback request."
callback = request.build_absolute_uri(callback or request.path)
if not self.check_application_state(request, callback):
callback = self.request.build_absolute_uri(self.callback or self.request.path)
if not self.check_application_state():
LOGGER.warning("Application state check failed.")
return None
if "code" in request.GET:
if "code" in self.request.GET:
args = {
"client_id": self.source.consumer_key,
"redirect_uri": callback,
"client_secret": self.source.consumer_secret,
"code": request.GET["code"],
"code": self.request.GET["code"],
"grant_type": "authorization_code",
}
else:
@ -61,7 +65,6 @@ class OAuth2Client(BaseOAuthClient):
self.source.access_token_url,
data=args,
headers=self._default_headers,
**request_kwargs,
)
response.raise_for_status()
except RequestException as exc:
@ -70,39 +73,33 @@ class OAuth2Client(BaseOAuthClient):
else:
return response.json()
# pylint: disable=unused-argument
def get_application_state(self, request: HttpRequest, callback):
"Generate state optional parameter."
return get_random_string(32)
def get_redirect_args(self, request, callback):
def get_redirect_args(self) -> Dict[str, str]:
"Get request parameters for redirect url."
callback = request.build_absolute_uri(callback)
args = {
"client_id": self.source.consumer_key,
callback = self.request.build_absolute_uri(self.callback)
client_id: str = self.source.consumer_key
args: Dict[str, str] = {
"client_id": client_id,
"redirect_uri": callback,
"response_type": "code",
}
state = self.get_application_state(request, callback)
state = self.get_application_state()
if state is not None:
args["state"] = state
request.session[self.session_key] = state
self.request.session[self.session_key] = state
return args
def parse_raw_token(self, raw_token):
def parse_raw_token(self, raw_token: str) -> Tuple[str, Optional[str]]:
"Parse token and secret from raw token response."
if raw_token is None:
return (None, None)
# Load as json first then parse as query string
try:
token_data = json.loads(raw_token)
except ValueError:
token = parse_qs(raw_token).get("access_token", [None])[0]
token = parse_qs(raw_token)["access_token"][0]
else:
token = token_data.get("access_token", None)
token = token_data["access_token"]
return (token, None)
def request(self, method, url, **kwargs):
def do_request(self, method: str, url: str, **kwargs) -> Response:
"Build remote url request. Constructs necessary auth."
user_token = kwargs.pop("token", self.token)
token, _ = self.parse_raw_token(user_token)
@ -110,7 +107,7 @@ class OAuth2Client(BaseOAuthClient):
params = kwargs.get("params", {})
params["access_token"] = token
kwargs["params"] = params
return super(OAuth2Client, self).session.request(method, url, **kwargs)
return super().do_request(method, url, **kwargs)
@property
def session_key(self):

View File

@ -0,0 +1,5 @@
from passbook.lib.sentry import SentryIgnoredException
class OAuthSourceException(SentryIgnoredException):
"""General Error during OAuth Flow occurred"""

View File

@ -27,9 +27,7 @@ class RedditOAuth2Client(OAuth2Client):
def get_access_token(self, request, callback=None, **request_kwargs):
"Fetch access token from callback request."
auth = HTTPBasicAuth(self.source.consumer_key, self.source.consumer_secret)
return super(RedditOAuth2Client, self).get_access_token(
request, callback, auth=auth
)
return super().get_access_token(auth=auth)
@MANAGER.source(kind=RequestKind.callback, name="reddit")

View File

@ -1,6 +1,8 @@
"""OAuth Base views"""
from typing import Optional, Type
from django.http.request import HttpRequest
from passbook.sources.oauth.clients.base import BaseOAuthClient
from passbook.sources.oauth.clients.oauth1 import OAuthClient
from passbook.sources.oauth.clients.oauth2 import OAuth2Client
@ -11,13 +13,15 @@ from passbook.sources.oauth.models import OAuthSource
class OAuthClientMixin:
"Mixin for getting OAuth client for a source."
request: HttpRequest # Set by View class
client_class: Optional[Type[BaseOAuthClient]] = None
def get_client(self, source: OAuthSource) -> BaseOAuthClient:
def get_client(self, source: OAuthSource, **kwargs) -> BaseOAuthClient:
"Get instance of the OAuth client for this source."
if self.client_class is not None:
# pylint: disable=not-callable
return self.client_class(source)
return self.client_class(source, self.request, **kwargs)
if source.request_token_url:
return OAuthClient(source)
return OAuth2Client(source)
return OAuthClient(source, self.request, **kwargs)
return OAuth2Client(source, self.request, **kwargs)

View File

@ -54,7 +54,7 @@ class OAuthCallback(OAuthClientMixin, View):
client = self.get_client(self.source)
callback = self.get_callback_url(self.source)
# Fetch access token
token = client.get_access_token(self.request, callback=callback)
token = client.get_access_token(callback=callback)
if token is None:
return self.handle_login_failure(self.source, "Could not retrieve token.")
if "error" in token:

View File

@ -40,9 +40,6 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
else:
if not source.enabled:
raise Http404(f"source {slug} is not enabled.")
client = self.get_client(source)
callback = self.get_callback_url(source)
client = self.get_client(source, callback=self.get_callback_url(source))
params = self.get_additional_parameters(source)
return client.get_redirect_url(
self.request, callback=callback, parameters=params
)
return client.get_redirect_url(params)