sources/oauth: rewrite to not directly create user, pre-seed data into flow

This commit is contained in:
Jens Langhammer 2020-07-08 20:36:48 +02:00
parent 0e3e73989d
commit d786fa4b7c
12 changed files with 119 additions and 118 deletions

View file

@ -1,6 +1,7 @@
"""passbook oauth_client Authorization backend"""
from typing import Optional
from django.contrib.auth.backends import ModelBackend
from django.db.models import Q
from django.http import HttpRequest
from passbook.core.models import User
@ -12,12 +13,11 @@ class AuthorizedServiceBackend(ModelBackend):
def authenticate(
self, request: HttpRequest, source: OAuthSource, identifier: str
) -> User:
) -> Optional[User]:
"Fetch user for a given source by id."
source_q = Q(source__name=source)
if isinstance(source, OAuthSource):
source_q = Q(source=source)
access = UserOAuthSourceConnection.objects.filter(
source_q, identifier=identifier
).select_related("user")[0]
return access.user
source=source, identifier=identifier
).select_related("user")
if not access.exists():
return None
return access.first().user

View file

@ -2,10 +2,8 @@
import uuid
from typing import Any, Dict
from passbook.core.models import User
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback
@ -16,16 +14,15 @@ class AzureADOAuthCallback(OAuthCallback):
def get_user_id(self, source: OAuthSource, info: Dict[str, Any]) -> str:
return str(uuid.UUID(info.get("objectId")).int)
def get_or_create_user(
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> User:
user_data = {
) -> Dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
return {
"username": info.get("displayName"),
"email": info.get("mail", None) or info.get("otherMails")[0],
"email": mail,
"name": info.get("displayName"),
"password": None,
}
return user_get_or_create(**user_data)

View file

@ -1,6 +1,8 @@
"""Discord OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -18,12 +20,14 @@ class DiscordOAuthRedirect(OAuthRedirect):
class DiscordOAuth2Callback(OAuthCallback):
"""Discord OAuth2 Callback"""
def get_or_create_user(self, source, access, info):
user_data = {
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("username"),
"email": info.get("email", "None"),
"email": info.get("email", None),
"name": info.get("username"),
"password": None,
}
discord_user = user_get_or_create(**user_data)
return discord_user

View file

@ -4,8 +4,8 @@ from typing import Any, Dict, Optional
from facebook import GraphAPI
from passbook.sources.oauth.clients import OAuth2Client
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -33,12 +33,14 @@ class FacebookOAuth2Callback(OAuthCallback):
client_class = FacebookOAuth2Client
def get_or_create_user(self, source, access, info):
user_data = {
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("name"),
"email": info.get("email", ""),
"email": info.get("email"),
"name": info.get("name"),
"password": None,
}
fb_user = user_get_or_create(**user_data)
return fb_user

View file

@ -1,6 +1,8 @@
"""GitHub OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback
@ -8,12 +10,14 @@ from passbook.sources.oauth.views.core import OAuthCallback
class GitHubOAuth2Callback(OAuthCallback):
"""GitHub OAuth2 Callback"""
def get_or_create_user(self, source, access, info):
user_data = {
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("login"),
"email": info.get("email", ""),
"email": info.get("email"),
"name": info.get("name"),
"password": None,
}
gh_user = user_get_or_create(**user_data)
return gh_user

View file

@ -1,6 +1,8 @@
"""Google OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -18,12 +20,14 @@ class GoogleOAuthRedirect(OAuthRedirect):
class GoogleOAuth2Callback(OAuthCallback):
"""Google OAuth2 Callback"""
def get_or_create_user(self, source, access, info):
user_data = {
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("email"),
"email": info.get("email", ""),
"email": info.get("email"),
"name": info.get("name"),
"password": None,
}
google_user = user_get_or_create(**user_data)
return google_user

View file

@ -48,7 +48,11 @@ class SourceTypeManager:
if kind.value in self.__source_types:
if source.provider_type in self.__source_types[kind.value]:
return self.__source_types[kind.value][source.provider_type]
LOGGER.warning("no matching type found, using default")
LOGGER.warning(
"no matching type found, using default",
wanted=source.provider_type,
have=self.__source_types[kind.value].keys(),
)
# Return defaults
if kind == RequestKind.callback:
return OAuthCallback

View file

@ -1,9 +1,8 @@
"""OpenID Connect OAuth Views"""
from typing import Dict
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -24,11 +23,14 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
def get_user_id(self, source: OAuthSource, info: Dict[str, str]) -> str:
return info.get("sub", "")
def get_or_create_user(self, source: OAuthSource, access, info: Dict[str, str]):
user_data = {
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("nickname"),
"email": info.get("email"),
"name": info.get("name"),
"password": None,
}
return user_get_or_create(**user_data)

View file

@ -1,9 +1,11 @@
"""Reddit OAuth Views"""
from typing import Any, Dict
from requests.auth import HTTPBasicAuth
from passbook.sources.oauth.clients import OAuth2Client
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback, OAuthRedirect
@ -35,12 +37,15 @@ class RedditOAuth2Callback(OAuthCallback):
client_class = RedditOAuth2Client
def get_or_create_user(self, source, access, info):
user_data = {
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("name"),
"email": None,
"name": info.get("name"),
"password": None,
}
reddit_user = user_get_or_create(**user_data)
return reddit_user

View file

@ -1,6 +1,9 @@
"""Twitter OAuth Views"""
from typing import Any, Dict
from passbook.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
# from passbook.sources.oauth.types.manager import MANAGER, RequestKind
from passbook.sources.oauth.utils import user_get_or_create
from passbook.sources.oauth.views.core import OAuthCallback
@ -8,12 +11,14 @@ from passbook.sources.oauth.views.core import OAuthCallback
class TwitterOAuthCallback(OAuthCallback):
"""Twitter OAuth2 Callback"""
def get_or_create_user(self, source, access, info):
user_data = {
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> Dict[str, Any]:
return {
"username": info.get("screen_name"),
"email": info.get("email", ""),
"email": info.get("email"),
"name": info.get("name"),
"password": None,
}
tw_user = user_get_or_create(**user_data)
return tw_user

View file

@ -1,16 +0,0 @@
"""OAuth Client User Creation Utils"""
from django.db.utils import IntegrityError
from passbook.core.models import User
def user_get_or_create(**kwargs: str) -> User:
"""Create user or return existing user"""
try:
new_user = User.objects.create_user(**kwargs)
except IntegrityError:
# At this point we've already checked that there is no existing connection
# to any user. Hence if we can't create the user,
kwargs["username"] = "%s_1" % kwargs["username"]
new_user = User.objects.create_user(**kwargs)
return new_user

View file

@ -67,10 +67,10 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
try:
source = OAuthSource.objects.get(slug=slug)
except OAuthSource.DoesNotExist:
raise Http404("Unknown OAuth source '%s'." % slug)
raise Http404(f"Unknown OAuth source '{slug}'.")
else:
if not source.enabled:
raise Http404("source %s is not enabled." % slug)
raise Http404(f"source {slug} is not enabled.")
client = self.get_client(source)
callback = self.get_callback_url(source)
params = self.get_additional_parameters(source)
@ -138,8 +138,8 @@ class OAuthCallback(OAuthClientMixin, View):
source=self.source, identifier=identifier, request=request
)
if user is None:
LOGGER.debug("Handling new user", source=self.source)
return self.handle_new_user(self.source, connection, info)
LOGGER.debug("Handling new connection", source=self.source)
return self.handle_new_connection(self.source, connection, info)
LOGGER.debug("Handling existing user", source=self.source)
return self.handle_existing_user(self.source, user, connection, info)
@ -153,13 +153,13 @@ class OAuthCallback(OAuthClientMixin, View):
"Return url to redirect on login failure."
return settings.LOGIN_URL
def get_or_create_user(
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> User:
"Create a shell auth.User."
) -> Dict[str, Any]:
"""Create a dict of User data"""
raise NotImplementedError()
# pylint: disable=unused-argument
@ -171,21 +171,19 @@ class OAuthCallback(OAuthClientMixin, View):
return info["id"]
return None
def handle_login_flow(self, flow: Optional[Flow], user: User) -> HttpResponse:
def handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
if not flow:
raise Http404
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(
self.request,
kwargs.update(
{
PLAN_CONTEXT_PENDING_USER: user,
# PLAN_CONTEXT_PENDING_USER: user,
# Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
PLAN_CONTEXT_SSO: True,
},
}
)
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(self.request, kwargs,)
self.request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"passbook_flows:flow-executor-shell", self.request.GET, flow_slug=flow.slug,
@ -207,10 +205,8 @@ class OAuthCallback(OAuthClientMixin, View):
% {"source": self.source.name}
),
)
user = AuthorizedServiceBackend().authenticate(
source=access.source, identifier=access.identifier, request=self.request
)
return self.handle_login_flow(source.authentication_flow, user)
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user}
return self.handle_login_flow(source.authentication_flow, **flow_kwargs)
def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse:
"Message user and redirect on error."
@ -218,27 +214,23 @@ class OAuthCallback(OAuthClientMixin, View):
messages.error(self.request, _("Authentication Failed."))
return redirect(self.get_error_redirect(source, reason))
def handle_new_user(
def handle_new_connection(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: Dict[str, Any],
) -> HttpResponse:
"Create a shell auth.User and redirect."
was_authenticated = False
"""Check if a user exists for the connection and connect them, otherwise
prepare to enroll a new user."""
if self.request.user.is_authenticated:
# there's already a user logged in, just link them up
user = self.request.user
was_authenticated = True
else:
user = self.get_or_create_user(source, access, info)
access.user = user
access.save()
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
Event.new(
EventAction.CUSTOM, message="Linked OAuth Source", source=source
).from_http(self.request)
if was_authenticated:
access.user = user
access.save()
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
Event.new(
EventAction.CUSTOM, message="Linked OAuth Source", source=source
).from_http(self.request)
messages.success(
self.request,
_("Successfully linked %(source)s!" % {"source": self.source.name}),
@ -249,10 +241,7 @@ class OAuthCallback(OAuthClientMixin, View):
kwargs={"source_slug": self.source.slug},
)
)
# User was not authenticated, new user has been created
user = AuthorizedServiceBackend().authenticate(
source=access.source, identifier=access.identifier, request=self.request
)
# User was not authenticated, new user will be created
messages.success(
self.request,
_(
@ -260,7 +249,8 @@ class OAuthCallback(OAuthClientMixin, View):
% {"source": self.source.name}
),
)
return self.handle_login_flow(source.enrollment_flow, user)
context = self.get_user_enroll_context(source, access, info)
return self.handle_login_flow(source.enrollment_flow, **context)
class DisconnectView(LoginRequiredMixin, View):