sources: rewrite onboarding

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-05-03 20:27:52 +02:00
parent e56c3fc54c
commit 35faf269db
26 changed files with 689 additions and 394 deletions

View file

@ -45,6 +45,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
"verbose_name",
"verbose_name_plural",
"policy_engine_mode",
"user_matching_mode",
]

View file

@ -0,0 +1,40 @@
# Generated by Django 3.2 on 2021-05-03 17:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0019_source_managed"),
]
operations = [
migrations.AddField(
model_name="source",
name="user_matching_mode",
field=models.TextField(
choices=[
("identifier", "Use the source-specific identifier"),
(
"email_link",
"Link to a user with identical email address. Can have security implications when a source doesn't validate email addresses.",
),
(
"email_deny",
"Use the user's email address, but deny enrollment when the email address already exists.",
),
(
"username_link",
"Link to a user with identical username address. Can have security implications when a username is used with another source.",
),
(
"username_deny",
"Use the user's username, but deny enrollment when the username already exists.",
),
],
default="identifier",
help_text="How the source determines if an existing user should be authenticated or a new user enrolled.",
),
),
]

View file

@ -240,6 +240,30 @@ class Application(PolicyBindingModel):
verbose_name_plural = _("Applications")
class SourceUserMatchingModes(models.TextChoices):
"""Different modes a source can handle new/returning users"""
IDENTIFIER = "identifier", _("Use the source-specific identifier")
EMAIL_LINK = "email_link", _(
(
"Link to a user with identical email address. Can have security implications "
"when a source doesn't validate email addresses."
)
)
EMAIL_DENY = "email_deny", _(
"Use the user's email address, but deny enrollment when the email address already exists."
)
USERNAME_LINK = "username_link", _(
(
"Link to a user with identical username address. Can have security implications "
"when a username is used with another source."
)
)
USERNAME_DENY = "username_deny", _(
"Use the user's username, but deny enrollment when the username already exists."
)
class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
@ -272,6 +296,17 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
related_name="source_enrollment",
)
user_matching_mode = models.TextField(
choices=SourceUserMatchingModes.choices,
default=SourceUserMatchingModes.IDENTIFIER,
help_text=_(
(
"How the source determines if an existing user should be authenticated or "
"a new user enrolled."
)
),
)
objects = InheritanceManager()
@property
@ -301,6 +336,8 @@ class UserSourceConnection(CreatedUpdatedModel):
user = models.ForeignKey(User, on_delete=models.CASCADE)
source = models.ForeignKey(Source, on_delete=models.CASCADE)
objects = InheritanceManager()
class Meta:
unique_together = (("user", "source"),)

View file

View file

@ -0,0 +1,261 @@
"""Source decision helper"""
from enum import Enum
from typing import Any, Optional, Type
from django.contrib import messages
from django.db.models.query_utils import Q
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.translation import gettext as _
from structlog.stdlib import get_logger
from authentik.core.models import (
Source,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION,
PostUserEnrollmentStage,
)
from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow, Stage, in_memory_stage
from authentik.flows.planner import (
PLAN_CONTEXT_PENDING_USER,
PLAN_CONTEXT_REDIRECT,
PLAN_CONTEXT_SOURCE,
PLAN_CONTEXT_SSO,
FlowPlanner,
)
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN
from authentik.lib.utils.urls import redirect_with_qs
from authentik.policies.utils import delete_none_keys
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
class Action(Enum):
"""Actions that can be decided based on the request
and source settings"""
LINK = "link"
AUTH = "auth"
ENROLL = "enroll"
DENY = "deny"
class SourceFlowManager:
"""Help sources decide what they should do after authorization. Based on source settings and
previous connections, authenticate the user, enroll a new user, link to an existing user
or deny the request."""
source: Source
request: HttpRequest
identifier: str
connection_type: Type[UserSourceConnection] = UserSourceConnection
def __init__(
self,
source: Source,
request: HttpRequest,
identifier: str,
enroll_info: dict[str, Any],
) -> None:
self.source = source
self.request = request
self.identifier = identifier
self.enroll_info = enroll_info
self._logger = get_logger().bind(source=source, identifier=identifier)
# pylint: disable=too-many-return-statements
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
"""decide which action should be taken"""
new_connection = self.connection_type(
source=self.source, identifier=self.identifier
)
# When request is authenticated, always link
if self.request.user.is_authenticated:
new_connection.user = self.request.user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection.save()
return Action.LINK, new_connection
existing_connections = self.connection_type.objects.filter(
source=self.source, identifier=self.identifier
)
if existing_connections.exists():
connection = existing_connections.first()
return Action.AUTH, self.update_connection(connection, **kwargs)
# No connection exists, but we match on identifier, so enroll
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
# Check for existing users with matching attributes
query = Q()
# Either query existing user based on email or username
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.EMAIL_DENY,
]:
if not self.enroll_info.get("email", None):
self._logger.warning("Refusing to use none email", source=self.source)
return Action.DENY, None
query = Q(email__exact=self.enroll_info.get("email", None))
if self.source.user_matching_mode in [
SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY,
]:
if not self.enroll_info.get("username", None):
self._logger.warning(
"Refusing to use none username", source=self.source
)
return Action.DENY, None
query = Q(username__exact=self.enroll_info.get("username", None))
matching_users = User.objects.filter(query)
# No matching users, always enroll
if not matching_users.exists():
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
user = matching_users.first()
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.USERNAME_LINK,
]:
new_connection.user = user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection.save()
return Action.LINK, new_connection
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_DENY,
SourceUserMatchingModes.USERNAME_DENY,
]:
return Action.DENY, None
return Action.DENY, None
def update_connection(
self, connection: UserSourceConnection, **kwargs
) -> UserSourceConnection:
"""Optionally make changes to the connection after it is looked up/created."""
return connection
def get_flow(self, **kwargs) -> HttpResponse:
"""Get the flow response based on user_matching_mode"""
action, connection = self.get_action()
if action == Action.LINK:
self._logger.debug("Linking existing user")
return self.handle_existing_user_link()
if not connection:
return redirect("/")
if action == Action.AUTH:
self._logger.debug("Handling auth user")
return self.handle_auth_user(connection)
if action == Action.ENROLL:
self._logger.debug("Handling enrollment of new user")
return self.handle_enroll(connection)
return redirect("/")
# pylint: disable=unused-argument
def get_stages_to_append(self, flow: Flow) -> list[Stage]:
"""Hook to override stages which are appended to the flow"""
if flow.slug == self.source.enrollment_flow.slug:
return [
in_memory_stage(PostUserEnrollmentStage),
]
return []
def _handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to
# authorize application
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-admin"
)
kwargs.update(
{
# Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_REDIRECT: final_redirect,
}
)
if not flow:
return HttpResponseBadRequest()
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(self.request, kwargs)
for stage in self.get_stages_to_append(flow):
plan.append(stage)
self.request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
self.request.GET,
flow_slug=flow.slug,
)
# pylint: disable=unused-argument
def handle_auth_user(
self,
connection: UserSourceConnection,
) -> HttpResponse:
"""Login user and redirect."""
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs)
def handle_existing_user_link(
self,
) -> HttpResponse:
"""Handler when the user was already authenticated and linked an external source
to their account."""
Event.new(
EventAction.SOURCE_LINKED,
message="Linked Source",
source=self.source,
).from_http(self.request)
messages.success(
self.request,
_("Successfully linked %(source)s!" % {"source": self.source.name}),
)
return redirect(
reverse(
"authentik_core:if-admin",
)
+ f"#/user;page-{self.source.slug}"
)
def handle_enroll(
self,
connection: UserSourceConnection,
) -> HttpResponse:
"""User was not authenticated and previous request was not authenticated."""
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
# We run the Flow planner here so we can pass the Pending user in the context
if not self.source.enrollment_flow:
self._logger.warning("source has no enrollment flow")
return HttpResponseBadRequest()
return self._handle_login_flow(
self.source.enrollment_flow,
**{
PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info),
PLAN_CONTEXT_SOURCES_CONNECTION: connection,
},
)

View file

@ -1,32 +1,30 @@
"""OAuth Stages"""
"""Source flow manager stages"""
from django.http import HttpRequest, HttpResponse
from authentik.core.models import User
from authentik.core.models import User, UserSourceConnection
from authentik.events.models import Event, EventAction
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import StageView
from authentik.sources.oauth.models import UserOAuthSourceConnection
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS = "sources_oauth_access"
PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection"
class PostUserEnrollmentStage(StageView):
"""Dynamically injected stage which saves the OAuth Connection after
"""Dynamically injected stage which saves the Connection after
the user has been enrolled."""
# pylint: disable=unused-argument
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Stage used after the user has been enrolled"""
access: UserOAuthSourceConnection = self.executor.plan.context[
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS
connection: UserSourceConnection = self.executor.plan.context[
PLAN_CONTEXT_SOURCES_CONNECTION
]
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
access.user = user
access.save()
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
connection.user = user
connection.save()
Event.new(
EventAction.SOURCE_LINKED,
message="Linked OAuth Source",
source=access.source,
message="Linked Source",
source=connection.source,
).from_http(self.request)
return self.executor.stage_ok()

View file

@ -0,0 +1,84 @@
# Generated by Django 3.2 on 2021-05-02 17:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_event_matcher", "0012_auto_20210323_1339"),
]
operations = [
migrations.AlterField(
model_name="eventmatcherpolicy",
name="app",
field=models.TextField(
blank=True,
choices=[
("authentik.admin", "authentik Admin"),
("authentik.api", "authentik API"),
("authentik.events", "authentik Events"),
("authentik.crypto", "authentik Crypto"),
("authentik.flows", "authentik Flows"),
("authentik.outposts", "authentik Outpost"),
("authentik.lib", "authentik lib"),
("authentik.policies", "authentik Policies"),
("authentik.policies.dummy", "authentik Policies.Dummy"),
(
"authentik.policies.event_matcher",
"authentik Policies.Event Matcher",
),
("authentik.policies.expiry", "authentik Policies.Expiry"),
("authentik.policies.expression", "authentik Policies.Expression"),
("authentik.policies.hibp", "authentik Policies.HaveIBeenPwned"),
("authentik.policies.password", "authentik Policies.Password"),
("authentik.policies.reputation", "authentik Policies.Reputation"),
("authentik.providers.proxy", "authentik Providers.Proxy"),
("authentik.providers.oauth2", "authentik Providers.OAuth2"),
("authentik.providers.saml", "authentik Providers.SAML"),
("authentik.recovery", "authentik Recovery"),
("authentik.sources.ldap", "authentik Sources.LDAP"),
("authentik.sources.oauth", "authentik Sources.OAuth"),
("authentik.sources.plex", "authentik Sources.Plex"),
("authentik.sources.saml", "authentik Sources.SAML"),
(
"authentik.stages.authenticator_static",
"authentik Stages.Authenticator.Static",
),
(
"authentik.stages.authenticator_totp",
"authentik Stages.Authenticator.TOTP",
),
(
"authentik.stages.authenticator_validate",
"authentik Stages.Authenticator.Validate",
),
(
"authentik.stages.authenticator_webauthn",
"authentik Stages.Authenticator.WebAuthn",
),
("authentik.stages.captcha", "authentik Stages.Captcha"),
("authentik.stages.consent", "authentik Stages.Consent"),
("authentik.stages.deny", "authentik Stages.Deny"),
("authentik.stages.dummy", "authentik Stages.Dummy"),
("authentik.stages.email", "authentik Stages.Email"),
(
"authentik.stages.identification",
"authentik Stages.Identification",
),
("authentik.stages.invitation", "authentik Stages.User Invitation"),
("authentik.stages.password", "authentik Stages.Password"),
("authentik.stages.prompt", "authentik Stages.Prompt"),
("authentik.stages.user_delete", "authentik Stages.User Delete"),
("authentik.stages.user_login", "authentik Stages.User Login"),
("authentik.stages.user_logout", "authentik Stages.User Logout"),
("authentik.stages.user_write", "authentik Stages.User Write"),
("authentik.core", "authentik Core"),
("authentik.managed", "authentik Managed"),
],
default="",
help_text="Match events created by selected application. When left empty, all applications are matched.",
),
),
]

View file

@ -1,23 +0,0 @@
"""authentik oauth_client Authorization backend"""
from typing import Optional
from django.contrib.auth.backends import ModelBackend
from django.http import HttpRequest
from authentik.core.models import User
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
class AuthorizedServiceBackend(ModelBackend):
"Authentication backend for users registered with remote OAuth provider."
def authenticate(
self, request: HttpRequest, source: OAuthSource, identifier: str
) -> Optional[User]:
"Fetch user for a given source by id."
access = UserOAuthSourceConnection.objects.filter(
source=source, identifier=identifier
).select_related("user")
if not access.exists():
return None
return access.first().user

View file

@ -1,7 +1,7 @@
"""Discord Type tests"""
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.discord import DiscordOAuth2Callback
# https://discord.com/developers/docs/resources/user#user-object
@ -33,9 +33,7 @@ class TestTypeDiscord(TestCase):
def test_enroll_context(self):
"""Test discord Enrollment context"""
ak_context = DiscordOAuth2Callback().get_user_enroll_context(
self.source, UserOAuthSourceConnection(), DISCORD_USER
)
ak_context = DiscordOAuth2Callback().get_user_enroll_context(DISCORD_USER)
self.assertEqual(ak_context["username"], DISCORD_USER["username"])
self.assertEqual(ak_context["email"], DISCORD_USER["email"])
self.assertEqual(ak_context["name"], DISCORD_USER["username"])

View file

@ -1,7 +1,7 @@
"""GitHub Type tests"""
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.github import GitHubOAuth2Callback
# https://developer.github.com/v3/users/#get-the-authenticated-user
@ -63,9 +63,7 @@ class TestTypeGitHub(TestCase):
def test_enroll_context(self):
"""Test GitHub Enrollment context"""
ak_context = GitHubOAuth2Callback().get_user_enroll_context(
self.source, UserOAuthSourceConnection(), GITHUB_USER
)
ak_context = GitHubOAuth2Callback().get_user_enroll_context(GITHUB_USER)
self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], GITHUB_USER["email"])
self.assertEqual(ak_context["name"], GITHUB_USER["name"])

View file

@ -1,7 +1,7 @@
"""google Type tests"""
from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.google import GoogleOAuth2Callback
# https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en
@ -32,9 +32,7 @@ class TestTypeGoogle(TestCase):
def test_enroll_context(self):
"""Test Google Enrollment context"""
ak_context = GoogleOAuth2Callback().get_user_enroll_context(
self.source, UserOAuthSourceConnection(), GOOGLE_USER
)
ak_context = GoogleOAuth2Callback().get_user_enroll_context(GOOGLE_USER)
self.assertEqual(ak_context["username"], GOOGLE_USER["email"])
self.assertEqual(ak_context["email"], GOOGLE_USER["email"])
self.assertEqual(ak_context["name"], GOOGLE_USER["name"])

View file

@ -1,7 +1,7 @@
"""Twitter Type tests"""
from django.test import Client, TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.twitter import TwitterOAuthCallback
# https://developer.twitter.com/en/docs/twitter-api/v1/accounts-and-users/manage-account-settings/ \
@ -104,9 +104,7 @@ class TestTypeGitHub(TestCase):
def test_enroll_context(self):
"""Test Twitter Enrollment context"""
ak_context = TwitterOAuthCallback().get_user_enroll_context(
self.source, UserOAuthSourceConnection(), TWITTER_USER
)
ak_context = TwitterOAuthCallback().get_user_enroll_context(TWITTER_USER)
self.assertEqual(ak_context["username"], TWITTER_USER["screen_name"])
self.assertEqual(ak_context["email"], TWITTER_USER.get("email", None))
self.assertEqual(ak_context["name"], TWITTER_USER["name"])

View file

@ -2,7 +2,6 @@
from typing import Any, Optional
from uuid import UUID
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
@ -10,7 +9,7 @@ from authentik.sources.oauth.views.callback import OAuthCallback
class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: dict[str, Any]) -> Optional[str]:
def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
try:
return str(UUID(info.get("objectId")).int)
except TypeError:
@ -18,8 +17,6 @@ class AzureADOAuthCallback(OAuthCallback):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0]

View file

@ -1,7 +1,6 @@
"""Discord OAuth Views"""
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -21,8 +20,6 @@ class DiscordOAuth2Callback(OAuthCallback):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return {

View file

@ -4,7 +4,6 @@ from typing import Any, Optional
from facebook import GraphAPI
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -34,8 +33,6 @@ class FacebookOAuth2Callback(OAuthCallback):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return {

View file

@ -1,7 +1,6 @@
"""GitHub OAuth Views"""
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
@ -11,8 +10,6 @@ class GitHubOAuth2Callback(OAuthCallback):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return {

View file

@ -1,7 +1,6 @@
"""Google OAuth Views"""
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -21,8 +20,6 @@ class GoogleOAuth2Callback(OAuthCallback):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return {

View file

@ -1,7 +1,7 @@
"""OpenID Connect OAuth Views"""
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -19,13 +19,11 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: dict[str, str]) -> str:
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return {

View file

@ -4,7 +4,6 @@ from typing import Any
from requests.auth import HTTPBasicAuth
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -36,8 +35,6 @@ class RedditOAuth2Callback(OAuthCallback):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return {

View file

@ -1,7 +1,6 @@
"""Twitter OAuth Views"""
from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
@ -11,8 +10,6 @@ class TwitterOAuthCallback(OAuthCallback):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return {

View file

@ -4,35 +4,14 @@ from typing import Any, Optional
from django.conf import settings
from django.contrib import messages
from django.http import Http404, HttpRequest, HttpResponse
from django.http.response import HttpResponseBadRequest
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.translation import gettext as _
from django.views.generic import View
from structlog.stdlib import get_logger
from authentik.core.models import User
from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow, in_memory_stage
from authentik.flows.planner import (
PLAN_CONTEXT_PENDING_USER,
PLAN_CONTEXT_REDIRECT,
PLAN_CONTEXT_SOURCE,
PLAN_CONTEXT_SSO,
FlowPlanner,
)
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN
from authentik.lib.utils.urls import redirect_with_qs
from authentik.policies.utils import delete_none_keys
from authentik.sources.oauth.auth import AuthorizedServiceBackend
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.views.base import OAuthClientMixin
from authentik.sources.oauth.views.flows import (
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS,
PostUserEnrollmentStage,
)
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
LOGGER = get_logger()
@ -40,8 +19,7 @@ LOGGER = get_logger()
class OAuthCallback(OAuthClientMixin, View):
"Base OAuth callback view."
source_id = None
source = None
source: OAuthSource
# pylint: disable=too-many-return-statements
def get(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
@ -60,47 +38,27 @@ class OAuthCallback(OAuthClientMixin, View):
# Fetch access token
token = client.get_access_token()
if token is None:
return self.handle_login_failure(self.source, "Could not retrieve token.")
return self.handle_login_failure("Could not retrieve token.")
if "error" in token:
return self.handle_login_failure(self.source, token["error"])
return self.handle_login_failure(token["error"])
# Fetch profile info
info = client.get_profile_info(token)
if info is None:
return self.handle_login_failure(self.source, "Could not retrieve profile.")
identifier = self.get_user_id(self.source, info)
raw_info = client.get_profile_info(token)
if raw_info is None:
return self.handle_login_failure("Could not retrieve profile.")
identifier = self.get_user_id(raw_info)
if identifier is None:
return self.handle_login_failure(self.source, "Could not determine id.")
return self.handle_login_failure("Could not determine id.")
# Get or create access record
defaults = {
"access_token": token.get("access_token"),
}
existing = UserOAuthSourceConnection.objects.filter(
source=self.source, identifier=identifier
enroll_info = self.get_user_enroll_context(raw_info)
sfm = OAuthSourceFlowManager(
source=self.source,
request=self.request,
identifier=identifier,
enroll_info=enroll_info,
)
if existing.exists():
connection = existing.first()
connection.access_token = token.get("access_token")
UserOAuthSourceConnection.objects.filter(pk=connection.pk).update(
**defaults
)
else:
connection = UserOAuthSourceConnection(
source=self.source,
identifier=identifier,
access_token=token.get("access_token"),
)
user = AuthorizedServiceBackend().authenticate(
source=self.source, identifier=identifier, request=request
return sfm.get_flow(
token=token,
)
if user is None:
if self.request.user.is_authenticated:
LOGGER.debug("Linking existing user", source=self.source)
return self.handle_existing_user_link(self.source, connection, info)
LOGGER.debug("Handling enrollment of new user", source=self.source)
return self.handle_enroll(self.source, connection, info)
LOGGER.debug("Handling existing user", source=self.source)
return self.handle_existing_user(self.source, user, connection, info)
# pylint: disable=unused-argument
def get_callback_url(self, source: OAuthSource) -> str:
@ -114,132 +72,34 @@ class OAuthCallback(OAuthClientMixin, View):
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
"""Create a dict of User data"""
raise NotImplementedError()
# pylint: disable=unused-argument
def get_user_id(
self, source: UserOAuthSourceConnection, info: dict[str, Any]
) -> Optional[str]:
def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
"""Return unique identifier from the profile info."""
if "id" in info:
return info["id"]
return None
def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse:
def handle_login_failure(self, reason: str) -> HttpResponse:
"Message user and redirect on error."
LOGGER.warning("Authentication Failure", reason=reason)
messages.error(self.request, _("Authentication Failed."))
return redirect(self.get_error_redirect(source, reason))
return redirect(self.get_error_redirect(self.source, reason))
def handle_login_flow(
self, flow: Flow, *stages_to_append, **kwargs
) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to
# authorize application
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-admin"
)
kwargs.update(
{
# Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_REDIRECT: final_redirect,
}
)
if not flow:
return HttpResponseBadRequest()
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(self.request, kwargs)
for stage in stages_to_append:
plan.append(stage)
self.request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
self.request.GET,
flow_slug=flow.slug,
)
# pylint: disable=unused-argument
def handle_existing_user(
self,
source: OAuthSource,
user: User,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> HttpResponse:
"Login user and redirect."
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user}
return self.handle_login_flow(source.authentication_flow, **flow_kwargs)
class OAuthSourceFlowManager(SourceFlowManager):
"""Flow manager for oauth sources"""
def handle_existing_user_link(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> HttpResponse:
"""Handler when the user was already authenticated and linked an external source
to their account."""
# there's already a user logged in, just link them up
user = self.request.user
access.user = user
access.save()
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
Event.new(
EventAction.SOURCE_LINKED, message="Linked OAuth Source", source=source
).from_http(self.request)
messages.success(
self.request,
_("Successfully linked %(source)s!" % {"source": self.source.name}),
)
return redirect(
reverse(
"authentik_core:if-admin",
)
+ f"#/user;page-{self.source.slug}"
)
connection_type = UserOAuthSourceConnection
def handle_enroll(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> HttpResponse:
"""User was not authenticated and previous request was not authenticated."""
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
# We run the Flow planner here so we can pass the Pending user in the context
if not source.enrollment_flow:
LOGGER.warning("source has no enrollment flow", source=source)
return HttpResponseBadRequest()
return self.handle_login_flow(
source.enrollment_flow,
in_memory_stage(PostUserEnrollmentStage),
**{
PLAN_CONTEXT_PROMPT: delete_none_keys(
self.get_user_enroll_context(source, access, info)
),
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS: access,
},
)
def update_connection(
self, connection: UserOAuthSourceConnection, token: dict[str, Any]
) -> UserOAuthSourceConnection:
"""Set the access_token on the connection"""
connection.access_token = token.get("access_token")
connection.save()
return connection

View file

@ -1,26 +1,22 @@
"""Plex Source Serializer"""
from urllib.parse import urlencode
from django.http import Http404
from django.shortcuts import get_object_or_404
from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema
from requests import RequestException, get
from rest_framework.decorators import action
from rest_framework.fields import CharField
from rest_framework.permissions import AllowAny
from rest_framework.request import Request
from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required
from authentik.core.api.sources import SourceSerializer
from authentik.core.api.utils import PassiveSerializer
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge
from authentik.flows.challenge import RedirectChallenge
from authentik.flows.views import to_stage_response
from authentik.sources.plex.models import PlexSource
LOGGER = get_logger()
from authentik.sources.plex.plex import PlexAuth
class PlexSourceSerializer(SourceSerializer):
@ -72,29 +68,8 @@ class PlexSourceViewSet(ModelViewSet):
plex_token = request.data.get("plex_token", None)
if not plex_token:
raise Http404
qs = {"X-Plex-Token": plex_token, "X-Plex-Client-Identifier": source.client_id}
try:
response = get(
f"https://plex.tv/api/v2/resources?{urlencode(qs)}",
headers={"Accept": "application/json"},
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user resources", exc=exc)
auth_api = PlexAuth(source, plex_token)
if not auth_api.check_server_overlap():
raise Http404
else:
resources: list[dict] = response.json()
for resource in resources:
if resource["provides"] != "server":
continue
if resource["clientIdentifier"] in source.allowed_servers:
LOGGER.info(
"Plex allowed access from server", name=resource["name"]
)
request.session["foo"] = "bar"
break
return Response(
RedirectChallenge(
{"type": ChallengeTypes.REDIRECT.value, "to": ""}
).data
)
response = auth_api.get_user_url(request)
return to_stage_response(request, response)

View file

@ -0,0 +1,38 @@
# Generated by Django 3.2 on 2021-05-03 17:06
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0020_source_user_matching_mode"),
("authentik_sources_plex", "0001_initial"),
]
operations = [
migrations.CreateModel(
name="PlexSourceConnection",
fields=[
(
"usersourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.usersourceconnection",
),
),
("plex_token", models.TextField()),
("identifier", models.TextField()),
],
options={
"verbose_name": "User Plex Source Connection",
"verbose_name_plural": "User Plex Source Connections",
},
bases=("authentik_core.usersourceconnection",),
),
]

View file

@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField
from rest_framework.serializers import BaseSerializer
from authentik.core.models import Source
from authentik.core.models import Source, UserSourceConnection
from authentik.core.types import UILoginButton
from authentik.flows.challenge import Challenge, ChallengeTypes
@ -53,3 +53,15 @@ class PlexSource(Source):
verbose_name = _("Plex Source")
verbose_name_plural = _("Plex Sources")
class PlexSourceConnection(UserSourceConnection):
"""Connect user and plex source"""
plex_token = models.TextField()
identifier = models.TextField()
class Meta:
verbose_name = _("User Plex Source Connection")
verbose_name_plural = _("User Plex Source Connections")

View file

@ -1,136 +1,113 @@
"""Plex OAuth Views"""
from typing import Any, Optional
"""Plex Views"""
from urllib.parse import urlencode
from django.http.response import Http404
from requests import post
from requests.api import get
import requests
from django.http.request import HttpRequest
from django.http.response import Http404, HttpResponse
from requests.exceptions import RequestException
from structlog.stdlib import get_logger
from authentik import __version__
from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.sources.plex.models import PlexSource, PlexSourceConnection
LOGGER = get_logger()
SESSION_ID_KEY = "PLEX_ID"
SESSION_CODE_KEY = "PLEX_CODE"
DEFAULT_PAYLOAD = {
"X-Plex-Product": "authentik",
"X-Plex-Version": __version__,
"X-Plex-Device-Vendor": "BeryJu.org",
}
class PlexRedirect(OAuthRedirect):
"""Plex Auth redirect, get a pin then redirect to a URL to claim it"""
class PlexAuth:
"""Plex authentication utilities"""
headers = {}
_source: PlexSource
_token: str
def get_pin(self, **data) -> dict:
"""Get plex pin that the user will claim
https://forums.plex.tv/t/authenticating-with-plex/609370"""
return post(
"https://plex.tv/api/v2/pins.json?strong=true",
data=data,
headers=self.headers,
).json()
def get_redirect_url(self, **kwargs) -> str:
slug = kwargs.get("source_slug", "")
self.headers = {"Origin": self.request.build_absolute_uri("/")}
try:
source: OAuthSource = OAuthSource.objects.get(slug=slug)
except OAuthSource.DoesNotExist:
raise Http404(f"Unknown OAuth source '{slug}'.")
else:
payload = DEFAULT_PAYLOAD.copy()
payload["X-Plex-Client-Identifier"] = source.consumer_key
# Get a pin first
pin = self.get_pin(**payload)
LOGGER.debug("Got pin", **pin)
self.request.session[SESSION_ID_KEY] = pin["id"]
self.request.session[SESSION_CODE_KEY] = pin["code"]
qs = {
"clientID": source.consumer_key,
"code": pin["code"],
"forwardUrl": self.request.build_absolute_uri(
self.get_callback_url(source)
),
}
return f"https://app.plex.tv/auth#!?{urlencode(qs)}"
class PlexOAuthClient(OAuth2Client):
"""Retrive the plex token after authentication, then ask the plex API about user info"""
def check_application_state(self) -> bool:
return SESSION_ID_KEY in self.request.session
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
payload = dict(DEFAULT_PAYLOAD)
payload["X-Plex-Client-Identifier"] = self.source.consumer_key
payload["Accept"] = "application/json"
response = get(
f"https://plex.tv/api/v2/pins/{self.request.session[SESSION_ID_KEY]}",
headers=payload,
def __init__(self, source: PlexSource, token: str):
self._source = source
self._token = token
self._session = requests.Session()
self._session.headers.update(
{"Accept": "application/json", "Content-Type": "application/json"}
)
response.raise_for_status()
token = response.json()["authToken"]
return {"plex_token": token}
self._session.headers.update(self.headers)
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information."
qs = {"X-Plex-Token": token["plex_token"]}
print(token)
try:
response = self.do_request(
"get", f"https://plex.tv/users/account.json?{urlencode(qs)}"
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc)
return None
else:
info = response.json()
return info.get("user", {})
class PlexOAuth2Callback(OAuthCallback):
"""Plex OAuth2 Callback"""
client_class = PlexOAuthClient
def get_user_id(
self, source: UserOAuthSourceConnection, info: dict[str, Any]
) -> Optional[str]:
return info.get("uuid")
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
@property
def headers(self) -> dict[str, str]:
"""Get common headers"""
return {
"username": info.get("username"),
"email": info.get("email"),
"name": info.get("title"),
"X-Plex-Product": "authentik",
"X-Plex-Version": __version__,
"X-Plex-Device-Vendor": "BeryJu.org",
}
def get_resources(self) -> list[dict]:
"""Get all resources the plex-token has access to"""
qs = {
"X-Plex-Token": self._token,
"X-Plex-Client-Identifier": self._source.client_id,
}
response = self._session.get(
f"https://plex.tv/api/v2/resources?{urlencode(qs)}",
)
response.raise_for_status()
return response.json()
@MANAGER.type()
class PlexType(SourceType):
"""Plex Type definition"""
def get_user_info(self) -> tuple[dict, int]:
"""Get user info of the plex token"""
qs = {
"X-Plex-Token": self._token,
"X-Plex-Client-Identifier": self._source.client_id,
}
response = self._session.get(
f"https://plex.tv/api/v2/user?{urlencode(qs)}",
)
response.raise_for_status()
raw_user_info = response.json()
return {
"username": raw_user_info.get("username"),
"email": raw_user_info.get("email"),
"name": raw_user_info.get("title"),
}, raw_user_info.get("id")
redirect_view = PlexRedirect
callback_view = PlexOAuth2Callback
name = "Plex"
slug = "plex"
def check_server_overlap(self) -> bool:
"""Check if the plex-token has any server overlap with our configured servers"""
try:
resources = self.get_resources()
except RequestException as exc:
LOGGER.warning("Unable to fetch user resources", exc=exc)
raise Http404
else:
for resource in resources:
if resource["provides"] != "server":
continue
if resource["clientIdentifier"] in self._source.allowed_servers:
LOGGER.info(
"Plex allowed access from server", name=resource["name"]
)
return True
return False
authorization_url = ""
access_token_url = "" # nosec
profile_url = ""
def get_user_url(self, request: HttpRequest) -> HttpResponse:
"""Get a URL to a flow executor for either enrollment or authentication"""
user_info, identifier = self.get_user_info()
sfm = PlexSourceFlowManager(
source=self._source,
request=request,
identifier=str(identifier),
enroll_info=user_info,
)
return sfm.get_flow(plex_token=self._token)
class PlexSourceFlowManager(SourceFlowManager):
"""Flow manager for plex sources"""
connection_type = PlexSourceConnection
def update_connection(
self, connection: PlexSourceConnection, plex_token: str
) -> PlexSourceConnection:
"""Set the access_token on the connection"""
connection.plex_token = plex_token
connection.save()
return connection

View file

@ -17289,6 +17289,17 @@ definitions:
enum:
- all
- any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
UserSetting:
required:
- object_uid
@ -17369,6 +17380,17 @@ definitions:
enum:
- all
- any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
server_uri:
title: Server URI
type: string
@ -17549,6 +17571,17 @@ definitions:
enum:
- all
- any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
provider_type:
title: Provider type
type: string
@ -17678,6 +17711,17 @@ definitions:
enum:
- all
- any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
client_id:
title: Client id
type: string
@ -17792,6 +17836,17 @@ definitions:
enum:
- all
- any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
pre_authentication_flow:
title: Pre authentication flow
description: Flow used before authentication.
@ -18537,6 +18592,17 @@ definitions:
enabled:
title: Enabled
type: boolean
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should
be authenticated or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
authentication_flow:
title: Authentication flow
description: Flow to use when authenticating existing users.