diff --git a/authentik/admin/tasks.py b/authentik/admin/tasks.py index 8289d4d1d..c5b0ebf61 100644 --- a/authentik/admin/tasks.py +++ b/authentik/admin/tasks.py @@ -58,7 +58,7 @@ def clear_update_notifications(): @prefill_task def update_latest_version(self: MonitoredTask): """Update latest version info""" - if CONFIG.y_bool("disable_update_check"): + if CONFIG.get_bool("disable_update_check"): cache.set(VERSION_CACHE_KEY, "0.0.0", VERSION_CACHE_TIMEOUT) self.set_status(TaskResult(TaskResultStatus.WARNING, messages=["Version check disabled."])) return diff --git a/authentik/api/v3/config.py b/authentik/api/v3/config.py index 856ae94a7..dc9bc37df 100644 --- a/authentik/api/v3/config.py +++ b/authentik/api/v3/config.py @@ -70,7 +70,7 @@ class ConfigView(APIView): caps.append(Capabilities.CAN_SAVE_MEDIA) if GEOIP_READER.enabled: caps.append(Capabilities.CAN_GEO_IP) - if CONFIG.y_bool("impersonation"): + if CONFIG.get_bool("impersonation"): caps.append(Capabilities.CAN_IMPERSONATE) if settings.DEBUG: # pragma: no cover caps.append(Capabilities.CAN_DEBUG) @@ -86,17 +86,17 @@ class ConfigView(APIView): return ConfigSerializer( { "error_reporting": { - "enabled": CONFIG.y("error_reporting.enabled"), - "sentry_dsn": CONFIG.y("error_reporting.sentry_dsn"), - "environment": CONFIG.y("error_reporting.environment"), - "send_pii": CONFIG.y("error_reporting.send_pii"), - "traces_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.4)), + "enabled": CONFIG.get("error_reporting.enabled"), + "sentry_dsn": CONFIG.get("error_reporting.sentry_dsn"), + "environment": CONFIG.get("error_reporting.environment"), + "send_pii": CONFIG.get("error_reporting.send_pii"), + "traces_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.4)), }, "capabilities": self.get_capabilities(), - "cache_timeout": int(CONFIG.y("redis.cache_timeout")), - "cache_timeout_flows": int(CONFIG.y("redis.cache_timeout_flows")), - "cache_timeout_policies": int(CONFIG.y("redis.cache_timeout_policies")), - "cache_timeout_reputation": int(CONFIG.y("redis.cache_timeout_reputation")), + "cache_timeout": int(CONFIG.get("redis.cache_timeout")), + "cache_timeout_flows": int(CONFIG.get("redis.cache_timeout_flows")), + "cache_timeout_policies": int(CONFIG.get("redis.cache_timeout_policies")), + "cache_timeout_reputation": int(CONFIG.get("redis.cache_timeout_reputation")), } ) diff --git a/authentik/blueprints/migrations/0001_initial.py b/authentik/blueprints/migrations/0001_initial.py index 53e831f24..8f6fb1a0f 100644 --- a/authentik/blueprints/migrations/0001_initial.py +++ b/authentik/blueprints/migrations/0001_initial.py @@ -30,7 +30,7 @@ def check_blueprint_v1_file(BlueprintInstance: type, path: Path): return blueprint_file.seek(0) instance: BlueprintInstance = BlueprintInstance.objects.filter(path=path).first() - rel_path = path.relative_to(Path(CONFIG.y("blueprints_dir"))) + rel_path = path.relative_to(Path(CONFIG.get("blueprints_dir"))) meta = None if metadata: meta = from_dict(BlueprintMetadata, metadata) @@ -55,7 +55,7 @@ def migration_blueprint_import(apps: Apps, schema_editor: BaseDatabaseSchemaEdit Flow = apps.get_model("authentik_flows", "Flow") db_alias = schema_editor.connection.alias - for file in glob(f"{CONFIG.y('blueprints_dir')}/**/*.yaml", recursive=True): + for file in glob(f"{CONFIG.get('blueprints_dir')}/**/*.yaml", recursive=True): check_blueprint_v1_file(BlueprintInstance, Path(file)) for blueprint in BlueprintInstance.objects.using(db_alias).all(): diff --git a/authentik/blueprints/models.py b/authentik/blueprints/models.py index e1099d2de..2551beacc 100644 --- a/authentik/blueprints/models.py +++ b/authentik/blueprints/models.py @@ -82,7 +82,7 @@ class BlueprintInstance(SerializerModel, ManagedModel, CreatedUpdatedModel): def retrieve_file(self) -> str: """Get blueprint from path""" try: - base = Path(CONFIG.y("blueprints_dir")) + base = Path(CONFIG.get("blueprints_dir")) full_path = base.joinpath(Path(self.path)).resolve() if not str(full_path).startswith(str(base.resolve())): raise BlueprintRetrievalFailed("Invalid blueprint path") diff --git a/authentik/blueprints/v1/tasks.py b/authentik/blueprints/v1/tasks.py index 1a9092a52..b63c0b144 100644 --- a/authentik/blueprints/v1/tasks.py +++ b/authentik/blueprints/v1/tasks.py @@ -62,7 +62,7 @@ def start_blueprint_watcher(): if _file_watcher_started: return observer = Observer() - observer.schedule(BlueprintEventHandler(), CONFIG.y("blueprints_dir"), recursive=True) + observer.schedule(BlueprintEventHandler(), CONFIG.get("blueprints_dir"), recursive=True) observer.start() _file_watcher_started = True @@ -80,7 +80,7 @@ class BlueprintEventHandler(FileSystemEventHandler): blueprints_discovery.delay() if isinstance(event, FileModifiedEvent): path = Path(event.src_path) - root = Path(CONFIG.y("blueprints_dir")).absolute() + root = Path(CONFIG.get("blueprints_dir")).absolute() rel_path = str(path.relative_to(root)) for instance in BlueprintInstance.objects.filter(path=rel_path): LOGGER.debug("modified blueprint file, starting apply", instance=instance) @@ -101,7 +101,7 @@ def blueprints_find_dict(): def blueprints_find(): """Find blueprints and return valid ones""" blueprints = [] - root = Path(CONFIG.y("blueprints_dir")) + root = Path(CONFIG.get("blueprints_dir")) for path in root.rglob("**/*.yaml"): # Check if any part in the path starts with a dot and assume a hidden file if any(part for part in path.parts if part.startswith(".")): diff --git a/authentik/core/api/users.py b/authentik/core/api/users.py index 4c2aae74d..c0046412e 100644 --- a/authentik/core/api/users.py +++ b/authentik/core/api/users.py @@ -596,7 +596,7 @@ class UserViewSet(UsedByMixin, ModelViewSet): @action(detail=True, methods=["POST"]) def impersonate(self, request: Request, pk: int) -> Response: """Impersonate a user""" - if not CONFIG.y_bool("impersonation"): + if not CONFIG.get_bool("impersonation"): LOGGER.debug("User attempted to impersonate", user=request.user) return Response(status=401) if not request.user.has_perm("impersonate"): diff --git a/authentik/core/management/commands/worker.py b/authentik/core/management/commands/worker.py index ad9fbe5c9..cb03c5846 100644 --- a/authentik/core/management/commands/worker.py +++ b/authentik/core/management/commands/worker.py @@ -18,7 +18,7 @@ class Command(BaseCommand): def handle(self, **options): close_old_connections() - if CONFIG.y_bool("remote_debug"): + if CONFIG.get_bool("remote_debug"): import debugpy debugpy.listen(("0.0.0.0", 6900)) # nosec diff --git a/authentik/core/models.py b/authentik/core/models.py index 383c0f2f7..719d4e66f 100644 --- a/authentik/core/models.py +++ b/authentik/core/models.py @@ -60,7 +60,7 @@ def default_token_key(): """Default token key""" # We use generate_id since the chars in the key should be easy # to use in Emails (for verification) and URLs (for recovery) - return generate_id(int(CONFIG.y("default_token_length"))) + return generate_id(int(CONFIG.get("default_token_length"))) class UserTypes(models.TextChoices): diff --git a/authentik/crypto/tasks.py b/authentik/crypto/tasks.py index 6b9cfe392..4a660ee81 100644 --- a/authentik/crypto/tasks.py +++ b/authentik/crypto/tasks.py @@ -46,7 +46,7 @@ def certificate_discovery(self: MonitoredTask): certs = {} private_keys = {} discovered = 0 - for file in glob(CONFIG.y("cert_discovery_dir") + "/**", recursive=True): + for file in glob(CONFIG.get("cert_discovery_dir") + "/**", recursive=True): path = Path(file) if not path.exists(): continue diff --git a/authentik/events/geo.py b/authentik/events/geo.py index fd38a873d..95a28539c 100644 --- a/authentik/events/geo.py +++ b/authentik/events/geo.py @@ -33,7 +33,7 @@ class GeoIPReader: def __open(self): """Get GeoIP Reader, if configured, otherwise none""" - path = CONFIG.y("geoip") + path = CONFIG.get("geoip") if path == "" or not path: return try: @@ -46,7 +46,7 @@ class GeoIPReader: def __check_expired(self): """Check if the modification date of the GeoIP database has changed, and reload it if so""" - path = CONFIG.y("geoip") + path = CONFIG.get("geoip") try: mtime = stat(path).st_mtime diff = self.__last_mtime < mtime diff --git a/authentik/flows/planner.py b/authentik/flows/planner.py index 64dfdaba0..962273b6a 100644 --- a/authentik/flows/planner.py +++ b/authentik/flows/planner.py @@ -33,7 +33,7 @@ PLAN_CONTEXT_SOURCE = "source" # Is set by the Flow Planner when a FlowToken was used, and the currently active flow plan # was restored. PLAN_CONTEXT_IS_RESTORED = "is_restored" -CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_flows")) +CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_flows")) CACHE_PREFIX = "goauthentik.io/flows/planner/" diff --git a/authentik/flows/tests/test_executor.py b/authentik/flows/tests/test_executor.py index a5c9a0444..4d5fb5c8b 100644 --- a/authentik/flows/tests/test_executor.py +++ b/authentik/flows/tests/test_executor.py @@ -18,7 +18,6 @@ from authentik.flows.planner import FlowPlan, FlowPlanner from authentik.flows.stage import PLAN_CONTEXT_PENDING_USER_IDENTIFIER, StageView from authentik.flows.tests import FlowTestCase from authentik.flows.views.executor import NEXT_ARG_NAME, SESSION_KEY_PLAN, FlowExecutorView -from authentik.lib.config import CONFIG from authentik.lib.generators import generate_id from authentik.policies.dummy.models import DummyPolicy from authentik.policies.models import PolicyBinding @@ -85,7 +84,6 @@ class TestFlowExecutor(FlowTestCase): FlowDesignation.AUTHENTICATION, ) - CONFIG.update_from_dict({"domain": "testserver"}) response = self.client.get( reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), ) @@ -111,7 +109,6 @@ class TestFlowExecutor(FlowTestCase): denied_action=FlowDeniedAction.CONTINUE, ) - CONFIG.update_from_dict({"domain": "testserver"}) response = self.client.get( reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), ) @@ -128,7 +125,6 @@ class TestFlowExecutor(FlowTestCase): FlowDesignation.AUTHENTICATION, ) - CONFIG.update_from_dict({"domain": "testserver"}) dest = "/unique-string" url = reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}) response = self.client.get(url + f"?{NEXT_ARG_NAME}={dest}") @@ -145,7 +141,6 @@ class TestFlowExecutor(FlowTestCase): FlowDesignation.AUTHENTICATION, ) - CONFIG.update_from_dict({"domain": "testserver"}) response = self.client.get( reverse("authentik_api:flow-executor", kwargs={"flow_slug": flow.slug}), ) diff --git a/authentik/lib/avatars.py b/authentik/lib/avatars.py index 25e3e3763..8a6e2b9c1 100644 --- a/authentik/lib/avatars.py +++ b/authentik/lib/avatars.py @@ -175,7 +175,7 @@ def get_avatar(user: "User") -> str: "initials": avatar_mode_generated, "gravatar": avatar_mode_gravatar, } - modes: str = CONFIG.y("avatars", "none") + modes: str = CONFIG.get("avatars", "none") for mode in modes.split(","): avatar = None if mode in mode_map: diff --git a/authentik/lib/config.py b/authentik/lib/config.py index 2f9d10002..3e1a27cf3 100644 --- a/authentik/lib/config.py +++ b/authentik/lib/config.py @@ -2,13 +2,15 @@ import os from collections.abc import Mapping from contextlib import contextmanager +from dataclasses import dataclass, field +from enum import Enum from glob import glob -from json import dumps, loads +from json import JSONEncoder, dumps, loads from json.decoder import JSONDecodeError from pathlib import Path from sys import argv, stderr from time import time -from typing import Any +from typing import Any, Optional from urllib.parse import urlparse import yaml @@ -32,15 +34,44 @@ def get_path_from_dict(root: dict, path: str, sep=".", default=None) -> Any: return root +@dataclass +class Attr: + """Single configuration attribute""" + + class Source(Enum): + """Sources a configuration attribute can come from, determines what should be done with + Attr.source (and if it's set at all)""" + + UNSPECIFIED = "unspecified" + ENV = "env" + CONFIG_FILE = "config_file" + URI = "uri" + + value: Any + + source_type: Source = field(default=Source.UNSPECIFIED) + + # depending on source_type, might contain the environment variable or the path + # to the config file containing this change or the file containing this value + source: Optional[str] = field(default=None) + + +class AttrEncoder(JSONEncoder): + """JSON encoder that can deal with `Attr` classes""" + + def default(self, o: Any) -> Any: + if isinstance(o, Attr): + return o.value + return super().default(o) + + class ConfigLoader: """Search through SEARCH_PATHS and load configuration. Environment variables starting with `ENV_PREFIX` are also applied. A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host""" - loaded_file = [] - - def __init__(self): + def __init__(self, **kwargs): super().__init__() self.__config = {} base_dir = Path(__file__).parent.joinpath(Path("../..")).resolve() @@ -65,6 +96,7 @@ class ConfigLoader: # Update config with env file self.update_from_file(env_file) self.update_from_env() + self.update(self.__config, kwargs) def log(self, level: str, message: str, **kwargs): """Custom Log method, we want to ensure ConfigLoader always logs JSON even when @@ -86,22 +118,32 @@ class ConfigLoader: else: if isinstance(value, str): value = self.parse_uri(value) + elif not isinstance(value, Attr): + value = Attr(value) root[key] = value return root - def parse_uri(self, value: str) -> str: + def refresh(self, key: str): + """Update a single value""" + attr: Attr = get_path_from_dict(self.raw, key) + if attr.source_type != Attr.Source.URI: + return + attr.value = self.parse_uri(attr.source).value + + def parse_uri(self, value: str) -> Attr: """Parse string values which start with a URI""" url = urlparse(value) + parsed_value = value if url.scheme == "env": - value = os.getenv(url.netloc, url.query) + parsed_value = os.getenv(url.netloc, url.query) if url.scheme == "file": try: with open(url.path, "r", encoding="utf8") as _file: - value = _file.read().strip() + parsed_value = _file.read().strip() except OSError as exc: self.log("error", f"Failed to read config value from {url.path}: {exc}") - value = url.query - return value + parsed_value = url.query + return Attr(parsed_value, Attr.Source.URI, value) def update_from_file(self, path: Path): """Update config from file contents""" @@ -110,7 +152,6 @@ class ConfigLoader: try: self.update(self.__config, yaml.safe_load(file)) self.log("debug", "Loaded config", file=str(path)) - self.loaded_file.append(path) except yaml.YAMLError as exc: raise ImproperlyConfigured from exc except PermissionError as exc: @@ -121,10 +162,6 @@ class ConfigLoader: error=str(exc), ) - def update_from_dict(self, update: dict): - """Update config from dict""" - self.__config.update(update) - def update_from_env(self): """Check environment variables""" outer = {} @@ -145,7 +182,7 @@ class ConfigLoader: value = loads(value) except JSONDecodeError: pass - current_obj[dot_parts[-1]] = value + current_obj[dot_parts[-1]] = Attr(value, Attr.Source.ENV, key) idx += 1 if idx > 0: self.log("debug", "Loaded environment variables", count=idx) @@ -154,28 +191,32 @@ class ConfigLoader: @contextmanager def patch(self, path: str, value: Any): """Context manager for unittests to patch a value""" - original_value = self.y(path) - self.y_set(path, value) + original_value = self.get(path) + self.set(path, value) try: yield finally: - self.y_set(path, original_value) + self.set(path, original_value) @property def raw(self) -> dict: """Get raw config dictionary""" return self.__config - # pylint: disable=invalid-name - def y(self, path: str, default=None, sep=".") -> Any: + def get(self, path: str, default=None, sep=".") -> Any: """Access attribute by using yaml path""" # Walk sub_dicts before parsing path root = self.raw # Walk each component of the path - return get_path_from_dict(root, path, sep=sep, default=default) + attr: Attr = get_path_from_dict(root, path, sep=sep, default=Attr(default)) + return attr.value - def y_set(self, path: str, value: Any, sep="."): - """Set value using same syntax as y()""" + def get_bool(self, path: str, default=False) -> bool: + """Wrapper for get that converts value into boolean""" + return str(self.get(path, default)).lower() == "true" + + def set(self, path: str, value: Any, sep="."): + """Set value using same syntax as get()""" # Walk sub_dicts before parsing path root = self.raw # Walk each component of the path @@ -184,17 +225,14 @@ class ConfigLoader: if comp not in root: root[comp] = {} root = root.get(comp, {}) - root[path_parts[-1]] = value - - def y_bool(self, path: str, default=False) -> bool: - """Wrapper for y that converts value into boolean""" - return str(self.y(path, default)).lower() == "true" + root[path_parts[-1]] = Attr(value) CONFIG = ConfigLoader() + if __name__ == "__main__": if len(argv) < 2: - print(dumps(CONFIG.raw, indent=4)) + print(dumps(CONFIG.raw, indent=4, cls=AttrEncoder)) else: - print(CONFIG.y(argv[1])) + print(CONFIG.get(argv[1])) diff --git a/authentik/lib/sentry.py b/authentik/lib/sentry.py index 70d7246ca..572990999 100644 --- a/authentik/lib/sentry.py +++ b/authentik/lib/sentry.py @@ -51,18 +51,18 @@ class SentryTransport(HttpTransport): def sentry_init(**sentry_init_kwargs): """Configure sentry SDK""" - sentry_env = CONFIG.y("error_reporting.environment", "customer") + sentry_env = CONFIG.get("error_reporting.environment", "customer") kwargs = { "environment": sentry_env, - "send_default_pii": CONFIG.y_bool("error_reporting.send_pii", False), + "send_default_pii": CONFIG.get_bool("error_reporting.send_pii", False), "_experiments": { - "profiles_sample_rate": float(CONFIG.y("error_reporting.sample_rate", 0.1)), + "profiles_sample_rate": float(CONFIG.get("error_reporting.sample_rate", 0.1)), }, } kwargs.update(**sentry_init_kwargs) # pylint: disable=abstract-class-instantiated sentry_sdk_init( - dsn=CONFIG.y("error_reporting.sentry_dsn"), + dsn=CONFIG.get("error_reporting.sentry_dsn"), integrations=[ ArgvIntegration(), StdlibIntegration(), @@ -92,7 +92,7 @@ def traces_sampler(sampling_context: dict) -> float: return 0 if _type == "websocket": return 0 - return float(CONFIG.y("error_reporting.sample_rate", 0.1)) + return float(CONFIG.get("error_reporting.sample_rate", 0.1)) def before_send(event: dict, hint: dict) -> Optional[dict]: diff --git a/authentik/lib/tests/test_config.py b/authentik/lib/tests/test_config.py index 8a67203c1..9bb7254a5 100644 --- a/authentik/lib/tests/test_config.py +++ b/authentik/lib/tests/test_config.py @@ -16,23 +16,23 @@ class TestConfig(TestCase): config = ConfigLoader() environ[ENV_PREFIX + "_test__test"] = "bar" config.update_from_env() - self.assertEqual(config.y("test.test"), "bar") + self.assertEqual(config.get("test.test"), "bar") def test_patch(self): """Test patch decorator""" config = ConfigLoader() - config.y_set("foo.bar", "bar") - self.assertEqual(config.y("foo.bar"), "bar") + config.set("foo.bar", "bar") + self.assertEqual(config.get("foo.bar"), "bar") with config.patch("foo.bar", "baz"): - self.assertEqual(config.y("foo.bar"), "baz") - self.assertEqual(config.y("foo.bar"), "bar") + self.assertEqual(config.get("foo.bar"), "baz") + self.assertEqual(config.get("foo.bar"), "bar") def test_uri_env(self): """Test URI parsing (environment)""" config = ConfigLoader() environ["foo"] = "bar" - self.assertEqual(config.parse_uri("env://foo"), "bar") - self.assertEqual(config.parse_uri("env://foo?bar"), "bar") + self.assertEqual(config.parse_uri("env://foo").value, "bar") + self.assertEqual(config.parse_uri("env://foo?bar").value, "bar") def test_uri_file(self): """Test URI parsing (file load)""" @@ -41,11 +41,25 @@ class TestConfig(TestCase): write(file, "foo".encode()) _, file2_name = mkstemp() chmod(file2_name, 0o000) # Remove all permissions so we can't read the file - self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo") - self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def") + self.assertEqual(config.parse_uri(f"file://{file_name}").value, "foo") + self.assertEqual(config.parse_uri(f"file://{file2_name}?def").value, "def") unlink(file_name) unlink(file2_name) + def test_uri_file_update(self): + """Test URI parsing (file load and update)""" + file, file_name = mkstemp() + write(file, "foo".encode()) + config = ConfigLoader(file_test=f"file://{file_name}") + self.assertEqual(config.get("file_test"), "foo") + + # Update config file + write(file, "bar".encode()) + config.refresh("file_test") + self.assertEqual(config.get("file_test"), "foobar") + + unlink(file_name) + def test_file_update(self): """Test update_from_file""" config = ConfigLoader() diff --git a/authentik/lib/utils/reflection.py b/authentik/lib/utils/reflection.py index 4e35688f9..c7dda7414 100644 --- a/authentik/lib/utils/reflection.py +++ b/authentik/lib/utils/reflection.py @@ -50,7 +50,7 @@ def get_env() -> str: """Get environment in which authentik is currently running""" if "CI" in os.environ: return "ci" - if CONFIG.y_bool("debug"): + if CONFIG.get_bool("debug"): return "dev" if SERVICE_HOST_ENV_NAME in os.environ: return "kubernetes" diff --git a/authentik/outposts/controllers/base.py b/authentik/outposts/controllers/base.py index a4ede702a..a3c0cb7d6 100644 --- a/authentik/outposts/controllers/base.py +++ b/authentik/outposts/controllers/base.py @@ -97,7 +97,7 @@ class BaseController: if self.outpost.config.container_image is not None: return self.outpost.config.container_image - image_name_template: str = CONFIG.y("outposts.container_image_base") + image_name_template: str = CONFIG.get("outposts.container_image_base") return image_name_template % { "type": self.outpost.type, "version": __version__, diff --git a/authentik/outposts/models.py b/authentik/outposts/models.py index c0217a0f4..ac05962c2 100644 --- a/authentik/outposts/models.py +++ b/authentik/outposts/models.py @@ -58,7 +58,7 @@ class OutpostConfig: authentik_host_insecure: bool = False authentik_host_browser: str = "" - log_level: str = CONFIG.y("log_level") + log_level: str = CONFIG.get("log_level") object_naming_template: str = field(default="ak-outpost-%(name)s") container_image: Optional[str] = field(default=None) diff --git a/authentik/outposts/tasks.py b/authentik/outposts/tasks.py index e10cb3026..227127352 100644 --- a/authentik/outposts/tasks.py +++ b/authentik/outposts/tasks.py @@ -256,7 +256,7 @@ def _outpost_single_update(outpost: Outpost, layer=None): def outpost_connection_discovery(self: MonitoredTask): """Checks the local environment and create Service connections.""" status = TaskResult(TaskResultStatus.SUCCESSFUL) - if not CONFIG.y_bool("outposts.discover"): + if not CONFIG.get_bool("outposts.discover"): status.messages.append("Outpost integration discovery is disabled") self.set_status(status) return diff --git a/authentik/policies/process.py b/authentik/policies/process.py index 6340162ae..fe05d8571 100644 --- a/authentik/policies/process.py +++ b/authentik/policies/process.py @@ -19,7 +19,7 @@ from authentik.policies.types import CACHE_PREFIX, PolicyRequest, PolicyResult LOGGER = get_logger() FORK_CTX = get_context("fork") -CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_policies")) +CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_policies")) PROCESS_CLASS = FORK_CTX.Process diff --git a/authentik/policies/reputation/signals.py b/authentik/policies/reputation/signals.py index 693b348fb..af78e6109 100644 --- a/authentik/policies/reputation/signals.py +++ b/authentik/policies/reputation/signals.py @@ -13,7 +13,7 @@ from authentik.policies.reputation.tasks import save_reputation from authentik.stages.identification.signals import identification_failed LOGGER = get_logger() -CACHE_TIMEOUT = int(CONFIG.y("redis.cache_timeout_reputation")) +CACHE_TIMEOUT = int(CONFIG.get("redis.cache_timeout_reputation")) def update_score(request: HttpRequest, identifier: str, amount: int): diff --git a/authentik/providers/oauth2/views/device_backchannel.py b/authentik/providers/oauth2/views/device_backchannel.py index a8b511f96..79f723a73 100644 --- a/authentik/providers/oauth2/views/device_backchannel.py +++ b/authentik/providers/oauth2/views/device_backchannel.py @@ -46,7 +46,7 @@ class DeviceView(View): def dispatch(self, request: HttpRequest, *args, **kwargs) -> HttpResponse: throttle = AnonRateThrottle() - throttle.rate = CONFIG.y("throttle.providers.oauth2.device", "20/hour") + throttle.rate = CONFIG.get("throttle.providers.oauth2.device", "20/hour") throttle.num_requests, throttle.duration = throttle.parse_rate(throttle.rate) if not throttle.allow_request(request, self): return HttpResponse(status=429) diff --git a/authentik/root/db/__init__.py b/authentik/root/db/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/authentik/root/db/base.py b/authentik/root/db/base.py new file mode 100644 index 000000000..1e9f80f70 --- /dev/null +++ b/authentik/root/db/base.py @@ -0,0 +1,15 @@ +"""authentik database backend""" +from django_prometheus.db.backends.postgresql.base import DatabaseWrapper as BaseDatabaseWrapper + +from authentik.lib.config import CONFIG + + +class DatabaseWrapper(BaseDatabaseWrapper): + """database backend which supports rotating credentials""" + + def get_connection_params(self): + CONFIG.refresh("postgresql.password") + conn_params = super().get_connection_params() + conn_params["user"] = CONFIG.get("postgresql.user") + conn_params["password"] = CONFIG.get("postgresql.password") + return conn_params diff --git a/authentik/root/install_id.py b/authentik/root/install_id.py index a77e7b9ac..a2b4fc544 100644 --- a/authentik/root/install_id.py +++ b/authentik/root/install_id.py @@ -26,15 +26,15 @@ def get_install_id_raw(): """Get install_id without django loaded, this is required for the startup when we get the install_id but django isn't loaded yet and we can't use the function above.""" conn = connect( - dbname=CONFIG.y("postgresql.name"), - user=CONFIG.y("postgresql.user"), - password=CONFIG.y("postgresql.password"), - host=CONFIG.y("postgresql.host"), - port=int(CONFIG.y("postgresql.port")), - sslmode=CONFIG.y("postgresql.sslmode"), - sslrootcert=CONFIG.y("postgresql.sslrootcert"), - sslcert=CONFIG.y("postgresql.sslcert"), - sslkey=CONFIG.y("postgresql.sslkey"), + dbname=CONFIG.get("postgresql.name"), + user=CONFIG.get("postgresql.user"), + password=CONFIG.get("postgresql.password"), + host=CONFIG.get("postgresql.host"), + port=int(CONFIG.get("postgresql.port")), + sslmode=CONFIG.get("postgresql.sslmode"), + sslrootcert=CONFIG.get("postgresql.sslrootcert"), + sslcert=CONFIG.get("postgresql.sslcert"), + sslkey=CONFIG.get("postgresql.sslkey"), ) cursor = conn.cursor() cursor.execute("SELECT id FROM authentik_install_id LIMIT 1;") diff --git a/authentik/root/settings.py b/authentik/root/settings.py index 7a4659d24..838cfddf2 100644 --- a/authentik/root/settings.py +++ b/authentik/root/settings.py @@ -24,8 +24,8 @@ BASE_DIR = Path(__file__).absolute().parent.parent.parent STATICFILES_DIRS = [BASE_DIR / Path("web")] MEDIA_ROOT = BASE_DIR / Path("media") -DEBUG = CONFIG.y_bool("debug") -SECRET_KEY = CONFIG.y("secret_key") +DEBUG = CONFIG.get_bool("debug") +SECRET_KEY = CONFIG.get("secret_key") INTERNAL_IPS = ["127.0.0.1"] ALLOWED_HOSTS = ["*"] @@ -40,7 +40,7 @@ CSRF_COOKIE_NAME = "authentik_csrf" CSRF_HEADER_NAME = "HTTP_X_AUTHENTIK_CSRF" LANGUAGE_COOKIE_NAME = "authentik_language" SESSION_COOKIE_NAME = "authentik_session" -SESSION_COOKIE_DOMAIN = CONFIG.y("cookie_domain", None) +SESSION_COOKIE_DOMAIN = CONFIG.get("cookie_domain", None) AUTHENTICATION_BACKENDS = [ "django.contrib.auth.backends.ModelBackend", @@ -179,26 +179,26 @@ REST_FRAMEWORK = { "TEST_REQUEST_DEFAULT_FORMAT": "json", "DEFAULT_THROTTLE_CLASSES": ["rest_framework.throttling.AnonRateThrottle"], "DEFAULT_THROTTLE_RATES": { - "anon": CONFIG.y("throttle.default"), + "anon": CONFIG.get("throttle.default"), }, } _redis_protocol_prefix = "redis://" _redis_celery_tls_requirements = "" -if CONFIG.y_bool("redis.tls", False): +if CONFIG.get_bool("redis.tls", False): _redis_protocol_prefix = "rediss://" - _redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.y('redis.tls_reqs')}" + _redis_celery_tls_requirements = f"?ssl_cert_reqs={CONFIG.get('redis.tls_reqs')}" _redis_url = ( f"{_redis_protocol_prefix}:" - f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:" - f"{int(CONFIG.y('redis.port'))}" + f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" + f"{int(CONFIG.get('redis.port'))}" ) CACHES = { "default": { "BACKEND": "django_redis.cache.RedisCache", - "LOCATION": f"{_redis_url}/{CONFIG.y('redis.db')}", - "TIMEOUT": int(CONFIG.y("redis.cache_timeout", 300)), + "LOCATION": f"{_redis_url}/{CONFIG.get('redis.db')}", + "TIMEOUT": int(CONFIG.get("redis.cache_timeout", 300)), "OPTIONS": {"CLIENT_CLASS": "django_redis.client.DefaultClient"}, "KEY_PREFIX": "authentik_cache", } @@ -238,7 +238,7 @@ ROOT_URLCONF = "authentik.root.urls" TEMPLATES = [ { "BACKEND": "django.template.backends.django.DjangoTemplates", - "DIRS": [CONFIG.y("email.template_dir")], + "DIRS": [CONFIG.get("email.template_dir")], "APP_DIRS": True, "OPTIONS": { "context_processors": [ @@ -258,7 +258,7 @@ CHANNEL_LAYERS = { "default": { "BACKEND": "channels_redis.core.RedisChannelLayer", "CONFIG": { - "hosts": [f"{_redis_url}/{CONFIG.y('redis.db')}"], + "hosts": [f"{_redis_url}/{CONFIG.get('redis.db')}"], "prefix": "authentik_channels", }, }, @@ -270,34 +270,37 @@ CHANNEL_LAYERS = { DATABASES = { "default": { - "ENGINE": "django_prometheus.db.backends.postgresql", - "HOST": CONFIG.y("postgresql.host"), - "NAME": CONFIG.y("postgresql.name"), - "USER": CONFIG.y("postgresql.user"), - "PASSWORD": CONFIG.y("postgresql.password"), - "PORT": int(CONFIG.y("postgresql.port")), - "SSLMODE": CONFIG.y("postgresql.sslmode"), - "SSLROOTCERT": CONFIG.y("postgresql.sslrootcert"), - "SSLCERT": CONFIG.y("postgresql.sslcert"), - "SSLKEY": CONFIG.y("postgresql.sslkey"), + "ENGINE": "authentik.root.db", + "HOST": CONFIG.get("postgresql.host"), + "NAME": CONFIG.get("postgresql.name"), + "USER": CONFIG.get("postgresql.user"), + "PASSWORD": CONFIG.get("postgresql.password"), + "PORT": int(CONFIG.get("postgresql.port")), + "SSLMODE": CONFIG.get("postgresql.sslmode"), + "SSLROOTCERT": CONFIG.get("postgresql.sslrootcert"), + "SSLCERT": CONFIG.get("postgresql.sslcert"), + "SSLKEY": CONFIG.get("postgresql.sslkey"), } } -if CONFIG.y_bool("postgresql.use_pgbouncer", False): +if CONFIG.get_bool("postgresql.use_pgbouncer", False): # https://docs.djangoproject.com/en/4.0/ref/databases/#transaction-pooling-server-side-cursors DATABASES["default"]["DISABLE_SERVER_SIDE_CURSORS"] = True # https://docs.djangoproject.com/en/4.0/ref/databases/#persistent-connections DATABASES["default"]["CONN_MAX_AGE"] = None # persistent # Email -EMAIL_HOST = CONFIG.y("email.host") -EMAIL_PORT = int(CONFIG.y("email.port")) -EMAIL_HOST_USER = CONFIG.y("email.username") -EMAIL_HOST_PASSWORD = CONFIG.y("email.password") -EMAIL_USE_TLS = CONFIG.y_bool("email.use_tls", False) -EMAIL_USE_SSL = CONFIG.y_bool("email.use_ssl", False) -EMAIL_TIMEOUT = int(CONFIG.y("email.timeout")) -DEFAULT_FROM_EMAIL = CONFIG.y("email.from") +# These values should never actually be used, emails are only sent from email stages, which +# loads the config directly from CONFIG +# See authentik/stages/email/models.py, line 105 +EMAIL_HOST = CONFIG.get("email.host") +EMAIL_PORT = int(CONFIG.get("email.port")) +EMAIL_HOST_USER = CONFIG.get("email.username") +EMAIL_HOST_PASSWORD = CONFIG.get("email.password") +EMAIL_USE_TLS = CONFIG.get_bool("email.use_tls", False) +EMAIL_USE_SSL = CONFIG.get_bool("email.use_ssl", False) +EMAIL_TIMEOUT = int(CONFIG.get("email.timeout")) +DEFAULT_FROM_EMAIL = CONFIG.get("email.from") SERVER_EMAIL = DEFAULT_FROM_EMAIL EMAIL_SUBJECT_PREFIX = "[authentik] " @@ -345,15 +348,15 @@ CELERY = { }, "task_create_missing_queues": True, "task_default_queue": "authentik", - "broker_url": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}", - "result_backend": f"{_redis_url}/{CONFIG.y('redis.db')}{_redis_celery_tls_requirements}", + "broker_url": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}", + "result_backend": f"{_redis_url}/{CONFIG.get('redis.db')}{_redis_celery_tls_requirements}", } # Sentry integration env = get_env() -_ERROR_REPORTING = CONFIG.y_bool("error_reporting.enabled", False) +_ERROR_REPORTING = CONFIG.get_bool("error_reporting.enabled", False) if _ERROR_REPORTING: - sentry_env = CONFIG.y("error_reporting.environment", "customer") + sentry_env = CONFIG.get("error_reporting.environment", "customer") sentry_init() set_tag("authentik.uuid", sha512(str(SECRET_KEY).encode("ascii")).hexdigest()[:16]) @@ -367,7 +370,7 @@ MEDIA_URL = "/media/" TEST = False TEST_RUNNER = "authentik.root.test_runner.PytestTestRunner" # We can't check TEST here as its set later by the test runner -LOG_LEVEL = CONFIG.y("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG" +LOG_LEVEL = CONFIG.get("log_level").upper() if "TF_BUILD" not in os.environ else "DEBUG" # We could add a custom level to stdlib logging and structlog, but it's not easy or clean # https://stackoverflow.com/questions/54505487/custom-log-level-not-working-with-structlog # Additionally, the entire code uses debug as highest level so that would have to be re-written too diff --git a/authentik/root/test_runner.py b/authentik/root/test_runner.py index 7a420c157..febcedf9c 100644 --- a/authentik/root/test_runner.py +++ b/authentik/root/test_runner.py @@ -31,14 +31,14 @@ class PytestTestRunner: # pragma: no cover settings.TEST = True settings.CELERY["task_always_eager"] = True - CONFIG.y_set("avatars", "none") - CONFIG.y_set("geoip", "tests/GeoLite2-City-Test.mmdb") - CONFIG.y_set("blueprints_dir", "./blueprints") - CONFIG.y_set( + CONFIG.set("avatars", "none") + CONFIG.set("geoip", "tests/GeoLite2-City-Test.mmdb") + CONFIG.set("blueprints_dir", "./blueprints") + CONFIG.set( "outposts.container_image_base", f"ghcr.io/goauthentik/dev-%(type)s:{get_docker_tag()}", ) - CONFIG.y_set("error_reporting.sample_rate", 0) + CONFIG.set("error_reporting.sample_rate", 0) sentry_init( environment="testing", send_default_pii=True, diff --git a/authentik/sources/ldap/models.py b/authentik/sources/ldap/models.py index 9c5040d2d..ac7f32aca 100644 --- a/authentik/sources/ldap/models.py +++ b/authentik/sources/ldap/models.py @@ -136,7 +136,7 @@ class LDAPSource(Source): chmod(private_key_file, 0o600) tls_kwargs["local_private_key_file"] = private_key_file tls_kwargs["local_certificate_file"] = certificate_file - if ciphers := CONFIG.y("ldap.tls.ciphers", None): + if ciphers := CONFIG.get("ldap.tls.ciphers", None): tls_kwargs["ciphers"] = ciphers.strip() if self.sni: tls_kwargs["sni"] = self.server_uri.split(",", maxsplit=1)[0].strip() diff --git a/authentik/sources/ldap/sync/base.py b/authentik/sources/ldap/sync/base.py index 97b1c381a..b544d70c9 100644 --- a/authentik/sources/ldap/sync/base.py +++ b/authentik/sources/ldap/sync/base.py @@ -93,7 +93,7 @@ class BaseLDAPSynchronizer: types_only=False, get_operational_attributes=False, controls=None, - paged_size=int(CONFIG.y("ldap.page_size", 50)), + paged_size=int(CONFIG.get("ldap.page_size", 50)), paged_criticality=False, ): """Search in pages, returns each page""" diff --git a/authentik/sources/ldap/tasks.py b/authentik/sources/ldap/tasks.py index 174b97e0d..0b0a975f5 100644 --- a/authentik/sources/ldap/tasks.py +++ b/authentik/sources/ldap/tasks.py @@ -68,12 +68,12 @@ def ldap_sync_paginator(source: LDAPSource, sync: type[BaseLDAPSynchronizer]) -> @CELERY_APP.task( bind=True, base=MonitoredTask, - soft_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), - task_time_limit=60 * 60 * int(CONFIG.y("ldap.task_timeout_hours")), + soft_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), + task_time_limit=60 * 60 * int(CONFIG.get("ldap.task_timeout_hours")), ) def ldap_sync(self: MonitoredTask, source_pk: str, sync_class: str, page_cache_key: str): """Synchronization of an LDAP Source""" - self.result_timeout_hours = int(CONFIG.y("ldap.task_timeout_hours")) + self.result_timeout_hours = int(CONFIG.get("ldap.task_timeout_hours")) source: LDAPSource = LDAPSource.objects.filter(pk=source_pk).first() if not source: # Because the source couldn't be found, we don't have a UID diff --git a/authentik/stages/email/models.py b/authentik/stages/email/models.py index 156762ac5..1beff50b6 100644 --- a/authentik/stages/email/models.py +++ b/authentik/stages/email/models.py @@ -13,6 +13,7 @@ from rest_framework.serializers import BaseSerializer from structlog.stdlib import get_logger from authentik.flows.models import Stage +from authentik.lib.config import CONFIG LOGGER = get_logger() @@ -104,7 +105,16 @@ class EmailStage(Stage): def backend(self) -> BaseEmailBackend: """Get fully configured Email Backend instance""" if self.use_global_settings: - return self.backend_class() + CONFIG.refresh("email.password") + return self.backend_class( + host=CONFIG.get("email.host"), + port=int(CONFIG.get("email.port")), + username=CONFIG.get("email.username"), + password=CONFIG.get("email.password"), + use_tls=CONFIG.get_bool("email.use_tls", False), + use_ssl=CONFIG.get_bool("email.use_ssl", False), + timeout=int(CONFIG.get("email.timeout")), + ) return self.backend_class( host=self.host, port=self.port, diff --git a/authentik/stages/email/tests/test_stage.py b/authentik/stages/email/tests/test_stage.py index 32141ba45..e4b362b2a 100644 --- a/authentik/stages/email/tests/test_stage.py +++ b/authentik/stages/email/tests/test_stage.py @@ -13,6 +13,7 @@ from authentik.flows.models import FlowDesignation, FlowStageBinding, FlowToken from authentik.flows.planner import PLAN_CONTEXT_PENDING_USER, FlowPlan from authentik.flows.tests import FlowTestCase from authentik.flows.views.executor import QS_KEY_TOKEN, SESSION_KEY_PLAN +from authentik.lib.config import CONFIG from authentik.stages.email.models import EmailStage from authentik.stages.email.stage import PLAN_CONTEXT_EMAIL_OVERRIDE @@ -120,7 +121,7 @@ class TestEmailStage(FlowTestCase): def test_use_global_settings(self): """Test use_global_settings""" host = "some-unique-string" - with self.settings(EMAIL_HOST=host): + with CONFIG.patch("email.host", host): self.assertEqual(EmailStage(use_global_settings=True).backend.host, host) def test_token(self): diff --git a/authentik/tenants/api.py b/authentik/tenants/api.py index 42b2985c0..a8c438907 100644 --- a/authentik/tenants/api.py +++ b/authentik/tenants/api.py @@ -78,7 +78,7 @@ class CurrentTenantSerializer(PassiveSerializer): ui_footer_links = ListField( child=FooterLinkSerializer(), read_only=True, - default=CONFIG.y("footer_links", []), + default=CONFIG.get("footer_links", []), ) ui_theme = ChoiceField( choices=Themes.choices, diff --git a/authentik/tenants/tests.py b/authentik/tenants/tests.py index 73bd8f7d3..a42326dc9 100644 --- a/authentik/tenants/tests.py +++ b/authentik/tenants/tests.py @@ -24,7 +24,7 @@ class TestTenants(APITestCase): "branding_favicon": "/static/dist/assets/icons/icon.png", "branding_title": "authentik", "matched_domain": tenant.domain, - "ui_footer_links": CONFIG.y("footer_links"), + "ui_footer_links": CONFIG.get("footer_links"), "ui_theme": Themes.AUTOMATIC, "default_locale": "", }, @@ -43,7 +43,7 @@ class TestTenants(APITestCase): "branding_favicon": "/static/dist/assets/icons/icon.png", "branding_title": "custom", "matched_domain": "bar.baz", - "ui_footer_links": CONFIG.y("footer_links"), + "ui_footer_links": CONFIG.get("footer_links"), "ui_theme": Themes.AUTOMATIC, "default_locale": "", }, @@ -59,7 +59,7 @@ class TestTenants(APITestCase): "branding_favicon": "/static/dist/assets/icons/icon.png", "branding_title": "authentik", "matched_domain": "fallback", - "ui_footer_links": CONFIG.y("footer_links"), + "ui_footer_links": CONFIG.get("footer_links"), "ui_theme": Themes.AUTOMATIC, "default_locale": "", }, diff --git a/authentik/tenants/utils.py b/authentik/tenants/utils.py index 87da6c8f6..d7f981b7e 100644 --- a/authentik/tenants/utils.py +++ b/authentik/tenants/utils.py @@ -36,7 +36,7 @@ def context_processor(request: HttpRequest) -> dict[str, Any]: trace = span.to_traceparent() return { "tenant": tenant, - "footer_links": CONFIG.y("footer_links"), + "footer_links": CONFIG.get("footer_links"), "sentry_trace": trace, "version": get_full_version(), } diff --git a/blueprints/default/flow-default-user-settings-flow.yaml b/blueprints/default/flow-default-user-settings-flow.yaml index 0640a643b..01fb733d5 100644 --- a/blueprints/default/flow-default-user-settings-flow.yaml +++ b/blueprints/default/flow-default-user-settings-flow.yaml @@ -94,21 +94,21 @@ entries: prompt_data = request.context.get("prompt_data") if not request.user.group_attributes(request.http_request).get( - USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.y_bool("default_user_change_email", True) + USER_ATTRIBUTE_CHANGE_EMAIL, CONFIG.get_bool("default_user_change_email", True) ): if prompt_data.get("email") != request.user.email: ak_message("Not allowed to change email address.") return False if not request.user.group_attributes(request.http_request).get( - USER_ATTRIBUTE_CHANGE_NAME, CONFIG.y_bool("default_user_change_name", True) + USER_ATTRIBUTE_CHANGE_NAME, CONFIG.get_bool("default_user_change_name", True) ): if prompt_data.get("name") != request.user.name: ak_message("Not allowed to change name.") return False if not request.user.group_attributes(request.http_request).get( - USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.y_bool("default_user_change_username", True) + USER_ATTRIBUTE_CHANGE_USERNAME, CONFIG.get_bool("default_user_change_username", True) ): if prompt_data.get("username") != request.user.username: ak_message("Not allowed to change username.") diff --git a/lifecycle/gunicorn.conf.py b/lifecycle/gunicorn.conf.py index f9c21ebdd..89dcb784f 100644 --- a/lifecycle/gunicorn.conf.py +++ b/lifecycle/gunicorn.conf.py @@ -37,7 +37,7 @@ makedirs(prometheus_tmp_dir, exist_ok=True) max_requests = 1000 max_requests_jitter = 50 -_debug = CONFIG.y_bool("DEBUG", False) +_debug = CONFIG.get_bool("DEBUG", False) logconfig_dict = { "version": 1, @@ -80,8 +80,8 @@ if SERVICE_HOST_ENV_NAME in os.environ: else: default_workers = max(cpu_count() * 0.25, 1) + 1 # Minimum of 2 workers -workers = int(CONFIG.y("web.workers", default_workers)) -threads = int(CONFIG.y("web.threads", 4)) +workers = int(CONFIG.get("web.workers", default_workers)) +threads = int(CONFIG.get("web.threads", 4)) def post_fork(server: "Arbiter", worker: DjangoUvicornWorker): @@ -133,7 +133,7 @@ def pre_fork(server: "Arbiter", worker: DjangoUvicornWorker): worker._worker_id = _next_worker_id(server) -if not CONFIG.y_bool("disable_startup_analytics", False): +if not CONFIG.get_bool("disable_startup_analytics", False): env = get_env() should_send = env not in ["dev", "ci"] if should_send: @@ -158,7 +158,7 @@ if not CONFIG.y_bool("disable_startup_analytics", False): except Exception: # nosec pass -if CONFIG.y_bool("remote_debug"): +if CONFIG.get_bool("remote_debug"): import debugpy debugpy.listen(("0.0.0.0", 6800)) # nosec diff --git a/lifecycle/migrate.py b/lifecycle/migrate.py index 077ca149c..fe4957652 100755 --- a/lifecycle/migrate.py +++ b/lifecycle/migrate.py @@ -52,15 +52,15 @@ def release_lock(): if __name__ == "__main__": conn = connect( - dbname=CONFIG.y("postgresql.name"), - user=CONFIG.y("postgresql.user"), - password=CONFIG.y("postgresql.password"), - host=CONFIG.y("postgresql.host"), - port=int(CONFIG.y("postgresql.port")), - sslmode=CONFIG.y("postgresql.sslmode"), - sslrootcert=CONFIG.y("postgresql.sslrootcert"), - sslcert=CONFIG.y("postgresql.sslcert"), - sslkey=CONFIG.y("postgresql.sslkey"), + dbname=CONFIG.get("postgresql.name"), + user=CONFIG.get("postgresql.user"), + password=CONFIG.get("postgresql.password"), + host=CONFIG.get("postgresql.host"), + port=int(CONFIG.get("postgresql.port")), + sslmode=CONFIG.get("postgresql.sslmode"), + sslrootcert=CONFIG.get("postgresql.sslrootcert"), + sslcert=CONFIG.get("postgresql.sslcert"), + sslkey=CONFIG.get("postgresql.sslkey"), ) curr = conn.cursor() try: diff --git a/lifecycle/system_migrations/install_id.py b/lifecycle/system_migrations/install_id.py index e00430478..28867d8fa 100644 --- a/lifecycle/system_migrations/install_id.py +++ b/lifecycle/system_migrations/install_id.py @@ -25,7 +25,7 @@ class Migration(BaseMigration): # If we already have migrations in the database, assume we're upgrading an existing install # and set the install id to the secret key self.cur.execute( - "INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.y("secret_key"),) + "INSERT INTO authentik_install_id (id) VALUES (%s)", (CONFIG.get("secret_key"),) ) else: # Otherwise assume a new install, generate an install ID based on a UUID diff --git a/lifecycle/system_migrations/to_0_13_authentik.py b/lifecycle/system_migrations/to_0_13_authentik.py index c566515e3..ff2088349 100644 --- a/lifecycle/system_migrations/to_0_13_authentik.py +++ b/lifecycle/system_migrations/to_0_13_authentik.py @@ -108,14 +108,14 @@ class Migration(BaseMigration): self.con.commit() # We also need to clean the cache to make sure no pickeled objects still exist for db in [ - CONFIG.y("redis.message_queue_db"), - CONFIG.y("redis.cache_db"), - CONFIG.y("redis.ws_db"), + CONFIG.get("redis.message_queue_db"), + CONFIG.get("redis.cache_db"), + CONFIG.get("redis.ws_db"), ]: redis = Redis( - host=CONFIG.y("redis.host"), + host=CONFIG.get("redis.host"), port=6379, db=db, - password=CONFIG.y("redis.password"), + password=CONFIG.get("redis.password"), ) redis.flushall() diff --git a/lifecycle/wait_for_db.py b/lifecycle/wait_for_db.py index f8d7deb36..fd5c5c848 100755 --- a/lifecycle/wait_for_db.py +++ b/lifecycle/wait_for_db.py @@ -14,7 +14,7 @@ from authentik.lib.config import CONFIG CONFIG.log("info", "Starting authentik bootstrap") # Sanity check, ensure SECRET_KEY is set before we even check for database connectivity -if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0: +if CONFIG.get("secret_key") is None or len(CONFIG.get("secret_key")) == 0: CONFIG.log("info", "----------------------------------------------------------------------") CONFIG.log("info", "Secret key missing, check https://goauthentik.io/docs/installation/.") CONFIG.log("info", "----------------------------------------------------------------------") @@ -24,15 +24,15 @@ if CONFIG.y("secret_key") is None or len(CONFIG.y("secret_key")) == 0: while True: try: conn = connect( - dbname=CONFIG.y("postgresql.name"), - user=CONFIG.y("postgresql.user"), - password=CONFIG.y("postgresql.password"), - host=CONFIG.y("postgresql.host"), - port=int(CONFIG.y("postgresql.port")), - sslmode=CONFIG.y("postgresql.sslmode"), - sslrootcert=CONFIG.y("postgresql.sslrootcert"), - sslcert=CONFIG.y("postgresql.sslcert"), - sslkey=CONFIG.y("postgresql.sslkey"), + dbname=CONFIG.get("postgresql.name"), + user=CONFIG.get("postgresql.user"), + password=CONFIG.get("postgresql.password"), + host=CONFIG.get("postgresql.host"), + port=int(CONFIG.get("postgresql.port")), + sslmode=CONFIG.get("postgresql.sslmode"), + sslrootcert=CONFIG.get("postgresql.sslrootcert"), + sslcert=CONFIG.get("postgresql.sslcert"), + sslkey=CONFIG.get("postgresql.sslkey"), ) conn.cursor() break @@ -42,12 +42,12 @@ while True: CONFIG.log("info", "PostgreSQL connection successful") REDIS_PROTOCOL_PREFIX = "redis://" -if CONFIG.y_bool("redis.tls", False): +if CONFIG.get_bool("redis.tls", False): REDIS_PROTOCOL_PREFIX = "rediss://" REDIS_URL = ( f"{REDIS_PROTOCOL_PREFIX}:" - f"{quote_plus(CONFIG.y('redis.password'))}@{quote_plus(CONFIG.y('redis.host'))}:" - f"{int(CONFIG.y('redis.port'))}/{CONFIG.y('redis.db')}" + f"{quote_plus(CONFIG.get('redis.password'))}@{quote_plus(CONFIG.get('redis.host'))}:" + f"{int(CONFIG.get('redis.port'))}/{CONFIG.get('redis.db')}" ) while True: try: diff --git a/tests/e2e/test_flows_enroll.py b/tests/e2e/test_flows_enroll.py index 5d4cc1d48..ddf9959cc 100644 --- a/tests/e2e/test_flows_enroll.py +++ b/tests/e2e/test_flows_enroll.py @@ -1,7 +1,6 @@ """Test Enroll flow""" from time import sleep -from django.test import override_settings from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec from selenium.webdriver.support.wait import WebDriverWait @@ -9,6 +8,7 @@ from selenium.webdriver.support.wait import WebDriverWait from authentik.blueprints.tests import apply_blueprint from authentik.core.models import User from authentik.flows.models import Flow +from authentik.lib.config import CONFIG from authentik.stages.identification.models import IdentificationStage from tests.e2e.utils import SeleniumTestCase, retry @@ -56,7 +56,7 @@ class TestFlowsEnroll(SeleniumTestCase): @apply_blueprint( "example/flows-enrollment-email-verification.yaml", ) - @override_settings(EMAIL_PORT=1025) + @CONFIG.patch("email.port", 1025) def test_enroll_email(self): """Test enroll with Email verification""" # Attach enrollment flow to identification stage diff --git a/tests/e2e/test_flows_recovery.py b/tests/e2e/test_flows_recovery.py index 51fb64975..ddddd9140 100644 --- a/tests/e2e/test_flows_recovery.py +++ b/tests/e2e/test_flows_recovery.py @@ -1,7 +1,6 @@ """Test recovery flow""" from time import sleep -from django.test import override_settings from selenium.webdriver.common.by import By from selenium.webdriver.support import expected_conditions as ec from selenium.webdriver.support.wait import WebDriverWait @@ -10,6 +9,7 @@ from authentik.blueprints.tests import apply_blueprint from authentik.core.models import User from authentik.core.tests.utils import create_test_admin_user from authentik.flows.models import Flow +from authentik.lib.config import CONFIG from authentik.lib.generators import generate_id from authentik.stages.identification.models import IdentificationStage from tests.e2e.utils import SeleniumTestCase, retry @@ -47,7 +47,7 @@ class TestFlowsRecovery(SeleniumTestCase): @apply_blueprint( "example/flows-recovery-email-verification.yaml", ) - @override_settings(EMAIL_PORT=1025) + @CONFIG.patch("email.port", 1025) def test_recover_email(self): """Test recovery with Email verification""" # Attach recovery flow to identification stage