Compare commits

..

1 commit

Author SHA1 Message Date
Jens Langhammer 7d1efd7450
root: replace celery queues with priority
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2023-10-18 12:50:24 +02:00
324 changed files with 9221 additions and 16186 deletions

View file

@ -1,5 +1,5 @@
[bumpversion] [bumpversion]
current_version = 2023.10.6 current_version = 2023.8.3
tag = True tag = True
commit = True commit = True
parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+) parse = (?P<major>\d+)\.(?P<minor>\d+)\.(?P<patch>\d+)

View file

@ -2,39 +2,36 @@ name: "Setup authentik testing environment"
description: "Setup authentik testing environment" description: "Setup authentik testing environment"
inputs: inputs:
postgresql_version: postgresql_tag:
description: "Optional postgresql image tag" description: "Optional postgresql image tag"
default: "12" default: "12"
runs: runs:
using: "composite" using: "composite"
steps: steps:
- name: Install poetry & deps - name: Install poetry
shell: bash shell: bash
run: | run: |
pipx install poetry || true pipx install poetry || true
sudo apt-get update sudo apt update
sudo apt-get install --no-install-recommends -y libpq-dev openssl libxmlsec1-dev pkg-config gettext sudo apt install -y libpq-dev openssl libxmlsec1-dev pkg-config gettext
- name: Setup python and restore poetry - name: Setup python and restore poetry
uses: actions/setup-python@v4 uses: actions/setup-python@v3
with: with:
python-version-file: 'pyproject.toml' python-version: "3.11"
cache: "poetry" cache: "poetry"
- name: Setup node - name: Setup node
uses: actions/setup-node@v3 uses: actions/setup-node@v3
with: with:
node-version-file: web/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Setup go
uses: actions/setup-go@v4
with:
go-version-file: "go.mod"
- name: Setup dependencies - name: Setup dependencies
shell: bash shell: bash
run: | run: |
export PSQL_TAG=${{ inputs.postgresql_version }} export PSQL_TAG=${{ inputs.postgresql_tag }}
docker-compose -f .github/actions/setup/docker-compose.yml up -d docker-compose -f .github/actions/setup/docker-compose.yml up -d
poetry env use python3.11
poetry install poetry install
cd web && npm ci cd web && npm ci
- name: Generate config - name: Generate config

View file

@ -11,7 +11,6 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- version-*
env: env:
POSTGRES_DB: authentik POSTGRES_DB: authentik
@ -48,38 +47,25 @@ jobs:
- name: run migrations - name: run migrations
run: poetry run python -m lifecycle.migrate run: poetry run python -m lifecycle.migrate
test-migrations-from-stable: test-migrations-from-stable:
name: test-migrations-from-stable - PostgreSQL ${{ matrix.psql }}
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: continue-on-error: true
fail-fast: false
matrix:
psql:
- 12-alpine
- 15-alpine
- 16-alpine
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
with:
postgresql_version: ${{ matrix.psql }}
- name: checkout stable - name: checkout stable
run: | run: |
# Delete all poetry envs
rm -rf /home/runner/.cache/pypoetry
# Copy current, latest config to local # Copy current, latest config to local
cp authentik/lib/default.yml local.env.yml cp authentik/lib/default.yml local.env.yml
cp -R .github .. cp -R .github ..
cp -R scripts .. cp -R scripts ..
git checkout version/$(python -c "from authentik import __version__; print(__version__)") git checkout $(git describe --tags $(git rev-list --tags --max-count=1))
rm -rf .github/ scripts/ rm -rf .github/ scripts/
mv ../.github ../scripts . mv ../.github ../scripts .
- name: Setup authentik env (ensure stable deps are installed) - name: Setup authentik env (ensure stable deps are installed)
uses: ./.github/actions/setup uses: ./.github/actions/setup
with:
postgresql_version: ${{ matrix.psql }}
- name: run migrations to stable - name: run migrations to stable
run: poetry run python -m lifecycle.migrate run: poetry run python -m lifecycle.migrate
- name: checkout current code - name: checkout current code
@ -89,13 +75,9 @@ jobs:
git reset --hard HEAD git reset --hard HEAD
git clean -d -fx . git clean -d -fx .
git checkout $GITHUB_SHA git checkout $GITHUB_SHA
# Delete previous poetry env
rm -rf $(poetry env info --path)
poetry install poetry install
- name: Setup authentik env (ensure latest deps are installed) - name: Setup authentik env (ensure latest deps are installed)
uses: ./.github/actions/setup uses: ./.github/actions/setup
with:
postgresql_version: ${{ matrix.psql }}
- name: migrate to latest - name: migrate to latest
run: poetry run python -m lifecycle.migrate run: poetry run python -m lifecycle.migrate
test-unittest: test-unittest:
@ -114,7 +96,7 @@ jobs:
- name: Setup authentik env - name: Setup authentik env
uses: ./.github/actions/setup uses: ./.github/actions/setup
with: with:
postgresql_version: ${{ matrix.psql }} postgresql_tag: ${{ matrix.psql }}
- name: run unittest - name: run unittest
run: | run: |
poetry run make test poetry run make test
@ -203,9 +185,6 @@ jobs:
build: build:
needs: ci-core-mark needs: ci-core-mark
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
@ -256,9 +235,6 @@ jobs:
build-arm64: build-arm64:
needs: ci-core-mark needs: ci-core-mark
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
timeout-minutes: 120 timeout-minutes: 120
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4

View file

@ -9,7 +9,6 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- version-*
jobs: jobs:
lint-golint: lint-golint:
@ -66,9 +65,6 @@ jobs:
- ldap - ldap
- radius - radius
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
@ -128,9 +124,9 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: web/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Generate API - name: Generate API

View file

@ -9,7 +9,6 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- version-*
jobs: jobs:
lint-eslint: lint-eslint:
@ -22,9 +21,9 @@ jobs:
- tests/wdio - tests/wdio
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: ${{ matrix.project }}/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: ${{ matrix.project }}/package-lock.json cache-dependency-path: ${{ matrix.project }}/package-lock.json
- working-directory: ${{ matrix.project }}/ - working-directory: ${{ matrix.project }}/
@ -38,9 +37,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: web/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -60,9 +59,9 @@ jobs:
- tests/wdio - tests/wdio
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: ${{ matrix.project }}/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: ${{ matrix.project }}/package-lock.json cache-dependency-path: ${{ matrix.project }}/package-lock.json
- working-directory: ${{ matrix.project }}/ - working-directory: ${{ matrix.project }}/
@ -76,9 +75,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: web/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/
@ -108,9 +107,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: web/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- working-directory: web/ - working-directory: web/

View file

@ -9,16 +9,15 @@ on:
pull_request: pull_request:
branches: branches:
- main - main
- version-*
jobs: jobs:
lint-prettier: lint-prettier:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: website/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/
@ -30,9 +29,9 @@ jobs:
runs-on: ubuntu-latest runs-on: ubuntu-latest
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: website/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/
@ -51,9 +50,9 @@ jobs:
- build-docs-only - build-docs-only
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: website/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: website/package-lock.json cache-dependency-path: website/package-lock.json
- working-directory: website/ - working-directory: website/

View file

@ -6,7 +6,6 @@ on:
workflow_dispatch: workflow_dispatch:
permissions: permissions:
# Needed to be able to push to the next branch
contents: write contents: write
jobs: jobs:

View file

@ -7,9 +7,6 @@ on:
jobs: jobs:
build-server: build-server:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
steps: steps:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
- name: Set up QEMU - name: Set up QEMU
@ -30,10 +27,8 @@ jobs:
registry: ghcr.io registry: ghcr.io
username: ${{ github.repository_owner }} username: ${{ github.repository_owner }}
password: ${{ secrets.GITHUB_TOKEN }} password: ${{ secrets.GITHUB_TOKEN }}
- name: make empty clients - name: make empty ts client
run: | run: mkdir -p ./gen-ts-client
mkdir -p ./gen-ts-api
mkdir -p ./gen-go-api
- name: Build Docker Image - name: Build Docker Image
uses: docker/build-push-action@v5 uses: docker/build-push-action@v5
with: with:
@ -55,9 +50,6 @@ jobs:
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost: build-outpost:
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
# Needed to upload contianer images to ghcr.io
packages: write
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -77,10 +69,6 @@ jobs:
- name: prepare variables - name: prepare variables
uses: ./.github/actions/docker-push-variables uses: ./.github/actions/docker-push-variables
id: ev id: ev
- name: make empty clients
run: |
mkdir -p ./gen-ts-api
mkdir -p ./gen-go-api
- name: Docker Login Registry - name: Docker Login Registry
uses: docker/login-action@v3 uses: docker/login-action@v3
with: with:
@ -105,16 +93,12 @@ jobs:
ghcr.io/goauthentik/${{ matrix.type }}:latest ghcr.io/goauthentik/${{ matrix.type }}:latest
file: ${{ matrix.type }}.Dockerfile file: ${{ matrix.type }}.Dockerfile
platforms: linux/amd64,linux/arm64 platforms: linux/amd64,linux/arm64
context: .
build-args: | build-args: |
VERSION=${{ steps.ev.outputs.version }} VERSION=${{ steps.ev.outputs.version }}
VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }} VERSION_FAMILY=${{ steps.ev.outputs.versionFamily }}
build-outpost-binary: build-outpost-binary:
timeout-minutes: 120 timeout-minutes: 120
runs-on: ubuntu-latest runs-on: ubuntu-latest
permissions:
# Needed to upload binaries to the release
contents: write
strategy: strategy:
fail-fast: false fail-fast: false
matrix: matrix:
@ -129,9 +113,9 @@ jobs:
- uses: actions/setup-go@v4 - uses: actions/setup-go@v4
with: with:
go-version-file: "go.mod" go-version-file: "go.mod"
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: web/package.json node-version: "20"
cache: "npm" cache: "npm"
cache-dependency-path: web/package-lock.json cache-dependency-path: web/package-lock.json
- name: Build web - name: Build web

View file

@ -16,7 +16,6 @@ jobs:
echo "PG_PASS=$(openssl rand -base64 32)" >> .env echo "PG_PASS=$(openssl rand -base64 32)" >> .env
echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env echo "AUTHENTIK_SECRET_KEY=$(openssl rand -base64 32)" >> .env
docker buildx install docker buildx install
mkdir -p ./gen-ts-api
docker build -t testing:latest . docker build -t testing:latest .
echo "AUTHENTIK_IMAGE=testing" >> .env echo "AUTHENTIK_IMAGE=testing" >> .env
echo "AUTHENTIK_TAG=latest" >> .env echo "AUTHENTIK_TAG=latest" >> .env

View file

@ -6,8 +6,8 @@ on:
workflow_dispatch: workflow_dispatch:
permissions: permissions:
# Needed to update issues and PRs
issues: write issues: write
pull-requests: write
jobs: jobs:
stale: stale:

View file

@ -17,9 +17,9 @@ jobs:
- uses: actions/checkout@v4 - uses: actions/checkout@v4
with: with:
token: ${{ steps.generate_token.outputs.token }} token: ${{ steps.generate_token.outputs.token }}
- uses: actions/setup-node@v4 - uses: actions/setup-node@v3
with: with:
node-version-file: web/package.json node-version: "20"
registry-url: "https://registry.npmjs.org" registry-url: "https://registry.npmjs.org"
- name: Generate API Client - name: Generate API Client
run: make gen-client-ts run: make gen-client-ts

View file

@ -1,7 +1,5 @@
# syntax=docker/dockerfile:1
# Stage 1: Build website # Stage 1: Build website
FROM --platform=${BUILDPLATFORM} docker.io/node:21 as website-builder FROM --platform=${BUILDPLATFORM} docker.io/node:20 as website-builder
ENV NODE_ENV=production ENV NODE_ENV=production
@ -9,7 +7,7 @@ WORKDIR /work/website
RUN --mount=type=bind,target=/work/website/package.json,src=./website/package.json \ RUN --mount=type=bind,target=/work/website/package.json,src=./website/package.json \
--mount=type=bind,target=/work/website/package-lock.json,src=./website/package-lock.json \ --mount=type=bind,target=/work/website/package-lock.json,src=./website/package-lock.json \
--mount=type=cache,id=npm-website,sharing=shared,target=/root/.npm \ --mount=type=cache,target=/root/.npm \
npm ci --include=dev npm ci --include=dev
COPY ./website /work/website/ COPY ./website /work/website/
@ -19,7 +17,7 @@ COPY ./SECURITY.md /work/
RUN npm run build-docs-only RUN npm run build-docs-only
# Stage 2: Build webui # Stage 2: Build webui
FROM --platform=${BUILDPLATFORM} docker.io/node:21 as web-builder FROM --platform=${BUILDPLATFORM} docker.io/node:20 as web-builder
ENV NODE_ENV=production ENV NODE_ENV=production
@ -27,7 +25,7 @@ WORKDIR /work/web
RUN --mount=type=bind,target=/work/web/package.json,src=./web/package.json \ RUN --mount=type=bind,target=/work/web/package.json,src=./web/package.json \
--mount=type=bind,target=/work/web/package-lock.json,src=./web/package-lock.json \ --mount=type=bind,target=/work/web/package-lock.json,src=./web/package-lock.json \
--mount=type=cache,id=npm-web,sharing=shared,target=/root/.npm \ --mount=type=cache,target=/root/.npm \
npm ci --include=dev npm ci --include=dev
COPY ./web /work/web/ COPY ./web /work/web/
@ -37,14 +35,7 @@ COPY ./gen-ts-api /work/web/node_modules/@goauthentik/api
RUN npm run build RUN npm run build
# Stage 3: Build go proxy # Stage 3: Build go proxy
FROM --platform=${BUILDPLATFORM} docker.io/golang:1.21.4-bookworm AS go-builder FROM docker.io/golang:1.21.3-bookworm AS go-builder
ARG TARGETOS
ARG TARGETARCH
ARG TARGETVARIANT
ARG GOOS=$TARGETOS
ARG GOARCH=$TARGETARCH
WORKDIR /go/src/goauthentik.io WORKDIR /go/src/goauthentik.io
@ -64,12 +55,12 @@ COPY ./go.sum /go/src/goauthentik.io/go.sum
ENV CGO_ENABLED=0 ENV CGO_ENABLED=0
RUN --mount=type=cache,sharing=locked,target=/go/pkg/mod \ RUN --mount=type=cache,target=/go/pkg/mod \
--mount=type=cache,id=go-build-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/root/.cache/go-build \ --mount=type=cache,target=/root/.cache/go-build \
GOARM="${TARGETVARIANT#v}" go build -o /go/authentik ./cmd/server go build -o /go/authentik ./cmd/server
# Stage 4: MaxMind GeoIP # Stage 4: MaxMind GeoIP
FROM --platform=${BUILDPLATFORM} ghcr.io/maxmind/geoipupdate:v6.0 as geoip FROM ghcr.io/maxmind/geoipupdate:v6.0 as geoip
ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City" ENV GEOIPUPDATE_EDITION_IDS="GeoLite2-City"
ENV GEOIPUPDATE_VERBOSE="true" ENV GEOIPUPDATE_VERBOSE="true"
@ -91,9 +82,7 @@ ENV VENV_PATH="/ak-root/venv" \
POETRY_VIRTUALENVS_CREATE=false \ POETRY_VIRTUALENVS_CREATE=false \
PATH="/ak-root/venv/bin:$PATH" PATH="/ak-root/venv/bin:$PATH"
RUN rm -f /etc/apt/apt.conf.d/docker-clean; echo 'Binary::apt::APT::Keep-Downloaded-Packages "true";' > /etc/apt/apt.conf.d/keep-cache RUN --mount=type=cache,target=/var/cache/apt \
RUN --mount=type=cache,id=apt-$TARGETARCH$TARGETVARIANT,sharing=locked,target=/var/cache/apt \
apt-get update && \ apt-get update && \
# Required for installing pip packages # Required for installing pip packages
apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev apt-get install -y --no-install-recommends build-essential pkg-config libxmlsec1-dev zlib1g-dev libpq-dev

View file

@ -56,9 +56,9 @@ test: ## Run the server tests and produce a coverage report (locally)
coverage report coverage report
lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors. lint-fix: ## Lint and automatically fix errors in the python source code. Reports spelling errors.
isort $(PY_SOURCES) isort authentik $(PY_SOURCES)
black $(PY_SOURCES) black authentik $(PY_SOURCES)
ruff $(PY_SOURCES) ruff authentik $(PY_SOURCES)
codespell -w $(CODESPELL_ARGS) codespell -w $(CODESPELL_ARGS)
lint: ## Lint the python and golang sources lint: ## Lint the python and golang sources

View file

@ -2,7 +2,7 @@
from os import environ from os import environ
from typing import Optional from typing import Optional
__version__ = "2023.10.6" __version__ = "2023.8.3"
ENV_GIT_HASH_KEY = "GIT_BUILD_HASH" ENV_GIT_HASH_KEY = "GIT_BUILD_HASH"

View file

@ -1,12 +1,13 @@
"""authentik admin settings""" """authentik admin settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"admin_latest_version": { "admin_latest_version": {
"task": "authentik.admin.tasks.update_latest_version", "task": "authentik.admin.tasks.update_latest_version",
"schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"), "schedule": crontab(minute=fqdn_rand("admin_latest_version"), hour="*"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
} }
} }

View file

@ -21,9 +21,7 @@ _other_urls = []
for _authentik_app in get_apps(): for _authentik_app in get_apps():
try: try:
api_urls = import_module(f"{_authentik_app.name}.urls") api_urls = import_module(f"{_authentik_app.name}.urls")
except ModuleNotFoundError: except (ModuleNotFoundError, ImportError) as exc:
continue
except ImportError as exc:
LOGGER.warning("Could not import app's URLs", app_name=_authentik_app.name, exc=exc) LOGGER.warning("Could not import app's URLs", app_name=_authentik_app.name, exc=exc)
continue continue
if not hasattr(api_urls, "api_urlpatterns"): if not hasattr(api_urls, "api_urlpatterns"):

View file

@ -40,7 +40,7 @@ class ManagedAppConfig(AppConfig):
meth() meth()
self._logger.debug("Successfully reconciled", name=name) self._logger.debug("Successfully reconciled", name=name)
except (DatabaseError, ProgrammingError, InternalError) as exc: except (DatabaseError, ProgrammingError, InternalError) as exc:
self._logger.warning("Failed to run reconcile", name=name, exc=exc) self._logger.debug("Failed to run reconcile", name=name, exc=exc)
class AuthentikBlueprintsConfig(ManagedAppConfig): class AuthentikBlueprintsConfig(ManagedAppConfig):

View file

@ -1,17 +1,18 @@
"""blueprint Settings""" """blueprint Settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"blueprints_v1_discover": { "blueprints_v1_discover": {
"task": "authentik.blueprints.v1.tasks.blueprints_discovery", "task": "authentik.blueprints.v1.tasks.blueprints_discovery",
"schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"), "schedule": crontab(minute=fqdn_rand("blueprints_v1_discover"), hour="*"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
"blueprints_v1_cleanup": { "blueprints_v1_cleanup": {
"task": "authentik.blueprints.v1.tasks.clear_failed_blueprints", "task": "authentik.blueprints.v1.tasks.clear_failed_blueprints",
"schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"), "schedule": crontab(minute=fqdn_rand("blueprints_v1_cleanup"), hour="*"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
} }

View file

@ -584,17 +584,12 @@ class EntryInvalidError(SentryIgnoredException):
entry_model: Optional[str] entry_model: Optional[str]
entry_id: Optional[str] entry_id: Optional[str]
validation_error: Optional[ValidationError] validation_error: Optional[ValidationError]
serializer: Optional[Serializer] = None
def __init__( def __init__(self, *args: object, validation_error: Optional[ValidationError] = None) -> None:
self, *args: object, validation_error: Optional[ValidationError] = None, **kwargs
) -> None:
super().__init__(*args) super().__init__(*args)
self.entry_model = None self.entry_model = None
self.entry_id = None self.entry_id = None
self.validation_error = validation_error self.validation_error = validation_error
for key, value in kwargs.items():
setattr(self, key, value)
@staticmethod @staticmethod
def from_entry( def from_entry(

View file

@ -255,10 +255,7 @@ class Importer:
try: try:
full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import)) full_data = self.__update_pks_for_attrs(entry.get_attrs(self._import))
except ValueError as exc: except ValueError as exc:
raise EntryInvalidError.from_entry( raise EntryInvalidError.from_entry(exc, entry) from exc
exc,
entry,
) from exc
always_merger.merge(full_data, updated_identifiers) always_merger.merge(full_data, updated_identifiers)
serializer_kwargs["data"] = full_data serializer_kwargs["data"] = full_data
@ -275,7 +272,6 @@ class Importer:
f"Serializer errors {serializer.errors}", f"Serializer errors {serializer.errors}",
validation_error=exc, validation_error=exc,
entry=entry, entry=entry,
serializer=serializer,
) from exc ) from exc
return serializer return serializer
@ -304,18 +300,16 @@ class Importer:
) )
return False return False
# Validate each single entry # Validate each single entry
serializer = None
try: try:
serializer = self._validate_single(entry) serializer = self._validate_single(entry)
except EntryInvalidError as exc: except EntryInvalidError as exc:
# For deleting objects we don't need the serializer to be valid # For deleting objects we don't need the serializer to be valid
if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT: if entry.get_state(self._import) == BlueprintEntryDesiredState.ABSENT:
serializer = exc.serializer continue
else: self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc)
self.logger.warning(f"entry invalid: {exc}", entry=entry, error=exc) if raise_errors:
if raise_errors: raise exc
raise exc return False
return False
if not serializer: if not serializer:
continue continue

View file

@ -75,14 +75,14 @@ class BlueprintEventHandler(FileSystemEventHandler):
return return
if event.is_directory: if event.is_directory:
return return
root = Path(CONFIG.get("blueprints_dir")).absolute()
path = Path(event.src_path).absolute()
rel_path = str(path.relative_to(root))
if isinstance(event, FileCreatedEvent): if isinstance(event, FileCreatedEvent):
LOGGER.debug("new blueprint file created, starting discovery", path=rel_path) LOGGER.debug("new blueprint file created, starting discovery")
blueprints_discovery.delay(rel_path) blueprints_discovery.delay()
if isinstance(event, FileModifiedEvent): if isinstance(event, FileModifiedEvent):
for instance in BlueprintInstance.objects.filter(path=rel_path, enabled=True): path = Path(event.src_path)
root = Path(CONFIG.get("blueprints_dir")).absolute()
rel_path = str(path.relative_to(root))
for instance in BlueprintInstance.objects.filter(path=rel_path):
LOGGER.debug("modified blueprint file, starting apply", instance=instance) LOGGER.debug("modified blueprint file, starting apply", instance=instance)
apply_blueprint.delay(instance.pk.hex) apply_blueprint.delay(instance.pk.hex)
@ -98,32 +98,39 @@ def blueprints_find_dict():
return blueprints return blueprints
def blueprints_find() -> list[BlueprintFile]: def blueprints_find():
"""Find blueprints and return valid ones""" """Find blueprints and return valid ones"""
blueprints = [] blueprints = []
root = Path(CONFIG.get("blueprints_dir")) root = Path(CONFIG.get("blueprints_dir"))
for path in root.rglob("**/*.yaml"): for path in root.rglob("**/*.yaml"):
rel_path = path.relative_to(root)
# Check if any part in the path starts with a dot and assume a hidden file # Check if any part in the path starts with a dot and assume a hidden file
if any(part for part in path.parts if part.startswith(".")): if any(part for part in path.parts if part.startswith(".")):
continue continue
LOGGER.debug("found blueprint", path=str(path))
with open(path, "r", encoding="utf-8") as blueprint_file: with open(path, "r", encoding="utf-8") as blueprint_file:
try: try:
raw_blueprint = load(blueprint_file.read(), BlueprintLoader) raw_blueprint = load(blueprint_file.read(), BlueprintLoader)
except YAMLError as exc: except YAMLError as exc:
raw_blueprint = None raw_blueprint = None
LOGGER.warning("failed to parse blueprint", exc=exc, path=str(rel_path)) LOGGER.warning("failed to parse blueprint", exc=exc, path=str(path))
if not raw_blueprint: if not raw_blueprint:
continue continue
metadata = raw_blueprint.get("metadata", None) metadata = raw_blueprint.get("metadata", None)
version = raw_blueprint.get("version", 1) version = raw_blueprint.get("version", 1)
if version != 1: if version != 1:
LOGGER.warning("invalid blueprint version", version=version, path=str(rel_path)) LOGGER.warning("invalid blueprint version", version=version, path=str(path))
continue continue
file_hash = sha512(path.read_bytes()).hexdigest() file_hash = sha512(path.read_bytes()).hexdigest()
blueprint = BlueprintFile(str(rel_path), version, file_hash, int(path.stat().st_mtime)) blueprint = BlueprintFile(
str(path.relative_to(root)), version, file_hash, int(path.stat().st_mtime)
)
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
blueprints.append(blueprint) blueprints.append(blueprint)
LOGGER.debug(
"parsed & loaded blueprint",
hash=file_hash,
path=str(path),
)
return blueprints return blueprints
@ -131,12 +138,10 @@ def blueprints_find() -> list[BlueprintFile]:
throws=(DatabaseError, ProgrammingError, InternalError), base=MonitoredTask, bind=True throws=(DatabaseError, ProgrammingError, InternalError), base=MonitoredTask, bind=True
) )
@prefill_task @prefill_task
def blueprints_discovery(self: MonitoredTask, path: Optional[str] = None): def blueprints_discovery(self: MonitoredTask):
"""Find blueprints and check if they need to be created in the database""" """Find blueprints and check if they need to be created in the database"""
count = 0 count = 0
for blueprint in blueprints_find(): for blueprint in blueprints_find():
if path and blueprint.path != path:
continue
check_blueprint_v1_file(blueprint) check_blueprint_v1_file(blueprint)
count += 1 count += 1
self.set_status( self.set_status(
@ -166,11 +171,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
metadata={}, metadata={},
) )
instance.save() instance.save()
LOGGER.info(
"Creating new blueprint instance from file", instance=instance, path=instance.path
)
if instance.last_applied_hash != blueprint.hash: if instance.last_applied_hash != blueprint.hash:
LOGGER.info("Applying blueprint due to changed file", instance=instance, path=instance.path)
apply_blueprint.delay(str(instance.pk)) apply_blueprint.delay(str(instance.pk))

View file

@ -98,7 +98,6 @@ class ApplicationSerializer(ModelSerializer):
class ApplicationViewSet(UsedByMixin, ModelViewSet): class ApplicationViewSet(UsedByMixin, ModelViewSet):
"""Application Viewset""" """Application Viewset"""
# pylint: disable=no-member
queryset = Application.objects.all().prefetch_related("provider") queryset = Application.objects.all().prefetch_related("provider")
serializer_class = ApplicationSerializer serializer_class = ApplicationSerializer
search_fields = [ search_fields = [

View file

@ -139,7 +139,6 @@ class UserAccountSerializer(PassiveSerializer):
class GroupViewSet(UsedByMixin, ModelViewSet): class GroupViewSet(UsedByMixin, ModelViewSet):
"""Group Viewset""" """Group Viewset"""
# pylint: disable=no-member
queryset = Group.objects.all().select_related("parent").prefetch_related("users") queryset = Group.objects.all().select_related("parent").prefetch_related("users")
serializer_class = GroupSerializer serializer_class = GroupSerializer
search_fields = ["name", "is_superuser"] search_fields = ["name", "is_superuser"]

View file

@ -38,7 +38,7 @@ class SourceSerializer(ModelSerializer, MetaNameSerializer):
managed = ReadOnlyField() managed = ReadOnlyField()
component = SerializerMethodField() component = SerializerMethodField()
icon = ReadOnlyField(source="icon_url") icon = ReadOnlyField(source="get_icon")
def get_component(self, obj: Source) -> str: def get_component(self, obj: Source) -> str:
"""Get object component so that we know how to edit the object""" """Get object component so that we know how to edit the object"""

View file

@ -171,11 +171,6 @@ class UserSerializer(ModelSerializer):
raise ValidationError("Setting a user to internal service account is not allowed.") raise ValidationError("Setting a user to internal service account is not allowed.")
return user_type return user_type
def validate(self, attrs: dict) -> dict:
if self.instance and self.instance.type == UserTypes.INTERNAL_SERVICE_ACCOUNT:
raise ValidationError("Can't modify internal service account users")
return super().validate(attrs)
class Meta: class Meta:
model = User model = User
fields = [ fields = [

View file

@ -44,7 +44,6 @@ class PropertyMappingEvaluator(BaseEvaluator):
if request: if request:
req.http_request = request req.http_request = request
self._context["request"] = req self._context["request"] = req
req.context.update(**kwargs)
self._context.update(**kwargs) self._context.update(**kwargs)
self.dry_run = dry_run self.dry_run = dry_run

View file

@ -17,15 +17,9 @@ class Command(BaseCommand):
"""Run worker""" """Run worker"""
def add_arguments(self, parser): def add_arguments(self, parser):
parser.add_argument( parser.add_argument("-b", "--beat", action="store_true")
"-b",
"--beat",
action="store_false",
help="When set, this worker will _not_ run Beat (scheduled) tasks",
)
def handle(self, **options): def handle(self, **options):
LOGGER.debug("Celery options", **options)
close_old_connections() close_old_connections()
if CONFIG.get_bool("remote_debug"): if CONFIG.get_bool("remote_debug"):
import debugpy import debugpy
@ -39,7 +33,6 @@ class Command(BaseCommand):
task_events=True, task_events=True,
beat=options.get("beat", True), beat=options.get("beat", True),
schedule_filename=f"{tempdir}/celerybeat-schedule", schedule_filename=f"{tempdir}/celerybeat-schedule",
queues=["authentik", "authentik_scheduled", "authentik_events"],
) )
for task in CELERY_APP.tasks: for task in CELERY_APP.tasks:
LOGGER.debug("Registered task", task=task) LOGGER.debug("Registered task", task=task)

View file

@ -97,7 +97,6 @@ class SourceFlowManager:
if self.request.user.is_authenticated: if self.request.user.is_authenticated:
new_connection.user = self.request.user new_connection.user = self.request.user
new_connection = self.update_connection(new_connection, **kwargs) new_connection = self.update_connection(new_connection, **kwargs)
# pylint: disable=no-member
new_connection.save() new_connection.save()
return Action.LINK, new_connection return Action.LINK, new_connection

View file

@ -13,6 +13,7 @@
{% block head_before %} {% block head_before %}
{% endblock %} {% endblock %}
<link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}"> <link rel="stylesheet" type="text/css" href="{% static 'dist/authentik.css' %}">
<link rel="stylesheet" type="text/css" href="{% static 'dist/theme-dark.css' %}" media="(prefers-color-scheme: dark)">
<link rel="stylesheet" type="text/css" href="{% static 'dist/custom.css' %}" data-inject> <link rel="stylesheet" type="text/css" href="{% static 'dist/custom.css' %}" data-inject>
<script src="{% static 'dist/poly.js' %}?version={{ version }}" type="module"></script> <script src="{% static 'dist/poly.js' %}?version={{ version }}" type="module"></script>
<script src="{% static 'dist/standalone/loading/index.js' %}?version={{ version }}" type="module"></script> <script src="{% static 'dist/standalone/loading/index.js' %}?version={{ version }}" type="module"></script>

View file

@ -16,8 +16,8 @@ You've logged out of {{ application }}.
{% block card %} {% block card %}
<form method="POST" class="pf-c-form"> <form method="POST" class="pf-c-form">
<p> <p>
{% blocktrans with application=application.name branding_title=tenant.branding_title %} {% blocktrans with application=application.name %}
You've logged out of {{ application }}. You can go back to the overview to launch another application, or log out of your {{ branding_title }} account. You've logged out of {{ application }}. You can go back to the overview to launch another application, or log out of your authentik account.
{% endblocktrans %} {% endblocktrans %}
</p> </p>

View file

@ -6,7 +6,6 @@
{% block head_before %} {% block head_before %}
<link rel="prefetch" href="/static/dist/assets/images/flow_background.jpg" /> <link rel="prefetch" href="/static/dist/assets/images/flow_background.jpg" />
<link rel="stylesheet" type="text/css" href="{% static 'dist/patternfly.min.css' %}"> <link rel="stylesheet" type="text/css" href="{% static 'dist/patternfly.min.css' %}">
<link rel="stylesheet" type="text/css" href="{% static 'dist/theme-dark.css' %}" media="(prefers-color-scheme: dark)">
{% include "base/header_js.html" %} {% include "base/header_js.html" %}
{% endblock %} {% endblock %}

View file

@ -1,10 +1,13 @@
"""authentik crypto app config""" """authentik crypto app config"""
from datetime import datetime from datetime import datetime
from typing import Optional from typing import TYPE_CHECKING, Optional
from authentik.blueprints.apps import ManagedAppConfig from authentik.blueprints.apps import ManagedAppConfig
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
if TYPE_CHECKING:
from authentik.crypto.models import CertificateKeyPair
MANAGED_KEY = "goauthentik.io/crypto/jwt-managed" MANAGED_KEY = "goauthentik.io/crypto/jwt-managed"
@ -20,37 +23,33 @@ class AuthentikCryptoConfig(ManagedAppConfig):
"""Load crypto tasks""" """Load crypto tasks"""
self.import_module("authentik.crypto.tasks") self.import_module("authentik.crypto.tasks")
def _create_update_cert(self): def _create_update_cert(self, cert: Optional["CertificateKeyPair"] = None):
from authentik.crypto.builder import CertificateBuilder from authentik.crypto.builder import CertificateBuilder
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
common_name = "authentik Internal JWT Certificate" builder = CertificateBuilder("authentik Internal JWT Certificate")
builder = CertificateBuilder(common_name)
builder.build( builder.build(
subject_alt_names=["goauthentik.io"], subject_alt_names=["goauthentik.io"],
validity_days=360, validity_days=360,
) )
CertificateKeyPair.objects.update_or_create( if not cert:
managed=MANAGED_KEY, cert = CertificateKeyPair()
defaults={ builder.cert = cert
"name": common_name, builder.cert.managed = MANAGED_KEY
"certificate_data": builder.certificate, builder.save()
"key_data": builder.private_key,
},
)
def reconcile_managed_jwt_cert(self): def reconcile_managed_jwt_cert(self):
"""Ensure managed JWT certificate""" """Ensure managed JWT certificate"""
from authentik.crypto.models import CertificateKeyPair from authentik.crypto.models import CertificateKeyPair
cert: Optional[CertificateKeyPair] = CertificateKeyPair.objects.filter( certs = CertificateKeyPair.objects.filter(managed=MANAGED_KEY)
managed=MANAGED_KEY if not certs.exists():
).first()
now = datetime.now()
if not cert or (
now < cert.certificate.not_valid_before or now > cert.certificate.not_valid_after
):
self._create_update_cert() self._create_update_cert()
return
cert: CertificateKeyPair = certs.first()
now = datetime.now()
if now < cert.certificate.not_valid_before or now > cert.certificate.not_valid_after:
self._create_update_cert(cert)
def reconcile_self_signed(self): def reconcile_self_signed(self):
"""Create self-signed keypair""" """Create self-signed keypair"""
@ -62,10 +61,4 @@ class AuthentikCryptoConfig(ManagedAppConfig):
return return
builder = CertificateBuilder(name) builder = CertificateBuilder(name)
builder.build(subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"]) builder.build(subject_alt_names=[f"{generate_id()}.self-signed.goauthentik.io"])
CertificateKeyPair.objects.get_or_create( builder.save()
name=name,
defaults={
"certificate_data": builder.certificate,
"key_data": builder.private_key,
},
)

View file

@ -1,12 +1,13 @@
"""Crypto task Settings""" """Crypto task Settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"crypto_certificate_discovery": { "crypto_certificate_discovery": {
"task": "authentik.crypto.tasks.certificate_discovery", "task": "authentik.crypto.tasks.certificate_discovery",
"schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"), "schedule": crontab(minute=fqdn_rand("crypto_certificate_discovery"), hour="*"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
} }

View file

@ -136,9 +136,6 @@ class LicenseKey:
def record_usage(self): def record_usage(self):
"""Capture the current validity status and metrics and save them""" """Capture the current validity status and metrics and save them"""
threshold = now() - timedelta(hours=8)
if LicenseUsage.objects.filter(record_date__gte=threshold).exists():
return
LicenseUsage.objects.create( LicenseUsage.objects.create(
user_count=self.get_default_user_count(), user_count=self.get_default_user_count(),
external_user_count=self.get_external_user_count(), external_user_count=self.get_external_user_count(),

View file

@ -1,12 +1,13 @@
"""Enterprise additional settings""" """Enterprise additional settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"enterprise_calculate_license": { "enterprise_calculate_license": {
"task": "authentik.enterprise.tasks.calculate_license", "task": "authentik.enterprise.tasks.calculate_license",
"schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/2"), "schedule": crontab(minute=fqdn_rand("calculate_license"), hour="*/8"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
} }
} }

View file

@ -6,4 +6,5 @@ from authentik.root.celery import CELERY_APP
@CELERY_APP.task() @CELERY_APP.task()
def calculate_license(): def calculate_license():
"""Calculate licensing status""" """Calculate licensing status"""
LicenseKey.get_total().record_usage() total = LicenseKey.get_total()
total.record_usage()

View file

@ -27,7 +27,6 @@ from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
from authentik.outposts.models import OutpostServiceConnection from authentik.outposts.models import OutpostServiceConnection
from authentik.policies.models import Policy, PolicyBindingModel from authentik.policies.models import Policy, PolicyBindingModel
from authentik.policies.reputation.models import Reputation
from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken from authentik.providers.oauth2.models import AccessToken, AuthorizationCode, RefreshToken
from authentik.providers.scim.models import SCIMGroup, SCIMUser from authentik.providers.scim.models import SCIMGroup, SCIMUser
from authentik.stages.authenticator_static.models import StaticToken from authentik.stages.authenticator_static.models import StaticToken
@ -53,13 +52,11 @@ IGNORED_MODELS = (
RefreshToken, RefreshToken,
SCIMUser, SCIMUser,
SCIMGroup, SCIMGroup,
Reputation,
) )
def should_log_model(model: Model) -> bool: def should_log_model(model: Model) -> bool:
"""Return true if operation on `model` should be logged""" """Return true if operation on `model` should be logged"""
# Check for silk by string so this comparison doesn't fail when silk isn't installed
if model.__module__.startswith("silk"): if model.__module__.startswith("silk"):
return False return False
return model.__class__ not in IGNORED_MODELS return model.__class__ not in IGNORED_MODELS
@ -96,30 +93,21 @@ class AuditMiddleware:
of models""" of models"""
get_response: Callable[[HttpRequest], HttpResponse] get_response: Callable[[HttpRequest], HttpResponse]
anonymous_user: User = None
def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]): def __init__(self, get_response: Callable[[HttpRequest], HttpResponse]):
self.get_response = get_response self.get_response = get_response
def _ensure_fallback_user(self):
"""Defer fetching anonymous user until we have to"""
if self.anonymous_user:
return
from guardian.shortcuts import get_anonymous_user
self.anonymous_user = get_anonymous_user()
def connect(self, request: HttpRequest): def connect(self, request: HttpRequest):
"""Connect signal for automatic logging""" """Connect signal for automatic logging"""
self._ensure_fallback_user() if not hasattr(request, "user"):
user = getattr(request, "user", self.anonymous_user) return
if not user.is_authenticated: if not getattr(request.user, "is_authenticated", False):
user = self.anonymous_user return
if not hasattr(request, "request_id"): if not hasattr(request, "request_id"):
return return
post_save_handler = partial(self.post_save_handler, user=user, request=request) post_save_handler = partial(self.post_save_handler, user=request.user, request=request)
pre_delete_handler = partial(self.pre_delete_handler, user=user, request=request) pre_delete_handler = partial(self.pre_delete_handler, user=request.user, request=request)
m2m_changed_handler = partial(self.m2m_changed_handler, user=user, request=request) m2m_changed_handler = partial(self.m2m_changed_handler, user=request.user, request=request)
post_save.connect( post_save.connect(
post_save_handler, post_save_handler,
dispatch_uid=request.request_id, dispatch_uid=request.request_id,

View file

@ -217,7 +217,6 @@ class Event(SerializerModel, ExpiringModel):
"path": request.path, "path": request.path,
"method": request.method, "method": request.method,
"args": cleanse_dict(QueryDict(request.META.get("QUERY_STRING", ""))), "args": cleanse_dict(QueryDict(request.META.get("QUERY_STRING", ""))),
"user_agent": request.META.get("HTTP_USER_AGENT", ""),
} }
# Special case for events created during flow execution # Special case for events created during flow execution
# since they keep the http query within a wrapped query # since they keep the http query within a wrapped query

View file

@ -1,12 +1,13 @@
"""Event Settings""" """Event Settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"events_notification_cleanup": { "events_notification_cleanup": {
"task": "authentik.events.tasks.notification_cleanup", "task": "authentik.events.tasks.notification_cleanup",
"schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"), "schedule": crontab(minute=fqdn_rand("notification_cleanup"), hour="*/8"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
} }

View file

@ -13,7 +13,6 @@ from authentik.events.tasks import event_notification_handler, gdpr_cleanup
from authentik.flows.models import Stage from authentik.flows.models import Stage
from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan from authentik.flows.planner import PLAN_CONTEXT_SOURCE, FlowPlan
from authentik.flows.views.executor import SESSION_KEY_PLAN from authentik.flows.views.executor import SESSION_KEY_PLAN
from authentik.lib.config import CONFIG
from authentik.stages.invitation.models import Invitation from authentik.stages.invitation.models import Invitation
from authentik.stages.invitation.signals import invitation_used from authentik.stages.invitation.signals import invitation_used
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
@ -93,5 +92,4 @@ def event_post_save_notification(sender, instance: Event, **_):
@receiver(pre_delete, sender=User) @receiver(pre_delete, sender=User)
def event_user_pre_delete_cleanup(sender, instance: User, **_): def event_user_pre_delete_cleanup(sender, instance: User, **_):
"""If gdpr_compliance is enabled, remove all the user's events""" """If gdpr_compliance is enabled, remove all the user's events"""
if CONFIG.get_bool("gdpr_compliance", True): gdpr_cleanup.delay(instance.pk)
gdpr_cleanup.delay(instance.pk)

View file

@ -20,6 +20,7 @@ from authentik.events.monitored_tasks import (
TaskResultStatus, TaskResultStatus,
prefill_task, prefill_task,
) )
from authentik.lib.config import CONFIG
from authentik.policies.engine import PolicyEngine from authentik.policies.engine import PolicyEngine
from authentik.policies.models import PolicyBinding, PolicyEngineMode from authentik.policies.models import PolicyBinding, PolicyEngineMode
from authentik.root.celery import CELERY_APP from authentik.root.celery import CELERY_APP
@ -89,7 +90,7 @@ def event_trigger_handler(event_uuid: str, trigger_name: str):
user.pk, user.pk,
str(trigger.pk), str(trigger.pk),
], ],
queue="authentik_events", priority=CONFIG.get_int("worker.priority.events"),
) )
if transport.send_once: if transport.send_once:
break break

View file

@ -53,15 +53,7 @@ class TestEvents(TestCase):
"""Test plain from_http""" """Test plain from_http"""
event = Event.new("unittest").from_http(self.factory.get("/")) event = Event.new("unittest").from_http(self.factory.get("/"))
self.assertEqual( self.assertEqual(
event.context, event.context, {"http_request": {"args": {}, "method": "GET", "path": "/"}}
{
"http_request": {
"args": {},
"method": "GET",
"path": "/",
"user_agent": "",
}
},
) )
def test_from_http_clean_querystring(self): def test_from_http_clean_querystring(self):
@ -75,7 +67,6 @@ class TestEvents(TestCase):
"args": {"token": SafeExceptionReporterFilter.cleansed_substitute}, "args": {"token": SafeExceptionReporterFilter.cleansed_substitute},
"method": "GET", "method": "GET",
"path": "/", "path": "/",
"user_agent": "",
} }
}, },
) )
@ -92,7 +83,6 @@ class TestEvents(TestCase):
"args": {"token": SafeExceptionReporterFilter.cleansed_substitute}, "args": {"token": SafeExceptionReporterFilter.cleansed_substitute},
"method": "GET", "method": "GET",
"path": "/", "path": "/",
"user_agent": "",
} }
}, },
) )

View file

@ -5,13 +5,12 @@ from dataclasses import asdict, is_dataclass
from datetime import date, datetime, time, timedelta from datetime import date, datetime, time, timedelta
from enum import Enum from enum import Enum
from pathlib import Path from pathlib import Path
from types import GeneratorType, NoneType from types import GeneratorType
from typing import Any, Optional from typing import Any, Optional
from uuid import UUID from uuid import UUID
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
from django.core.handlers.wsgi import WSGIRequest from django.core.handlers.wsgi import WSGIRequest
from django.core.serializers.json import DjangoJSONEncoder
from django.db import models from django.db import models
from django.db.models.base import Model from django.db.models.base import Model
from django.http.request import HttpRequest from django.http.request import HttpRequest
@ -154,20 +153,7 @@ def sanitize_item(value: Any) -> Any:
return value.isoformat() return value.isoformat()
if isinstance(value, timedelta): if isinstance(value, timedelta):
return str(value.total_seconds()) return str(value.total_seconds())
if callable(value): return value
return {
"type": "callable",
"name": value.__name__,
"module": value.__module__,
}
# List taken from the stdlib's JSON encoder (_make_iterencode, encoder.py:415)
if isinstance(value, (bool, int, float, NoneType, list, tuple, dict)):
return value
try:
return DjangoJSONEncoder().default(value)
except TypeError:
return str(value)
return str(value)
def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]: def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:

View file

@ -1,34 +0,0 @@
# Generated by Django 4.2.6 on 2023-10-28 14:24
from django.apps.registry import Apps
from django.db import migrations
from django.db.backends.base.schema import BaseDatabaseSchemaEditor
def set_oobe_flow_authentication(apps: Apps, schema_editor: BaseDatabaseSchemaEditor):
from guardian.shortcuts import get_anonymous_user
Flow = apps.get_model("authentik_flows", "Flow")
User = apps.get_model("authentik_core", "User")
db_alias = schema_editor.connection.alias
users = User.objects.using(db_alias).exclude(username="akadmin")
try:
users = users.exclude(pk=get_anonymous_user().pk)
# pylint: disable=broad-except
except Exception: # nosec
pass
if users.exists():
Flow.objects.filter(slug="initial-setup").update(authentication="require_superuser")
class Migration(migrations.Migration):
dependencies = [
("authentik_flows", "0026_alter_flow_options"),
]
operations = [
migrations.RunPython(set_oobe_flow_authentication),
]

View file

@ -167,11 +167,7 @@ class ChallengeStageView(StageView):
stage_type=self.__class__.__name__, method="get_challenge" stage_type=self.__class__.__name__, method="get_challenge"
).time(), ).time(),
): ):
try: challenge = self.get_challenge(*args, **kwargs)
challenge = self.get_challenge(*args, **kwargs)
except StageInvalidException as exc:
self.logger.debug("Got StageInvalidException", exc=exc)
return self.executor.stage_invalid()
with Hub.current.start_span( with Hub.current.start_span(
op="authentik.flow.stage._get_challenge", op="authentik.flow.stage._get_challenge",
description=self.__class__.__name__, description=self.__class__.__name__,

View file

@ -114,3 +114,8 @@ web:
worker: worker:
concurrency: 2 concurrency: 2
priority:
default: 4
scheduled: 9
sync: 9
events: 8

View file

@ -18,7 +18,7 @@ from authentik.core.api.used_by import UsedByMixin
from authentik.core.api.utils import PassiveSerializer, is_dict from authentik.core.api.utils import PassiveSerializer, is_dict
from authentik.core.models import Provider from authentik.core.models import Provider
from authentik.outposts.api.service_connections import ServiceConnectionSerializer from authentik.outposts.api.service_connections import ServiceConnectionSerializer
from authentik.outposts.apps import MANAGED_OUTPOST, MANAGED_OUTPOST_NAME from authentik.outposts.apps import MANAGED_OUTPOST
from authentik.outposts.models import ( from authentik.outposts.models import (
Outpost, Outpost,
OutpostConfig, OutpostConfig,
@ -47,16 +47,6 @@ class OutpostSerializer(ModelSerializer):
source="service_connection", read_only=True source="service_connection", read_only=True
) )
def validate_name(self, name: str) -> str:
"""Validate name (especially for embedded outpost)"""
if not self.instance:
return name
if self.instance.managed == MANAGED_OUTPOST and name != MANAGED_OUTPOST_NAME:
raise ValidationError("Embedded outpost's name cannot be changed")
if self.instance.name == MANAGED_OUTPOST_NAME:
self.instance.managed = MANAGED_OUTPOST
return name
def validate_providers(self, providers: list[Provider]) -> list[Provider]: def validate_providers(self, providers: list[Provider]) -> list[Provider]:
"""Check that all providers match the type of the outpost""" """Check that all providers match the type of the outpost"""
type_map = { type_map = {

View file

@ -15,7 +15,6 @@ GAUGE_OUTPOSTS_LAST_UPDATE = Gauge(
["outpost", "uid", "version"], ["outpost", "uid", "version"],
) )
MANAGED_OUTPOST = "goauthentik.io/outposts/embedded" MANAGED_OUTPOST = "goauthentik.io/outposts/embedded"
MANAGED_OUTPOST_NAME = "authentik Embedded Outpost"
class AuthentikOutpostConfig(ManagedAppConfig): class AuthentikOutpostConfig(ManagedAppConfig):
@ -36,17 +35,14 @@ class AuthentikOutpostConfig(ManagedAppConfig):
DockerServiceConnection, DockerServiceConnection,
KubernetesServiceConnection, KubernetesServiceConnection,
Outpost, Outpost,
OutpostConfig,
OutpostType, OutpostType,
) )
if outpost := Outpost.objects.filter(name=MANAGED_OUTPOST_NAME, managed="").first():
outpost.managed = MANAGED_OUTPOST
outpost.save()
return
outpost, updated = Outpost.objects.update_or_create( outpost, updated = Outpost.objects.update_or_create(
defaults={ defaults={
"name": "authentik Embedded Outpost",
"type": OutpostType.PROXY, "type": OutpostType.PROXY,
"name": MANAGED_OUTPOST_NAME,
}, },
managed=MANAGED_OUTPOST, managed=MANAGED_OUTPOST,
) )
@ -55,4 +51,10 @@ class AuthentikOutpostConfig(ManagedAppConfig):
outpost.service_connection = KubernetesServiceConnection.objects.first() outpost.service_connection = KubernetesServiceConnection.objects.first()
elif DockerServiceConnection.objects.exists(): elif DockerServiceConnection.objects.exists():
outpost.service_connection = DockerServiceConnection.objects.first() outpost.service_connection = DockerServiceConnection.objects.first()
outpost.config = OutpostConfig(
kubernetes_disabled_components=[
"deployment",
"secret",
]
)
outpost.save() outpost.save()

View file

@ -43,10 +43,6 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
self.api = AppsV1Api(controller.client) self.api = AppsV1Api(controller.client)
self.outpost = self.controller.outpost self.outpost = self.controller.outpost
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod @staticmethod
def reconciler_name() -> str: def reconciler_name() -> str:
return "deployment" return "deployment"

View file

@ -24,10 +24,6 @@ class SecretReconciler(KubernetesObjectReconciler[V1Secret]):
super().__init__(controller) super().__init__(controller)
self.api = CoreV1Api(controller.client) self.api = CoreV1Api(controller.client)
@property
def noop(self) -> bool:
return self.is_embedded
@staticmethod @staticmethod
def reconciler_name() -> str: def reconciler_name() -> str:
return "secret" return "secret"

View file

@ -77,10 +77,7 @@ class PrometheusServiceMonitorReconciler(KubernetesObjectReconciler[PrometheusSe
@property @property
def noop(self) -> bool: def noop(self) -> bool:
if not self._crd_exists(): return (not self._crd_exists()) or (self.is_embedded)
self.logger.debug("CRD doesn't exist")
return True
return self.is_embedded
def _crd_exists(self) -> bool: def _crd_exists(self) -> bool:
"""Check if the Prometheus ServiceMonitor exists""" """Check if the Prometheus ServiceMonitor exists"""

View file

@ -344,22 +344,12 @@ class Outpost(SerializerModel, ManagedModel):
user_created = False user_created = False
if not user: if not user:
user: User = User.objects.create(username=self.user_identifier) user: User = User.objects.create(username=self.user_identifier)
user_created = True
attrs = {
"type": UserTypes.INTERNAL_SERVICE_ACCOUNT,
"name": f"Outpost {self.name} Service-Account",
"path": USER_PATH_OUTPOSTS,
}
dirty = False
for key, value in attrs.items():
if getattr(user, key) != value:
dirty = True
setattr(user, key, value)
if user.has_usable_password():
user.set_unusable_password() user.set_unusable_password()
dirty = True user_created = True
if dirty: user.type = UserTypes.INTERNAL_SERVICE_ACCOUNT
user.save() user.name = f"Outpost {self.name} Service-Account"
user.path = USER_PATH_OUTPOSTS
user.save()
if user_created: if user_created:
self.build_user_permissions(user) self.build_user_permissions(user)
return user return user

View file

@ -1,27 +1,28 @@
"""Outposts Settings""" """Outposts Settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"outposts_controller": { "outposts_controller": {
"task": "authentik.outposts.tasks.outpost_controller_all", "task": "authentik.outposts.tasks.outpost_controller_all",
"schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"), "schedule": crontab(minute=fqdn_rand("outposts_controller"), hour="*/4"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
"outposts_service_connection_check": { "outposts_service_connection_check": {
"task": "authentik.outposts.tasks.outpost_service_connection_monitor", "task": "authentik.outposts.tasks.outpost_service_connection_monitor",
"schedule": crontab(minute="3-59/15"), "schedule": crontab(minute="3-59/15"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
"outpost_token_ensurer": { "outpost_token_ensurer": {
"task": "authentik.outposts.tasks.outpost_token_ensurer", "task": "authentik.outposts.tasks.outpost_token_ensurer",
"schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"), "schedule": crontab(minute=fqdn_rand("outpost_token_ensurer"), hour="*/8"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
"outpost_connection_discovery": { "outpost_connection_discovery": {
"task": "authentik.outposts.tasks.outpost_connection_discovery", "task": "authentik.outposts.tasks.outpost_connection_discovery",
"schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"), "schedule": crontab(minute=fqdn_rand("outpost_connection_discovery"), hour="*/8"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
} }

View file

@ -2,13 +2,11 @@
from django.urls import reverse from django.urls import reverse
from rest_framework.test import APITestCase from rest_framework.test import APITestCase
from authentik.blueprints.tests import reconcile_app
from authentik.core.models import PropertyMapping from authentik.core.models import PropertyMapping
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.lib.generators import generate_id from authentik.lib.generators import generate_id
from authentik.outposts.api.outposts import OutpostSerializer from authentik.outposts.api.outposts import OutpostSerializer
from authentik.outposts.apps import MANAGED_OUTPOST from authentik.outposts.models import OutpostType, default_outpost_config
from authentik.outposts.models import Outpost, OutpostType, default_outpost_config
from authentik.providers.ldap.models import LDAPProvider from authentik.providers.ldap.models import LDAPProvider
from authentik.providers.proxy.models import ProxyProvider from authentik.providers.proxy.models import ProxyProvider
@ -24,36 +22,7 @@ class TestOutpostServiceConnectionsAPI(APITestCase):
self.user = create_test_admin_user() self.user = create_test_admin_user()
self.client.force_login(self.user) self.client.force_login(self.user)
@reconcile_app("authentik_outposts") def test_outpost_validaton(self):
def test_managed_name_change(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 400)
self.assertJSONEqual(
response.content, {"name": ["Embedded outpost's name cannot be changed"]}
)
@reconcile_app("authentik_outposts")
def test_managed_without_managed(self):
"""Test name change for embedded outpost"""
embedded_outpost = Outpost.objects.filter(managed=MANAGED_OUTPOST).first()
self.assertIsNotNone(embedded_outpost)
embedded_outpost.managed = ""
embedded_outpost.save()
response = self.client.patch(
reverse("authentik_api:outpost-detail", kwargs={"pk": embedded_outpost.pk}),
{"name": "foo"},
)
self.assertEqual(response.status_code, 200)
embedded_outpost.refresh_from_db()
self.assertEqual(embedded_outpost.managed, MANAGED_OUTPOST)
def test_outpost_validation(self):
"""Test Outpost validation""" """Test Outpost validation"""
valid = OutpostSerializer( valid = OutpostSerializer(
data={ data={

View file

@ -1,10 +1,12 @@
"""Reputation Settings""" """Reputation Settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"policies_reputation_save": { "policies_reputation_save": {
"task": "authentik.policies.reputation.tasks.save_reputation", "task": "authentik.policies.reputation.tasks.save_reputation",
"schedule": crontab(minute="1-59/5"), "schedule": crontab(minute="1-59/5"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
} }

View file

@ -1,27 +0,0 @@
# Generated by Django 5.0 on 2023-12-22 23:20
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_oauth2", "0016_alter_refreshtoken_token"),
]
operations = [
migrations.AddField(
model_name="accesstoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="authorizationcode",
name="session_id",
field=models.CharField(blank=True, default=""),
),
migrations.AddField(
model_name="refreshtoken",
name="session_id",
field=models.CharField(blank=True, default=""),
),
]

View file

@ -296,7 +296,6 @@ class BaseGrantModel(models.Model):
revoked = models.BooleanField(default=False) revoked = models.BooleanField(default=False)
_scope = models.TextField(default="", verbose_name=_("Scopes")) _scope = models.TextField(default="", verbose_name=_("Scopes"))
auth_time = models.DateTimeField(verbose_name="Authentication time") auth_time = models.DateTimeField(verbose_name="Authentication time")
session_id = models.CharField(default="", blank=True)
@property @property
def scope(self) -> list[str]: def scope(self) -> list[str]:

View file

@ -85,25 +85,6 @@ class TestAuthorize(OAuthTestCase):
) )
OAuthAuthorizationParams.from_request(request) OAuthAuthorizationParams.from_request(request)
def test_blocked_redirect_uri(self):
"""test missing/invalid redirect URI"""
OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=create_test_flow(),
redirect_uris="data:local.invalid",
)
with self.assertRaises(RedirectUriError):
request = self.factory.get(
"/",
data={
"response_type": "code",
"client_id": "test",
"redirect_uri": "data:localhost",
},
)
OAuthAuthorizationParams.from_request(request)
def test_invalid_redirect_uri_empty(self): def test_invalid_redirect_uri_empty(self):
"""test missing/invalid redirect URI""" """test missing/invalid redirect URI"""
provider = OAuth2Provider.objects.create( provider = OAuth2Provider.objects.create(

View file

@ -1,187 +0,0 @@
"""Test token view"""
from base64 import b64encode, urlsafe_b64encode
from hashlib import sha256
from django.test import RequestFactory
from django.urls import reverse
from authentik.core.models import Application
from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.challenge import ChallengeTypes
from authentik.lib.generators import generate_id
from authentik.providers.oauth2.constants import GRANT_TYPE_AUTHORIZATION_CODE
from authentik.providers.oauth2.models import AuthorizationCode, OAuth2Provider
from authentik.providers.oauth2.tests.utils import OAuthTestCase
class TestTokenPKCE(OAuthTestCase):
"""Test token view"""
def setUp(self) -> None:
super().setUp()
self.factory = RequestFactory()
self.app = Application.objects.create(name=generate_id(), slug="test")
def test_pkce_missing_in_token(self):
"""Test full with pkce"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
challenge = generate_id()
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"code_challenge": challenge,
"code_challenge_method": "S256",
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
code: AuthorizationCode = AuthorizationCode.objects.filter(user=user).first()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": f"foo://localhost?code={code.code}&state={state}",
},
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
# Missing the code_verifier here
"redirect_uri": "foo://localhost",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertJSONEqual(
response.content,
{"error": "invalid_request", "error_description": "The request is otherwise malformed"},
)
self.assertEqual(response.status_code, 400)
def test_pkce_correct_s256(self):
"""Test full with pkce"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
verifier = generate_id()
challenge = (
urlsafe_b64encode(sha256(verifier.encode("ascii")).digest())
.decode("utf-8")
.replace("=", "")
)
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"code_challenge": challenge,
"code_challenge_method": "S256",
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
code: AuthorizationCode = AuthorizationCode.objects.filter(user=user).first()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": f"foo://localhost?code={code.code}&state={state}",
},
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
"code_verifier": verifier,
"redirect_uri": "foo://localhost",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertEqual(response.status_code, 200)
def test_pkce_correct_plain(self):
"""Test full with pkce"""
flow = create_test_flow()
provider = OAuth2Provider.objects.create(
name=generate_id(),
client_id="test",
authorization_flow=flow,
redirect_uris="foo://localhost",
access_code_validity="seconds=100",
)
Application.objects.create(name="app", slug="app", provider=provider)
state = generate_id()
user = create_test_admin_user()
self.client.force_login(user)
verifier = generate_id()
header = b64encode(f"{provider.client_id}:{provider.client_secret}".encode()).decode()
# Step 1, initiate params and get redirect to flow
self.client.get(
reverse("authentik_providers_oauth2:authorize"),
data={
"response_type": "code",
"client_id": "test",
"state": state,
"redirect_uri": "foo://localhost",
"code_challenge": verifier,
},
)
response = self.client.get(
reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}),
)
code: AuthorizationCode = AuthorizationCode.objects.filter(user=user).first()
self.assertJSONEqual(
response.content.decode(),
{
"component": "xak-flow-redirect",
"type": ChallengeTypes.REDIRECT.value,
"to": f"foo://localhost?code={code.code}&state={state}",
},
)
response = self.client.post(
reverse("authentik_providers_oauth2:token"),
data={
"grant_type": GRANT_TYPE_AUTHORIZATION_CODE,
"code": code.code,
"code_verifier": verifier,
"redirect_uri": "foo://localhost",
},
HTTP_AUTHORIZATION=f"Basic {header}",
)
self.assertEqual(response.status_code, 200)

View file

@ -188,7 +188,6 @@ def authenticate_provider(request: HttpRequest) -> Optional[OAuth2Provider]:
if client_id != provider.client_id or client_secret != provider.client_secret: if client_id != provider.client_id or client_secret != provider.client_secret:
LOGGER.debug("(basic) Provider for basic auth does not exist") LOGGER.debug("(basic) Provider for basic auth does not exist")
return None return None
CTX_AUTH_VIA.set("oauth_client_secret")
return provider return provider

View file

@ -1,7 +1,6 @@
"""authentik OAuth2 Authorization views""" """authentik OAuth2 Authorization views"""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from hashlib import sha256
from json import dumps from json import dumps
from re import error as RegexError from re import error as RegexError
from re import fullmatch from re import fullmatch
@ -75,7 +74,6 @@ PLAN_CONTEXT_PARAMS = "goauthentik.io/providers/oauth2/params"
SESSION_KEY_LAST_LOGIN_UID = "authentik/providers/oauth2/last_login_uid" SESSION_KEY_LAST_LOGIN_UID = "authentik/providers/oauth2/last_login_uid"
ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN} ALLOWED_PROMPT_PARAMS = {PROMPT_NONE, PROMPT_CONSENT, PROMPT_LOGIN}
FORBIDDEN_URI_SCHEMES = {"javascript", "data", "vbscript"}
@dataclass(slots=True) @dataclass(slots=True)
@ -176,10 +174,6 @@ class OAuthAuthorizationParams:
self.check_scope() self.check_scope()
self.check_nonce() self.check_nonce()
self.check_code_challenge() self.check_code_challenge()
if self.request:
raise AuthorizeError(
self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
def check_redirect_uri(self): def check_redirect_uri(self):
"""Redirect URI validation.""" """Redirect URI validation."""
@ -217,9 +211,10 @@ class OAuthAuthorizationParams:
expected=allowed_redirect_urls, expected=allowed_redirect_urls,
) )
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) raise RedirectUriError(self.redirect_uri, allowed_redirect_urls)
# Check against forbidden schemes if self.request:
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES: raise AuthorizeError(
raise RedirectUriError(self.redirect_uri, allowed_redirect_urls) self.redirect_uri, "request_not_supported", self.grant_type, self.state
)
def check_scope(self): def check_scope(self):
"""Ensure openid scope is set in Hybrid flows, or when requesting an id_token""" """Ensure openid scope is set in Hybrid flows, or when requesting an id_token"""
@ -287,7 +282,6 @@ class OAuthAuthorizationParams:
expires=now + timedelta_from_string(self.provider.access_code_validity), expires=now + timedelta_from_string(self.provider.access_code_validity),
scope=self.scope, scope=self.scope,
nonce=self.nonce, nonce=self.nonce,
session_id=sha256(request.session.session_key.encode("ascii")).hexdigest(),
) )
if self.code_challenge and self.code_challenge_method: if self.code_challenge and self.code_challenge_method:
@ -575,7 +569,6 @@ class OAuthFulfillmentStage(StageView):
expires=access_token_expiry, expires=access_token_expiry,
provider=self.provider, provider=self.provider,
auth_time=auth_event.created if auth_event else now, auth_time=auth_event.created if auth_event else now,
session_id=sha256(self.request.session.session_key.encode("ascii")).hexdigest(),
) )
id_token = IDToken.new(self.provider, token, self.request) id_token = IDToken.new(self.provider, token, self.request)

View file

@ -6,7 +6,6 @@ from hashlib import sha256
from re import error as RegexError from re import error as RegexError
from re import fullmatch from re import fullmatch
from typing import Any, Optional from typing import Any, Optional
from urllib.parse import urlparse
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils import timezone from django.utils import timezone
@ -18,7 +17,6 @@ from jwt import PyJWK, PyJWT, PyJWTError, decode
from sentry_sdk.hub import Hub from sentry_sdk.hub import Hub
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.middleware import CTX_AUTH_VIA
from authentik.core.models import ( from authentik.core.models import (
USER_ATTRIBUTE_EXPIRES, USER_ATTRIBUTE_EXPIRES,
USER_ATTRIBUTE_GENERATED, USER_ATTRIBUTE_GENERATED,
@ -55,7 +53,6 @@ from authentik.providers.oauth2.models import (
RefreshToken, RefreshToken,
) )
from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth from authentik.providers.oauth2.utils import TokenResponse, cors_allow, extract_client_auth
from authentik.providers.oauth2.views.authorize import FORBIDDEN_URI_SCHEMES
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS from authentik.stages.password.stage import PLAN_CONTEXT_METHOD, PLAN_CONTEXT_METHOD_ARGS
@ -207,10 +204,6 @@ class TokenParams:
).from_http(request) ).from_http(request)
raise TokenError("invalid_client") raise TokenError("invalid_client")
# Check against forbidden schemes
if urlparse(self.redirect_uri).scheme in FORBIDDEN_URI_SCHEMES:
raise TokenError("invalid_request")
self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first() self.authorization_code = AuthorizationCode.objects.filter(code=raw_code).first()
if not self.authorization_code: if not self.authorization_code:
LOGGER.warning("Code does not exist", code=raw_code) LOGGER.warning("Code does not exist", code=raw_code)
@ -228,10 +221,7 @@ class TokenParams:
raise TokenError("invalid_grant") raise TokenError("invalid_grant")
# Validate PKCE parameters. # Validate PKCE parameters.
if self.authorization_code.code_challenge: if self.code_verifier:
# Authorization code had PKCE but we didn't get one
if not self.code_verifier:
raise TokenError("invalid_request")
if self.authorization_code.code_challenge_method == PKCE_METHOD_S256: if self.authorization_code.code_challenge_method == PKCE_METHOD_S256:
new_code_challenge = ( new_code_challenge = (
urlsafe_b64encode(sha256(self.code_verifier.encode("ascii")).digest()) urlsafe_b64encode(sha256(self.code_verifier.encode("ascii")).digest())
@ -458,7 +448,6 @@ class TokenView(View):
if not self.provider: if not self.provider:
LOGGER.warning("OAuth2Provider does not exist", client_id=client_id) LOGGER.warning("OAuth2Provider does not exist", client_id=client_id)
raise TokenError("invalid_client") raise TokenError("invalid_client")
CTX_AUTH_VIA.set("oauth_client_secret")
self.params = TokenParams.parse(request, self.provider, client_id, client_secret) self.params = TokenParams.parse(request, self.provider, client_id, client_secret)
with Hub.current.start_span( with Hub.current.start_span(
@ -493,7 +482,6 @@ class TokenView(View):
# Keep same scopes as previous token # Keep same scopes as previous token
scope=self.params.authorization_code.scope, scope=self.params.authorization_code.scope,
auth_time=self.params.authorization_code.auth_time, auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
) )
access_token.id_token = IDToken.new( access_token.id_token = IDToken.new(
self.provider, self.provider,
@ -509,7 +497,6 @@ class TokenView(View):
expires=refresh_token_expiry, expires=refresh_token_expiry,
provider=self.provider, provider=self.provider,
auth_time=self.params.authorization_code.auth_time, auth_time=self.params.authorization_code.auth_time,
session_id=self.params.authorization_code.session_id,
) )
id_token = IDToken.new( id_token = IDToken.new(
self.provider, self.provider,
@ -547,7 +534,6 @@ class TokenView(View):
# Keep same scopes as previous token # Keep same scopes as previous token
scope=self.params.refresh_token.scope, scope=self.params.refresh_token.scope,
auth_time=self.params.refresh_token.auth_time, auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
) )
access_token.id_token = IDToken.new( access_token.id_token = IDToken.new(
self.provider, self.provider,
@ -563,7 +549,6 @@ class TokenView(View):
expires=refresh_token_expiry, expires=refresh_token_expiry,
provider=self.provider, provider=self.provider,
auth_time=self.params.refresh_token.auth_time, auth_time=self.params.refresh_token.auth_time,
session_id=self.params.refresh_token.session_id,
) )
id_token = IDToken.new( id_token = IDToken.new(
self.provider, self.provider,

View file

@ -1,6 +1,4 @@
"""proxy provider tasks""" """proxy provider tasks"""
from hashlib import sha256
from asgiref.sync import async_to_sync from asgiref.sync import async_to_sync
from channels.layers import get_channel_layer from channels.layers import get_channel_layer
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
@ -25,7 +23,6 @@ def proxy_set_defaults():
def proxy_on_logout(session_id: str): def proxy_on_logout(session_id: str):
"""Update outpost instances connected to a single outpost""" """Update outpost instances connected to a single outpost"""
layer = get_channel_layer() layer = get_channel_layer()
hashed_session_id = sha256(session_id.encode("ascii")).hexdigest()
for outpost in Outpost.objects.filter(type=OutpostType.PROXY): for outpost in Outpost.objects.filter(type=OutpostType.PROXY):
group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)} group = OUTPOST_GROUP % {"outpost_pk": str(outpost.pk)}
async_to_sync(layer.group_send)( async_to_sync(layer.group_send)(
@ -33,6 +30,6 @@ def proxy_on_logout(session_id: str):
{ {
"type": "event.provider.specific", "type": "event.provider.specific",
"sub_type": "logout", "sub_type": "logout",
"session_id": hashed_session_id, "session_id": session_id,
}, },
) )

View file

@ -21,7 +21,6 @@ class RadiusProviderSerializer(ProviderSerializer):
# an admin might have to view it # an admin might have to view it
"shared_secret", "shared_secret",
"outpost_set", "outpost_set",
"mfa_support",
] ]
extra_kwargs = ProviderSerializer.Meta.extra_kwargs extra_kwargs = ProviderSerializer.Meta.extra_kwargs
@ -56,7 +55,6 @@ class RadiusOutpostConfigSerializer(ModelSerializer):
"auth_flow_slug", "auth_flow_slug",
"client_networks", "client_networks",
"shared_secret", "shared_secret",
"mfa_support",
] ]

View file

@ -1,21 +0,0 @@
# Generated by Django 4.2.6 on 2023-10-18 15:09
from django.db import migrations, models
class Migration(migrations.Migration):
dependencies = [
("authentik_providers_radius", "0001_initial"),
]
operations = [
migrations.AddField(
model_name="radiusprovider",
name="mfa_support",
field=models.BooleanField(
default=True,
help_text="When enabled, code-based multi-factor authentication can be used by appending a semicolon and the TOTP code to the password. This should only be enabled if all users that will bind to this provider have a TOTP device configured, as otherwise a password may incorrectly be rejected if it contains a semicolon.",
verbose_name="MFA Support",
),
),
]

View file

@ -27,17 +27,6 @@ class RadiusProvider(OutpostModel, Provider):
), ),
) )
mfa_support = models.BooleanField(
default=True,
verbose_name="MFA Support",
help_text=_(
"When enabled, code-based multi-factor authentication can be used by appending a "
"semicolon and the TOTP code to the password. This should only be enabled if all "
"users that will bind to this provider have a TOTP device configured, as otherwise "
"a password may incorrectly be rejected if it contains a semicolon."
),
)
@property @property
def launch_url(self) -> Optional[str]: def launch_url(self) -> Optional[str]:
"""Radius never has a launch URL""" """Radius never has a launch URL"""

View file

@ -46,9 +46,7 @@ class SCIMGroupClient(SCIMClient[Group, SCIMGroupSchema]):
def to_scim(self, obj: Group) -> SCIMGroupSchema: def to_scim(self, obj: Group) -> SCIMGroupSchema:
"""Convert authentik user into SCIM""" """Convert authentik user into SCIM"""
raw_scim_group = { raw_scim_group = {}
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:Group",),
}
for mapping in ( for mapping in (
self.provider.property_mappings_group.all().order_by("name").select_subclasses() self.provider.property_mappings_group.all().order_by("name").select_subclasses()
): ):

View file

@ -15,14 +15,12 @@ from pydanticscim.user import User as BaseUser
class User(BaseUser): class User(BaseUser):
"""Modified User schema with added externalId field""" """Modified User schema with added externalId field"""
schemas: tuple[str] = ("urn:ietf:params:scim:schemas:core:2.0:User",)
externalId: Optional[str] = None externalId: Optional[str] = None
class Group(BaseGroup): class Group(BaseGroup):
"""Modified Group schema with added externalId field""" """Modified Group schema with added externalId field"""
schemas: tuple[str] = ("urn:ietf:params:scim:schemas:core:2.0:Group",)
externalId: Optional[str] = None externalId: Optional[str] = None

View file

@ -39,9 +39,7 @@ class SCIMUserClient(SCIMClient[User, SCIMUserSchema]):
def to_scim(self, obj: User) -> SCIMUserSchema: def to_scim(self, obj: User) -> SCIMUserSchema:
"""Convert authentik user into SCIM""" """Convert authentik user into SCIM"""
raw_scim_user = { raw_scim_user = {}
"schemas": ("urn:ietf:params:scim:schemas:core:2.0:User",),
}
for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses(): for mapping in self.provider.property_mappings.all().order_by("name").select_subclasses():
if not isinstance(mapping, SCIMMapping): if not isinstance(mapping, SCIMMapping):
continue continue

View file

@ -1,12 +1,13 @@
"""SCIM task Settings""" """SCIM task Settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"providers_scim_sync": { "providers_scim_sync": {
"task": "authentik.providers.scim.tasks.scim_sync_all", "task": "authentik.providers.scim.tasks.scim_sync_all",
"schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*"), "schedule": crontab(minute=fqdn_rand("scim_sync_all"), hour="*"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
} }

View file

@ -61,11 +61,7 @@ class SCIMGroupTests(TestCase):
self.assertEqual(mock.request_history[1].method, "POST") self.assertEqual(mock.request_history[1].method, "POST")
self.assertJSONEqual( self.assertJSONEqual(
mock.request_history[1].body, mock.request_history[1].body,
{ {"externalId": str(group.pk), "displayName": group.name},
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
) )
@Mocker() @Mocker()
@ -100,11 +96,7 @@ class SCIMGroupTests(TestCase):
validate(body, loads(schema.read())) validate(body, loads(schema.read()))
self.assertEqual( self.assertEqual(
body, body,
{ {"externalId": str(group.pk), "displayName": group.name},
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
) )
group.save() group.save()
self.assertEqual(mock.call_count, 4) self.assertEqual(mock.call_count, 4)
@ -137,11 +129,7 @@ class SCIMGroupTests(TestCase):
self.assertEqual(mock.request_history[1].method, "POST") self.assertEqual(mock.request_history[1].method, "POST")
self.assertJSONEqual( self.assertJSONEqual(
mock.request_history[1].body, mock.request_history[1].body,
{ {"externalId": str(group.pk), "displayName": group.name},
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
) )
group.delete() group.delete()
self.assertEqual(mock.call_count, 4) self.assertEqual(mock.call_count, 4)

View file

@ -89,22 +89,17 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[3].body, mocker.request_history[3].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"emails": [], "emails": [],
"active": True, "active": True,
"externalId": user.uid, "externalId": user.uid,
"name": {"familyName": " ", "formatted": " ", "givenName": ""}, "name": {"familyName": "", "formatted": "", "givenName": ""},
"displayName": "", "displayName": "",
"userName": user.username, "userName": user.username,
}, },
) )
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[5].body, mocker.request_history[5].body,
{ {"externalId": str(group.pk), "displayName": group.name},
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
) )
with Mocker() as mocker: with Mocker() as mocker:
@ -123,7 +118,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[1].body, mocker.request_history[1].body,
{ {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [ "Operations": [
{ {
"op": "add", "op": "add",
@ -131,6 +125,7 @@ class SCIMMembershipTests(TestCase):
"value": [{"value": user_scim_id}], "value": [{"value": user_scim_id}],
} }
], ],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
}, },
) )
@ -179,22 +174,17 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[3].body, mocker.request_history[3].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True, "active": True,
"displayName": "", "displayName": "",
"emails": [], "emails": [],
"externalId": user.uid, "externalId": user.uid,
"name": {"familyName": " ", "formatted": " ", "givenName": ""}, "name": {"familyName": "", "formatted": "", "givenName": ""},
"userName": user.username, "userName": user.username,
}, },
) )
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[5].body, mocker.request_history[5].body,
{ {"externalId": str(group.pk), "displayName": group.name},
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:Group"],
"externalId": str(group.pk),
"displayName": group.name,
},
) )
with Mocker() as mocker: with Mocker() as mocker:
@ -213,7 +203,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[1].body, mocker.request_history[1].body,
{ {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [ "Operations": [
{ {
"op": "add", "op": "add",
@ -221,6 +210,7 @@ class SCIMMembershipTests(TestCase):
"value": [{"value": user_scim_id}], "value": [{"value": user_scim_id}],
} }
], ],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
}, },
) )
@ -240,7 +230,6 @@ class SCIMMembershipTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mocker.request_history[1].body, mocker.request_history[1].body,
{ {
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
"Operations": [ "Operations": [
{ {
"op": "remove", "op": "remove",
@ -248,5 +237,6 @@ class SCIMMembershipTests(TestCase):
"value": [{"value": user_scim_id}], "value": [{"value": user_scim_id}],
} }
], ],
"schemas": ["urn:ietf:params:scim:api:messages:2.0:PatchOp"],
}, },
) )

View file

@ -57,7 +57,7 @@ class SCIMUserTests(TestCase):
uid = generate_id() uid = generate_id()
user = User.objects.create( user = User.objects.create(
username=uid, username=uid,
name=f"{uid} {uid}", name=uid,
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
) )
self.assertEqual(mock.call_count, 2) self.assertEqual(mock.call_count, 2)
@ -66,7 +66,6 @@ class SCIMUserTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mock.request_history[1].body, mock.request_history[1].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True, "active": True,
"emails": [ "emails": [
{ {
@ -77,11 +76,11 @@ class SCIMUserTests(TestCase):
], ],
"externalId": user.uid, "externalId": user.uid,
"name": { "name": {
"familyName": uid, "familyName": "",
"formatted": f"{uid} {uid}", "formatted": uid,
"givenName": uid, "givenName": uid,
}, },
"displayName": f"{uid} {uid}", "displayName": uid,
"userName": uid, "userName": uid,
}, },
) )
@ -110,7 +109,7 @@ class SCIMUserTests(TestCase):
uid = generate_id() uid = generate_id()
user = User.objects.create( user = User.objects.create(
username=uid, username=uid,
name=f"{uid} {uid}", name=uid,
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
) )
self.assertEqual(mock.call_count, 2) self.assertEqual(mock.call_count, 2)
@ -122,7 +121,6 @@ class SCIMUserTests(TestCase):
self.assertEqual( self.assertEqual(
body, body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True, "active": True,
"emails": [ "emails": [
{ {
@ -131,11 +129,11 @@ class SCIMUserTests(TestCase):
"value": f"{uid}@goauthentik.io", "value": f"{uid}@goauthentik.io",
} }
], ],
"displayName": f"{uid} {uid}", "displayName": uid,
"externalId": user.uid, "externalId": user.uid,
"name": { "name": {
"familyName": uid, "familyName": "",
"formatted": f"{uid} {uid}", "formatted": uid,
"givenName": uid, "givenName": uid,
}, },
"userName": uid, "userName": uid,
@ -166,7 +164,7 @@ class SCIMUserTests(TestCase):
uid = generate_id() uid = generate_id()
user = User.objects.create( user = User.objects.create(
username=uid, username=uid,
name=f"{uid} {uid}", name=uid,
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
) )
self.assertEqual(mock.call_count, 2) self.assertEqual(mock.call_count, 2)
@ -175,7 +173,6 @@ class SCIMUserTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mock.request_history[1].body, mock.request_history[1].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True, "active": True,
"emails": [ "emails": [
{ {
@ -186,11 +183,11 @@ class SCIMUserTests(TestCase):
], ],
"externalId": user.uid, "externalId": user.uid,
"name": { "name": {
"familyName": uid, "familyName": "",
"formatted": f"{uid} {uid}", "formatted": uid,
"givenName": uid, "givenName": uid,
}, },
"displayName": f"{uid} {uid}", "displayName": uid,
"userName": uid, "userName": uid,
}, },
) )
@ -230,7 +227,7 @@ class SCIMUserTests(TestCase):
) )
user = User.objects.create( user = User.objects.create(
username=uid, username=uid,
name=f"{uid} {uid}", name=uid,
email=f"{uid}@goauthentik.io", email=f"{uid}@goauthentik.io",
) )
@ -243,7 +240,6 @@ class SCIMUserTests(TestCase):
self.assertJSONEqual( self.assertJSONEqual(
mock.request_history[1].body, mock.request_history[1].body,
{ {
"schemas": ["urn:ietf:params:scim:schemas:core:2.0:User"],
"active": True, "active": True,
"emails": [ "emails": [
{ {
@ -254,11 +250,11 @@ class SCIMUserTests(TestCase):
], ],
"externalId": user.uid, "externalId": user.uid,
"name": { "name": {
"familyName": uid, "familyName": "",
"formatted": f"{uid} {uid}", "formatted": uid,
"givenName": uid, "givenName": uid,
}, },
"displayName": f"{uid} {uid}", "displayName": uid,
"userName": uid, "userName": uid,
}, },
) )

View file

@ -32,19 +32,13 @@ class PermissionSerializer(ModelSerializer):
def get_app_label_verbose(self, instance: Permission) -> str: def get_app_label_verbose(self, instance: Permission) -> str:
"""Human-readable app label""" """Human-readable app label"""
try: return apps.get_app_config(instance.content_type.app_label).verbose_name
return apps.get_app_config(instance.content_type.app_label).verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
def get_model_verbose(self, instance: Permission) -> str: def get_model_verbose(self, instance: Permission) -> str:
"""Human-readable model name""" """Human-readable model name"""
try: return apps.get_model(
return apps.get_model( instance.content_type.app_label, instance.content_type.model
instance.content_type.app_label, instance.content_type.model )._meta.verbose_name
)._meta.verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
class Meta: class Meta:
model = Permission model = Permission

View file

@ -24,19 +24,13 @@ class ExtraRoleObjectPermissionSerializer(RoleObjectPermissionSerializer):
def get_app_label_verbose(self, instance: GroupObjectPermission) -> str: def get_app_label_verbose(self, instance: GroupObjectPermission) -> str:
"""Get app label from permission's model""" """Get app label from permission's model"""
try: return apps.get_app_config(instance.content_type.app_label).verbose_name
return apps.get_app_config(instance.content_type.app_label).verbose_name
except LookupError:
return instance.content_type.app_label
def get_model_verbose(self, instance: GroupObjectPermission) -> str: def get_model_verbose(self, instance: GroupObjectPermission) -> str:
"""Get model label from permission's model""" """Get model label from permission's model"""
try: return apps.get_model(
return apps.get_model( instance.content_type.app_label, instance.content_type.model
instance.content_type.app_label, instance.content_type.model )._meta.verbose_name
)._meta.verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
def get_object_description(self, instance: GroupObjectPermission) -> Optional[str]: def get_object_description(self, instance: GroupObjectPermission) -> Optional[str]:
"""Get model description from attached model. This operation takes at least """Get model description from attached model. This operation takes at least
@ -44,10 +38,7 @@ class ExtraRoleObjectPermissionSerializer(RoleObjectPermissionSerializer):
view_ permission on the object""" view_ permission on the object"""
app_label = instance.content_type.app_label app_label = instance.content_type.app_label
model = instance.content_type.model model = instance.content_type.model
try: model_class = apps.get_model(app_label, model)
model_class = apps.get_model(app_label, model)
except LookupError:
return None
objects = get_objects_for_group(instance.group, f"{app_label}.view_{model}", model_class) objects = get_objects_for_group(instance.group, f"{app_label}.view_{model}", model_class)
obj = objects.first() obj = objects.first()
if not obj: if not obj:

View file

@ -24,19 +24,13 @@ class ExtraUserObjectPermissionSerializer(UserObjectPermissionSerializer):
def get_app_label_verbose(self, instance: UserObjectPermission) -> str: def get_app_label_verbose(self, instance: UserObjectPermission) -> str:
"""Get app label from permission's model""" """Get app label from permission's model"""
try: return apps.get_app_config(instance.content_type.app_label).verbose_name
return apps.get_app_config(instance.content_type.app_label).verbose_name
except LookupError:
return instance.content_type.app_label
def get_model_verbose(self, instance: UserObjectPermission) -> str: def get_model_verbose(self, instance: UserObjectPermission) -> str:
"""Get model label from permission's model""" """Get model label from permission's model"""
try: return apps.get_model(
return apps.get_model( instance.content_type.app_label, instance.content_type.model
instance.content_type.app_label, instance.content_type.model )._meta.verbose_name
)._meta.verbose_name
except LookupError:
return f"{instance.content_type.app_label}.{instance.content_type.model}"
def get_object_description(self, instance: UserObjectPermission) -> Optional[str]: def get_object_description(self, instance: UserObjectPermission) -> Optional[str]:
"""Get model description from attached model. This operation takes at least """Get model description from attached model. This operation takes at least
@ -44,10 +38,7 @@ class ExtraUserObjectPermissionSerializer(UserObjectPermissionSerializer):
view_ permission on the object""" view_ permission on the object"""
app_label = instance.content_type.app_label app_label = instance.content_type.app_label
model = instance.content_type.model model = instance.content_type.model
try: model_class = apps.get_model(app_label, model)
model_class = apps.get_model(app_label, model)
except LookupError:
return None
objects = get_objects_for_user(instance.user, f"{app_label}.view_{model}", model_class) objects = get_objects_for_user(instance.user, f"{app_label}.view_{model}", model_class)
obj = objects.first() obj = objects.first()
if not obj: if not obj:

View file

@ -339,17 +339,23 @@ CELERY = {
"clean_expired_models": { "clean_expired_models": {
"task": "authentik.core.tasks.clean_expired_models", "task": "authentik.core.tasks.clean_expired_models",
"schedule": crontab(minute="2-59/5"), "schedule": crontab(minute="2-59/5"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
"user_cleanup": { "user_cleanup": {
"task": "authentik.core.tasks.clean_temporary_users", "task": "authentik.core.tasks.clean_temporary_users",
"schedule": crontab(minute="9-59/5"), "schedule": crontab(minute="9-59/5"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
}, },
"task_create_missing_queues": True, "task_create_missing_queues": True,
"task_default_queue": "authentik", "task_default_queue": "authentik",
"task_default_priority": 4,
"task_inherit_parent_priority": True,
"broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}", "broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
"broker_transport_options": {
"queue_order_strategy": "priority",
"priority_steps": list(range(10)),
},
"result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}", "result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}",
} }

View file

@ -1,12 +1,13 @@
"""LDAP Settings""" """LDAP Settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"sources_ldap_sync": { "sources_ldap_sync": {
"task": "authentik.sources.ldap.tasks.ldap_sync_all", "task": "authentik.sources.ldap.tasks.ldap_sync_all",
"schedule": crontab(minute=fqdn_rand("sources_ldap_sync"), hour="*/2"), "schedule": crontab(minute=fqdn_rand("sources_ldap_sync"), hour="*/2"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
} }
} }

View file

@ -58,10 +58,12 @@ def ldap_sync_single(source_pk: str):
group( group(
ldap_sync_paginator(source, UserLDAPSynchronizer) ldap_sync_paginator(source, UserLDAPSynchronizer)
+ ldap_sync_paginator(source, GroupLDAPSynchronizer), + ldap_sync_paginator(source, GroupLDAPSynchronizer),
priority=CONFIG.get_int("worker.priority.sync"),
), ),
# Membership sync needs to run afterwards # Membership sync needs to run afterwards
group( group(
ldap_sync_paginator(source, MembershipLDAPSynchronizer), ldap_sync_paginator(source, MembershipLDAPSynchronizer),
priority=CONFIG.get_int("worker.priority.sync"),
), ),
) )
task() task()

View file

@ -30,8 +30,6 @@ class SourceTypeSerializer(PassiveSerializer):
authorization_url = CharField(read_only=True, allow_null=True) authorization_url = CharField(read_only=True, allow_null=True)
access_token_url = CharField(read_only=True, allow_null=True) access_token_url = CharField(read_only=True, allow_null=True)
profile_url = CharField(read_only=True, allow_null=True) profile_url = CharField(read_only=True, allow_null=True)
oidc_well_known_url = CharField(read_only=True, allow_null=True)
oidc_jwks_url = CharField(read_only=True, allow_null=True)
class OAuthSourceSerializer(SourceSerializer): class OAuthSourceSerializer(SourceSerializer):
@ -54,15 +52,11 @@ class OAuthSourceSerializer(SourceSerializer):
@extend_schema_field(SourceTypeSerializer) @extend_schema_field(SourceTypeSerializer)
def get_type(self, instance: OAuthSource) -> SourceTypeSerializer: def get_type(self, instance: OAuthSource) -> SourceTypeSerializer:
"""Get source's type configuration""" """Get source's type configuration"""
return SourceTypeSerializer(instance.source_type).data return SourceTypeSerializer(instance.type).data
def validate(self, attrs: dict) -> dict: def validate(self, attrs: dict) -> dict:
session = get_http_session() session = get_http_session()
source_type = registry.find_type(attrs["provider_type"]) well_known = attrs.get("oidc_well_known_url")
well_known = attrs.get("oidc_well_known_url") or source_type.oidc_well_known_url
inferred_oidc_jwks_url = None
if well_known and well_known != "": if well_known and well_known != "":
try: try:
well_known_config = session.get(well_known) well_known_config = session.get(well_known)
@ -71,23 +65,24 @@ class OAuthSourceSerializer(SourceSerializer):
text = exc.response.text if exc.response else str(exc) text = exc.response.text if exc.response else str(exc)
raise ValidationError({"oidc_well_known_url": text}) raise ValidationError({"oidc_well_known_url": text})
config = well_known_config.json() config = well_known_config.json()
if "issuer" not in config: try:
raise ValidationError({"oidc_well_known_url": "Invalid well-known configuration"}) attrs["authorization_url"] = config["authorization_endpoint"]
attrs["authorization_url"] = config.get("authorization_endpoint", "") attrs["access_token_url"] = config["token_endpoint"]
attrs["access_token_url"] = config.get("token_endpoint", "") attrs["profile_url"] = config["userinfo_endpoint"]
attrs["profile_url"] = config.get("userinfo_endpoint", "") attrs["oidc_jwks_url"] = config["jwks_uri"]
inferred_oidc_jwks_url = config.get("jwks_uri", "") except (IndexError, KeyError) as exc:
raise ValidationError(
{"oidc_well_known_url": f"Invalid well-known configuration: {exc}"}
)
# Prefer user-entered URL to inferred URL to default URL jwks_url = attrs.get("oidc_jwks_url")
jwks_url = attrs.get("oidc_jwks_url") or inferred_oidc_jwks_url or source_type.oidc_jwks_url
if jwks_url and jwks_url != "": if jwks_url and jwks_url != "":
attrs["oidc_jwks_url"] = jwks_url
try: try:
jwks_config = session.get(jwks_url) jwks_config = session.get(jwks_url)
jwks_config.raise_for_status() jwks_config.raise_for_status()
except RequestException as exc: except RequestException as exc:
text = exc.response.text if exc.response else str(exc) text = exc.response.text if exc.response else str(exc)
raise ValidationError({"oidc_jwks_url": text}) raise ValidationError({"jwks_url": text})
config = jwks_config.json() config = jwks_config.json()
attrs["oidc_jwks"] = config attrs["oidc_jwks"] = config

View file

@ -36,8 +36,8 @@ class BaseOAuthClient:
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"""Fetch user profile information.""" """Fetch user profile information."""
profile_url = self.source.source_type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
response = self.do_request("get", profile_url, token=token) response = self.do_request("get", profile_url, token=token)
try: try:
@ -57,8 +57,8 @@ class BaseOAuthClient:
def get_redirect_url(self, parameters=None): def get_redirect_url(self, parameters=None):
"""Build authentication redirect url.""" """Build authentication redirect url."""
authorization_url = self.source.source_type.authorization_url or "" authorization_url = self.source.type.authorization_url or ""
if self.source.source_type.urls_customizable and self.source.authorization_url: if self.source.type.urls_customizable and self.source.authorization_url:
authorization_url = self.source.authorization_url authorization_url = self.source.authorization_url
if authorization_url == "": if authorization_url == "":
Event.new( Event.new(

View file

@ -28,8 +28,8 @@ class OAuthClient(BaseOAuthClient):
if raw_token is not None and verifier is not None: if raw_token is not None and verifier is not None:
token = self.parse_raw_token(raw_token) token = self.parse_raw_token(raw_token)
try: try:
access_token_url = self.source.source_type.access_token_url or "" access_token_url = self.source.type.access_token_url or ""
if self.source.source_type.urls_customizable and self.source.access_token_url: if self.source.type.urls_customizable and self.source.access_token_url:
access_token_url = self.source.access_token_url access_token_url = self.source.access_token_url
response = self.do_request( response = self.do_request(
"post", "post",
@ -54,8 +54,8 @@ class OAuthClient(BaseOAuthClient):
"""Fetch the OAuth request token. Only required for OAuth 1.0.""" """Fetch the OAuth request token. Only required for OAuth 1.0."""
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
try: try:
request_token_url = self.source.source_type.request_token_url or "" request_token_url = self.source.type.request_token_url or ""
if self.source.source_type.urls_customizable and self.source.request_token_url: if self.source.type.urls_customizable and self.source.request_token_url:
request_token_url = self.source.request_token_url request_token_url = self.source.request_token_url
response = self.do_request( response = self.do_request(
"post", "post",

View file

@ -76,8 +76,8 @@ class OAuth2Client(BaseOAuthClient):
if SESSION_KEY_OAUTH_PKCE in self.request.session: if SESSION_KEY_OAUTH_PKCE in self.request.session:
args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE] args["code_verifier"] = self.request.session[SESSION_KEY_OAUTH_PKCE]
try: try:
access_token_url = self.source.source_type.access_token_url or "" access_token_url = self.source.type.access_token_url or ""
if self.source.source_type.urls_customizable and self.source.access_token_url: if self.source.type.urls_customizable and self.source.access_token_url:
access_token_url = self.source.access_token_url access_token_url = self.source.access_token_url
response = self.session.request( response = self.session.request(
"post", access_token_url, data=args, headers=self._default_headers, **request_kwargs "post", access_token_url, data=args, headers=self._default_headers, **request_kwargs
@ -140,8 +140,8 @@ class UserprofileHeaderAuthClient(OAuth2Client):
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information." "Fetch user profile information."
profile_url = self.source.source_type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
response = self.session.request( response = self.session.request(
"get", "get",

View file

@ -1,5 +1,5 @@
"""OAuth Client models""" """OAuth Client models"""
from typing import TYPE_CHECKING, Optional from typing import TYPE_CHECKING, Optional, Type
from django.db import models from django.db import models
from django.http.request import HttpRequest from django.http.request import HttpRequest
@ -55,7 +55,7 @@ class OAuthSource(Source):
oidc_jwks = models.JSONField(default=dict, blank=True) oidc_jwks = models.JSONField(default=dict, blank=True)
@property @property
def source_type(self) -> type["SourceType"]: def type(self) -> type["SourceType"]:
"""Return the provider instance for this source""" """Return the provider instance for this source"""
from authentik.sources.oauth.types.registry import registry from authentik.sources.oauth.types.registry import registry
@ -65,14 +65,15 @@ class OAuthSource(Source):
def component(self) -> str: def component(self) -> str:
return "ak-source-oauth-form" return "ak-source-oauth-form"
# we're using Type[] instead of type[] here since type[] interferes with the property above
@property @property
def serializer(self) -> type[Serializer]: def serializer(self) -> Type[Serializer]:
from authentik.sources.oauth.api.source import OAuthSourceSerializer from authentik.sources.oauth.api.source import OAuthSourceSerializer
return OAuthSourceSerializer return OAuthSourceSerializer
def ui_login_button(self, request: HttpRequest) -> UILoginButton: def ui_login_button(self, request: HttpRequest) -> UILoginButton:
provider_type = self.source_type provider_type = self.type
provider = provider_type() provider = provider_type()
icon = self.get_icon icon = self.get_icon
if not icon: if not icon:
@ -84,7 +85,7 @@ class OAuthSource(Source):
) )
def ui_user_settings(self) -> Optional[UserSettingSerializer]: def ui_user_settings(self) -> Optional[UserSettingSerializer]:
provider_type = self.source_type provider_type = self.type
icon = self.get_icon icon = self.get_icon
if not icon: if not icon:
icon = provider_type().icon_url() icon = provider_type().icon_url()

View file

@ -1,12 +0,0 @@
"""OAuth source settings"""
from celery.schedules import crontab
from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = {
"update_oauth_source_oidc_well_known": {
"task": "authentik.sources.oauth.tasks.update_well_known_jwks",
"schedule": crontab(minute=fqdn_rand("update_well_known_jwks"), hour="*/3"),
"options": {"queue": "authentik_scheduled"},
},
}

View file

@ -1,70 +0,0 @@
"""OAuth Source tasks"""
from json import dumps
from requests import RequestException
from structlog.stdlib import get_logger
from authentik.events.monitored_tasks import MonitoredTask, TaskResult, TaskResultStatus
from authentik.lib.utils.http import get_http_session
from authentik.root.celery import CELERY_APP
from authentik.sources.oauth.models import OAuthSource
LOGGER = get_logger()
@CELERY_APP.task(bind=True, base=MonitoredTask)
def update_well_known_jwks(self: MonitoredTask):
"""Update OAuth sources' config from well_known, and JWKS info from the configured URL"""
session = get_http_session()
result = TaskResult(TaskResultStatus.SUCCESSFUL, [])
for source in OAuthSource.objects.all().exclude(oidc_well_known_url=""):
try:
well_known_config = session.get(source.oidc_well_known_url)
well_known_config.raise_for_status()
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
LOGGER.warning("Failed to update well_known", source=source, exc=exc, text=text)
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
continue
config = well_known_config.json()
try:
dirty = False
source_attr_key = (
("authorization_url", "authorization_endpoint"),
("access_token_url", "token_endpoint"),
("profile_url", "userinfo_endpoint"),
("oidc_jwks_url", "jwks_uri"),
)
for source_attr, config_key in source_attr_key:
# Check if we're actually changing anything to only
# save when something has changed
if getattr(source, source_attr, "") != config[config_key]:
dirty = True
setattr(source, source_attr, config[config_key])
except (IndexError, KeyError) as exc:
LOGGER.warning(
"Failed to update well_known",
source=source,
exc=exc,
)
result.messages.append(f"Failed to update OIDC configuration for {source.slug}")
continue
if dirty:
LOGGER.info("Updating sources' OpenID Configuration", source=source)
source.save()
for source in OAuthSource.objects.all().exclude(oidc_jwks_url=""):
try:
jwks_config = session.get(source.oidc_jwks_url)
jwks_config.raise_for_status()
except RequestException as exc:
text = exc.response.text if exc.response else str(exc)
LOGGER.warning("Failed to update JWKS", source=source, exc=exc, text=text)
result.messages.append(f"Failed to update JWKS for {source.slug}")
continue
config = jwks_config.json()
if dumps(source.oidc_jwks, sort_keys=True) != dumps(config, sort_keys=True):
source.oidc_jwks = config
LOGGER.info("Updating sources' JWKS", source=source)
source.save()
self.set_status(result)

View file

@ -1,48 +0,0 @@
"""Test OAuth Source tasks"""
from django.test import TestCase
from requests_mock import Mocker
from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.tasks import update_well_known_jwks
class TestOAuthSourceTasks(TestCase):
"""Test OAuth Source tasks"""
def setUp(self) -> None:
self.source = OAuthSource.objects.create(
name="test",
slug="test",
provider_type="openidconnect",
authorization_url="",
profile_url="",
consumer_key="",
)
@Mocker()
def test_well_known_jwks(self, mock: Mocker):
"""Test well_known update"""
self.source.oidc_well_known_url = "http://foo/.well-known/openid-configuration"
self.source.save()
mock.get(
self.source.oidc_well_known_url,
json={
"authorization_endpoint": "foo",
"token_endpoint": "foo",
"userinfo_endpoint": "foo",
"jwks_uri": "http://foo/jwks",
},
)
mock.get("http://foo/jwks", json={"foo": "bar"})
update_well_known_jwks() # pylint: disable=no-value-for-parameter
self.source.refresh_from_db()
self.assertEqual(self.source.authorization_url, "foo")
self.assertEqual(self.source.access_token_url, "foo")
self.assertEqual(self.source.profile_url, "foo")
self.assertEqual(self.source.oidc_jwks_url, "http://foo/jwks")
self.assertEqual(
self.source.oidc_jwks,
{
"foo": "bar",
},
)

View file

@ -50,7 +50,6 @@ class TestOAuthSource(TestCase):
def test_api_validate_openid_connect(self): def test_api_validate_openid_connect(self):
"""Test API validation (with OIDC endpoints)""" """Test API validation (with OIDC endpoints)"""
openid_config = { openid_config = {
"issuer": "foo",
"authorization_endpoint": "http://mock/oauth/authorize", "authorization_endpoint": "http://mock/oauth/authorize",
"token_endpoint": "http://mock/oauth/token", "token_endpoint": "http://mock/oauth/token",
"userinfo_endpoint": "http://mock/oauth/userinfo", "userinfo_endpoint": "http://mock/oauth/userinfo",

View file

@ -4,8 +4,8 @@ from typing import Any
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
LOGGER = get_logger() LOGGER = get_logger()
@ -20,7 +20,7 @@ class AzureADOAuthRedirect(OAuthRedirect):
} }
class AzureADOAuthCallback(OpenIDConnectOAuth2Callback): class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback""" """AzureAD OAuth2 Callback"""
client_class = UserprofileHeaderAuthClient client_class = UserprofileHeaderAuthClient
@ -50,8 +50,4 @@ class AzureADType(SourceType):
authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize" authorization_url = "https://login.microsoftonline.com/common/oauth2/v2.0/authorize"
access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec access_token_url = "https://login.microsoftonline.com/common/oauth2/v2.0/token" # nosec
profile_url = "https://login.microsoftonline.com/common/openid/userinfo" profile_url = "https://graph.microsoft.com/v1.0/me"
oidc_well_known_url = (
"https://login.microsoftonline.com/common/.well-known/openid-configuration"
)
oidc_jwks_url = "https://login.microsoftonline.com/common/discovery/keys"

View file

@ -23,8 +23,8 @@ class GitHubOAuth2Client(OAuth2Client):
def get_github_emails(self, token: dict[str, str]) -> list[dict[str, Any]]: def get_github_emails(self, token: dict[str, str]) -> list[dict[str, Any]]:
"""Get Emails from the GitHub API""" """Get Emails from the GitHub API"""
profile_url = self.source.source_type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
profile_url += "/emails" profile_url += "/emails"
response = self.do_request("get", profile_url, token=token) response = self.do_request("get", profile_url, token=token)
@ -76,7 +76,3 @@ class GitHubType(SourceType):
authorization_url = "https://github.com/login/oauth/authorize" authorization_url = "https://github.com/login/oauth/authorize"
access_token_url = "https://github.com/login/oauth/access_token" # nosec access_token_url = "https://github.com/login/oauth/access_token" # nosec
profile_url = "https://api.github.com/user" profile_url = "https://api.github.com/user"
oidc_well_known_url = (
"https://token.actions.githubusercontent.com/.well-known/openid-configuration"
)
oidc_jwks_url = "https://token.actions.githubusercontent.com/.well-known/jwks"

View file

@ -40,5 +40,3 @@ class GoogleType(SourceType):
authorization_url = "https://accounts.google.com/o/oauth2/auth" authorization_url = "https://accounts.google.com/o/oauth2/auth"
access_token_url = "https://oauth2.googleapis.com/token" # nosec access_token_url = "https://oauth2.googleapis.com/token" # nosec
profile_url = "https://www.googleapis.com/oauth2/v1/userinfo" profile_url = "https://www.googleapis.com/oauth2/v1/userinfo"
oidc_well_known_url = "https://accounts.google.com/.well-known/openid-configuration"
oidc_jwks_url = "https://www.googleapis.com/oauth2/v3/certs"

View file

@ -26,8 +26,8 @@ class MailcowOAuth2Client(OAuth2Client):
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]: def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information." "Fetch user profile information."
profile_url = self.source.source_type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.source_type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
response = self.session.request( response = self.session.request(
"get", "get",

View file

@ -23,7 +23,7 @@ class OpenIDConnectOAuth2Callback(OAuthCallback):
client_class = UserprofileHeaderAuthClient client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", None) return info.get("sub", "")
def get_user_enroll_context( def get_user_enroll_context(
self, self,

View file

@ -3,8 +3,8 @@ from typing import Any
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.models import OAuthSource from authentik.sources.oauth.models import OAuthSource
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -17,7 +17,7 @@ class OktaOAuthRedirect(OAuthRedirect):
} }
class OktaOAuth2Callback(OpenIDConnectOAuth2Callback): class OktaOAuth2Callback(OAuthCallback):
"""Okta OAuth2 Callback""" """Okta OAuth2 Callback"""
# Okta has the same quirk as azure and throws an error if the access token # Okta has the same quirk as azure and throws an error if the access token
@ -25,6 +25,9 @@ class OktaOAuth2Callback(OpenIDConnectOAuth2Callback):
# see https://github.com/goauthentik/authentik/issues/1910 # see https://github.com/goauthentik/authentik/issues/1910
client_class = UserprofileHeaderAuthClient client_class = UserprofileHeaderAuthClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context( def get_user_enroll_context(
self, self,
info: dict[str, Any], info: dict[str, Any],

View file

@ -12,9 +12,8 @@ class PatreonOAuthRedirect(OAuthRedirect):
"""Patreon OAuth2 Redirect""" """Patreon OAuth2 Redirect"""
def get_additional_parameters(self, source: OAuthSource): # pragma: no cover def get_additional_parameters(self, source: OAuthSource): # pragma: no cover
# https://docs.patreon.com/#scopes
return { return {
"scope": ["identity", "identity[email]"], "scope": ["openid", "email", "profile"],
} }

View file

@ -36,8 +36,6 @@ class SourceType:
authorization_url: Optional[str] = None authorization_url: Optional[str] = None
access_token_url: Optional[str] = None access_token_url: Optional[str] = None
profile_url: Optional[str] = None profile_url: Optional[str] = None
oidc_well_known_url: Optional[str] = None
oidc_jwks_url: Optional[str] = None
def icon_url(self) -> str: def icon_url(self) -> str:
"""Get Icon URL for login""" """Get Icon URL for login"""

View file

@ -3,8 +3,8 @@ from json import dumps
from typing import Any, Optional from typing import Any, Optional
from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient from authentik.sources.oauth.clients.oauth2 import UserprofileHeaderAuthClient
from authentik.sources.oauth.types.oidc import OpenIDConnectOAuth2Callback
from authentik.sources.oauth.types.registry import SourceType, registry from authentik.sources.oauth.types.registry import SourceType, registry
from authentik.sources.oauth.views.callback import OAuthCallback
from authentik.sources.oauth.views.redirect import OAuthRedirect from authentik.sources.oauth.views.redirect import OAuthRedirect
@ -27,11 +27,14 @@ class TwitchOAuthRedirect(OAuthRedirect):
} }
class TwitchOAuth2Callback(OpenIDConnectOAuth2Callback): class TwitchOAuth2Callback(OAuthCallback):
"""Twitch OAuth2 Callback""" """Twitch OAuth2 Callback"""
client_class = TwitchClient client_class = TwitchClient
def get_user_id(self, info: dict[str, str]) -> str:
return info.get("sub", "")
def get_user_enroll_context( def get_user_enroll_context(
self, self,
info: dict[str, Any], info: dict[str, Any],

View file

@ -25,7 +25,7 @@ class OAuthClientMixin:
if self.client_class is not None: if self.client_class is not None:
# pylint: disable=not-callable # pylint: disable=not-callable
return self.client_class(source, self.request, **kwargs) return self.client_class(source, self.request, **kwargs)
if source.source_type.request_token_url or source.request_token_url: if source.type.request_token_url or source.request_token_url:
client = OAuthClient(source, self.request, **kwargs) client = OAuthClient(source, self.request, **kwargs)
else: else:
client = OAuth2Client(source, self.request, **kwargs) client = OAuth2Client(source, self.request, **kwargs)

View file

@ -1,12 +1,13 @@
"""Plex source settings""" """Plex source settings"""
from celery.schedules import crontab from celery.schedules import crontab
from authentik.lib.config import CONFIG
from authentik.lib.utils.time import fqdn_rand from authentik.lib.utils.time import fqdn_rand
CELERY_BEAT_SCHEDULE = { CELERY_BEAT_SCHEDULE = {
"check_plex_token": { "check_plex_token": {
"task": "authentik.sources.plex.tasks.check_plex_token_all", "task": "authentik.sources.plex.tasks.check_plex_token_all",
"schedule": crontab(minute=fqdn_rand("check_plex_token"), hour="*/3"), "schedule": crontab(minute=fqdn_rand("check_plex_token"), hour="*/3"),
"options": {"queue": "authentik_scheduled"}, "options": {"priority": CONFIG.get_int("worker.priority.scheduled")},
}, },
} }

View file

@ -12,6 +12,7 @@ from authentik.flows.challenge import (
Challenge, Challenge,
ChallengeResponse, ChallengeResponse,
ChallengeTypes, ChallengeTypes,
ErrorDetailSerializer,
WithUserInfoChallenge, WithUserInfoChallenge,
) )
from authentik.flows.stage import ChallengeStageView from authentik.flows.stage import ChallengeStageView
@ -23,7 +24,6 @@ from authentik.stages.authenticator_sms.models import (
from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT from authentik.stages.prompt.stage import PLAN_CONTEXT_PROMPT
SESSION_KEY_SMS_DEVICE = "authentik/stages/authenticator_sms/sms_device" SESSION_KEY_SMS_DEVICE = "authentik/stages/authenticator_sms/sms_device"
PLAN_CONTEXT_PHONE = "phone"
class AuthenticatorSMSChallenge(WithUserInfoChallenge): class AuthenticatorSMSChallenge(WithUserInfoChallenge):
@ -48,8 +48,6 @@ class AuthenticatorSMSChallengeResponse(ChallengeResponse):
def validate(self, attrs: dict) -> dict: def validate(self, attrs: dict) -> dict:
"""Check""" """Check"""
if "code" not in attrs: if "code" not in attrs:
if "phone_number" not in attrs:
raise ValidationError("phone_number required")
self.device.phone_number = attrs["phone_number"] self.device.phone_number = attrs["phone_number"]
self.stage.validate_and_send(attrs["phone_number"]) self.stage.validate_and_send(attrs["phone_number"])
return super().validate(attrs) return super().validate(attrs)
@ -77,9 +75,9 @@ class AuthenticatorSMSStageView(ChallengeStageView):
def _has_phone_number(self) -> Optional[str]: def _has_phone_number(self) -> Optional[str]:
context = self.executor.plan.context context = self.executor.plan.context
if PLAN_CONTEXT_PHONE in context.get(PLAN_CONTEXT_PROMPT, {}): if "phone" in context.get(PLAN_CONTEXT_PROMPT, {}):
self.logger.debug("got phone number from plan context") self.logger.debug("got phone number from plan context")
return context.get(PLAN_CONTEXT_PROMPT, {}).get(PLAN_CONTEXT_PHONE) return context.get(PLAN_CONTEXT_PROMPT, {}).get("phone")
if SESSION_KEY_SMS_DEVICE in self.request.session: if SESSION_KEY_SMS_DEVICE in self.request.session:
self.logger.debug("got phone number from device in session") self.logger.debug("got phone number from device in session")
device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE] device: SMSDevice = self.request.session[SESSION_KEY_SMS_DEVICE]
@ -115,17 +113,10 @@ class AuthenticatorSMSStageView(ChallengeStageView):
try: try:
self.validate_and_send(phone_number) self.validate_and_send(phone_number)
except ValidationError as exc: except ValidationError as exc:
# We had a phone number given already (at this point only possible from flow response = AuthenticatorSMSChallengeResponse()
# context), but an error occurred while sending a number (most likely) response._errors.setdefault("phone_number", [])
# due to a duplicate device, so delete the number we got given, reset the state response._errors["phone_number"].append(ErrorDetailSerializer(exc.detail))
# (ish) and retry return self.challenge_invalid(response)
device.phone_number = ""
self.executor.plan.context.get(PLAN_CONTEXT_PROMPT, {}).pop(
PLAN_CONTEXT_PHONE, None
)
self.request.session.pop(SESSION_KEY_SMS_DEVICE, None)
self.logger.warning("failed to send SMS message to pre-set number", exc=exc)
return self.get(request, *args, **kwargs)
return super().get(request, *args, **kwargs) return super().get(request, *args, **kwargs)
def challenge_valid(self, response: ChallengeResponse) -> HttpResponse: def challenge_valid(self, response: ChallengeResponse) -> HttpResponse:

Some files were not shown because too many files have changed in this diff Show more