sources: rewrite onboarding

Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
This commit is contained in:
Jens Langhammer 2021-05-03 20:27:52 +02:00
parent e56c3fc54c
commit 35faf269db
26 changed files with 689 additions and 394 deletions

View File

@ -45,6 +45,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
"verbose_name", "verbose_name",
"verbose_name_plural", "verbose_name_plural",
"policy_engine_mode", "policy_engine_mode",
"user_matching_mode",
] ]

View File

@ -0,0 +1,40 @@
# Generated by Django 3.2 on 2021-05-03 17:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0019_source_managed"),
]
operations = [
migrations.AddField(
model_name="source",
name="user_matching_mode",
field=models.TextField(
choices=[
("identifier", "Use the source-specific identifier"),
(
"email_link",
"Link to a user with identical email address. Can have security implications when a source doesn't validate email addresses.",
),
(
"email_deny",
"Use the user's email address, but deny enrollment when the email address already exists.",
),
(
"username_link",
"Link to a user with identical username address. Can have security implications when a username is used with another source.",
),
(
"username_deny",
"Use the user's username, but deny enrollment when the username already exists.",
),
],
default="identifier",
help_text="How the source determines if an existing user should be authenticated or a new user enrolled.",
),
),
]

View File

@ -240,6 +240,30 @@ class Application(PolicyBindingModel):
verbose_name_plural = _("Applications") verbose_name_plural = _("Applications")
class SourceUserMatchingModes(models.TextChoices):
"""Different modes a source can handle new/returning users"""
IDENTIFIER = "identifier", _("Use the source-specific identifier")
EMAIL_LINK = "email_link", _(
(
"Link to a user with identical email address. Can have security implications "
"when a source doesn't validate email addresses."
)
)
EMAIL_DENY = "email_deny", _(
"Use the user's email address, but deny enrollment when the email address already exists."
)
USERNAME_LINK = "username_link", _(
(
"Link to a user with identical username address. Can have security implications "
"when a username is used with another source."
)
)
USERNAME_DENY = "username_deny", _(
"Use the user's username, but deny enrollment when the username already exists."
)
class Source(ManagedModel, SerializerModel, PolicyBindingModel): class Source(ManagedModel, SerializerModel, PolicyBindingModel):
"""Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server""" """Base Authentication source, i.e. an OAuth Provider, SAML Remote or LDAP Server"""
@ -272,6 +296,17 @@ class Source(ManagedModel, SerializerModel, PolicyBindingModel):
related_name="source_enrollment", related_name="source_enrollment",
) )
user_matching_mode = models.TextField(
choices=SourceUserMatchingModes.choices,
default=SourceUserMatchingModes.IDENTIFIER,
help_text=_(
(
"How the source determines if an existing user should be authenticated or "
"a new user enrolled."
)
),
)
objects = InheritanceManager() objects = InheritanceManager()
@property @property
@ -301,6 +336,8 @@ class UserSourceConnection(CreatedUpdatedModel):
user = models.ForeignKey(User, on_delete=models.CASCADE) user = models.ForeignKey(User, on_delete=models.CASCADE)
source = models.ForeignKey(Source, on_delete=models.CASCADE) source = models.ForeignKey(Source, on_delete=models.CASCADE)
objects = InheritanceManager()
class Meta: class Meta:
unique_together = (("user", "source"),) unique_together = (("user", "source"),)

View File

View File

@ -0,0 +1,261 @@
"""Source decision helper"""
from enum import Enum
from typing import Any, Optional, Type
from django.contrib import messages
from django.db.models.query_utils import Q
from django.http import HttpRequest, HttpResponse, HttpResponseBadRequest
from django.shortcuts import redirect
from django.urls import reverse
from django.utils.translation import gettext as _
from structlog.stdlib import get_logger
from authentik.core.models import (
Source,
SourceUserMatchingModes,
User,
UserSourceConnection,
)
from authentik.core.sources.stage import (
PLAN_CONTEXT_SOURCES_CONNECTION,
PostUserEnrollmentStage,
)
from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow, Stage, in_memory_stage
from authentik.flows.planner import (
PLAN_CONTEXT_PENDING_USER,
PLAN_CONTEXT_REDIRECT,
PLAN_CONTEXT_SOURCE,
PLAN_CONTEXT_SSO,
FlowPlanner,
)
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN
from authentik.lib.utils.urls import redirect_with_qs
from authentik.policies.utils import delete_none_keys
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
class Action(Enum):
"""Actions that can be decided based on the request
and source settings"""
LINK = "link"
AUTH = "auth"
ENROLL = "enroll"
DENY = "deny"
class SourceFlowManager:
"""Help sources decide what they should do after authorization. Based on source settings and
previous connections, authenticate the user, enroll a new user, link to an existing user
or deny the request."""
source: Source
request: HttpRequest
identifier: str
connection_type: Type[UserSourceConnection] = UserSourceConnection
def __init__(
self,
source: Source,
request: HttpRequest,
identifier: str,
enroll_info: dict[str, Any],
) -> None:
self.source = source
self.request = request
self.identifier = identifier
self.enroll_info = enroll_info
self._logger = get_logger().bind(source=source, identifier=identifier)
# pylint: disable=too-many-return-statements
def get_action(self, **kwargs) -> tuple[Action, Optional[UserSourceConnection]]:
"""decide which action should be taken"""
new_connection = self.connection_type(
source=self.source, identifier=self.identifier
)
# When request is authenticated, always link
if self.request.user.is_authenticated:
new_connection.user = self.request.user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection.save()
return Action.LINK, new_connection
existing_connections = self.connection_type.objects.filter(
source=self.source, identifier=self.identifier
)
if existing_connections.exists():
connection = existing_connections.first()
return Action.AUTH, self.update_connection(connection, **kwargs)
# No connection exists, but we match on identifier, so enroll
if self.source.user_matching_mode == SourceUserMatchingModes.IDENTIFIER:
# We don't save the connection here cause it doesn't have a user assigned yet
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
# Check for existing users with matching attributes
query = Q()
# Either query existing user based on email or username
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.EMAIL_DENY,
]:
if not self.enroll_info.get("email", None):
self._logger.warning("Refusing to use none email", source=self.source)
return Action.DENY, None
query = Q(email__exact=self.enroll_info.get("email", None))
if self.source.user_matching_mode in [
SourceUserMatchingModes.USERNAME_LINK,
SourceUserMatchingModes.USERNAME_DENY,
]:
if not self.enroll_info.get("username", None):
self._logger.warning(
"Refusing to use none username", source=self.source
)
return Action.DENY, None
query = Q(username__exact=self.enroll_info.get("username", None))
matching_users = User.objects.filter(query)
# No matching users, always enroll
if not matching_users.exists():
return Action.ENROLL, self.update_connection(new_connection, **kwargs)
user = matching_users.first()
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_LINK,
SourceUserMatchingModes.USERNAME_LINK,
]:
new_connection.user = user
new_connection = self.update_connection(new_connection, **kwargs)
new_connection.save()
return Action.LINK, new_connection
if self.source.user_matching_mode in [
SourceUserMatchingModes.EMAIL_DENY,
SourceUserMatchingModes.USERNAME_DENY,
]:
return Action.DENY, None
return Action.DENY, None
def update_connection(
self, connection: UserSourceConnection, **kwargs
) -> UserSourceConnection:
"""Optionally make changes to the connection after it is looked up/created."""
return connection
def get_flow(self, **kwargs) -> HttpResponse:
"""Get the flow response based on user_matching_mode"""
action, connection = self.get_action()
if action == Action.LINK:
self._logger.debug("Linking existing user")
return self.handle_existing_user_link()
if not connection:
return redirect("/")
if action == Action.AUTH:
self._logger.debug("Handling auth user")
return self.handle_auth_user(connection)
if action == Action.ENROLL:
self._logger.debug("Handling enrollment of new user")
return self.handle_enroll(connection)
return redirect("/")
# pylint: disable=unused-argument
def get_stages_to_append(self, flow: Flow) -> list[Stage]:
"""Hook to override stages which are appended to the flow"""
if flow.slug == self.source.enrollment_flow.slug:
return [
in_memory_stage(PostUserEnrollmentStage),
]
return []
def _handle_login_flow(self, flow: Flow, **kwargs) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to
# authorize application
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-admin"
)
kwargs.update(
{
# Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_REDIRECT: final_redirect,
}
)
if not flow:
return HttpResponseBadRequest()
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(self.request, kwargs)
for stage in self.get_stages_to_append(flow):
plan.append(stage)
self.request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
self.request.GET,
flow_slug=flow.slug,
)
# pylint: disable=unused-argument
def handle_auth_user(
self,
connection: UserSourceConnection,
) -> HttpResponse:
"""Login user and redirect."""
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: connection.user}
return self._handle_login_flow(self.source.authentication_flow, **flow_kwargs)
def handle_existing_user_link(
self,
) -> HttpResponse:
"""Handler when the user was already authenticated and linked an external source
to their account."""
Event.new(
EventAction.SOURCE_LINKED,
message="Linked Source",
source=self.source,
).from_http(self.request)
messages.success(
self.request,
_("Successfully linked %(source)s!" % {"source": self.source.name}),
)
return redirect(
reverse(
"authentik_core:if-admin",
)
+ f"#/user;page-{self.source.slug}"
)
def handle_enroll(
self,
connection: UserSourceConnection,
) -> HttpResponse:
"""User was not authenticated and previous request was not authenticated."""
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
# We run the Flow planner here so we can pass the Pending user in the context
if not self.source.enrollment_flow:
self._logger.warning("source has no enrollment flow")
return HttpResponseBadRequest()
return self._handle_login_flow(
self.source.enrollment_flow,
**{
PLAN_CONTEXT_PROMPT: delete_none_keys(self.enroll_info),
PLAN_CONTEXT_SOURCES_CONNECTION: connection,
},
)

View File

@ -1,32 +1,30 @@
"""OAuth Stages""" """Source flow manager stages"""
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from authentik.core.models import User from authentik.core.models import User, UserSourceConnection
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER
from authentik.flows.stage import StageView from authentik.flows.stage import StageView
from authentik.sources.oauth.models import UserOAuthSourceConnection
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS = "sources_oauth_access" PLAN_CONTEXT_SOURCES_CONNECTION = "goauthentik.io/sources/connection"
class PostUserEnrollmentStage(StageView): class PostUserEnrollmentStage(StageView):
"""Dynamically injected stage which saves the OAuth Connection after """Dynamically injected stage which saves the Connection after
the user has been enrolled.""" the user has been enrolled."""
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Stage used after the user has been enrolled""" """Stage used after the user has been enrolled"""
access: UserOAuthSourceConnection = self.executor.plan.context[ connection: UserSourceConnection = self.executor.plan.context[
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS PLAN_CONTEXT_SOURCES_CONNECTION
] ]
user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER] user: User = self.executor.plan.context[PLAN_CONTEXT_PENDING_USER]
access.user = user connection.user = user
access.save() connection.save()
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
Event.new( Event.new(
EventAction.SOURCE_LINKED, EventAction.SOURCE_LINKED,
message="Linked OAuth Source", message="Linked Source",
source=access.source, source=connection.source,
).from_http(self.request) ).from_http(self.request)
return self.executor.stage_ok() return self.executor.stage_ok()

View File

@ -0,0 +1,84 @@
# Generated by Django 3.2 on 2021-05-02 17:06
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_policies_event_matcher", "0012_auto_20210323_1339"),
]
operations = [
migrations.AlterField(
model_name="eventmatcherpolicy",
name="app",
field=models.TextField(
blank=True,
choices=[
("authentik.admin", "authentik Admin"),
("authentik.api", "authentik API"),
("authentik.events", "authentik Events"),
("authentik.crypto", "authentik Crypto"),
("authentik.flows", "authentik Flows"),
("authentik.outposts", "authentik Outpost"),
("authentik.lib", "authentik lib"),
("authentik.policies", "authentik Policies"),
("authentik.policies.dummy", "authentik Policies.Dummy"),
(
"authentik.policies.event_matcher",
"authentik Policies.Event Matcher",
),
("authentik.policies.expiry", "authentik Policies.Expiry"),
("authentik.policies.expression", "authentik Policies.Expression"),
("authentik.policies.hibp", "authentik Policies.HaveIBeenPwned"),
("authentik.policies.password", "authentik Policies.Password"),
("authentik.policies.reputation", "authentik Policies.Reputation"),
("authentik.providers.proxy", "authentik Providers.Proxy"),
("authentik.providers.oauth2", "authentik Providers.OAuth2"),
("authentik.providers.saml", "authentik Providers.SAML"),
("authentik.recovery", "authentik Recovery"),
("authentik.sources.ldap", "authentik Sources.LDAP"),
("authentik.sources.oauth", "authentik Sources.OAuth"),
("authentik.sources.plex", "authentik Sources.Plex"),
("authentik.sources.saml", "authentik Sources.SAML"),
(
"authentik.stages.authenticator_static",
"authentik Stages.Authenticator.Static",
),
(
"authentik.stages.authenticator_totp",
"authentik Stages.Authenticator.TOTP",
),
(
"authentik.stages.authenticator_validate",
"authentik Stages.Authenticator.Validate",
),
(
"authentik.stages.authenticator_webauthn",
"authentik Stages.Authenticator.WebAuthn",
),
("authentik.stages.captcha", "authentik Stages.Captcha"),
("authentik.stages.consent", "authentik Stages.Consent"),
("authentik.stages.deny", "authentik Stages.Deny"),
("authentik.stages.dummy", "authentik Stages.Dummy"),
("authentik.stages.email", "authentik Stages.Email"),
(
"authentik.stages.identification",
"authentik Stages.Identification",
),
("authentik.stages.invitation", "authentik Stages.User Invitation"),
("authentik.stages.password", "authentik Stages.Password"),
("authentik.stages.prompt", "authentik Stages.Prompt"),
("authentik.stages.user_delete", "authentik Stages.User Delete"),
("authentik.stages.user_login", "authentik Stages.User Login"),
("authentik.stages.user_logout", "authentik Stages.User Logout"),
("authentik.stages.user_write", "authentik Stages.User Write"),
("authentik.core", "authentik Core"),
("authentik.managed", "authentik Managed"),
],
default="",
help_text="Match events created by selected application. When left empty, all applications are matched.",
),
),
]

View File

@ -1,23 +0,0 @@
"""authentik oauth_client Authorization backend"""
from typing import Optional
from django.contrib.auth.backends import ModelBackend
from django.http import HttpRequest
from authentik.core.models import User
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
class AuthorizedServiceBackend(ModelBackend):
"Authentication backend for users registered with remote OAuth provider."
def authenticate(
self, request: HttpRequest, source: OAuthSource, identifier: str
) -> Optional[User]:
"Fetch user for a given source by id."
access = UserOAuthSourceConnection.objects.filter(
source=source, identifier=identifier
).select_related("user")
if not access.exists():
return None
return access.first().user

View File

@ -1,7 +1,7 @@
"""Discord Type tests""" """Discord Type tests"""
from django.test import TestCase from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.discord import DiscordOAuth2Callback from authentik.sources.oauth.types.discord import DiscordOAuth2Callback
# https://discord.com/developers/docs/resources/user#user-object # https://discord.com/developers/docs/resources/user#user-object
@ -33,9 +33,7 @@ class TestTypeDiscord(TestCase):
def test_enroll_context(self): def test_enroll_context(self):
"""Test discord Enrollment context""" """Test discord Enrollment context"""
ak_context = DiscordOAuth2Callback().get_user_enroll_context( ak_context = DiscordOAuth2Callback().get_user_enroll_context(DISCORD_USER)
self.source, UserOAuthSourceConnection(), DISCORD_USER
)
self.assertEqual(ak_context["username"], DISCORD_USER["username"]) self.assertEqual(ak_context["username"], DISCORD_USER["username"])
self.assertEqual(ak_context["email"], DISCORD_USER["email"]) self.assertEqual(ak_context["email"], DISCORD_USER["email"])
self.assertEqual(ak_context["name"], DISCORD_USER["username"]) self.assertEqual(ak_context["name"], DISCORD_USER["username"])

View File

@ -1,7 +1,7 @@
"""GitHub Type tests""" """GitHub Type tests"""
from django.test import TestCase from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.github import GitHubOAuth2Callback from authentik.sources.oauth.types.github import GitHubOAuth2Callback
# https://developer.github.com/v3/users/#get-the-authenticated-user # https://developer.github.com/v3/users/#get-the-authenticated-user
@ -63,9 +63,7 @@ class TestTypeGitHub(TestCase):
def test_enroll_context(self): def test_enroll_context(self):
"""Test GitHub Enrollment context""" """Test GitHub Enrollment context"""
ak_context = GitHubOAuth2Callback().get_user_enroll_context( ak_context = GitHubOAuth2Callback().get_user_enroll_context(GITHUB_USER)
self.source, UserOAuthSourceConnection(), GITHUB_USER
)
self.assertEqual(ak_context["username"], GITHUB_USER["login"]) self.assertEqual(ak_context["username"], GITHUB_USER["login"])
self.assertEqual(ak_context["email"], GITHUB_USER["email"]) self.assertEqual(ak_context["email"], GITHUB_USER["email"])
self.assertEqual(ak_context["name"], GITHUB_USER["name"]) self.assertEqual(ak_context["name"], GITHUB_USER["name"])

View File

@ -1,7 +1,7 @@
"""google Type tests""" """google Type tests"""
from django.test import TestCase from django.test import TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.google import GoogleOAuth2Callback from authentik.sources.oauth.types.google import GoogleOAuth2Callback
# https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en # https://developers.google.com/identity/protocols/oauth2/openid-connect?hl=en
@ -32,9 +32,7 @@ class TestTypeGoogle(TestCase):
def test_enroll_context(self): def test_enroll_context(self):
"""Test Google Enrollment context""" """Test Google Enrollment context"""
ak_context = GoogleOAuth2Callback().get_user_enroll_context( ak_context = GoogleOAuth2Callback().get_user_enroll_context(GOOGLE_USER)
self.source, UserOAuthSourceConnection(), GOOGLE_USER
)
self.assertEqual(ak_context["username"], GOOGLE_USER["email"]) self.assertEqual(ak_context["username"], GOOGLE_USER["email"])
self.assertEqual(ak_context["email"], GOOGLE_USER["email"]) self.assertEqual(ak_context["email"], GOOGLE_USER["email"])
self.assertEqual(ak_context["name"], GOOGLE_USER["name"]) self.assertEqual(ak_context["name"], GOOGLE_USER["name"])

View File

@ -1,7 +1,7 @@
"""Twitter Type tests""" """Twitter Type tests"""
from django.test import Client, TestCase from django.test import Client, TestCase
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.twitter import TwitterOAuthCallback from authentik.sources.oauth.types.twitter import TwitterOAuthCallback
# https://developer.twitter.com/en/docs/twitter-api/v1/accounts-and-users/manage-account-settings/ \ # https://developer.twitter.com/en/docs/twitter-api/v1/accounts-and-users/manage-account-settings/ \
@ -104,9 +104,7 @@ class TestTypeGitHub(TestCase):
def test_enroll_context(self): def test_enroll_context(self):
"""Test Twitter Enrollment context""" """Test Twitter Enrollment context"""
ak_context = TwitterOAuthCallback().get_user_enroll_context( ak_context = TwitterOAuthCallback().get_user_enroll_context(TWITTER_USER)
self.source, UserOAuthSourceConnection(), TWITTER_USER
)
self.assertEqual(ak_context["username"], TWITTER_USER["screen_name"]) self.assertEqual(ak_context["username"], TWITTER_USER["screen_name"])
self.assertEqual(ak_context["email"], TWITTER_USER.get("email", None)) self.assertEqual(ak_context["email"], TWITTER_USER.get("email", None))
self.assertEqual(ak_context["name"], TWITTER_USER["name"]) self.assertEqual(ak_context["name"], TWITTER_USER["name"])

View File

@ -2,7 +2,6 @@
from typing import Any, Optional from typing import Any, Optional
from uuid import UUID from uuid import UUID
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@ -10,7 +9,7 @@ from authentik.sources.oauth.views.callback import OAuthCallback
class AzureADOAuthCallback(OAuthCallback): class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback""" """AzureAD OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: dict[str, Any]) -> Optional[str]: def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
try: try:
return str(UUID(info.get("objectId")).int) return str(UUID(info.get("objectId")).int)
except TypeError: except TypeError:
@ -18,8 +17,6 @@ class AzureADOAuthCallback(OAuthCallback):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0] mail = info.get("mail", None) or info.get("otherMails", [None])[0]

View File

@ -1,7 +1,6 @@
"""Discord OAuth Views""" """Discord OAuth Views"""
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -21,8 +20,6 @@ class DiscordOAuth2Callback(OAuthCallback):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View File

@ -4,7 +4,6 @@ from typing import Any, Optional
from facebook import GraphAPI from facebook import GraphAPI
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -34,8 +33,6 @@ class FacebookOAuth2Callback(OAuthCallback):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View File

@ -1,7 +1,6 @@
"""GitHub OAuth Views""" """GitHub OAuth Views"""
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@ -11,8 +10,6 @@ class GitHubOAuth2Callback(OAuthCallback):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View File

@ -1,7 +1,6 @@
"""Google OAuth Views""" """Google OAuth Views"""
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -21,8 +20,6 @@ class GoogleOAuth2Callback(OAuthCallback):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View File

@ -1,7 +1,7 @@
"""OpenID Connect OAuth Views""" """OpenID Connect OAuth Views"""
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -19,13 +19,11 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
class OpenIDConnectOAuth2Callback(OAuthCallback): class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback""" """OpenIDConnect OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "") return info.get("sub", "")
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View File

@ -4,7 +4,6 @@ from typing import Any
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.sources.oauth.clients.oauth2 import OAuth2Client
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -36,8 +35,6 @@ class RedditOAuth2Callback(OAuthCallback):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View File

@ -1,7 +1,6 @@
"""Twitter OAuth Views""" """Twitter OAuth Views"""
from typing import Any from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback from authentik.sources.oauth.views.callback import OAuthCallback
@ -11,8 +10,6 @@ class TwitterOAuthCallback(OAuthCallback):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
return { return {

View File

@ -4,35 +4,14 @@ from typing import Any, Optional
from django.conf import settings from django.conf import settings
from django.contrib import messages from django.contrib import messages
from django.http import Http404, HttpRequest, HttpResponse from django.http import Http404, HttpRequest, HttpResponse
from django.http.response import HttpResponseBadRequest
from django.shortcuts import redirect from django.shortcuts import redirect
from django.urls import reverse
from django.utils.translation import gettext as _ from django.utils.translation import gettext as _
from django.views.generic import View from django.views.generic import View
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.models import User from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.events.models import Event, EventAction
from authentik.flows.models import Flow, in_memory_stage
from authentik.flows.planner import (
PLAN_CONTEXT_PENDING_USER,
PLAN_CONTEXT_REDIRECT,
PLAN_CONTEXT_SOURCE,
PLAN_CONTEXT_SSO,
FlowPlanner,
)
from authentik.flows.views import NEXT_ARG_NAME, SESSION_KEY_GET, SESSION_KEY_PLAN
from authentik.lib.utils.urls import redirect_with_qs
from authentik.policies.utils import delete_none_keys
from authentik.sources.oauth.auth import AuthorizedServiceBackend
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.views.base import OAuthClientMixin from authentik.sources.oauth.views.base import OAuthClientMixin
from authentik.sources.oauth.views.flows import (
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS,
PostUserEnrollmentStage,
)
from authentik.stages.password.stage import PLAN_CONTEXT_AUTHENTICATION_BACKEND
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
LOGGER = get_logger() LOGGER = get_logger()
@ -40,8 +19,7 @@ LOGGER = get_logger()
class OAuthCallback(OAuthClientMixin, View): class OAuthCallback(OAuthClientMixin, View):
"Base OAuth callback view." "Base OAuth callback view."
source_id = None source: OAuthSource
source = None
# pylint: disable=too-many-return-statements # pylint: disable=too-many-return-statements
def get(self, request: HttpRequest, *_, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *_, **kwargs) -> HttpResponse:
@ -60,47 +38,27 @@ class OAuthCallback(OAuthClientMixin, View):
# Fetch access token # Fetch access token
token = client.get_access_token() token = client.get_access_token()
if token is None: if token is None:
return self.handle_login_failure(self.source, "Could not retrieve token.") return self.handle_login_failure("Could not retrieve token.")
if "error" in token: if "error" in token:
return self.handle_login_failure(self.source, token["error"]) return self.handle_login_failure(token["error"])
# Fetch profile info # Fetch profile info
info = client.get_profile_info(token) raw_info = client.get_profile_info(token)
if info is None: if raw_info is None:
return self.handle_login_failure(self.source, "Could not retrieve profile.") return self.handle_login_failure("Could not retrieve profile.")
identifier = self.get_user_id(self.source, info) identifier = self.get_user_id(raw_info)
if identifier is None: if identifier is None:
return self.handle_login_failure(self.source, "Could not determine id.") return self.handle_login_failure("Could not determine id.")
# Get or create access record # Get or create access record
defaults = { enroll_info = self.get_user_enroll_context(raw_info)
"access_token": token.get("access_token"), sfm = OAuthSourceFlowManager(
} source=self.source,
existing = UserOAuthSourceConnection.objects.filter( request=self.request,
source=self.source, identifier=identifier identifier=identifier,
enroll_info=enroll_info,
) )
return sfm.get_flow(
if existing.exists(): token=token,
connection = existing.first()
connection.access_token = token.get("access_token")
UserOAuthSourceConnection.objects.filter(pk=connection.pk).update(
**defaults
)
else:
connection = UserOAuthSourceConnection(
source=self.source,
identifier=identifier,
access_token=token.get("access_token"),
)
user = AuthorizedServiceBackend().authenticate(
source=self.source, identifier=identifier, request=request
) )
if user is None:
if self.request.user.is_authenticated:
LOGGER.debug("Linking existing user", source=self.source)
return self.handle_existing_user_link(self.source, connection, info)
LOGGER.debug("Handling enrollment of new user", source=self.source)
return self.handle_enroll(self.source, connection, info)
LOGGER.debug("Handling existing user", source=self.source)
return self.handle_existing_user(self.source, user, connection, info)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get_callback_url(self, source: OAuthSource) -> str: def get_callback_url(self, source: OAuthSource) -> str:
@ -114,132 +72,34 @@ class OAuthCallback(OAuthClientMixin, View):
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any], info: dict[str, Any],
) -> dict[str, Any]: ) -> dict[str, Any]:
"""Create a dict of User data""" """Create a dict of User data"""
raise NotImplementedError() raise NotImplementedError()
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get_user_id( def get_user_id(self, info: dict[str, Any]) -> Optional[str]:
self, source: UserOAuthSourceConnection, info: dict[str, Any]
) -> Optional[str]:
"""Return unique identifier from the profile info.""" """Return unique identifier from the profile info."""
if "id" in info: if "id" in info:
return info["id"] return info["id"]
return None return None
def handle_login_failure(self, source: OAuthSource, reason: str) -> HttpResponse: def handle_login_failure(self, reason: str) -> HttpResponse:
"Message user and redirect on error." "Message user and redirect on error."
LOGGER.warning("Authentication Failure", reason=reason) LOGGER.warning("Authentication Failure", reason=reason)
messages.error(self.request, _("Authentication Failed.")) messages.error(self.request, _("Authentication Failed."))
return redirect(self.get_error_redirect(source, reason)) return redirect(self.get_error_redirect(self.source, reason))
def handle_login_flow(
self, flow: Flow, *stages_to_append, **kwargs
) -> HttpResponse:
"""Prepare Authentication Plan, redirect user FlowExecutor"""
# Ensure redirect is carried through when user was trying to
# authorize application
final_redirect = self.request.session.get(SESSION_KEY_GET, {}).get(
NEXT_ARG_NAME, "authentik_core:if-admin"
)
kwargs.update(
{
# Since we authenticate the user by their token, they have no backend set
PLAN_CONTEXT_AUTHENTICATION_BACKEND: "django.contrib.auth.backends.ModelBackend",
PLAN_CONTEXT_SSO: True,
PLAN_CONTEXT_SOURCE: self.source,
PLAN_CONTEXT_REDIRECT: final_redirect,
}
)
if not flow:
return HttpResponseBadRequest()
# We run the Flow planner here so we can pass the Pending user in the context
planner = FlowPlanner(flow)
plan = planner.plan(self.request, kwargs)
for stage in stages_to_append:
plan.append(stage)
self.request.session[SESSION_KEY_PLAN] = plan
return redirect_with_qs(
"authentik_core:if-flow",
self.request.GET,
flow_slug=flow.slug,
)
# pylint: disable=unused-argument class OAuthSourceFlowManager(SourceFlowManager):
def handle_existing_user( """Flow manager for oauth sources"""
self,
source: OAuthSource,
user: User,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> HttpResponse:
"Login user and redirect."
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
flow_kwargs = {PLAN_CONTEXT_PENDING_USER: user}
return self.handle_login_flow(source.authentication_flow, **flow_kwargs)
def handle_existing_user_link( connection_type = UserOAuthSourceConnection
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> HttpResponse:
"""Handler when the user was already authenticated and linked an external source
to their account."""
# there's already a user logged in, just link them up
user = self.request.user
access.user = user
access.save()
UserOAuthSourceConnection.objects.filter(pk=access.pk).update(user=user)
Event.new(
EventAction.SOURCE_LINKED, message="Linked OAuth Source", source=source
).from_http(self.request)
messages.success(
self.request,
_("Successfully linked %(source)s!" % {"source": self.source.name}),
)
return redirect(
reverse(
"authentik_core:if-admin",
)
+ f"#/user;page-{self.source.slug}"
)
def handle_enroll( def update_connection(
self, self, connection: UserOAuthSourceConnection, token: dict[str, Any]
source: OAuthSource, ) -> UserOAuthSourceConnection:
access: UserOAuthSourceConnection, """Set the access_token on the connection"""
info: dict[str, Any], connection.access_token = token.get("access_token")
) -> HttpResponse: connection.save()
"""User was not authenticated and previous request was not authenticated.""" return connection
messages.success(
self.request,
_(
"Successfully authenticated with %(source)s!"
% {"source": self.source.name}
),
)
# We run the Flow planner here so we can pass the Pending user in the context
if not source.enrollment_flow:
LOGGER.warning("source has no enrollment flow", source=source)
return HttpResponseBadRequest()
return self.handle_login_flow(
source.enrollment_flow,
in_memory_stage(PostUserEnrollmentStage),
**{
PLAN_CONTEXT_PROMPT: delete_none_keys(
self.get_user_enroll_context(source, access, info)
),
PLAN_CONTEXT_SOURCES_OAUTH_ACCESS: access,
},
)

View File

@ -1,26 +1,22 @@
"""Plex Source Serializer""" """Plex Source Serializer"""
from urllib.parse import urlencode
from django.http import Http404 from django.http import Http404
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from drf_yasg import openapi from drf_yasg import openapi
from drf_yasg.utils import swagger_auto_schema from drf_yasg.utils import swagger_auto_schema
from requests import RequestException, get
from rest_framework.decorators import action from rest_framework.decorators import action
from rest_framework.fields import CharField from rest_framework.fields import CharField
from rest_framework.permissions import AllowAny from rest_framework.permissions import AllowAny
from rest_framework.request import Request from rest_framework.request import Request
from rest_framework.response import Response from rest_framework.response import Response
from rest_framework.viewsets import ModelViewSet from rest_framework.viewsets import ModelViewSet
from structlog.stdlib import get_logger
from authentik.api.decorators import permission_required from authentik.api.decorators import permission_required
from authentik.core.api.sources import SourceSerializer from authentik.core.api.sources import SourceSerializer
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer
from authentik.flows.challenge import ChallengeTypes, RedirectChallenge from authentik.flows.challenge import RedirectChallenge
from authentik.flows.views import to_stage_response
from authentik.sources.plex.models import PlexSource from authentik.sources.plex.models import PlexSource
from authentik.sources.plex.plex import PlexAuth
LOGGER = get_logger()
class PlexSourceSerializer(SourceSerializer): class PlexSourceSerializer(SourceSerializer):
@ -72,29 +68,8 @@ class PlexSourceViewSet(ModelViewSet):
plex_token = request.data.get("plex_token", None) plex_token = request.data.get("plex_token", None)
if not plex_token: if not plex_token:
raise Http404 raise Http404
qs = {"X-Plex-Token": plex_token, "X-Plex-Client-Identifier": source.client_id} auth_api = PlexAuth(source, plex_token)
try: if not auth_api.check_server_overlap():
response = get(
f"https://plex.tv/api/v2/resources?{urlencode(qs)}",
headers={"Accept": "application/json"},
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user resources", exc=exc)
raise Http404 raise Http404
else: response = auth_api.get_user_url(request)
resources: list[dict] = response.json() return to_stage_response(request, response)
for resource in resources:
if resource["provides"] != "server":
continue
if resource["clientIdentifier"] in source.allowed_servers:
LOGGER.info(
"Plex allowed access from server", name=resource["name"]
)
request.session["foo"] = "bar"
break
return Response(
RedirectChallenge(
{"type": ChallengeTypes.REDIRECT.value, "to": ""}
).data
)

View File

@ -0,0 +1,38 @@
# Generated by Django 3.2 on 2021-05-03 17:06
import django.db.models.deletion
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_core", "0020_source_user_matching_mode"),
("authentik_sources_plex", "0001_initial"),
]
operations = [
migrations.CreateModel(
name="PlexSourceConnection",
fields=[
(
"usersourceconnection_ptr",
models.OneToOneField(
auto_created=True,
on_delete=django.db.models.deletion.CASCADE,
parent_link=True,
primary_key=True,
serialize=False,
to="authentik_core.usersourceconnection",
),
),
("plex_token", models.TextField()),
("identifier", models.TextField()),
],
options={
"verbose_name": "User Plex Source Connection",
"verbose_name_plural": "User Plex Source Connections",
},
bases=("authentik_core.usersourceconnection",),
),
]

View File

@ -6,7 +6,7 @@ from django.utils.translation import gettext_lazy as _
from rest_framework.fields import CharField from rest_framework.fields import CharField
from rest_framework.serializers import BaseSerializer from rest_framework.serializers import BaseSerializer
from authentik.core.models import Source from authentik.core.models import Source, UserSourceConnection
from authentik.core.types import UILoginButton from authentik.core.types import UILoginButton
from authentik.flows.challenge import Challenge, ChallengeTypes from authentik.flows.challenge import Challenge, ChallengeTypes
@ -53,3 +53,15 @@ class PlexSource(Source):
verbose_name = _("Plex Source") verbose_name = _("Plex Source")
verbose_name_plural = _("Plex Sources") verbose_name_plural = _("Plex Sources")
class PlexSourceConnection(UserSourceConnection):
"""Connect user and plex source"""
plex_token = models.TextField()
identifier = models.TextField()
class Meta:
verbose_name = _("User Plex Source Connection")
verbose_name_plural = _("User Plex Source Connections")

View File

@ -1,136 +1,113 @@
"""Plex OAuth Views""" """Plex Views"""
from typing import Any, Optional
from urllib.parse import urlencode from urllib.parse import urlencode
from django.http.response import Http404 import requests
from requests import post from django.http.request import HttpRequest
from requests.api import get from django.http.response import Http404, HttpResponse
from requests.exceptions import RequestException from requests.exceptions import RequestException
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik import __version__ from authentik import __version__
from authentik.sources.oauth.clients.oauth2 import OAuth2Client from authentik.core.sources.flow_manager import SourceFlowManager
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.plex.models import PlexSource, PlexSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, SourceType
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger() LOGGER = get_logger()
SESSION_ID_KEY = "PLEX_ID" SESSION_ID_KEY = "PLEX_ID"
SESSION_CODE_KEY = "PLEX_CODE" SESSION_CODE_KEY = "PLEX_CODE"
DEFAULT_PAYLOAD = {
"X-Plex-Product": "authentik",
"X-Plex-Version": __version__,
"X-Plex-Device-Vendor": "BeryJu.org",
}
class PlexRedirect(OAuthRedirect): class PlexAuth:
"""Plex Auth redirect, get a pin then redirect to a URL to claim it""" """Plex authentication utilities"""
headers = {} _source: PlexSource
_token: str
def get_pin(self, **data) -> dict: def __init__(self, source: PlexSource, token: str):
"""Get plex pin that the user will claim self._source = source
https://forums.plex.tv/t/authenticating-with-plex/609370""" self._token = token
return post( self._session = requests.Session()
"https://plex.tv/api/v2/pins.json?strong=true", self._session.headers.update(
data=data, {"Accept": "application/json", "Content-Type": "application/json"}
headers=self.headers,
).json()
def get_redirect_url(self, **kwargs) -> str:
slug = kwargs.get("source_slug", "")
self.headers = {"Origin": self.request.build_absolute_uri("/")}
try:
source: OAuthSource = OAuthSource.objects.get(slug=slug)
except OAuthSource.DoesNotExist:
raise Http404(f"Unknown OAuth source '{slug}'.")
else:
payload = DEFAULT_PAYLOAD.copy()
payload["X-Plex-Client-Identifier"] = source.consumer_key
# Get a pin first
pin = self.get_pin(**payload)
LOGGER.debug("Got pin", **pin)
self.request.session[SESSION_ID_KEY] = pin["id"]
self.request.session[SESSION_CODE_KEY] = pin["code"]
qs = {
"clientID": source.consumer_key,
"code": pin["code"],
"forwardUrl": self.request.build_absolute_uri(
self.get_callback_url(source)
),
}
return f"https://app.plex.tv/auth#!?{urlencode(qs)}"
class PlexOAuthClient(OAuth2Client):
"""Retrive the plex token after authentication, then ask the plex API about user info"""
def check_application_state(self) -> bool:
return SESSION_ID_KEY in self.request.session
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
payload = dict(DEFAULT_PAYLOAD)
payload["X-Plex-Client-Identifier"] = self.source.consumer_key
payload["Accept"] = "application/json"
response = get(
f"https://plex.tv/api/v2/pins/{self.request.session[SESSION_ID_KEY]}",
headers=payload,
) )
response.raise_for_status() self._session.headers.update(self.headers)
token = response.json()["authToken"]
return {"plex_token": token}
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: @property
"Fetch user profile information." def headers(self) -> dict[str, str]:
qs = {"X-Plex-Token": token["plex_token"]} """Get common headers"""
print(token)
try:
response = self.do_request(
"get", f"https://plex.tv/users/account.json?{urlencode(qs)}"
)
response.raise_for_status()
except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc)
return None
else:
info = response.json()
return info.get("user", {})
class PlexOAuth2Callback(OAuthCallback):
"""Plex OAuth2 Callback"""
client_class = PlexOAuthClient
def get_user_id(
self, source: UserOAuthSourceConnection, info: dict[str, Any]
) -> Optional[str]:
return info.get("uuid")
def get_user_enroll_context(
self,
source: OAuthSource,
access: UserOAuthSourceConnection,
info: dict[str, Any],
) -> dict[str, Any]:
return { return {
"username": info.get("username"), "X-Plex-Product": "authentik",
"email": info.get("email"), "X-Plex-Version": __version__,
"name": info.get("title"), "X-Plex-Device-Vendor": "BeryJu.org",
} }
def get_resources(self) -> list[dict]:
"""Get all resources the plex-token has access to"""
qs = {
"X-Plex-Token": self._token,
"X-Plex-Client-Identifier": self._source.client_id,
}
response = self._session.get(
f"https://plex.tv/api/v2/resources?{urlencode(qs)}",
)
response.raise_for_status()
return response.json()
@MANAGER.type() def get_user_info(self) -> tuple[dict, int]:
class PlexType(SourceType): """Get user info of the plex token"""
"""Plex Type definition""" qs = {
"X-Plex-Token": self._token,
"X-Plex-Client-Identifier": self._source.client_id,
}
response = self._session.get(
f"https://plex.tv/api/v2/user?{urlencode(qs)}",
)
response.raise_for_status()
raw_user_info = response.json()
return {
"username": raw_user_info.get("username"),
"email": raw_user_info.get("email"),
"name": raw_user_info.get("title"),
}, raw_user_info.get("id")
redirect_view = PlexRedirect def check_server_overlap(self) -> bool:
callback_view = PlexOAuth2Callback """Check if the plex-token has any server overlap with our configured servers"""
name = "Plex" try:
slug = "plex" resources = self.get_resources()
except RequestException as exc:
LOGGER.warning("Unable to fetch user resources", exc=exc)
raise Http404
else:
for resource in resources:
if resource["provides"] != "server":
continue
if resource["clientIdentifier"] in self._source.allowed_servers:
LOGGER.info(
"Plex allowed access from server", name=resource["name"]
)
return True
return False
authorization_url = "" def get_user_url(self, request: HttpRequest) -> HttpResponse:
access_token_url = "" # nosec """Get a URL to a flow executor for either enrollment or authentication"""
profile_url = "" user_info, identifier = self.get_user_info()
sfm = PlexSourceFlowManager(
source=self._source,
request=request,
identifier=str(identifier),
enroll_info=user_info,
)
return sfm.get_flow(plex_token=self._token)
class PlexSourceFlowManager(SourceFlowManager):
"""Flow manager for plex sources"""
connection_type = PlexSourceConnection
def update_connection(
self, connection: PlexSourceConnection, plex_token: str
) -> PlexSourceConnection:
"""Set the access_token on the connection"""
connection.plex_token = plex_token
connection.save()
return connection

View File

@ -17289,6 +17289,17 @@ definitions:
enum: enum:
- all - all
- any - any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
UserSetting: UserSetting:
required: required:
- object_uid - object_uid
@ -17369,6 +17380,17 @@ definitions:
enum: enum:
- all - all
- any - any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
server_uri: server_uri:
title: Server URI title: Server URI
type: string type: string
@ -17549,6 +17571,17 @@ definitions:
enum: enum:
- all - all
- any - any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
provider_type: provider_type:
title: Provider type title: Provider type
type: string type: string
@ -17678,6 +17711,17 @@ definitions:
enum: enum:
- all - all
- any - any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
client_id: client_id:
title: Client id title: Client id
type: string type: string
@ -17792,6 +17836,17 @@ definitions:
enum: enum:
- all - all
- any - any
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should be authenticated
or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
pre_authentication_flow: pre_authentication_flow:
title: Pre authentication flow title: Pre authentication flow
description: Flow used before authentication. description: Flow used before authentication.
@ -18537,6 +18592,17 @@ definitions:
enabled: enabled:
title: Enabled title: Enabled
type: boolean type: boolean
user_matching_mode:
title: User matching mode
description: How the source determines if an existing user should
be authenticated or a new user enrolled.
type: string
enum:
- identifier
- email_link
- email_deny
- username_link
- username_deny
authentication_flow: authentication_flow:
title: Authentication flow title: Authentication flow
description: Flow to use when authenticating existing users. description: Flow to use when authenticating existing users.