sources/oauth: rewrite to not directly create user, pre-seed data into flow

This commit is contained in:
Jens Langhammer 2020-07-08 20:36:48 +02:00
parent 0e3e73989d
commit d786fa4b7c
12 changed files with 119 additions and 118 deletions

View File

@ -1,6 +1,7 @@
"""passbook oauth_client Authorization backend""" """passbook oauth_client Authorization backend"""
from typing import Optional
from django.contrib.auth.backends import ModelBackend from django.contrib.auth.backends import ModelBackend
from django.db.models import Q
from django.http import HttpRequest from django.http import HttpRequest
from passbook.core.models import User from passbook.core.models import User
@ -12,12 +13,11 @@ class AuthorizedServiceBackend(ModelBackend):
def authenticate( def authenticate(
self, request: HttpRequest, source: OAuthSource, identifier: str self, request: HttpRequest, source: OAuthSource, identifier: str
) -> User: ) -> Optional[User]:
"Fetch user for a given source by id." "Fetch user for a given source by id."
source_q = Q(source__name=source)
if isinstance(source, OAuthSource):
source_q = Q(source=source)
access = UserOAuthSourceConnection.objects.filter( access = UserOAuthSourceConnection.objects.filter(
source_q, identifier=identifier source=source, identifier=identifier
).select_related("user")[0] ).select_related("user")
return access.user if not access.exists():
return None
return access.first().user

View File

@ -2,10 +2,8 @@
import uuid import uuid
from typing import Any, Dict from typing import Any, Dict
from passbook.core.models import User
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback from passbook.sources.oauth.views.core import OAuthCallback
@ -16,16 +14,15 @@ class AzureADOAuthCallback(OAuthCallback):
def get_user_id(self, source: OAuthSource, info: Dict[str, Any]) -> str: def get_user_id(self, source: OAuthSource, info: Dict[str, Any]) -> str:
return str(uuid.UUID(info.get("objectId")).int) return str(uuid.UUID(info.get("objectId")).int)
def get_or_create_user( def get_user_enroll_context(
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: Dict[str, Any],
) -> User: ) -> Dict[str, Any]:
user_data = { mail = info.get("mail", None) or info.get("otherMails", [None])[0]
return {
"username": info.get("displayName"), "username": info.get("displayName"),
"email": info.get("mail", None) or info.get("otherMails")[0], "email": mail,
"name": info.get("displayName"), "name": info.get("displayName"),
"password": None,
} }
return user_get_or_create(**user_data)

View File

@ -1,6 +1,8 @@
"""Discord OAuth Views""" """Discord OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -18,12 +20,14 @@ class DiscordOAuthRedirect(OAuthRedirect):
class DiscordOAuth2Callback(OAuthCallback): class DiscordOAuth2Callback(OAuthCallback):
"""Discord OAuth2 Callback""" """Discord OAuth2 Callback"""
def get_or_create_user(self, source, access, info): def get_user_enroll_context(
user_data = { self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("username"), "username": info.get("username"),
"email": info.get("email", "None"), "email": info.get("email", None),
"name": info.get("username"), "name": info.get("username"),
"password": None,
} }
discord_user = user_get_or_create(**user_data)
return discord_user

View File

@ -4,8 +4,8 @@ from typing import Any, Dict, Optional
from facebook import GraphAPI from facebook import GraphAPI
from passbook.sources.oauth.clients import OAuth2Client from passbook.sources.oauth.clients import OAuth2Client
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -33,12 +33,14 @@ class FacebookOAuth2Callback(OAuthCallback):
client_class = FacebookOAuth2Client client_class = FacebookOAuth2Client
def get_or_create_user(self, source, access, info): def get_user_enroll_context(
user_data = { self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("name"), "username": info.get("name"),
"email": info.get("email", ""), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
"password": None,
} }
fb_user = user_get_or_create(**user_data)
return fb_user

View File

@ -1,6 +1,8 @@
"""GitHub OAuth Views""" """GitHub OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback from passbook.sources.oauth.views.core import OAuthCallback
@ -8,12 +10,14 @@ from passbook.sources.oauth.views.core import OAuthCallback
class GitHubOAuth2Callback(OAuthCallback): class GitHubOAuth2Callback(OAuthCallback):
"""GitHub OAuth2 Callback""" """GitHub OAuth2 Callback"""
def get_or_create_user(self, source, access, info): def get_user_enroll_context(
user_data = { self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("login"), "username": info.get("login"),
"email": info.get("email", ""), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
"password": None,
} }
gh_user = user_get_or_create(**user_data)
return gh_user

View File

@ -1,6 +1,8 @@
"""Google OAuth Views""" """Google OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -18,12 +20,14 @@ class GoogleOAuthRedirect(OAuthRedirect):
class GoogleOAuth2Callback(OAuthCallback): class GoogleOAuth2Callback(OAuthCallback):
"""Google OAuth2 Callback""" """Google OAuth2 Callback"""
def get_or_create_user(self, source, access, info): def get_user_enroll_context(
user_data = { self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("email"), "username": info.get("email"),
"email": info.get("email", ""), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
"password": None,
} }
google_user = user_get_or_create(**user_data)
return google_user

View File

@ -48,7 +48,11 @@ class SourceTypeManager:
if kind.value in self.__source_types: if kind.value in self.__source_types:
if source.provider_type in self.__source_types[kind.value]: if source.provider_type in self.__source_types[kind.value]:
return self.__source_types[kind.value][source.provider_type] return self.__source_types[kind.value][source.provider_type]
LOGGER.warning("no matching type found, using default") LOGGER.warning(
"no matching type found, using default",
wanted=source.provider_type,
have=self.__source_types[kind.value].keys(),
)
# Return defaults # Return defaults
if kind == RequestKind.callback: if kind == RequestKind.callback:
return OAuthCallback return OAuthCallback

View File

@ -1,9 +1,8 @@
"""OpenID Connect OAuth Views""" """OpenID Connect OAuth Views"""
from typing import Dict from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -24,11 +23,14 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
def get_user_id(self, source: OAuthSource, info: Dict[str, str]) -> str: def get_user_id(self, source: OAuthSource, info: Dict[str, str]) -> str:
return info.get("sub", "") return info.get("sub", "")
def get_or_create_user(self, source: OAuthSource, access, info: Dict[str, str]): def get_user_enroll_context(
user_data = { self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("nickname"), "username": info.get("nickname"),
"email": info.get("email"), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
"password": None,
} }
return user_get_or_create(**user_data)

View File

@ -1,9 +1,11 @@
"""Reddit OAuth Views""" """Reddit OAuth Views"""
from typing import Any, Dict
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
from passbook.sources.oauth.clients import OAuth2Client from passbook.sources.oauth.clients import OAuth2Client
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -35,12 +37,15 @@ class RedditOAuth2Callback(OAuthCallback):
client_class = RedditOAuth2Client client_class = RedditOAuth2Client
def get_or_create_user(self, source, access, info): def get_user_enroll_context(
user_data = { self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("name"), "username": info.get("name"),
"email": None, "email": None,
"name": info.get("name"), "name": info.get("name"),
"password": None, "password": None,
} }
reddit_user = user_get_or_create(**user_data)
return reddit_user

View File

@ -1,6 +1,9 @@
"""Twitter OAuth Views""" """Twitter OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
# from passbook.sources.oauth.types.manager import MANAGER, RequestKind # from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback from passbook.sources.oauth.views.core import OAuthCallback
@ -8,12 +11,14 @@ from passbook.sources.oauth.views.core import OAuthCallback
class TwitterOAuthCallback(OAuthCallback): class TwitterOAuthCallback(OAuthCallback):
"""Twitter OAuth2 Callback""" """Twitter OAuth2 Callback"""
def get_or_create_user(self, source, access, info): def get_user_enroll_context(
user_data = { self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("screen_name"), "username": info.get("screen_name"),
"email": info.get("email", ""), "email": info.get("email"),
"name": info.get("name"), "name": info.get("name"),
"password": None,
} }
tw_user = user_get_or_create(**user_data)
return tw_user

View File

@ -1,16 +0,0 @@
"""OAuth Client User Creation Utils"""
from django.db.utils import IntegrityError
from passbook.core.models import User
def user_get_or_create(**kwargs: str) -> User:
"""Create user or return existing user"""
try:
new_user = User.objects.create_user(**kwargs)
except IntegrityError:
# At this point we've already checked that there is no existing connection
# to any user. Hence if we can't create the user,
kwargs["username"] = "%s_1" % kwargs["username"]
new_user = User.objects.create_user(**kwargs)
return new_user

View File

@ -67,10 +67,10 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
try: try:
source = OAuthSource.objects.get(slug=slug) source = OAuthSource.objects.get(slug=slug)
except OAuthSource.DoesNotExist: except OAuthSource.DoesNotExist:
raise Http404("Unknown OAuth source '%s'." % slug) raise Http404(f"Unknown OAuth source '{slug}'.")
else: else:
if not source.enabled: if not source.enabled:
raise Http404("source %s is not enabled." % slug) raise Http404(f"source {slug} is not enabled.")
client = self.get_client(source) client = self.get_client(source)
callback = self.get_callback_url(source) callback = self.get_callback_url(source)
params = self.get_additional_parameters(source) params = self.get_additional_parameters(source)
@ -138,8 +138,8 @@ class OAuthCallback(OAuthClientMixin, View):
source=self.source, identifier=identifier, request=request source=self.source, identifier=identifier, request=request
) )
if user is None: if user is None:
LOGGER.debug("Handling new user", source=self.source) LOGGER.debug("Handling new connection", source=self.source)
return self.handle_new_user(self.source, connection, info) return self.handle_new_connection(self.source, connection, info)
LOGGER.debug("Handling existing user", source=self.source) LOGGER.debug("Handling existing user", source=self.source)
return self.handle_existing_user(self.source, user, connection, info) return self.handle_existing_user(self.source, user, connection, info)
@ -153,13 +153,13 @@ class OAuthCallback(OAuthClientMixin, View):
"Return url to redirect on login failure." "Return url to redirect on login failure."
return settings.LOGIN_URL return settings.LOGIN_URL
def get_or_create_user( def get_user_enroll_context(
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: Dict[str, Any],
) -> User: ) -> Dict[str, Any]:
"Create a shell auth.User." """Create a dict of User data"""
raise NotImplementedError() raise NotImplementedError()
# pylint: disable=unused-argument # pylint: disable=unused-argument
@ -171,21 +171,19 @@ class OAuthCallback(OAuthClientMixin, View):
return info["id"] return info["id"]
return None return None
def handle_login_flow(self, flow: Optional[Flow], user: User) -> HttpResponse: def handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor""" """Prepare Authentication Plan, redirect user FlowExecutor"""
if not flow: kwargs.update(
raise Http404
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(
self.request,
{ {
PLAN_CONTEXT_PENDING_USER: user, # PLAN_CONTEXT_PENDING_USER: user,
# Since we authenticate the user by their token, they have no backend set # Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend", PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
PLAN_CONTEXT_SSO: True, PLAN_CONTEXT_SSO: True,
}, }
) )
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(self.request, kwargs,)
self.request.session[SESSION_KEY_PLAN] = plan self.request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs( return redirect_with_qs(
"passbook_flows:flow-executor-shell", self.request.GET, flow_slug=flow.slug, "passbook_flows:flow-executor-shell", self.request.GET, flow_slug=flow.slug,
@ -207,10 +205,8 @@ class OAuthCallback(OAuthClientMixin, View):
% {"source": self.source.name} % {"source": self.source.name}
), ),
) )
user = AuthorizedServiceBackend().authenticate( flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user}
source=access.source, identifier=access.identifier, request=self.request return self.handle_login_flow(source.authentication_flow, **flow_kwargs)
)
return self.handle_login_flow(source.authentication_flow, user)
def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse: def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse:
"Message user and redirect on error." "Message user and redirect on error."
@ -218,27 +214,23 @@ class OAuthCallback(OAuthClientMixin, View):
messages.error(self.request, _("Authentication Failed.")) messages.error(self.request, _("Authentication Failed."))
return redirect(self.get_error_redirect(source, reason)) return redirect(self.get_error_redirect(source, reason))
def handle_new_user( def handle_new_connection(
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: Dict[str, Any],
) -> HttpResponse: ) -> HttpResponse:
"Create a shell auth.User and redirect." """Check if a user exists for the connection and connect them, otherwise
was_authenticated = False prepare to enroll a new user."""
if self.request.user.is_authenticated: if self.request.user.is_authenticated:
# there's already a user logged in, just link them up # there's already a user logged in, just link them up
user = self.request.user user = self.request.user
was_authenticated = True
else:
user = self.get_or_create_user(source, access, info)
access.user = user access.user = user
access.save() access.save()
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user) UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
Event.new( Event.new(
EventAction.CUSTOM, message="Linked OAuth Source", source=source EventAction.CUSTOM, message="Linked OAuth Source", source=source
).from_http(self.request) ).from_http(self.request)
if was_authenticated:
messages.success( messages.success(
self.request, self.request,
_("Successfully linked %(source)s!" % {"source": self.source.name}), _("Successfully linked %(source)s!" % {"source": self.source.name}),
@ -249,10 +241,7 @@ class OAuthCallback(OAuthClientMixin, View):
kwargs={"source_slug": self.source.slug}, kwargs={"source_slug": self.source.slug},
) )
) )
# User was not authenticated, new user has been created # User was not authenticated, new user will be created
user = AuthorizedServiceBackend().authenticate(
source=access.source, identifier=access.identifier, request=self.request
)
messages.success( messages.success(
self.request, self.request,
_( _(
@ -260,7 +249,8 @@ class OAuthCallback(OAuthClientMixin, View):
% {"source": self.source.name} % {"source": self.source.name}
), ),
) )
return self.handle_login_flow(source.enrollment_flow, user) context = self.get_user_enroll_context(source, access, info)
return self.handle_login_flow(source.enrollment_flow, **context)
class DisconnectView(LoginRequiredMixin, View): class DisconnectView(LoginRequiredMixin, View):