ci: update pyright (#3546)

This commit is contained in:
Jens L 2022-09-07 00:23:25 +02:00 committed by GitHub
parent 03a3f1bd6f
commit 62f93c83d4
No known key found for this signature in database
GPG key ID: 4AEE18F83AFDEB23
41 changed files with 131 additions and 95 deletions

View file

@ -27,7 +27,7 @@ runs:
docker-compose -f .github/actions/setup/docker-compose.yml up -d docker-compose -f .github/actions/setup/docker-compose.yml up -d
poetry env use python3.10 poetry env use python3.10
poetry install poetry install
npm install -g pyright@1.1.136 cd web && npm ci
- name: Generate config - name: Generate config
shell: poetry run python {0} shell: poetry run python {0}
run: | run: |

View file

@ -148,25 +148,25 @@ website-watch:
# These targets are use by GitHub actions to allow usage of matrix # These targets are use by GitHub actions to allow usage of matrix
# which makes the YAML File a lot smaller # which makes the YAML File a lot smaller
PY_SOURCES=authentik tests lifecycle
ci--meta-debug: ci--meta-debug:
python -V python -V
node --version node --version
ci-pylint: ci--meta-debug ci-pylint: ci--meta-debug
pylint authentik tests lifecycle pylint $(PY_SOURCES)
ci-black: ci--meta-debug ci-black: ci--meta-debug
black --check authentik tests lifecycle black --check $(PY_SOURCES)
ci-isort: ci--meta-debug ci-isort: ci--meta-debug
isort --check authentik tests lifecycle isort --check $(PY_SOURCES)
ci-bandit: ci--meta-debug ci-bandit: ci--meta-debug
bandit -r authentik tests lifecycle bandit -r $(PY_SOURCES)
ci-pyright: ci--meta-debug ci-pyright: ci--meta-debug
pyright e2e lifecycle ./web/node_modules/.bin/pyright $(PY_SOURCES)
ci-pending-migrations: ci--meta-debug ci-pending-migrations: ci--meta-debug
ak makemigrations --check ak makemigrations --check

View file

@ -16,7 +16,7 @@ from authentik.providers.oauth2.models import RefreshToken
LOGGER = get_logger() LOGGER = get_logger()
def validate_auth(header: bytes) -> str: def validate_auth(header: bytes) -> Optional[str]:
"""Validate that the header is in a correct format, """Validate that the header is in a correct format,
returns type and credentials""" returns type and credentials"""
auth_credentials = header.decode().strip() auth_credentials = header.decode().strip()

View file

@ -4,7 +4,7 @@ from glob import glob
from pathlib import Path from pathlib import Path
import django.contrib.postgres.fields import django.contrib.postgres.fields
from dacite import from_dict from dacite.core import from_dict
from django.apps.registry import Apps from django.apps.registry import Apps
from django.conf import settings from django.conf import settings
from django.db import migrations, models from django.db import migrations, models

View file

@ -105,9 +105,9 @@ class Blueprint:
version: int = field(default=1) version: int = field(default=1)
entries: list[BlueprintEntry] = field(default_factory=list) entries: list[BlueprintEntry] = field(default_factory=list)
context: dict = field(default_factory=dict)
metadata: Optional[BlueprintMetadata] = field(default=None) metadata: Optional[BlueprintMetadata] = field(default=None)
context: Optional[dict] = field(default_factory=dict)
class YAMLTag: class YAMLTag:

View file

@ -1,5 +1,5 @@
"""Blueprint exporter""" """Blueprint exporter"""
from typing import Iterator from typing import Iterable
from uuid import UUID from uuid import UUID
from django.apps import apps from django.apps import apps
@ -34,7 +34,7 @@ class Exporter:
Event, Event,
] ]
def get_entries(self) -> Iterator[BlueprintEntry]: def get_entries(self) -> Iterable[BlueprintEntry]:
"""Get blueprint entries""" """Get blueprint entries"""
for model in apps.get_models(): for model in apps.get_models():
if not is_model_allowed(model): if not is_model_allowed(model):
@ -96,7 +96,7 @@ class FlowExporter(Exporter):
"pbm_uuid", flat=True "pbm_uuid", flat=True
) )
def walk_stages(self) -> Iterator[BlueprintEntry]: def walk_stages(self) -> Iterable[BlueprintEntry]:
"""Convert all stages attached to self.flow into BlueprintEntry objects""" """Convert all stages attached to self.flow into BlueprintEntry objects"""
stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses() stages = Stage.objects.filter(flow=self.flow).select_related().select_subclasses()
for stage in stages: for stage in stages:
@ -104,13 +104,13 @@ class FlowExporter(Exporter):
pass pass
yield BlueprintEntry.from_model(stage, "name") yield BlueprintEntry.from_model(stage, "name")
def walk_stage_bindings(self) -> Iterator[BlueprintEntry]: def walk_stage_bindings(self) -> Iterable[BlueprintEntry]:
"""Convert all bindings attached to self.flow into BlueprintEntry objects""" """Convert all bindings attached to self.flow into BlueprintEntry objects"""
bindings = FlowStageBinding.objects.filter(target=self.flow).select_related() bindings = FlowStageBinding.objects.filter(target=self.flow).select_related()
for binding in bindings: for binding in bindings:
yield BlueprintEntry.from_model(binding, "target", "stage", "order") yield BlueprintEntry.from_model(binding, "target", "stage", "order")
def walk_policies(self) -> Iterator[BlueprintEntry]: def walk_policies(self) -> Iterable[BlueprintEntry]:
"""Walk over all policies. This is done at the beginning of the export for stages that have """Walk over all policies. This is done at the beginning of the export for stages that have
a direct foreign key to a policy.""" a direct foreign key to a policy."""
# Special case for PromptStage as that has a direct M2M to policy, we have to ensure # Special case for PromptStage as that has a direct M2M to policy, we have to ensure
@ -121,21 +121,21 @@ class FlowExporter(Exporter):
for policy in policies: for policy in policies:
yield BlueprintEntry.from_model(policy) yield BlueprintEntry.from_model(policy)
def walk_policy_bindings(self) -> Iterator[BlueprintEntry]: def walk_policy_bindings(self) -> Iterable[BlueprintEntry]:
"""Walk over all policybindings relative to us. This is run at the end of the export, as """Walk over all policybindings relative to us. This is run at the end of the export, as
we are sure all objects exist now.""" we are sure all objects exist now."""
bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related() bindings = PolicyBinding.objects.filter(target__in=self.pbm_uuids).select_related()
for binding in bindings: for binding in bindings:
yield BlueprintEntry.from_model(binding, "policy", "target", "order") yield BlueprintEntry.from_model(binding, "policy", "target", "order")
def walk_stage_prompts(self) -> Iterator[BlueprintEntry]: def walk_stage_prompts(self) -> Iterable[BlueprintEntry]:
"""Walk over all prompts associated with any PromptStages""" """Walk over all prompts associated with any PromptStages"""
prompt_stages = PromptStage.objects.filter(flow=self.flow) prompt_stages = PromptStage.objects.filter(flow=self.flow)
for stage in prompt_stages: for stage in prompt_stages:
for prompt in stage.fields.all(): for prompt in stage.fields.all():
yield BlueprintEntry.from_model(prompt) yield BlueprintEntry.from_model(prompt)
def get_entries(self) -> Iterator[BlueprintEntry]: def get_entries(self) -> Iterable[BlueprintEntry]:
entries = [] entries = []
entries.append(BlueprintEntry.from_model(self.flow, "slug")) entries.append(BlueprintEntry.from_model(self.flow, "slug"))
if self.with_stage_prompts: if self.with_stage_prompts:

View file

@ -3,7 +3,7 @@ from contextlib import contextmanager
from copy import deepcopy from copy import deepcopy
from typing import Any, Optional from typing import Any, Optional
from dacite import from_dict from dacite.core import from_dict
from dacite.exceptions import DaciteError from dacite.exceptions import DaciteError
from deepmerge import always_merger from deepmerge import always_merger
from django.db import transaction from django.db import transaction
@ -143,7 +143,8 @@ class Importer:
if not is_model_allowed(model): if not is_model_allowed(model):
raise EntryInvalidError(f"Model {model} not allowed") raise EntryInvalidError(f"Model {model} not allowed")
if issubclass(model, BaseMetaModel): if issubclass(model, BaseMetaModel):
serializer = model.serializer()(data=entry.get_attrs(self.__import)) serializer_class: type[Serializer] = model.serializer()
serializer = serializer_class(data=entry.get_attrs(self.__import))
try: try:
serializer.is_valid(raise_exception=True) serializer.is_valid(raise_exception=True)
except ValidationError as exc: except ValidationError as exc:

View file

@ -1,6 +1,4 @@
"""Base models""" """Base models"""
from typing import Optional
from django.apps import apps from django.apps import apps
from django.db.models import Model from django.db.models import Model
from rest_framework.serializers import Serializer from rest_framework.serializers import Serializer
@ -51,7 +49,7 @@ class MetaModelRegistry:
models.append(value) models.append(value)
return models return models
def get_model(self, app_label: str, model_id: str) -> Optional[type[Model]]: def get_model(self, app_label: str, model_id: str) -> type[Model]:
"""Get model checks if any virtual models are registered, and falls back """Get model checks if any virtual models are registered, and falls back
to actual django models""" to actual django models"""
if app_label.lower() == self.virtual_prefix: if app_label.lower() == self.virtual_prefix:

View file

@ -4,7 +4,7 @@ from hashlib import sha512
from pathlib import Path from pathlib import Path
from typing import Optional from typing import Optional
from dacite import from_dict from dacite.core import from_dict
from django.db import DatabaseError, InternalError, ProgrammingError from django.db import DatabaseError, InternalError, ProgrammingError
from django.utils.text import slugify from django.utils.text import slugify
from django.utils.timezone import now from django.utils.timezone import now
@ -77,7 +77,9 @@ def blueprints_find():
LOGGER.warning("invalid blueprint version", version=version, path=str(path)) LOGGER.warning("invalid blueprint version", version=version, path=str(path))
continue continue
file_hash = sha512(path.read_bytes()).hexdigest() file_hash = sha512(path.read_bytes()).hexdigest()
blueprint = BlueprintFile(path.relative_to(root), version, file_hash, path.stat().st_mtime) blueprint = BlueprintFile(
str(path.relative_to(root)), version, file_hash, int(path.stat().st_mtime)
)
blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None blueprint.meta = from_dict(BlueprintMetadata, metadata) if metadata else None
blueprints.append(blueprint) blueprints.append(blueprint)
LOGGER.info( LOGGER.info(
@ -136,6 +138,7 @@ def check_blueprint_v1_file(blueprint: BlueprintFile):
def apply_blueprint(self: MonitoredTask, instance_pk: str): def apply_blueprint(self: MonitoredTask, instance_pk: str):
"""Apply single blueprint""" """Apply single blueprint"""
self.save_on_success = False self.save_on_success = False
instance: Optional[BlueprintInstance] = None
try: try:
instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first() instance: BlueprintInstance = BlueprintInstance.objects.filter(pk=instance_pk).first()
self.set_uid(slugify(instance.name)) self.set_uid(slugify(instance.name))
@ -170,7 +173,9 @@ def apply_blueprint(self: MonitoredTask, instance_pk: str):
BlueprintRetrievalFailed, BlueprintRetrievalFailed,
EntryInvalidError, EntryInvalidError,
) as exc: ) as exc:
if instance:
instance.status = BlueprintInstanceStatus.ERROR instance.status = BlueprintInstanceStatus.ERROR
self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc)) self.set_status(TaskResult(TaskResultStatus.ERROR).with_error(exc))
finally: finally:
if instance:
instance.save() instance.save()

View file

@ -9,7 +9,7 @@ from django.db.models.signals import post_save, pre_delete
from authentik import __version__ from authentik import __version__
from authentik.core.models import User from authentik.core.models import User
from authentik.events.middleware import IGNORED_MODELS from authentik.events.middleware import should_log_model
from authentik.events.models import Event, EventAction from authentik.events.models import Event, EventAction
from authentik.events.utils import model_to_dict from authentik.events.utils import model_to_dict
@ -50,7 +50,7 @@ class Command(BaseCommand):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def post_save_handler(sender, instance: Model, created: bool, **_): def post_save_handler(sender, instance: Model, created: bool, **_):
"""Signal handler for all object's post_save""" """Signal handler for all object's post_save"""
if isinstance(instance, IGNORED_MODELS): if not should_log_model(instance):
return return
action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED
@ -66,7 +66,7 @@ class Command(BaseCommand):
# pylint: disable=unused-argument # pylint: disable=unused-argument
def pre_delete_handler(sender, instance: Model, **_): def pre_delete_handler(sender, instance: Model, **_):
"""Signal handler for all object's pre_delete""" """Signal handler for all object's pre_delete"""
if isinstance(instance, IGNORED_MODELS): # pragma: no cover if not should_log_model(instance): # pragma: no cover
return return
Event.new(EventAction.MODEL_DELETED, model=model_to_dict(instance)).set_user( Event.new(EventAction.MODEL_DELETED, model=model_to_dict(instance)).set_user(

View file

@ -1,6 +1,6 @@
"""authentik admin Middleware to impersonate users""" """authentik admin Middleware to impersonate users"""
from contextvars import ContextVar from contextvars import ContextVar
from typing import Callable from typing import Callable, Optional
from uuid import uuid4 from uuid import uuid4
from django.http import HttpRequest, HttpResponse from django.http import HttpRequest, HttpResponse
@ -13,9 +13,9 @@ RESPONSE_HEADER_ID = "X-authentik-id"
KEY_AUTH_VIA = "auth_via" KEY_AUTH_VIA = "auth_via"
KEY_USER = "user" KEY_USER = "user"
CTX_REQUEST_ID = ContextVar(STRUCTLOG_KEY_PREFIX + "request_id", default=None) CTX_REQUEST_ID = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "request_id", default=None)
CTX_HOST = ContextVar(STRUCTLOG_KEY_PREFIX + "host", default=None) CTX_HOST = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + "host", default=None)
CTX_AUTH_VIA = ContextVar(STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None) CTX_AUTH_VIA = ContextVar[Optional[str]](STRUCTLOG_KEY_PREFIX + KEY_AUTH_VIA, default=None)
class ImpersonateMiddleware: class ImpersonateMiddleware:

View file

@ -52,5 +52,5 @@ def create_test_cert() -> CertificateKeyPair:
subject_alt_names=["goauthentik.io"], subject_alt_names=["goauthentik.io"],
validity_days=360, validity_days=360,
) )
builder.name = generate_id() builder.common_name = generate_id()
return builder.save() return builder.save()

View file

@ -26,7 +26,7 @@ class CertificateBuilder:
self.common_name = "authentik Self-signed Certificate" self.common_name = "authentik Self-signed Certificate"
self.cert = CertificateKeyPair() self.cert = CertificateKeyPair()
def save(self) -> Optional[CertificateKeyPair]: def save(self) -> CertificateKeyPair:
"""Save generated certificate as model""" """Save generated certificate as model"""
if not self.__certificate: if not self.__certificate:
raise ValueError("Certificated hasn't been built yet") raise ValueError("Certificated hasn't been built yet")

View file

@ -6,12 +6,7 @@ from uuid import uuid4
from cryptography.hazmat.backends import default_backend from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives import hashes from cryptography.hazmat.primitives import hashes
from cryptography.hazmat.primitives.asymmetric.ec import ( from cryptography.hazmat.primitives.asymmetric.types import PRIVATE_KEY_TYPES, PUBLIC_KEY_TYPES
EllipticCurvePrivateKey,
EllipticCurvePublicKey,
)
from cryptography.hazmat.primitives.asymmetric.ed25519 import Ed25519PrivateKey, Ed25519PublicKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey, RSAPublicKey
from cryptography.hazmat.primitives.serialization import load_pem_private_key from cryptography.hazmat.primitives.serialization import load_pem_private_key
from cryptography.x509 import Certificate, load_pem_x509_certificate from cryptography.x509 import Certificate, load_pem_x509_certificate
from django.db import models from django.db import models
@ -42,8 +37,8 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
) )
_cert: Optional[Certificate] = None _cert: Optional[Certificate] = None
_private_key: Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey] = None _private_key: Optional[PRIVATE_KEY_TYPES] = None
_public_key: Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey] = None _public_key: Optional[PUBLIC_KEY_TYPES] = None
@property @property
def serializer(self) -> Serializer: def serializer(self) -> Serializer:
@ -61,7 +56,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
return self._cert return self._cert
@property @property
def public_key(self) -> Optional[RSAPublicKey | EllipticCurvePublicKey | Ed25519PublicKey]: def public_key(self) -> Optional[PUBLIC_KEY_TYPES]:
"""Get public key of the private key""" """Get public key of the private key"""
if not self._public_key: if not self._public_key:
self._public_key = self.private_key.public_key() self._public_key = self.private_key.public_key()
@ -70,7 +65,7 @@ class CertificateKeyPair(SerializerModel, ManagedModel, CreatedUpdatedModel):
@property @property
def private_key( def private_key(
self, self,
) -> Optional[RSAPrivateKey | EllipticCurvePrivateKey | Ed25519PrivateKey]: ) -> Optional[PRIVATE_KEY_TYPES]:
"""Get python cryptography PrivateKey instance""" """Get python cryptography PrivateKey instance"""
if not self._private_key and self.key_data != "": if not self._private_key and self.key_data != "":
try: try:

View file

@ -19,7 +19,7 @@ from authentik.flows.models import FlowToken
from authentik.lib.sentry import before_send from authentik.lib.sentry import before_send
from authentik.lib.utils.errors import exception_to_string from authentik.lib.utils.errors import exception_to_string
IGNORED_MODELS = [ IGNORED_MODELS = (
Event, Event,
Notification, Notification,
UserObjectPermission, UserObjectPermission,
@ -27,12 +27,14 @@ IGNORED_MODELS = [
StaticToken, StaticToken,
Session, Session,
FlowToken, FlowToken,
] )
if settings.DEBUG:
from silk.models import Request, Response, SQLQuery
IGNORED_MODELS += [Request, Response, SQLQuery]
IGNORED_MODELS = tuple(IGNORED_MODELS) def should_log_model(model: Model) -> bool:
"""Return true if operation on `model` should be logged"""
if model.__module__.startswith("silk"):
return False
return not isinstance(model, IGNORED_MODELS)
class AuditMiddleware: class AuditMiddleware:
@ -109,7 +111,7 @@ class AuditMiddleware:
user: User, request: HttpRequest, sender, instance: Model, created: bool, **_ user: User, request: HttpRequest, sender, instance: Model, created: bool, **_
): ):
"""Signal handler for all object's post_save""" """Signal handler for all object's post_save"""
if isinstance(instance, IGNORED_MODELS): if not should_log_model(instance):
return return
action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED action = EventAction.MODEL_CREATED if created else EventAction.MODEL_UPDATED
@ -119,7 +121,7 @@ class AuditMiddleware:
# pylint: disable=unused-argument # pylint: disable=unused-argument
def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_): def pre_delete_handler(user: User, request: HttpRequest, sender, instance: Model, **_):
"""Signal handler for all object's pre_delete""" """Signal handler for all object's pre_delete"""
if isinstance(instance, IGNORED_MODELS): # pragma: no cover if not should_log_model(instance): # pragma: no cover
return return
EventNewThread( EventNewThread(

View file

@ -152,6 +152,7 @@ class FlowExecutorView(APIView):
token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first() token: Optional[FlowToken] = FlowToken.filter_not_expired(key=key).first()
if not token: if not token:
return None return None
plan = None
try: try:
plan = token.plan plan = token.plan
except (AttributeError, EOFError, ImportError, IndexError) as exc: except (AttributeError, EOFError, ImportError, IndexError) as exc:

View file

@ -20,7 +20,7 @@ ENV_PREFIX = "AUTHENTIK"
ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local") ENVIRONMENT = os.getenv(f"{ENV_PREFIX}_ENV", "local")
def get_path_from_dict(root: dict, path: str, sep=".", default=None): def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any:
"""Recursively walk through `root`, checking each part of `path` split by `sep`. """Recursively walk through `root`, checking each part of `path` split by `sep`.
If at any point a dict does not exist, return default""" If at any point a dict does not exist, return default"""
for comp in path.split(sep): for comp in path.split(sep):
@ -180,7 +180,7 @@ class ConfigLoader:
# pyright: reportGeneralTypeIssues=false # pyright: reportGeneralTypeIssues=false
if comp not in root: if comp not in root:
root[comp] = {} root[comp] = {}
root = root.get(comp) root = root.get(comp, {})
root[path_parts[-1]] = value root[path_parts[-1]] = value
def y_bool(self, path: str, default=False) -> bool: def y_bool(self, path: str, default=False) -> bool:

View file

@ -12,5 +12,4 @@ class TestReflectionUtils(TestCase):
def test_path_to_class(self): def test_path_to_class(self):
"""Test path_to_class""" """Test path_to_class"""
self.assertIsNone(path_to_class(None))
self.assertEqual(path_to_class("datetime.datetime"), datetime) self.assertEqual(path_to_class("datetime.datetime"), datetime)

View file

@ -29,10 +29,8 @@ def class_to_path(cls: type) -> str:
return f"{cls.__module__}.{cls.__name__}" return f"{cls.__module__}.{cls.__name__}"
def path_to_class(path: str | None) -> type | None: def path_to_class(path: str = "") -> type:
"""Import module and return class""" """Import module and return class"""
if not path:
return None
parts = path.split(".") parts = path.split(".")
package = ".".join(parts[:-1]) package = ".".join(parts[:-1])
_class = getattr(import_module(package), parts[-1]) _class = getattr(import_module(package), parts[-1])

View file

@ -5,7 +5,7 @@ from enum import IntEnum
from typing import Any, Optional from typing import Any, Optional
from channels.exceptions import DenyConnection from channels.exceptions import DenyConnection
from dacite import from_dict from dacite.core import from_dict
from dacite.data import Data from dacite.data import Data
from guardian.shortcuts import get_objects_for_user from guardian.shortcuts import get_objects_for_user
from structlog.stdlib import BoundLogger, get_logger from structlog.stdlib import BoundLogger, get_logger

View file

@ -2,7 +2,7 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from dacite import from_dict from dacite.core import from_dict
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi
from authentik.outposts.controllers.base import FIELD_MANAGER from authentik.outposts.controllers.base import FIELD_MANAGER

View file

@ -4,7 +4,7 @@ from datetime import datetime
from typing import Iterable, Optional from typing import Iterable, Optional
from uuid import uuid4 from uuid import uuid4
from dacite import from_dict from dacite.core import from_dict
from django.contrib.auth.models import Permission from django.contrib.auth.models import Permission
from django.core.cache import cache from django.core.cache import cache
from django.db import IntegrityError, models, transaction from django.db import IntegrityError, models, transaction
@ -74,7 +74,7 @@ class OutpostConfig:
kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls") kubernetes_ingress_secret_name: str = field(default="authentik-outpost-tls")
kubernetes_service_type: str = field(default="ClusterIP") kubernetes_service_type: str = field(default="ClusterIP")
kubernetes_disabled_components: list[str] = field(default_factory=list) kubernetes_disabled_components: list[str] = field(default_factory=list)
kubernetes_image_pull_secrets: Optional[list[str]] = field(default_factory=list) kubernetes_image_pull_secrets: list[str] = field(default_factory=list)
class OutpostModel(Model): class OutpostModel(Model):

View file

@ -74,10 +74,14 @@ def outpost_service_connection_state(connection_pk: Any):
) )
if not connection: if not connection:
return return
cls = None
if isinstance(connection, DockerServiceConnection): if isinstance(connection, DockerServiceConnection):
cls = DockerClient cls = DockerClient
if isinstance(connection, KubernetesServiceConnection): if isinstance(connection, KubernetesServiceConnection):
cls = KubernetesClient cls = KubernetesClient
if not cls:
LOGGER.warning("No class found for service connection", connection=connection)
return
try: try:
with cls(connection) as client: with cls(connection) as client:
state = client.fetch_state() state = client.fetch_state()

View file

@ -11,7 +11,7 @@ from urllib.parse import urlparse, urlunparse
from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey from cryptography.hazmat.primitives.asymmetric.ec import EllipticCurvePrivateKey
from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey from cryptography.hazmat.primitives.asymmetric.rsa import RSAPrivateKey
from dacite import from_dict from dacite.core import from_dict
from django.db import models from django.db import models
from django.http import HttpRequest from django.http import HttpRequest
from django.utils import dateformat, timezone from django.utils import dateformat, timezone

View file

@ -2,7 +2,7 @@
from dataclasses import asdict, dataclass, field from dataclasses import asdict, dataclass, field
from typing import TYPE_CHECKING from typing import TYPE_CHECKING
from dacite import from_dict from dacite.core import from_dict
from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi from kubernetes.client import ApiextensionsV1Api, CustomObjectsApi
from authentik.outposts.controllers.base import FIELD_MANAGER from authentik.outposts.controllers.base import FIELD_MANAGER

View file

@ -39,8 +39,8 @@ class BaseOAuthClient:
profile_url = self.source.type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
try:
response = self.do_request("get", profile_url, token=token) response = self.do_request("get", profile_url, token=token)
try:
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text) self.logger.warning("Unable to fetch user profile", exc=exc, body=response.text)

View file

@ -138,12 +138,12 @@ class UserprofileHeaderAuthClient(OAuth2Client):
profile_url = self.source.type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
try:
response = self.session.request( response = self.session.request(
"get", "get",
profile_url, profile_url,
headers={"Authorization": f"{token['token_type']} {token['access_token']}"}, headers={"Authorization": f"{token['token_type']} {token['access_token']}"},
) )
try:
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text) LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)

View file

@ -1,5 +1,5 @@
"""GitHub OAuth Views""" """GitHub OAuth Views"""
from typing import Any, Optional from typing import Any
from requests.exceptions import RequestException from requests.exceptions import RequestException
@ -21,14 +21,14 @@ class GitHubOAuthRedirect(OAuthRedirect):
class GitHubOAuth2Client(OAuth2Client): class GitHubOAuth2Client(OAuth2Client):
"""GitHub OAuth2 Client""" """GitHub OAuth2 Client"""
def get_github_emails(self, token: dict[str, str]) -> Optional[dict[str, Any]]: def get_github_emails(self, token: dict[str, str]) -> list[dict[str, Any]]:
"""Get Emails from the GitHub API""" """Get Emails from the GitHub API"""
profile_url = self.source.type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
profile_url += "/emails" profile_url += "/emails"
try:
response = self.do_request("get", profile_url, token=token) response = self.do_request("get", profile_url, token=token)
try:
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
self.logger.warning("Unable to fetch github emails", exc=exc) self.logger.warning("Unable to fetch github emails", exc=exc)

View file

@ -29,11 +29,11 @@ class MailcowOAuth2Client(OAuth2Client):
profile_url = self.source.type.profile_url or "" profile_url = self.source.type.profile_url or ""
if self.source.type.urls_customizable and self.source.profile_url: if self.source.type.urls_customizable and self.source.profile_url:
profile_url = self.source.profile_url profile_url = self.source.profile_url
try:
response = self.session.request( response = self.session.request(
"get", "get",
f"{profile_url}?access_token={token['access_token']}", f"{profile_url}?access_token={token['access_token']}",
) )
try:
response.raise_for_status() response.raise_for_status()
except RequestException as exc: except RequestException as exc:
LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text) LOGGER.warning("Unable to fetch user profile", exc=exc, body=response.text)

View file

@ -13,9 +13,11 @@ from django_otp.models import Device
from rest_framework.fields import CharField, JSONField from rest_framework.fields import CharField, JSONField
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from webauthn import generate_authentication_options, verify_authentication_response from webauthn.authentication.generate_authentication_options import generate_authentication_options
from webauthn.helpers import base64url_to_bytes, options_to_json from webauthn.authentication.verify_authentication_response import verify_authentication_response
from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
from webauthn.helpers.exceptions import InvalidAuthenticationResponse from webauthn.helpers.exceptions import InvalidAuthenticationResponse
from webauthn.helpers.options_to_json import options_to_json
from webauthn.helpers.structs import AuthenticationCredential from webauthn.helpers.structs import AuthenticationCredential
from authentik.core.api.utils import PassiveSerializer from authentik.core.api.utils import PassiveSerializer

View file

@ -4,7 +4,8 @@ from time import sleep
from django.test.client import RequestFactory from django.test.client import RequestFactory
from django.urls.base import reverse from django.urls.base import reverse
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from webauthn.helpers import base64url_to_bytes, bytes_to_base64url from webauthn.helpers.base64url_to_bytes import base64url_to_bytes
from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.models import Flow, FlowStageBinding, NotConfiguredAction from authentik.flows.models import Flow, FlowStageBinding, NotConfiguredAction

View file

@ -5,15 +5,19 @@ from django.http import HttpRequest, HttpResponse
from django.http.request import QueryDict from django.http.request import QueryDict
from rest_framework.fields import CharField, JSONField from rest_framework.fields import CharField, JSONField
from rest_framework.serializers import ValidationError from rest_framework.serializers import ValidationError
from webauthn import generate_registration_options, options_to_json, verify_registration_response from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
from webauthn.helpers import bytes_to_base64url
from webauthn.helpers.exceptions import InvalidRegistrationResponse from webauthn.helpers.exceptions import InvalidRegistrationResponse
from webauthn.helpers.options_to_json import options_to_json
from webauthn.helpers.structs import ( from webauthn.helpers.structs import (
AuthenticatorSelectionCriteria, AuthenticatorSelectionCriteria,
PublicKeyCredentialCreationOptions, PublicKeyCredentialCreationOptions,
RegistrationCredential, RegistrationCredential,
) )
from webauthn.registration.verify_registration_response import VerifiedRegistration from webauthn.registration.generate_registration_options import generate_registration_options
from webauthn.registration.verify_registration_response import (
VerifiedRegistration,
verify_registration_response,
)
from authentik.core.models import User from authentik.core.models import User
from authentik.flows.challenge import ( from authentik.flows.challenge import (

View file

@ -2,7 +2,7 @@
from base64 import b64decode from base64 import b64decode
from django.urls import reverse from django.urls import reverse
from webauthn.helpers import bytes_to_base64url from webauthn.helpers.bytes_to_base64url import bytes_to_base64url
from authentik.core.tests.utils import create_test_admin_user, create_test_flow from authentik.core.tests.utils import create_test_admin_user, create_test_flow
from authentik.flows.markers import StageMarker from authentik.flows.markers import StageMarker

View file

@ -62,6 +62,8 @@ if __name__ == "__main__":
try: try:
for migration in Path(__file__).parent.absolute().glob("system_migrations/*.py"): for migration in Path(__file__).parent.absolute().glob("system_migrations/*.py"):
spec = spec_from_file_location("lifecycle.system_migrations", migration) spec = spec_from_file_location("lifecycle.system_migrations", migration)
if not spec:
continue
mod = module_from_spec(spec) mod = module_from_spec(spec)
# pyright: reportGeneralTypeIssues=false # pyright: reportGeneralTypeIssues=false
spec.loader.exec_module(mod) spec.loader.exec_module(mod)

View file

@ -3,14 +3,17 @@ ignore = [
"**/migrations/**", "**/migrations/**",
"**/node_modules/**" "**/node_modules/**"
] ]
reportMissingTypeStubs = false reportMissingTypeStubs = false
strictParameterNoneValue = true strictParameterNoneValue = true
strictDictionaryInference = true strictDictionaryInference = true
strictListInference = true strictListInference = true
reportOptionalMemberAccess = false
# Sadly pyright still has issues with enums, and they fall under general type issues
# so we have to disable those for now
reportGeneralTypeIssues = false
verboseOutput = false verboseOutput = false
pythonVersion = "3.9" pythonVersion = "3.10"
pythonPlatform = "Linux" pythonPlatform = "All"
[tool.black] [tool.black]
line-length = 100 line-length = 100

View file

@ -198,7 +198,7 @@ class TestProviderLDAP(SeleniumTestCase):
search_scope=SUBTREE, search_scope=SUBTREE,
attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES], attributes=[ALL_ATTRIBUTES, ALL_OPERATIONAL_ATTRIBUTES],
) )
response = _connection.response response: dict = _connection.response
# Remove raw_attributes to make checking easier # Remove raw_attributes to make checking easier
for obj in response: for obj in response:
del obj["raw_attributes"] del obj["raw_attributes"]

View file

@ -26,7 +26,7 @@ from tests.e2e.utils import SeleniumTestCase, retry
CONFIG_PATH = "/tmp/dex.yml" # nosec CONFIG_PATH = "/tmp/dex.yml" # nosec
class OAUth1Callback(OAuthCallback): class OAuth1Callback(OAuthCallback):
"""OAuth1 Callback with custom getters""" """OAuth1 Callback with custom getters"""
def get_user_id(self, info: dict[str, str]) -> str: def get_user_id(self, info: dict[str, str]) -> str:
@ -47,7 +47,7 @@ class OAUth1Callback(OAuthCallback):
class OAUth1Type(SourceType): class OAUth1Type(SourceType):
"""OAuth1 Type definition""" """OAuth1 Type definition"""
callback_view = OAUth1Callback callback_view = OAuth1Callback
name = "OAuth1" name = "OAuth1"
slug = "oauth1" slug = "oauth1"

View file

@ -20,7 +20,7 @@ from selenium.webdriver.common.by import By
from selenium.webdriver.common.keys import Keys from selenium.webdriver.common.keys import Keys
from selenium.webdriver.remote.webdriver import WebDriver from selenium.webdriver.remote.webdriver import WebDriver
from selenium.webdriver.remote.webelement import WebElement from selenium.webdriver.remote.webelement import WebElement
from selenium.webdriver.support.ui import WebDriverWait from selenium.webdriver.support.wait import WebDriverWait
from structlog.stdlib import get_logger from structlog.stdlib import get_logger
from authentik.core.api.users import UserSerializer from authentik.core.api.users import UserSerializer
@ -143,7 +143,9 @@ class SeleniumTestCase(StaticLiveServerTestCase):
"""same as self.url() but show URL in shell""" """same as self.url() but show URL in shell"""
return f"{self.live_server_url}/if/user/#{view}" return f"{self.live_server_url}/if/user/#{view}"
def get_shadow_root(self, selector: str, container: Optional[WebElement] = None) -> WebElement: def get_shadow_root(
self, selector: str, container: Optional[WebElement | WebDriver] = None
) -> WebElement:
"""Get shadow root element's inner shadowRoot""" """Get shadow root element's inner shadowRoot"""
if not container: if not container:
container = self.driver container = self.driver

18
web/package-lock.json generated
View file

@ -62,6 +62,7 @@
"lit": "^2.3.1", "lit": "^2.3.1",
"moment": "^2.29.4", "moment": "^2.29.4",
"prettier": "^2.7.1", "prettier": "^2.7.1",
"pyright": "^1.1.269",
"rapidoc": "^9.3.3", "rapidoc": "^9.3.3",
"rollup": "^2.79.0", "rollup": "^2.79.0",
"rollup-plugin-copy": "^3.4.0", "rollup-plugin-copy": "^3.4.0",
@ -7361,6 +7362,18 @@
"node": ">=6" "node": ">=6"
} }
}, },
"node_modules/pyright": {
"version": "1.1.269",
"resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.269.tgz",
"integrity": "sha512-n3Q1ccQ4nzMmFGC8B6WUmuoylrkxrknlvpt1ODDbmXUFJlMQSNGLIoZYFZlnP0lt0b4tpO+nDaK1q0lI0nQaxA==",
"bin": {
"pyright": "index.js",
"pyright-langserver": "langserver.index.js"
},
"engines": {
"node": ">=12.0.0"
}
},
"node_modules/qrjs": { "node_modules/qrjs": {
"version": "0.1.2", "version": "0.1.2",
"resolved": "https://registry.npmjs.org/qrjs/-/qrjs-0.1.2.tgz", "resolved": "https://registry.npmjs.org/qrjs/-/qrjs-0.1.2.tgz",
@ -14573,6 +14586,11 @@
"resolved": "https://registry.npmjs.org/punycode/-/punycode-2.1.1.tgz", "resolved": "https://registry.npmjs.org/punycode/-/punycode-2.1.1.tgz",
"integrity": "sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A==" "integrity": "sha512-XRsRjdf+j5ml+y/6GKHPZbrF/8p2Yga0JPtdqTIY2Xe5ohJPD9saDJJLPvp9+NSBprVvevdXZybnj2cv8OEd0A=="
}, },
"pyright": {
"version": "1.1.269",
"resolved": "https://registry.npmjs.org/pyright/-/pyright-1.1.269.tgz",
"integrity": "sha512-n3Q1ccQ4nzMmFGC8B6WUmuoylrkxrknlvpt1ODDbmXUFJlMQSNGLIoZYFZlnP0lt0b4tpO+nDaK1q0lI0nQaxA=="
},
"qrjs": { "qrjs": {
"version": "0.1.2", "version": "0.1.2",
"resolved": "https://registry.npmjs.org/qrjs/-/qrjs-0.1.2.tgz", "resolved": "https://registry.npmjs.org/qrjs/-/qrjs-0.1.2.tgz",

View file

@ -105,6 +105,7 @@
"lit": "^2.3.1", "lit": "^2.3.1",
"moment": "^2.29.4", "moment": "^2.29.4",
"prettier": "^2.7.1", "prettier": "^2.7.1",
"pyright": "^1.1.269",
"rapidoc": "^9.3.3", "rapidoc": "^9.3.3",
"rollup": "^2.79.0", "rollup": "^2.79.0",
"rollup-plugin-copy": "^3.4.0", "rollup-plugin-copy": "^3.4.0",

View file

@ -31,7 +31,7 @@ Generally speaking, authentik is a Django application, ran by gunicorn, proxied
Most functions and classes have type-hints and docstrings, so it is recommended to install a Python Type-checking Extension in your IDE to navigate around the code. Most functions and classes have type-hints and docstrings, so it is recommended to install a Python Type-checking Extension in your IDE to navigate around the code.
Before committing code, run `make lint` to ensure your code is formatted well. This also requires `pyright@1.1.136`, which can be installed with npm. Before committing code, run `make lint` to ensure your code is formatted well. This also requires `pyright`, which is installed in the `web/` folder to make dependency management easier.
Run `make gen` to generate an updated OpenAPI document for any changes you made. Run `make gen` to generate an updated OpenAPI document for any changes you made.