diff --git a/aea/helpers/env_vars.py b/aea/helpers/env_vars.py index 43f62f4b5b..d1b865a8ae 100644 --- a/aea/helpers/env_vars.py +++ b/aea/helpers/env_vars.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # ------------------------------------------------------------------------------ # -# Copyright 2022-2023 Valory AG +# Copyright 2022-2024 Valory AG # Copyright 2018-2019 Fetch.AI Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -34,6 +34,11 @@ ENV_VARIABLE_RE = re.compile(r"^\$\{(([A-Z0-9_]+):?)?([a-z]+)?(:(.+))?}$") +MODELS = "models" +ARGS = "args" +ARGS_LEVEL_FROM_MODELS = 2 +ARG_LEVEL_FROM_MODELS = ARGS_LEVEL_FROM_MODELS + 1 +RESTRICTION_EXCEPTIONS = frozenset({"setup", "genesis_config"}) def is_env_variable(value: Any) -> bool: @@ -41,10 +46,35 @@ def is_env_variable(value: Any) -> bool: return isinstance(value, str) and bool(ENV_VARIABLE_RE.match(value)) -def export_path_to_env_var_string(export_path: List[str]) -> str: - """Conver export path to environment variable string.""" +def restrict_model_args(export_path: List[str]) -> Tuple[List[str], List[str]]: + """Do not allow more levels than one for a model's argument.""" + restricted = [] + result = [] + for i, current_path in enumerate(export_path): + result.append(current_path) + args_level = i + ARGS_LEVEL_FROM_MODELS + arg_level = i + ARG_LEVEL_FROM_MODELS + if ( + current_path == MODELS + and arg_level < len(export_path) + and export_path[args_level] == ARGS + and export_path[arg_level] not in RESTRICTION_EXCEPTIONS + ): + # do not allow more levels than one for a model's argument + arg_content_level = arg_level + 1 + result.extend(export_path[i + 1 : arg_content_level]) + # store the restricted part of the path + for j in range(arg_content_level, len(export_path)): + restricted.append(export_path[j]) + break + return restricted, result + + +def export_path_to_env_var_string(export_path: List[str]) -> Tuple[List[str], str]: + """Convert export path to environment variable string.""" + restricted, export_path = restrict_model_args(export_path) env_var_string = "_".join(map(str, export_path)) - return env_var_string.upper() + return restricted, env_var_string.upper() NotSet = object() @@ -149,7 +179,7 @@ def apply_env_variables( data, env_variables, default_value, - default_var_name=export_path_to_env_var_string(export_path=path), + default_var_name=export_path_to_env_var_string(export_path=path)[1], ) return data @@ -242,35 +272,76 @@ def is_strict_list(data: Union[List, Tuple]) -> bool: return is_strict +def list_to_nested_dict(lst: list, val: Any) -> dict: + """Convert a list to a nested dict.""" + nested_dict = val + for item in reversed(lst): + nested_dict = {item: nested_dict} + return nested_dict + + +def ensure_dict(dict_: Dict[str, Union[dict, str]]) -> dict: + """Return the given dictionary converting any values which are json strings as dicts.""" + return {k: json.loads(v) for k, v in dict_.items() if isinstance(v, str)} + + +def ensure_json_content(dict_: dict) -> dict: + """Return the given dictionary converting any nested dictionary values as json strings.""" + return {k: json.dumps(v) for k, v in dict_.items() if isinstance(v, dict)} + + +def merge_dicts(a: dict, b: dict) -> dict: + """Merge two dictionaries.""" + # shallow copy of `a` + merged = {**a} + for key, value in b.items(): + if key in merged and isinstance(merged[key], dict) and isinstance(value, dict): + # recursively merge nested dictionaries + merged[key] = merge_dicts(merged[key], value) + else: + # if not a nested dictionary, just take the value from `b` + merged[key] = value + return merged + + def generate_env_vars_recursively( data: Union[Dict, List], export_path: List[str], ) -> Dict: """Generate environment variables recursively.""" - env_var_dict = {} + env_var_dict: Dict[str, Any] = {} if isinstance(data, dict): for key, value in data.items(): - env_var_dict.update( - generate_env_vars_recursively( - data=value, - export_path=[*export_path, key], - ) + res = generate_env_vars_recursively( + data=value, + export_path=[*export_path, key], ) + if res: + env_var = list(res.keys())[0] + if env_var in env_var_dict: + dicts = (ensure_dict(dict_) for dict_ in (env_var_dict, res)) + res = ensure_json_content(merge_dicts(*dicts)) + env_var_dict.update(res) elif isinstance(data, list): if is_strict_list(data=data): - env_var_dict[ - export_path_to_env_var_string(export_path=export_path) - ] = json.dumps(data, separators=(",", ":")) + restricted, path = export_path_to_env_var_string(export_path=export_path) + if restricted: + env_var_dict[path] = json.dumps(list_to_nested_dict(restricted, data)) + else: + env_var_dict[path] = json.dumps(data, separators=(",", ":")) else: for key, value in enumerate(data): - env_var_dict.update( - generate_env_vars_recursively( - data=value, - export_path=[*export_path, key], - ) + res = generate_env_vars_recursively( + data=value, + export_path=[*export_path, key], ) + env_var_dict.update(res) else: - env_var_dict[export_path_to_env_var_string(export_path=export_path)] = data + restricted, path = export_path_to_env_var_string(export_path=export_path) + if restricted: + env_var_dict[path] = json.dumps(list_to_nested_dict(restricted, data)) + else: + env_var_dict[path] = data return env_var_dict diff --git a/docs/api/helpers/env_vars.md b/docs/api/helpers/env_vars.md index d34e6adb8e..ff0705f7db 100644 --- a/docs/api/helpers/env_vars.md +++ b/docs/api/helpers/env_vars.md @@ -14,15 +14,26 @@ def is_env_variable(value: Any) -> bool Check is variable string with env variable pattern. + + +#### restrict`_`model`_`args + +```python +def restrict_model_args(export_path: List[str]) -> Tuple[List[str], List[str]] +``` + +Do not allow more levels than one for a model's argument. + #### export`_`path`_`to`_`env`_`var`_`string ```python -def export_path_to_env_var_string(export_path: List[str]) -> str +def export_path_to_env_var_string( + export_path: List[str]) -> Tuple[List[str], str] ``` -Conver export path to environment variable string. +Convert export path to environment variable string. @@ -115,6 +126,46 @@ parameters: Boolean specifying whether it's a strict list or not + + +#### list`_`to`_`nested`_`dict + +```python +def list_to_nested_dict(lst: list, val: Any) -> dict +``` + +Convert a list to a nested dict. + + + +#### ensure`_`dict + +```python +def ensure_dict(dict_: Dict[str, Union[dict, str]]) -> dict +``` + +Return the given dictionary converting any values which are json strings as dicts. + + + +#### ensure`_`json`_`content + +```python +def ensure_json_content(dict_: dict) -> dict +``` + +Return the given dictionary converting any nested dictionary values as json strings. + + + +#### merge`_`dicts + +```python +def merge_dicts(a: dict, b: dict) -> dict +``` + +Merge two dictionaries. + #### generate`_`env`_`vars`_`recursively diff --git a/tests/test_helpers/test_env_vars.py b/tests/test_helpers/test_env_vars.py index e0fde93864..70a57dd884 100644 --- a/tests/test_helpers/test_env_vars.py +++ b/tests/test_helpers/test_env_vars.py @@ -1,7 +1,7 @@ # -*- coding: utf-8 -*- # ------------------------------------------------------------------------------ # -# Copyright 2022-2023 Valory AG +# Copyright 2022-2024 Valory AG # Copyright 2018-2019 Fetch.AI Limited # # Licensed under the Apache License, Version 2.0 (the "License"); @@ -118,7 +118,7 @@ def test_convert_value_str_to_type(): def test_env_var_string_generator(export_path: List[str], var_string: str) -> None: """Test `export_path_to_env_var_string` method""" - assert export_path_to_env_var_string(export_path=export_path) == var_string + assert export_path_to_env_var_string(export_path=export_path)[1] == var_string @pytest.mark.parametrize(