diff --git a/shared/libretime_shared/config/_base.py b/shared/libretime_shared/config/_base.py index ac9ba09cf..40c0c652a 100644 --- a/shared/libretime_shared/config/_base.py +++ b/shared/libretime_shared/config/_base.py @@ -1,16 +1,14 @@ import sys -from os import environ +from itertools import zip_longest from pathlib import Path from typing import Any, Dict, List, Optional, Union from loguru import logger - -# pylint: disable=no-name-in-module from pydantic import BaseModel, ValidationError -from pydantic.fields import ModelField -from pydantic.utils import deep_update from yaml import YAMLError, safe_load +from ._env import EnvLoader + DEFAULT_ENV_PREFIX = "LIBRETIME" DEFAULT_CONFIG_FILEPATH = Path("/etc/libretime/config.yml") @@ -36,54 +34,19 @@ class BaseConfig(BaseModel): if filepath is not None: filepath = Path(filepath) - file_values = self._load_file_values(filepath) - env_values = self._load_env_values(env_prefix, env_delimiter) + env_loader = EnvLoader(self.schema(), env_prefix, env_delimiter) + + values = deep_merge_dict( + self._load_file_values(filepath), + env_loader.load(), + ) try: - super().__init__(**deep_update(file_values, env_values)) + super().__init__(**values) except ValidationError as error: logger.critical(error) sys.exit(1) - def _load_env_values(self, env_prefix: str, env_delimiter: str) -> Dict[str, Any]: - return self._get_fields_from_env(env_prefix, env_delimiter, self.__fields__) - - def _get_fields_from_env( - self, - env_prefix: str, - env_delimiter: str, - fields: Dict[str, ModelField], - ) -> Dict[str, Any]: - result: Dict[str, Any] = {} - - if env_prefix != "": - env_prefix += env_delimiter - - for field in fields.values(): - env_name = (env_prefix + field.name).upper() - - if field.is_complex(): - children: Union[List[Any], Dict[str, Any]] = [] - - if field.sub_fields: - if env_name in environ: - children = [v.strip() for v in environ[env_name].split(",")] - - else: - children = self._get_fields_from_env( - env_name, - env_delimiter, - field.type_.__fields__, - ) - - if len(children) != 0: - result[field.name] = children - else: - if env_name in environ: - result[field.name] = environ[env_name] - - return result - def _load_file_values( self, filepath: Optional[Path] = None, @@ -102,3 +65,38 @@ class BaseConfig(BaseModel): logger.error(f"config file '{filepath}' is not a valid yaml file: {error}") return {} + + +def deep_merge_dict(base: Dict[str, Any], next_: Dict[str, Any]) -> Dict[str, Any]: + result = base.copy() + for key, value in next_.items(): + if key in result: + if isinstance(result[key], dict) and isinstance(value, dict): + result[key] = deep_merge_dict(result[key], value) + continue + + if isinstance(result[key], list) and isinstance(value, list): + result[key] = deep_merge_list(result[key], value) + continue + + if value: + result[key] = value + + return result + + +def deep_merge_list(base: List[Any], next_: List[Any]) -> List[Any]: + result: List[Any] = [] + for base_item, next_item in zip_longest(base, next_): + if isinstance(base_item, list) and isinstance(next_item, list): + result.append(deep_merge_list(base_item, next_item)) + continue + + if isinstance(base_item, dict) and isinstance(next_item, dict): + result.append(deep_merge_dict(base_item, next_item)) + continue + + if next_item: + result.append(next_item) + + return result diff --git a/shared/libretime_shared/config/_env.py b/shared/libretime_shared/config/_env.py new file mode 100644 index 000000000..d8a42a649 --- /dev/null +++ b/shared/libretime_shared/config/_env.py @@ -0,0 +1,266 @@ +from collections import ChainMap +from functools import reduce +from operator import getitem +from os import environ +from typing import Any, Dict, List, Optional, TypeVar + +__all__ = [ + "EnvLoader", +] + + +def filter_env(env: Dict[str, str], prefix: str) -> Dict[str, str]: + """ + Filter a environment variables dict by key prefix. + + Args: + env: Environment variables dict. + prefix: Environment variable key prefix. + + Returns: + Environment variables dict. + """ + return {k: v for k, v in env.items() if k.startswith(prefix)} + + +def guess_env_array_indexes(env: Dict[str, str], prefix: str) -> List[int]: + """ + Guess environment variables indexes from the environment variables keys. + + Args: + env: Environment variables dict. + prefix: Environment variable key prefix for all indexes. + + Returns: + A list of indexes. + """ + prefix_len = len(prefix) + + result = [] + for env_name in filter_env(env, prefix): + if not env_name[prefix_len].isdigit(): + continue + + index_str = env_name[prefix_len:] + index_str = index_str.partition("_")[0] + result.append(int(index_str)) + + return result + + +T = TypeVar("T") + + +def index_dict_to_none_list(base: Dict[int, T]) -> List[Optional[T]]: + """ + Convert a dict to a list by associating the dict keys to the list + indexes and filling the missing indexes with None. + + Args: + base: Dict to convert. + + Returns: + Converted dict. + """ + if not base: + return [] + + result: List[Optional[T]] = [None] * (max(base.keys()) + 1) + + for index, value in base.items(): + result[index] = value + + return result + + +# pylint: disable=too-few-public-methods +class EnvLoader: + schema: dict + + env_prefix: str + env_delimiter: str + + _env: Dict[str, str] + + def __init__( + self, + schema: dict, + env_prefix: Optional[str] = None, + env_delimiter: str = "_", + ) -> None: + self.schema = schema + self.env_prefix = env_prefix or "" + self.env_delimiter = env_delimiter + + self._env = environ.copy() + if self.env_prefix: + self._env = filter_env(self._env, self.env_prefix) + + def load(self) -> Dict[str, Any]: + if not self._env: + return {} + + return self._get(self.env_prefix, self.schema) + + def _resolve_ref( + self, + path: str, + ) -> Dict[str, Any]: + _, *parts = path.split("/") + return reduce(getitem, parts, self.schema) + + def _get_mapping( + self, + env_name: str, + *schemas: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Get a mapping of each subtypes with the data. + + This helps resolve conflicts after we have all the data. + + Args: + env_name: Environment variable name to get the data from. + + Returns: + Mapping of each subtypes, with associated data as value. + """ + mapping: Dict[str, Any] = {} + + for schema in schemas: + if "$ref" in schema: + schema = self._resolve_ref(schema["$ref"]) + + value = self._get(env_name, schema) + if not value: + continue + + key = "title" if "title" in schema else "type" + mapping[schema[key]] = value + + return mapping + + # pylint: disable=too-many-return-statements + def _get( + self, + env_name: str, + schema: Dict[str, Any], + ) -> Any: + """ + Get a value from the environment. + + Args: + env_name: Environment variable name. + schema: Schema for the value we are retrieving. + + Returns: + Value retrieved from the environment. + """ + + if "$ref" in schema: + schema = self._resolve_ref(schema["$ref"]) + + if "type" in schema: + if schema["type"] in ("string", "integer", "boolean"): + return self._env.get(env_name, None) + + if schema["type"] == "object": + return self._get_object(env_name, schema) + + if schema["type"] == "array": + return self._get_array(env_name, schema) + + # Get all the properties as we won't have typing conflicts + if "allOf" in schema: + all_of_mapping = self._get_mapping(env_name, *schema["allOf"]) + # Merging all subtypes data together + return dict(ChainMap(*all_of_mapping.values())) + + # Get all the properties and resolve conflicts after + if "anyOf" in schema: + any_of_mapping = self._get_mapping(env_name, *schema["anyOf"]) + if any_of_mapping: + any_of_values = list(any_of_mapping.values()) + + # If all subtypes are primary types, return the first subtype data + if all(isinstance(value, str) for value in any_of_values): + return any_of_values[0] + + # If all subtypes are dicts, merge the subtypes data in a single dict. + # Do not worry if subtypes share a field name, as the value is from a + # single environment variable and will have the same value. + if all(isinstance(value, dict) for value in any_of_values): + return dict(ChainMap(*any_of_values)) + + return None + + raise ValueError(f"{env_name}: unhandled schema {schema}") + + def _get_object( + self, + env_name: str, + schema: Dict[str, Any], + ) -> Dict[str, Any]: + """ + Get an object from the environment. + + Args: + env_name: Environment variable name. + schema: Schema for the value we are retrieving. + + Returns: + Value retrieved from the environment. + """ + result: Dict[str, Any] = {} + + if env_name != "": + env_name += self.env_delimiter + + for child_key, child_schema in schema["properties"].items(): + child_env_name = (env_name + child_key).upper() + + value = self._get(child_env_name, child_schema) + if value: + result[child_key] = value + + return result + + # pylint: disable=too-many-branches + def _get_array( + self, + env_parent: str, + schema: Dict[str, Any], + ) -> Optional[List[Any]]: + """ + Get an array from the environment. + + Args: + env_name: Environment variable name. + schema: Schema for the value we are retrieving. + + Returns: + Value retrieved from the environment. + """ + result: Dict[int, Any] = {} + + schema_items = schema["items"] + if "$ref" in schema_items: + schema_items = self._resolve_ref(schema_items["$ref"]) + + # Found a environment variable without index suffix, try + # to extract CSV formatted array + if env_parent in self._env: + values = self._get(env_parent, schema_items) + if values: + for index, value in enumerate(values.split(",")): + result[index] = value.strip() + + indexes = guess_env_array_indexes(self._env, env_parent + self.env_delimiter) + if indexes: + for index in indexes: + env_name = env_parent + self.env_delimiter + str(index) + value = self._get(env_name, schema_items) + if value: + result[index] = value + + return index_dict_to_none_list(result) diff --git a/shared/tests/config/env_test.py b/shared/tests/config/env_test.py new file mode 100644 index 000000000..bd5fc62a3 --- /dev/null +++ b/shared/tests/config/env_test.py @@ -0,0 +1,189 @@ +# pylint: disable=protected-access +from os import environ +from typing import List, Union +from unittest import mock + +import pytest +from pydantic import BaseModel + +from libretime_shared.config import BaseConfig +from libretime_shared.config._env import EnvLoader + +ENV_SCHEMA_OBJ_WITH_STR = { + "type": "object", + "properties": {"a_str": {"type": "string"}}, +} + + +@pytest.mark.parametrize( + "env_parent, env, schema, expected", + [ + ( + "PRE", + {"PRE_A_STR": "found"}, + {"a_str": {"type": "string"}}, + {"a_str": "found"}, + ), + ( + "PRE", + {"PRE_OBJ_A_STR": "found"}, + {"obj": ENV_SCHEMA_OBJ_WITH_STR}, + {"obj": {"a_str": "found"}}, + ), + ( + "PRE", + {"PRE_ARR1": "one, two"}, + {"arr1": {"type": "array", "items": {"type": "string"}}}, + {"arr1": ["one", "two"]}, + ), + ( + "PRE", + { + "PRE_ARR2_0_A_STR": "one", + "PRE_ARR2_1_A_STR": "two", + "PRE_ARR2_3_A_STR": "ten", + }, + {"arr2": {"type": "array", "items": ENV_SCHEMA_OBJ_WITH_STR}}, + { + "arr2": [ + {"a_str": "one"}, + {"a_str": "two"}, + None, + {"a_str": "ten"}, + ] + }, + ), + ], +) +def test_env_config_loader_get_object( + env_parent, + env, + schema, + expected, +): + with mock.patch.dict(environ, env): + loader = EnvLoader(schema={}, env_prefix="PRE") + result = loader._get_object(env_parent, {"properties": schema}) + assert result == expected + + +class FirstChildConfig(BaseModel): + a_child_str: str + + +class SecondChildConfig(BaseModel): + a_child_str: str + a_child_int: int + + +# pylint: disable=too-few-public-methods +class FixtureConfig(BaseConfig): + a_str: str + a_list_of_str: List[str] + a_obj: FirstChildConfig + a_obj_with_default: FirstChildConfig = FirstChildConfig(a_child_str="default") + a_list_of_obj: List[FirstChildConfig] + a_union_str_or_int: Union[str, int] + a_union_obj: Union[FirstChildConfig, SecondChildConfig] + a_list_of_union_str_or_int: List[Union[str, int]] + a_list_of_union_obj: List[Union[FirstChildConfig, SecondChildConfig]] + + +ENV_SCHEMA = FixtureConfig.schema() + + +@pytest.mark.parametrize( + "env_name, env, schema, expected", + [ + ( + "PRE_A_STR", + {"PRE_A_STR": "found"}, + ENV_SCHEMA["properties"]["a_str"], + "found", + ), + ( + "PRE_A_LIST_OF_STR", + {"PRE_A_LIST_OF_STR": "one, two"}, + ENV_SCHEMA["properties"]["a_list_of_str"], + ["one", "two"], + ), + ( + "PRE_A_OBJ", + {"PRE_A_OBJ_A_CHILD_STR": "found"}, + ENV_SCHEMA["properties"]["a_obj"], + {"a_child_str": "found"}, + ), + ], +) +def test_env_config_loader_get( + env_name, + env, + schema, + expected, +): + with mock.patch.dict(environ, env): + loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE") + result = loader._get(env_name, schema) + assert result == expected + + +def test_env_config_loader_load_empty(): + with mock.patch.dict(environ, {}): + loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE") + result = loader.load() + assert not result + + +def test_env_config_loader_load(): + with mock.patch.dict( + environ, + { + "PRE_A_STR": "found", + "PRE_A_LIST_OF_STR": "one, two", + "PRE_A_OBJ": "invalid", + "PRE_A_OBJ_A_CHILD_STR": "found", + "PRE_A_OBJ_WITH_DEFAULT_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ": "invalid", + "PRE_A_LIST_OF_OBJ_0_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ_1_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ_3_A_CHILD_STR": "found", + "PRE_A_LIST_OF_OBJ_INVALID": "invalid", + "PRE_A_UNION_STR_OR_INT": "found", + "PRE_A_UNION_OBJ_A_CHILD_STR": "found", + "PRE_A_UNION_OBJ_A_CHILD_INT": "found", + "PRE_A_LIST_OF_UNION_STR_OR_INT": "one, two, 3", + "PRE_A_LIST_OF_UNION_STR_OR_INT_3": "4", + "PRE_A_LIST_OF_UNION_OBJ": "invalid", + "PRE_A_LIST_OF_UNION_OBJ_0_A_CHILD_STR": "found", + "PRE_A_LIST_OF_UNION_OBJ_1_A_CHILD_STR": "found", + "PRE_A_LIST_OF_UNION_OBJ_1_A_CHILD_INT": "found", + "PRE_A_LIST_OF_UNION_OBJ_3_A_CHILD_INT": "found", + "PRE_A_LIST_OF_UNION_OBJ_INVALID": "invalid", + }, + ): + loader = EnvLoader(schema=ENV_SCHEMA, env_prefix="PRE") + result = loader.load() + assert result == { + "a_str": "found", + "a_list_of_str": ["one", "two"], + "a_obj": {"a_child_str": "found"}, + "a_obj_with_default": {"a_child_str": "found"}, + "a_list_of_obj": [ + {"a_child_str": "found"}, + {"a_child_str": "found"}, + None, + {"a_child_str": "found"}, + ], + "a_union_str_or_int": "found", + "a_union_obj": { + "a_child_str": "found", + "a_child_int": "found", + }, + "a_list_of_union_str_or_int": ["one", "two", "3", "4"], + "a_list_of_union_obj": [ + {"a_child_str": "found"}, + {"a_child_str": "found", "a_child_int": "found"}, + None, + {"a_child_int": "found"}, + ], + }