Compare commits

...
This repository has been archived on 2024-05-31. You can view files and clone it, but cannot push or open issues or pull requests.

1 commit

Author SHA1 Message Date
Jens Langhammer 3fa987f443
start config file watch
Signed-off-by: Jens Langhammer <jens.langhammer@beryju.org>
Signed-off-by: Jens Langhammer <jens@goauthentik.io>
2023-01-25 21:32:33 +01:00
2 changed files with 78 additions and 12 deletions

View file

@ -5,13 +5,20 @@ from contextlib import contextmanager
from glob import glob from glob import glob
from json import dumps, loads from json import dumps, loads
from json.decoder import JSONDecodeError from json.decoder import JSONDecodeError
from pathlib import Path
from sys import argv, stderr from sys import argv, stderr
from time import time from time import time
from typing import Any from typing import Any, Optional
from urllib.parse import urlparse from urllib.parse import urlparse
import yaml import yaml
from django.conf import ImproperlyConfigured from django.conf import ImproperlyConfigured
from watchdog.events import (
FileModifiedEvent,
FileSystemEvent,
FileSystemEventHandler,
)
from watchdog.observers import Observer
SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob( SEARCH_PATHS = ["authentik/lib/default.yml", "/etc/authentik/config.yml", ""] + glob(
"/etc/authentik/config.d/*.yml", recursive=True "/etc/authentik/config.d/*.yml", recursive=True
@ -38,9 +45,47 @@ class ConfigLoader:
A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host""" A variable like AUTHENTIK_POSTGRESQL__HOST would translate to postgresql.host"""
loaded_file = [] loaded_file = []
observer: Observer
class FSObserver(FileSystemEventHandler):
"""File system observer"""
loader: "ConfigLoader"
path: str
container: Optional[dict] = None
key: Optional[str] = None
def __init__(
self,
loader: "ConfigLoader",
path: str,
container: Optional[dict] = None,
key: Optional[str] = None,
) -> None:
super().__init__()
self.loader = loader
self.path = path
self.container = container
self.key = key
def on_any_event(self, event: FileSystemEvent):
if not isinstance(event, FileModifiedEvent):
return
if event.is_directory:
return
if event.src_path != self.path:
return
if self.container and self.key:
with open(self.path, "r", encoding="utf8") as _file:
self.container[self.key] = _file.read()
else:
self.loader.log("info", "Updating from changed file", file=self.path)
self.loader.update_from_file(self.path, watch=False)
def __init__(self): def __init__(self):
super().__init__() super().__init__()
self.observer = Observer()
self.observer.start()
self.__config = {} self.__config = {}
base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../..")) base_dir = os.path.realpath(os.path.join(os.path.dirname(__file__), "../.."))
for path in SEARCH_PATHS: for path in SEARCH_PATHS:
@ -81,11 +126,11 @@ class ConfigLoader:
root[key] = self.update(root.get(key, {}), value) root[key] = self.update(root.get(key, {}), value)
else: else:
if isinstance(value, str): if isinstance(value, str):
value = self.parse_uri(value) value = self.parse_uri(value, root, key)
root[key] = value root[key] = value
return root return root
def parse_uri(self, value: str) -> str: def parse_uri(self, value: str, container: dict[str, Any], key: Optional[str] = None, ) -> str:
"""Parse string values which start with a URI""" """Parse string values which start with a URI"""
url = urlparse(value) url = urlparse(value)
if url.scheme == "env": if url.scheme == "env":
@ -93,13 +138,23 @@ class ConfigLoader:
if url.scheme == "file": if url.scheme == "file":
try: try:
with open(url.path, "r", encoding="utf8") as _file: with open(url.path, "r", encoding="utf8") as _file:
value = _file.read().strip() value = _file.read()
if key:
self.observer.schedule(
ConfigLoader.FSObserver(
self,
url.path,
container,
key,
),
Path(url.path).parent,
)
except OSError as exc: except OSError as exc:
self.log("error", f"Failed to read config value from {url.path}: {exc}") self.log("error", f"Failed to read config value from {url.path}: {exc}")
value = url.query value = url.query
return value return value
def update_from_file(self, path: str): def update_from_file(self, path: str, watch=True):
"""Update config from file contents""" """Update config from file contents"""
try: try:
with open(path, encoding="utf8") as file: with open(path, encoding="utf8") as file:
@ -107,6 +162,8 @@ class ConfigLoader:
self.update(self.__config, yaml.safe_load(file)) self.update(self.__config, yaml.safe_load(file))
self.log("debug", "Loaded config", file=path) self.log("debug", "Loaded config", file=path)
self.loaded_file.append(path) self.loaded_file.append(path)
if watch:
self.observer.schedule(ConfigLoader.FSObserver(self, path), Path(path).parent)
except yaml.YAMLError as exc: except yaml.YAMLError as exc:
raise ImproperlyConfigured from exc raise ImproperlyConfigured from exc
except PermissionError as exc: except PermissionError as exc:
@ -181,13 +238,12 @@ class ConfigLoader:
if comp not in root: if comp not in root:
root[comp] = {} root[comp] = {}
root = root.get(comp, {}) root = root.get(comp, {})
root[path_parts[-1]] = value self.parse_uri(value, root, path_parts[-1])
def y_bool(self, path: str, default=False) -> bool: def y_bool(self, path: str, default=False) -> bool:
"""Wrapper for y that converts value into boolean""" """Wrapper for y that converts value into boolean"""
return str(self.y(path, default)).lower() == "true" return str(self.y(path, default)).lower() == "true"
CONFIG = ConfigLoader() CONFIG = ConfigLoader()
if __name__ == "__main__": if __name__ == "__main__":

View file

@ -5,7 +5,7 @@ from tempfile import mkstemp
from django.conf import ImproperlyConfigured from django.conf import ImproperlyConfigured
from django.test import TestCase from django.test import TestCase
from authentik.lib.config import ENV_PREFIX, ConfigLoader from authentik.lib.config import CONFIG, ENV_PREFIX, ConfigLoader
class TestConfig(TestCase): class TestConfig(TestCase):
@ -31,8 +31,8 @@ class TestConfig(TestCase):
"""Test URI parsing (environment)""" """Test URI parsing (environment)"""
config = ConfigLoader() config = ConfigLoader()
environ["foo"] = "bar" environ["foo"] = "bar"
self.assertEqual(config.parse_uri("env://foo"), "bar") self.assertEqual(config.parse_uri("env://foo", {}), "bar")
self.assertEqual(config.parse_uri("env://foo?bar"), "bar") self.assertEqual(config.parse_uri("env://foo?bar", {}), "bar")
def test_uri_file(self): def test_uri_file(self):
"""Test URI parsing (file load)""" """Test URI parsing (file load)"""
@ -41,8 +41,8 @@ class TestConfig(TestCase):
write(file, "foo".encode()) write(file, "foo".encode())
_, file2_name = mkstemp() _, file2_name = mkstemp()
chmod(file2_name, 0o000) # Remove all permissions so we can't read the file chmod(file2_name, 0o000) # Remove all permissions so we can't read the file
self.assertEqual(config.parse_uri(f"file://{file_name}"), "foo") self.assertEqual(config.parse_uri(f"file://{file_name}", {}), "foo")
self.assertEqual(config.parse_uri(f"file://{file2_name}?def"), "def") self.assertEqual(config.parse_uri(f"file://{file2_name}?def", {}), "def")
unlink(file_name) unlink(file_name)
unlink(file2_name) unlink(file2_name)
@ -59,3 +59,13 @@ class TestConfig(TestCase):
config.update_from_file(file2_name) config.update_from_file(file2_name)
unlink(file_name) unlink(file_name)
unlink(file2_name) unlink(file2_name)
def test_update(self):
"""Test change to file"""
file, file_name = mkstemp()
write(file, b"test")
CONFIG.y_set("test.file", f"file://{file_name}")
self.assertEqual(CONFIG.y("test.file"), "test")
write(file, "test2")
self.assertEqual(CONFIG.y("test.file"), "test2")
unlink(file_name)