diff --git a/passbook/sources/oauth/auth.py b/passbook/sources/oauth/auth.py index fb2445d89..0428365d8 100644 --- a/passbook/sources/oauth/auth.py +++ b/passbook/sources/oauth/auth.py @@ -1,6 +1,7 @@ """passbook oauth_client Authorization backend""" +from typing import Optional + from django.contrib.auth.backends import ModelBackend -from django.db.models import Q from django.http import HttpRequest from passbook.core.models import User @@ -12,12 +13,11 @@ class AuthorizedServiceBackend(ModelBackend): def authenticate( self, request: HttpRequest, source: OAuthSource, identifier: str - ) -> User: + ) -> Optional[User]: "Fetch user for a given source by id." - source_q = Q(source__name=source) - if isinstance(source, OAuthSource): - source_q = Q(source=source) access = UserOAuthSourceConnection.objects.filter( - source_q, identifier=identifier - ).select_related("user")[0] - return access.user + source=source, identifier=identifier + ).select_related("user") + if not access.exists(): + return None + return access.first().user diff --git a/passbook/sources/oauth/types/azure_ad.py b/passbook/sources/oauth/types/azure_ad.py index ce4044435..936c1708c 100644 --- a/passbook/sources/oauth/types/azure_ad.py +++ b/passbook/sources/oauth/types/azure_ad.py @@ -2,10 +2,8 @@ import uuid from typing import Any, Dict -from passbook.core.models import User from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback @@ -16,16 +14,15 @@ class AzureADOAuthCallback(OAuthCallback): def get_user_id(self, source: OAuthSource, info: Dict[str, Any]) -> str: return str(uuid.UUID(info.get("objectId")).int) - def get_or_create_user( + def get_user_enroll_context( self, source: OAuthSource, access: UserOAuthSourceConnection, info: Dict[str, Any], - ) -> User: - user_data = { + ) -> Dict[str, Any]: + mail = info.get("mail", None) or info.get("otherMails", [None])[0] + return { "username": info.get("displayName"), - "email": info.get("mail", None) or info.get("otherMails")[0], + "email": mail, "name": info.get("displayName"), - "password": None, } - return user_get_or_create(**user_data) diff --git a/passbook/sources/oauth/types/discord.py b/passbook/sources/oauth/types/discord.py index f82089cf5..ba5b61ce8 100644 --- a/passbook/sources/oauth/types/discord.py +++ b/passbook/sources/oauth/types/discord.py @@ -1,6 +1,8 @@ """Discord OAuth Views""" +from typing import Any, Dict + +from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect @@ -18,12 +20,14 @@ class DiscordOAuthRedirect(OAuthRedirect): class DiscordOAuth2Callback(OAuthCallback): """Discord OAuth2 Callback""" - def get_or_create_user(self, source, access, info): - user_data = { + def get_user_enroll_context( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> Dict[str, Any]: + return { "username": info.get("username"), - "email": info.get("email", "None"), + "email": info.get("email", None), "name": info.get("username"), - "password": None, } - discord_user = user_get_or_create(**user_data) - return discord_user diff --git a/passbook/sources/oauth/types/facebook.py b/passbook/sources/oauth/types/facebook.py index 397b4adb0..c46557652 100644 --- a/passbook/sources/oauth/types/facebook.py +++ b/passbook/sources/oauth/types/facebook.py @@ -4,8 +4,8 @@ from typing import Any, Dict, Optional from facebook import GraphAPI from passbook.sources.oauth.clients import OAuth2Client +from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect @@ -33,12 +33,14 @@ class FacebookOAuth2Callback(OAuthCallback): client_class = FacebookOAuth2Client - def get_or_create_user(self, source, access, info): - user_data = { + def get_user_enroll_context( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> Dict[str, Any]: + return { "username": info.get("name"), - "email": info.get("email", ""), + "email": info.get("email"), "name": info.get("name"), - "password": None, } - fb_user = user_get_or_create(**user_data) - return fb_user diff --git a/passbook/sources/oauth/types/github.py b/passbook/sources/oauth/types/github.py index 174fe046f..2d7a7181c 100644 --- a/passbook/sources/oauth/types/github.py +++ b/passbook/sources/oauth/types/github.py @@ -1,6 +1,8 @@ """GitHub OAuth Views""" +from typing import Any, Dict + +from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback @@ -8,12 +10,14 @@ from passbook.sources.oauth.views.core import OAuthCallback class GitHubOAuth2Callback(OAuthCallback): """GitHub OAuth2 Callback""" - def get_or_create_user(self, source, access, info): - user_data = { + def get_user_enroll_context( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> Dict[str, Any]: + return { "username": info.get("login"), - "email": info.get("email", ""), + "email": info.get("email"), "name": info.get("name"), - "password": None, } - gh_user = user_get_or_create(**user_data) - return gh_user diff --git a/passbook/sources/oauth/types/google.py b/passbook/sources/oauth/types/google.py index 4c721f438..5e79c6819 100644 --- a/passbook/sources/oauth/types/google.py +++ b/passbook/sources/oauth/types/google.py @@ -1,6 +1,8 @@ """Google OAuth Views""" +from typing import Any, Dict + +from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect @@ -18,12 +20,14 @@ class GoogleOAuthRedirect(OAuthRedirect): class GoogleOAuth2Callback(OAuthCallback): """Google OAuth2 Callback""" - def get_or_create_user(self, source, access, info): - user_data = { + def get_user_enroll_context( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> Dict[str, Any]: + return { "username": info.get("email"), - "email": info.get("email", ""), + "email": info.get("email"), "name": info.get("name"), - "password": None, } - google_user = user_get_or_create(**user_data) - return google_user diff --git a/passbook/sources/oauth/types/manager.py b/passbook/sources/oauth/types/manager.py index 7200130b8..a84a65f37 100644 --- a/passbook/sources/oauth/types/manager.py +++ b/passbook/sources/oauth/types/manager.py @@ -48,7 +48,11 @@ class SourceTypeManager: if kind.value in self.__source_types: if source.provider_type in self.__source_types[kind.value]: return self.__source_types[kind.value][source.provider_type] - LOGGER.warning("no matching type found, using default") + LOGGER.warning( + "no matching type found, using default", + wanted=source.provider_type, + have=self.__source_types[kind.value].keys(), + ) # Return defaults if kind == RequestKind.callback: return OAuthCallback diff --git a/passbook/sources/oauth/types/oidc.py b/passbook/sources/oauth/types/oidc.py index 639d706c7..99dfe1522 100644 --- a/passbook/sources/oauth/types/oidc.py +++ b/passbook/sources/oauth/types/oidc.py @@ -1,9 +1,8 @@ """OpenID Connect OAuth Views""" -from typing import Dict +from typing import Any, Dict -from passbook.sources.oauth.models import OAuthSource +from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect @@ -24,11 +23,14 @@ class OpenIDConnectOAuth2Callback(OAuthCallback): def get_user_id(self, source: OAuthSource, info: Dict[str, str]) -> str: return info.get("sub", "") - def get_or_create_user(self, source: OAuthSource, access, info: Dict[str, str]): - user_data = { + def get_user_enroll_context( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> Dict[str, Any]: + return { "username": info.get("nickname"), "email": info.get("email"), "name": info.get("name"), - "password": None, } - return user_get_or_create(**user_data) diff --git a/passbook/sources/oauth/types/reddit.py b/passbook/sources/oauth/types/reddit.py index 8384d6435..6ee13376f 100644 --- a/passbook/sources/oauth/types/reddit.py +++ b/passbook/sources/oauth/types/reddit.py @@ -1,9 +1,11 @@ """Reddit OAuth Views""" +from typing import Any, Dict + from requests.auth import HTTPBasicAuth from passbook.sources.oauth.clients import OAuth2Client +from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect @@ -35,12 +37,15 @@ class RedditOAuth2Callback(OAuthCallback): client_class = RedditOAuth2Client - def get_or_create_user(self, source, access, info): - user_data = { + def get_user_enroll_context( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> Dict[str, Any]: + return { "username": info.get("name"), "email": None, "name": info.get("name"), "password": None, } - reddit_user = user_get_or_create(**user_data) - return reddit_user diff --git a/passbook/sources/oauth/types/twitter.py b/passbook/sources/oauth/types/twitter.py index 248f6c996..15934835f 100644 --- a/passbook/sources/oauth/types/twitter.py +++ b/passbook/sources/oauth/types/twitter.py @@ -1,6 +1,9 @@ """Twitter OAuth Views""" +from typing import Any, Dict + +from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection + # from passbook.sources.oauth.types.manager import MANAGER, RequestKind -from passbook.sources.oauth.utils import user_get_or_create from passbook.sources.oauth.views.core import OAuthCallback @@ -8,12 +11,14 @@ from passbook.sources.oauth.views.core import OAuthCallback class TwitterOAuthCallback(OAuthCallback): """Twitter OAuth2 Callback""" - def get_or_create_user(self, source, access, info): - user_data = { + def get_user_enroll_context( + self, + source: OAuthSource, + access: UserOAuthSourceConnection, + info: Dict[str, Any], + ) -> Dict[str, Any]: + return { "username": info.get("screen_name"), - "email": info.get("email", ""), + "email": info.get("email"), "name": info.get("name"), - "password": None, } - tw_user = user_get_or_create(**user_data) - return tw_user diff --git a/passbook/sources/oauth/utils.py b/passbook/sources/oauth/utils.py deleted file mode 100644 index a418368c3..000000000 --- a/passbook/sources/oauth/utils.py +++ /dev/null @@ -1,16 +0,0 @@ -"""OAuth Client User Creation Utils""" -from django.db.utils import IntegrityError - -from passbook.core.models import User - - -def user_get_or_create(**kwargs: str) -> User: - """Create user or return existing user""" - try: - new_user = User.objects.create_user(**kwargs) - except IntegrityError: - # At this point we've already checked that there is no existing connection - # to any user. Hence if we can't create the user, - kwargs["username"] = "%s_1" % kwargs["username"] - new_user = User.objects.create_user(**kwargs) - return new_user diff --git a/passbook/sources/oauth/views/core.py b/passbook/sources/oauth/views/core.py index 4cebf7431..23883de01 100644 --- a/passbook/sources/oauth/views/core.py +++ b/passbook/sources/oauth/views/core.py @@ -67,10 +67,10 @@ class OAuthRedirect(OAuthClientMixin, RedirectView): try: source = OAuthSource.objects.get(slug=slug) except OAuthSource.DoesNotExist: - raise Http404("Unknown OAuth source '%s'." % slug) + raise Http404(f"Unknown OAuth source '{slug}'.") else: if not source.enabled: - raise Http404("source %s is not enabled." % slug) + raise Http404(f"source {slug} is not enabled.") client = self.get_client(source) callback = self.get_callback_url(source) params = self.get_additional_parameters(source) @@ -138,8 +138,8 @@ class OAuthCallback(OAuthClientMixin, View): source=self.source, identifier=identifier, request=request ) if user is None: - LOGGER.debug("Handling new user", source=self.source) - return self.handle_new_user(self.source, connection, info) + LOGGER.debug("Handling new connection", source=self.source) + return self.handle_new_connection(self.source, connection, info) LOGGER.debug("Handling existing user", source=self.source) return self.handle_existing_user(self.source, user, connection, info) @@ -153,13 +153,13 @@ class OAuthCallback(OAuthClientMixin, View): "Return url to redirect on login failure." return settings.LOGIN_URL - def get_or_create_user( + def get_user_enroll_context( self, source: OAuthSource, access: UserOAuthSourceConnection, info: Dict[str, Any], - ) -> User: - "Create a shell auth.User." + ) -> Dict[str, Any]: + """Create a dict of User data""" raise NotImplementedError() # pylint: disable=unused-argument @@ -171,21 +171,19 @@ class OAuthCallback(OAuthClientMixin, View): return info["id"] return None - def handle_login_flow(self, flow: Optional[Flow], user: User) -> HttpResponse: + def handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse: """Prepare Authentication Plan, redirect user FlowExecutor""" - if not flow: - raise Http404 - # We run the Flow planner here so we can pass the Pending user in the context - planner = FlowPlanner(flow) - plan = planner.plan( - self.request, + kwargs.update( { - PLAN_CONTEXT_PENDING_USER: user, + # PLAN_CONTEXT_PENDING_USER: user, # Since we authenticate the user by their token, they have no backend set PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend", PLAN_CONTEXT_SSO: True, - }, + } ) + # We run the Flow planner here so we can pass the Pending user in the context + planner = FlowPlanner(flow) + plan = planner.plan(self.request, kwargs,) self.request.session[SESSION_KEY_PLAN] = plan return redirect_with_qs( "passbook_flows:flow-executor-shell", self.request.GET, flow_slug=flow.slug, @@ -207,10 +205,8 @@ class OAuthCallback(OAuthClientMixin, View): % {"source": self.source.name} ), ) - user = AuthorizedServiceBackend().authenticate( - source=access.source, identifier=access.identifier, request=self.request - ) - return self.handle_login_flow(source.authentication_flow, user) + flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user} + return self.handle_login_flow(source.authentication_flow, **flow_kwargs) def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse: "Message user and redirect on error." @@ -218,27 +214,23 @@ class OAuthCallback(OAuthClientMixin, View): messages.error(self.request, _("Authentication Failed.")) return redirect(self.get_error_redirect(source, reason)) - def handle_new_user( + def handle_new_connection( self, source: OAuthSource, access: UserOAuthSourceConnection, info: Dict[str, Any], ) -> HttpResponse: - "Create a shell auth.User and redirect." - was_authenticated = False + """Check if a user exists for the connection and connect them, otherwise + prepare to enroll a new user.""" if self.request.user.is_authenticated: # there's already a user logged in, just link them up user = self.request.user - was_authenticated = True - else: - user = self.get_or_create_user(source, access, info) - access.user = user - access.save() - UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) - Event.new( - EventAction.CUSTOM, message="Linked OAuth Source", source=source - ).from_http(self.request) - if was_authenticated: + access.user = user + access.save() + UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) + Event.new( + EventAction.CUSTOM, message="Linked OAuth Source", source=source + ).from_http(self.request) messages.success( self.request, _("Successfully linked %(source)s!" % {"source": self.source.name}), @@ -249,10 +241,7 @@ class OAuthCallback(OAuthClientMixin, View): kwargs={"source_slug": self.source.slug}, ) ) - # User was not authenticated, new user has been created - user = AuthorizedServiceBackend().authenticate( - source=access.source, identifier=access.identifier, request=self.request - ) + # User was not authenticated, new user will be created messages.success( self.request, _( @@ -260,7 +249,8 @@ class OAuthCallback(OAuthClientMixin, View): % {"source": self.source.name} ), ) - return self.handle_login_flow(source.enrollment_flow, user) + context = self.get_user_enroll_context(source, access, info) + return self.handle_login_flow(source.enrollment_flow, **context) class DisconnectView(LoginRequiredMixin, View):