Merge branch 'main' into web/issue-4880-multi-select-limitations
* main: (30 commits) outposts/proxy: better Redis error message (#8044) translate: Updates for file web/xliff/en.xlf in fr (#8046) web: bump the eslint group in /tests/wdio with 2 updates (#8041) web: bump the storybook group in /web with 7 updates (#8042) web: bump the eslint group in /web with 2 updates (#8043) web: bump @types/guacamole-common-js from 1.3.2 to 1.5.2 in /web (#8030) translate: Updates for file web/xliff/en.xlf in zh_CN (#8038) translate: Updates for file web/xliff/en.xlf in zh-Hans (#8039) website: bump clsx from 2.0.0 to 2.1.0 in /website (#8033) core: bump golang from 1.21.3-bookworm to 1.21.5-bookworm (#8027) web: bump the babel group in /web with 4 updates (#8028) web: bump the esbuild group in /web with 2 updates (#8029) web: bump rollup from 4.9.1 to 4.9.2 in /web (#8031) tests/e2e: fix tests to work without docker network_mode host (#8035) website/docs: fix typo (#8015) web: bump API Client version (#8025) enterprise/providers: Add RAC [AUTH-15] (#7291) outposts: disable deployment and secret reconciler for embedded outpost in code instead of in config (#8021) providers/proxy: use access token (#8022) website/integrations: Add custom Group/Role mapping documentation for Grafana (#7453) ...
This commit is contained in:
commit
73bba0498f
|
@ -9,3 +9,4 @@ blueprints/local
|
|||
.git
|
||||
!gen-ts-api/node_modules
|
||||
!gen-ts-api/dist/**
|
||||
!gen-go-api/
|
||||
|
|
1
.github/codespell-words.txt
vendored
1
.github/codespell-words.txt
vendored
|
@ -2,3 +2,4 @@ keypair
|
|||
keypairs
|
||||
hass
|
||||
warmup
|
||||
ontext
|
||||
|
|
29
.github/workflows/ci-main.yml
vendored
29
.github/workflows/ci-main.yml
vendored
|
@ -249,12 +249,6 @@ jobs:
|
|||
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
- name: Comment on PR
|
||||
if: github.event_name == 'pull_request'
|
||||
continue-on-error: true
|
||||
uses: ./.github/actions/comment-pr-instructions
|
||||
with:
|
||||
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
|
||||
build-arm64:
|
||||
needs: ci-core-mark
|
||||
runs-on: ubuntu-latest
|
||||
|
@ -303,3 +297,26 @@ jobs:
|
|||
platforms: linux/arm64
|
||||
cache-from: type=gha
|
||||
cache-to: type=gha,mode=max
|
||||
pr-comment:
|
||||
needs:
|
||||
- build
|
||||
- build-arm64
|
||||
runs-on: ubuntu-latest
|
||||
if: ${{ github.event_name == 'pull_request' }}
|
||||
permissions:
|
||||
# Needed to write comments on PRs
|
||||
pull-requests: write
|
||||
timeout-minutes: 120
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
with:
|
||||
ref: ${{ github.event.pull_request.head.sha }}
|
||||
- name: prepare variables
|
||||
uses: ./.github/actions/docker-push-variables
|
||||
id: ev
|
||||
env:
|
||||
DOCKER_USERNAME: ${{ secrets.DOCKER_USERNAME }}
|
||||
- name: Comment on PR
|
||||
uses: ./.github/actions/comment-pr-instructions
|
||||
with:
|
||||
tag: gh-${{ steps.ev.outputs.branchNameContainer }}-${{ steps.ev.outputs.timestamp }}-${{ steps.ev.outputs.shortHash }}
|
||||
|
|
2
.github/workflows/ci-outpost.yml
vendored
2
.github/workflows/ci-outpost.yml
vendored
|
@ -65,6 +65,7 @@ jobs:
|
|||
- proxy
|
||||
- ldap
|
||||
- radius
|
||||
- rac
|
||||
runs-on: ubuntu-latest
|
||||
permissions:
|
||||
# Needed to upload contianer images to ghcr.io
|
||||
|
@ -119,6 +120,7 @@ jobs:
|
|||
- proxy
|
||||
- ldap
|
||||
- radius
|
||||
- rac
|
||||
goos: [linux]
|
||||
goarch: [amd64, arm64]
|
||||
steps:
|
||||
|
|
1
.github/workflows/release-publish.yml
vendored
1
.github/workflows/release-publish.yml
vendored
|
@ -65,6 +65,7 @@ jobs:
|
|||
- proxy
|
||||
- ldap
|
||||
- radius
|
||||
- rac
|
||||
steps:
|
||||
- uses: actions/checkout@v4
|
||||
- uses: actions/setup-go@v5
|
||||
|
|
2
Makefile
2
Makefile
|
@ -58,7 +58,7 @@ test: ## Run the server tests and produce a coverage report (locally)
|
|||
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
|
||||
isort $(PY_SOURCES)
|
||||
black $(PY_SOURCES)
|
||||
ruff $(PY_SOURCES)
|
||||
ruff --fix $(PY_SOURCES)
|
||||
codespell -w $(CODESPELL_ARGS)
|
||||
|
||||
lint: ## Lint the python and golang sources
|
||||
|
|
|
@ -40,7 +40,7 @@ class ManagedAppConfig(AppConfig):
|
|||
meth()
|
||||
self._logger.debug("Successfully reconciled", name=name)
|
||||
except (DatabaseError, ProgrammingError, InternalError) as exc:
|
||||
self._logger.debug("Failed to run reconcile", name=name, exc=exc)
|
||||
self._logger.warning("Failed to run reconcile", name=name, exc=exc)
|
||||
|
||||
|
||||
class AuthentikBlueprintsConfig(ManagedAppConfig):
|
||||
|
|
|
@ -1,22 +1,29 @@
|
|||
"""Channels base classes"""
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from rest_framework.exceptions import AuthenticationFailed
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.api.authentication import bearer_auth
|
||||
from authentik.core.models import User
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class AuthJsonConsumer(JsonWebsocketConsumer):
|
||||
class TokenOutpostMiddleware:
|
||||
"""Authorize a client with a token"""
|
||||
|
||||
user: User
|
||||
def __init__(self, inner):
|
||||
self.inner = inner
|
||||
|
||||
def connect(self):
|
||||
headers = dict(self.scope["headers"])
|
||||
async def __call__(self, scope, receive, send):
|
||||
scope = dict(scope)
|
||||
await self.auth(scope)
|
||||
return await self.inner(scope, receive, send)
|
||||
|
||||
@database_sync_to_async
|
||||
def auth(self, scope):
|
||||
"""Authenticate request from header"""
|
||||
headers = dict(scope["headers"])
|
||||
if b"authorization" not in headers:
|
||||
LOGGER.warning("WS Request without authorization header")
|
||||
raise DenyConnection()
|
||||
|
@ -32,4 +39,4 @@ class AuthJsonConsumer(JsonWebsocketConsumer):
|
|||
LOGGER.warning("Failed to authenticate", exc=exc)
|
||||
raise DenyConnection()
|
||||
|
||||
self.user = user
|
||||
scope["user"] = user
|
||||
|
|
|
@ -22,6 +22,7 @@ class InterfaceView(TemplateView):
|
|||
kwargs["version_family"] = f"{LOCAL_VERSION.major}.{LOCAL_VERSION.minor}"
|
||||
kwargs["version_subdomain"] = f"version-{LOCAL_VERSION.major}-{LOCAL_VERSION.minor}"
|
||||
kwargs["build"] = get_build_hash()
|
||||
kwargs["url_kwargs"] = self.kwargs
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
|
||||
|
|
|
@ -1,6 +1,8 @@
|
|||
"""Enterprise license policies"""
|
||||
from typing import Optional
|
||||
|
||||
from django.utils.translation import gettext_lazy as _
|
||||
|
||||
from authentik.core.models import User, UserTypes
|
||||
from authentik.enterprise.models import LicenseKey
|
||||
from authentik.policies.types import PolicyRequest, PolicyResult
|
||||
|
@ -13,10 +15,10 @@ class EnterprisePolicyAccessView(PolicyAccessView):
|
|||
def check_license(self):
|
||||
"""Check license"""
|
||||
if not LicenseKey.get_total().is_valid():
|
||||
return False
|
||||
return PolicyResult(False, _("Enterprise required to access this feature."))
|
||||
if self.request.user.type != UserTypes.INTERNAL:
|
||||
return False
|
||||
return True
|
||||
return PolicyResult(False, _("Feature only accessible for internal users."))
|
||||
return PolicyResult(True)
|
||||
|
||||
def user_has_access(self, user: Optional[User] = None) -> PolicyResult:
|
||||
user = user or self.request.user
|
||||
|
@ -24,7 +26,7 @@ class EnterprisePolicyAccessView(PolicyAccessView):
|
|||
request.http_request = self.request
|
||||
result = super().user_has_access(user)
|
||||
enterprise_result = self.check_license()
|
||||
if not enterprise_result:
|
||||
if not enterprise_result.passing:
|
||||
return enterprise_result
|
||||
return result
|
||||
|
||||
|
|
0
authentik/enterprise/providers/__init__.py
Normal file
0
authentik/enterprise/providers/__init__.py
Normal file
0
authentik/enterprise/providers/rac/__init__.py
Normal file
0
authentik/enterprise/providers/rac/__init__.py
Normal file
0
authentik/enterprise/providers/rac/api/__init__.py
Normal file
0
authentik/enterprise/providers/rac/api/__init__.py
Normal file
133
authentik/enterprise/providers/rac/api/endpoints.py
Normal file
133
authentik/enterprise/providers/rac/api/endpoints.py
Normal file
|
@ -0,0 +1,133 @@
|
|||
"""RAC Provider API Views"""
|
||||
from typing import Optional
|
||||
|
||||
from django.core.cache import cache
|
||||
from django.db.models import QuerySet
|
||||
from django.urls import reverse
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, OpenApiResponse, extend_schema
|
||||
from rest_framework.fields import SerializerMethodField
|
||||
from rest_framework.request import Request
|
||||
from rest_framework.response import Response
|
||||
from rest_framework.serializers import ModelSerializer
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
|
||||
from authentik.enterprise.providers.rac.models import Endpoint
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
from authentik.rbac.filters import ObjectFilter
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
def user_endpoint_cache_key(user_pk: str) -> str:
|
||||
"""Cache key where endpoint list for user is saved"""
|
||||
return f"goauthentik.io/providers/rac/endpoint_access/{user_pk}"
|
||||
|
||||
|
||||
class EndpointSerializer(ModelSerializer):
|
||||
"""Endpoint Serializer"""
|
||||
|
||||
provider_obj = RACProviderSerializer(source="provider", read_only=True)
|
||||
launch_url = SerializerMethodField()
|
||||
|
||||
def get_launch_url(self, endpoint: Endpoint) -> Optional[str]:
|
||||
"""Build actual launch URL (the provider itself does not have one, just
|
||||
individual endpoints)"""
|
||||
try:
|
||||
# pylint: disable=no-member
|
||||
return reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": endpoint.provider.application.slug, "endpoint": endpoint.pk},
|
||||
)
|
||||
except Provider.application.RelatedObjectDoesNotExist:
|
||||
return None
|
||||
|
||||
class Meta:
|
||||
model = Endpoint
|
||||
fields = [
|
||||
"pk",
|
||||
"name",
|
||||
"provider",
|
||||
"provider_obj",
|
||||
"protocol",
|
||||
"host",
|
||||
"settings",
|
||||
"property_mappings",
|
||||
"auth_mode",
|
||||
"launch_url",
|
||||
]
|
||||
|
||||
|
||||
class EndpointViewSet(UsedByMixin, ModelViewSet):
|
||||
"""Endpoint Viewset"""
|
||||
|
||||
queryset = Endpoint.objects.all()
|
||||
serializer_class = EndpointSerializer
|
||||
filterset_fields = ["name", "provider"]
|
||||
search_fields = ["name", "protocol"]
|
||||
ordering = ["name", "protocol"]
|
||||
|
||||
def _filter_queryset_for_list(self, queryset: QuerySet) -> QuerySet:
|
||||
"""Custom filter_queryset method which ignores guardian, but still supports sorting"""
|
||||
for backend in list(self.filter_backends):
|
||||
if backend == ObjectFilter:
|
||||
continue
|
||||
queryset = backend().filter_queryset(self.request, queryset, self)
|
||||
return queryset
|
||||
|
||||
def _get_allowed_endpoints(self, queryset: QuerySet) -> list[Endpoint]:
|
||||
endpoints = []
|
||||
for endpoint in queryset:
|
||||
engine = PolicyEngine(endpoint, self.request.user, self.request)
|
||||
engine.build()
|
||||
if engine.passing:
|
||||
endpoints.append(endpoint)
|
||||
return endpoints
|
||||
|
||||
@extend_schema(
|
||||
parameters=[
|
||||
OpenApiParameter(
|
||||
"search",
|
||||
OpenApiTypes.STR,
|
||||
),
|
||||
OpenApiParameter(
|
||||
name="superuser_full_list",
|
||||
location=OpenApiParameter.QUERY,
|
||||
type=OpenApiTypes.BOOL,
|
||||
),
|
||||
],
|
||||
responses={
|
||||
200: EndpointSerializer(many=True),
|
||||
400: OpenApiResponse(description="Bad request"),
|
||||
},
|
||||
)
|
||||
def list(self, request: Request, *args, **kwargs) -> Response:
|
||||
"""List accessible endpoints"""
|
||||
should_cache = request.GET.get("search", "") == ""
|
||||
|
||||
superuser_full_list = str(request.GET.get("superuser_full_list", "false")).lower() == "true"
|
||||
if superuser_full_list and request.user.is_superuser:
|
||||
return super().list(request)
|
||||
|
||||
queryset = self._filter_queryset_for_list(self.get_queryset())
|
||||
self.paginate_queryset(queryset)
|
||||
|
||||
allowed_endpoints = []
|
||||
if not should_cache:
|
||||
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
||||
if should_cache:
|
||||
allowed_endpoints = cache.get(user_endpoint_cache_key(self.request.user.pk))
|
||||
if not allowed_endpoints:
|
||||
LOGGER.debug("Caching allowed endpoint list")
|
||||
allowed_endpoints = self._get_allowed_endpoints(queryset)
|
||||
cache.set(
|
||||
user_endpoint_cache_key(self.request.user.pk),
|
||||
allowed_endpoints,
|
||||
timeout=86400,
|
||||
)
|
||||
serializer = self.get_serializer(allowed_endpoints, many=True)
|
||||
return self.get_paginated_response(serializer.data)
|
35
authentik/enterprise/providers/rac/api/property_mappings.py
Normal file
35
authentik/enterprise/providers/rac/api/property_mappings.py
Normal file
|
@ -0,0 +1,35 @@
|
|||
"""RAC Provider API Views"""
|
||||
from rest_framework.fields import CharField
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.propertymappings import PropertyMappingSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import JSONDictField
|
||||
from authentik.enterprise.providers.rac.models import RACPropertyMapping
|
||||
|
||||
|
||||
class RACPropertyMappingSerializer(PropertyMappingSerializer):
|
||||
"""RACPropertyMapping Serializer"""
|
||||
|
||||
static_settings = JSONDictField()
|
||||
expression = CharField(allow_blank=True, required=False)
|
||||
|
||||
def validate_expression(self, expression: str) -> str:
|
||||
"""Test Syntax"""
|
||||
if expression == "":
|
||||
return expression
|
||||
return super().validate_expression(expression)
|
||||
|
||||
class Meta:
|
||||
model = RACPropertyMapping
|
||||
fields = PropertyMappingSerializer.Meta.fields + ["static_settings"]
|
||||
|
||||
|
||||
class RACPropertyMappingViewSet(UsedByMixin, ModelViewSet):
|
||||
"""RACPropertyMapping Viewset"""
|
||||
|
||||
queryset = RACPropertyMapping.objects.all()
|
||||
serializer_class = RACPropertyMappingSerializer
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
||||
filterset_fields = ["name", "managed"]
|
31
authentik/enterprise/providers/rac/api/providers.py
Normal file
31
authentik/enterprise/providers/rac/api/providers.py
Normal file
|
@ -0,0 +1,31 @@
|
|||
"""RAC Provider API Views"""
|
||||
from rest_framework.fields import CharField, ListField
|
||||
from rest_framework.viewsets import ModelViewSet
|
||||
|
||||
from authentik.core.api.providers import ProviderSerializer
|
||||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.enterprise.providers.rac.models import RACProvider
|
||||
|
||||
|
||||
class RACProviderSerializer(ProviderSerializer):
|
||||
"""RACProvider Serializer"""
|
||||
|
||||
outpost_set = ListField(child=CharField(), read_only=True, source="outpost_set.all")
|
||||
|
||||
class Meta:
|
||||
model = RACProvider
|
||||
fields = ProviderSerializer.Meta.fields + ["settings", "outpost_set", "connection_expiry"]
|
||||
extra_kwargs = ProviderSerializer.Meta.extra_kwargs
|
||||
|
||||
|
||||
class RACProviderViewSet(UsedByMixin, ModelViewSet):
|
||||
"""RACProvider Viewset"""
|
||||
|
||||
queryset = RACProvider.objects.all()
|
||||
serializer_class = RACProviderSerializer
|
||||
filterset_fields = {
|
||||
"application": ["isnull"],
|
||||
"name": ["iexact"],
|
||||
}
|
||||
search_fields = ["name"]
|
||||
ordering = ["name"]
|
17
authentik/enterprise/providers/rac/apps.py
Normal file
17
authentik/enterprise/providers/rac/apps.py
Normal file
|
@ -0,0 +1,17 @@
|
|||
"""RAC app config"""
|
||||
from authentik.blueprints.apps import ManagedAppConfig
|
||||
|
||||
|
||||
class AuthentikEnterpriseProviderRAC(ManagedAppConfig):
|
||||
"""authentik enterprise rac app config"""
|
||||
|
||||
name = "authentik.enterprise.providers.rac"
|
||||
label = "authentik_providers_rac"
|
||||
verbose_name = "authentik Enterprise.Providers.RAC"
|
||||
default = True
|
||||
mountpoint = ""
|
||||
ws_mountpoint = "authentik.enterprise.providers.rac.urls"
|
||||
|
||||
def reconcile_load_rac_signals(self):
|
||||
"""Load rac signals"""
|
||||
self.import_module("authentik.enterprise.providers.rac.signals")
|
163
authentik/enterprise/providers/rac/consumer_client.py
Normal file
163
authentik/enterprise/providers/rac/consumer_client.py
Normal file
|
@ -0,0 +1,163 @@
|
|||
"""RAC Client consumer"""
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.db import database_sync_to_async
|
||||
from channels.exceptions import ChannelFull, DenyConnection
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
from django.http.request import QueryDict
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken, RACProvider
|
||||
from authentik.outposts.consumer import OUTPOST_GROUP_INSTANCE
|
||||
from authentik.outposts.models import Outpost, OutpostState, OutpostType
|
||||
|
||||
# Global broadcast group, which messages are sent to when the outpost connects back
|
||||
# to authentik for a specific connection
|
||||
# The `RACClientConsumer` consumer adds itself to this group on connection,
|
||||
# and removes itself once it has been assigned a specific outpost channel
|
||||
RAC_CLIENT_GROUP = "group_enterprise_rac_client"
|
||||
# A group for all connections in a given authentik session ID
|
||||
# A disconnect message is sent to this group when the session expires/is deleted
|
||||
RAC_CLIENT_GROUP_SESSION = "group_enterprise_rac_client_%(session)s"
|
||||
# A group for all connections with a specific token, which in almost all cases
|
||||
# is just one connection, however this is used to disconnect the connection
|
||||
# when the token is deleted
|
||||
RAC_CLIENT_GROUP_TOKEN = "group_enterprise_rac_token_%(token)s" # nosec
|
||||
|
||||
# Step 1: Client connects to this websocket endpoint
|
||||
# Step 2: We prepare all the connection args for Guac
|
||||
# Step 3: Send a websocket message to a single outpost that has this provider assigned
|
||||
# (Currently sending to all of them)
|
||||
# (Should probably do different load balancing algorithms)
|
||||
# Step 4: Outpost creates a websocket connection back to authentik
|
||||
# with /ws/outpost_rac/<our_channel_id>/
|
||||
# Step 5: This consumer transfers data between the two channels
|
||||
|
||||
|
||||
class RACClientConsumer(AsyncWebsocketConsumer):
|
||||
"""RAC client consumer the browser connects to"""
|
||||
|
||||
dest_channel_id: str = ""
|
||||
provider: RACProvider
|
||||
token: ConnectionToken
|
||||
logger: BoundLogger
|
||||
|
||||
async def connect(self):
|
||||
await self.accept("guacamole")
|
||||
await self.channel_layer.group_add(RAC_CLIENT_GROUP, self.channel_name)
|
||||
await self.channel_layer.group_add(
|
||||
RAC_CLIENT_GROUP_SESSION % {"session": self.scope["session"].session_key},
|
||||
self.channel_name,
|
||||
)
|
||||
await self.init_outpost_connection()
|
||||
|
||||
async def disconnect(self, code):
|
||||
self.logger.debug("Disconnecting")
|
||||
# Tell the outpost we're disconnecting
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.disconnect",
|
||||
},
|
||||
)
|
||||
|
||||
@database_sync_to_async
|
||||
def init_outpost_connection(self):
|
||||
"""Initialize guac connection settings"""
|
||||
self.token = ConnectionToken.filter_not_expired(
|
||||
token=self.scope["url_route"]["kwargs"]["token"]
|
||||
).first()
|
||||
if not self.token:
|
||||
raise DenyConnection()
|
||||
self.provider = self.token.provider
|
||||
params = self.token.get_settings()
|
||||
self.logger = get_logger().bind(
|
||||
endpoint=self.token.endpoint.name, user=self.scope["user"].username
|
||||
)
|
||||
msg = {
|
||||
"type": "event.provider.specific",
|
||||
"sub_type": "init_connection",
|
||||
"dest_channel_id": self.channel_name,
|
||||
"params": params,
|
||||
"protocol": self.token.endpoint.protocol,
|
||||
}
|
||||
query = QueryDict(self.scope["query_string"].decode())
|
||||
for key in ["screen_width", "screen_height", "screen_dpi", "audio"]:
|
||||
value = query.get(key, None)
|
||||
if not value:
|
||||
continue
|
||||
msg[key] = str(value)
|
||||
outposts = Outpost.objects.filter(
|
||||
type=OutpostType.RAC,
|
||||
providers__in=[self.provider],
|
||||
)
|
||||
if not outposts.exists():
|
||||
self.logger.warning("Provider has no outpost")
|
||||
raise DenyConnection()
|
||||
for outpost in outposts:
|
||||
# Sort all states for the outpost by connection count
|
||||
states = sorted(
|
||||
OutpostState.for_outpost(outpost),
|
||||
key=lambda state: int(state.args.get("active_connections", 0)),
|
||||
)
|
||||
if len(states) < 1:
|
||||
continue
|
||||
self.logger.debug("Sending out connection broadcast")
|
||||
async_to_sync(self.channel_layer.group_send)(
|
||||
OUTPOST_GROUP_INSTANCE % {"outpost_pk": str(outpost.pk), "instance": states[0].uid},
|
||||
msg,
|
||||
)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
"""Mirror data received from client to the dest_channel_id
|
||||
which is the channel talking to guacd"""
|
||||
if self.dest_channel_id == "":
|
||||
return
|
||||
if self.token.is_expired:
|
||||
await self.event_disconnect({"reason": "token_expiry"})
|
||||
return
|
||||
try:
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.send",
|
||||
"text_data": text_data,
|
||||
"bytes_data": bytes_data,
|
||||
},
|
||||
)
|
||||
except ChannelFull:
|
||||
pass
|
||||
|
||||
async def event_outpost_connected(self, event: dict):
|
||||
"""Handle event broadcasted from outpost consumer, and check if they
|
||||
created a connection for us"""
|
||||
outpost_channel = event.get("outpost_channel")
|
||||
if event.get("client_channel") != self.channel_name:
|
||||
return
|
||||
if self.dest_channel_id != "":
|
||||
# We've already selected an outpost channel, so tell the other channel to disconnect
|
||||
# This should never happen since we remove ourselves from the broadcast group
|
||||
await self.channel_layer.send(
|
||||
outpost_channel,
|
||||
{
|
||||
"type": "event.disconnect",
|
||||
},
|
||||
)
|
||||
return
|
||||
self.logger.debug("Connected to a single outpost instance")
|
||||
self.dest_channel_id = outpost_channel
|
||||
# Since we have a specific outpost channel now, we can remove
|
||||
# ourselves from the global broadcast group
|
||||
await self.channel_layer.group_discard(RAC_CLIENT_GROUP, self.channel_name)
|
||||
|
||||
async def event_send(self, event: dict):
|
||||
"""Handler called by outpost websocket that sends data to this specific
|
||||
client connection"""
|
||||
if self.token.is_expired:
|
||||
await self.event_disconnect({"reason": "token_expiry"})
|
||||
return
|
||||
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
|
||||
|
||||
async def event_disconnect(self, event: dict):
|
||||
"""Disconnect when the session ends"""
|
||||
self.logger.info("Disconnecting RAC connection", reason=event.get("reason"))
|
||||
await self.close()
|
48
authentik/enterprise/providers/rac/consumer_outpost.py
Normal file
48
authentik/enterprise/providers/rac/consumer_outpost.py
Normal file
|
@ -0,0 +1,48 @@
|
|||
"""RAC consumer"""
|
||||
from channels.exceptions import ChannelFull
|
||||
from channels.generic.websocket import AsyncWebsocketConsumer
|
||||
|
||||
from authentik.enterprise.providers.rac.consumer_client import RAC_CLIENT_GROUP
|
||||
|
||||
|
||||
class RACOutpostConsumer(AsyncWebsocketConsumer):
|
||||
"""Consumer the outpost connects to, to send specific data back to a client connection"""
|
||||
|
||||
dest_channel_id: str
|
||||
|
||||
async def connect(self):
|
||||
self.dest_channel_id = self.scope["url_route"]["kwargs"]["channel"]
|
||||
await self.accept()
|
||||
await self.channel_layer.group_send(
|
||||
RAC_CLIENT_GROUP,
|
||||
{
|
||||
"type": "event.outpost.connected",
|
||||
"outpost_channel": self.channel_name,
|
||||
"client_channel": self.dest_channel_id,
|
||||
},
|
||||
)
|
||||
|
||||
async def receive(self, text_data=None, bytes_data=None):
|
||||
"""Mirror data received from guacd running in the outpost
|
||||
to the dest_channel_id which is the channel talking to the browser"""
|
||||
try:
|
||||
await self.channel_layer.send(
|
||||
self.dest_channel_id,
|
||||
{
|
||||
"type": "event.send",
|
||||
"text_data": text_data,
|
||||
"bytes_data": bytes_data,
|
||||
},
|
||||
)
|
||||
except ChannelFull:
|
||||
pass
|
||||
|
||||
async def event_send(self, event: dict):
|
||||
"""Handler called by client websocket that sends data to this specific
|
||||
outpost connection"""
|
||||
await self.send(text_data=event.get("text_data"), bytes_data=event.get("bytes_data"))
|
||||
|
||||
async def event_disconnect(self, event: dict):
|
||||
"""Tell outpost we're about to disconnect"""
|
||||
await self.send(text_data="0.authentik.disconnect")
|
||||
await self.close()
|
11
authentik/enterprise/providers/rac/controllers/docker.py
Normal file
11
authentik/enterprise/providers/rac/controllers/docker.py
Normal file
|
@ -0,0 +1,11 @@
|
|||
"""RAC Provider Docker Controller"""
|
||||
from authentik.outposts.controllers.docker import DockerController
|
||||
from authentik.outposts.models import DockerServiceConnection, Outpost
|
||||
|
||||
|
||||
class RACDockerController(DockerController):
|
||||
"""RAC Provider Docker Controller"""
|
||||
|
||||
def __init__(self, outpost: Outpost, connection: DockerServiceConnection):
|
||||
super().__init__(outpost, connection)
|
||||
self.deployment_ports = []
|
13
authentik/enterprise/providers/rac/controllers/kubernetes.py
Normal file
13
authentik/enterprise/providers/rac/controllers/kubernetes.py
Normal file
|
@ -0,0 +1,13 @@
|
|||
"""RAC Provider Kubernetes Controller"""
|
||||
from authentik.outposts.controllers.k8s.service import ServiceReconciler
|
||||
from authentik.outposts.controllers.kubernetes import KubernetesController
|
||||
from authentik.outposts.models import KubernetesServiceConnection, Outpost
|
||||
|
||||
|
||||
class RACKubernetesController(KubernetesController):
|
||||
"""RAC Provider Kubernetes Controller"""
|
||||
|
||||
def __init__(self, outpost: Outpost, connection: KubernetesServiceConnection):
|
||||
super().__init__(outpost, connection)
|
||||
self.deployment_ports = []
|
||||
del self.reconcilers[ServiceReconciler.reconciler_name()]
|
164
authentik/enterprise/providers/rac/migrations/0001_initial.py
Normal file
164
authentik/enterprise/providers/rac/migrations/0001_initial.py
Normal file
|
@ -0,0 +1,164 @@
|
|||
# Generated by Django 4.2.8 on 2023-12-29 15:58
|
||||
|
||||
import uuid
|
||||
|
||||
import django.db.models.deletion
|
||||
from django.db import migrations, models
|
||||
|
||||
import authentik.core.models
|
||||
import authentik.lib.utils.time
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
initial = True
|
||||
|
||||
dependencies = [
|
||||
("authentik_policies", "0011_policybinding_failure_result_and_more"),
|
||||
("authentik_core", "0032_group_roles"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.CreateModel(
|
||||
name="RACPropertyMapping",
|
||||
fields=[
|
||||
(
|
||||
"propertymapping_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.propertymapping",
|
||||
),
|
||||
),
|
||||
("static_settings", models.JSONField(default=dict)),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "RAC Property Mapping",
|
||||
"verbose_name_plural": "RAC Property Mappings",
|
||||
},
|
||||
bases=("authentik_core.propertymapping",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="RACProvider",
|
||||
fields=[
|
||||
(
|
||||
"provider_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_core.provider",
|
||||
),
|
||||
),
|
||||
("settings", models.JSONField(default=dict)),
|
||||
(
|
||||
"auth_mode",
|
||||
models.TextField(
|
||||
choices=[("static", "Static"), ("prompt", "Prompt")], default="prompt"
|
||||
),
|
||||
),
|
||||
(
|
||||
"connection_expiry",
|
||||
models.TextField(
|
||||
default="hours=8",
|
||||
help_text="Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)",
|
||||
validators=[authentik.lib.utils.time.timedelta_string_validator],
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
},
|
||||
bases=("authentik_core.provider",),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="Endpoint",
|
||||
fields=[
|
||||
(
|
||||
"policybindingmodel_ptr",
|
||||
models.OneToOneField(
|
||||
auto_created=True,
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
parent_link=True,
|
||||
primary_key=True,
|
||||
serialize=False,
|
||||
to="authentik_policies.policybindingmodel",
|
||||
),
|
||||
),
|
||||
("name", models.TextField()),
|
||||
("host", models.TextField()),
|
||||
(
|
||||
"protocol",
|
||||
models.TextField(choices=[("rdp", "Rdp"), ("vnc", "Vnc"), ("ssh", "Ssh")]),
|
||||
),
|
||||
("settings", models.JSONField(default=dict)),
|
||||
(
|
||||
"auth_mode",
|
||||
models.TextField(choices=[("static", "Static"), ("prompt", "Prompt")]),
|
||||
),
|
||||
(
|
||||
"property_mappings",
|
||||
models.ManyToManyField(
|
||||
blank=True, default=None, to="authentik_core.propertymapping"
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_rac.racprovider",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"verbose_name": "RAC Endpoint",
|
||||
"verbose_name_plural": "RAC Endpoints",
|
||||
},
|
||||
bases=("authentik_policies.policybindingmodel", models.Model),
|
||||
),
|
||||
migrations.CreateModel(
|
||||
name="ConnectionToken",
|
||||
fields=[
|
||||
(
|
||||
"expires",
|
||||
models.DateTimeField(default=authentik.core.models.default_token_duration),
|
||||
),
|
||||
("expiring", models.BooleanField(default=True)),
|
||||
(
|
||||
"connection_token_uuid",
|
||||
models.UUIDField(default=uuid.uuid4, primary_key=True, serialize=False),
|
||||
),
|
||||
("token", models.TextField(default=authentik.core.models.default_token_key)),
|
||||
("settings", models.JSONField(default=dict)),
|
||||
(
|
||||
"endpoint",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_rac.endpoint",
|
||||
),
|
||||
),
|
||||
(
|
||||
"provider",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_providers_rac.racprovider",
|
||||
),
|
||||
),
|
||||
(
|
||||
"session",
|
||||
models.ForeignKey(
|
||||
on_delete=django.db.models.deletion.CASCADE,
|
||||
to="authentik_core.authenticatedsession",
|
||||
),
|
||||
),
|
||||
],
|
||||
options={
|
||||
"abstract": False,
|
||||
},
|
||||
),
|
||||
]
|
191
authentik/enterprise/providers/rac/models.py
Normal file
191
authentik/enterprise/providers/rac/models.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
"""RAC Models"""
|
||||
from typing import Optional
|
||||
from uuid import uuid4
|
||||
|
||||
from deepmerge import always_merger
|
||||
from django.db import models
|
||||
from django.db.models import QuerySet
|
||||
from django.utils.translation import gettext as _
|
||||
from rest_framework.serializers import Serializer
|
||||
from structlog.stdlib import get_logger
|
||||
|
||||
from authentik.core.exceptions import PropertyMappingExpressionException
|
||||
from authentik.core.models import ExpiringModel, PropertyMapping, Provider, default_token_key
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.lib.models import SerializerModel
|
||||
from authentik.lib.utils.time import timedelta_string_validator
|
||||
from authentik.policies.models import PolicyBindingModel
|
||||
|
||||
LOGGER = get_logger()
|
||||
|
||||
|
||||
class Protocols(models.TextChoices):
|
||||
"""Supported protocols"""
|
||||
|
||||
RDP = "rdp"
|
||||
VNC = "vnc"
|
||||
SSH = "ssh"
|
||||
|
||||
|
||||
class AuthenticationMode(models.TextChoices):
|
||||
"""Authentication modes"""
|
||||
|
||||
STATIC = "static"
|
||||
PROMPT = "prompt"
|
||||
|
||||
|
||||
class RACProvider(Provider):
|
||||
"""Remotely access computers/servers"""
|
||||
|
||||
settings = models.JSONField(default=dict)
|
||||
auth_mode = models.TextField(
|
||||
choices=AuthenticationMode.choices, default=AuthenticationMode.PROMPT
|
||||
)
|
||||
connection_expiry = models.TextField(
|
||||
default="hours=8",
|
||||
validators=[timedelta_string_validator],
|
||||
help_text=_(
|
||||
"Determines how long a session lasts. Default of 0 means "
|
||||
"that the sessions lasts until the browser is closed. "
|
||||
"(Format: hours=-1;minutes=-2;seconds=-3)"
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def launch_url(self) -> Optional[str]:
|
||||
"""URL to this provider and initiate authorization for the user.
|
||||
Can return None for providers that are not URL-based"""
|
||||
return "goauthentik.io://providers/rac/launch"
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-provider-rac-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.rac.api.providers import RACProviderSerializer
|
||||
|
||||
return RACProviderSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Provider")
|
||||
verbose_name_plural = _("RAC Providers")
|
||||
|
||||
|
||||
class Endpoint(SerializerModel, PolicyBindingModel):
|
||||
"""Remote-accessible endpoint"""
|
||||
|
||||
name = models.TextField()
|
||||
host = models.TextField()
|
||||
protocol = models.TextField(choices=Protocols.choices)
|
||||
settings = models.JSONField(default=dict)
|
||||
auth_mode = models.TextField(choices=AuthenticationMode.choices)
|
||||
provider = models.ForeignKey("RACProvider", on_delete=models.CASCADE)
|
||||
|
||||
property_mappings = models.ManyToManyField(
|
||||
"authentik_core.PropertyMapping", default=None, blank=True
|
||||
)
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.rac.api.endpoints import EndpointSerializer
|
||||
|
||||
return EndpointSerializer
|
||||
|
||||
def __str__(self):
|
||||
return f"RAC Endpoint {self.name}"
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Endpoint")
|
||||
verbose_name_plural = _("RAC Endpoints")
|
||||
|
||||
|
||||
class RACPropertyMapping(PropertyMapping):
|
||||
"""Configure settings for remote access endpoints."""
|
||||
|
||||
static_settings = models.JSONField(default=dict)
|
||||
|
||||
@property
|
||||
def component(self) -> str:
|
||||
return "ak-property-mapping-rac-form"
|
||||
|
||||
@property
|
||||
def serializer(self) -> type[Serializer]:
|
||||
from authentik.enterprise.providers.rac.api.property_mappings import (
|
||||
RACPropertyMappingSerializer,
|
||||
)
|
||||
|
||||
return RACPropertyMappingSerializer
|
||||
|
||||
class Meta:
|
||||
verbose_name = _("RAC Property Mapping")
|
||||
verbose_name_plural = _("RAC Property Mappings")
|
||||
|
||||
|
||||
class ConnectionToken(ExpiringModel):
|
||||
"""Token for a single connection to a specified endpoint"""
|
||||
|
||||
connection_token_uuid = models.UUIDField(default=uuid4, primary_key=True)
|
||||
provider = models.ForeignKey(RACProvider, on_delete=models.CASCADE)
|
||||
endpoint = models.ForeignKey(Endpoint, on_delete=models.CASCADE)
|
||||
token = models.TextField(default=default_token_key)
|
||||
settings = models.JSONField(default=dict)
|
||||
session = models.ForeignKey("authentik_core.AuthenticatedSession", on_delete=models.CASCADE)
|
||||
|
||||
def get_settings(self) -> dict:
|
||||
"""Get settings"""
|
||||
default_settings = {}
|
||||
if ":" in self.endpoint.host:
|
||||
host, _, port = self.endpoint.host.partition(":")
|
||||
default_settings["hostname"] = host
|
||||
default_settings["port"] = str(port)
|
||||
else:
|
||||
default_settings["hostname"] = self.endpoint.host
|
||||
default_settings["client-name"] = "authentik"
|
||||
# default_settings["enable-drive"] = "true"
|
||||
# default_settings["drive-name"] = "authentik"
|
||||
settings = {}
|
||||
always_merger.merge(settings, default_settings)
|
||||
always_merger.merge(settings, self.endpoint.provider.settings)
|
||||
always_merger.merge(settings, self.endpoint.settings)
|
||||
always_merger.merge(settings, self.settings)
|
||||
|
||||
def mapping_evaluator(mappings: QuerySet):
|
||||
for mapping in mappings:
|
||||
mapping: RACPropertyMapping
|
||||
if len(mapping.static_settings) > 0:
|
||||
always_merger.merge(settings, mapping.static_settings)
|
||||
continue
|
||||
try:
|
||||
mapping_settings = mapping.evaluate(
|
||||
self.session.user, None, endpoint=self.endpoint, provider=self.provider
|
||||
)
|
||||
always_merger.merge(settings, mapping_settings)
|
||||
except PropertyMappingExpressionException as exc:
|
||||
Event.new(
|
||||
EventAction.CONFIGURATION_ERROR,
|
||||
message=f"Failed to evaluate property-mapping: '{mapping.name}'",
|
||||
provider=self.provider,
|
||||
mapping=mapping,
|
||||
).set_user(self.session.user).save()
|
||||
LOGGER.warning("Failed to evaluate property mapping", exc=exc)
|
||||
|
||||
mapping_evaluator(
|
||||
RACPropertyMapping.objects.filter(provider__in=[self.provider]).order_by("name")
|
||||
)
|
||||
mapping_evaluator(
|
||||
RACPropertyMapping.objects.filter(endpoint__in=[self.endpoint]).order_by("name")
|
||||
)
|
||||
|
||||
settings["drive-path"] = f"/tmp/connection/{self.token}" # nosec
|
||||
settings["create-drive-path"] = "true"
|
||||
# Ensure all values of the settings dict are strings
|
||||
for key, value in settings.items():
|
||||
if isinstance(value, str):
|
||||
continue
|
||||
# Special case for bools
|
||||
if isinstance(value, bool):
|
||||
settings[key] = str(value).lower()
|
||||
continue
|
||||
settings[key] = str(value)
|
||||
return settings
|
54
authentik/enterprise/providers/rac/signals.py
Normal file
54
authentik/enterprise/providers/rac/signals.py
Normal file
|
@ -0,0 +1,54 @@
|
|||
"""RAC Signals"""
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.layers import get_channel_layer
|
||||
from django.contrib.auth.signals import user_logged_out
|
||||
from django.core.cache import cache
|
||||
from django.db.models import Model
|
||||
from django.db.models.signals import post_save, pre_delete
|
||||
from django.dispatch import receiver
|
||||
from django.http import HttpRequest
|
||||
|
||||
from authentik.core.models import User
|
||||
from authentik.enterprise.providers.rac.api.endpoints import user_endpoint_cache_key
|
||||
from authentik.enterprise.providers.rac.consumer_client import (
|
||||
RAC_CLIENT_GROUP_SESSION,
|
||||
RAC_CLIENT_GROUP_TOKEN,
|
||||
)
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint
|
||||
|
||||
|
||||
@receiver(user_logged_out)
|
||||
def user_logged_out_session(sender, request: HttpRequest, user: User, **_):
|
||||
"""Disconnect any open RAC connections"""
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
RAC_CLIENT_GROUP_SESSION
|
||||
% {
|
||||
"session": request.session.session_key,
|
||||
},
|
||||
{"type": "event.disconnect", "reason": "session_logout"},
|
||||
)
|
||||
|
||||
|
||||
@receiver(pre_delete, sender=ConnectionToken)
|
||||
def pre_delete_connection_token_disconnect(sender, instance: ConnectionToken, **_):
|
||||
"""Disconnect session when connection token is deleted"""
|
||||
layer = get_channel_layer()
|
||||
async_to_sync(layer.group_send)(
|
||||
RAC_CLIENT_GROUP_TOKEN
|
||||
% {
|
||||
"token": instance.token,
|
||||
},
|
||||
{"type": "event.disconnect", "reason": "token_delete"},
|
||||
)
|
||||
|
||||
|
||||
@receiver(post_save, sender=Endpoint)
|
||||
def post_save_application(sender: type[Model], instance, created: bool, **_):
|
||||
"""Clear user's application cache upon application creation"""
|
||||
if not created: # pragma: no cover
|
||||
return
|
||||
|
||||
# Delete user endpoint cache
|
||||
keys = cache.keys(user_endpoint_cache_key("*"))
|
||||
cache.delete_many(keys)
|
18
authentik/enterprise/providers/rac/templates/if/rac.html
Normal file
18
authentik/enterprise/providers/rac/templates/if/rac.html
Normal file
|
@ -0,0 +1,18 @@
|
|||
{% extends "base/skeleton.html" %}
|
||||
|
||||
{% load static %}
|
||||
|
||||
{% block head %}
|
||||
<script src="{% static 'dist/enterprise/rac/index.js' %}?version={{ version }}" type="module"></script>
|
||||
<meta name="theme-color" content="#18191a" media="(prefers-color-scheme: dark)">
|
||||
<meta name="theme-color" content="#ffffff" media="(prefers-color-scheme: light)">
|
||||
<link rel="icon" href="{{ tenant.branding_favicon }}">
|
||||
<link rel="shortcut icon" href="{{ tenant.branding_favicon }}">
|
||||
{% include "base/header_js.html" %}
|
||||
{% endblock %}
|
||||
|
||||
{% block body %}
|
||||
<ak-rac token="{{ url_kwargs.token }}" endpointName="{{ token.endpoint.name }}">
|
||||
<ak-loading></ak-loading>
|
||||
</ak-rac>
|
||||
{% endblock %}
|
168
authentik/enterprise/providers/rac/tests/test_endpoints_api.py
Normal file
168
authentik/enterprise/providers/rac/tests/test_endpoints_api.py
Normal file
|
@ -0,0 +1,168 @@
|
|||
"""Test Endpoints API"""
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
||||
|
||||
class TestEndpointsAPI(APITestCase):
|
||||
"""Test endpoints API"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.user = create_test_admin_user()
|
||||
self.provider = RACProvider.objects.create(
|
||||
name=generate_id(),
|
||||
)
|
||||
self.app = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
self.allowed = Endpoint.objects.create(
|
||||
name=f"a-{generate_id()}",
|
||||
host=generate_id(),
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
self.denied = Endpoint.objects.create(
|
||||
name=f"b-{generate_id()}",
|
||||
host=generate_id(),
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
PolicyBinding.objects.create(
|
||||
target=self.denied,
|
||||
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||
order=0,
|
||||
)
|
||||
|
||||
def test_list(self):
|
||||
"""Test list operation without superuser_full_list"""
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(reverse("authentik_api:endpoint-list"))
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"pagination": {
|
||||
"next": 0,
|
||||
"previous": 0,
|
||||
"count": 2,
|
||||
"current": 1,
|
||||
"total_pages": 1,
|
||||
"start_index": 1,
|
||||
"end_index": 2,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"pk": str(self.allowed.pk),
|
||||
"name": self.allowed.name,
|
||||
"provider": self.provider.pk,
|
||||
"provider_obj": {
|
||||
"pk": self.provider.pk,
|
||||
"name": self.provider.name,
|
||||
"authentication_flow": None,
|
||||
"authorization_flow": None,
|
||||
"property_mappings": [],
|
||||
"connection_expiry": "hours=8",
|
||||
"component": "ak-provider-rac-form",
|
||||
"assigned_application_slug": self.app.slug,
|
||||
"assigned_application_name": self.app.name,
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||
"settings": {},
|
||||
"outpost_set": [],
|
||||
},
|
||||
"protocol": "rdp",
|
||||
"host": self.allowed.host,
|
||||
"settings": {},
|
||||
"property_mappings": [],
|
||||
"auth_mode": "",
|
||||
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
||||
|
||||
def test_list_superuser_full_list(self):
|
||||
"""Test list operation with superuser_full_list"""
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:endpoint-list") + "?superuser_full_list=true"
|
||||
)
|
||||
self.assertJSONEqual(
|
||||
response.content.decode(),
|
||||
{
|
||||
"pagination": {
|
||||
"next": 0,
|
||||
"previous": 0,
|
||||
"count": 2,
|
||||
"current": 1,
|
||||
"total_pages": 1,
|
||||
"start_index": 1,
|
||||
"end_index": 2,
|
||||
},
|
||||
"results": [
|
||||
{
|
||||
"pk": str(self.allowed.pk),
|
||||
"name": self.allowed.name,
|
||||
"provider": self.provider.pk,
|
||||
"provider_obj": {
|
||||
"pk": self.provider.pk,
|
||||
"name": self.provider.name,
|
||||
"authentication_flow": None,
|
||||
"authorization_flow": None,
|
||||
"property_mappings": [],
|
||||
"component": "ak-provider-rac-form",
|
||||
"assigned_application_slug": self.app.slug,
|
||||
"assigned_application_name": self.app.name,
|
||||
"connection_expiry": "hours=8",
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||
"settings": {},
|
||||
"outpost_set": [],
|
||||
},
|
||||
"protocol": "rdp",
|
||||
"host": self.allowed.host,
|
||||
"settings": {},
|
||||
"property_mappings": [],
|
||||
"auth_mode": "",
|
||||
"launch_url": f"/application/rac/{self.app.slug}/{str(self.allowed.pk)}/",
|
||||
},
|
||||
{
|
||||
"pk": str(self.denied.pk),
|
||||
"name": self.denied.name,
|
||||
"provider": self.provider.pk,
|
||||
"provider_obj": {
|
||||
"pk": self.provider.pk,
|
||||
"name": self.provider.name,
|
||||
"authentication_flow": None,
|
||||
"authorization_flow": None,
|
||||
"property_mappings": [],
|
||||
"component": "ak-provider-rac-form",
|
||||
"assigned_application_slug": self.app.slug,
|
||||
"assigned_application_name": self.app.name,
|
||||
"connection_expiry": "hours=8",
|
||||
"verbose_name": "RAC Provider",
|
||||
"verbose_name_plural": "RAC Providers",
|
||||
"meta_model_name": "authentik_providers_rac.racprovider",
|
||||
"settings": {},
|
||||
"outpost_set": [],
|
||||
},
|
||||
"protocol": "rdp",
|
||||
"host": self.denied.host,
|
||||
"settings": {},
|
||||
"property_mappings": [],
|
||||
"auth_mode": "",
|
||||
"launch_url": f"/application/rac/{self.app.slug}/{str(self.denied.pk)}/",
|
||||
},
|
||||
],
|
||||
},
|
||||
)
|
144
authentik/enterprise/providers/rac/tests/test_models.py
Normal file
144
authentik/enterprise/providers/rac/tests/test_models.py
Normal file
|
@ -0,0 +1,144 @@
|
|||
"""Test RAC Models"""
|
||||
from django.test import TransactionTestCase
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
from authentik.core.tests.utils import create_test_admin_user
|
||||
from authentik.enterprise.providers.rac.models import (
|
||||
ConnectionToken,
|
||||
Endpoint,
|
||||
Protocols,
|
||||
RACPropertyMapping,
|
||||
RACProvider,
|
||||
)
|
||||
from authentik.lib.generators import generate_id
|
||||
|
||||
|
||||
class TestModels(TransactionTestCase):
|
||||
"""Test RAC Models"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.provider = RACProvider.objects.create(
|
||||
name=generate_id(),
|
||||
)
|
||||
self.app = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
self.endpoint = Endpoint.objects.create(
|
||||
name=generate_id(),
|
||||
host=f"{generate_id()}:1324",
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
def test_settings_merge(self):
|
||||
"""Test settings merge"""
|
||||
token = ConnectionToken.objects.create(
|
||||
provider=self.provider,
|
||||
endpoint=self.endpoint,
|
||||
session=AuthenticatedSession.objects.create(
|
||||
user=self.user,
|
||||
session_key=generate_id(),
|
||||
),
|
||||
)
|
||||
path = f"/tmp/connection/{token.token}" # nosec
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
},
|
||||
)
|
||||
# Set settings in provider
|
||||
self.provider.settings = {"level": "provider"}
|
||||
self.provider.save()
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "provider",
|
||||
},
|
||||
)
|
||||
# Set settings in endpoint
|
||||
self.endpoint.settings = {
|
||||
"level": "endpoint",
|
||||
}
|
||||
self.endpoint.save()
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "endpoint",
|
||||
},
|
||||
)
|
||||
# Set settings in token
|
||||
token.settings = {
|
||||
"level": "token",
|
||||
}
|
||||
token.save()
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "token",
|
||||
},
|
||||
)
|
||||
# Set settings in property mapping (provider)
|
||||
mapping = RACPropertyMapping.objects.create(
|
||||
name=generate_id(),
|
||||
expression="""return {
|
||||
"level": "property_mapping_provider"
|
||||
}""",
|
||||
)
|
||||
self.provider.property_mappings.add(mapping)
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "property_mapping_provider",
|
||||
},
|
||||
)
|
||||
# Set settings in property mapping (endpoint)
|
||||
mapping = RACPropertyMapping.objects.create(
|
||||
name=generate_id(),
|
||||
static_settings={
|
||||
"level": "property_mapping_endpoint",
|
||||
"foo": True,
|
||||
"bar": 6,
|
||||
},
|
||||
)
|
||||
self.endpoint.property_mappings.add(mapping)
|
||||
self.assertEqual(
|
||||
token.get_settings(),
|
||||
{
|
||||
"hostname": self.endpoint.host.split(":")[0],
|
||||
"port": "1324",
|
||||
"client-name": "authentik",
|
||||
"drive-path": path,
|
||||
"create-drive-path": "true",
|
||||
"level": "property_mapping_endpoint",
|
||||
"foo": "true",
|
||||
"bar": "6",
|
||||
},
|
||||
)
|
132
authentik/enterprise/providers/rac/tests/test_views.py
Normal file
132
authentik/enterprise/providers/rac/tests/test_views.py
Normal file
|
@ -0,0 +1,132 @@
|
|||
"""RAC Views tests"""
|
||||
from datetime import timedelta
|
||||
from json import loads
|
||||
from time import mktime
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.core.models import Application
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.enterprise.models import License, LicenseKey
|
||||
from authentik.enterprise.providers.rac.models import Endpoint, Protocols, RACProvider
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.policies.denied import AccessDeniedResponse
|
||||
from authentik.policies.dummy.models import DummyPolicy
|
||||
from authentik.policies.models import PolicyBinding
|
||||
|
||||
|
||||
class TestRACViews(APITestCase):
|
||||
"""RAC Views tests"""
|
||||
|
||||
def setUp(self):
|
||||
self.user = create_test_admin_user()
|
||||
self.flow = create_test_flow()
|
||||
self.provider = RACProvider.objects.create(name=generate_id(), authorization_flow=self.flow)
|
||||
self.app = Application.objects.create(
|
||||
name=generate_id(),
|
||||
slug=generate_id(),
|
||||
provider=self.provider,
|
||||
)
|
||||
self.endpoint = Endpoint.objects.create(
|
||||
name=generate_id(),
|
||||
host=f"{generate_id()}:1324",
|
||||
protocol=Protocols.RDP,
|
||||
provider=self.provider,
|
||||
)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_no_policy(self):
|
||||
"""Test request"""
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
flow_response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
)
|
||||
body = loads(flow_response.content)
|
||||
next_url = body["to"]
|
||||
final_response = self.client.get(next_url)
|
||||
self.assertEqual(final_response.status_code, 200)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_app_deny(self):
|
||||
"""Test request (deny on app level)"""
|
||||
PolicyBinding.objects.create(
|
||||
target=self.app,
|
||||
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||
order=0,
|
||||
)
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||
)
|
||||
)
|
||||
self.assertIsInstance(response, AccessDeniedResponse)
|
||||
|
||||
@patch(
|
||||
"authentik.enterprise.models.LicenseKey.validate",
|
||||
MagicMock(
|
||||
return_value=LicenseKey(
|
||||
aud="",
|
||||
exp=int(mktime((now() + timedelta(days=3000)).timetuple())),
|
||||
name=generate_id(),
|
||||
internal_users=100,
|
||||
external_users=100,
|
||||
)
|
||||
),
|
||||
)
|
||||
def test_endpoint_deny(self):
|
||||
"""Test request (deny on endpoint level)"""
|
||||
PolicyBinding.objects.create(
|
||||
target=self.endpoint,
|
||||
policy=DummyPolicy.objects.create(name="deny", result=False, wait_min=1, wait_max=2),
|
||||
order=0,
|
||||
)
|
||||
License.objects.create(key=generate_id())
|
||||
self.client.force_login(self.user)
|
||||
response = self.client.get(
|
||||
reverse(
|
||||
"authentik_providers_rac:start",
|
||||
kwargs={"app": self.app.slug, "endpoint": str(self.endpoint.pk)},
|
||||
)
|
||||
)
|
||||
self.assertEqual(response.status_code, 302)
|
||||
flow_response = self.client.get(
|
||||
reverse("authentik_api:flow-executor", kwargs={"flow_slug": self.flow.slug})
|
||||
)
|
||||
body = loads(flow_response.content)
|
||||
self.assertEqual(body["component"], "ak-stage-access-denied")
|
47
authentik/enterprise/providers/rac/urls.py
Normal file
47
authentik/enterprise/providers/rac/urls.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
"""rac urls"""
|
||||
from channels.auth import AuthMiddleware
|
||||
from channels.sessions import CookieMiddleware
|
||||
from django.urls import path
|
||||
from django.views.decorators.csrf import ensure_csrf_cookie
|
||||
|
||||
from authentik.core.channels import TokenOutpostMiddleware
|
||||
from authentik.enterprise.providers.rac.api.endpoints import EndpointViewSet
|
||||
from authentik.enterprise.providers.rac.api.property_mappings import RACPropertyMappingViewSet
|
||||
from authentik.enterprise.providers.rac.api.providers import RACProviderViewSet
|
||||
from authentik.enterprise.providers.rac.consumer_client import RACClientConsumer
|
||||
from authentik.enterprise.providers.rac.consumer_outpost import RACOutpostConsumer
|
||||
from authentik.enterprise.providers.rac.views import RACInterface, RACStartView
|
||||
from authentik.root.asgi_middleware import SessionMiddleware
|
||||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
|
||||
urlpatterns = [
|
||||
path(
|
||||
"application/rac/<slug:app>/<uuid:endpoint>/",
|
||||
ensure_csrf_cookie(RACStartView.as_view()),
|
||||
name="start",
|
||||
),
|
||||
path(
|
||||
"if/rac/<str:token>/",
|
||||
ensure_csrf_cookie(RACInterface.as_view()),
|
||||
name="if-rac",
|
||||
),
|
||||
]
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path(
|
||||
"ws/rac/<str:token>/",
|
||||
ChannelsLoggingMiddleware(
|
||||
CookieMiddleware(SessionMiddleware(AuthMiddleware(RACClientConsumer.as_asgi())))
|
||||
),
|
||||
),
|
||||
path(
|
||||
"ws/outpost_rac/<str:channel>/",
|
||||
ChannelsLoggingMiddleware(TokenOutpostMiddleware(RACOutpostConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
("providers/rac", RACProviderViewSet),
|
||||
("propertymappings/rac", RACPropertyMappingViewSet),
|
||||
("rac/endpoints", EndpointViewSet),
|
||||
]
|
115
authentik/enterprise/providers/rac/views.py
Normal file
115
authentik/enterprise/providers/rac/views.py
Normal file
|
@ -0,0 +1,115 @@
|
|||
"""RAC Views"""
|
||||
from typing import Any
|
||||
|
||||
from django.http import Http404, HttpRequest, HttpResponse
|
||||
from django.shortcuts import get_object_or_404, redirect
|
||||
from django.urls import reverse
|
||||
from django.utils.timezone import now
|
||||
|
||||
from authentik.core.models import Application, AuthenticatedSession
|
||||
from authentik.core.views.interface import InterfaceView
|
||||
from authentik.enterprise.policy import EnterprisePolicyAccessView
|
||||
from authentik.enterprise.providers.rac.models import ConnectionToken, Endpoint, RACProvider
|
||||
from authentik.flows.challenge import RedirectChallenge
|
||||
from authentik.flows.exceptions import FlowNonApplicableException
|
||||
from authentik.flows.models import in_memory_stage
|
||||
from authentik.flows.planner import FlowPlanner
|
||||
from authentik.flows.stage import RedirectStage
|
||||
from authentik.flows.views.executor import SESSION_KEY_PLAN
|
||||
from authentik.lib.utils.time import timedelta_from_string
|
||||
from authentik.lib.utils.urls import redirect_with_qs
|
||||
from authentik.policies.engine import PolicyEngine
|
||||
|
||||
|
||||
class RACStartView(EnterprisePolicyAccessView):
|
||||
"""Start a RAC connection by checking access and creating a connection token"""
|
||||
|
||||
endpoint: Endpoint
|
||||
|
||||
def resolve_provider_application(self):
|
||||
self.application = get_object_or_404(Application, slug=self.kwargs["app"])
|
||||
# Endpoint permissions are validated in the RACFinalStage below
|
||||
self.endpoint = get_object_or_404(Endpoint, pk=self.kwargs["endpoint"])
|
||||
self.provider = RACProvider.objects.get(application=self.application)
|
||||
|
||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||
"""Start flow planner for RAC provider"""
|
||||
planner = FlowPlanner(self.provider.authorization_flow)
|
||||
planner.allow_empty_flows = True
|
||||
try:
|
||||
plan = planner.plan(self.request)
|
||||
except FlowNonApplicableException:
|
||||
raise Http404
|
||||
plan.insert_stage(
|
||||
in_memory_stage(
|
||||
RACFinalStage,
|
||||
endpoint=self.endpoint,
|
||||
provider=self.provider,
|
||||
)
|
||||
)
|
||||
request.session[SESSION_KEY_PLAN] = plan
|
||||
return redirect_with_qs(
|
||||
"authentik_core:if-flow",
|
||||
request.GET,
|
||||
flow_slug=self.provider.authorization_flow.slug,
|
||||
)
|
||||
|
||||
|
||||
class RACInterface(InterfaceView):
|
||||
"""Start RAC connection"""
|
||||
|
||||
template_name = "if/rac.html"
|
||||
token: ConnectionToken
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||
# Early sanity check to ensure token still exists
|
||||
token = ConnectionToken.filter_not_expired(token=self.kwargs["token"]).first()
|
||||
if not token:
|
||||
return redirect("authentik_core:if-user")
|
||||
self.token = token
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||
kwargs["token"] = self.token
|
||||
return super().get_context_data(**kwargs)
|
||||
|
||||
|
||||
class RACFinalStage(RedirectStage):
|
||||
"""RAC Connection final stage, set the connection token in the stage"""
|
||||
|
||||
def dispatch(self, request: HttpRequest, *args: Any, **kwargs: Any) -> HttpResponse:
|
||||
endpoint: Endpoint = self.executor.current_stage.endpoint
|
||||
engine = PolicyEngine(endpoint, self.request.user, self.request)
|
||||
engine.use_cache = False
|
||||
engine.build()
|
||||
passing = engine.result
|
||||
if not passing.passing:
|
||||
return self.executor.stage_invalid(", ".join(passing.messages))
|
||||
return super().dispatch(request, *args, **kwargs)
|
||||
|
||||
def get_challenge(self, *args, **kwargs) -> RedirectChallenge:
|
||||
endpoint: Endpoint = self.executor.current_stage.endpoint
|
||||
provider: RACProvider = self.executor.current_stage.provider
|
||||
token = ConnectionToken.objects.create(
|
||||
provider=provider,
|
||||
endpoint=endpoint,
|
||||
settings=self.executor.plan.context.get("connection_settings", {}),
|
||||
session=AuthenticatedSession.objects.filter(
|
||||
session_key=self.request.session.session_key
|
||||
).first(),
|
||||
expires=now() + timedelta_from_string(provider.connection_expiry),
|
||||
expiring=True,
|
||||
)
|
||||
setattr(
|
||||
self.executor.current_stage,
|
||||
"destination",
|
||||
self.request.build_absolute_uri(
|
||||
reverse(
|
||||
"authentik_providers_rac:if-rac",
|
||||
kwargs={
|
||||
"token": str(token.token),
|
||||
},
|
||||
)
|
||||
),
|
||||
)
|
||||
return super().get_challenge(*args, **kwargs)
|
|
@ -10,3 +10,7 @@ CELERY_BEAT_SCHEDULE = {
|
|||
"options": {"queue": "authentik_scheduled"},
|
||||
}
|
||||
}
|
||||
|
||||
INSTALLED_APPS = [
|
||||
"authentik.enterprise.providers.rac",
|
||||
]
|
||||
|
|
|
@ -6,6 +6,7 @@ import django_filters
|
|||
from django.db.models.aggregates import Count
|
||||
from django.db.models.fields.json import KeyTextTransform, KeyTransform
|
||||
from django.db.models.functions import ExtractDay, ExtractHour
|
||||
from django.db.models.query_utils import Q
|
||||
from drf_spectacular.types import OpenApiTypes
|
||||
from drf_spectacular.utils import OpenApiParameter, extend_schema
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
|
@ -87,7 +88,12 @@ class EventsFilter(django_filters.FilterSet):
|
|||
we need to remove the dashes that a client may send. We can't use a
|
||||
UUIDField for this, as some models might not have a UUID PK"""
|
||||
value = str(value).replace("-", "")
|
||||
return queryset.filter(context__model__pk=value)
|
||||
query = Q(context__model__pk=value)
|
||||
try:
|
||||
query |= Q(context__model__pk=int(value))
|
||||
except ValueError:
|
||||
pass
|
||||
return queryset.filter(query)
|
||||
|
||||
class Meta:
|
||||
model = Event
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
"""Event API tests"""
|
||||
from json import loads
|
||||
|
||||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
@ -11,6 +12,9 @@ from authentik.events.models import (
|
|||
NotificationSeverity,
|
||||
TransportMode,
|
||||
)
|
||||
from authentik.events.utils import model_to_dict
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.providers.oauth2.models import OAuth2Provider
|
||||
|
||||
|
||||
class TestEventsAPI(APITestCase):
|
||||
|
@ -20,6 +24,25 @@ class TestEventsAPI(APITestCase):
|
|||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_filter_model_pk_int(self):
|
||||
"""Test event list with context_model_pk and integer PKs"""
|
||||
provider = OAuth2Provider.objects.create(
|
||||
name=generate_id(),
|
||||
)
|
||||
event = Event.new(EventAction.MODEL_CREATED, model=model_to_dict(provider))
|
||||
event.save()
|
||||
response = self.client.get(
|
||||
reverse("authentik_api:event-list"),
|
||||
data={
|
||||
"context_model_pk": provider.pk,
|
||||
"context_model_app": "authentik_providers_oauth2",
|
||||
"context_model_name": "oauth2provider",
|
||||
},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
body = loads(response.content)
|
||||
self.assertEqual(body["pagination"]["count"], 1)
|
||||
|
||||
def test_top_n(self):
|
||||
"""Test top_per_user"""
|
||||
event = Event.new(EventAction.AUTHORIZE_APPLICATION)
|
||||
|
|
|
@ -17,8 +17,9 @@ from authentik.core.api.providers import ProviderSerializer
|
|||
from authentik.core.api.used_by import UsedByMixin
|
||||
from authentik.core.api.utils import JSONDictField, PassiveSerializer
|
||||
from authentik.core.models import Provider
|
||||
from authentik.enterprise.providers.rac.models import RACProvider
|
||||
from authentik.outposts.api.service_connections import ServiceConnectionSerializer
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME
|
||||
from authentik.outposts.models import (
|
||||
Outpost,
|
||||
OutpostConfig,
|
||||
|
@ -47,12 +48,23 @@ class OutpostSerializer(ModelSerializer):
|
|||
source="service_connection", read_only=True
|
||||
)
|
||||
|
||||
def validate_name(self, name: str) -> str:
|
||||
"""Validate name (especially for embedded outpost)"""
|
||||
if not self.instance:
|
||||
return name
|
||||
if self.instance.managed == MANAGED_OUTPOST and name != MANAGED_OUTPOST_NAME:
|
||||
raise ValidationError("Embedded outpost's name cannot be changed")
|
||||
if self.instance.name == MANAGED_OUTPOST_NAME:
|
||||
self.instance.managed = MANAGED_OUTPOST
|
||||
return name
|
||||
|
||||
def validate_providers(self, providers: list[Provider]) -> list[Provider]:
|
||||
"""Check that all providers match the type of the outpost"""
|
||||
type_map = {
|
||||
OutpostType.LDAP: LDAPProvider,
|
||||
OutpostType.PROXY: ProxyProvider,
|
||||
OutpostType.RADIUS: RadiusProvider,
|
||||
OutpostType.RAC: RACProvider,
|
||||
None: Provider,
|
||||
}
|
||||
for provider in providers:
|
||||
|
|
|
@ -15,6 +15,7 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
|
|||
["outpost", "uid", "version"],
|
||||
)
|
||||
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
|
||||
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
|
||||
|
||||
|
||||
class AuthentikOutpostConfig(ManagedAppConfig):
|
||||
|
@ -35,14 +36,17 @@ class AuthentikOutpostConfig(ManagedAppConfig):
|
|||
DockerServiceConnection,
|
||||
KubernetesServiceConnection,
|
||||
Outpost,
|
||||
OutpostConfig,
|
||||
OutpostType,
|
||||
)
|
||||
|
||||
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
|
||||
outpost.managed = MANAGED_OUTPOST
|
||||
outpost.save()
|
||||
return
|
||||
outpost, updated = Outpost.objects.update_or_create(
|
||||
defaults={
|
||||
"name": "authentik Embedded Outpost",
|
||||
"type": OutpostType.PROXY,
|
||||
"name": MANAGED_OUTPOST_NAME,
|
||||
},
|
||||
managed=MANAGED_OUTPOST,
|
||||
)
|
||||
|
@ -51,10 +55,4 @@ class AuthentikOutpostConfig(ManagedAppConfig):
|
|||
outpost.service_connection = KubernetesServiceConnection.objects.first()
|
||||
elif DockerServiceConnection.objects.exists():
|
||||
outpost.service_connection = DockerServiceConnection.objects.first()
|
||||
outpost.config = OutpostConfig(
|
||||
kubernetes_disabled_components=[
|
||||
"deployment",
|
||||
"secret",
|
||||
]
|
||||
)
|
||||
outpost.save()
|
||||
|
|
|
@ -6,16 +6,18 @@ from typing import Any, Optional
|
|||
|
||||
from asgiref.sync import async_to_sync
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.generic.websocket import JsonWebsocketConsumer
|
||||
from dacite.core import from_dict
|
||||
from dacite.data import Data
|
||||
from django.http.request import QueryDict
|
||||
from guardian.shortcuts import get_objects_for_user
|
||||
from structlog.stdlib import BoundLogger, get_logger
|
||||
|
||||
from authentik.core.channels import AuthJsonConsumer
|
||||
from authentik.outposts.apps import GAUGE_OUTPOSTS_CONNECTED, GAUGE_OUTPOSTS_LAST_UPDATE
|
||||
from authentik.outposts.models import OUTPOST_HELLO_INTERVAL, Outpost, OutpostState
|
||||
|
||||
OUTPOST_GROUP = "group_outpost_%(outpost_pk)s"
|
||||
OUTPOST_GROUP_INSTANCE = "group_outpost_%(outpost_pk)s_%(instance)s"
|
||||
|
||||
|
||||
class WebsocketMessageInstruction(IntEnum):
|
||||
|
@ -42,25 +44,23 @@ class WebsocketMessage:
|
|||
args: dict[str, Any] = field(default_factory=dict)
|
||||
|
||||
|
||||
class OutpostConsumer(AuthJsonConsumer):
|
||||
class OutpostConsumer(JsonWebsocketConsumer):
|
||||
"""Handler for Outposts that connect over websockets for health checks and live updates"""
|
||||
|
||||
outpost: Optional[Outpost] = None
|
||||
logger: BoundLogger
|
||||
|
||||
last_uid: Optional[str] = None
|
||||
instance_uid: Optional[str] = None
|
||||
|
||||
def __init__(self, *args, **kwargs):
|
||||
super().__init__(*args, **kwargs)
|
||||
self.logger = get_logger()
|
||||
|
||||
def connect(self):
|
||||
super().connect()
|
||||
uuid = self.scope["url_route"]["kwargs"]["pk"]
|
||||
user = self.scope["user"]
|
||||
outpost = (
|
||||
get_objects_for_user(self.user, "authentik_outposts.view_outpost")
|
||||
.filter(pk=uuid)
|
||||
.first()
|
||||
get_objects_for_user(user, "authentik_outposts.view_outpost").filter(pk=uuid).first()
|
||||
)
|
||||
if not outpost:
|
||||
raise DenyConnection()
|
||||
|
@ -71,13 +71,19 @@ class OutpostConsumer(AuthJsonConsumer):
|
|||
self.logger.warning("runtime error during accept", exc=exc)
|
||||
raise DenyConnection()
|
||||
self.outpost = outpost
|
||||
self.last_uid = self.channel_name
|
||||
query = QueryDict(self.scope["query_string"].decode())
|
||||
self.instance_uid = query.get("instance_uuid", self.channel_name)
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
async_to_sync(self.channel_layer.group_add)(
|
||||
OUTPOST_GROUP_INSTANCE
|
||||
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||
self.channel_name,
|
||||
)
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid,
|
||||
uid=self.instance_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).inc()
|
||||
|
||||
|
@ -86,34 +92,37 @@ class OutpostConsumer(AuthJsonConsumer):
|
|||
async_to_sync(self.channel_layer.group_discard)(
|
||||
OUTPOST_GROUP % {"outpost_pk": str(self.outpost.pk)}, self.channel_name
|
||||
)
|
||||
if self.outpost and self.last_uid:
|
||||
if self.instance_uid:
|
||||
async_to_sync(self.channel_layer.group_discard)(
|
||||
OUTPOST_GROUP_INSTANCE
|
||||
% {"outpost_pk": str(self.outpost.pk), "instance": self.instance_uid},
|
||||
self.channel_name,
|
||||
)
|
||||
if self.outpost and self.instance_uid:
|
||||
GAUGE_OUTPOSTS_CONNECTED.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid,
|
||||
uid=self.instance_uid,
|
||||
expected=self.outpost.config.kubernetes_replicas,
|
||||
).dec()
|
||||
|
||||
def receive_json(self, content: Data, **kwargs):
|
||||
msg = from_dict(WebsocketMessage, content)
|
||||
uid = msg.args.get("uuid", self.channel_name)
|
||||
self.last_uid = uid
|
||||
|
||||
if not self.outpost:
|
||||
raise DenyConnection()
|
||||
|
||||
state = OutpostState.for_instance_uid(self.outpost, uid)
|
||||
state = OutpostState.for_instance_uid(self.outpost, self.instance_uid)
|
||||
state.last_seen = datetime.now()
|
||||
state.hostname = msg.args.pop("hostname", "")
|
||||
|
||||
if msg.instruction == WebsocketMessageInstruction.HELLO:
|
||||
state.version = msg.args.pop("version", None)
|
||||
state.build_hash = msg.args.pop("buildHash", "")
|
||||
state.args = msg.args
|
||||
state.args.update(msg.args)
|
||||
elif msg.instruction == WebsocketMessageInstruction.ACK:
|
||||
return
|
||||
GAUGE_OUTPOSTS_LAST_UPDATE.labels(
|
||||
outpost=self.outpost.name,
|
||||
uid=self.last_uid or "",
|
||||
uid=self.instance_uid or "",
|
||||
version=state.version or "",
|
||||
).set_to_current_time()
|
||||
state.save(timeout=OUTPOST_HELLO_INTERVAL * 1.5)
|
||||
|
|
|
@ -43,6 +43,10 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
|
|||
self.api = AppsV1Api(controller.client)
|
||||
self.outpost = self.controller.outpost
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return self.is_embedded
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "deployment"
|
||||
|
|
|
@ -24,6 +24,10 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
|
|||
super().__init__(controller)
|
||||
self.api = CoreV1Api(controller.client)
|
||||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return self.is_embedded
|
||||
|
||||
@staticmethod
|
||||
def reconciler_name() -> str:
|
||||
return "secret"
|
||||
|
|
|
@ -77,7 +77,10 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
|
|||
|
||||
@property
|
||||
def noop(self) -> bool:
|
||||
return (not self._crd_exists()) or (self.is_embedded)
|
||||
if not self._crd_exists():
|
||||
self.logger.debug("CRD doesn't exist")
|
||||
return True
|
||||
return self.is_embedded
|
||||
|
||||
def _crd_exists(self) -> bool:
|
||||
"""Check if the Prometheus ServiceMonitor exists"""
|
||||
|
|
|
@ -1,5 +1,6 @@
|
|||
"""k8s utils"""
|
||||
from pathlib import Path
|
||||
from typing import Optional
|
||||
|
||||
from kubernetes.client.models.v1_container_port import V1ContainerPort
|
||||
from kubernetes.client.models.v1_service_port import V1ServicePort
|
||||
|
@ -37,9 +38,12 @@ def compare_port(
|
|||
|
||||
|
||||
def compare_ports(
|
||||
current: list[V1ServicePort | V1ContainerPort], reference: list[V1ServicePort | V1ContainerPort]
|
||||
current: Optional[list[V1ServicePort | V1ContainerPort]],
|
||||
reference: Optional[list[V1ServicePort | V1ContainerPort]],
|
||||
):
|
||||
"""Compare ports of a list"""
|
||||
if not current or not reference:
|
||||
raise NeedsRecreate()
|
||||
if len(current) != len(reference):
|
||||
raise NeedsRecreate()
|
||||
for port in reference:
|
||||
|
|
|
@ -81,7 +81,10 @@ class KubernetesController(BaseController):
|
|||
def up(self):
|
||||
try:
|
||||
for reconcile_key in self.reconcile_order:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
reconciler.up()
|
||||
|
||||
except (OpenApiException, HTTPError, ServiceConnectionInvalid) as exc:
|
||||
|
@ -95,7 +98,10 @@ class KubernetesController(BaseController):
|
|||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||
continue
|
||||
with capture_logs() as logs:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
reconciler.up()
|
||||
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
||||
return all_logs
|
||||
|
@ -105,7 +111,10 @@ class KubernetesController(BaseController):
|
|||
def down(self):
|
||||
try:
|
||||
for reconcile_key in self.reconcile_order:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
self.logger.debug("Tearing down object", name=reconcile_key)
|
||||
reconciler.down()
|
||||
|
||||
|
@ -120,7 +129,10 @@ class KubernetesController(BaseController):
|
|||
all_logs += [f"{reconcile_key.title()}: Disabled"]
|
||||
continue
|
||||
with capture_logs() as logs:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
reconciler.down()
|
||||
all_logs += [f"{reconcile_key.title()}: {x['event']}" for x in logs]
|
||||
return all_logs
|
||||
|
@ -130,7 +142,10 @@ class KubernetesController(BaseController):
|
|||
def get_static_deployment(self) -> str:
|
||||
documents = []
|
||||
for reconcile_key in self.reconcile_order:
|
||||
reconciler = self.reconcilers[reconcile_key](self)
|
||||
reconciler_cls = self.reconcilers.get(reconcile_key)
|
||||
if not reconciler_cls:
|
||||
continue
|
||||
reconciler = reconciler_cls(self)
|
||||
if reconciler.noop:
|
||||
continue
|
||||
documents.append(reconciler.get_reference_object().to_dict())
|
||||
|
|
25
authentik/outposts/migrations/0021_alter_outpost_type.py
Normal file
25
authentik/outposts/migrations/0021_alter_outpost_type.py
Normal file
|
@ -0,0 +1,25 @@
|
|||
# Generated by Django 4.2.6 on 2023-10-14 19:23
|
||||
|
||||
from django.db import migrations, models
|
||||
|
||||
|
||||
class Migration(migrations.Migration):
|
||||
dependencies = [
|
||||
("authentik_outposts", "0020_alter_outpost_type"),
|
||||
]
|
||||
|
||||
operations = [
|
||||
migrations.AlterField(
|
||||
model_name="outpost",
|
||||
name="type",
|
||||
field=models.TextField(
|
||||
choices=[
|
||||
("proxy", "Proxy"),
|
||||
("ldap", "Ldap"),
|
||||
("radius", "Radius"),
|
||||
("rac", "Rac"),
|
||||
],
|
||||
default="proxy",
|
||||
),
|
||||
),
|
||||
]
|
|
@ -90,11 +90,12 @@ class OutpostModel(Model):
|
|||
|
||||
|
||||
class OutpostType(models.TextChoices):
|
||||
"""Outpost types, currently only the reverse proxy is available"""
|
||||
"""Outpost types"""
|
||||
|
||||
PROXY = "proxy"
|
||||
LDAP = "ldap"
|
||||
RADIUS = "radius"
|
||||
RAC = "rac"
|
||||
|
||||
|
||||
def default_outpost_config(host: Optional[str] = None):
|
||||
|
@ -459,7 +460,7 @@ class OutpostState:
|
|||
def for_instance_uid(outpost: Outpost, uid: str) -> "OutpostState":
|
||||
"""Get state for a single instance"""
|
||||
key = f"{outpost.state_cache_prefix}/{uid}"
|
||||
default_data = {"uid": uid, "channel_ids": []}
|
||||
default_data = {"uid": uid}
|
||||
data = cache.get(key, default_data)
|
||||
if isinstance(data, str):
|
||||
cache.delete(key)
|
||||
|
|
|
@ -17,6 +17,8 @@ from kubernetes.config.kube_config import KUBE_CONFIG_DEFAULT_LOCATION
|
|||
from structlog.stdlib import get_logger
|
||||
from yaml import safe_load
|
||||
|
||||
from authentik.enterprise.providers.rac.controllers.docker import RACDockerController
|
||||
from authentik.enterprise.providers.rac.controllers.kubernetes import RACKubernetesController
|
||||
from authentik.events.monitored_tasks import (
|
||||
MonitoredTask,
|
||||
TaskResult,
|
||||
|
@ -71,6 +73,11 @@ def controller_for_outpost(outpost: Outpost) -> Optional[type[BaseController]]:
|
|||
return RadiusDockerController
|
||||
if isinstance(service_connection, KubernetesServiceConnection):
|
||||
return RadiusKubernetesController
|
||||
if outpost.type == OutpostType.RAC:
|
||||
if isinstance(service_connection, DockerServiceConnection):
|
||||
return RACDockerController
|
||||
if isinstance(service_connection, KubernetesServiceConnection):
|
||||
return RACKubernetesController
|
||||
return None
|
||||
|
||||
|
||||
|
|
|
@ -2,11 +2,13 @@
|
|||
from django.urls import reverse
|
||||
from rest_framework.test import APITestCase
|
||||
|
||||
from authentik.blueprints.tests import reconcile_app
|
||||
from authentik.core.models import PropertyMapping
|
||||
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.outposts.api.outposts import OutpostSerializer
|
||||
from authentik.outposts.models import OutpostType, default_outpost_config
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.models import Outpost, OutpostType, default_outpost_config
|
||||
from authentik.providers.ldap.models import LDAPProvider
|
||||
from authentik.providers.proxy.models import ProxyProvider
|
||||
|
||||
|
@ -22,7 +24,36 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
|
|||
self.user = create_test_admin_user()
|
||||
self.client.force_login(self.user)
|
||||
|
||||
def test_outpost_validaton(self):
|
||||
@reconcile_app("authentik_outposts")
|
||||
def test_managed_name_change(self):
|
||||
"""Test name change for embedded outpost"""
|
||||
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||
self.assertIsNotNone(embedded_outpost)
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
|
||||
{"name": "foo"},
|
||||
)
|
||||
self.assertEqual(response.status_code, 400)
|
||||
self.assertJSONEqual(
|
||||
response.content, {"name": ["Embedded outpost's name cannot be changed"]}
|
||||
)
|
||||
|
||||
@reconcile_app("authentik_outposts")
|
||||
def test_managed_without_managed(self):
|
||||
"""Test name change for embedded outpost"""
|
||||
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
|
||||
self.assertIsNotNone(embedded_outpost)
|
||||
embedded_outpost.managed = ""
|
||||
embedded_outpost.save()
|
||||
response = self.client.patch(
|
||||
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
|
||||
{"name": "foo"},
|
||||
)
|
||||
self.assertEqual(response.status_code, 200)
|
||||
embedded_outpost.refresh_from_db()
|
||||
self.assertEqual(embedded_outpost.managed, MANAGED_OUTPOST)
|
||||
|
||||
def test_outpost_validation(self):
|
||||
"""Test Outpost validation"""
|
||||
valid = OutpostSerializer(
|
||||
data={
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Websocket tests"""
|
||||
from dataclasses import asdict
|
||||
|
||||
from channels.exceptions import DenyConnection
|
||||
from channels.routing import URLRouter
|
||||
from channels.testing import WebsocketCommunicator
|
||||
from django.test import TransactionTestCase
|
||||
|
@ -35,6 +36,7 @@ class TestOutpostWS(TransactionTestCase):
|
|||
communicator = WebsocketCommunicator(
|
||||
URLRouter(websocket.websocket_urlpatterns), f"/ws/outpost/{self.outpost.pk}/"
|
||||
)
|
||||
with self.assertRaises(DenyConnection):
|
||||
connected, _ = await communicator.connect()
|
||||
self.assertFalse(connected)
|
||||
|
||||
|
|
|
@ -1,6 +1,7 @@
|
|||
"""Outpost Websocket URLS"""
|
||||
from django.urls import path
|
||||
|
||||
from authentik.core.channels import TokenOutpostMiddleware
|
||||
from authentik.outposts.api.outposts import OutpostViewSet
|
||||
from authentik.outposts.api.service_connections import (
|
||||
DockerServiceConnectionViewSet,
|
||||
|
@ -11,7 +12,10 @@ from authentik.outposts.consumer import OutpostConsumer
|
|||
from authentik.root.middleware import ChannelsLoggingMiddleware
|
||||
|
||||
websocket_urlpatterns = [
|
||||
path("ws/outpost/<uuid:pk>/", ChannelsLoggingMiddleware(OutpostConsumer.as_asgi())),
|
||||
path(
|
||||
"ws/outpost/<uuid:pk>/",
|
||||
ChannelsLoggingMiddleware(TokenOutpostMiddleware(OutpostConsumer.as_asgi())),
|
||||
),
|
||||
]
|
||||
|
||||
api_urlpatterns = [
|
||||
|
|
|
@ -40,10 +40,9 @@ class PytestTestRunner(DiscoverRunner): # pragma: no cover
|
|||
f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}",
|
||||
)
|
||||
CONFIG.set("error_reporting.sample_rate", 0)
|
||||
sentry_init(
|
||||
environment="testing",
|
||||
send_default_pii=True,
|
||||
)
|
||||
CONFIG.set("error_reporting.environment", "testing")
|
||||
CONFIG.set("error_reporting.send_pii", True)
|
||||
sentry_init()
|
||||
|
||||
@classmethod
|
||||
def add_arguments(cls, parser: ArgumentParser):
|
||||
|
|
|
@ -99,7 +99,9 @@ class OAuthSourceSerializer(SourceSerializer):
|
|||
]:
|
||||
if getattr(provider_type, url, None) is None:
|
||||
if url not in attrs:
|
||||
raise ValidationError(f"{url} is required for provider {provider_type.name}")
|
||||
raise ValidationError(
|
||||
f"{url} is required for provider {provider_type.verbose_name}"
|
||||
)
|
||||
return attrs
|
||||
|
||||
class Meta:
|
||||
|
|
|
@ -104,8 +104,8 @@ class AppleType(SourceType):
|
|||
|
||||
callback_view = AppleOAuth2Callback
|
||||
redirect_view = AppleOAuthRedirect
|
||||
name = "Apple"
|
||||
slug = "apple"
|
||||
verbose_name = "Apple"
|
||||
name = "apple"
|
||||
|
||||
authorization_url = "https://appleid.apple.com/auth/authorize"
|
||||
access_token_url = "https://appleid.apple.com/auth/token" # nosec
|
||||
|
|
|
@ -43,8 +43,8 @@ class AzureADType(SourceType):
|
|||
|
||||
callback_view = AzureADOAuthCallback
|
||||
redirect_view = AzureADOAuthRedirect
|
||||
name = "Azure AD"
|
||||
slug = "azuread"
|
||||
verbose_name = "Azure AD"
|
||||
name = "azuread"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
|
|
|
@ -36,8 +36,8 @@ class DiscordType(SourceType):
|
|||
|
||||
callback_view = DiscordOAuth2Callback
|
||||
redirect_view = DiscordOAuthRedirect
|
||||
name = "Discord"
|
||||
slug = "discord"
|
||||
verbose_name = "Discord"
|
||||
name = "discord"
|
||||
|
||||
authorization_url = "https://discord.com/api/oauth2/authorize"
|
||||
access_token_url = "https://discord.com/api/oauth2/token" # nosec
|
||||
|
|
|
@ -48,8 +48,8 @@ class FacebookType(SourceType):
|
|||
|
||||
callback_view = FacebookOAuth2Callback
|
||||
redirect_view = FacebookOAuthRedirect
|
||||
name = "Facebook"
|
||||
slug = "facebook"
|
||||
verbose_name = "Facebook"
|
||||
name = "facebook"
|
||||
|
||||
authorization_url = "https://www.facebook.com/v7.0/dialog/oauth"
|
||||
access_token_url = "https://graph.facebook.com/v7.0/oauth/access_token" # nosec
|
||||
|
|
|
@ -68,8 +68,8 @@ class GitHubType(SourceType):
|
|||
|
||||
callback_view = GitHubOAuth2Callback
|
||||
redirect_view = GitHubOAuthRedirect
|
||||
name = "GitHub"
|
||||
slug = "github"
|
||||
verbose_name = "GitHub"
|
||||
name = "github"
|
||||
|
||||
urls_customizable = True
|
||||
|
||||
|
|
|
@ -34,8 +34,8 @@ class GoogleType(SourceType):
|
|||
|
||||
callback_view = GoogleOAuth2Callback
|
||||
redirect_view = GoogleOAuthRedirect
|
||||
name = "Google"
|
||||
slug = "google"
|
||||
verbose_name = "Google"
|
||||
name = "google"
|
||||
|
||||
authorization_url = "https://accounts.google.com/o/oauth2/auth"
|
||||
access_token_url = "https://oauth2.googleapis.com/token" # nosec
|
||||
|
|
|
@ -63,7 +63,7 @@ class MailcowType(SourceType):
|
|||
|
||||
callback_view = MailcowOAuth2Callback
|
||||
redirect_view = MailcowOAuthRedirect
|
||||
name = "Mailcow"
|
||||
slug = "mailcow"
|
||||
verbose_name = "Mailcow"
|
||||
name = "mailcow"
|
||||
|
||||
urls_customizable = True
|
||||
|
|
|
@ -42,7 +42,7 @@ class OpenIDConnectType(SourceType):
|
|||
|
||||
callback_view = OpenIDConnectOAuth2Callback
|
||||
redirect_view = OpenIDConnectOAuthRedirect
|
||||
name = "OpenID Connect"
|
||||
slug = "openidconnect"
|
||||
verbose_name = "OpenID Connect"
|
||||
name = "openidconnect"
|
||||
|
||||
urls_customizable = True
|
||||
|
|
|
@ -42,7 +42,7 @@ class OktaType(SourceType):
|
|||
|
||||
callback_view = OktaOAuth2Callback
|
||||
redirect_view = OktaOAuthRedirect
|
||||
name = "Okta"
|
||||
slug = "okta"
|
||||
verbose_name = "Okta"
|
||||
name = "okta"
|
||||
|
||||
urls_customizable = True
|
||||
|
|
|
@ -43,8 +43,8 @@ class PatreonType(SourceType):
|
|||
|
||||
callback_view = PatreonOAuthCallback
|
||||
redirect_view = PatreonOAuthRedirect
|
||||
name = "Patreon"
|
||||
slug = "patreon"
|
||||
verbose_name = "Patreon"
|
||||
name = "patreon"
|
||||
|
||||
authorization_url = "https://www.patreon.com/oauth2/authorize"
|
||||
access_token_url = "https://www.patreon.com/api/oauth2/token" # nosec
|
||||
|
|
|
@ -51,8 +51,8 @@ class RedditType(SourceType):
|
|||
|
||||
callback_view = RedditOAuth2Callback
|
||||
redirect_view = RedditOAuthRedirect
|
||||
name = "Reddit"
|
||||
slug = "reddit"
|
||||
verbose_name = "Reddit"
|
||||
name = "reddit"
|
||||
|
||||
authorization_url = "https://www.reddit.com/api/v1/authorize"
|
||||
access_token_url = "https://www.reddit.com/api/v1/access_token" # nosec
|
||||
|
|
|
@ -28,7 +28,7 @@ class SourceType:
|
|||
callback_view = OAuthCallback
|
||||
redirect_view = OAuthRedirect
|
||||
name: str = "default"
|
||||
slug: str = "default"
|
||||
verbose_name: str = "Default source type"
|
||||
|
||||
urls_customizable = False
|
||||
|
||||
|
@ -41,7 +41,7 @@ class SourceType:
|
|||
|
||||
def icon_url(self) -> str:
|
||||
"""Get Icon URL for login"""
|
||||
return static(f"authentik/sources/{self.slug}.svg")
|
||||
return static(f"authentik/sources/{self.name}.svg")
|
||||
|
||||
def login_challenge(self, source: OAuthSource, request: HttpRequest) -> Challenge:
|
||||
"""Allow types to return custom challenges"""
|
||||
|
@ -77,20 +77,20 @@ class SourceTypeRegistry:
|
|||
|
||||
def get_name_tuple(self):
|
||||
"""Get list of tuples of all registered names"""
|
||||
return [(x.slug, x.name) for x in self.__sources]
|
||||
return [(x.name, x.verbose_name) for x in self.__sources]
|
||||
|
||||
def find_type(self, type_name: str) -> Type[SourceType]:
|
||||
"""Find type based on source"""
|
||||
found_type = None
|
||||
for src_type in self.__sources:
|
||||
if src_type.slug == type_name:
|
||||
if src_type.name == type_name:
|
||||
return src_type
|
||||
if not found_type:
|
||||
found_type = SourceType
|
||||
LOGGER.warning(
|
||||
"no matching type found, using default",
|
||||
wanted=type_name,
|
||||
have=[x.slug for x in self.__sources],
|
||||
have=[x.name for x in self.__sources],
|
||||
)
|
||||
return found_type
|
||||
|
||||
|
|
|
@ -49,8 +49,8 @@ class TwitchType(SourceType):
|
|||
|
||||
callback_view = TwitchOAuth2Callback
|
||||
redirect_view = TwitchOAuthRedirect
|
||||
name = "Twitch"
|
||||
slug = "twitch"
|
||||
verbose_name = "Twitch"
|
||||
name = "twitch"
|
||||
|
||||
authorization_url = "https://id.twitch.tv/oauth2/authorize"
|
||||
access_token_url = "https://id.twitch.tv/oauth2/token" # nosec
|
||||
|
|
|
@ -66,8 +66,8 @@ class TwitterType(SourceType):
|
|||
|
||||
callback_view = TwitterOAuthCallback
|
||||
redirect_view = TwitterOAuthRedirect
|
||||
name = "Twitter"
|
||||
slug = "twitter"
|
||||
verbose_name = "Twitter"
|
||||
name = "twitter"
|
||||
|
||||
authorization_url = "https://twitter.com/i/oauth2/authorize"
|
||||
access_token_url = "https://api.twitter.com/2/oauth2/token" # nosec
|
||||
|
|
|
@ -2779,6 +2779,117 @@
|
|||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_providers_rac.racprovider"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"present",
|
||||
"created",
|
||||
"must_created"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_providers_rac.racprovider"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_providers_rac.racprovider"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_providers_rac.endpoint"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"present",
|
||||
"created",
|
||||
"must_created"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_providers_rac.endpoint"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_providers_rac.endpoint"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
"model",
|
||||
"identifiers"
|
||||
],
|
||||
"properties": {
|
||||
"model": {
|
||||
"const": "authentik_providers_rac.racpropertymapping"
|
||||
},
|
||||
"id": {
|
||||
"type": "string"
|
||||
},
|
||||
"state": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"absent",
|
||||
"present",
|
||||
"created",
|
||||
"must_created"
|
||||
],
|
||||
"default": "present"
|
||||
},
|
||||
"conditions": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "boolean"
|
||||
}
|
||||
},
|
||||
"attrs": {
|
||||
"$ref": "#/$defs/model_authentik_providers_rac.racpropertymapping"
|
||||
},
|
||||
"identifiers": {
|
||||
"$ref": "#/$defs/model_authentik_providers_rac.racpropertymapping"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "object",
|
||||
"required": [
|
||||
|
@ -3296,7 +3407,8 @@
|
|||
"enum": [
|
||||
"proxy",
|
||||
"ldap",
|
||||
"radius"
|
||||
"radius",
|
||||
"rac"
|
||||
],
|
||||
"title": "Type"
|
||||
},
|
||||
|
@ -3476,7 +3588,8 @@
|
|||
"authentik.tenants",
|
||||
"authentik.blueprints",
|
||||
"authentik.core",
|
||||
"authentik.enterprise"
|
||||
"authentik.enterprise",
|
||||
"authentik.enterprise.providers.rac"
|
||||
],
|
||||
"title": "App",
|
||||
"description": "Match events created by selected application. When left empty, all applications are matched."
|
||||
|
@ -3561,7 +3674,10 @@
|
|||
"authentik_core.user",
|
||||
"authentik_core.application",
|
||||
"authentik_core.token",
|
||||
"authentik_enterprise.license"
|
||||
"authentik_enterprise.license",
|
||||
"authentik_providers_rac.racprovider",
|
||||
"authentik_providers_rac.endpoint",
|
||||
"authentik_providers_rac.racpropertymapping"
|
||||
],
|
||||
"title": "Model",
|
||||
"description": "Match events created by selected model. When left empty, all models are matched. When an app is selected, all the application's models are matched."
|
||||
|
@ -8758,6 +8874,123 @@
|
|||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_providers_rac.racprovider": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Name"
|
||||
},
|
||||
"authentication_flow": {
|
||||
"type": "integer",
|
||||
"title": "Authentication flow",
|
||||
"description": "Flow used for authentication when the associated application is accessed by an un-authenticated user."
|
||||
},
|
||||
"authorization_flow": {
|
||||
"type": "integer",
|
||||
"title": "Authorization flow",
|
||||
"description": "Flow used when authorizing this provider."
|
||||
},
|
||||
"property_mappings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
},
|
||||
"title": "Property mappings"
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Settings"
|
||||
},
|
||||
"connection_expiry": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Connection expiry",
|
||||
"description": "Determines how long a session lasts. Default of 0 means that the sessions lasts until the browser is closed. (Format: hours=-1;minutes=-2;seconds=-3)"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_providers_rac.endpoint": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Name"
|
||||
},
|
||||
"provider": {
|
||||
"type": "integer",
|
||||
"title": "Provider"
|
||||
},
|
||||
"protocol": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"rdp",
|
||||
"vnc",
|
||||
"ssh"
|
||||
],
|
||||
"title": "Protocol"
|
||||
},
|
||||
"host": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Host"
|
||||
},
|
||||
"settings": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Settings"
|
||||
},
|
||||
"property_mappings": {
|
||||
"type": "array",
|
||||
"items": {
|
||||
"type": "integer"
|
||||
},
|
||||
"title": "Property mappings"
|
||||
},
|
||||
"auth_mode": {
|
||||
"type": "string",
|
||||
"enum": [
|
||||
"static",
|
||||
"prompt"
|
||||
],
|
||||
"title": "Auth mode"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_providers_rac.racpropertymapping": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
"managed": {
|
||||
"type": [
|
||||
"string",
|
||||
"null"
|
||||
],
|
||||
"minLength": 1,
|
||||
"title": "Managed by authentik",
|
||||
"description": "Objects that are managed by authentik. These objects are created and updated automatically. This flag only indicates that an object can be overwritten by migrations. You can still modify the objects via the API, but expect changes to be overwritten in a later update."
|
||||
},
|
||||
"name": {
|
||||
"type": "string",
|
||||
"minLength": 1,
|
||||
"title": "Name"
|
||||
},
|
||||
"expression": {
|
||||
"type": "string",
|
||||
"title": "Expression"
|
||||
},
|
||||
"static_settings": {
|
||||
"type": "object",
|
||||
"additionalProperties": true,
|
||||
"title": "Static settings"
|
||||
}
|
||||
},
|
||||
"required": []
|
||||
},
|
||||
"model_authentik_blueprints.metaapplyblueprint": {
|
||||
"type": "object",
|
||||
"properties": {
|
||||
|
|
32
blueprints/system/providers-rac.yaml
Normal file
32
blueprints/system/providers-rac.yaml
Normal file
|
@ -0,0 +1,32 @@
|
|||
version: 1
|
||||
metadata:
|
||||
labels:
|
||||
blueprints.goauthentik.io/system: "true"
|
||||
name: System - RAC Provider - Mappings
|
||||
entries:
|
||||
- identifiers:
|
||||
managed: goauthentik.io/providers/rac/rdp-default
|
||||
model: authentik_providers_rac.racpropertymapping
|
||||
attrs:
|
||||
name: "authentik default RAC Mapping: RDP Default settings"
|
||||
static_settings:
|
||||
resize-method: "display-update"
|
||||
enable-wallpaper: "true"
|
||||
enable-font-smoothing: "true"
|
||||
- identifiers:
|
||||
managed: goauthentik.io/providers/rac/rdp-high-fidelity
|
||||
model: authentik_providers_rac.racpropertymapping
|
||||
attrs:
|
||||
name: "authentik default RAC Mapping: RDP High Fidelity"
|
||||
static_settings:
|
||||
enable-theming: "true"
|
||||
enable-full-window-drag: "true"
|
||||
enable-desktop-composition: "true"
|
||||
enable-menu-animations: "true"
|
||||
- identifiers:
|
||||
managed: goauthentik.io/providers/rac/ssh-default
|
||||
model: authentik_providers_rac.racpropertymapping
|
||||
attrs:
|
||||
name: "authentik default RAC Mapping: SSH Default settings"
|
||||
static_settings:
|
||||
terminal-type: "xterm-256color"
|
93
cmd/rac/main.go
Normal file
93
cmd/rac/main.go
Normal file
|
@ -0,0 +1,93 @@
|
|||
package main
|
||||
|
||||
import (
|
||||
"fmt"
|
||||
"net/url"
|
||||
"os"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/spf13/cobra"
|
||||
|
||||
"goauthentik.io/internal/common"
|
||||
"goauthentik.io/internal/debug"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/ak/healthcheck"
|
||||
"goauthentik.io/internal/outpost/rac"
|
||||
)
|
||||
|
||||
const helpMessage = `authentik RAC
|
||||
|
||||
Required environment variables:
|
||||
- AUTHENTIK_HOST: URL to connect to (format "http://authentik.company")
|
||||
- AUTHENTIK_TOKEN: Token to authenticate with
|
||||
- AUTHENTIK_INSECURE: Skip SSL Certificate verification`
|
||||
|
||||
var rootCmd = &cobra.Command{
|
||||
Long: helpMessage,
|
||||
PersistentPreRun: func(cmd *cobra.Command, args []string) {
|
||||
log.SetLevel(log.DebugLevel)
|
||||
log.SetFormatter(&log.JSONFormatter{
|
||||
FieldMap: log.FieldMap{
|
||||
log.FieldKeyMsg: "event",
|
||||
log.FieldKeyTime: "timestamp",
|
||||
},
|
||||
DisableHTMLEscape: true,
|
||||
})
|
||||
},
|
||||
Run: func(cmd *cobra.Command, args []string) {
|
||||
debug.EnableDebugServer()
|
||||
akURL, found := os.LookupEnv("AUTHENTIK_HOST")
|
||||
if !found {
|
||||
fmt.Println("env AUTHENTIK_HOST not set!")
|
||||
fmt.Println(helpMessage)
|
||||
os.Exit(1)
|
||||
}
|
||||
akToken, found := os.LookupEnv("AUTHENTIK_TOKEN")
|
||||
if !found {
|
||||
fmt.Println("env AUTHENTIK_TOKEN not set!")
|
||||
fmt.Println(helpMessage)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
akURLActual, err := url.Parse(akURL)
|
||||
if err != nil {
|
||||
fmt.Println(err)
|
||||
fmt.Println(helpMessage)
|
||||
os.Exit(1)
|
||||
}
|
||||
|
||||
ex := common.Init()
|
||||
defer common.Defer()
|
||||
go func() {
|
||||
for {
|
||||
<-ex
|
||||
os.Exit(0)
|
||||
}
|
||||
}()
|
||||
|
||||
ac := ak.NewAPIController(*akURLActual, akToken)
|
||||
if ac == nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
defer ac.Shutdown()
|
||||
|
||||
ac.Server = rac.NewServer(ac)
|
||||
|
||||
err = ac.Start()
|
||||
if err != nil {
|
||||
log.WithError(err).Panic("Failed to run server")
|
||||
}
|
||||
|
||||
for {
|
||||
<-ex
|
||||
}
|
||||
},
|
||||
}
|
||||
|
||||
func main() {
|
||||
rootCmd.AddCommand(healthcheck.Command)
|
||||
err := rootCmd.Execute()
|
||||
if err != nil {
|
||||
os.Exit(1)
|
||||
}
|
||||
}
|
15
go.mod
15
go.mod
|
@ -10,7 +10,7 @@ require (
|
|||
github.com/go-http-utils/etag v0.0.0-20161124023236-513ea8f21eb1
|
||||
github.com/go-ldap/ldap/v3 v3.4.6
|
||||
github.com/go-openapi/runtime v0.26.2
|
||||
github.com/go-openapi/strfmt v0.21.10
|
||||
github.com/go-openapi/strfmt v0.22.0
|
||||
github.com/golang-jwt/jwt v3.2.2+incompatible
|
||||
github.com/google/uuid v1.5.0
|
||||
github.com/gorilla/handlers v1.5.2
|
||||
|
@ -22,12 +22,13 @@ require (
|
|||
github.com/mitchellh/mapstructure v1.5.0
|
||||
github.com/nmcclain/asn1-ber v0.0.0-20170104154839-2661553a0484
|
||||
github.com/pires/go-proxyproto v0.7.0
|
||||
github.com/prometheus/client_golang v1.17.0
|
||||
github.com/prometheus/client_golang v1.18.0
|
||||
github.com/redis/go-redis/v9 v9.3.1
|
||||
github.com/sirupsen/logrus v1.9.3
|
||||
github.com/spf13/cobra v1.8.0
|
||||
github.com/stretchr/testify v1.8.4
|
||||
goauthentik.io/api/v3 v3.2023105.2
|
||||
github.com/wwt/guac v1.3.2
|
||||
goauthentik.io/api/v3 v3.2023105.3
|
||||
golang.org/x/exp v0.0.0-20230210204819-062eb4c674ab
|
||||
golang.org/x/oauth2 v0.15.0
|
||||
golang.org/x/sync v0.5.0
|
||||
|
@ -60,14 +61,14 @@ require (
|
|||
github.com/inconshreveable/mousetrap v1.1.0 // indirect
|
||||
github.com/josharian/intern v1.0.0 // indirect
|
||||
github.com/mailru/easyjson v0.7.7 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 // indirect
|
||||
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 // indirect
|
||||
github.com/oklog/ulid v1.3.1 // indirect
|
||||
github.com/opentracing/opentracing-go v1.2.0 // indirect
|
||||
github.com/pmezard/go-difflib v1.0.0 // indirect
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac // indirect
|
||||
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 // indirect
|
||||
github.com/prometheus/common v0.44.0 // indirect
|
||||
github.com/prometheus/procfs v0.11.1 // indirect
|
||||
github.com/prometheus/client_model v0.5.0 // indirect
|
||||
github.com/prometheus/common v0.45.0 // indirect
|
||||
github.com/prometheus/procfs v0.12.0 // indirect
|
||||
github.com/spf13/pflag v1.0.5 // indirect
|
||||
go.mongodb.org/mongo-driver v1.13.1 // indirect
|
||||
go.opentelemetry.io/otel v1.17.0 // indirect
|
||||
|
|
36
go.sum
36
go.sum
|
@ -116,8 +116,8 @@ github.com/go-openapi/spec v0.20.6/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6
|
|||
github.com/go-openapi/spec v0.20.11 h1:J/TzFDLTt4Rcl/l1PmyErvkqlJDncGvPTMnCI39I4gY=
|
||||
github.com/go-openapi/spec v0.20.11/go.mod h1:2OpW+JddWPrpXSCIX8eOx7lZ5iyuWj3RYR6VaaBKcWA=
|
||||
github.com/go-openapi/strfmt v0.21.3/go.mod h1:k+RzNO0Da+k3FrrynSNN8F7n/peCmQQqbbXjtDfvmGg=
|
||||
github.com/go-openapi/strfmt v0.21.10 h1:JIsly3KXZB/Qf4UzvzJpg4OELH/0ASDQsyk//TTBDDk=
|
||||
github.com/go-openapi/strfmt v0.21.10/go.mod h1:vNDMwbilnl7xKiO/Ve/8H8Bb2JIInBnH+lqiw6QWgis=
|
||||
github.com/go-openapi/strfmt v0.22.0 h1:Ew9PnEYc246TwrEspvBdDHS4BVKXy/AOVsfqGDgAcaI=
|
||||
github.com/go-openapi/strfmt v0.22.0/go.mod h1:HzJ9kokGIju3/K6ap8jL+OlGAbjpSv27135Yr9OivU4=
|
||||
github.com/go-openapi/swag v0.19.5/go.mod h1:POnQmlKehdgb5mhVOsnJFsivZCEZ/vjK9gh66Z9tfKk=
|
||||
github.com/go-openapi/swag v0.19.15/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
|
||||
github.com/go-openapi/swag v0.21.1/go.mod h1:QYRuS/SOXUCsnplDa677K7+DxSOj6IPNl/eQntq43wQ=
|
||||
|
@ -195,6 +195,7 @@ github.com/gorilla/securecookie v1.1.2 h1:YCIWL56dvtr73r6715mJs5ZvhtnY73hBvEF8kX
|
|||
github.com/gorilla/securecookie v1.1.2/go.mod h1:NfCASbcHqRSY+3a8tlWJwsQap2VX5pwzwo4h3eOamfo=
|
||||
github.com/gorilla/sessions v1.2.2 h1:lqzMYz6bOfvn2WriPUjNByzeXIlVzURcPmgMczkmTjY=
|
||||
github.com/gorilla/sessions v1.2.2/go.mod h1:ePLdVu+jbEgHH+KWw8I1z2wqd0BAdAQh/8LRvBeoNcQ=
|
||||
github.com/gorilla/websocket v1.4.1/go.mod h1:YR8l580nyteQvAITg2hZ9XVh4b55+EU/adAjf1fMHhE=
|
||||
github.com/gorilla/websocket v1.5.1 h1:gmztn0JnHVt9JZquRuzLw3g4wouNVzKL15iLr/zn/QY=
|
||||
github.com/gorilla/websocket v1.5.1/go.mod h1:x3kM2JMyaluk02fnUJpQuwD2dCS5NDG2ZHL0uE0tcaY=
|
||||
github.com/hashicorp/golang-lru v0.5.0/go.mod h1:/m3WP610KZHVQ1SGc6re/UDhFvYD7pJ4Ao+sR/qLZy8=
|
||||
|
@ -210,6 +211,7 @@ github.com/jstemmer/go-junit-report v0.0.0-20190106144839-af01ea7f8024/go.mod h1
|
|||
github.com/jstemmer/go-junit-report v0.9.1/go.mod h1:Brl9GWCQeLvo8nXZwPNNblvFj/XSXhF0NWZEnDohbsk=
|
||||
github.com/kisielk/gotool v1.0.0/go.mod h1:XhKaO+MFFWcvkIS/tQcRk01m1F5IRFswLeQ+oQHNcck=
|
||||
github.com/klauspost/compress v1.13.6/go.mod h1:/3/Vjq9QcHkK5uEr5lBEmyoZ1iFhe47etQ6QUkpK6sk=
|
||||
github.com/konsorten/go-windows-terminal-sequences v1.0.1/go.mod h1:T0+1ngSBFLxvqU3pZ+m/2kptfBszLMUkC4ZK/EgS/cQ=
|
||||
github.com/kr/pretty v0.1.0/go.mod h1:dAy3ld7l9f0ibDNOQOHHMYYIIbhfbHSm3C4ZsoJORNo=
|
||||
github.com/kr/pretty v0.2.1/go.mod h1:ipq/a2n7PKx3OHsz4KJII5eveXtPO4qwEXGdVfWzfnI=
|
||||
github.com/kr/pretty v0.3.1 h1:flRD4NNwYAUpkphVc1HcthR4KEIFJ65n8Mw5qdRn3LE=
|
||||
|
@ -223,8 +225,8 @@ github.com/mailru/easyjson v0.0.0-20190626092158-b2ccc519800e/go.mod h1:C1wdFJiN
|
|||
github.com/mailru/easyjson v0.7.6/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/mailru/easyjson v0.7.7 h1:UGYAvKxe3sBsEDzO8ZeWOSlIQfWFlxbzLZe7hwFURr0=
|
||||
github.com/mailru/easyjson v0.7.7/go.mod h1:xzfreul335JAWq5oZzymOObrkdz5UnU4kGfJJLY9Nlc=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4 h1:mmDVorXM7PCGKw94cs5zkfA9PSy5pEvNWRP0ET0TIVo=
|
||||
github.com/matttproud/golang_protobuf_extensions v1.0.4/go.mod h1:BSXmuO+STAnVfrANrmjBb36TMTDstsz7MSK+HVaYKv4=
|
||||
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0 h1:jWpvCLoY8Z/e3VKvlsiIGKtc+UG6U5vzxaoagmhXfyg=
|
||||
github.com/matttproud/golang_protobuf_extensions/v2 v2.0.0/go.mod h1:QUyp042oQthUoa9bqDv0ER0wrtXnBruoNd7aNjkbP+k=
|
||||
github.com/mitchellh/mapstructure v1.3.3/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/mitchellh/mapstructure v1.4.1/go.mod h1:bFUtVrKA4DC2yAKiSyO/QUcy7e+RRV2QTWOzhPopBRo=
|
||||
github.com/mitchellh/mapstructure v1.5.0 h1:jeMsZIYE/09sWLaz43PL7Gy6RuMjD2eJVyuac5Z2hdY=
|
||||
|
@ -247,21 +249,22 @@ github.com/pmezard/go-difflib v1.0.0 h1:4DBwDE0NGyQoBHbLQYPwSUPoCMWR5BEzIk/f1lZb
|
|||
github.com/pmezard/go-difflib v1.0.0/go.mod h1:iKH77koFhYxTK1pcRnkKkqfTogsbg7gZNVY4sRDYZ/4=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac h1:jWKYCNlX4J5s8M0nHYkh7Y7c9gRVDEb3mq51j5J0F5M=
|
||||
github.com/pquerna/cachecontrol v0.0.0-20201205024021-ac21108117ac/go.mod h1:hoLfEwdY11HjRfKFH6KqnPsfxlo3BP6bJehpDv8t6sQ=
|
||||
github.com/prometheus/client_golang v1.17.0 h1:rl2sfwZMtSthVU752MqfjQozy7blglC+1SOtjMAMh+Q=
|
||||
github.com/prometheus/client_golang v1.17.0/go.mod h1:VeL+gMmOAxkS2IqfCq0ZmHSL+LjWfWDUmp1mBz9JgUY=
|
||||
github.com/prometheus/client_golang v1.18.0 h1:HzFfmkOzH5Q8L8G+kSJKUx5dtG87sewO+FoDDqP5Tbk=
|
||||
github.com/prometheus/client_golang v1.18.0/go.mod h1:T+GXkCk5wSJyOqMIzVgvvjFDlkOQntgjkJWKrN5txjA=
|
||||
github.com/prometheus/client_model v0.0.0-20190812154241-14fe0d1b01d4/go.mod h1:xMI15A0UPsDsEKsMN9yxemIoYk6Tm2C1GtYGdfGttqA=
|
||||
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16 h1:v7DLqVdK4VrYkVD5diGdl4sxJurKJEMnODWRJlxV9oM=
|
||||
github.com/prometheus/client_model v0.4.1-0.20230718164431-9a2bf3000d16/go.mod h1:oMQmHW1/JoDwqLtg57MGgP/Fb1CJEYF2imWWhWtMkYU=
|
||||
github.com/prometheus/common v0.44.0 h1:+5BrQJwiBB9xsMygAB3TNvpQKOwlkc25LbISbrdOOfY=
|
||||
github.com/prometheus/common v0.44.0/go.mod h1:ofAIvZbQ1e/nugmZGz4/qCb9Ap1VoSTIO7x0VV9VvuY=
|
||||
github.com/prometheus/procfs v0.11.1 h1:xRC8Iq1yyca5ypa9n1EZnWZkt7dwcoRPQwX/5gwaUuI=
|
||||
github.com/prometheus/procfs v0.11.1/go.mod h1:eesXgaPo1q7lBpVMoMy0ZOFTth9hBn4W/y0/p/ScXhY=
|
||||
github.com/prometheus/client_model v0.5.0 h1:VQw1hfvPvk3Uv6Qf29VrPF32JB6rtbgI6cYPYQjL0Qw=
|
||||
github.com/prometheus/client_model v0.5.0/go.mod h1:dTiFglRmd66nLR9Pv9f0mZi7B7fk5Pm3gvsjB5tr+kI=
|
||||
github.com/prometheus/common v0.45.0 h1:2BGz0eBc2hdMDLnO/8n0jeB3oPrt2D08CekT0lneoxM=
|
||||
github.com/prometheus/common v0.45.0/go.mod h1:YJmSTw9BoKxJplESWWxlbyttQR4uaEcGyv9MZjVOJsY=
|
||||
github.com/prometheus/procfs v0.12.0 h1:jluTpSng7V9hY0O2R9DzzJHYb2xULk9VTR1V1R/k6Bo=
|
||||
github.com/prometheus/procfs v0.12.0/go.mod h1:pcuDEFsWDnvcgNzo4EEweacyhjeA9Zk3cnaOZAZEfOo=
|
||||
github.com/redis/go-redis/v9 v9.3.1 h1:KqdY8U+3X6z+iACvumCNxnoluToB+9Me+TvyFa21Mds=
|
||||
github.com/redis/go-redis/v9 v9.3.1/go.mod h1:hdY0cQFCN4fnSYT6TkisLufl/4W5UIXyv0b/CLO2V2M=
|
||||
github.com/rogpeppe/go-internal v1.3.0/go.mod h1:M8bDsm7K2OlrFYOpmOWEs/qY81heoFRclV5y23lUDJ4=
|
||||
github.com/rogpeppe/go-internal v1.11.0 h1:cWPaGQEPrBb5/AsnsZesgZZ9yb1OQ+GOISoDNXVBh4M=
|
||||
github.com/rogpeppe/go-internal v1.11.0/go.mod h1:ddIwULY96R17DhadqLgMfk9H9tvdUzkipdSkR5nkCZA=
|
||||
github.com/russross/blackfriday/v2 v2.1.0/go.mod h1:+Rmxgy9KzJVeS9/2gXHxylqXiyQDYRxCVz55jmeOWTM=
|
||||
github.com/sirupsen/logrus v1.4.2/go.mod h1:tLMulIdttU9McNUspp0xgXVQah82FyeX6MwdIuYE2rE=
|
||||
github.com/sirupsen/logrus v1.9.3 h1:dueUQJ1C2q9oE3F7wvmSGAaVtTmUizReu6fjN8uqzbQ=
|
||||
github.com/sirupsen/logrus v1.9.3/go.mod h1:naHLuLoDiP4jHNo9R0sCBMtWGeIprob74mVsIT4qYEQ=
|
||||
github.com/spf13/cobra v1.8.0 h1:7aJaZx1B85qltLMc546zn58BxxfZdR/W22ej9CFoEf0=
|
||||
|
@ -269,8 +272,10 @@ github.com/spf13/cobra v1.8.0/go.mod h1:WXLWApfZ71AjXPya3WOlMsY9yMs7YeiHhFVlvLyh
|
|||
github.com/spf13/pflag v1.0.5 h1:iy+VFUOCP1a+8yFto/drg2CJ5u0yRoB7fZw3DKv/JXA=
|
||||
github.com/spf13/pflag v1.0.5/go.mod h1:McXfInJRrz4CZXVZOBLb0bTZqETkiAhM9Iw0y3An2Bg=
|
||||
github.com/stretchr/objx v0.1.0/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.1.1/go.mod h1:HFkY916IF+rwdDfMAkV7OtwuqBVzrE8GR6GFx+wExME=
|
||||
github.com/stretchr/objx v0.4.0/go.mod h1:YvHI0jy2hoMjB+UWwv71VJQ9isScKT/TqJzVSSt89Yw=
|
||||
github.com/stretchr/objx v0.5.0/go.mod h1:Yh+to48EsGEfYuaHDzXPcE3xhTkx73EhmCGUpEOglKo=
|
||||
github.com/stretchr/testify v1.2.2/go.mod h1:a8OnRcib4nhh0OaRAV+Yts87kKdq0PP7pXfy6kDkUVs=
|
||||
github.com/stretchr/testify v1.3.0/go.mod h1:M5WIy9Dh21IEIfnGCwXGc5bZfKNJtfHm1UVUgZn+9EI=
|
||||
github.com/stretchr/testify v1.4.0/go.mod h1:j7eGeouHqKxXV5pUuKE4zz7dFj8WfuZ+81PSLYec5m4=
|
||||
github.com/stretchr/testify v1.6.1/go.mod h1:6Fq8oRcR53rry900zMqJjRRixrwX3KX962/h/Wwjteg=
|
||||
|
@ -281,6 +286,8 @@ github.com/stretchr/testify v1.8.1/go.mod h1:w2LPCIKwWwSfY2zedu0+kehJoqGctiVI29o
|
|||
github.com/stretchr/testify v1.8.4 h1:CcVxjf3Q8PM0mHUKJCdn+eZZtm5yQwehR5yeSVQQcUk=
|
||||
github.com/stretchr/testify v1.8.4/go.mod h1:sz/lmYIOXD/1dqDmKjjqLyZ2RngseejIcXlSw2iwfAo=
|
||||
github.com/tidwall/pretty v1.0.0/go.mod h1:XNkn88O1ChpSDQmQeStsy+sBenx6DDtFZJxhVysOjyk=
|
||||
github.com/wwt/guac v1.3.2 h1:sH6OFGa/1tBs7ieWBVlZe7t6F5JAOWBry/tqQL/Vup4=
|
||||
github.com/wwt/guac v1.3.2/go.mod h1:eKm+NrnK7A88l4UBEcYNpZQGMpZRryYKoz4D/0/n1C0=
|
||||
github.com/xdg-go/pbkdf2 v1.0.0/go.mod h1:jrpuAogTd400dnrH08LKmI/xc1MbPOebTwRqcT5RDeI=
|
||||
github.com/xdg-go/scram v1.1.1/go.mod h1:RaEWvsqvNKKvBPvcKeFjrG2cJqOkHTiyTpzz23ni57g=
|
||||
github.com/xdg-go/scram v1.1.2/go.mod h1:RT/sEzTbU5y00aCK8UOx6R7YryM0iF1N2MOmC3kKLN4=
|
||||
|
@ -309,8 +316,8 @@ go.opentelemetry.io/otel/trace v1.17.0 h1:/SWhSRHmDPOImIAetP1QAeMnZYiQXrTy4fMMYO
|
|||
go.opentelemetry.io/otel/trace v1.17.0/go.mod h1:I/4vKTgFclIsXRVucpH25X0mpFSczM7aHeaz0ZBLWjY=
|
||||
go.uber.org/goleak v1.2.1 h1:NBol2c7O1ZokfZ0LEU9K6Whx/KnwvepVetCUhtKja4A=
|
||||
go.uber.org/goleak v1.2.1/go.mod h1:qlT2yGI9QafXHhZZLxlSuNsMw3FFLxBr+tBRlmO1xH4=
|
||||
goauthentik.io/api/v3 v3.2023105.2 h1:ZUblqN5LidnCSlEZ/L19h7OnwppnAA3m5AGC7wUN0Ew=
|
||||
goauthentik.io/api/v3 v3.2023105.2/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
goauthentik.io/api/v3 v3.2023105.3 h1:x0pMJIKkbN198OOssqA94h8bO6ft9gwG8bpZqZL7WVg=
|
||||
goauthentik.io/api/v3 v3.2023105.3/go.mod h1:zz+mEZg8rY/7eEjkMGWJ2DnGqk+zqxuybGCGrR2O4Kw=
|
||||
golang.org/x/crypto v0.0.0-20190308221718-c2843e01d9a2/go.mod h1:djNgcEr1/C05ACkg1iLfiJU5Ep61QUkGW8qpdssI0+w=
|
||||
golang.org/x/crypto v0.0.0-20190510104115-cbcb75029529/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
golang.org/x/crypto v0.0.0-20190605123033-f99c8df09eb5/go.mod h1:yigFU9vqHzYiE8UmvKecakEJjdnWj3jj499lnFckfCI=
|
||||
|
@ -414,6 +421,7 @@ golang.org/x/sys v0.0.0-20180830151530-49385e6e1522/go.mod h1:STP8DvDyc/dI5b8T5h
|
|||
golang.org/x/sys v0.0.0-20190215142949-d0b11bdaac8a/go.mod h1:STP8DvDyc/dI5b8T5hshtkjS+E42TnysNCUPdjciGhY=
|
||||
golang.org/x/sys v0.0.0-20190312061237-fead79001313/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190412213103-97732733099d/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190422165155-953cdadca894/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190502145724-3ef323f4f1fd/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190507160741-ecd444e8653b/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
golang.org/x/sys v0.0.0-20190606165138-5da285871e9c/go.mod h1:h1NjWce9XRLGQEsW7wpKNCjG9DtNlClVuFLEZdDNbEs=
|
||||
|
|
|
@ -159,8 +159,8 @@ func (a *APIController) AddRefreshHandler(handler func()) {
|
|||
a.refreshHandlers = append(a.refreshHandlers, handler)
|
||||
}
|
||||
|
||||
func (a *APIController) AddWSHandler(handler WSHandler) {
|
||||
a.wsHandlers = append(a.wsHandlers, handler)
|
||||
func (a *APIController) Token() string {
|
||||
return a.token
|
||||
}
|
||||
|
||||
func (a *APIController) OnRefresh() error {
|
||||
|
@ -182,7 +182,7 @@ func (a *APIController) OnRefresh() error {
|
|||
return err
|
||||
}
|
||||
|
||||
func (a *APIController) getWebsocketArgs() map[string]interface{} {
|
||||
func (a *APIController) getWebsocketPingArgs() map[string]interface{} {
|
||||
args := map[string]interface{}{
|
||||
"version": constants.VERSION,
|
||||
"buildHash": constants.BUILD("tagged"),
|
||||
|
|
|
@ -18,6 +18,8 @@ import (
|
|||
|
||||
func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
|
||||
pathTemplate := "%s://%s/ws/outpost/%s/?%s"
|
||||
query := akURL.Query()
|
||||
query.Set("instance_uuid", ac.instanceUUID.String())
|
||||
scheme := strings.ReplaceAll(akURL.Scheme, "http", "ws")
|
||||
|
||||
authHeader := fmt.Sprintf("Bearer %s", ac.token)
|
||||
|
@ -45,7 +47,7 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
|
|||
// Send hello message with our version
|
||||
msg := websocketMessage{
|
||||
Instruction: WebsocketInstructionHello,
|
||||
Args: ac.getWebsocketArgs(),
|
||||
Args: ac.getWebsocketPingArgs(),
|
||||
}
|
||||
err = ws.WriteJSON(msg)
|
||||
if err != nil {
|
||||
|
@ -53,7 +55,7 @@ func (ac *APIController) initWS(akURL url.URL, outpostUUID string) error {
|
|||
return err
|
||||
}
|
||||
ac.lastWsReconnect = time.Now()
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Debug("Successfully connected websocket")
|
||||
ac.logger.WithField("logger", "authentik.outpost.ak-ws").WithField("outpost", outpostUUID).Info("Successfully connected websocket")
|
||||
return nil
|
||||
}
|
||||
|
||||
|
@ -157,23 +159,19 @@ func (ac *APIController) startWSHandler() {
|
|||
func (ac *APIController) startWSHealth() {
|
||||
ticker := time.NewTicker(time.Second * 10)
|
||||
for ; true; <-ticker.C {
|
||||
aliveMsg := websocketMessage{
|
||||
Instruction: WebsocketInstructionHello,
|
||||
Args: ac.getWebsocketArgs(),
|
||||
}
|
||||
if ac.wsConn == nil {
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
}
|
||||
err := ac.wsConn.WriteJSON(aliveMsg)
|
||||
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
|
||||
err := ac.SendWSHello(map[string]interface{}{})
|
||||
if err != nil {
|
||||
ac.logger.WithField("loop", "ws-health").WithError(err).Warning("ws write error")
|
||||
go ac.reconnectWS()
|
||||
time.Sleep(time.Second * 5)
|
||||
continue
|
||||
} else {
|
||||
ac.logger.WithField("loop", "ws-health").Trace("hello'd")
|
||||
ConnectionStatus.With(prometheus.Labels{
|
||||
"outpost_name": ac.Outpost.Name,
|
||||
"outpost_type": ac.Server.Type(),
|
||||
|
@ -202,3 +200,20 @@ func (ac *APIController) startIntervalUpdater() {
|
|||
}
|
||||
}
|
||||
}
|
||||
|
||||
func (a *APIController) AddWSHandler(handler WSHandler) {
|
||||
a.wsHandlers = append(a.wsHandlers, handler)
|
||||
}
|
||||
|
||||
func (a *APIController) SendWSHello(args map[string]interface{}) error {
|
||||
allArgs := a.getWebsocketPingArgs()
|
||||
for key, value := range args {
|
||||
allArgs[key] = value
|
||||
}
|
||||
aliveMsg := websocketMessage{
|
||||
Instruction: WebsocketInstructionHello,
|
||||
Args: allArgs,
|
||||
}
|
||||
err := a.wsConn.WriteJSON(aliveMsg)
|
||||
return err
|
||||
}
|
||||
|
|
|
@ -6,6 +6,7 @@ import (
|
|||
"strings"
|
||||
|
||||
"beryju.io/ldap"
|
||||
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/utils"
|
||||
|
@ -49,8 +50,8 @@ func (pi *ProviderInstance) UserEntry(u api.User) *ldap.Entry {
|
|||
constants.OCPosixAccount,
|
||||
constants.OCAKUser,
|
||||
},
|
||||
"uidNumber": {pi.GetUidNumber(u)},
|
||||
"gidNumber": {pi.GetUidNumber(u)},
|
||||
"uidNumber": {pi.GetUserUidNumber(u)},
|
||||
"gidNumber": {pi.GetUserGidNumber(u)},
|
||||
"homeDirectory": {fmt.Sprintf("/home/%s", u.Username)},
|
||||
"sn": {u.Name},
|
||||
})
|
||||
|
|
|
@ -4,6 +4,7 @@ import (
|
|||
"strconv"
|
||||
|
||||
"beryju.io/ldap"
|
||||
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/outpost/ldap/constants"
|
||||
"goauthentik.io/internal/outpost/ldap/server"
|
||||
|
@ -50,7 +51,7 @@ func FromAPIGroup(g api.Group, si server.LDAPServerInstance) *LDAPGroup {
|
|||
DN: si.GetGroupDN(g.Name),
|
||||
CN: g.Name,
|
||||
Uid: string(g.Pk),
|
||||
GidNumber: si.GetGidNumber(g),
|
||||
GidNumber: si.GetGroupGidNumber(g),
|
||||
Member: si.UsersForGroup(g),
|
||||
IsVirtualGroup: false,
|
||||
IsSuperuser: *g.IsSuperuser,
|
||||
|
@ -63,7 +64,7 @@ func FromAPIUser(u api.User, si server.LDAPServerInstance) *LDAPGroup {
|
|||
DN: si.GetVirtualGroupDN(u.Username),
|
||||
CN: u.Username,
|
||||
Uid: u.Uid,
|
||||
GidNumber: si.GetUidNumber(u),
|
||||
GidNumber: si.GetUserGidNumber(u),
|
||||
Member: []string{si.GetUserDN(u.Username)},
|
||||
IsVirtualGroup: true,
|
||||
IsSuperuser: false,
|
||||
|
|
|
@ -3,6 +3,7 @@ package server
|
|||
import (
|
||||
"beryju.io/ldap"
|
||||
"github.com/go-openapi/strfmt"
|
||||
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/outpost/ldap/flags"
|
||||
)
|
||||
|
@ -28,8 +29,9 @@ type LDAPServerInstance interface {
|
|||
GetGroupDN(string) string
|
||||
GetVirtualGroupDN(string) string
|
||||
|
||||
GetUidNumber(api.User) string
|
||||
GetGidNumber(api.Group) string
|
||||
GetUserUidNumber(api.User) string
|
||||
GetUserGidNumber(api.User) string
|
||||
GetGroupGidNumber(api.Group) string
|
||||
|
||||
UsersForGroup(api.Group) []string
|
||||
|
||||
|
|
|
@ -35,7 +35,7 @@ func (pi *ProviderInstance) GetVirtualGroupDN(group string) string {
|
|||
return fmt.Sprintf("cn=%s,%s", group, pi.VirtualGroupDN)
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetUidNumber(user api.User) string {
|
||||
func (pi *ProviderInstance) GetUserUidNumber(user api.User) string {
|
||||
uidNumber, ok := user.GetAttributes()["uidNumber"].(string)
|
||||
|
||||
if ok {
|
||||
|
@ -45,7 +45,17 @@ func (pi *ProviderInstance) GetUidNumber(user api.User) string {
|
|||
return strconv.FormatInt(int64(pi.uidStartNumber+user.Pk), 10)
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetGidNumber(group api.Group) string {
|
||||
func (pi *ProviderInstance) GetUserGidNumber(user api.User) string {
|
||||
gidNumber, ok := user.GetAttributes()["gidNumber"].(string)
|
||||
|
||||
if ok {
|
||||
return gidNumber
|
||||
}
|
||||
|
||||
return pi.GetUserUidNumber(user)
|
||||
}
|
||||
|
||||
func (pi *ProviderInstance) GetGroupGidNumber(group api.Group) string {
|
||||
gidNumber, ok := group.GetAttributes()["gidNumber"].(string)
|
||||
|
||||
if ok {
|
||||
|
|
|
@ -31,16 +31,11 @@ func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Co
|
|||
return nil, err
|
||||
}
|
||||
|
||||
// Extract the ID Token from OAuth2 token.
|
||||
rawIDToken, ok := oauth2Token.Extra("id_token").(string)
|
||||
if !ok {
|
||||
return nil, fmt.Errorf("missing id_token")
|
||||
}
|
||||
|
||||
a.log.WithField("id_token", rawIDToken).Trace("id_token")
|
||||
jwt := oauth2Token.AccessToken
|
||||
a.log.WithField("jwt", jwt).Trace("access_token")
|
||||
|
||||
// Parse and verify ID Token payload.
|
||||
idToken, err := a.tokenVerifier.Verify(ctx, rawIDToken)
|
||||
idToken, err := a.tokenVerifier.Verify(ctx, jwt)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
|
@ -53,6 +48,6 @@ func (a *Application) redeemCallback(savedState string, u *url.URL, c context.Co
|
|||
if claims.Proxy == nil {
|
||||
claims.Proxy = &ProxyClaims{}
|
||||
}
|
||||
claims.RawToken = rawIDToken
|
||||
claims.RawToken = jwt
|
||||
return claims, nil
|
||||
}
|
||||
|
|
|
@ -13,6 +13,7 @@ import (
|
|||
"github.com/gorilla/securecookie"
|
||||
"github.com/gorilla/sessions"
|
||||
"github.com/redis/go-redis/v9"
|
||||
|
||||
"goauthentik.io/api/v3"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/outpost/proxyv2/codecs"
|
||||
|
@ -40,7 +41,7 @@ func (a *Application) getStore(p api.ProxyOutpostConfig, externalHost *url.URL)
|
|||
// New default RedisStore
|
||||
rs, err := redisstore.NewRedisStore(context.Background(), client)
|
||||
if err != nil {
|
||||
panic(err)
|
||||
a.log.WithError(err).Panic("failed to connect to redis")
|
||||
}
|
||||
|
||||
rs.KeyPrefix(RedisKeyPrefix)
|
||||
|
|
124
internal/outpost/rac/connection/connection.go
Normal file
124
internal/outpost/rac/connection/connection.go
Normal file
|
@ -0,0 +1,124 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"context"
|
||||
"crypto/tls"
|
||||
"fmt"
|
||||
"net"
|
||||
"net/http"
|
||||
"strings"
|
||||
"time"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wwt/guac"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/constants"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
)
|
||||
|
||||
const guacAddr = "0.0.0.0:4822"
|
||||
|
||||
type Connection struct {
|
||||
log *log.Entry
|
||||
st *guac.SimpleTunnel
|
||||
ac *ak.APIController
|
||||
ws *websocket.Conn
|
||||
ctx context.Context
|
||||
ctxCancel context.CancelFunc
|
||||
OnError func(error)
|
||||
closing bool
|
||||
}
|
||||
|
||||
func NewConnection(ac *ak.APIController, forChannel string, cfg *guac.Config) (*Connection, error) {
|
||||
ctx, canc := context.WithCancel(context.Background())
|
||||
c := &Connection{
|
||||
ac: ac,
|
||||
log: log.WithField("connection", forChannel),
|
||||
ctx: ctx,
|
||||
ctxCancel: canc,
|
||||
OnError: func(err error) {},
|
||||
closing: false,
|
||||
}
|
||||
err := c.initGuac(cfg)
|
||||
if err != nil {
|
||||
return nil, err
|
||||
}
|
||||
err = c.initSocket(forChannel)
|
||||
if err != nil {
|
||||
_ = c.st.Close()
|
||||
return nil, err
|
||||
}
|
||||
c.initMirror()
|
||||
return c, nil
|
||||
}
|
||||
|
||||
func (c *Connection) initSocket(forChannel string) error {
|
||||
pathTemplate := "%s://%s/ws/outpost_rac/%s/"
|
||||
scheme := strings.ReplaceAll(c.ac.Client.GetConfig().Scheme, "http", "ws")
|
||||
|
||||
authHeader := fmt.Sprintf("Bearer %s", c.ac.Token())
|
||||
|
||||
header := http.Header{
|
||||
"Authorization": []string{authHeader},
|
||||
"User-Agent": []string{constants.OutpostUserAgent()},
|
||||
}
|
||||
|
||||
dialer := websocket.Dialer{
|
||||
Proxy: http.ProxyFromEnvironment,
|
||||
HandshakeTimeout: 10 * time.Second,
|
||||
TLSClientConfig: &tls.Config{
|
||||
InsecureSkipVerify: config.Get().AuthentikInsecure,
|
||||
},
|
||||
}
|
||||
|
||||
url := fmt.Sprintf(pathTemplate, scheme, c.ac.Client.GetConfig().Host, forChannel)
|
||||
ws, _, err := dialer.Dial(url, header)
|
||||
if err != nil {
|
||||
c.log.WithError(err).Warning("failed to connect websocket")
|
||||
return err
|
||||
}
|
||||
c.ws = ws
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) initGuac(cfg *guac.Config) error {
|
||||
addr, err := net.ResolveTCPAddr("tcp", guacAddr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
conn, err := net.DialTCP("tcp", nil, addr)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
|
||||
stream := guac.NewStream(conn, guac.SocketTimeout)
|
||||
|
||||
err = stream.Handshake(cfg)
|
||||
if err != nil {
|
||||
return err
|
||||
}
|
||||
st := guac.NewSimpleTunnel(stream)
|
||||
c.st = st
|
||||
return nil
|
||||
}
|
||||
|
||||
func (c *Connection) initMirror() {
|
||||
go c.wsToGuacd()
|
||||
go c.guacdToWs()
|
||||
}
|
||||
|
||||
func (c *Connection) onError(err error) {
|
||||
if c.closing {
|
||||
return
|
||||
}
|
||||
c.closing = true
|
||||
e := c.st.Close()
|
||||
if e != nil {
|
||||
c.log.WithError(e).Warning("failed to close guacd connection")
|
||||
}
|
||||
c.log.WithError(err).Info("removing connection")
|
||||
c.ctxCancel()
|
||||
c.OnError(err)
|
||||
}
|
103
internal/outpost/rac/connection/mirror.go
Normal file
103
internal/outpost/rac/connection/mirror.go
Normal file
|
@ -0,0 +1,103 @@
|
|||
package connection
|
||||
|
||||
import (
|
||||
"bytes"
|
||||
"fmt"
|
||||
|
||||
"github.com/gorilla/websocket"
|
||||
"github.com/wwt/guac"
|
||||
)
|
||||
|
||||
var (
|
||||
internalOpcodeIns = []byte(fmt.Sprint(len(guac.InternalDataOpcode), ".", guac.InternalDataOpcode))
|
||||
authentikOpcode = []byte("0.authentik.")
|
||||
)
|
||||
|
||||
// MessageReader wraps a websocket connection and only permits Reading
|
||||
type MessageReader interface {
|
||||
// ReadMessage should return a single complete message to send to guac
|
||||
ReadMessage() (int, []byte, error)
|
||||
}
|
||||
|
||||
func (c *Connection) wsToGuacd() {
|
||||
w := c.st.AcquireWriter()
|
||||
for {
|
||||
select {
|
||||
default:
|
||||
_, data, e := c.ws.ReadMessage()
|
||||
if e != nil {
|
||||
c.log.WithError(e).Trace("Error reading message from ws")
|
||||
c.onError(e)
|
||||
return
|
||||
}
|
||||
if bytes.HasPrefix(data, internalOpcodeIns) {
|
||||
if bytes.HasPrefix(data, authentikOpcode) {
|
||||
switch string(bytes.Replace(data, authentikOpcode, []byte{}, 1)) {
|
||||
case "disconnect":
|
||||
_, e := w.Write([]byte(guac.NewInstruction("disconnect").String()))
|
||||
c.onError(e)
|
||||
return
|
||||
}
|
||||
}
|
||||
// messages starting with the InternalDataOpcode are never sent to guacd
|
||||
continue
|
||||
}
|
||||
|
||||
if _, e = w.Write(data); e != nil {
|
||||
c.log.WithError(e).Trace("Failed writing to guacd")
|
||||
c.onError(e)
|
||||
return
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
// MessageWriter wraps a websocket connection and only permits Writing
|
||||
type MessageWriter interface {
|
||||
// WriteMessage writes one or more complete guac commands to the websocket
|
||||
WriteMessage(int, []byte) error
|
||||
}
|
||||
|
||||
func (c *Connection) guacdToWs() {
|
||||
r := c.st.AcquireReader()
|
||||
buf := bytes.NewBuffer(make([]byte, 0, guac.MaxGuacMessage*2))
|
||||
for {
|
||||
select {
|
||||
default:
|
||||
ins, e := r.ReadSome()
|
||||
if e != nil {
|
||||
c.log.WithError(e).Trace("Error reading from guacd")
|
||||
c.onError(e)
|
||||
return
|
||||
}
|
||||
|
||||
if bytes.HasPrefix(ins, internalOpcodeIns) {
|
||||
// messages starting with the InternalDataOpcode are never sent to the websocket
|
||||
continue
|
||||
}
|
||||
|
||||
if _, e = buf.Write(ins); e != nil {
|
||||
c.log.WithError(e).Trace("Failed to buffer guacd to ws")
|
||||
c.onError(e)
|
||||
return
|
||||
}
|
||||
|
||||
// if the buffer has more data in it or we've reached the max buffer size, send the data and reset
|
||||
if !r.Available() || buf.Len() >= guac.MaxGuacMessage {
|
||||
if e = c.ws.WriteMessage(1, buf.Bytes()); e != nil {
|
||||
if e == websocket.ErrCloseSent {
|
||||
return
|
||||
}
|
||||
c.log.WithError(e).Trace("Failed sending message to ws")
|
||||
c.onError(e)
|
||||
return
|
||||
}
|
||||
buf.Reset()
|
||||
}
|
||||
case <-c.ctx.Done():
|
||||
return
|
||||
}
|
||||
}
|
||||
}
|
26
internal/outpost/rac/guacd.go
Normal file
26
internal/outpost/rac/guacd.go
Normal file
|
@ -0,0 +1,26 @@
|
|||
package rac
|
||||
|
||||
import (
|
||||
"os"
|
||||
"os/exec"
|
||||
"strings"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
)
|
||||
|
||||
const (
|
||||
guacdPath = "/opt/guacamole/sbin/guacd"
|
||||
guacdDefaultArgs = " -b 0.0.0.0 -f"
|
||||
)
|
||||
|
||||
func (rs *RACServer) startGuac() error {
|
||||
guacdArgs := strings.Split(guacdDefaultArgs, " ")
|
||||
guacdArgs = append(guacdArgs, "-L", rs.ac.Outpost.Config[ak.ConfigLogLevel].(string))
|
||||
rs.guacd = exec.Command(guacdPath, guacdArgs...)
|
||||
rs.guacd.Env = os.Environ()
|
||||
rs.guacd.Stdout = rs.log.WithField("logger", "authentik.outpost.rac.guacd").WriterLevel(log.InfoLevel)
|
||||
rs.guacd.Stderr = rs.log.WithField("logger", "authentik.outpost.rac.guacd").WriterLevel(log.InfoLevel)
|
||||
rs.log.Info("starting guacd")
|
||||
return rs.guacd.Start()
|
||||
}
|
28
internal/outpost/rac/metrics/metrics.go
Normal file
28
internal/outpost/rac/metrics/metrics.go
Normal file
|
@ -0,0 +1,28 @@
|
|||
package metrics
|
||||
|
||||
import (
|
||||
"net/http"
|
||||
|
||||
log "github.com/sirupsen/logrus"
|
||||
"goauthentik.io/internal/config"
|
||||
"goauthentik.io/internal/utils/sentry"
|
||||
|
||||
"github.com/gorilla/mux"
|
||||
"github.com/prometheus/client_golang/prometheus/promhttp"
|
||||
)
|
||||
|
||||
func RunServer() {
|
||||
m := mux.NewRouter()
|
||||
l := log.WithField("logger", "authentik.outpost.metrics")
|
||||
m.Use(sentry.SentryNoSampleMiddleware)
|
||||
m.HandleFunc("/outpost.goauthentik.io/ping", func(rw http.ResponseWriter, r *http.Request) {
|
||||
rw.WriteHeader(204)
|
||||
})
|
||||
m.Path("/metrics").Handler(promhttp.Handler())
|
||||
listen := config.Get().Listen.Metrics
|
||||
l.WithField("listen", listen).Info("Starting Metrics server")
|
||||
err := http.ListenAndServe(listen, m)
|
||||
if err != nil {
|
||||
l.WithError(err).Warning("Failed to start metrics listener")
|
||||
}
|
||||
}
|
126
internal/outpost/rac/rac.go
Normal file
126
internal/outpost/rac/rac.go
Normal file
|
@ -0,0 +1,126 @@
|
|||
package rac
|
||||
|
||||
import (
|
||||
"context"
|
||||
"os/exec"
|
||||
"strconv"
|
||||
"sync"
|
||||
|
||||
"github.com/mitchellh/mapstructure"
|
||||
log "github.com/sirupsen/logrus"
|
||||
"github.com/wwt/guac"
|
||||
|
||||
"goauthentik.io/internal/outpost/ak"
|
||||
"goauthentik.io/internal/outpost/rac/connection"
|
||||
"goauthentik.io/internal/outpost/rac/metrics"
|
||||
)
|
||||
|
||||
type RACServer struct {
|
||||
log *log.Entry
|
||||
ac *ak.APIController
|
||||
guacd *exec.Cmd
|
||||
connm sync.RWMutex
|
||||
conns map[string]connection.Connection
|
||||
}
|
||||
|
||||
func NewServer(ac *ak.APIController) *RACServer {
|
||||
rs := &RACServer{
|
||||
log: log.WithField("logger", "authentik.outpost.rac"),
|
||||
ac: ac,
|
||||
connm: sync.RWMutex{},
|
||||
conns: map[string]connection.Connection{},
|
||||
}
|
||||
ac.AddWSHandler(rs.wsHandler)
|
||||
return rs
|
||||
}
|
||||
|
||||
type WSMessage struct {
|
||||
ConnID string `mapstructure:"conn_id"`
|
||||
DestChannelID string `mapstructure:"dest_channel_id"`
|
||||
Params map[string]string `mapstructure:"params"`
|
||||
Protocol string `mapstructure:"protocol"`
|
||||
OptimalScreenWidth string `mapstructure:"screen_width"`
|
||||
OptimalScreenHeight string `mapstructure:"screen_height"`
|
||||
OptimalScreenDPI string `mapstructure:"screen_dpi"`
|
||||
}
|
||||
|
||||
func parseIntOrZero(input string) int {
|
||||
x, err := strconv.Atoi(input)
|
||||
if err != nil {
|
||||
return 0
|
||||
}
|
||||
return x
|
||||
}
|
||||
|
||||
func (rs *RACServer) wsHandler(ctx context.Context, args map[string]interface{}) {
|
||||
wsm := WSMessage{}
|
||||
err := mapstructure.Decode(args, &wsm)
|
||||
if err != nil {
|
||||
rs.log.WithError(err).Warning("invalid ws message")
|
||||
return
|
||||
}
|
||||
config := guac.NewGuacamoleConfiguration()
|
||||
config.Protocol = wsm.Protocol
|
||||
config.Parameters = wsm.Params
|
||||
config.OptimalScreenWidth = parseIntOrZero(wsm.OptimalScreenWidth)
|
||||
config.OptimalScreenHeight = parseIntOrZero(wsm.OptimalScreenHeight)
|
||||
config.OptimalResolution = parseIntOrZero(wsm.OptimalScreenDPI)
|
||||
config.AudioMimetypes = []string{
|
||||
"audio/L8",
|
||||
"audio/L16",
|
||||
}
|
||||
cc, err := connection.NewConnection(rs.ac, wsm.DestChannelID, config)
|
||||
if err != nil {
|
||||
rs.log.WithError(err).Warning("failed to setup connection")
|
||||
return
|
||||
}
|
||||
cc.OnError = func(err error) {
|
||||
rs.connm.Lock()
|
||||
delete(rs.conns, wsm.ConnID)
|
||||
_ = rs.ac.SendWSHello(map[string]interface{}{
|
||||
"active_connections": len(rs.conns),
|
||||
})
|
||||
rs.connm.Unlock()
|
||||
}
|
||||
rs.connm.Lock()
|
||||
rs.conns[wsm.ConnID] = *cc
|
||||
_ = rs.ac.SendWSHello(map[string]interface{}{
|
||||
"active_connections": len(rs.conns),
|
||||
})
|
||||
rs.connm.Unlock()
|
||||
}
|
||||
|
||||
func (rs *RACServer) Start() error {
|
||||
wg := sync.WaitGroup{}
|
||||
wg.Add(2)
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
metrics.RunServer()
|
||||
}()
|
||||
go func() {
|
||||
defer wg.Done()
|
||||
err := rs.startGuac()
|
||||
if err != nil {
|
||||
panic(err)
|
||||
}
|
||||
}()
|
||||
wg.Wait()
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RACServer) Stop() error {
|
||||
if rs.guacd != nil {
|
||||
return rs.guacd.Process.Kill()
|
||||
}
|
||||
return nil
|
||||
}
|
||||
|
||||
func (rs *RACServer) TimerFlowCacheExpiry(context.Context) {}
|
||||
|
||||
func (rs *RACServer) Type() string {
|
||||
return "rac"
|
||||
}
|
||||
|
||||
func (rs *RACServer) Refresh() error {
|
||||
return nil
|
||||
}
|
|
@ -34,6 +34,11 @@ func (ws *WebServer) configureStatic() {
|
|||
})
|
||||
indexLessRouter.PathPrefix("/if/admin/assets").Handler(http.StripPrefix("/if/admin", distFs))
|
||||
indexLessRouter.PathPrefix("/if/user/assets").Handler(http.StripPrefix("/if/user", distFs))
|
||||
indexLessRouter.PathPrefix("/if/rac/{app_slug}/assets").HandlerFunc(func(rw http.ResponseWriter, r *http.Request) {
|
||||
vars := mux.Vars(r)
|
||||
|
||||
web.DisableIndex(http.StripPrefix(fmt.Sprintf("/if/rac/%s", vars["app_slug"]), distFs)).ServeHTTP(rw, r)
|
||||
})
|
||||
|
||||
indexLessRouter.PathPrefix("/media/").Handler(http.StripPrefix("/media", fs))
|
||||
|
||||
|
|
BIN
locale/it/LC_MESSAGES/django.mo
Normal file
BIN
locale/it/LC_MESSAGES/django.mo
Normal file
Binary file not shown.
Binary file not shown.
1065
poetry.lock
generated
1065
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -145,7 +145,12 @@ geoip2 = "*"
|
|||
gunicorn = "*"
|
||||
kubernetes = "*"
|
||||
ldap3 = "*"
|
||||
lxml = "*"
|
||||
lxml = [
|
||||
# 5.0.0 works with libxml2 2.11.x, which is standard on brew
|
||||
{ version = "5.0.0", platform = "darwin" },
|
||||
# 4.9.x works with previous libxml2 versions, which is what we get on linux
|
||||
{ version = "4.9.4", platform = "linux" },
|
||||
]
|
||||
opencontainers = { extras = ["reggie"], version = "*" }
|
||||
packaging = "*"
|
||||
paramiko = "*"
|
||||
|
|
38
rac.Dockerfile
Normal file
38
rac.Dockerfile
Normal file
|
@ -0,0 +1,38 @@
|
|||
# syntax=docker/dockerfile:1
|
||||
|
||||
# Stage 1: Build
|
||||
FROM docker.io/golang:1.21.5-bookworm AS builder
|
||||
|
||||
WORKDIR /go/src/goauthentik.io
|
||||
|
||||
RUN --mount=type=bind,target=/go/src/goauthentik.io/go.mod,src=./go.mod \
|
||||
--mount=type=bind,target=/go/src/goauthentik.io/go.sum,src=./go.sum \
|
||||
--mount=type=bind,target=/go/src/goauthentik.io/gen-go-api,src=./gen-go-api \
|
||||
--mount=type=cache,target=/go/pkg/mod \
|
||||
go mod download
|
||||
|
||||
ENV CGO_ENABLED=0
|
||||
COPY . .
|
||||
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \
|
||||
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \
|
||||
go build -o /go/rac ./cmd/rac
|
||||
|
||||
# Stage 2: Run
|
||||
FROM ghcr.io/beryju/guacd:1.5.3
|
||||
|
||||
ARG GIT_BUILD_HASH
|
||||
ENV GIT_BUILD_HASH=$GIT_BUILD_HASH
|
||||
|
||||
LABEL org.opencontainers.image.url https://goauthentik.io
|
||||
LABEL org.opencontainers.image.description goauthentik.io RAC outpost, see https://goauthentik.io for more info.
|
||||
LABEL org.opencontainers.image.source https://github.com/goauthentik/authentik
|
||||
LABEL org.opencontainers.image.version ${VERSION}
|
||||
LABEL org.opencontainers.image.revision ${GIT_BUILD_HASH}
|
||||
|
||||
COPY --from=builder /go/rac /
|
||||
|
||||
HEALTHCHECK --interval=5s --retries=20 --start-period=3s CMD [ "/rac", "healthcheck" ]
|
||||
|
||||
USER 1000
|
||||
|
||||
ENTRYPOINT ["/rac"]
|
1232
schema.yml
1232
schema.yml
File diff suppressed because it is too large
Load diff
|
@ -1,8 +1,6 @@
|
|||
"""LDAP and Outpost e2e tests"""
|
||||
from dataclasses import asdict
|
||||
from sys import platform
|
||||
from time import sleep
|
||||
from unittest.case import skipUnless
|
||||
|
||||
from docker.client import DockerClient, from_env
|
||||
from docker.models.containers import Container
|
||||
|
@ -14,13 +12,13 @@ from authentik.blueprints.tests import apply_blueprint, reconcile_app
|
|||
from authentik.core.models import Application, User
|
||||
from authentik.events.models import Event, EventAction
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.lib.generators import generate_id
|
||||
from authentik.outposts.apps import MANAGED_OUTPOST
|
||||
from authentik.outposts.models import Outpost, OutpostConfig, OutpostType
|
||||
from authentik.providers.ldap.models import APIAccessMode, LDAPProvider
|
||||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderLDAP(SeleniumTestCase):
|
||||
"""LDAP and Outpost e2e tests"""
|
||||
|
||||
|
@ -37,7 +35,10 @@ class TestProviderLDAP(SeleniumTestCase):
|
|||
container = client.containers.run(
|
||||
image=self.get_container_image("ghcr.io/goauthentik/dev-ldap"),
|
||||
detach=True,
|
||||
network_mode="host",
|
||||
ports={
|
||||
"3389": "3389",
|
||||
"6636": "6636",
|
||||
},
|
||||
environment={
|
||||
"AUTHENTIK_HOST": self.live_server_url,
|
||||
"AUTHENTIK_TOKEN": outpost.token.key,
|
||||
|
@ -51,15 +52,15 @@ class TestProviderLDAP(SeleniumTestCase):
|
|||
self.user.save()
|
||||
|
||||
ldap: LDAPProvider = LDAPProvider.objects.create(
|
||||
name="ldap_provider",
|
||||
name=generate_id(),
|
||||
authorization_flow=Flow.objects.get(slug="default-authentication-flow"),
|
||||
search_group=self.user.ak_groups.first(),
|
||||
search_mode=APIAccessMode.CACHED,
|
||||
)
|
||||
# we need to create an application to actually access the ldap
|
||||
Application.objects.create(name="ldap", slug="ldap", provider=ldap)
|
||||
Application.objects.create(name=generate_id(), slug=generate_id(), provider=ldap)
|
||||
outpost: Outpost = Outpost.objects.create(
|
||||
name="ldap_outpost",
|
||||
name=generate_id(),
|
||||
type=OutpostType.LDAP,
|
||||
_config=asdict(OutpostConfig(log_level="debug")),
|
||||
)
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
"""test OAuth Provider flow"""
|
||||
from sys import platform
|
||||
from time import sleep
|
||||
from typing import Any, Optional
|
||||
from unittest.case import skipUnless
|
||||
|
||||
from docker.types import Healthcheck
|
||||
from selenium.webdriver.common.by import By
|
||||
|
@ -18,7 +16,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider
|
|||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderOAuth2Github(SeleniumTestCase):
|
||||
"""test OAuth Provider flow"""
|
||||
|
||||
|
@ -32,7 +29,9 @@ class TestProviderOAuth2Github(SeleniumTestCase):
|
|||
return {
|
||||
"image": "grafana/grafana:7.1.0",
|
||||
"detach": True,
|
||||
"network_mode": "host",
|
||||
"ports": {
|
||||
"3000": "3000",
|
||||
},
|
||||
"auto_remove": True,
|
||||
"healthcheck": Healthcheck(
|
||||
test=["CMD", "wget", "--spider", "http://localhost:3000"],
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
"""test OAuth2 OpenID Provider flow"""
|
||||
from sys import platform
|
||||
from time import sleep
|
||||
from typing import Any, Optional
|
||||
from unittest.case import skipUnless
|
||||
|
||||
from docker.types import Healthcheck
|
||||
from selenium.webdriver.common.by import By
|
||||
|
@ -24,7 +22,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
|
|||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderOAuth2OAuth(SeleniumTestCase):
|
||||
"""test OAuth with OAuth Provider flow"""
|
||||
|
||||
|
@ -38,13 +35,15 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
|
|||
return {
|
||||
"image": "grafana/grafana:7.1.0",
|
||||
"detach": True,
|
||||
"network_mode": "host",
|
||||
"auto_remove": True,
|
||||
"healthcheck": Healthcheck(
|
||||
test=["CMD", "wget", "--spider", "http://localhost:3000"],
|
||||
interval=5 * 1_000 * 1_000_000,
|
||||
start_period=1 * 1_000 * 1_000_000,
|
||||
),
|
||||
"ports": {
|
||||
"3000": "3000",
|
||||
},
|
||||
"environment": {
|
||||
"GF_AUTH_GENERIC_OAUTH_ENABLED": "true",
|
||||
"GF_AUTH_GENERIC_OAUTH_CLIENT_ID": self.client_id,
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
"""test OAuth2 OpenID Provider flow"""
|
||||
from json import loads
|
||||
from sys import platform
|
||||
from time import sleep
|
||||
from unittest.case import skipUnless
|
||||
|
||||
from docker import DockerClient, from_env
|
||||
from docker.models.containers import Container
|
||||
|
@ -25,7 +23,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
|
|||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderOAuth2OIDC(SeleniumTestCase):
|
||||
"""test OAuth with OpenID Provider flow"""
|
||||
|
||||
|
@ -36,13 +33,15 @@ class TestProviderOAuth2OIDC(SeleniumTestCase):
|
|||
super().setUp()
|
||||
|
||||
def setup_client(self) -> Container:
|
||||
"""Setup client saml-sp container which we test SAML against"""
|
||||
"""Setup client oidc-test-client container which we test OIDC against"""
|
||||
sleep(1)
|
||||
client: DockerClient = from_env()
|
||||
container = client.containers.run(
|
||||
image="ghcr.io/beryju/oidc-test-client:1.3",
|
||||
detach=True,
|
||||
network_mode="host",
|
||||
ports={
|
||||
"9009": "9009",
|
||||
},
|
||||
environment={
|
||||
"OIDC_CLIENT_ID": self.client_id,
|
||||
"OIDC_CLIENT_SECRET": self.client_secret,
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
"""test OAuth2 OpenID Provider flow"""
|
||||
from json import loads
|
||||
from sys import platform
|
||||
from time import sleep
|
||||
from unittest.case import skipUnless
|
||||
|
||||
from docker import DockerClient, from_env
|
||||
from docker.models.containers import Container
|
||||
|
@ -25,7 +23,6 @@ from authentik.providers.oauth2.models import ClientTypes, OAuth2Provider, Scope
|
|||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
|
||||
"""test OAuth with OpenID Provider flow"""
|
||||
|
||||
|
@ -36,13 +33,15 @@ class TestProviderOAuth2OIDCImplicit(SeleniumTestCase):
|
|||
super().setUp()
|
||||
|
||||
def setup_client(self) -> Container:
|
||||
"""Setup client saml-sp container which we test SAML against"""
|
||||
"""Setup client oidc-test-client container which we test OIDC against"""
|
||||
sleep(1)
|
||||
client: DockerClient = from_env()
|
||||
container = client.containers.run(
|
||||
image="ghcr.io/beryju/oidc-test-client:1.3",
|
||||
detach=True,
|
||||
network_mode="host",
|
||||
ports={
|
||||
"9009": "9009",
|
||||
},
|
||||
environment={
|
||||
"OIDC_CLIENT_ID": self.client_id,
|
||||
"OIDC_CLIENT_SECRET": self.client_secret,
|
||||
|
|
|
@ -21,7 +21,6 @@ from authentik.providers.proxy.models import ProxyProvider
|
|||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderProxy(SeleniumTestCase):
|
||||
"""Proxy and Outpost e2e tests"""
|
||||
|
||||
|
@ -36,7 +35,9 @@ class TestProviderProxy(SeleniumTestCase):
|
|||
return {
|
||||
"image": "traefik/whoami:latest",
|
||||
"detach": True,
|
||||
"network_mode": "host",
|
||||
"ports": {
|
||||
"80": "80",
|
||||
},
|
||||
"auto_remove": True,
|
||||
}
|
||||
|
||||
|
@ -46,7 +47,9 @@ class TestProviderProxy(SeleniumTestCase):
|
|||
container = client.containers.run(
|
||||
image=self.get_container_image("ghcr.io/goauthentik/dev-proxy"),
|
||||
detach=True,
|
||||
network_mode="host",
|
||||
ports={
|
||||
"9000": "9000",
|
||||
},
|
||||
environment={
|
||||
"AUTHENTIK_HOST": self.live_server_url,
|
||||
"AUTHENTIK_TOKEN": outpost.token.key,
|
||||
|
@ -78,7 +81,7 @@ class TestProviderProxy(SeleniumTestCase):
|
|||
authorization_flow=Flow.objects.get(
|
||||
slug="default-provider-authorization-implicit-consent"
|
||||
),
|
||||
internal_host="http://localhost",
|
||||
internal_host=f"http://{self.host}",
|
||||
external_host="http://localhost:9000",
|
||||
)
|
||||
# Ensure OAuth2 Params are set
|
||||
|
@ -145,7 +148,7 @@ class TestProviderProxy(SeleniumTestCase):
|
|||
authorization_flow=Flow.objects.get(
|
||||
slug="default-provider-authorization-implicit-consent"
|
||||
),
|
||||
internal_host="http://localhost",
|
||||
internal_host=f"http://{self.host}",
|
||||
external_host="http://localhost:9000",
|
||||
basic_auth_enabled=True,
|
||||
basic_auth_user_attribute="basic-username",
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
"""Radius e2e tests"""
|
||||
from dataclasses import asdict
|
||||
from sys import platform
|
||||
from time import sleep
|
||||
from unittest.case import skipUnless
|
||||
|
||||
from docker.client import DockerClient, from_env
|
||||
from docker.models.containers import Container
|
||||
|
@ -19,7 +17,6 @@ from authentik.providers.radius.models import RadiusProvider
|
|||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderRadius(SeleniumTestCase):
|
||||
"""Radius Outpost e2e tests"""
|
||||
|
||||
|
@ -40,7 +37,7 @@ class TestProviderRadius(SeleniumTestCase):
|
|||
container = client.containers.run(
|
||||
image=self.get_container_image("ghcr.io/goauthentik/dev-radius"),
|
||||
detach=True,
|
||||
network_mode="host",
|
||||
ports={"1812/udp": "1812/udp"},
|
||||
environment={
|
||||
"AUTHENTIK_HOST": self.live_server_url,
|
||||
"AUTHENTIK_TOKEN": outpost.token.key,
|
||||
|
|
|
@ -1,8 +1,6 @@
|
|||
"""test SAML Provider flow"""
|
||||
from json import loads
|
||||
from sys import platform
|
||||
from time import sleep
|
||||
from unittest.case import skipUnless
|
||||
|
||||
from docker import DockerClient, from_env
|
||||
from docker.models.containers import Container
|
||||
|
@ -20,7 +18,6 @@ from authentik.sources.saml.processors.constants import SAML_BINDING_POST
|
|||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
@skipUnless(platform.startswith("linux"), "requires local docker")
|
||||
class TestProviderSAML(SeleniumTestCase):
|
||||
"""test SAML Provider flow"""
|
||||
|
||||
|
@ -41,7 +38,9 @@ class TestProviderSAML(SeleniumTestCase):
|
|||
container = client.containers.run(
|
||||
image="ghcr.io/beryju/saml-test-sp:1.1",
|
||||
detach=True,
|
||||
network_mode="host",
|
||||
ports={
|
||||
"9009": "9009",
|
||||
},
|
||||
environment={
|
||||
"SP_ENTITY_ID": provider.issuer,
|
||||
"SP_SSO_BINDING": "urn:oasis:names:tc:SAML:2.0:bindings:HTTP-POST",
|
||||
|
|
141
tests/e2e/test_source_oauth_oauth1.py
Normal file
141
tests/e2e/test_source_oauth_oauth1.py
Normal file
|
@ -0,0 +1,141 @@
|
|||
"""test OAuth Source"""
|
||||
from time import sleep
|
||||
from typing import Any, Optional
|
||||
|
||||
from selenium.webdriver.common.by import By
|
||||
from selenium.webdriver.common.keys import Keys
|
||||
from selenium.webdriver.support import expected_conditions as ec
|
||||
from selenium.webdriver.support.wait import WebDriverWait
|
||||
|
||||
from authentik.blueprints.tests import apply_blueprint
|
||||
from authentik.core.models import User
|
||||
from authentik.flows.models import Flow
|
||||
from authentik.lib.generators import generate_id, generate_key
|
||||
from authentik.sources.oauth.models import OAuthSource
|
||||
from authentik.sources.oauth.types.registry import SourceType, registry
|
||||
from authentik.sources.oauth.views.callback import OAuthCallback
|
||||
from authentik.stages.identification.models import IdentificationStage
|
||||
from tests.e2e.utils import SeleniumTestCase, retry
|
||||
|
||||
|
||||
class OAuth1Callback(OAuthCallback):
|
||||
"""OAuth1 Callback with custom getters"""
|
||||
|
||||
def get_user_id(self, info: dict[str, str]) -> str:
|
||||
return info.get("id")
|
||||
|
||||
def get_user_enroll_context(
|
||||
self,
|
||||
info: dict[str, Any],
|
||||
) -> dict[str, Any]:
|
||||
return {
|
||||
"username": info.get("screen_name"),
|
||||
"email": info.get("email"),
|
||||
"name": info.get("name"),
|
||||
}
|
||||
|
||||
|
||||
@registry.register()
|
||||
class OAUth1Type(SourceType):
|
||||
"""OAuth1 Type definition"""
|
||||
|
||||
callback_view = OAuth1Callback
|
||||
verbose_name = "OAuth1"
|
||||
name = "oauth1"
|
||||
|
||||
request_token_url = "http://localhost:5001/oauth/request_token" # nosec
|
||||
access_token_url = "http://localhost:5001/oauth/access_token" # nosec
|
||||
authorization_url = "http://localhost:5001/oauth/authorize"
|
||||
profile_url = "http://localhost:5001/api/me"
|
||||
urls_customizable = False
|
||||
|
||||
|
||||
class TestSourceOAuth1(SeleniumTestCase):
|
||||
"""Test OAuth1 Source"""
|
||||
|
||||
def setUp(self) -> None:
|
||||
self.client_id = generate_id()
|
||||
self.client_secret = generate_key()
|
||||
self.source_slug = generate_id()
|
||||
super().setUp()
|
||||
|
||||
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||
return {
|
||||
"image": "ghcr.io/beryju/oauth1-test-server:v1.1",
|
||||
"detach": True,
|
||||
"ports": {"5000": "5001"},
|
||||
"auto_remove": True,
|
||||
"environment": {
|
||||
"OAUTH1_CLIENT_ID": self.client_id,
|
||||
"OAUTH1_CLIENT_SECRET": self.client_secret,
|
||||
"OAUTH1_REDIRECT_URI": self.url(
|
||||
"authentik_sources_oauth:oauth-client-callback",
|
||||
source_slug=self.source_slug,
|
||||
),
|
||||
},
|
||||
}
|
||||
|
||||
def create_objects(self):
|
||||
"""Create required objects"""
|
||||
# Bootstrap all needed objects
|
||||
authentication_flow = Flow.objects.get(slug="default-source-authentication")
|
||||
enrollment_flow = Flow.objects.get(slug="default-source-enrollment")
|
||||
|
||||
source = OAuthSource.objects.create( # nosec
|
||||
name=generate_id(),
|
||||
slug=self.source_slug,
|
||||
authentication_flow=authentication_flow,
|
||||
enrollment_flow=enrollment_flow,
|
||||
provider_type="oauth1",
|
||||
consumer_key=self.client_id,
|
||||
consumer_secret=self.client_secret,
|
||||
)
|
||||
ident_stage = IdentificationStage.objects.first()
|
||||
ident_stage.sources.set([source])
|
||||
ident_stage.save()
|
||||
|
||||
@retry()
|
||||
@apply_blueprint(
|
||||
"default/flow-default-authentication-flow.yaml",
|
||||
"default/flow-default-invalidation-flow.yaml",
|
||||
)
|
||||
@apply_blueprint(
|
||||
"default/flow-default-source-authentication.yaml",
|
||||
"default/flow-default-source-enrollment.yaml",
|
||||
"default/flow-default-source-pre-authentication.yaml",
|
||||
)
|
||||
def test_oauth_enroll(self):
|
||||
"""test OAuth Source With With OIDC"""
|
||||
self.create_objects()
|
||||
self.driver.get(self.live_server_url)
|
||||
|
||||
flow_executor = self.get_shadow_root("ak-flow-executor")
|
||||
identification_stage = self.get_shadow_root("ak-stage-identification", flow_executor)
|
||||
wait = WebDriverWait(identification_stage, self.wait_timeout)
|
||||
|
||||
wait.until(
|
||||
ec.presence_of_element_located(
|
||||
(By.CSS_SELECTOR, ".pf-c-login__main-footer-links-item > button")
|
||||
)
|
||||
)
|
||||
identification_stage.find_element(
|
||||
By.CSS_SELECTOR, ".pf-c-login__main-footer-links-item > button"
|
||||
).click()
|
||||
|
||||
# Now we should be at the IDP, wait for the login field
|
||||
self.wait.until(ec.presence_of_element_located((By.NAME, "username")))
|
||||
self.driver.find_element(By.NAME, "username").send_keys("example-user")
|
||||
self.driver.find_element(By.NAME, "username").send_keys(Keys.ENTER)
|
||||
sleep(2)
|
||||
|
||||
# Wait until we're logged in
|
||||
self.wait.until(ec.presence_of_element_located((By.CSS_SELECTOR, "[name='confirm']")))
|
||||
self.driver.find_element(By.CSS_SELECTOR, "[name='confirm']").click()
|
||||
|
||||
# Wait until we've loaded the user info page
|
||||
sleep(2)
|
||||
# Wait until we've logged in
|
||||
self.wait_for_url(self.if_user_url("/library"))
|
||||
self.driver.get(self.if_user_url("/settings"))
|
||||
|
||||
self.assert_user(User(username="example-user", name="test name", email="foo@example.com"))
|
Some files were not shown because too many files have changed in this diff Show more
Reference in a new issue