Merge branch 'master' into stage-challenge
# Conflicts: # authentik/stages/authenticator_validate/stage.py # authentik/stages/identification/stage.py
This commit is contained in:
commit
b229b2f40d
|
@ -2,7 +2,6 @@
|
||||||
import time
|
import time
|
||||||
from collections import Counter
|
from collections import Counter
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import Dict, List
|
|
||||||
|
|
||||||
from django.db.models import Count, ExpressionWrapper, F, Model
|
from django.db.models import Count, ExpressionWrapper, F, Model
|
||||||
from django.db.models.fields import DurationField
|
from django.db.models.fields import DurationField
|
||||||
|
@ -19,7 +18,7 @@ from rest_framework.viewsets import ViewSet
|
||||||
from authentik.events.models import Event, EventAction
|
from authentik.events.models import Event, EventAction
|
||||||
|
|
||||||
|
|
||||||
def get_events_per_1h(**filter_kwargs) -> List[Dict[str, int]]:
|
def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
|
||||||
"""Get event count by hour in the last day, fill with zeros"""
|
"""Get event count by hour in the last day, fill with zeros"""
|
||||||
date_from = now() - timedelta(days=1)
|
date_from = now() - timedelta(days=1)
|
||||||
result = (
|
result = (
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""authentik Outpost administration"""
|
"""authentik Outpost administration"""
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||||
from django.contrib.auth.mixins import (
|
from django.contrib.auth.mixins import (
|
||||||
|
@ -33,7 +33,7 @@ class OutpostCreateView(
|
||||||
template_name = "generic/create.html"
|
template_name = "generic/create.html"
|
||||||
success_message = _("Successfully created Outpost")
|
success_message = _("Successfully created Outpost")
|
||||||
|
|
||||||
def get_initial(self) -> Dict[str, Any]:
|
def get_initial(self) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"_config": asdict(
|
"_config": asdict(
|
||||||
OutpostConfig(authentik_host=self.request.build_absolute_uri("/"))
|
OutpostConfig(authentik_host=self.request.build_absolute_uri("/"))
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik Policy administration"""
|
"""authentik Policy administration"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||||
from django.contrib.auth.mixins import (
|
from django.contrib.auth.mixins import (
|
||||||
|
@ -102,7 +102,7 @@ class PolicyTestView(LoginRequiredMixin, DetailView, PermissionRequiredMixin, Fo
|
||||||
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
|
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
|
||||||
)
|
)
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
kwargs["policy"] = self.get_object()
|
kwargs["policy"] = self.get_object()
|
||||||
return super().get_context_data(**kwargs)
|
return super().get_context_data(**kwargs)
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik Tasks List"""
|
"""authentik Tasks List"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.views.generic.base import TemplateView
|
from django.views.generic.base import TemplateView
|
||||||
|
|
||||||
|
@ -12,7 +12,7 @@ class TaskListView(AdminRequiredMixin, TemplateView):
|
||||||
|
|
||||||
template_name = "administration/task/list.html"
|
template_name = "administration/task/list.html"
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
kwargs = super().get_context_data(**kwargs)
|
kwargs = super().get_context_data(**kwargs)
|
||||||
kwargs["object_list"] = sorted(
|
kwargs["object_list"] = sorted(
|
||||||
TaskInfo.all().values(), key=lambda x: x.task_name
|
TaskInfo.all().values(), key=lambda x: x.task_name
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik admin util views"""
|
"""authentik admin util views"""
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
|
@ -40,7 +40,7 @@ class SearchListMixin(MultipleObjectMixin):
|
||||||
"""Accept search query using `search` querystring parameter. Requires self.search_fields,
|
"""Accept search query using `search` querystring parameter. Requires self.search_fields,
|
||||||
a list of all fields to search. Can contain special lookups like __icontains"""
|
a list of all fields to search. Can contain special lookups like __icontains"""
|
||||||
|
|
||||||
search_fields: List[str]
|
search_fields: list[str]
|
||||||
|
|
||||||
def get_queryset(self) -> QuerySet:
|
def get_queryset(self) -> QuerySet:
|
||||||
queryset = super().get_queryset()
|
queryset = super().get_queryset()
|
||||||
|
@ -69,7 +69,7 @@ class InheritanceCreateView(CreateAssignPermView):
|
||||||
raise Http404 from exc
|
raise Http404 from exc
|
||||||
return model().form
|
return model().form
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
kwargs = super().get_context_data(**kwargs)
|
kwargs = super().get_context_data(**kwargs)
|
||||||
form_cls = self.get_form_class()
|
form_cls = self.get_form_class()
|
||||||
if hasattr(form_cls, "template_name"):
|
if hasattr(form_cls, "template_name"):
|
||||||
|
@ -80,7 +80,7 @@ class InheritanceCreateView(CreateAssignPermView):
|
||||||
class InheritanceUpdateView(UpdateView):
|
class InheritanceUpdateView(UpdateView):
|
||||||
"""UpdateView for objects using InheritanceManager"""
|
"""UpdateView for objects using InheritanceManager"""
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
kwargs = super().get_context_data(**kwargs)
|
kwargs = super().get_context_data(**kwargs)
|
||||||
form_cls = self.get_form_class()
|
form_cls = self.get_form_class()
|
||||||
if hasattr(form_cls, "template_name"):
|
if hasattr(form_cls, "template_name"):
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""API Authentication"""
|
"""API Authentication"""
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from binascii import Error
|
from binascii import Error
|
||||||
from typing import Any, Optional, Tuple, Union
|
from typing import Any, Optional, Union
|
||||||
|
|
||||||
from rest_framework.authentication import BaseAuthentication, get_authorization_header
|
from rest_framework.authentication import BaseAuthentication, get_authorization_header
|
||||||
from rest_framework.request import Request
|
from rest_framework.request import Request
|
||||||
|
@ -44,7 +44,7 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
|
||||||
class AuthentikTokenAuthentication(BaseAuthentication):
|
class AuthentikTokenAuthentication(BaseAuthentication):
|
||||||
"""Token-based authentication using HTTP Basic authentication"""
|
"""Token-based authentication using HTTP Basic authentication"""
|
||||||
|
|
||||||
def authenticate(self, request: Request) -> Union[Tuple[User, Any], None]:
|
def authenticate(self, request: Request) -> Union[tuple[User, Any], None]:
|
||||||
"""Token-based authentication using HTTP Basic authentication"""
|
"""Token-based authentication using HTTP Basic authentication"""
|
||||||
auth = get_authorization_header(request)
|
auth = get_authorization_header(request)
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""authentik core models"""
|
"""authentik core models"""
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, Dict, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
|
@ -96,7 +96,7 @@ class User(GuardianUserMixin, AbstractUser):
|
||||||
|
|
||||||
objects = UserManager()
|
objects = UserManager()
|
||||||
|
|
||||||
def group_attributes(self) -> Dict[str, Any]:
|
def group_attributes(self) -> dict[str, Any]:
|
||||||
"""Get a dictionary containing the attributes from all groups the user belongs to,
|
"""Get a dictionary containing the attributes from all groups the user belongs to,
|
||||||
including the users attributes"""
|
including the users attributes"""
|
||||||
final_attributes = {}
|
final_attributes = {}
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik core user views"""
|
"""authentik core user views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||||
from django.contrib.auth.mixins import (
|
from django.contrib.auth.mixins import (
|
||||||
|
@ -45,7 +45,7 @@ class UserDetailsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView):
|
||||||
def get_object(self):
|
def get_object(self):
|
||||||
return self.request.user
|
return self.request.user
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
kwargs = super().get_context_data(**kwargs)
|
kwargs = super().get_context_data(**kwargs)
|
||||||
unenrollment_flow = Flow.with_policy(
|
unenrollment_flow = Flow.with_policy(
|
||||||
self.request, designation=FlowDesignation.UNRENOLLMENT
|
self.request, designation=FlowDesignation.UNRENOLLMENT
|
||||||
|
|
|
@ -3,7 +3,7 @@ from dataclasses import dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from traceback import format_tb
|
from traceback import format_tb
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from celery import Task
|
from celery import Task
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
|
@ -26,7 +26,7 @@ class TaskResult:
|
||||||
|
|
||||||
status: TaskResultStatus
|
status: TaskResultStatus
|
||||||
|
|
||||||
messages: List[str] = field(default_factory=list)
|
messages: list[str] = field(default_factory=list)
|
||||||
|
|
||||||
# Optional UID used in cache for tasks that run in different instances
|
# Optional UID used in cache for tasks that run in different instances
|
||||||
uid: Optional[str] = field(default=None)
|
uid: Optional[str] = field(default=None)
|
||||||
|
@ -49,8 +49,8 @@ class TaskInfo:
|
||||||
|
|
||||||
task_call_module: str
|
task_call_module: str
|
||||||
task_call_func: str
|
task_call_func: str
|
||||||
task_call_args: List[Any] = field(default_factory=list)
|
task_call_args: list[Any] = field(default_factory=list)
|
||||||
task_call_kwargs: Dict[str, Any] = field(default_factory=dict)
|
task_call_kwargs: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
task_description: Optional[str] = field(default=None)
|
task_description: Optional[str] = field(default=None)
|
||||||
|
|
||||||
|
@ -60,7 +60,7 @@ class TaskInfo:
|
||||||
return self.task_name.split("_")
|
return self.task_name.split("_")
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def all() -> Dict[str, "TaskInfo"]:
|
def all() -> dict[str, "TaskInfo"]:
|
||||||
"""Get all TaskInfo objects"""
|
"""Get all TaskInfo objects"""
|
||||||
return cache.get_many(cache.keys("task_*"))
|
return cache.get_many(cache.keys("task_*"))
|
||||||
|
|
||||||
|
@ -109,7 +109,7 @@ class MonitoredTask(Task):
|
||||||
|
|
||||||
# pylint: disable=too-many-arguments
|
# pylint: disable=too-many-arguments
|
||||||
def after_return(
|
def after_return(
|
||||||
self, status, retval, task_id, args: List[Any], kwargs: Dict[str, Any], einfo
|
self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo
|
||||||
):
|
):
|
||||||
if not self._result.uid:
|
if not self._result.uid:
|
||||||
self._result.uid = self._uid
|
self._result.uid = self._uid
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""authentik events signal listener"""
|
"""authentik events signal listener"""
|
||||||
from threading import Thread
|
from threading import Thread
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.contrib.auth.signals import (
|
from django.contrib.auth.signals import (
|
||||||
user_logged_in,
|
user_logged_in,
|
||||||
|
@ -27,7 +27,7 @@ class EventNewThread(Thread):
|
||||||
|
|
||||||
action: str
|
action: str
|
||||||
request: HttpRequest
|
request: HttpRequest
|
||||||
kwargs: Dict[str, Any]
|
kwargs: dict[str, Any]
|
||||||
user: Optional[User] = None
|
user: Optional[User] = None
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
|
@ -69,7 +69,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
|
||||||
@receiver(user_write)
|
@receiver(user_write)
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def on_user_write(
|
def on_user_write(
|
||||||
sender, request: HttpRequest, user: User, data: Dict[str, Any], **kwargs
|
sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs
|
||||||
):
|
):
|
||||||
"""Log User write"""
|
"""Log User write"""
|
||||||
thread = EventNewThread(EventAction.USER_WRITE, request, **data)
|
thread = EventNewThread(EventAction.USER_WRITE, request, **data)
|
||||||
|
@ -81,7 +81,7 @@ def on_user_write(
|
||||||
@receiver(user_login_failed)
|
@receiver(user_login_failed)
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def on_user_login_failed(
|
def on_user_login_failed(
|
||||||
sender, credentials: Dict[str, str], request: HttpRequest, **_
|
sender, credentials: dict[str, str], request: HttpRequest, **_
|
||||||
):
|
):
|
||||||
"""Failed Login"""
|
"""Failed Login"""
|
||||||
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)
|
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""event utilities"""
|
"""event utilities"""
|
||||||
import re
|
import re
|
||||||
from dataclasses import asdict, is_dataclass
|
from dataclasses import asdict, is_dataclass
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from django.contrib.auth.models import AnonymousUser
|
from django.contrib.auth.models import AnonymousUser
|
||||||
|
@ -20,7 +20,7 @@ from authentik.policies.types import PolicyRequest
|
||||||
ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I)
|
ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I)
|
||||||
|
|
||||||
|
|
||||||
def cleanse_dict(source: Dict[Any, Any]) -> Dict[Any, Any]:
|
def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
"""Cleanse a dictionary, recursively"""
|
"""Cleanse a dictionary, recursively"""
|
||||||
final_dict = {}
|
final_dict = {}
|
||||||
for key, value in source.items():
|
for key, value in source.items():
|
||||||
|
@ -38,7 +38,7 @@ def cleanse_dict(source: Dict[Any, Any]) -> Dict[Any, Any]:
|
||||||
return final_dict
|
return final_dict
|
||||||
|
|
||||||
|
|
||||||
def model_to_dict(model: Model) -> Dict[str, Any]:
|
def model_to_dict(model: Model) -> dict[str, Any]:
|
||||||
"""Convert model to dict"""
|
"""Convert model to dict"""
|
||||||
name = str(model)
|
name = str(model)
|
||||||
if hasattr(model, "name"):
|
if hasattr(model, "name"):
|
||||||
|
@ -51,7 +51,7 @@ def model_to_dict(model: Model) -> Dict[str, Any]:
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
||||||
def get_user(user: User, original_user: Optional[User] = None) -> Dict[str, Any]:
|
def get_user(user: User, original_user: Optional[User] = None) -> dict[str, Any]:
|
||||||
"""Convert user object to dictionary, optionally including the original user"""
|
"""Convert user object to dictionary, optionally including the original user"""
|
||||||
if isinstance(user, AnonymousUser):
|
if isinstance(user, AnonymousUser):
|
||||||
user = get_anonymous_user()
|
user = get_anonymous_user()
|
||||||
|
@ -67,7 +67,7 @@ def get_user(user: User, original_user: Optional[User] = None) -> Dict[str, Any]
|
||||||
return user_data
|
return user_data
|
||||||
|
|
||||||
|
|
||||||
def sanitize_dict(source: Dict[Any, Any]) -> Dict[Any, Any]:
|
def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
"""clean source of all Models that would interfere with the JSONField.
|
"""clean source of all Models that would interfere with the JSONField.
|
||||||
Models are replaced with a dictionary of {
|
Models are replaced with a dictionary of {
|
||||||
app: str,
|
app: str,
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""Flows Planner"""
|
"""Flows Planner"""
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
@ -38,9 +38,9 @@ class FlowPlan:
|
||||||
|
|
||||||
flow_pk: str
|
flow_pk: str
|
||||||
|
|
||||||
stages: List[Stage] = field(default_factory=list)
|
stages: list[Stage] = field(default_factory=list)
|
||||||
context: Dict[str, Any] = field(default_factory=dict)
|
context: dict[str, Any] = field(default_factory=dict)
|
||||||
markers: List[StageMarker] = field(default_factory=list)
|
markers: list[StageMarker] = field(default_factory=list)
|
||||||
|
|
||||||
def append(self, stage: Stage, marker: Optional[StageMarker] = None):
|
def append(self, stage: Stage, marker: Optional[StageMarker] = None):
|
||||||
"""Append `stage` to all stages, optionall with stage marker"""
|
"""Append `stage` to all stages, optionall with stage marker"""
|
||||||
|
@ -96,7 +96,7 @@ class FlowPlanner:
|
||||||
self._logger = get_logger().bind(flow=flow)
|
self._logger = get_logger().bind(flow=flow)
|
||||||
|
|
||||||
def plan(
|
def plan(
|
||||||
self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None
|
self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
|
||||||
) -> FlowPlan:
|
) -> FlowPlan:
|
||||||
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
|
"""Check each of the flows' policies, check policies for each stage with PolicyBinding
|
||||||
and return ordered list"""
|
and return ordered list"""
|
||||||
|
@ -149,7 +149,7 @@ class FlowPlanner:
|
||||||
self,
|
self,
|
||||||
user: User,
|
user: User,
|
||||||
request: HttpRequest,
|
request: HttpRequest,
|
||||||
default_context: Optional[Dict[str, Any]],
|
default_context: Optional[dict[str, Any]],
|
||||||
) -> FlowPlan:
|
) -> FlowPlan:
|
||||||
"""Build flow plan by checking each stage in their respective
|
"""Build flow plan by checking each stage in their respective
|
||||||
order and checking the applied policies"""
|
order and checking the applied policies"""
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""authentik stage Base view"""
|
"""authentik stage Base view"""
|
||||||
from collections import namedtuple
|
from collections import namedtuple
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from django.http.response import HttpResponse, JsonResponse
|
from django.http.response import HttpResponse, JsonResponse
|
||||||
|
@ -32,7 +32,7 @@ class StageView(TemplateView):
|
||||||
def __init__(self, executor: FlowExecutorView):
|
def __init__(self, executor: FlowExecutorView):
|
||||||
self.executor = executor
|
self.executor = executor
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
kwargs["title"] = self.executor.flow.title
|
kwargs["title"] = self.executor.flow.title
|
||||||
# Either show the matched User object or show what the user entered,
|
# Either show the matched User object or show what the user entered,
|
||||||
# based on what the earlier stage (mostly IdentificationStage) set.
|
# based on what the earlier stage (mostly IdentificationStage) set.
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""transfer common classes"""
|
"""transfer common classes"""
|
||||||
from dataclasses import asdict, dataclass, field, is_dataclass
|
from dataclasses import asdict, dataclass, field, is_dataclass
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from django.core.serializers.json import DjangoJSONEncoder
|
from django.core.serializers.json import DjangoJSONEncoder
|
||||||
|
@ -9,7 +9,7 @@ from authentik.lib.models import SerializerModel
|
||||||
from authentik.lib.sentry import SentryIgnoredException
|
from authentik.lib.sentry import SentryIgnoredException
|
||||||
|
|
||||||
|
|
||||||
def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
|
def get_attrs(obj: SerializerModel) -> dict[str, Any]:
|
||||||
"""Get object's attributes via their serializer, and covert it to a normal dict"""
|
"""Get object's attributes via their serializer, and covert it to a normal dict"""
|
||||||
data = dict(obj.serializer(obj).data)
|
data = dict(obj.serializer(obj).data)
|
||||||
to_remove = (
|
to_remove = (
|
||||||
|
@ -33,9 +33,9 @@ def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
|
||||||
class FlowBundleEntry:
|
class FlowBundleEntry:
|
||||||
"""Single entry of a bundle"""
|
"""Single entry of a bundle"""
|
||||||
|
|
||||||
identifiers: Dict[str, Any]
|
identifiers: dict[str, Any]
|
||||||
model: str
|
model: str
|
||||||
attrs: Dict[str, Any]
|
attrs: dict[str, Any]
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def from_model(
|
def from_model(
|
||||||
|
@ -61,7 +61,7 @@ class FlowBundle:
|
||||||
"""Dataclass used for a full export"""
|
"""Dataclass used for a full export"""
|
||||||
|
|
||||||
version: int = field(default=1)
|
version: int = field(default=1)
|
||||||
entries: List[FlowBundleEntry] = field(default_factory=list)
|
entries: list[FlowBundleEntry] = field(default_factory=list)
|
||||||
|
|
||||||
|
|
||||||
class DataclassEncoder(DjangoJSONEncoder):
|
class DataclassEncoder(DjangoJSONEncoder):
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""Flow exporter"""
|
"""Flow exporter"""
|
||||||
from json import dumps
|
from json import dumps
|
||||||
from typing import Iterator, List
|
from typing import Iterator
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
|
@ -22,7 +22,7 @@ class FlowExporter:
|
||||||
with_policies: bool
|
with_policies: bool
|
||||||
with_stage_prompts: bool
|
with_stage_prompts: bool
|
||||||
|
|
||||||
pbm_uuids: List[UUID]
|
pbm_uuids: list[UUID]
|
||||||
|
|
||||||
def __init__(self, flow: Flow):
|
def __init__(self, flow: Flow):
|
||||||
self.flow = flow
|
self.flow = flow
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from contextlib import contextmanager
|
from contextlib import contextmanager
|
||||||
from copy import deepcopy
|
from copy import deepcopy
|
||||||
from json import loads
|
from json import loads
|
||||||
from typing import Any, Dict, Type
|
from typing import Any, Type
|
||||||
|
|
||||||
from dacite import from_dict
|
from dacite import from_dict
|
||||||
from dacite.exceptions import DaciteError
|
from dacite.exceptions import DaciteError
|
||||||
|
@ -42,7 +42,7 @@ class FlowImporter:
|
||||||
|
|
||||||
__import: FlowBundle
|
__import: FlowBundle
|
||||||
|
|
||||||
__pk_map: Dict[Any, Model]
|
__pk_map: dict[Any, Model]
|
||||||
|
|
||||||
logger: BoundLogger
|
logger: BoundLogger
|
||||||
|
|
||||||
|
@ -55,7 +55,7 @@ class FlowImporter:
|
||||||
except DaciteError as exc:
|
except DaciteError as exc:
|
||||||
raise EntryInvalidError from exc
|
raise EntryInvalidError from exc
|
||||||
|
|
||||||
def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]:
|
def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]:
|
||||||
"""Replace any value if it is a known primary key of an other object"""
|
"""Replace any value if it is a known primary key of an other object"""
|
||||||
|
|
||||||
def updater(value) -> Any:
|
def updater(value) -> Any:
|
||||||
|
@ -75,7 +75,7 @@ class FlowImporter:
|
||||||
attrs[key] = updater(value)
|
attrs[key] = updater(value)
|
||||||
return attrs
|
return attrs
|
||||||
|
|
||||||
def __query_from_identifier(self, attrs: Dict[str, Any]) -> Q:
|
def __query_from_identifier(self, attrs: dict[str, Any]) -> Q:
|
||||||
"""Generate an or'd query from all identifiers in an entry"""
|
"""Generate an or'd query from all identifiers in an entry"""
|
||||||
# Since identifiers can also be pk-references to other objects (see FlowStageBinding)
|
# Since identifiers can also be pk-references to other objects (see FlowStageBinding)
|
||||||
# we have to ensure those references are also replaced
|
# we have to ensure those references are also replaced
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""authentik multi-stage authentication engine"""
|
"""authentik multi-stage authentication engine"""
|
||||||
from traceback import format_tb
|
from traceback import format_tb
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.contrib.auth.mixins import LoginRequiredMixin
|
from django.contrib.auth.mixins import LoginRequiredMixin
|
||||||
from django.http import (
|
from django.http import (
|
||||||
|
@ -225,8 +225,8 @@ class FlowErrorResponse(TemplateResponse):
|
||||||
self.error = error
|
self.error = error
|
||||||
|
|
||||||
def resolve_context(
|
def resolve_context(
|
||||||
self, context: Optional[Dict[str, Any]]
|
self, context: Optional[dict[str, Any]]
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[dict[str, Any]]:
|
||||||
if not context:
|
if not context:
|
||||||
context = {}
|
context = {}
|
||||||
context["error"] = self.error
|
context["error"] = self.error
|
||||||
|
@ -244,7 +244,7 @@ class FlowExecutorShellView(TemplateView):
|
||||||
|
|
||||||
template_name = "flows/shell.html"
|
template_name = "flows/shell.html"
|
||||||
|
|
||||||
def get_context_data(self, **kwargs) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs) -> dict[str, Any]:
|
||||||
flow: Flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug"))
|
flow: Flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug"))
|
||||||
kwargs["background_url"] = flow.background.url
|
kwargs["background_url"] = flow.background.url
|
||||||
kwargs["exec_url"] = reverse("authentik_api:flow-executor", kwargs=self.kwargs)
|
kwargs["exec_url"] = reverse("authentik_api:flow-executor", kwargs=self.kwargs)
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""authentik expression policy evaluator"""
|
"""authentik expression policy evaluator"""
|
||||||
import re
|
import re
|
||||||
from textwrap import indent
|
from textwrap import indent
|
||||||
from typing import Any, Dict, Iterable, Optional
|
from typing import Any, Iterable, Optional
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from requests import Session
|
from requests import Session
|
||||||
|
@ -18,9 +18,9 @@ class BaseEvaluator:
|
||||||
"""Validate and evaluate python-based expressions"""
|
"""Validate and evaluate python-based expressions"""
|
||||||
|
|
||||||
# Globals that can be used by function
|
# Globals that can be used by function
|
||||||
_globals: Dict[str, Any]
|
_globals: dict[str, Any]
|
||||||
# Context passed as locals to exec()
|
# Context passed as locals to exec()
|
||||||
_context: Dict[str, Any]
|
_context: dict[str, Any]
|
||||||
|
|
||||||
# Filename used for exec
|
# Filename used for exec
|
||||||
_filename: str
|
_filename: str
|
||||||
|
|
|
@ -1,10 +1,10 @@
|
||||||
"""http helpers"""
|
"""http helpers"""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
|
||||||
|
|
||||||
def _get_client_ip_from_meta(meta: Dict[str, Any]) -> Optional[str]:
|
def _get_client_ip_from_meta(meta: dict[str, Any]) -> Optional[str]:
|
||||||
"""Attempt to get the client's IP by checking common HTTP Headers.
|
"""Attempt to get the client's IP by checking common HTTP Headers.
|
||||||
Returns none if no IP Could be found"""
|
Returns none if no IP Could be found"""
|
||||||
headers = (
|
headers = (
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
"""authentik UI utils"""
|
"""authentik UI utils"""
|
||||||
from typing import Any, List
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def human_list(_list: List[Any]) -> str:
|
def human_list(_list: list[Any]) -> str:
|
||||||
"""Convert a list of items into 'a, b or c'"""
|
"""Convert a list of items into 'a, b or c'"""
|
||||||
last_item = _list.pop()
|
last_item = _list.pop()
|
||||||
if len(_list) < 1:
|
if len(_list) < 1:
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from enum import IntEnum
|
from enum import IntEnum
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from channels.exceptions import DenyConnection
|
from channels.exceptions import DenyConnection
|
||||||
from dacite import from_dict
|
from dacite import from_dict
|
||||||
|
@ -34,7 +34,7 @@ class WebsocketMessage:
|
||||||
"""Complete Websocket Message that is being sent"""
|
"""Complete Websocket Message that is being sent"""
|
||||||
|
|
||||||
instruction: int
|
instruction: int
|
||||||
args: Dict[str, Any] = field(default_factory=dict)
|
args: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
|
|
||||||
class OutpostConsumer(AuthJsonConsumer):
|
class OutpostConsumer(AuthJsonConsumer):
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""Docker controller"""
|
"""Docker controller"""
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Dict, Tuple
|
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from docker import DockerClient
|
from docker import DockerClient
|
||||||
|
@ -33,10 +32,10 @@ class DockerController(BaseController):
|
||||||
except ServiceConnectionInvalid as exc:
|
except ServiceConnectionInvalid as exc:
|
||||||
raise ControllerException from exc
|
raise ControllerException from exc
|
||||||
|
|
||||||
def _get_labels(self) -> Dict[str, str]:
|
def _get_labels(self) -> dict[str, str]:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
def _get_env(self) -> Dict[str, str]:
|
def _get_env(self) -> dict[str, str]:
|
||||||
return {
|
return {
|
||||||
"AUTHENTIK_HOST": self.outpost.config.authentik_host,
|
"AUTHENTIK_HOST": self.outpost.config.authentik_host,
|
||||||
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure),
|
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure),
|
||||||
|
@ -55,7 +54,7 @@ class DockerController(BaseController):
|
||||||
return True
|
return True
|
||||||
return False
|
return False
|
||||||
|
|
||||||
def _get_container(self) -> Tuple[Container, bool]:
|
def _get_container(self) -> tuple[Container, bool]:
|
||||||
container_name = f"authentik-proxy-{self.outpost.uuid.hex}"
|
container_name = f"authentik-proxy-{self.outpost.uuid.hex}"
|
||||||
try:
|
try:
|
||||||
return self.client.containers.get(container_name), False
|
return self.client.containers.get(container_name), False
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Kubernetes Deployment Reconciler"""
|
"""Kubernetes Deployment Reconciler"""
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING
|
||||||
|
|
||||||
from kubernetes.client import (
|
from kubernetes.client import (
|
||||||
AppsV1Api,
|
AppsV1Api,
|
||||||
|
@ -53,7 +53,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
|
||||||
):
|
):
|
||||||
raise NeedsUpdate()
|
raise NeedsUpdate()
|
||||||
|
|
||||||
def get_pod_meta(self) -> Dict[str, str]:
|
def get_pod_meta(self) -> dict[str, str]:
|
||||||
"""Get common object metadata"""
|
"""Get common object metadata"""
|
||||||
return {
|
return {
|
||||||
"app.kubernetes.io/name": "authentik-outpost",
|
"app.kubernetes.io/name": "authentik-outpost",
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""Kubernetes deployment controller"""
|
"""Kubernetes deployment controller"""
|
||||||
from io import StringIO
|
from io import StringIO
|
||||||
from typing import Dict, List, Type
|
from typing import Type
|
||||||
|
|
||||||
from kubernetes.client import OpenApiException
|
from kubernetes.client import OpenApiException
|
||||||
from kubernetes.client.api_client import ApiClient
|
from kubernetes.client.api_client import ApiClient
|
||||||
|
@ -18,8 +18,8 @@ from authentik.outposts.models import KubernetesServiceConnection, Outpost
|
||||||
class KubernetesController(BaseController):
|
class KubernetesController(BaseController):
|
||||||
"""Manage deployment of outpost in kubernetes"""
|
"""Manage deployment of outpost in kubernetes"""
|
||||||
|
|
||||||
reconcilers: Dict[str, Type[KubernetesObjectReconciler]]
|
reconcilers: dict[str, Type[KubernetesObjectReconciler]]
|
||||||
reconcile_order: List[str]
|
reconcile_order: list[str]
|
||||||
|
|
||||||
client: ApiClient
|
client: ApiClient
|
||||||
connection: KubernetesServiceConnection
|
connection: KubernetesServiceConnection
|
||||||
|
@ -45,7 +45,7 @@ class KubernetesController(BaseController):
|
||||||
except OpenApiException as exc:
|
except OpenApiException as exc:
|
||||||
raise ControllerException from exc
|
raise ControllerException from exc
|
||||||
|
|
||||||
def up_with_logs(self) -> List[str]:
|
def up_with_logs(self) -> list[str]:
|
||||||
try:
|
try:
|
||||||
all_logs = []
|
all_logs = []
|
||||||
for reconcile_key in self.reconcile_order:
|
for reconcile_key in self.reconcile_order:
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""Outpost models"""
|
"""Outpost models"""
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from datetime import datetime
|
from datetime import datetime
|
||||||
from typing import Dict, Iterable, List, Optional, Type, Union
|
from typing import Iterable, Optional, Type, Union
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
from dacite import from_dict
|
from dacite import from_dict
|
||||||
|
@ -58,7 +58,7 @@ class OutpostConfig:
|
||||||
|
|
||||||
kubernetes_replicas: int = field(default=1)
|
kubernetes_replicas: int = field(default=1)
|
||||||
kubernetes_namespace: str = field(default="default")
|
kubernetes_namespace: str = field(default="default")
|
||||||
kubernetes_ingress_annotations: Dict[str, str] = field(default_factory=dict)
|
kubernetes_ingress_annotations: dict[str, str] = field(default_factory=dict)
|
||||||
kubernetes_ingress_secret_name: str = field(default="authentik-outpost")
|
kubernetes_ingress_secret_name: str = field(default="authentik-outpost")
|
||||||
|
|
||||||
|
|
||||||
|
@ -315,7 +315,7 @@ class Outpost(models.Model):
|
||||||
return f"outpost_{self.uuid.hex}_state"
|
return f"outpost_{self.uuid.hex}_state"
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def state(self) -> List["OutpostState"]:
|
def state(self) -> list["OutpostState"]:
|
||||||
"""Get outpost's health status"""
|
"""Get outpost's health status"""
|
||||||
return OutpostState.for_outpost(self)
|
return OutpostState.for_outpost(self)
|
||||||
|
|
||||||
|
@ -399,7 +399,7 @@ class OutpostState:
|
||||||
return parse(self.version) < OUR_VERSION
|
return parse(self.version) < OUR_VERSION
|
||||||
|
|
||||||
@staticmethod
|
@staticmethod
|
||||||
def for_outpost(outpost: Outpost) -> List["OutpostState"]:
|
def for_outpost(outpost: Outpost) -> list["OutpostState"]:
|
||||||
"""Get all states for an outpost"""
|
"""Get all states for an outpost"""
|
||||||
keys = cache.keys(f"{outpost.state_cache_prefix}_*")
|
keys = cache.keys(f"{outpost.state_cache_prefix}_*")
|
||||||
states = []
|
states = []
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from multiprocessing import Pipe, current_process
|
from multiprocessing import Pipe, current_process
|
||||||
from multiprocessing.connection import Connection
|
from multiprocessing.connection import Connection
|
||||||
from typing import Iterator, List, Optional
|
from typing import Iterator, Optional
|
||||||
|
|
||||||
from django.core.cache import cache
|
from django.core.cache import cache
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
@ -54,8 +54,8 @@ class PolicyEngine:
|
||||||
empty_result: bool
|
empty_result: bool
|
||||||
|
|
||||||
__pbm: PolicyBindingModel
|
__pbm: PolicyBindingModel
|
||||||
__cached_policies: List[PolicyResult]
|
__cached_policies: list[PolicyResult]
|
||||||
__processes: List[PolicyProcessInfo]
|
__processes: list[PolicyProcessInfo]
|
||||||
|
|
||||||
__expected_result_count: int
|
__expected_result_count: int
|
||||||
|
|
||||||
|
@ -137,7 +137,7 @@ class PolicyEngine:
|
||||||
@property
|
@property
|
||||||
def result(self) -> PolicyResult:
|
def result(self) -> PolicyResult:
|
||||||
"""Get policy-checking result"""
|
"""Get policy-checking result"""
|
||||||
process_results: List[PolicyResult] = [
|
process_results: list[PolicyResult] = [
|
||||||
x.result for x in self.__processes if x.result
|
x.result for x in self.__processes if x.result
|
||||||
]
|
]
|
||||||
all_results = list(process_results + self.__cached_policies)
|
all_results = list(process_results + self.__cached_policies)
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""authentik expression policy evaluator"""
|
"""authentik expression policy evaluator"""
|
||||||
from ipaddress import ip_address, ip_network
|
from ipaddress import ip_address, ip_network
|
||||||
from typing import TYPE_CHECKING, List, Optional
|
from typing import TYPE_CHECKING, Optional
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
@ -19,7 +19,7 @@ if TYPE_CHECKING:
|
||||||
class PolicyEvaluator(BaseEvaluator):
|
class PolicyEvaluator(BaseEvaluator):
|
||||||
"""Validate and evaluate python-based expressions"""
|
"""Validate and evaluate python-based expressions"""
|
||||||
|
|
||||||
_messages: List[str]
|
_messages: list[str]
|
||||||
|
|
||||||
policy: Optional["ExpressionPolicy"] = None
|
policy: Optional["ExpressionPolicy"] = None
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""policy http response"""
|
"""policy http response"""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.http.request import HttpRequest
|
from django.http.request import HttpRequest
|
||||||
from django.template.response import TemplateResponse
|
from django.template.response import TemplateResponse
|
||||||
|
@ -24,8 +24,8 @@ class AccessDeniedResponse(TemplateResponse):
|
||||||
self.title = _("Access denied")
|
self.title = _("Access denied")
|
||||||
|
|
||||||
def resolve_context(
|
def resolve_context(
|
||||||
self, context: Optional[Dict[str, Any]]
|
self, context: Optional[dict[str, Any]]
|
||||||
) -> Optional[Dict[str, Any]]:
|
) -> Optional[dict[str, Any]]:
|
||||||
if not context:
|
if not context:
|
||||||
context = {}
|
context = {}
|
||||||
context["title"] = self.title
|
context["title"] = self.title
|
||||||
|
|
|
@ -1,8 +1,8 @@
|
||||||
"""Policy Utils"""
|
"""Policy Utils"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
|
|
||||||
def delete_none_keys(dict_: Dict[Any, Any]) -> Dict[Any, Any]:
|
def delete_none_keys(dict_: dict[Any, Any]) -> dict[Any, Any]:
|
||||||
"""Remove any keys from `dict_` that are None."""
|
"""Remove any keys from `dict_` that are None."""
|
||||||
new_dict = {}
|
new_dict = {}
|
||||||
for key, value in dict_.items():
|
for key, value in dict_.items():
|
||||||
|
|
|
@ -5,7 +5,7 @@ import json
|
||||||
import time
|
import time
|
||||||
from dataclasses import asdict, dataclass, field
|
from dataclasses import asdict, dataclass, field
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, Dict, List, Optional, Type
|
from typing import Any, Optional, Type
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
@ -218,7 +218,7 @@ class OAuth2Provider(Provider):
|
||||||
)
|
)
|
||||||
|
|
||||||
def create_refresh_token(
|
def create_refresh_token(
|
||||||
self, user: User, scope: List[str], request: HttpRequest
|
self, user: User, scope: list[str], request: HttpRequest
|
||||||
) -> "RefreshToken":
|
) -> "RefreshToken":
|
||||||
"""Create and populate a RefreshToken object."""
|
"""Create and populate a RefreshToken object."""
|
||||||
token = RefreshToken(
|
token = RefreshToken(
|
||||||
|
@ -231,7 +231,7 @@ class OAuth2Provider(Provider):
|
||||||
token.access_token = token.create_access_token(user, request)
|
token.access_token = token.create_access_token(user, request)
|
||||||
return token
|
return token
|
||||||
|
|
||||||
def get_jwt_keys(self) -> List[Key]:
|
def get_jwt_keys(self) -> list[Key]:
|
||||||
"""
|
"""
|
||||||
Takes a provider and returns the set of keys associated with it.
|
Takes a provider and returns the set of keys associated with it.
|
||||||
Returns a list of keys.
|
Returns a list of keys.
|
||||||
|
@ -299,7 +299,7 @@ class OAuth2Provider(Provider):
|
||||||
def __str__(self):
|
def __str__(self):
|
||||||
return f"OAuth2 Provider {self.name}"
|
return f"OAuth2 Provider {self.name}"
|
||||||
|
|
||||||
def encode(self, payload: Dict[str, Any]) -> str:
|
def encode(self, payload: dict[str, Any]) -> str:
|
||||||
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
"""Represent the ID Token as a JSON Web Token (JWT)."""
|
||||||
keys = self.get_jwt_keys()
|
keys = self.get_jwt_keys()
|
||||||
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
# If the provider does not have an RSA Key assigned, it was switched to Symmetric
|
||||||
|
@ -321,7 +321,7 @@ class BaseGrantModel(models.Model):
|
||||||
_scope = models.TextField(default="", verbose_name=_("Scopes"))
|
_scope = models.TextField(default="", verbose_name=_("Scopes"))
|
||||||
|
|
||||||
@property
|
@property
|
||||||
def scope(self) -> List[str]:
|
def scope(self) -> list[str]:
|
||||||
"""Return scopes as list of strings"""
|
"""Return scopes as list of strings"""
|
||||||
return self._scope.split()
|
return self._scope.split()
|
||||||
|
|
||||||
|
@ -394,9 +394,9 @@ class IDToken:
|
||||||
nonce: Optional[str] = None
|
nonce: Optional[str] = None
|
||||||
at_hash: Optional[str] = None
|
at_hash: Optional[str] = None
|
||||||
|
|
||||||
claims: Dict[str, Any] = field(default_factory=dict)
|
claims: dict[str, Any] = field(default_factory=dict)
|
||||||
|
|
||||||
def to_dict(self) -> Dict[str, Any]:
|
def to_dict(self) -> dict[str, Any]:
|
||||||
"""Convert dataclass to dict, and update with keys from `claims`"""
|
"""Convert dataclass to dict, and update with keys from `claims`"""
|
||||||
dic = asdict(self)
|
dic = asdict(self)
|
||||||
dic.pop("claims")
|
dic.pop("claims")
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
import re
|
import re
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from binascii import Error
|
from binascii import Error
|
||||||
from typing import List, Optional, Tuple
|
from typing import Optional
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||||
from django.utils.cache import patch_vary_headers
|
from django.utils.cache import patch_vary_headers
|
||||||
|
@ -68,7 +68,7 @@ def extract_access_token(request: HttpRequest) -> Optional[str]:
|
||||||
return None
|
return None
|
||||||
|
|
||||||
|
|
||||||
def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
|
def extract_client_auth(request: HttpRequest) -> tuple[str, str]:
|
||||||
"""
|
"""
|
||||||
Get client credentials using HTTP Basic Authentication method.
|
Get client credentials using HTTP Basic Authentication method.
|
||||||
Or try getting parameters via POST.
|
Or try getting parameters via POST.
|
||||||
|
@ -92,7 +92,7 @@ def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
|
||||||
return (client_id, client_secret)
|
return (client_id, client_secret)
|
||||||
|
|
||||||
|
|
||||||
def protected_resource_view(scopes: List[str]):
|
def protected_resource_view(scopes: list[str]):
|
||||||
"""View decorator. The client accesses protected resources by presenting the
|
"""View decorator. The client accesses protected resources by presenting the
|
||||||
access token to the resource server.
|
access token to the resource server.
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""authentik OAuth2 Authorization views"""
|
"""authentik OAuth2 Authorization views"""
|
||||||
from dataclasses import dataclass, field
|
from dataclasses import dataclass, field
|
||||||
from datetime import timedelta
|
from datetime import timedelta
|
||||||
from typing import List, Optional, Set
|
from typing import Optional
|
||||||
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
|
||||||
from uuid import uuid4
|
from uuid import uuid4
|
||||||
|
|
||||||
|
@ -69,10 +69,10 @@ class OAuthAuthorizationParams:
|
||||||
client_id: str
|
client_id: str
|
||||||
redirect_uri: str
|
redirect_uri: str
|
||||||
response_type: str
|
response_type: str
|
||||||
scope: List[str]
|
scope: list[str]
|
||||||
state: str
|
state: str
|
||||||
nonce: Optional[str]
|
nonce: Optional[str]
|
||||||
prompt: Set[str]
|
prompt: set[str]
|
||||||
grant_type: str
|
grant_type: str
|
||||||
|
|
||||||
provider: OAuth2Provider = field(default_factory=OAuth2Provider)
|
provider: OAuth2Provider = field(default_factory=OAuth2Provider)
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik OAuth2 OpenID well-known views"""
|
"""authentik OAuth2 OpenID well-known views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse, JsonResponse
|
from django.http import HttpRequest, HttpResponse, JsonResponse
|
||||||
from django.shortcuts import get_object_or_404, reverse
|
from django.shortcuts import get_object_or_404, reverse
|
||||||
|
@ -29,7 +29,7 @@ PLAN_CONTEXT_SCOPES = "scopes"
|
||||||
class ProviderInfoView(View):
|
class ProviderInfoView(View):
|
||||||
"""OpenID-compliant Provider Info"""
|
"""OpenID-compliant Provider Info"""
|
||||||
|
|
||||||
def get_info(self, provider: OAuth2Provider) -> Dict[str, Any]:
|
def get_info(self, provider: OAuth2Provider) -> dict[str, Any]:
|
||||||
"""Get dictionary for OpenID Connect information"""
|
"""Get dictionary for OpenID Connect information"""
|
||||||
scopes = list(
|
scopes = list(
|
||||||
ScopeMapping.objects.filter(provider=provider).values_list(
|
ScopeMapping.objects.filter(provider=provider).values_list(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik OAuth2 Session Views"""
|
"""authentik OAuth2 Session Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.shortcuts import get_object_or_404
|
from django.shortcuts import get_object_or_404
|
||||||
from django.views.generic.base import TemplateView
|
from django.views.generic.base import TemplateView
|
||||||
|
@ -12,7 +12,7 @@ class EndSessionView(TemplateView):
|
||||||
|
|
||||||
template_name = "providers/oauth2/end_session.html"
|
template_name = "providers/oauth2/end_session.html"
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
|
||||||
context = super().get_context_data(**kwargs)
|
context = super().get_context_data(**kwargs)
|
||||||
|
|
||||||
context["application"] = get_object_or_404(
|
context["application"] = get_object_or_404(
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from base64 import urlsafe_b64encode
|
from base64 import urlsafe_b64encode
|
||||||
from dataclasses import InitVar, dataclass
|
from dataclasses import InitVar, dataclass
|
||||||
from hashlib import sha256
|
from hashlib import sha256
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.views import View
|
from django.views import View
|
||||||
|
@ -33,7 +33,7 @@ class TokenParams:
|
||||||
redirect_uri: str
|
redirect_uri: str
|
||||||
grant_type: str
|
grant_type: str
|
||||||
state: str
|
state: str
|
||||||
scope: List[str]
|
scope: list[str]
|
||||||
|
|
||||||
authorization_code: Optional[AuthorizationCode] = None
|
authorization_code: Optional[AuthorizationCode] = None
|
||||||
refresh_token: Optional[RefreshToken] = None
|
refresh_token: Optional[RefreshToken] = None
|
||||||
|
@ -171,7 +171,7 @@ class TokenView(View):
|
||||||
except UserAuthError as error:
|
except UserAuthError as error:
|
||||||
return TokenResponse(error.create_dict(), status=403)
|
return TokenResponse(error.create_dict(), status=403)
|
||||||
|
|
||||||
def create_code_response_dic(self) -> Dict[str, Any]:
|
def create_code_response_dic(self) -> dict[str, Any]:
|
||||||
"""See https://tools.ietf.org/html/rfc6749#section-4.1"""
|
"""See https://tools.ietf.org/html/rfc6749#section-4.1"""
|
||||||
|
|
||||||
refresh_token = self.params.authorization_code.provider.create_refresh_token(
|
refresh_token = self.params.authorization_code.provider.create_refresh_token(
|
||||||
|
@ -207,7 +207,7 @@ class TokenView(View):
|
||||||
|
|
||||||
return response_dict
|
return response_dict
|
||||||
|
|
||||||
def create_refresh_response_dic(self) -> Dict[str, Any]:
|
def create_refresh_response_dic(self) -> dict[str, Any]:
|
||||||
"""See https://tools.ietf.org/html/rfc6749#section-6"""
|
"""See https://tools.ietf.org/html/rfc6749#section-6"""
|
||||||
|
|
||||||
unauthorized_scopes = set(self.params.scope) - set(
|
unauthorized_scopes = set(self.params.scope) - set(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik OAuth2 OpenID Userinfo views"""
|
"""authentik OAuth2 OpenID Userinfo views"""
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.utils.translation import gettext_lazy as _
|
from django.utils.translation import gettext_lazy as _
|
||||||
|
@ -22,7 +22,7 @@ class UserInfoView(View):
|
||||||
"""Create a dictionary with all the requested claims about the End-User.
|
"""Create a dictionary with all the requested claims about the End-User.
|
||||||
See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse"""
|
See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse"""
|
||||||
|
|
||||||
def get_scope_descriptions(self, scopes: List[str]) -> Dict[str, str]:
|
def get_scope_descriptions(self, scopes: list[str]) -> dict[str, str]:
|
||||||
"""Get a list of all Scopes's descriptions"""
|
"""Get a list of all Scopes's descriptions"""
|
||||||
scope_descriptions = {}
|
scope_descriptions = {}
|
||||||
for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by(
|
for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by(
|
||||||
|
@ -47,7 +47,7 @@ class UserInfoView(View):
|
||||||
scope_descriptions[scope] = github_scope_map[scope]
|
scope_descriptions[scope] = github_scope_map[scope]
|
||||||
return scope_descriptions
|
return scope_descriptions
|
||||||
|
|
||||||
def get_claims(self, token: RefreshToken) -> Dict[str, Any]:
|
def get_claims(self, token: RefreshToken) -> dict[str, Any]:
|
||||||
"""Get a dictionary of claims from scopes that the token
|
"""Get a dictionary of claims from scopes that the token
|
||||||
requires and are assigned to the provider."""
|
requires and are assigned to the provider."""
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,4 @@
|
||||||
"""Proxy Provider Docker Contoller"""
|
"""Proxy Provider Docker Contoller"""
|
||||||
from typing import Dict
|
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from authentik.outposts.controllers.base import DeploymentPort
|
from authentik.outposts.controllers.base import DeploymentPort
|
||||||
|
@ -18,7 +17,7 @@ class ProxyDockerController(DockerController):
|
||||||
DeploymentPort(4443, "https", "tcp"),
|
DeploymentPort(4443, "https", "tcp"),
|
||||||
]
|
]
|
||||||
|
|
||||||
def _get_labels(self) -> Dict[str, str]:
|
def _get_labels(self) -> dict[str, str]:
|
||||||
hosts = []
|
hosts = []
|
||||||
for proxy_provider in ProxyProvider.objects.filter(outpost__in=[self.outpost]):
|
for proxy_provider in ProxyProvider.objects.filter(outpost__in=[self.outpost]):
|
||||||
proxy_provider: ProxyProvider
|
proxy_provider: ProxyProvider
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Kubernetes Ingress Reconciler"""
|
"""Kubernetes Ingress Reconciler"""
|
||||||
from typing import TYPE_CHECKING, Dict
|
from typing import TYPE_CHECKING
|
||||||
from urllib.parse import urlparse
|
from urllib.parse import urlparse
|
||||||
|
|
||||||
from kubernetes.client import (
|
from kubernetes.client import (
|
||||||
|
@ -78,7 +78,7 @@ class IngressReconciler(KubernetesObjectReconciler[NetworkingV1beta1Ingress]):
|
||||||
if have_hosts_tls != expected_hosts_tls:
|
if have_hosts_tls != expected_hosts_tls:
|
||||||
raise NeedsUpdate()
|
raise NeedsUpdate()
|
||||||
|
|
||||||
def get_ingress_annotations(self) -> Dict[str, str]:
|
def get_ingress_annotations(self) -> dict[str, str]:
|
||||||
"""Get ingress annotations"""
|
"""Get ingress annotations"""
|
||||||
annotations = {
|
annotations = {
|
||||||
# Ensure that with multiple proxy replicas deployed, the same CSRF request
|
# Ensure that with multiple proxy replicas deployed, the same CSRF request
|
||||||
|
|
|
@ -8,7 +8,7 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
|
||||||
"""
|
"""
|
||||||
import typing
|
import typing
|
||||||
from time import time
|
from time import time
|
||||||
from typing import Any, ByteString, Dict
|
from typing import Any, ByteString
|
||||||
|
|
||||||
import django
|
import django
|
||||||
from asgiref.compatibility import guarantee_single_callable
|
from asgiref.compatibility import guarantee_single_callable
|
||||||
|
@ -64,7 +64,7 @@ class ASGILogger:
|
||||||
app: ASGIApp
|
app: ASGIApp
|
||||||
|
|
||||||
scope: Scope
|
scope: Scope
|
||||||
headers: Dict[ByteString, Any]
|
headers: dict[ByteString, Any]
|
||||||
|
|
||||||
status_code: int
|
status_code: int
|
||||||
start: float
|
start: float
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik ldap source signals"""
|
"""authentik ldap source signals"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.core.exceptions import ValidationError
|
from django.core.exceptions import ValidationError
|
||||||
from django.db.models.signals import post_save
|
from django.db.models.signals import post_save
|
||||||
|
@ -26,7 +26,7 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
|
||||||
|
|
||||||
@receiver(password_validate)
|
@receiver(password_validate)
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def ldap_password_validate(sender, password: str, plan_context: Dict[str, Any], **__):
|
def ldap_password_validate(sender, password: str, plan_context: dict[str, Any], **__):
|
||||||
"""if there's an LDAP Source with enabled password sync, check the password"""
|
"""if there's an LDAP Source with enabled password sync, check the password"""
|
||||||
sources = LDAPSource.objects.filter(sync_users_password=True)
|
sources = LDAPSource.objects.filter(sync_users_password=True)
|
||||||
if not sources.exists():
|
if not sources.exists():
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""OAuth Clients"""
|
"""OAuth Clients"""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from urllib.parse import urlencode
|
from urllib.parse import urlencode
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
@ -33,11 +33,11 @@ class BaseOAuthClient:
|
||||||
self.callback = callback
|
self.callback = callback
|
||||||
self.session.headers.update({"User-Agent": f"authentik {__version__}"})
|
self.session.headers.update({"User-Agent": f"authentik {__version__}"})
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
|
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||||
"Fetch access token from callback request."
|
"Fetch access token from callback request."
|
||||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||||
|
|
||||||
def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||||
"Fetch user profile information."
|
"Fetch user profile information."
|
||||||
try:
|
try:
|
||||||
response = self.do_request("get", self.source.profile_url, token=token)
|
response = self.do_request("get", self.source.profile_url, token=token)
|
||||||
|
@ -48,7 +48,7 @@ class BaseOAuthClient:
|
||||||
else:
|
else:
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def get_redirect_args(self) -> Dict[str, str]:
|
def get_redirect_args(self) -> dict[str, str]:
|
||||||
"Get request parameters for redirect url."
|
"Get request parameters for redirect url."
|
||||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||||
|
|
||||||
|
@ -61,7 +61,7 @@ class BaseOAuthClient:
|
||||||
LOGGER.info("redirect args", **args)
|
LOGGER.info("redirect args", **args)
|
||||||
return f"{self.source.authorization_url}?{params}"
|
return f"{self.source.authorization_url}?{params}"
|
||||||
|
|
||||||
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
|
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||||
"Parse token and secret from raw token response."
|
"Parse token and secret from raw token response."
|
||||||
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
raise NotImplementedError("Defined in a sub-class") # pragma: no cover
|
||||||
|
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""OAuth 1 Clients"""
|
"""OAuth 1 Clients"""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from urllib.parse import parse_qsl
|
from urllib.parse import parse_qsl
|
||||||
|
|
||||||
from requests.exceptions import RequestException
|
from requests.exceptions import RequestException
|
||||||
|
@ -20,7 +20,7 @@ class OAuthClient(BaseOAuthClient):
|
||||||
"Accept": "application/json",
|
"Accept": "application/json",
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
|
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||||
"Fetch access token from callback request."
|
"Fetch access token from callback request."
|
||||||
raw_token = self.request.session.get(self.session_key, None)
|
raw_token = self.request.session.get(self.session_key, None)
|
||||||
verifier = self.request.GET.get("oauth_verifier", None)
|
verifier = self.request.GET.get("oauth_verifier", None)
|
||||||
|
@ -60,7 +60,7 @@ class OAuthClient(BaseOAuthClient):
|
||||||
else:
|
else:
|
||||||
return response.text
|
return response.text
|
||||||
|
|
||||||
def get_redirect_args(self) -> Dict[str, Any]:
|
def get_redirect_args(self) -> dict[str, Any]:
|
||||||
"Get request parameters for redirect url."
|
"Get request parameters for redirect url."
|
||||||
callback = self.request.build_absolute_uri(self.callback)
|
callback = self.request.build_absolute_uri(self.callback)
|
||||||
raw_token = self.get_request_token()
|
raw_token = self.get_request_token()
|
||||||
|
@ -71,7 +71,7 @@ class OAuthClient(BaseOAuthClient):
|
||||||
"oauth_callback": callback,
|
"oauth_callback": callback,
|
||||||
}
|
}
|
||||||
|
|
||||||
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
|
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||||
"Parse token and secret from raw token response."
|
"Parse token and secret from raw token response."
|
||||||
return dict(parse_qsl(raw_token))
|
return dict(parse_qsl(raw_token))
|
||||||
|
|
||||||
|
@ -80,7 +80,7 @@ class OAuthClient(BaseOAuthClient):
|
||||||
resource_owner_key = None
|
resource_owner_key = None
|
||||||
resource_owner_secret = None
|
resource_owner_secret = None
|
||||||
if "token" in kwargs:
|
if "token" in kwargs:
|
||||||
user_token: Dict[str, Any] = kwargs.pop("token")
|
user_token: dict[str, Any] = kwargs.pop("token")
|
||||||
resource_owner_key = user_token["oauth_token"]
|
resource_owner_key = user_token["oauth_token"]
|
||||||
resource_owner_secret = user_token["oauth_token_secret"]
|
resource_owner_secret = user_token["oauth_token_secret"]
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""OAuth 2 Clients"""
|
"""OAuth 2 Clients"""
|
||||||
from json import loads
|
from json import loads
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from urllib.parse import parse_qsl
|
from urllib.parse import parse_qsl
|
||||||
|
|
||||||
from django.utils.crypto import constant_time_compare, get_random_string
|
from django.utils.crypto import constant_time_compare, get_random_string
|
||||||
|
@ -38,7 +38,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
"Generate state optional parameter."
|
"Generate state optional parameter."
|
||||||
return get_random_string(32)
|
return get_random_string(32)
|
||||||
|
|
||||||
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]:
|
def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
|
||||||
"Fetch access token from callback request."
|
"Fetch access token from callback request."
|
||||||
callback = self.request.build_absolute_uri(self.callback or self.request.path)
|
callback = self.request.build_absolute_uri(self.callback or self.request.path)
|
||||||
if not self.check_application_state():
|
if not self.check_application_state():
|
||||||
|
@ -69,11 +69,11 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
else:
|
else:
|
||||||
return response.json()
|
return response.json()
|
||||||
|
|
||||||
def get_redirect_args(self) -> Dict[str, str]:
|
def get_redirect_args(self) -> dict[str, str]:
|
||||||
"Get request parameters for redirect url."
|
"Get request parameters for redirect url."
|
||||||
callback = self.request.build_absolute_uri(self.callback)
|
callback = self.request.build_absolute_uri(self.callback)
|
||||||
client_id: str = self.source.consumer_key
|
client_id: str = self.source.consumer_key
|
||||||
args: Dict[str, str] = {
|
args: dict[str, str] = {
|
||||||
"client_id": client_id,
|
"client_id": client_id,
|
||||||
"redirect_uri": callback,
|
"redirect_uri": callback,
|
||||||
"response_type": "code",
|
"response_type": "code",
|
||||||
|
@ -84,7 +84,7 @@ class OAuth2Client(BaseOAuthClient):
|
||||||
self.request.session[self.session_key] = state
|
self.request.session[self.session_key] = state
|
||||||
return args
|
return args
|
||||||
|
|
||||||
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]:
|
def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
|
||||||
"Parse token and secret from raw token response."
|
"Parse token and secret from raw token response."
|
||||||
# Load as json first then parse as query string
|
# Load as json first then parse as query string
|
||||||
try:
|
try:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""AzureAD OAuth2 Views"""
|
"""AzureAD OAuth2 Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
from uuid import UUID
|
from uuid import UUID
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
|
@ -11,15 +11,15 @@ from authentik.sources.oauth.views.callback import OAuthCallback
|
||||||
class AzureADOAuthCallback(OAuthCallback):
|
class AzureADOAuthCallback(OAuthCallback):
|
||||||
"""AzureAD OAuth2 Callback"""
|
"""AzureAD OAuth2 Callback"""
|
||||||
|
|
||||||
def get_user_id(self, source: OAuthSource, info: Dict[str, Any]) -> str:
|
def get_user_id(self, source: OAuthSource, info: dict[str, Any]) -> str:
|
||||||
return str(UUID(info.get("objectId")).int)
|
return str(UUID(info.get("objectId")).int)
|
||||||
|
|
||||||
def get_user_enroll_context(
|
def get_user_enroll_context(
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
mail = info.get("mail", None) or info.get("otherMails", [None])[0]
|
||||||
return {
|
return {
|
||||||
"username": info.get("displayName"),
|
"username": info.get("displayName"),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Discord OAuth Views"""
|
"""Discord OAuth Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||||
|
@ -25,8 +25,8 @@ class DiscordOAuth2Callback(OAuthCallback):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("username"),
|
"username": info.get("username"),
|
||||||
"email": info.get("email", None),
|
"email": info.get("email", None),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Facebook OAuth Views"""
|
"""Facebook OAuth Views"""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from facebook import GraphAPI
|
from facebook import GraphAPI
|
||||||
|
|
||||||
|
@ -23,7 +23,7 @@ class FacebookOAuthRedirect(OAuthRedirect):
|
||||||
class FacebookOAuth2Client(OAuth2Client):
|
class FacebookOAuth2Client(OAuth2Client):
|
||||||
"""Facebook OAuth2 Client"""
|
"""Facebook OAuth2 Client"""
|
||||||
|
|
||||||
def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]:
|
def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
|
||||||
api = GraphAPI(access_token=token["access_token"])
|
api = GraphAPI(access_token=token["access_token"])
|
||||||
return api.get_object("me", fields="id,name,email")
|
return api.get_object("me", fields="id,name,email")
|
||||||
|
|
||||||
|
@ -38,8 +38,8 @@ class FacebookOAuth2Callback(OAuthCallback):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("name"),
|
"username": info.get("name"),
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""GitHub OAuth Views"""
|
"""GitHub OAuth Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||||
|
@ -14,8 +14,8 @@ class GitHubOAuth2Callback(OAuthCallback):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("login"),
|
"username": info.get("login"),
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Google OAuth Views"""
|
"""Google OAuth Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||||
|
@ -25,8 +25,8 @@ class GoogleOAuth2Callback(OAuthCallback):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("email"),
|
"username": info.get("email"),
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""Source type manager"""
|
"""Source type manager"""
|
||||||
from enum import Enum
|
from enum import Enum
|
||||||
from typing import Callable, Dict, List
|
from typing import Callable
|
||||||
|
|
||||||
from django.utils.text import slugify
|
from django.utils.text import slugify
|
||||||
from structlog.stdlib import get_logger
|
from structlog.stdlib import get_logger
|
||||||
|
@ -22,8 +22,8 @@ class RequestKind(Enum):
|
||||||
class SourceTypeManager:
|
class SourceTypeManager:
|
||||||
"""Manager to hold all Source types."""
|
"""Manager to hold all Source types."""
|
||||||
|
|
||||||
__source_types: Dict[RequestKind, Dict[str, Callable]] = {}
|
__source_types: dict[RequestKind, dict[str, Callable]] = {}
|
||||||
__names: List[str] = []
|
__names: list[str] = []
|
||||||
|
|
||||||
def source(self, kind: RequestKind, name: str):
|
def source(self, kind: RequestKind, name: str):
|
||||||
"""Class decorator to register classes inline."""
|
"""Class decorator to register classes inline."""
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""OpenID Connect OAuth Views"""
|
"""OpenID Connect OAuth Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||||
|
@ -21,15 +21,15 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
|
||||||
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
class OpenIDConnectOAuth2Callback(OAuthCallback):
|
||||||
"""OpenIDConnect OAuth2 Callback"""
|
"""OpenIDConnect OAuth2 Callback"""
|
||||||
|
|
||||||
def get_user_id(self, source: OAuthSource, info: Dict[str, str]) -> str:
|
def get_user_id(self, source: OAuthSource, info: dict[str, str]) -> str:
|
||||||
return info.get("sub", "")
|
return info.get("sub", "")
|
||||||
|
|
||||||
def get_user_enroll_context(
|
def get_user_enroll_context(
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("nickname"),
|
"username": info.get("nickname"),
|
||||||
"email": info.get("email"),
|
"email": info.get("email"),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Reddit OAuth Views"""
|
"""Reddit OAuth Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from requests.auth import HTTPBasicAuth
|
from requests.auth import HTTPBasicAuth
|
||||||
|
|
||||||
|
@ -40,8 +40,8 @@ class RedditOAuth2Callback(OAuthCallback):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("name"),
|
"username": info.get("name"),
|
||||||
"email": None,
|
"email": None,
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Twitter OAuth Views"""
|
"""Twitter OAuth Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
|
||||||
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
from authentik.sources.oauth.types.manager import MANAGER, RequestKind
|
||||||
|
@ -14,8 +14,8 @@ class TwitterOAuthCallback(OAuthCallback):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
return {
|
return {
|
||||||
"username": info.get("screen_name"),
|
"username": info.get("screen_name"),
|
||||||
"email": info.get("email", None),
|
"email": info.get("email", None),
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""OAuth Callback Views"""
|
"""OAuth Callback Views"""
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.conf import settings
|
from django.conf import settings
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
|
@ -115,14 +115,14 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> Dict[str, Any]:
|
) -> dict[str, Any]:
|
||||||
"""Create a dict of User data"""
|
"""Create a dict of User data"""
|
||||||
raise NotImplementedError()
|
raise NotImplementedError()
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def get_user_id(
|
def get_user_id(
|
||||||
self, source: UserOAuthSourceConnection, info: Dict[str, Any]
|
self, source: UserOAuthSourceConnection, info: dict[str, Any]
|
||||||
) -> Optional[str]:
|
) -> Optional[str]:
|
||||||
"""Return unique identifier from the profile info."""
|
"""Return unique identifier from the profile info."""
|
||||||
if "id" in info:
|
if "id" in info:
|
||||||
|
@ -167,7 +167,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
user: User,
|
user: User,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
"Login user and redirect."
|
"Login user and redirect."
|
||||||
messages.success(
|
messages.success(
|
||||||
|
@ -184,7 +184,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
"""Handler when the user was already authenticated and linked an external source
|
"""Handler when the user was already authenticated and linked an external source
|
||||||
to their account."""
|
to their account."""
|
||||||
|
@ -211,7 +211,7 @@ class OAuthCallback(OAuthClientMixin, View):
|
||||||
self,
|
self,
|
||||||
source: OAuthSource,
|
source: OAuthSource,
|
||||||
access: UserOAuthSourceConnection,
|
access: UserOAuthSourceConnection,
|
||||||
info: Dict[str, Any],
|
info: dict[str, Any],
|
||||||
) -> HttpResponse:
|
) -> HttpResponse:
|
||||||
"""User was not authenticated and previous request was not authenticated."""
|
"""User was not authenticated and previous request was not authenticated."""
|
||||||
messages.success(
|
messages.success(
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""OAuth Redirect Views"""
|
"""OAuth Redirect Views"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.http import Http404
|
from django.http import Http404
|
||||||
from django.urls import reverse
|
from django.urls import reverse
|
||||||
|
@ -19,7 +19,7 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
|
||||||
params = None
|
params = None
|
||||||
|
|
||||||
# pylint: disable=unused-argument
|
# pylint: disable=unused-argument
|
||||||
def get_additional_parameters(self, source: OAuthSource) -> Dict[str, Any]:
|
def get_additional_parameters(self, source: OAuthSource) -> dict[str, Any]:
|
||||||
"Return additional redirect parameters for this source."
|
"Return additional redirect parameters for this source."
|
||||||
return self.params or {}
|
return self.params or {}
|
||||||
|
|
||||||
|
|
|
@ -1,6 +1,5 @@
|
||||||
"""SAML AuthnRequest Processor"""
|
"""SAML AuthnRequest Processor"""
|
||||||
from base64 import b64encode
|
from base64 import b64encode
|
||||||
from typing import Dict
|
|
||||||
from urllib.parse import quote_plus
|
from urllib.parse import quote_plus
|
||||||
|
|
||||||
import xmlsec
|
import xmlsec
|
||||||
|
@ -125,7 +124,7 @@ class RequestProcessor:
|
||||||
|
|
||||||
return etree.tostring(auth_n_request).decode()
|
return etree.tostring(auth_n_request).decode()
|
||||||
|
|
||||||
def build_auth_n_detached(self) -> Dict[str, str]:
|
def build_auth_n_detached(self) -> dict[str, str]:
|
||||||
"""Get Dict AuthN Request for Redirect bindings, with detached
|
"""Get Dict AuthN Request for Redirect bindings, with detached
|
||||||
Signature. See https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf"""
|
Signature. See https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf"""
|
||||||
auth_n_request = self.get_auth_n()
|
auth_n_request = self.get_auth_n()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""authentik saml source processor"""
|
"""authentik saml source processor"""
|
||||||
from base64 import b64decode
|
from base64 import b64decode
|
||||||
from typing import TYPE_CHECKING, Any, Dict
|
from typing import TYPE_CHECKING, Any
|
||||||
|
|
||||||
import xmlsec
|
import xmlsec
|
||||||
from defusedxml.lxml import fromstring
|
from defusedxml.lxml import fromstring
|
||||||
|
@ -154,7 +154,7 @@ class ResponseProcessor:
|
||||||
raise ValueError("NameID Element not found!")
|
raise ValueError("NameID Element not found!")
|
||||||
return name_id
|
return name_id
|
||||||
|
|
||||||
def _get_name_id_filter(self) -> Dict[str, str]:
|
def _get_name_id_filter(self) -> dict[str, str]:
|
||||||
"""Returns the subject's NameID as a Filter for the `User`"""
|
"""Returns the subject's NameID as a Filter for the `User`"""
|
||||||
name_id_el = self._get_name_id()
|
name_id_el = self._get_name_id()
|
||||||
name_id = name_id_el.text
|
name_id = name_id_el.text
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Static OTP Setup stage"""
|
"""Static OTP Setup stage"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.views.generic import FormView
|
from django.views.generic import FormView
|
||||||
|
@ -21,7 +21,7 @@ class AuthenticatorStaticStageView(FormView, StageView):
|
||||||
|
|
||||||
form_class = SetupForm
|
form_class = SetupForm
|
||||||
|
|
||||||
def get_form_kwargs(self, **kwargs) -> Dict[str, Any]:
|
def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
|
||||||
kwargs = super().get_form_kwargs(**kwargs)
|
kwargs = super().get_form_kwargs(**kwargs)
|
||||||
tokens = self.request.session[SESSION_STATIC_TOKENS]
|
tokens = self.request.session[SESSION_STATIC_TOKENS]
|
||||||
kwargs["tokens"] = tokens
|
kwargs["tokens"] = tokens
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""TOTP Setup stage"""
|
"""TOTP Setup stage"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.utils.encoding import force_str
|
from django.utils.encoding import force_str
|
||||||
|
@ -24,7 +24,7 @@ class AuthenticatorTOTPStageView(FormView, StageView):
|
||||||
|
|
||||||
form_class = SetupForm
|
form_class = SetupForm
|
||||||
|
|
||||||
def get_form_kwargs(self, **kwargs) -> Dict[str, Any]:
|
def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
|
||||||
kwargs = super().get_form_kwargs(**kwargs)
|
kwargs = super().get_form_kwargs(**kwargs)
|
||||||
device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE]
|
device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE]
|
||||||
kwargs["device"] = device
|
kwargs["device"] = device
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""OTP Validation"""
|
"""OTP Validation"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.views.generic import FormView
|
from django.views.generic import FormView
|
||||||
|
@ -32,6 +32,11 @@ class AuthenticatorValidateStageView(ChallengeStageView):
|
||||||
|
|
||||||
form_class = ValidationForm
|
form_class = ValidationForm
|
||||||
|
|
||||||
|
def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
|
||||||
|
kwargs = super().get_form_kwargs(**kwargs)
|
||||||
|
kwargs["user"] = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
|
||||||
|
return kwargs
|
||||||
|
|
||||||
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
|
||||||
"""Check if a user is set, and check if the user has any devices
|
"""Check if a user is set, and check if the user has any devices
|
||||||
if not, we can skip this entire stage"""
|
if not, we can skip this entire stage"""
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik consent stage"""
|
"""authentik consent stage"""
|
||||||
from typing import Any, Dict, List
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest, HttpResponse
|
from django.http import HttpRequest, HttpResponse
|
||||||
from django.utils.timezone import now
|
from django.utils.timezone import now
|
||||||
|
@ -19,13 +19,13 @@ class ConsentStageView(FormView, StageView):
|
||||||
|
|
||||||
form_class = ConsentForm
|
form_class = ConsentForm
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||||
kwargs = super().get_context_data(**kwargs)
|
kwargs = super().get_context_data(**kwargs)
|
||||||
kwargs["current_stage"] = self.executor.current_stage
|
kwargs["current_stage"] = self.executor.current_stage
|
||||||
kwargs["context"] = self.executor.plan.context
|
kwargs["context"] = self.executor.plan.context
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
||||||
def get_template_names(self) -> List[str]:
|
def get_template_names(self) -> list[str]:
|
||||||
# PLAN_CONTEXT_CONSENT_TEMPLATE has to be set by a template that calls this stage
|
# PLAN_CONTEXT_CONSENT_TEMPLATE has to be set by a template that calls this stage
|
||||||
if PLAN_CONTEXT_CONSENT_TEMPLATE in self.executor.plan.context:
|
if PLAN_CONTEXT_CONSENT_TEMPLATE in self.executor.plan.context:
|
||||||
template_name = self.executor.plan.context[PLAN_CONTEXT_CONSENT_TEMPLATE]
|
template_name = self.executor.plan.context[PLAN_CONTEXT_CONSENT_TEMPLATE]
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik multi-stage authentication engine"""
|
"""authentik multi-stage authentication engine"""
|
||||||
from typing import Any, Dict
|
from typing import Any
|
||||||
|
|
||||||
from django.http import HttpRequest
|
from django.http import HttpRequest
|
||||||
|
|
||||||
|
@ -13,7 +13,7 @@ class DummyStageView(StageView):
|
||||||
"""Just redirect to next stage"""
|
"""Just redirect to next stage"""
|
||||||
return self.executor.stage_ok()
|
return self.executor.stage_ok()
|
||||||
|
|
||||||
def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]:
|
def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
|
||||||
kwargs = super().get_context_data(**kwargs)
|
kwargs = super().get_context_data(**kwargs)
|
||||||
kwargs["title"] = self.executor.current_stage.name
|
kwargs["title"] = self.executor.current_stage.name
|
||||||
return kwargs
|
return kwargs
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""Identification stage logic"""
|
"""Identification stage logic"""
|
||||||
from typing import List, Optional
|
from typing import Optional
|
||||||
|
|
||||||
from django.contrib import messages
|
from django.contrib import messages
|
||||||
from django.db.models import Q
|
from django.db.models import Q
|
||||||
|
@ -75,7 +75,7 @@ class IdentificationStageView(ChallengeStageView):
|
||||||
|
|
||||||
# Check all enabled source, add them if they have a UI Login button.
|
# Check all enabled source, add them if they have a UI Login button.
|
||||||
args["sources"] = []
|
args["sources"] = []
|
||||||
sources: List[Source] = (
|
sources: list[Source] = (
|
||||||
Source.objects.filter(enabled=True).order_by("name").select_subclasses()
|
Source.objects.filter(enabled=True).order_by("name").select_subclasses()
|
||||||
)
|
)
|
||||||
for source in sources:
|
for source in sources:
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik password stage"""
|
"""authentik password stage"""
|
||||||
from typing import Any, Dict, List, Optional
|
from typing import Any, Optional
|
||||||
|
|
||||||
from django.contrib.auth import _clean_credentials
|
from django.contrib.auth import _clean_credentials
|
||||||
from django.contrib.auth.backends import BaseBackend
|
from django.contrib.auth.backends import BaseBackend
|
||||||
|
@ -24,7 +24,7 @@ SESSION_INVALID_TRIES = "user_invalid_tries"
|
||||||
|
|
||||||
|
|
||||||
def authenticate(
|
def authenticate(
|
||||||
request: HttpRequest, backends: List[str], **credentials: Dict[str, Any]
|
request: HttpRequest, backends: list[str], **credentials: dict[str, Any]
|
||||||
) -> Optional[User]:
|
) -> Optional[User]:
|
||||||
"""If the given credentials are valid, return a User object.
|
"""If the given credentials are valid, return a User object.
|
||||||
|
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""Prompt forms"""
|
"""Prompt forms"""
|
||||||
from email.policy import Policy
|
from email.policy import Policy
|
||||||
from types import MethodType
|
from types import MethodType
|
||||||
from typing import Any, Callable, Iterator, List
|
from typing import Any, Callable, Iterator
|
||||||
|
|
||||||
from django import forms
|
from django import forms
|
||||||
from django.db.models.query import QuerySet
|
from django.db.models.query import QuerySet
|
||||||
|
@ -52,10 +52,10 @@ class PromptAdminForm(forms.ModelForm):
|
||||||
class ListPolicyEngine(PolicyEngine):
|
class ListPolicyEngine(PolicyEngine):
|
||||||
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
|
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
|
||||||
|
|
||||||
__list: List[Policy]
|
__list: list[Policy]
|
||||||
|
|
||||||
def __init__(
|
def __init__(
|
||||||
self, policies: List[Policy], user: User, request: HttpRequest = None
|
self, policies: list[Policy], user: User, request: HttpRequest = None
|
||||||
) -> None:
|
) -> None:
|
||||||
super().__init__(PolicyBindingModel(), user, request)
|
super().__init__(PolicyBindingModel(), user, request)
|
||||||
self.__list = policies
|
self.__list = policies
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik prompt stage signals"""
|
"""authentik prompt stage signals"""
|
||||||
from django.core.signals import Signal
|
from django.core.signals import Signal
|
||||||
|
|
||||||
# Arguments: password: str, plan_context: Dict[str, Any]
|
# Arguments: password: str, plan_context: dict[str, Any]
|
||||||
password_validate = Signal()
|
password_validate = Signal()
|
||||||
|
|
|
@ -1,5 +1,5 @@
|
||||||
"""authentik user_write signals"""
|
"""authentik user_write signals"""
|
||||||
from django.core.signals import Signal
|
from django.core.signals import Signal
|
||||||
|
|
||||||
# Arguments: request: HttpRequest, user: User, data: Dict[str, Any], created: bool
|
# Arguments: request: HttpRequest, user: User, data: dict[str, Any], created: bool
|
||||||
user_write = Signal()
|
user_write = Signal()
|
||||||
|
|
|
@ -1,6 +1,6 @@
|
||||||
"""Test Enroll flow"""
|
"""Test Enroll flow"""
|
||||||
from sys import platform
|
from sys import platform
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from unittest.case import skipUnless
|
from unittest.case import skipUnless
|
||||||
|
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
@ -22,7 +22,7 @@ from tests.e2e.utils import USER, SeleniumTestCase, retry
|
||||||
class TestFlowsEnroll(SeleniumTestCase):
|
class TestFlowsEnroll(SeleniumTestCase):
|
||||||
"""Test Enroll flow"""
|
"""Test Enroll flow"""
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
"image": "mailhog/mailhog:v1.0.1",
|
"image": "mailhog/mailhog:v1.0.1",
|
||||||
"detach": True,
|
"detach": True,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""test OAuth Provider flow"""
|
"""test OAuth Provider flow"""
|
||||||
from sys import platform
|
from sys import platform
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from unittest.case import skipUnless
|
from unittest.case import skipUnless
|
||||||
|
|
||||||
from docker.types import Healthcheck
|
from docker.types import Healthcheck
|
||||||
|
@ -30,7 +30,7 @@ class TestProviderOAuth2Github(SeleniumTestCase):
|
||||||
self.client_secret = generate_client_secret()
|
self.client_secret = generate_client_secret()
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
"""Setup client grafana container which we test OAuth against"""
|
"""Setup client grafana container which we test OAuth against"""
|
||||||
return {
|
return {
|
||||||
"image": "grafana/grafana:7.1.0",
|
"image": "grafana/grafana:7.1.0",
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""test OAuth2 OpenID Provider flow"""
|
"""test OAuth2 OpenID Provider flow"""
|
||||||
from sys import platform
|
from sys import platform
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from unittest.case import skipUnless
|
from unittest.case import skipUnless
|
||||||
|
|
||||||
from docker.types import Healthcheck
|
from docker.types import Healthcheck
|
||||||
|
@ -40,7 +40,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
|
||||||
self.client_secret = generate_client_secret()
|
self.client_secret = generate_client_secret()
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
"image": "grafana/grafana:7.1.0",
|
"image": "grafana/grafana:7.1.0",
|
||||||
"detach": True,
|
"detach": True,
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from dataclasses import asdict
|
from dataclasses import asdict
|
||||||
from sys import platform
|
from sys import platform
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from unittest.case import skipUnless
|
from unittest.case import skipUnless
|
||||||
|
|
||||||
from channels.testing import ChannelsLiveServerTestCase
|
from channels.testing import ChannelsLiveServerTestCase
|
||||||
|
@ -35,7 +35,7 @@ class TestProviderProxy(SeleniumTestCase):
|
||||||
super().tearDown()
|
super().tearDown()
|
||||||
self.proxy_container.kill()
|
self.proxy_container.kill()
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
"image": "traefik/whoami:latest",
|
"image": "traefik/whoami:latest",
|
||||||
"detach": True,
|
"detach": True,
|
||||||
|
|
|
@ -2,7 +2,7 @@
|
||||||
from os.path import abspath
|
from os.path import abspath
|
||||||
from sys import platform
|
from sys import platform
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from unittest.case import skipUnless
|
from unittest.case import skipUnless
|
||||||
|
|
||||||
from django.test import override_settings
|
from django.test import override_settings
|
||||||
|
@ -72,7 +72,7 @@ class TestSourceOAuth2(SeleniumTestCase):
|
||||||
with open(CONFIG_PATH, "w+") as _file:
|
with open(CONFIG_PATH, "w+") as _file:
|
||||||
safe_dump(config, _file)
|
safe_dump(config, _file)
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
"image": "quay.io/dexidp/dex:v2.24.0",
|
"image": "quay.io/dexidp/dex:v2.24.0",
|
||||||
"detach": True,
|
"detach": True,
|
||||||
|
@ -249,7 +249,7 @@ class TestSourceOAuth1(SeleniumTestCase):
|
||||||
self.source_slug = "oauth1-test"
|
self.source_slug = "oauth1-test"
|
||||||
super().setUp()
|
super().setUp()
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
"image": "beryju/oauth1-test-server",
|
"image": "beryju/oauth1-test-server",
|
||||||
"detach": True,
|
"detach": True,
|
||||||
|
|
|
@ -1,7 +1,7 @@
|
||||||
"""test SAML Source"""
|
"""test SAML Source"""
|
||||||
from sys import platform
|
from sys import platform
|
||||||
from time import sleep
|
from time import sleep
|
||||||
from typing import Any, Dict, Optional
|
from typing import Any, Optional
|
||||||
from unittest.case import skipUnless
|
from unittest.case import skipUnless
|
||||||
|
|
||||||
from docker.types import Healthcheck
|
from docker.types import Healthcheck
|
||||||
|
@ -73,7 +73,7 @@ Sm75WXsflOxuTn08LbgGc4s=
|
||||||
class TestSourceSAML(SeleniumTestCase):
|
class TestSourceSAML(SeleniumTestCase):
|
||||||
"""test SAML Source flow"""
|
"""test SAML Source flow"""
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
return {
|
return {
|
||||||
"image": "kristophjunge/test-saml-idp:1.15",
|
"image": "kristophjunge/test-saml-idp:1.15",
|
||||||
"detach": True,
|
"detach": True,
|
||||||
|
|
|
@ -6,7 +6,7 @@ from importlib.util import module_from_spec, spec_from_file_location
|
||||||
from inspect import getmembers, isfunction
|
from inspect import getmembers, isfunction
|
||||||
from os import environ, makedirs
|
from os import environ, makedirs
|
||||||
from time import sleep, time
|
from time import sleep, time
|
||||||
from typing import Any, Callable, Dict, Optional
|
from typing import Any, Callable, Optional
|
||||||
|
|
||||||
from django.apps import apps
|
from django.apps import apps
|
||||||
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
|
from django.contrib.staticfiles.testing import StaticLiveServerTestCase
|
||||||
|
@ -56,7 +56,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
|
||||||
if specs := self.get_container_specs():
|
if specs := self.get_container_specs():
|
||||||
self.container = self._start_container(specs)
|
self.container = self._start_container(specs)
|
||||||
|
|
||||||
def _start_container(self, specs: Dict[str, Any]) -> Container:
|
def _start_container(self, specs: dict[str, Any]) -> Container:
|
||||||
client: DockerClient = from_env()
|
client: DockerClient = from_env()
|
||||||
client.images.pull(specs["image"])
|
client.images.pull(specs["image"])
|
||||||
container = client.containers.run(**specs)
|
container = client.containers.run(**specs)
|
||||||
|
@ -70,7 +70,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
|
||||||
self.logger.info("Container failed healthcheck")
|
self.logger.info("Container failed healthcheck")
|
||||||
sleep(1)
|
sleep(1)
|
||||||
|
|
||||||
def get_container_specs(self) -> Optional[Dict[str, Any]]:
|
def get_container_specs(self) -> Optional[dict[str, Any]]:
|
||||||
"""Optionally get container specs which will launched on setup, wait for the container to
|
"""Optionally get container specs which will launched on setup, wait for the container to
|
||||||
be healthy, and deleted again on tearDown"""
|
be healthy, and deleted again on tearDown"""
|
||||||
return None
|
return None
|
||||||
|
|
Reference in New Issue