Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Prepare to add support for service-level dictionary overrides in the open-autonomy #761

Merged
merged 2 commits into from
Oct 3, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
111 changes: 91 additions & 20 deletions aea/helpers/env_vars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#
# Copyright 2022-2023 Valory AG
# Copyright 2022-2024 Valory AG
# Copyright 2018-2019 Fetch.AI Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -34,17 +34,47 @@


ENV_VARIABLE_RE = re.compile(r"^\$\{(([A-Z0-9_]+):?)?([a-z]+)?(:(.+))?}$")
MODELS = "models"
ARGS = "args"
ARGS_LEVEL_FROM_MODELS = 2
ARG_LEVEL_FROM_MODELS = ARGS_LEVEL_FROM_MODELS + 1
RESTRICTION_EXCEPTIONS = frozenset({"setup", "genesis_config"})


def is_env_variable(value: Any) -> bool:
"""Check is variable string with env variable pattern."""
return isinstance(value, str) and bool(ENV_VARIABLE_RE.match(value))


def export_path_to_env_var_string(export_path: List[str]) -> str:
"""Conver export path to environment variable string."""
def restrict_model_args(export_path: List[str]) -> Tuple[List[str], List[str]]:
"""Do not allow more levels than one for a model's argument."""
restricted = []
result = []
for i, current_path in enumerate(export_path):
result.append(current_path)
args_level = i + ARGS_LEVEL_FROM_MODELS
arg_level = i + ARG_LEVEL_FROM_MODELS
if (
current_path == MODELS
and arg_level < len(export_path)
and export_path[args_level] == ARGS
and export_path[arg_level] not in RESTRICTION_EXCEPTIONS
):
# do not allow more levels than one for a model's argument
arg_content_level = arg_level + 1
result.extend(export_path[i + 1 : arg_content_level])
# store the restricted part of the path
for j in range(arg_content_level, len(export_path)):
restricted.append(export_path[j])
break
return restricted, result


def export_path_to_env_var_string(export_path: List[str]) -> Tuple[List[str], str]:
"""Convert export path to environment variable string."""
restricted, export_path = restrict_model_args(export_path)
env_var_string = "_".join(map(str, export_path))
return env_var_string.upper()
return restricted, env_var_string.upper()


NotSet = object()
Expand Down Expand Up @@ -149,7 +179,7 @@ def apply_env_variables(
data,
env_variables,
default_value,
default_var_name=export_path_to_env_var_string(export_path=path),
default_var_name=export_path_to_env_var_string(export_path=path)[1],
)

return data
Expand Down Expand Up @@ -242,35 +272,76 @@ def is_strict_list(data: Union[List, Tuple]) -> bool:
return is_strict


def list_to_nested_dict(lst: list, val: Any) -> dict:
"""Convert a list to a nested dict."""
nested_dict = val
for item in reversed(lst):
nested_dict = {item: nested_dict}
return nested_dict


def ensure_dict(dict_: Dict[str, Union[dict, str]]) -> dict:
"""Return the given dictionary converting any values which are json strings as dicts."""
return {k: json.loads(v) for k, v in dict_.items() if isinstance(v, str)}


def ensure_json_content(dict_: dict) -> dict:
"""Return the given dictionary converting any nested dictionary values as json strings."""
return {k: json.dumps(v) for k, v in dict_.items() if isinstance(v, dict)}


def merge_dicts(a: dict, b: dict) -> dict:
"""Merge two dictionaries."""
# shallow copy of `a`
merged = {**a}
for key, value in b.items():
if key in merged and isinstance(merged[key], dict) and isinstance(value, dict):
# recursively merge nested dictionaries
merged[key] = merge_dicts(merged[key], value)
else:
# if not a nested dictionary, just take the value from `b`
merged[key] = value
return merged


def generate_env_vars_recursively(
data: Union[Dict, List],
export_path: List[str],
) -> Dict:
"""Generate environment variables recursively."""
env_var_dict = {}
env_var_dict: Dict[str, Any] = {}

if isinstance(data, dict):
for key, value in data.items():
env_var_dict.update(
generate_env_vars_recursively(
data=value,
export_path=[*export_path, key],
)
res = generate_env_vars_recursively(
data=value,
export_path=[*export_path, key],
)
if res:
env_var = list(res.keys())[0]
if env_var in env_var_dict:
dicts = (ensure_dict(dict_) for dict_ in (env_var_dict, res))
res = ensure_json_content(merge_dicts(*dicts))
env_var_dict.update(res)
elif isinstance(data, list):
if is_strict_list(data=data):
env_var_dict[
export_path_to_env_var_string(export_path=export_path)
] = json.dumps(data, separators=(",", ":"))
restricted, path = export_path_to_env_var_string(export_path=export_path)
if restricted:
env_var_dict[path] = json.dumps(list_to_nested_dict(restricted, data))
else:
env_var_dict[path] = json.dumps(data, separators=(",", ":"))
else:
for key, value in enumerate(data):
env_var_dict.update(
generate_env_vars_recursively(
data=value,
export_path=[*export_path, key],
)
res = generate_env_vars_recursively(
data=value,
export_path=[*export_path, key],
)
env_var_dict.update(res)
else:
env_var_dict[export_path_to_env_var_string(export_path=export_path)] = data
restricted, path = export_path_to_env_var_string(export_path=export_path)
if restricted:
env_var_dict[path] = json.dumps(list_to_nested_dict(restricted, data))
else:
env_var_dict[path] = data

return env_var_dict
55 changes: 53 additions & 2 deletions docs/api/helpers/env_vars.md
Original file line number Diff line number Diff line change
Expand Up @@ -14,15 +14,26 @@ def is_env_variable(value: Any) -> bool

Check is variable string with env variable pattern.

<a id="aea.helpers.env_vars.restrict_model_args"></a>

#### restrict`_`model`_`args

```python
def restrict_model_args(export_path: List[str]) -> Tuple[List[str], List[str]]
```

Do not allow more levels than one for a model's argument.

<a id="aea.helpers.env_vars.export_path_to_env_var_string"></a>

#### export`_`path`_`to`_`env`_`var`_`string

```python
def export_path_to_env_var_string(export_path: List[str]) -> str
def export_path_to_env_var_string(
export_path: List[str]) -> Tuple[List[str], str]
```

Conver export path to environment variable string.
Convert export path to environment variable string.

<a id="aea.helpers.env_vars.parse_list"></a>

Expand Down Expand Up @@ -115,6 +126,46 @@ parameters:

Boolean specifying whether it's a strict list or not

<a id="aea.helpers.env_vars.list_to_nested_dict"></a>

#### list`_`to`_`nested`_`dict

```python
def list_to_nested_dict(lst: list, val: Any) -> dict
```

Convert a list to a nested dict.

<a id="aea.helpers.env_vars.ensure_dict"></a>

#### ensure`_`dict

```python
def ensure_dict(dict_: Dict[str, Union[dict, str]]) -> dict
```

Return the given dictionary converting any values which are json strings as dicts.

<a id="aea.helpers.env_vars.ensure_json_content"></a>

#### ensure`_`json`_`content

```python
def ensure_json_content(dict_: dict) -> dict
```

Return the given dictionary converting any nested dictionary values as json strings.

<a id="aea.helpers.env_vars.merge_dicts"></a>

#### merge`_`dicts

```python
def merge_dicts(a: dict, b: dict) -> dict
```

Merge two dictionaries.

<a id="aea.helpers.env_vars.generate_env_vars_recursively"></a>

#### generate`_`env`_`vars`_`recursively
Expand Down
4 changes: 2 additions & 2 deletions tests/test_helpers/test_env_vars.py
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# -*- coding: utf-8 -*-
# ------------------------------------------------------------------------------
#
# Copyright 2022-2023 Valory AG
# Copyright 2022-2024 Valory AG
# Copyright 2018-2019 Fetch.AI Limited
#
# Licensed under the Apache License, Version 2.0 (the "License");
Expand Down Expand Up @@ -118,7 +118,7 @@ def test_convert_value_str_to_type():
def test_env_var_string_generator(export_path: List[str], var_string: str) -> None:
"""Test `export_path_to_env_var_string` method"""

assert export_path_to_env_var_string(export_path=export_path) == var_string
assert export_path_to_env_var_string(export_path=export_path)[1] == var_string


@pytest.mark.parametrize(
Expand Down
Loading