Merge branch 'master' into stage-challenge

# Conflicts:
#	authentik/stages/authenticator_validate/stage.py
#	authentik/stages/identification/stage.py
This commit is contained in:
Jens Langhammer 2021-02-18 14:04:35 +01:00
commit b229b2f40d
73 changed files with 216 additions and 215 deletions

View file

@ -2,7 +2,6 @@
import time import time
from collections import Counter from collections import Counter
from datetime import timedelta from datetime import timedelta
from typing import Dict, List
from django.db.models import Count, ExpressionWrapper, F, Model from django.db.models import Count, ExpressionWrapper, F, Model
from django.db.models.fields import DurationField from django.db.models.fields import DurationField
@ -19,7 +18,7 @@ from rest_framework.viewsets import ViewSet
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
def get_events_per_1h(**filter_kwargs) -> List[Dict[str, int]]: def get_events_per_1h(**filter_kwargs) -> list[dict[str, int]]:
"""Get event count by hour in the last day, fill with zeros""" """Get event count by hour in the last day, fill with zeros"""
date_from = now() - timedelta(days=1) date_from = now() - timedelta(days=1)
result = ( result = (

View file

@ -1,6 +1,6 @@
"""authentik Outpost administration""" """authentik Outpost administration"""
from dataclasses import asdict from dataclasses import asdict
from typing import Any, Dict from typing import Any
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import ( from django.contrib.auth.mixins import (
@ -33,7 +33,7 @@ class OutpostCreateView(
template_name = "generic/create.html" template_name = "generic/create.html"
success_message = _("Successfully created Outpost") success_message = _("Successfully created Outpost")
def get_initial(self) -> Dict[str, Any]: def get_initial(self) -> dict[str, Any]:
return { return {
"_config": asdict( "_config": asdict(
OutpostConfig(authentik_host=self.request.build_absolute_uri("/")) OutpostConfig(authentik_host=self.request.build_absolute_uri("/"))

View file

@ -1,5 +1,5 @@
"""authentik Policy administration""" """authentik Policy administration"""
from typing import Any, Dict from typing import Any
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import ( from django.contrib.auth.mixins import (
@ -102,7 +102,7 @@ class PolicyTestView(LoginRequiredMixin, DetailView, PermissionRequiredMixin, Fo
Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first() Policy.objects.filter(pk=self.kwargs.get("pk")).select_subclasses().first()
) )
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs["policy"] = self.get_object() kwargs["policy"] = self.get_object()
return super().get_context_data(**kwargs) return super().get_context_data(**kwargs)

View file

@ -1,5 +1,5 @@
"""authentik Tasks List""" """authentik Tasks List"""
from typing import Any, Dict from typing import Any
from django.views.generic.base import TemplateView from django.views.generic.base import TemplateView
@ -12,7 +12,7 @@ class TaskListView(AdminRequiredMixin, TemplateView):
template_name = "administration/task/list.html" template_name = "administration/task/list.html"
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
kwargs["object_list"] = sorted( kwargs["object_list"] = sorted(
TaskInfo.all().values(), key=lambda x: x.task_name TaskInfo.all().values(), key=lambda x: x.task_name

View file

@ -1,5 +1,5 @@
"""authentik admin util views""" """authentik admin util views"""
from typing import Any, Dict, List, Optional from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
from django.contrib import messages from django.contrib import messages
@ -40,7 +40,7 @@ class SearchListMixin(MultipleObjectMixin):
"""Accept search query using `search` querystring parameter. Requires self.search_fields, """Accept search query using `search` querystring parameter. Requires self.search_fields,
a list of all fields to search. Can contain special lookups like __icontains""" a list of all fields to search. Can contain special lookups like __icontains"""
search_fields: List[str] search_fields: list[str]
def get_queryset(self) -> QuerySet: def get_queryset(self) -> QuerySet:
queryset = super().get_queryset() queryset = super().get_queryset()
@ -69,7 +69,7 @@ class InheritanceCreateView(CreateAssignPermView):
raise Http404 from exc raise Http404 from exc
return model().form return model().form
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
form_cls = self.get_form_class() form_cls = self.get_form_class()
if hasattr(form_cls, "template_name"): if hasattr(form_cls, "template_name"):
@ -80,7 +80,7 @@ class InheritanceCreateView(CreateAssignPermView):
class InheritanceUpdateView(UpdateView): class InheritanceUpdateView(UpdateView):
"""UpdateView for objects using InheritanceManager""" """UpdateView for objects using InheritanceManager"""
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
form_cls = self.get_form_class() form_cls = self.get_form_class()
if hasattr(form_cls, "template_name"): if hasattr(form_cls, "template_name"):

View file

@ -1,7 +1,7 @@
"""API Authentication""" """API Authentication"""
from base64 import b64decode from base64 import b64decode
from binascii import Error from binascii import Error
from typing import Any, Optional, Tuple, Union from typing import Any, Optional, Union
from rest_framework.authentication import BaseAuthentication, get_authorization_header from rest_framework.authentication import BaseAuthentication, get_authorization_header
from rest_framework.request import Request from rest_framework.request import Request
@ -44,7 +44,7 @@ def token_from_header(raw_header: bytes) -> Optional[Token]:
class AuthentikTokenAuthentication(BaseAuthentication): class AuthentikTokenAuthentication(BaseAuthentication):
"""Token-based authentication using HTTP Basic authentication""" """Token-based authentication using HTTP Basic authentication"""
def authenticate(self, request: Request) -> Union[Tuple[User, Any], None]: def authenticate(self, request: Request) -> Union[tuple[User, Any], None]:
"""Token-based authentication using HTTP Basic authentication""" """Token-based authentication using HTTP Basic authentication"""
auth = get_authorization_header(request) auth = get_authorization_header(request)

View file

@ -1,7 +1,7 @@
"""authentik core models""" """authentik core models"""
from datetime import timedelta from datetime import timedelta
from hashlib import sha256 from hashlib import sha256
from typing import Any, Dict, Optional, Type from typing import Any, Optional, Type
from uuid import uuid4 from uuid import uuid4
from django.conf import settings from django.conf import settings
@ -96,7 +96,7 @@ class User(GuardianUserMixin, AbstractUser):
objects = UserManager() objects = UserManager()
def group_attributes(self) -> Dict[str, Any]: def group_attributes(self) -> dict[str, Any]:
"""Get a dictionary containing the attributes from all groups the user belongs to, """Get a dictionary containing the attributes from all groups the user belongs to,
including the users attributes""" including the users attributes"""
final_attributes = {} final_attributes = {}

View file

@ -1,5 +1,5 @@
"""authentik core user views""" """authentik core user views"""
from typing import Any, Dict from typing import Any
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.contrib.auth.mixins import ( from django.contrib.auth.mixins import (
@ -45,7 +45,7 @@ class UserDetailsView(SuccessMessageMixin, LoginRequiredMixin, UpdateView):
def get_object(self): def get_object(self):
return self.request.user return self.request.user
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
unenrollment_flow = Flow.with_policy( unenrollment_flow = Flow.with_policy(
self.request, designation=FlowDesignation.UNRENOLLMENT self.request, designation=FlowDesignation.UNRENOLLMENT

View file

@ -3,7 +3,7 @@ from dataclasses import dataclass, field
from datetime import datetime from datetime import datetime
from enum import Enum from enum import Enum
from traceback import format_tb from traceback import format_tb
from typing import Any, Dict, List, Optional from typing import Any, Optional
from celery import Task from celery import Task
from django.core.cache import cache from django.core.cache import cache
@ -26,7 +26,7 @@ class TaskResult:
status: TaskResultStatus status: TaskResultStatus
messages: List[str] = field(default_factory=list) messages: list[str] = field(default_factory=list)
# Optional UID used in cache for tasks that run in different instances # Optional UID used in cache for tasks that run in different instances
uid: Optional[str] = field(default=None) uid: Optional[str] = field(default=None)
@ -49,8 +49,8 @@ class TaskInfo:
task_call_module: str task_call_module: str
task_call_func: str task_call_func: str
task_call_args: List[Any] = field(default_factory=list) task_call_args: list[Any] = field(default_factory=list)
task_call_kwargs: Dict[str, Any] = field(default_factory=dict) task_call_kwargs: dict[str, Any] = field(default_factory=dict)
task_description: Optional[str] = field(default=None) task_description: Optional[str] = field(default=None)
@ -60,7 +60,7 @@ class TaskInfo:
return self.task_name.split("_") return self.task_name.split("_")
@staticmethod @staticmethod
def all() -> Dict[str, "TaskInfo"]: def all() -> dict[str, "TaskInfo"]:
"""Get all TaskInfo objects""" """Get all TaskInfo objects"""
return cache.get_many(cache.keys("task_*")) return cache.get_many(cache.keys("task_*"))
@ -109,7 +109,7 @@ class MonitoredTask(Task):
# pylint: disable=too-many-arguments # pylint: disable=too-many-arguments
def after_return( def after_return(
self, status, retval, task_id, args: List[Any], kwargs: Dict[str, Any], einfo self, status, retval, task_id, args: list[Any], kwargs: dict[str, Any], einfo
): ):
if not self._result.uid: if not self._result.uid:
self._result.uid = self._uid self._result.uid = self._uid

View file

@ -1,6 +1,6 @@
"""authentik events signal listener""" """authentik events signal listener"""
from threading import Thread from threading import Thread
from typing import Any, Dict, Optional from typing import Any, Optional
from django.contrib.auth.signals import ( from django.contrib.auth.signals import (
user_logged_in, user_logged_in,
@ -27,7 +27,7 @@ class EventNewThread(Thread):
action: str action: str
request: HttpRequest request: HttpRequest
kwargs: Dict[str, Any] kwargs: dict[str, Any]
user: Optional[User] = None user: Optional[User] = None
def __init__( def __init__(
@ -69,7 +69,7 @@ def on_user_logged_out(sender, request: HttpRequest, user: User, **_):
@receiver(user_write) @receiver(user_write)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def on_user_write( def on_user_write(
sender, request: HttpRequest, user: User, data: Dict[str, Any], **kwargs sender, request: HttpRequest, user: User, data: dict[str, Any], **kwargs
): ):
"""Log User write""" """Log User write"""
thread = EventNewThread(EventAction.USER_WRITE, request, **data) thread = EventNewThread(EventAction.USER_WRITE, request, **data)
@ -81,7 +81,7 @@ def on_user_write(
@receiver(user_login_failed) @receiver(user_login_failed)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def on_user_login_failed( def on_user_login_failed(
sender, credentials: Dict[str, str], request: HttpRequest, **_ sender, credentials: dict[str, str], request: HttpRequest, **_
): ):
"""Failed Login""" """Failed Login"""
thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials) thread = EventNewThread(EventAction.LOGIN_FAILED, request, **credentials)

View file

@ -1,7 +1,7 @@
"""event utilities""" """event utilities"""
import re import re
from dataclasses import asdict, is_dataclass from dataclasses import asdict, is_dataclass
from typing import Any, Dict, Optional from typing import Any, Optional
from uuid import UUID from uuid import UUID
from django.contrib.auth.models import AnonymousUser from django.contrib.auth.models import AnonymousUser
@ -20,7 +20,7 @@ from authentik.policies.types import PolicyRequest
ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I) ALLOWED_SPECIAL_KEYS = re.compile("passing", flags=re.I)
def cleanse_dict(source: Dict[Any, Any]) -> Dict[Any, Any]: def cleanse_dict(source: dict[Any, Any]) -> dict[Any, Any]:
"""Cleanse a dictionary, recursively""" """Cleanse a dictionary, recursively"""
final_dict = {} final_dict = {}
for key, value in source.items(): for key, value in source.items():
@ -38,7 +38,7 @@ def cleanse_dict(source: Dict[Any, Any]) -> Dict[Any, Any]:
return final_dict return final_dict
def model_to_dict(model: Model) -> Dict[str, Any]: def model_to_dict(model: Model) -> dict[str, Any]:
"""Convert model to dict""" """Convert model to dict"""
name = str(model) name = str(model)
if hasattr(model, "name"): if hasattr(model, "name"):
@ -51,7 +51,7 @@ def model_to_dict(model: Model) -> Dict[str, Any]:
} }
def get_user(user: User, original_user: Optional[User] = None) -> Dict[str, Any]: def get_user(user: User, original_user: Optional[User] = None) -> dict[str, Any]:
"""Convert user object to dictionary, optionally including the original user""" """Convert user object to dictionary, optionally including the original user"""
if isinstance(user, AnonymousUser): if isinstance(user, AnonymousUser):
user = get_anonymous_user() user = get_anonymous_user()
@ -67,7 +67,7 @@ def get_user(user: User, original_user: Optional[User] = None) -> Dict[str, Any]
return user_data return user_data
def sanitize_dict(source: Dict[Any, Any]) -> Dict[Any, Any]: def sanitize_dict(source: dict[Any, Any]) -> dict[Any, Any]:
"""clean source of all Models that would interfere with the JSONField. """clean source of all Models that would interfere with the JSONField.
Models are replaced with a dictionary of { Models are replaced with a dictionary of {
app: str, app: str,

View file

@ -1,6 +1,6 @@
"""Flows Planner""" """Flows Planner"""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from typing import Any, Dict, List, Optional from typing import Any, Optional
from django.core.cache import cache from django.core.cache import cache
from django.http import HttpRequest from django.http import HttpRequest
@ -38,9 +38,9 @@ class FlowPlan:
flow_pk: str flow_pk: str
stages: List[Stage] = field(default_factory=list) stages: list[Stage] = field(default_factory=list)
context: Dict[str, Any] = field(default_factory=dict) context: dict[str, Any] = field(default_factory=dict)
markers: List[StageMarker] = field(default_factory=list) markers: list[StageMarker] = field(default_factory=list)
def append(self, stage: Stage, marker: Optional[StageMarker] = None): def append(self, stage: Stage, marker: Optional[StageMarker] = None):
"""Append `stage` to all stages, optionall with stage marker""" """Append `stage` to all stages, optionall with stage marker"""
@ -96,7 +96,7 @@ class FlowPlanner:
self._logger = get_logger().bind(flow=flow) self._logger = get_logger().bind(flow=flow)
def plan( def plan(
self, request: HttpRequest, default_context: Optional[Dict[str, Any]] = None self, request: HttpRequest, default_context: Optional[dict[str, Any]] = None
) -> FlowPlan: ) -> FlowPlan:
"""Check each of the flows' policies, check policies for each stage with PolicyBinding """Check each of the flows' policies, check policies for each stage with PolicyBinding
and return ordered list""" and return ordered list"""
@ -149,7 +149,7 @@ class FlowPlanner:
self, self,
user: User, user: User,
request: HttpRequest, request: HttpRequest,
default_context: Optional[Dict[str, Any]], default_context: Optional[dict[str, Any]],
) -> FlowPlan: ) -> FlowPlan:
"""Build flow plan by checking each stage in their respective """Build flow plan by checking each stage in their respective
order and checking the applied policies""" order and checking the applied policies"""

View file

@ -1,6 +1,6 @@
"""authentik stage Base view""" """authentik stage Base view"""
from collections import namedtuple from collections import namedtuple
from typing import Any, Dict from typing import Any
from django.http import HttpRequest from django.http import HttpRequest
from django.http.response import HttpResponse, JsonResponse from django.http.response import HttpResponse, JsonResponse
@ -32,7 +32,7 @@ class StageView(TemplateView):
def __init__(self, executor: FlowExecutorView): def __init__(self, executor: FlowExecutorView):
self.executor = executor self.executor = executor
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
kwargs["title"] = self.executor.flow.title kwargs["title"] = self.executor.flow.title
# Either show the matched User object or show what the user entered, # Either show the matched User object or show what the user entered,
# based on what the earlier stage (mostly IdentificationStage) set. # based on what the earlier stage (mostly IdentificationStage) set.

View file

@ -1,6 +1,6 @@
"""transfer common classes""" """transfer common classes"""
from dataclasses import asdict, dataclass, field, is_dataclass from dataclasses import asdict, dataclass, field, is_dataclass
from typing import Any, Dict, List from typing import Any
from uuid import UUID from uuid import UUID
from django.core.serializers.json import DjangoJSONEncoder from django.core.serializers.json import DjangoJSONEncoder
@ -9,7 +9,7 @@ from authentik.lib.models import SerializerModel
from authentik.lib.sentry import SentryIgnoredException from authentik.lib.sentry import SentryIgnoredException
def get_attrs(obj: SerializerModel) -> Dict[str, Any]: def get_attrs(obj: SerializerModel) -> dict[str, Any]:
"""Get object's attributes via their serializer, and covert it to a normal dict""" """Get object's attributes via their serializer, and covert it to a normal dict"""
data = dict(obj.serializer(obj).data) data = dict(obj.serializer(obj).data)
to_remove = ( to_remove = (
@ -33,9 +33,9 @@ def get_attrs(obj: SerializerModel) -> Dict[str, Any]:
class FlowBundleEntry: class FlowBundleEntry:
"""Single entry of a bundle""" """Single entry of a bundle"""
identifiers: Dict[str, Any] identifiers: dict[str, Any]
model: str model: str
attrs: Dict[str, Any] attrs: dict[str, Any]
@staticmethod @staticmethod
def from_model( def from_model(
@ -61,7 +61,7 @@ class FlowBundle:
"""Dataclass used for a full export""" """Dataclass used for a full export"""
version: int = field(default=1) version: int = field(default=1)
entries: List[FlowBundleEntry] = field(default_factory=list) entries: list[FlowBundleEntry] = field(default_factory=list)
class DataclassEncoder(DjangoJSONEncoder): class DataclassEncoder(DjangoJSONEncoder):

View file

@ -1,6 +1,6 @@
"""Flow exporter""" """Flow exporter"""
from json import dumps from json import dumps
from typing import Iterator, List from typing import Iterator
from uuid import UUID from uuid import UUID
from django.db.models import Q from django.db.models import Q
@ -22,7 +22,7 @@ class FlowExporter:
with_policies: bool with_policies: bool
with_stage_prompts: bool with_stage_prompts: bool
pbm_uuids: List[UUID] pbm_uuids: list[UUID]
def __init__(self, flow: Flow): def __init__(self, flow: Flow):
self.flow = flow self.flow = flow

View file

@ -2,7 +2,7 @@
from contextlib import contextmanager from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from json import loads from json import loads
from typing import Any, Dict, Type from typing import Any, Type
from dacite import from_dict from dacite import from_dict
from dacite.exceptions import DaciteError from dacite.exceptions import DaciteError
@ -42,7 +42,7 @@ class FlowImporter:
__import: FlowBundle __import: FlowBundle
__pk_map: Dict[Any, Model] __pk_map: dict[Any, Model]
logger: BoundLogger logger: BoundLogger
@ -55,7 +55,7 @@ class FlowImporter:
except DaciteError as exc: except DaciteError as exc:
raise EntryInvalidError from exc raise EntryInvalidError from exc
def __update_pks_for_attrs(self, attrs: Dict[str, Any]) -> Dict[str, Any]: def __update_pks_for_attrs(self, attrs: dict[str, Any]) -> dict[str, Any]:
"""Replace any value if it is a known primary key of an other object""" """Replace any value if it is a known primary key of an other object"""
def updater(value) -> Any: def updater(value) -> Any:
@ -75,7 +75,7 @@ class FlowImporter:
attrs[key] = updater(value) attrs[key] = updater(value)
return attrs return attrs
def __query_from_identifier(self, attrs: Dict[str, Any]) -> Q: def __query_from_identifier(self, attrs: dict[str, Any]) -> Q:
"""Generate an or'd query from all identifiers in an entry""" """Generate an or'd query from all identifiers in an entry"""
# Since identifiers can also be pk-references to other objects (see FlowStageBinding) # Since identifiers can also be pk-references to other objects (see FlowStageBinding)
# we have to ensure those references are also replaced # we have to ensure those references are also replaced

View file

@ -1,6 +1,6 @@
"""authentik multi-stage authentication engine""" """authentik multi-stage authentication engine"""
from traceback import format_tb from traceback import format_tb
from typing import Any, Dict, Optional from typing import Any, Optional
from django.contrib.auth.mixins import LoginRequiredMixin from django.contrib.auth.mixins import LoginRequiredMixin
from django.http import ( from django.http import (
@ -225,8 +225,8 @@ class FlowErrorResponse(TemplateResponse):
self.error = error self.error = error
def resolve_context( def resolve_context(
self, context: Optional[Dict[str, Any]] self, context: Optional[dict[str, Any]]
) -> Optional[Dict[str, Any]]: ) -> Optional[dict[str, Any]]:
if not context: if not context:
context = {} context = {}
context["error"] = self.error context["error"] = self.error
@ -244,7 +244,7 @@ class FlowExecutorShellView(TemplateView):
template_name = "flows/shell.html" template_name = "flows/shell.html"
def get_context_data(self, **kwargs) -> Dict[str, Any]: def get_context_data(self, **kwargs) -> dict[str, Any]:
flow: Flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug")) flow: Flow = get_object_or_404(Flow, slug=self.kwargs.get("flow_slug"))
kwargs["background_url"] = flow.background.url kwargs["background_url"] = flow.background.url
kwargs["exec_url"] = reverse("authentik_api:flow-executor", kwargs=self.kwargs) kwargs["exec_url"] = reverse("authentik_api:flow-executor", kwargs=self.kwargs)

View file

@ -1,7 +1,7 @@
"""authentik expression policy evaluator""" """authentik expression policy evaluator"""
import re import re
from textwrap import indent from textwrap import indent
from typing import Any, Dict, Iterable, Optional from typing import Any, Iterable, Optional
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from requests import Session from requests import Session
@ -18,9 +18,9 @@ class BaseEvaluator:
"""Validate and evaluate python-based expressions""" """Validate and evaluate python-based expressions"""
# Globals that can be used by function # Globals that can be used by function
_globals: Dict[str, Any] _globals: dict[str, Any]
# Context passed as locals to exec() # Context passed as locals to exec()
_context: Dict[str, Any] _context: dict[str, Any]
# Filename used for exec # Filename used for exec
_filename: str _filename: str

View file

@ -1,10 +1,10 @@
"""http helpers""" """http helpers"""
from typing import Any, Dict, Optional from typing import Any, Optional
from django.http import HttpRequest from django.http import HttpRequest
def _get_client_ip_from_meta(meta: Dict[str, Any]) -> Optional[str]: def _get_client_ip_from_meta(meta: dict[str, Any]) -> Optional[str]:
"""Attempt to get the client's IP by checking common HTTP Headers. """Attempt to get the client's IP by checking common HTTP Headers.
Returns none if no IP Could be found""" Returns none if no IP Could be found"""
headers = ( headers = (

View file

@ -1,8 +1,8 @@
"""authentik UI utils""" """authentik UI utils"""
from typing import Any, List from typing import Any
def human_list(_list: List[Any]) -> str: def human_list(_list: list[Any]) -> str:
"""Convert a list of items into 'a, b or c'""" """Convert a list of items into 'a, b or c'"""
last_item = _list.pop() last_item = _list.pop()
if len(_list) < 1: if len(_list) < 1:

View file

@ -2,7 +2,7 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from enum import IntEnum from enum import IntEnum
from typing import Any, Dict, Optional from typing import Any, Optional
from channels.exceptions import DenyConnection from channels.exceptions import DenyConnection
from dacite import from_dict from dacite import from_dict
@ -34,7 +34,7 @@ class WebsocketMessage:
"""Complete Websocket Message that is being sent""" """Complete Websocket Message that is being sent"""
instruction: int instruction: int
args: Dict[str, Any] = field(default_factory=dict) args: dict[str, Any] = field(default_factory=dict)
class OutpostConsumer(AuthJsonConsumer): class OutpostConsumer(AuthJsonConsumer):

View file

@ -1,6 +1,5 @@
"""Docker controller""" """Docker controller"""
from time import sleep from time import sleep
from typing import Dict, Tuple
from django.conf import settings from django.conf import settings
from docker import DockerClient from docker import DockerClient
@ -33,10 +32,10 @@ class DockerController(BaseController):
except ServiceConnectionInvalid as exc: except ServiceConnectionInvalid as exc:
raise ControllerException from exc raise ControllerException from exc
def _get_labels(self) -> Dict[str, str]: def _get_labels(self) -> dict[str, str]:
return {} return {}
def _get_env(self) -> Dict[str, str]: def _get_env(self) -> dict[str, str]:
return { return {
"AUTHENTIK_HOST": self.outpost.config.authentik_host, "AUTHENTIK_HOST": self.outpost.config.authentik_host,
"AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure), "AUTHENTIK_INSECURE": str(self.outpost.config.authentik_host_insecure),
@ -55,7 +54,7 @@ class DockerController(BaseController):
return True return True
return False return False
def _get_container(self) -> Tuple[Container, bool]: def _get_container(self) -> tuple[Container, bool]:
container_name = f"authentik-proxy-{self.outpost.uuid.hex}" container_name = f"authentik-proxy-{self.outpost.uuid.hex}"
try: try:
return self.client.containers.get(container_name), False return self.client.containers.get(container_name), False

View file

@ -1,5 +1,5 @@
"""Kubernetes Deployment Reconciler""" """Kubernetes Deployment Reconciler"""
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from kubernetes.client import ( from kubernetes.client import (
AppsV1Api, AppsV1Api,
@ -53,7 +53,7 @@ class DeploymentReconciler(KubernetesObjectReconciler[V1Deployment]):
): ):
raise NeedsUpdate() raise NeedsUpdate()
def get_pod_meta(self) -> Dict[str, str]: def get_pod_meta(self) -> dict[str, str]:
"""Get common object metadata""" """Get common object metadata"""
return { return {
"app.kubernetes.io/name": "authentik-outpost", "app.kubernetes.io/name": "authentik-outpost",

View file

@ -1,6 +1,6 @@
"""Kubernetes deployment controller""" """Kubernetes deployment controller"""
from io import StringIO from io import StringIO
from typing import Dict, List, Type from typing import Type
from kubernetes.client import OpenApiException from kubernetes.client import OpenApiException
from kubernetes.client.api_client import ApiClient from kubernetes.client.api_client import ApiClient
@ -18,8 +18,8 @@ from authentik.outposts.models import KubernetesServiceConnection, Outpost
class KubernetesController(BaseController): class KubernetesController(BaseController):
"""Manage deployment of outpost in kubernetes""" """Manage deployment of outpost in kubernetes"""
reconcilers: Dict[str, Type[KubernetesObjectReconciler]] reconcilers: dict[str, Type[KubernetesObjectReconciler]]
reconcile_order: List[str] reconcile_order: list[str]
client: ApiClient client: ApiClient
connection: KubernetesServiceConnection connection: KubernetesServiceConnection
@ -45,7 +45,7 @@ class KubernetesController(BaseController):
except OpenApiException as exc: except OpenApiException as exc:
raise ControllerException from exc raise ControllerException from exc
def up_with_logs(self) -> List[str]: def up_with_logs(self) -> list[str]:
try: try:
all_logs = [] all_logs = []
for reconcile_key in self.reconcile_order: for reconcile_key in self.reconcile_order:

View file

@ -1,7 +1,7 @@
"""Outpost models""" """Outpost models"""
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from datetime import datetime from datetime import datetime
from typing import Dict, Iterable, List, Optional, Type, Union from typing import Iterable, Optional, Type, Union
from uuid import uuid4 from uuid import uuid4
from dacite import from_dict from dacite import from_dict
@ -58,7 +58,7 @@ class OutpostConfig:
kubernetes_replicas: int = field(default=1) kubernetes_replicas: int = field(default=1)
kubernetes_namespace: str = field(default="default") kubernetes_namespace: str = field(default="default")
kubernetes_ingress_annotations: Dict[str, str] = field(default_factory=dict) kubernetes_ingress_annotations: dict[str, str] = field(default_factory=dict)
kubernetes_ingress_secret_name: str = field(default="authentik-outpost") kubernetes_ingress_secret_name: str = field(default="authentik-outpost")
@ -315,7 +315,7 @@ class Outpost(models.Model):
return f"outpost_{self.uuid.hex}_state" return f"outpost_{self.uuid.hex}_state"
@property @property
def state(self) -> List["OutpostState"]: def state(self) -> list["OutpostState"]:
"""Get outpost's health status""" """Get outpost's health status"""
return OutpostState.for_outpost(self) return OutpostState.for_outpost(self)
@ -399,7 +399,7 @@ class OutpostState:
return parse(self.version) < OUR_VERSION return parse(self.version) < OUR_VERSION
@staticmethod @staticmethod
def for_outpost(outpost: Outpost) -> List["OutpostState"]: def for_outpost(outpost: Outpost) -> list["OutpostState"]:
"""Get all states for an outpost""" """Get all states for an outpost"""
keys = cache.keys(f"{outpost.state_cache_prefix}_*") keys = cache.keys(f"{outpost.state_cache_prefix}_*")
states = [] states = []

View file

@ -2,7 +2,7 @@
from enum import Enum from enum import Enum
from multiprocessing import Pipe, current_process from multiprocessing import Pipe, current_process
from multiprocessing.connection import Connection from multiprocessing.connection import Connection
from typing import Iterator, List, Optional from typing import Iterator, Optional
from django.core.cache import cache from django.core.cache import cache
from django.http import HttpRequest from django.http import HttpRequest
@ -54,8 +54,8 @@ class PolicyEngine:
empty_result: bool empty_result: bool
__pbm: PolicyBindingModel __pbm: PolicyBindingModel
__cached_policies: List[PolicyResult] __cached_policies: list[PolicyResult]
__processes: List[PolicyProcessInfo] __processes: list[PolicyProcessInfo]
__expected_result_count: int __expected_result_count: int
@ -137,7 +137,7 @@ class PolicyEngine:
@property @property
def result(self) -> PolicyResult: def result(self) -> PolicyResult:
"""Get policy-checking result""" """Get policy-checking result"""
process_results: List[PolicyResult] = [ process_results: list[PolicyResult] = [
x.result for x in self.__processes if x.result x.result for x in self.__processes if x.result
] ]
all_results = list(process_results + self.__cached_policies) all_results = list(process_results + self.__cached_policies)

View file

@ -1,6 +1,6 @@
"""authentik expression policy evaluator""" """authentik expression policy evaluator"""
from ipaddress import ip_address, ip_network from ipaddress import ip_address, ip_network
from typing import TYPE_CHECKING, List, Optional from typing import TYPE_CHECKING, Optional
from django.http import HttpRequest from django.http import HttpRequest
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -19,7 +19,7 @@ if TYPE_CHECKING:
class PolicyEvaluator(BaseEvaluator): class PolicyEvaluator(BaseEvaluator):
"""Validate and evaluate python-based expressions""" """Validate and evaluate python-based expressions"""
_messages: List[str] _messages: list[str]
policy: Optional["ExpressionPolicy"] = None policy: Optional["ExpressionPolicy"] = None

View file

@ -1,5 +1,5 @@
"""policy http response""" """policy http response"""
from typing import Any, Dict, Optional from typing import Any, Optional
from django.http.request import HttpRequest from django.http.request import HttpRequest
from django.template.response import TemplateResponse from django.template.response import TemplateResponse
@ -24,8 +24,8 @@ class AccessDeniedResponse(TemplateResponse):
self.title = _("Access denied") self.title = _("Access denied")
def resolve_context( def resolve_context(
self, context: Optional[Dict[str, Any]] self, context: Optional[dict[str, Any]]
) -> Optional[Dict[str, Any]]: ) -> Optional[dict[str, Any]]:
if not context: if not context:
context = {} context = {}
context["title"] = self.title context["title"] = self.title

View file

@ -1,8 +1,8 @@
"""Policy Utils""" """Policy Utils"""
from typing import Any, Dict from typing import Any
def delete_none_keys(dict_: Dict[Any, Any]) -> Dict[Any, Any]: def delete_none_keys(dict_: dict[Any, Any]) -> dict[Any, Any]:
"""Remove any keys from `dict_` that are None.""" """Remove any keys from `dict_` that are None."""
new_dict = {} new_dict = {}
for key, value in dict_.items(): for key, value in dict_.items():

View file

@ -5,7 +5,7 @@ import json
import time import time
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from hashlib import sha256 from hashlib import sha256
from typing import Any, Dict, List, Optional, Type from typing import Any, Optional, Type
from urllib.parse import urlparse from urllib.parse import urlparse
from uuid import uuid4 from uuid import uuid4
@ -218,7 +218,7 @@ class OAuth2Provider(Provider):
) )
def create_refresh_token( def create_refresh_token(
self, user: User, scope: List[str], request: HttpRequest self, user: User, scope: list[str], request: HttpRequest
) -> "RefreshToken": ) -> "RefreshToken":
"""Create and populate a RefreshToken object.""" """Create and populate a RefreshToken object."""
token = RefreshToken( token = RefreshToken(
@ -231,7 +231,7 @@ class OAuth2Provider(Provider):
token.access_token = token.create_access_token(user, request) token.access_token = token.create_access_token(user, request)
return token return token
def get_jwt_keys(self) -> List[Key]: def get_jwt_keys(self) -> list[Key]:
""" """
Takes a provider and returns the set of keys associated with it. Takes a provider and returns the set of keys associated with it.
Returns a list of keys. Returns a list of keys.
@ -299,7 +299,7 @@ class OAuth2Provider(Provider):
def __str__(self): def __str__(self):
return f"OAuth2 Provider {self.name}" return f"OAuth2 Provider {self.name}"
def encode(self, payload: Dict[str, Any]) -> str: def encode(self, payload: dict[str, Any]) -> str:
"""Represent the ID Token as a JSON Web Token (JWT).""" """Represent the ID Token as a JSON Web Token (JWT)."""
keys = self.get_jwt_keys() keys = self.get_jwt_keys()
# If the provider does not have an RSA Key assigned, it was switched to Symmetric # If the provider does not have an RSA Key assigned, it was switched to Symmetric
@ -321,7 +321,7 @@ class BaseGrantModel(models.Model):
_scope = models.TextField(default="", verbose_name=_("Scopes")) _scope = models.TextField(default="", verbose_name=_("Scopes"))
@property @property
def scope(self) -> List[str]: def scope(self) -> list[str]:
"""Return scopes as list of strings""" """Return scopes as list of strings"""
return self._scope.split() return self._scope.split()
@ -394,9 +394,9 @@ class IDToken:
nonce: Optional[str] = None nonce: Optional[str] = None
at_hash: Optional[str] = None at_hash: Optional[str] = None
claims: Dict[str, Any] = field(default_factory=dict) claims: dict[str, Any] = field(default_factory=dict)
def to_dict(self) -> Dict[str, Any]: def to_dict(self) -> dict[str, Any]:
"""Convert dataclass to dict, and update with keys from `claims`""" """Convert dataclass to dict, and update with keys from `claims`"""
dic = asdict(self) dic = asdict(self)
dic.pop("claims") dic.pop("claims")

View file

@ -2,7 +2,7 @@
import re import re
from base64 import b64decode from base64 import b64decode
from binascii import Error from binascii import Error
from typing import List, Optional, Tuple from typing import Optional
from django.http import HttpRequest, HttpResponse, JsonResponse from django.http import HttpRequest, HttpResponse, JsonResponse
from django.utils.cache import patch_vary_headers from django.utils.cache import patch_vary_headers
@ -68,7 +68,7 @@ def extract_access_token(request: HttpRequest) -> Optional[str]:
return None return None
def extract_client_auth(request: HttpRequest) -> Tuple[str, str]: def extract_client_auth(request: HttpRequest) -> tuple[str, str]:
""" """
Get client credentials using HTTP Basic Authentication method. Get client credentials using HTTP Basic Authentication method.
Or try getting parameters via POST. Or try getting parameters via POST.
@ -92,7 +92,7 @@ def extract_client_auth(request: HttpRequest) -> Tuple[str, str]:
return (client_id, client_secret) return (client_id, client_secret)
def protected_resource_view(scopes: List[str]): def protected_resource_view(scopes: list[str]):
"""View decorator. The client accesses protected resources by presenting the """View decorator. The client accesses protected resources by presenting the
access token to the resource server. access token to the resource server.

View file

@ -1,7 +1,7 @@
"""authentik OAuth2 Authorization views""" """authentik OAuth2 Authorization views"""
from dataclasses import dataclass, field from dataclasses import dataclass, field
from datetime import timedelta from datetime import timedelta
from typing import List, Optional, Set from typing import Optional
from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit from urllib.parse import parse_qs, urlencode, urlsplit, urlunsplit
from uuid import uuid4 from uuid import uuid4
@ -69,10 +69,10 @@ class OAuthAuthorizationParams:
client_id: str client_id: str
redirect_uri: str redirect_uri: str
response_type: str response_type: str
scope: List[str] scope: list[str]
state: str state: str
nonce: Optional[str] nonce: Optional[str]
prompt: Set[str] prompt: set[str]
grant_type: str grant_type: str
provider: OAuth2Provider = field(default_factory=OAuth2Provider) provider: OAuth2Provider = field(default_factory=OAuth2Provider)

View file

@ -1,5 +1,5 @@
"""authentik OAuth2 OpenID well-known views""" """authentik OAuth2 OpenID well-known views"""
from typing import Any, Dict from typing import Any
from django.http import HttpRequest, HttpResponse, JsonResponse from django.http import HttpRequest, HttpResponse, JsonResponse
from django.shortcuts import get_object_or_404, reverse from django.shortcuts import get_object_or_404, reverse
@ -29,7 +29,7 @@ PLAN_CONTEXT_SCOPES = "scopes"
class ProviderInfoView(View): class ProviderInfoView(View):
"""OpenID-compliant Provider Info""" """OpenID-compliant Provider Info"""
def get_info(self, provider: OAuth2Provider) -> Dict[str, Any]: def get_info(self, provider: OAuth2Provider) -> dict[str, Any]:
"""Get dictionary for OpenID Connect information""" """Get dictionary for OpenID Connect information"""
scopes = list( scopes = list(
ScopeMapping.objects.filter(provider=provider).values_list( ScopeMapping.objects.filter(provider=provider).values_list(

View file

@ -1,5 +1,5 @@
"""authentik OAuth2 Session Views""" """authentik OAuth2 Session Views"""
from typing import Any, Dict from typing import Any
from django.shortcuts import get_object_or_404 from django.shortcuts import get_object_or_404
from django.views.generic.base import TemplateView from django.views.generic.base import TemplateView
@ -12,7 +12,7 @@ class EndSessionView(TemplateView):
template_name = "providers/oauth2/end_session.html" template_name = "providers/oauth2/end_session.html"
def get_context_data(self, **kwargs: Any) -> Dict[str, Any]: def get_context_data(self, **kwargs: Any) -> dict[str, Any]:
context = super().get_context_data(**kwargs) context = super().get_context_data(**kwargs)
context["application"] = get_object_or_404( context["application"] = get_object_or_404(

View file

@ -2,7 +2,7 @@
from base64 import urlsafe_b64encode from base64 import urlsafe_b64encode
from dataclasses import InitVar, dataclass from dataclasses import InitVar, dataclass
from hashlib import sha256 from hashlib import sha256
from typing import Any, Dict, List, Optional from typing import Any, Optional
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.views import View from django.views import View
@ -33,7 +33,7 @@ class TokenParams:
redirect_uri: str redirect_uri: str
grant_type: str grant_type: str
state: str state: str
scope: List[str] scope: list[str]
authorization_code: Optional[AuthorizationCode] = None authorization_code: Optional[AuthorizationCode] = None
refresh_token: Optional[RefreshToken] = None refresh_token: Optional[RefreshToken] = None
@ -171,7 +171,7 @@ class TokenView(View):
except UserAuthError as error: except UserAuthError as error:
return TokenResponse(error.create_dict(), status=403) return TokenResponse(error.create_dict(), status=403)
def create_code_response_dic(self) -> Dict[str, Any]: def create_code_response_dic(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-4.1""" """See https://tools.ietf.org/html/rfc6749#section-4.1"""
refresh_token = self.params.authorization_code.provider.create_refresh_token( refresh_token = self.params.authorization_code.provider.create_refresh_token(
@ -207,7 +207,7 @@ class TokenView(View):
return response_dict return response_dict
def create_refresh_response_dic(self) -> Dict[str, Any]: def create_refresh_response_dic(self) -> dict[str, Any]:
"""See https://tools.ietf.org/html/rfc6749#section-6""" """See https://tools.ietf.org/html/rfc6749#section-6"""
unauthorized_scopes = set(self.params.scope) - set( unauthorized_scopes = set(self.params.scope) - set(

View file

@ -1,5 +1,5 @@
"""authentik OAuth2 OpenID Userinfo views""" """authentik OAuth2 OpenID Userinfo views"""
from typing import Any, Dict, List from typing import Any
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.translation import gettext_lazy as _ from django.utils.translation import gettext_lazy as _
@ -22,7 +22,7 @@ class UserInfoView(View):
"""Create a dictionary with all the requested claims about the End-User. """Create a dictionary with all the requested claims about the End-User.
See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse""" See: http://openid.net/specs/openid-connect-core-1_0.html#UserInfoResponse"""
def get_scope_descriptions(self, scopes: List[str]) -> Dict[str, str]: def get_scope_descriptions(self, scopes: list[str]) -> dict[str, str]:
"""Get a list of all Scopes's descriptions""" """Get a list of all Scopes's descriptions"""
scope_descriptions = {} scope_descriptions = {}
for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by( for scope in ScopeMapping.objects.filter(scope_name__in=scopes).order_by(
@ -47,7 +47,7 @@ class UserInfoView(View):
scope_descriptions[scope] = github_scope_map[scope] scope_descriptions[scope] = github_scope_map[scope]
return scope_descriptions return scope_descriptions
def get_claims(self, token: RefreshToken) -> Dict[str, Any]: def get_claims(self, token: RefreshToken) -> dict[str, Any]:
"""Get a dictionary of claims from scopes that the token """Get a dictionary of claims from scopes that the token
requires and are assigned to the provider.""" requires and are assigned to the provider."""

View file

@ -1,5 +1,4 @@
"""Proxy Provider Docker Contoller""" """Proxy Provider Docker Contoller"""
from typing import Dict
from urllib.parse import urlparse from urllib.parse import urlparse
from authentik.outposts.controllers.base import DeploymentPort from authentik.outposts.controllers.base import DeploymentPort
@ -18,7 +17,7 @@ class ProxyDockerController(DockerController):
DeploymentPort(4443, "https", "tcp"), DeploymentPort(4443, "https", "tcp"),
] ]
def _get_labels(self) -> Dict[str, str]: def _get_labels(self) -> dict[str, str]:
hosts = [] hosts = []
for proxy_provider in ProxyProvider.objects.filter(outpost__in=[self.outpost]): for proxy_provider in ProxyProvider.objects.filter(outpost__in=[self.outpost]):
proxy_provider: ProxyProvider proxy_provider: ProxyProvider

View file

@ -1,5 +1,5 @@
"""Kubernetes Ingress Reconciler""" """Kubernetes Ingress Reconciler"""
from typing import TYPE_CHECKING, Dict from typing import TYPE_CHECKING
from urllib.parse import urlparse from urllib.parse import urlparse
from kubernetes.client import ( from kubernetes.client import (
@ -78,7 +78,7 @@ class IngressReconciler(KubernetesObjectReconciler[NetworkingV1beta1Ingress]):
if have_hosts_tls != expected_hosts_tls: if have_hosts_tls != expected_hosts_tls:
raise NeedsUpdate() raise NeedsUpdate()
def get_ingress_annotations(self) -> Dict[str, str]: def get_ingress_annotations(self) -> dict[str, str]:
"""Get ingress annotations""" """Get ingress annotations"""
annotations = { annotations = {
# Ensure that with multiple proxy replicas deployed, the same CSRF request # Ensure that with multiple proxy replicas deployed, the same CSRF request

View file

@ -8,7 +8,7 @@ https://docs.djangoproject.com/en/3.0/howto/deployment/asgi/
""" """
import typing import typing
from time import time from time import time
from typing import Any, ByteString, Dict from typing import Any, ByteString
import django import django
from asgiref.compatibility import guarantee_single_callable from asgiref.compatibility import guarantee_single_callable
@ -64,7 +64,7 @@ class ASGILogger:
app: ASGIApp app: ASGIApp
scope: Scope scope: Scope
headers: Dict[ByteString, Any] headers: dict[ByteString, Any]
status_code: int status_code: int
start: float start: float

View file

@ -1,5 +1,5 @@
"""authentik ldap source signals""" """authentik ldap source signals"""
from typing import Any, Dict from typing import Any
from django.core.exceptions import ValidationError from django.core.exceptions import ValidationError
from django.db.models.signals import post_save from django.db.models.signals import post_save
@ -26,7 +26,7 @@ def sync_ldap_source_on_save(sender, instance: LDAPSource, **_):
@receiver(password_validate) @receiver(password_validate)
# pylint: disable=unused-argument # pylint: disable=unused-argument
def ldap_password_validate(sender, password: str, plan_context: Dict[str, Any], **__): def ldap_password_validate(sender, password: str, plan_context: dict[str, Any], **__):
"""if there's an LDAP Source with enabled password sync, check the password""" """if there's an LDAP Source with enabled password sync, check the password"""
sources = LDAPSource.objects.filter(sync_users_password=True) sources = LDAPSource.objects.filter(sync_users_password=True)
if not sources.exists(): if not sources.exists():

View file

@ -1,5 +1,5 @@
"""OAuth Clients""" """OAuth Clients"""
from typing import Any, Dict, Optional from typing import Any, Optional
from urllib.parse import urlencode from urllib.parse import urlencode
from django.http import HttpRequest from django.http import HttpRequest
@ -33,11 +33,11 @@ class BaseOAuthClient:
self.callback = callback self.callback = callback
self.session.headers.update({"User-Agent": f"authentik {__version__}"}) self.session.headers.update({"User-Agent": f"authentik {__version__}"})
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request." "Fetch access token from callback request."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]: def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
"Fetch user profile information." "Fetch user profile information."
try: try:
response = self.do_request("get", self.source.profile_url, token=token) response = self.do_request("get", self.source.profile_url, token=token)
@ -48,7 +48,7 @@ class BaseOAuthClient:
else: else:
return response.json() return response.json()
def get_redirect_args(self) -> Dict[str, str]: def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url." "Get request parameters for redirect url."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover
@ -61,7 +61,7 @@ class BaseOAuthClient:
LOGGER.info("redirect args", **args) LOGGER.info("redirect args", **args)
return f"{self.source.authorization_url}?{params}" return f"{self.source.authorization_url}?{params}"
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
raise NotImplementedError("Defined in a sub-class") # pragma: no cover raise NotImplementedError("Defined in a sub-class") # pragma: no cover

View file

@ -1,5 +1,5 @@
"""OAuth 1 Clients""" """OAuth 1 Clients"""
from typing import Any, Dict, Optional from typing import Any, Optional
from urllib.parse import parse_qsl from urllib.parse import parse_qsl
from requests.exceptions import RequestException from requests.exceptions import RequestException
@ -20,7 +20,7 @@ class OAuthClient(BaseOAuthClient):
"Accept": "application/json", "Accept": "application/json",
} }
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request." "Fetch access token from callback request."
raw_token = self.request.session.get(self.session_key, None) raw_token = self.request.session.get(self.session_key, None)
verifier = self.request.GET.get("oauth_verifier", None) verifier = self.request.GET.get("oauth_verifier", None)
@ -60,7 +60,7 @@ class OAuthClient(BaseOAuthClient):
else: else:
return response.text return response.text
def get_redirect_args(self) -> Dict[str, Any]: def get_redirect_args(self) -> dict[str, Any]:
"Get request parameters for redirect url." "Get request parameters for redirect url."
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
raw_token = self.get_request_token() raw_token = self.get_request_token()
@ -71,7 +71,7 @@ class OAuthClient(BaseOAuthClient):
"oauth_callback": callback, "oauth_callback": callback,
} }
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
return dict(parse_qsl(raw_token)) return dict(parse_qsl(raw_token))
@ -80,7 +80,7 @@ class OAuthClient(BaseOAuthClient):
resource_owner_key = None resource_owner_key = None
resource_owner_secret = None resource_owner_secret = None
if "token" in kwargs: if "token" in kwargs:
user_token: Dict[str, Any] = kwargs.pop("token") user_token: dict[str, Any] = kwargs.pop("token")
resource_owner_key = user_token["oauth_token"] resource_owner_key = user_token["oauth_token"]
resource_owner_secret = user_token["oauth_token_secret"] resource_owner_secret = user_token["oauth_token_secret"]

View file

@ -1,6 +1,6 @@
"""OAuth 2 Clients""" """OAuth 2 Clients"""
from json import loads from json import loads
from typing import Any, Dict, Optional from typing import Any, Optional
from urllib.parse import parse_qsl from urllib.parse import parse_qsl
from django.utils.crypto import constant_time_compare, get_random_string from django.utils.crypto import constant_time_compare, get_random_string
@ -38,7 +38,7 @@ class OAuth2Client(BaseOAuthClient):
"Generate state optional parameter." "Generate state optional parameter."
return get_random_string(32) return get_random_string(32)
def get_access_token(self, **request_kwargs) -> Optional[Dict[str, Any]]: def get_access_token(self, **request_kwargs) -> Optional[dict[str, Any]]:
"Fetch access token from callback request." "Fetch access token from callback request."
callback = self.request.build_absolute_uri(self.callback or self.request.path) callback = self.request.build_absolute_uri(self.callback or self.request.path)
if not self.check_application_state(): if not self.check_application_state():
@ -69,11 +69,11 @@ class OAuth2Client(BaseOAuthClient):
else: else:
return response.json() return response.json()
def get_redirect_args(self) -> Dict[str, str]: def get_redirect_args(self) -> dict[str, str]:
"Get request parameters for redirect url." "Get request parameters for redirect url."
callback = self.request.build_absolute_uri(self.callback) callback = self.request.build_absolute_uri(self.callback)
client_id: str = self.source.consumer_key client_id: str = self.source.consumer_key
args: Dict[str, str] = { args: dict[str, str] = {
"client_id": client_id, "client_id": client_id,
"redirect_uri": callback, "redirect_uri": callback,
"response_type": "code", "response_type": "code",
@ -84,7 +84,7 @@ class OAuth2Client(BaseOAuthClient):
self.request.session[self.session_key] = state self.request.session[self.session_key] = state
return args return args
def parse_raw_token(self, raw_token: str) -> Dict[str, Any]: def parse_raw_token(self, raw_token: str) -> dict[str, Any]:
"Parse token and secret from raw token response." "Parse token and secret from raw token response."
# Load as json first then parse as query string # Load as json first then parse as query string
try: try:

View file

@ -1,5 +1,5 @@
"""AzureAD OAuth2 Views""" """AzureAD OAuth2 Views"""
from typing import Any, Dict from typing import Any
from uuid import UUID from uuid import UUID
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
@ -11,15 +11,15 @@ from authentik.sources.oauth.views.callback import OAuthCallback
class AzureADOAuthCallback(OAuthCallback): class AzureADOAuthCallback(OAuthCallback):
"""AzureAD OAuth2 Callback""" """AzureAD OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: Dict[str, Any]) -> str: def get_user_id(self, source: OAuthSource, info: dict[str, Any]) -> str:
return str(UUID(info.get("objectId")).int) return str(UUID(info.get("objectId")).int)
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
mail = info.get("mail", None) or info.get("otherMails", [None])[0] mail = info.get("mail", None) or info.get("otherMails", [None])[0]
return { return {
"username": info.get("displayName"), "username": info.get("displayName"),

View file

@ -1,5 +1,5 @@
"""Discord OAuth Views""" """Discord OAuth Views"""
from typing import Any, Dict from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -25,8 +25,8 @@ class DiscordOAuth2Callback(OAuthCallback):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"username": info.get("username"), "username": info.get("username"),
"email": info.get("email", None), "email": info.get("email", None),

View file

@ -1,5 +1,5 @@
"""Facebook OAuth Views""" """Facebook OAuth Views"""
from typing import Any, Dict, Optional from typing import Any, Optional
from facebook import GraphAPI from facebook import GraphAPI
@ -23,7 +23,7 @@ class FacebookOAuthRedirect(OAuthRedirect):
class FacebookOAuth2Client(OAuth2Client): class FacebookOAuth2Client(OAuth2Client):
"""Facebook OAuth2 Client""" """Facebook OAuth2 Client"""
def get_profile_info(self, token: Dict[str, str]) -> Optional[Dict[str, Any]]: def get_profile_info(self, token: dict[str, str]) -> Optional[dict[str, Any]]:
api = GraphAPI(access_token=token["access_token"]) api = GraphAPI(access_token=token["access_token"])
return api.get_object("me", fields="id,name,email") return api.get_object("me", fields="id,name,email")
@ -38,8 +38,8 @@ class FacebookOAuth2Callback(OAuthCallback):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"username": info.get("name"), "username": info.get("name"),
"email": info.get("email"), "email": info.get("email"),

View file

@ -1,5 +1,5 @@
"""GitHub OAuth Views""" """GitHub OAuth Views"""
from typing import Any, Dict from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -14,8 +14,8 @@ class GitHubOAuth2Callback(OAuthCallback):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"username": info.get("login"), "username": info.get("login"),
"email": info.get("email"), "email": info.get("email"),

View file

@ -1,5 +1,5 @@
"""Google OAuth Views""" """Google OAuth Views"""
from typing import Any, Dict from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -25,8 +25,8 @@ class GoogleOAuth2Callback(OAuthCallback):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"username": info.get("email"), "username": info.get("email"),
"email": info.get("email"), "email": info.get("email"),

View file

@ -1,6 +1,6 @@
"""Source type manager""" """Source type manager"""
from enum import Enum from enum import Enum
from typing import Callable, Dict, List from typing import Callable
from django.utils.text import slugify from django.utils.text import slugify
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
@ -22,8 +22,8 @@ class RequestKind(Enum):
class SourceTypeManager: class SourceTypeManager:
"""Manager to hold all Source types.""" """Manager to hold all Source types."""
__source_types: Dict[RequestKind, Dict[str, Callable]] = {} __source_types: dict[RequestKind, dict[str, Callable]] = {}
__names: List[str] = [] __names: list[str] = []
def source(self, kind: RequestKind, name: str): def source(self, kind: RequestKind, name: str):
"""Class decorator to register classes inline.""" """Class decorator to register classes inline."""

View file

@ -1,5 +1,5 @@
"""OpenID Connect OAuth Views""" """OpenID Connect OAuth Views"""
from typing import Any, Dict from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -21,15 +21,15 @@ class OpenIDConnectOAuthRedirect(OAuthRedirect):
class OpenIDConnectOAuth2Callback(OAuthCallback): class OpenIDConnectOAuth2Callback(OAuthCallback):
"""OpenIDConnect OAuth2 Callback""" """OpenIDConnect OAuth2 Callback"""
def get_user_id(self, source: OAuthSource, info: Dict[str, str]) -> str: def get_user_id(self, source: OAuthSource, info: dict[str, str]) -> str:
return info.get("sub", "") return info.get("sub", "")
def get_user_enroll_context( def get_user_enroll_context(
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"username": info.get("nickname"), "username": info.get("nickname"),
"email": info.get("email"), "email": info.get("email"),

View file

@ -1,5 +1,5 @@
"""Reddit OAuth Views""" """Reddit OAuth Views"""
from typing import Any, Dict from typing import Any
from requests.auth import HTTPBasicAuth from requests.auth import HTTPBasicAuth
@ -40,8 +40,8 @@ class RedditOAuth2Callback(OAuthCallback):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"username": info.get("name"), "username": info.get("name"),
"email": None, "email": None,

View file

@ -1,5 +1,5 @@
"""Twitter OAuth Views""" """Twitter OAuth Views"""
from typing import Any, Dict from typing import Any
from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection from authentik.sources.oauth.models import OAuthSource, UserOAuthSourceConnection
from authentik.sources.oauth.types.manager import MANAGER, RequestKind from authentik.sources.oauth.types.manager import MANAGER, RequestKind
@ -14,8 +14,8 @@ class TwitterOAuthCallback(OAuthCallback):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
return { return {
"username": info.get("screen_name"), "username": info.get("screen_name"),
"email": info.get("email", None), "email": info.get("email", None),

View file

@ -1,5 +1,5 @@
"""OAuth Callback Views""" """OAuth Callback Views"""
from typing import Any, Dict, Optional from typing import Any, Optional
from django.conf import settings from django.conf import settings
from django.contrib import messages from django.contrib import messages
@ -115,14 +115,14 @@ class OAuthCallback(OAuthClientMixin, View):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> Dict[str, Any]: ) -> dict[str, Any]:
"""Create a dict of User data""" """Create a dict of User data"""
raise NotImplementedError() raise NotImplementedError()
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get_user_id( def get_user_id(
self, source: UserOAuthSourceConnection, info: Dict[str, Any] self, source: UserOAuthSourceConnection, info: dict[str, Any]
) -> Optional[str]: ) -> Optional[str]:
"""Return unique identifier from the profile info.""" """Return unique identifier from the profile info."""
if "id" in info: if "id" in info:
@ -167,7 +167,7 @@ class OAuthCallback(OAuthClientMixin, View):
source: OAuthSource, source: OAuthSource,
user: User, user: User,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> HttpResponse: ) -> HttpResponse:
"Login user and redirect." "Login user and redirect."
messages.success( messages.success(
@ -184,7 +184,7 @@ class OAuthCallback(OAuthClientMixin, View):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> HttpResponse: ) -> HttpResponse:
"""Handler when the user was already authenticated and linked an external source """Handler when the user was already authenticated and linked an external source
to their account.""" to their account."""
@ -211,7 +211,7 @@ class OAuthCallback(OAuthClientMixin, View):
self, self,
source: OAuthSource, source: OAuthSource,
access: UserOAuthSourceConnection, access: UserOAuthSourceConnection,
info: Dict[str, Any], info: dict[str, Any],
) -> HttpResponse: ) -> HttpResponse:
"""User was not authenticated and previous request was not authenticated.""" """User was not authenticated and previous request was not authenticated."""
messages.success( messages.success(

View file

@ -1,5 +1,5 @@
"""OAuth Redirect Views""" """OAuth Redirect Views"""
from typing import Any, Dict from typing import Any
from django.http import Http404 from django.http import Http404
from django.urls import reverse from django.urls import reverse
@ -19,7 +19,7 @@ class OAuthRedirect(OAuthClientMixin, RedirectView):
params = None params = None
# pylint: disable=unused-argument # pylint: disable=unused-argument
def get_additional_parameters(self, source: OAuthSource) -> Dict[str, Any]: def get_additional_parameters(self, source: OAuthSource) -> dict[str, Any]:
"Return additional redirect parameters for this source." "Return additional redirect parameters for this source."
return self.params or {} return self.params or {}

View file

@ -1,6 +1,5 @@
"""SAML AuthnRequest Processor""" """SAML AuthnRequest Processor"""
from base64 import b64encode from base64 import b64encode
from typing import Dict
from urllib.parse import quote_plus from urllib.parse import quote_plus
import xmlsec import xmlsec
@ -125,7 +124,7 @@ class RequestProcessor:
return etree.tostring(auth_n_request).decode() return etree.tostring(auth_n_request).decode()
def build_auth_n_detached(self) -> Dict[str, str]: def build_auth_n_detached(self) -> dict[str, str]:
"""Get Dict AuthN Request for Redirect bindings, with detached """Get Dict AuthN Request for Redirect bindings, with detached
Signature. See https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf""" Signature. See https://docs.oasis-open.org/security/saml/v2.0/saml-bindings-2.0-os.pdf"""
auth_n_request = self.get_auth_n() auth_n_request = self.get_auth_n()

View file

@ -1,6 +1,6 @@
"""authentik saml source processor""" """authentik saml source processor"""
from base64 import b64decode from base64 import b64decode
from typing import TYPE_CHECKING, Any, Dict from typing import TYPE_CHECKING, Any
import xmlsec import xmlsec
from defusedxml.lxml import fromstring from defusedxml.lxml import fromstring
@ -154,7 +154,7 @@ class ResponseProcessor:
raise ValueError("NameID Element not found!") raise ValueError("NameID Element not found!")
return name_id return name_id
def _get_name_id_filter(self) -> Dict[str, str]: def _get_name_id_filter(self) -> dict[str, str]:
"""Returns the subject's NameID as a Filter for the `User`""" """Returns the subject's NameID as a Filter for the `User`"""
name_id_el = self._get_name_id() name_id_el = self._get_name_id()
name_id = name_id_el.text name_id = name_id_el.text

View file

@ -1,5 +1,5 @@
"""Static OTP Setup stage""" """Static OTP Setup stage"""
from typing import Any, Dict from typing import Any
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.views.generic import FormView from django.views.generic import FormView
@ -21,7 +21,7 @@ class AuthenticatorStaticStageView(FormView, StageView):
form_class = SetupForm form_class = SetupForm
def get_form_kwargs(self, **kwargs) -> Dict[str, Any]: def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs = super().get_form_kwargs(**kwargs) kwargs = super().get_form_kwargs(**kwargs)
tokens = self.request.session[SESSION_STATIC_TOKENS] tokens = self.request.session[SESSION_STATIC_TOKENS]
kwargs["tokens"] = tokens kwargs["tokens"] = tokens

View file

@ -1,5 +1,5 @@
"""TOTP Setup stage""" """TOTP Setup stage"""
from typing import Any, Dict from typing import Any
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.encoding import force_str from django.utils.encoding import force_str
@ -24,7 +24,7 @@ class AuthenticatorTOTPStageView(FormView, StageView):
form_class = SetupForm form_class = SetupForm
def get_form_kwargs(self, **kwargs) -> Dict[str, Any]: def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs = super().get_form_kwargs(**kwargs) kwargs = super().get_form_kwargs(**kwargs)
device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE] device: TOTPDevice = self.request.session[SESSION_TOTP_DEVICE]
kwargs["device"] = device kwargs["device"] = device

View file

@ -1,5 +1,5 @@
"""OTP Validation""" """OTP Validation"""
from typing import Any, Dict from typing import Any
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.views.generic import FormView from django.views.generic import FormView
@ -32,6 +32,11 @@ class AuthenticatorValidateStageView(ChallengeStageView):
form_class = ValidationForm form_class = ValidationForm
def get_form_kwargs(self, **kwargs) -> dict[str, Any]:
kwargs = super().get_form_kwargs(**kwargs)
kwargs["user"] = self.executor.plan.context.get(PLAN_CONTEXT_PENDING_USER)
return kwargs
def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: def get(self, request: HttpRequest, *args, **kwargs) -> HttpResponse:
"""Check if a user is set, and check if the user has any devices """Check if a user is set, and check if the user has any devices
if not, we can skip this entire stage""" if not, we can skip this entire stage"""

View file

@ -1,5 +1,5 @@
"""authentik consent stage""" """authentik consent stage"""
from typing import Any, Dict, List from typing import Any
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
from django.utils.timezone import now from django.utils.timezone import now
@ -19,13 +19,13 @@ class ConsentStageView(FormView, StageView):
form_class = ConsentForm form_class = ConsentForm
def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]: def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
kwargs["current_stage"] = self.executor.current_stage kwargs["current_stage"] = self.executor.current_stage
kwargs["context"] = self.executor.plan.context kwargs["context"] = self.executor.plan.context
return kwargs return kwargs
def get_template_names(self) -> List[str]: def get_template_names(self) -> list[str]:
# PLAN_CONTEXT_CONSENT_TEMPLATE has to be set by a template that calls this stage # PLAN_CONTEXT_CONSENT_TEMPLATE has to be set by a template that calls this stage
if PLAN_CONTEXT_CONSENT_TEMPLATE in self.executor.plan.context: if PLAN_CONTEXT_CONSENT_TEMPLATE in self.executor.plan.context:
template_name = self.executor.plan.context[PLAN_CONTEXT_CONSENT_TEMPLATE] template_name = self.executor.plan.context[PLAN_CONTEXT_CONSENT_TEMPLATE]

View file

@ -1,5 +1,5 @@
"""authentik multi-stage authentication engine""" """authentik multi-stage authentication engine"""
from typing import Any, Dict from typing import Any
from django.http import HttpRequest from django.http import HttpRequest
@ -13,7 +13,7 @@ class DummyStageView(StageView):
"""Just redirect to next stage""" """Just redirect to next stage"""
return self.executor.stage_ok() return self.executor.stage_ok()
def get_context_data(self, **kwargs: Dict[str, Any]) -> Dict[str, Any]: def get_context_data(self, **kwargs: dict[str, Any]) -> dict[str, Any]:
kwargs = super().get_context_data(**kwargs) kwargs = super().get_context_data(**kwargs)
kwargs["title"] = self.executor.current_stage.name kwargs["title"] = self.executor.current_stage.name
return kwargs return kwargs

View file

@ -1,5 +1,5 @@
"""Identification stage logic""" """Identification stage logic"""
from typing import List, Optional from typing import Optional
from django.contrib import messages from django.contrib import messages
from django.db.models import Q from django.db.models import Q
@ -75,7 +75,7 @@ class IdentificationStageView(ChallengeStageView):
# Check all enabled source, add them if they have a UI Login button. # Check all enabled source, add them if they have a UI Login button.
args["sources"] = [] args["sources"] = []
sources: List[Source] = ( sources: list[Source] = (
Source.objects.filter(enabled=True).order_by("name").select_subclasses() Source.objects.filter(enabled=True).order_by("name").select_subclasses()
) )
for source in sources: for source in sources:

View file

@ -1,5 +1,5 @@
"""authentik password stage""" """authentik password stage"""
from typing import Any, Dict, List, Optional from typing import Any, Optional
from django.contrib.auth import _clean_credentials from django.contrib.auth import _clean_credentials
from django.contrib.auth.backends import BaseBackend from django.contrib.auth.backends import BaseBackend
@ -24,7 +24,7 @@ SESSION_INVALID_TRIES = "user_invalid_tries"
def authenticate( def authenticate(
request: HttpRequest, backends: List[str], **credentials: Dict[str, Any] request: HttpRequest, backends: list[str], **credentials: dict[str, Any]
) -> Optional[User]: ) -> Optional[User]:
"""If the given credentials are valid, return a User object. """If the given credentials are valid, return a User object.

View file

@ -1,7 +1,7 @@
"""Prompt forms""" """Prompt forms"""
from email.policy import Policy from email.policy import Policy
from types import MethodType from types import MethodType
from typing import Any, Callable, Iterator, List from typing import Any, Callable, Iterator
from django import forms from django import forms
from django.db.models.query import QuerySet from django.db.models.query import QuerySet
@ -52,10 +52,10 @@ class PromptAdminForm(forms.ModelForm):
class ListPolicyEngine(PolicyEngine): class ListPolicyEngine(PolicyEngine):
"""Slightly modified policy engine, which uses a list instead of a PolicyBindingModel""" """Slightly modified policy engine, which uses a list instead of a PolicyBindingModel"""
__list: List[Policy] __list: list[Policy]
def __init__( def __init__(
self, policies: List[Policy], user: User, request: HttpRequest = None self, policies: list[Policy], user: User, request: HttpRequest = None
) -> None: ) -> None:
super().__init__(PolicyBindingModel(), user, request) super().__init__(PolicyBindingModel(), user, request)
self.__list = policies self.__list = policies

View file

@ -1,5 +1,5 @@
"""authentik prompt stage signals""" """authentik prompt stage signals"""
from django.core.signals import Signal from django.core.signals import Signal
# Arguments: password: str, plan_context: Dict[str, Any] # Arguments: password: str, plan_context: dict[str, Any]
password_validate = Signal() password_validate = Signal()

View file

@ -1,5 +1,5 @@
"""authentik user_write signals""" """authentik user_write signals"""
from django.core.signals import Signal from django.core.signals import Signal
# Arguments: request: HttpRequest, user: User, data: Dict[str, Any], created: bool # Arguments: request: HttpRequest, user: User, data: dict[str, Any], created: bool
user_write = Signal() user_write = Signal()

View file

@ -1,6 +1,6 @@
"""Test Enroll flow""" """Test Enroll flow"""
from sys import platform from sys import platform
from typing import Any, Dict, Optional from typing import Any, Optional
from unittest.case import skipUnless from unittest.case import skipUnless
from django.test import override_settings from django.test import override_settings
@ -22,7 +22,7 @@ from tests.e2e.utils import USER, SeleniumTestCase, retry
class TestFlowsEnroll(SeleniumTestCase): class TestFlowsEnroll(SeleniumTestCase):
"""Test Enroll flow""" """Test Enroll flow"""
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
return { return {
"image": "mailhog/mailhog:v1.0.1", "image": "mailhog/mailhog:v1.0.1",
"detach": True, "detach": True,

View file

@ -1,7 +1,7 @@
"""test OAuth Provider flow""" """test OAuth Provider flow"""
from sys import platform from sys import platform
from time import sleep from time import sleep
from typing import Any, Dict, Optional from typing import Any, Optional
from unittest.case import skipUnless from unittest.case import skipUnless
from docker.types import Healthcheck from docker.types import Healthcheck
@ -30,7 +30,7 @@ class TestProviderOAuth2Github(SeleniumTestCase):
self.client_secret = generate_client_secret() self.client_secret = generate_client_secret()
super().setUp() super().setUp()
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
"""Setup client grafana container which we test OAuth against""" """Setup client grafana container which we test OAuth against"""
return { return {
"image": "grafana/grafana:7.1.0", "image": "grafana/grafana:7.1.0",

View file

@ -1,7 +1,7 @@
"""test OAuth2 OpenID Provider flow""" """test OAuth2 OpenID Provider flow"""
from sys import platform from sys import platform
from time import sleep from time import sleep
from typing import Any, Dict, Optional from typing import Any, Optional
from unittest.case import skipUnless from unittest.case import skipUnless
from docker.types import Healthcheck from docker.types import Healthcheck
@ -40,7 +40,7 @@ class TestProviderOAuth2OAuth(SeleniumTestCase):
self.client_secret = generate_client_secret() self.client_secret = generate_client_secret()
super().setUp() super().setUp()
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
return { return {
"image": "grafana/grafana:7.1.0", "image": "grafana/grafana:7.1.0",
"detach": True, "detach": True,

View file

@ -2,7 +2,7 @@
from dataclasses import asdict from dataclasses import asdict
from sys import platform from sys import platform
from time import sleep from time import sleep
from typing import Any, Dict, Optional from typing import Any, Optional
from unittest.case import skipUnless from unittest.case import skipUnless
from channels.testing import ChannelsLiveServerTestCase from channels.testing import ChannelsLiveServerTestCase
@ -35,7 +35,7 @@ class TestProviderProxy(SeleniumTestCase):
super().tearDown() super().tearDown()
self.proxy_container.kill() self.proxy_container.kill()
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
return { return {
"image": "traefik/whoami:latest", "image": "traefik/whoami:latest",
"detach": True, "detach": True,

View file

@ -2,7 +2,7 @@
from os.path import abspath from os.path import abspath
from sys import platform from sys import platform
from time import sleep from time import sleep
from typing import Any, Dict, Optional from typing import Any, Optional
from unittest.case import skipUnless from unittest.case import skipUnless
from django.test import override_settings from django.test import override_settings
@ -72,7 +72,7 @@ class TestSourceOAuth2(SeleniumTestCase):
with open(CONFIG_PATH, "w+") as _file: with open(CONFIG_PATH, "w+") as _file:
safe_dump(config, _file) safe_dump(config, _file)
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
return { return {
"image": "quay.io/dexidp/dex:v2.24.0", "image": "quay.io/dexidp/dex:v2.24.0",
"detach": True, "detach": True,
@ -249,7 +249,7 @@ class TestSourceOAuth1(SeleniumTestCase):
self.source_slug = "oauth1-test" self.source_slug = "oauth1-test"
super().setUp() super().setUp()
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
return { return {
"image": "beryju/oauth1-test-server", "image": "beryju/oauth1-test-server",
"detach": True, "detach": True,

View file

@ -1,7 +1,7 @@
"""test SAML Source""" """test SAML Source"""
from sys import platform from sys import platform
from time import sleep from time import sleep
from typing import Any, Dict, Optional from typing import Any, Optional
from unittest.case import skipUnless from unittest.case import skipUnless
from docker.types import Healthcheck from docker.types import Healthcheck
@ -73,7 +73,7 @@ Sm75WXsflOxuTn08LbgGc4s=
class TestSourceSAML(SeleniumTestCase): class TestSourceSAML(SeleniumTestCase):
"""test SAML Source flow""" """test SAML Source flow"""
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
return { return {
"image": "kristophjunge/test-saml-idp:1.15", "image": "kristophjunge/test-saml-idp:1.15",
"detach": True, "detach": True,

View file

@ -6,7 +6,7 @@ from importlib.util import module_from_spec, spec_from_file_location
from inspect import getmembers, isfunction from inspect import getmembers, isfunction
from os import environ, makedirs from os import environ, makedirs
from time import sleep, time from time import sleep, time
from typing import Any, Callable, Dict, Optional from typing import Any, Callable, Optional
from django.apps import apps from django.apps import apps
from django.contrib.staticfiles.testing import StaticLiveServerTestCase from django.contrib.staticfiles.testing import StaticLiveServerTestCase
@ -56,7 +56,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
if specs := self.get_container_specs(): if specs := self.get_container_specs():
self.container = self._start_container(specs) self.container = self._start_container(specs)
def _start_container(self, specs: Dict[str, Any]) -> Container: def _start_container(self, specs: dict[str, Any]) -> Container:
client: DockerClient = from_env() client: DockerClient = from_env()
client.images.pull(specs["image"]) client.images.pull(specs["image"])
container = client.containers.run(**specs) container = client.containers.run(**specs)
@ -70,7 +70,7 @@ class SeleniumTestCase(StaticLiveServerTestCase):
self.logger.info("Container failed healthcheck") self.logger.info("Container failed healthcheck")
sleep(1) sleep(1)
def get_container_specs(self) -> Optional[Dict[str, Any]]: def get_container_specs(self) -> Optional[dict[str, Any]]:
"""Optionally get container specs which will launched on setup, wait for the container to """Optionally get container specs which will launched on setup, wait for the container to
be healthy, and deleted again on tearDown""" be healthy, and deleted again on tearDown"""
return None return None