From da57595e5b0679ff1139595c367d0d1c9b9e598f Mon Sep 17 00:00:00 2001 From: Carl Baillargeon Date: Tue, 13 Aug 2024 14:18:46 -0400 Subject: [PATCH] Update Ansible anta_workflow plugin --- .../arista/avd/plugins/action/anta_runner.py | 287 ------------- .../avd/plugins/action/anta_workflow.py | 391 ++++++++++++++++++ python-avd/pyavd/_anta/utils/logging_utils.py | 5 +- python-avd/pyavd/_anta/utils/models.py | 55 ++- .../utils/{test_specs.py => test_index.py} | 2 +- python-avd/pyavd/_anta/utils/test_loader.py | 8 +- python-avd/pyavd/_errors/__init__.py | 4 + python-avd/pyavd/get_device_anta_catalog.py | 17 +- 8 files changed, 452 insertions(+), 317 deletions(-) delete mode 100644 ansible_collections/arista/avd/plugins/action/anta_runner.py create mode 100644 ansible_collections/arista/avd/plugins/action/anta_workflow.py rename python-avd/pyavd/_anta/utils/{test_specs.py => test_index.py} (98%) diff --git a/ansible_collections/arista/avd/plugins/action/anta_runner.py b/ansible_collections/arista/avd/plugins/action/anta_runner.py deleted file mode 100644 index bf13755e213..00000000000 --- a/ansible_collections/arista/avd/plugins/action/anta_runner.py +++ /dev/null @@ -1,287 +0,0 @@ -# Copyright (c) 2023-2024 Arista Networks, Inc. -# Use of this source code is governed by the Apache License 2.0 -# that can be found in the LICENSE file. -from __future__ import annotations - -import json -import logging -from asyncio import run -from collections import defaultdict -from pathlib import Path -from typing import TYPE_CHECKING, Any - -import yaml -from ansible.errors import AnsibleActionFail -from ansible.plugins.action import ActionBase, display - -from ansible_collections.arista.avd.plugins.plugin_utils.utils import PythonToAnsibleHandler - -if TYPE_CHECKING: - from collections.abc import Mapping - -PLUGIN_NAME = "arista.avd.anta_runner" - -try: - from pyavd._anta.lib import AntaCatalog, AntaInventory, AsyncEOSDevice, ResultManager, anta_runner, setup_logging - from pyavd._utils import get, strip_empties_from_dict - from pyavd.get_device_anta_catalog import get_device_anta_catalog, get_fabric_data - - HAS_PYAVD = True -except ImportError: - HAS_PYAVD = False - -LOGGER = logging.getLogger("ansible_collections.arista.avd") -# ANTA currently add some RichHandler to the root logger so need to disable propagation -LOGGER.propagate = False -LOGGING_LEVELS = ["DEBUG", "INFO", "ERROR", "WARNING", "CRITICAL"] - - -ARGUMENT_SPEC = { - "structured_config_dir": {"type": "str", "required": True}, - "structured_config_suffix": {"type": "str", "default": "yml", "choices": ["yml", "yaml", "json"]}, - "device_list": {"type": "list", "elements": "str", "required": True}, - "custom_anta_catalog": { - "type": "dict", - "options": { - "directory": {"type": "str", "required": True}, - "overwrite": {"type": "bool", "default": False}, - }, - }, - "skip_tests": { - "type": "dict", - "default": {}, - "options": { - "all": { - "type": "list", - "elements": "str", - }, - "device_specific": { - "type": "list", - "elements": "dict", - "options": { - "device": {"type": "str", "required": True}, - "tests": {"type": "list", "elements": "str", "required": True}, - }, - }, - }, - }, - "anta_log_settings": { - "type": "dict", - "options": { - "log_level": {"type": "str", "default": "WARNING", "choices": LOGGING_LEVELS}, - "log_file": {"type": "str"}, - }, - }, - "anta_global_settings": { - "type": "dict", - "options": { - "timeout": {"type": "float", "default": 30.0}, - "disable_cache": {"type": "bool", "default": False}, - }, - }, -} - - -class ActionModule(ActionBase): - def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: - self._supports_check_mode = False - - if task_vars is None: - task_vars = {} - - result = super().run(tmp, task_vars) - del tmp # tmp no longer has any effect - - if not HAS_PYAVD: - msg = f"The {PLUGIN_NAME} plugin requires the 'pyavd' Python library. Got import error" - raise AnsibleActionFail(msg) - - # Setup module logging - setup_module_logging(result) - - # Setup variables - hostvars = task_vars["hostvars"] - - # Get task arguments and validate them - validated_args = strip_empties_from_dict(self._task.args) - validation_result, validated_args = self.validate_argument_spec(ARGUMENT_SPEC) - validated_args = strip_empties_from_dict(validated_args) - - # Converting to json and back to remove any AnsibeUnsafe types - validated_args = json.loads(json.dumps(validated_args)) - - # Launch the run_anta coroutine to run everything - return run(self.run_anta(validated_args, hostvars, result)) - - async def run_anta(self, validated_args: dict, hostvars: Mapping, result: dict) -> dict: - # Setup logging for ANTA - log_file = get(validated_args, "anta_log_settings.log_file") - if log_file: - log_file = Path(log_file) - log_level = get(validated_args, "anta_log_settings.log_level") - setup_logging(level=log_level, file=log_file) - - # Build the required ANTA objects - result_manager, inventory, catalog = self.build_objects( - device_list=get(validated_args, "device_list"), - hostvars=hostvars, - structured_config_dir=get(validated_args, "structured_config_dir"), - structured_config_suffix=get(validated_args, "structured_config_suffix"), - custom_anta_catalog=get(validated_args, "custom_anta_catalog"), - skip_tests=get(validated_args, "skip_tests"), - anta_global_settings=get(validated_args, "anta_global_settings"), - ) - - await anta_runner(result_manager, inventory, catalog) - - # TODO: Do something useful with the results - LOGGER.info("ANTA run completed; total tests: %s", len(result_manager.results)) - LOGGER.info("ANTA run results: %s", result_manager.json) - - return result - - def build_objects( - self, - device_list: list[str], - hostvars: Mapping, - structured_config_dir: str, - structured_config_suffix: str, - custom_anta_catalog: dict | None, - skip_tests: dict | None, - anta_global_settings: dict, - ) -> tuple[ResultManager, AntaInventory, AntaCatalog]: - # Initialize the ANTA objects - final_inventory = AntaInventory() - final_catalog = AntaCatalog() - result_manager = ResultManager() - - # Check if the default catalog should be overwritten by the custom ANTA catalogs - overwrite_default_catalog = custom_anta_catalog and custom_anta_catalog["overwrite"] - - # Load custom ANTA catalogs if provided - if custom_anta_catalog: - final_catalog = final_catalog.merge(self.load_custom_anta_catalogs(custom_anta_catalog["directory"])) - - # Parse the skip tests once for all device. Not required when overwriting the default catalog - parsed_skip_tests = self.parse_skip_tests(device_list, skip_tests) if skip_tests and not overwrite_default_catalog else {} - - # Load the connection settings for each device. Structured configs are also loaded when not overwriting the default catalog - device_connection_settings, device_structured_configs = self.load_devices_data( - device_list, hostvars, structured_config_dir, structured_config_suffix, connection_settings_only=overwrite_default_catalog - ) - - # Create the fabric data object. When overwriting the default catalog, `fabric_data` will be empty and not used - fabric_data = get_fabric_data(device_structured_configs, logger=LOGGER) - - # Update the ANTA inventory and catalog for each device - for device, connection_settings in device_connection_settings.items(): - # Add global settings to the connection settings - connection_settings.update(anta_global_settings) - anta_device = AsyncEOSDevice(name=device, **connection_settings) - final_inventory.add_device(anta_device) - - # Skip adding the device catalog if the default catalog should be overwritten - if overwrite_default_catalog: - continue - - device_skip_tests = get(parsed_skip_tests, device) - device_catalog = get_device_anta_catalog(device, fabric_data, skip_tests=device_skip_tests, logger=LOGGER) - final_catalog = final_catalog.merge(device_catalog) - - return result_manager, final_inventory, final_catalog - - def parse_skip_tests(self, device_list: list[str], skip_tests: dict) -> dict: - parsed_skip_tests = defaultdict(set) - - # Handle tests to skip for all devices - skip_for_all = set(get(skip_tests, "all", default=[])) - for device in device_list: - parsed_skip_tests[device].update(skip_for_all) - - # Handle device-specific tests to skip - for item in get(skip_tests, "device_specific", default=[]): - device = item["device"] - if device in device_list: - parsed_skip_tests[device].update(item["tests"]) - - return dict(parsed_skip_tests) - - def load_custom_anta_catalogs(self, custom_anta_catalogs_dir: str) -> AntaCatalog: - # Tests from custom catalogs tagged by the device name will be honored, i.e. they will run only on the device with the same name - # TODO: Other tags will be ignored since we don't have a way (yet) to add these tags to the devices in the Ansible inventory - final_catalog = AntaCatalog() - for path_obj in Path(custom_anta_catalogs_dir).iterdir(): - if path_obj.is_file() and path_obj.suffix.lower() in {".yml", ".yaml"}: - LOGGER.info("Loading custom ANTA catalog from %s", path_obj) - catalog = AntaCatalog.parse(path_obj) - final_catalog = final_catalog.merge(catalog) - return final_catalog - - def load_devices_data( - self, device_list: list[str], hostvars: Mapping, structured_config_dir: str, structured_config_suffix: str, *, connection_settings_only: bool = False - ) -> tuple[dict, dict]: - structured_configs = {} - connection_settings = {} - - for device in device_list: - # Load the connection settings for each device from Ansible hostvars - if device not in hostvars: - LOGGER.warning("Device %s not found in inventory. Skipping device", device) - continue - - device_connection_settings = self.get_connection_settings(device, hostvars[device]) - if device_connection_settings: - connection_settings[device] = device_connection_settings - - # Skip loading the structured config if only connection settings are needed - if connection_settings_only: - continue - - # Load the structured config for each device - structured_config = self.load_structured_config(device, structured_config_dir, structured_config_suffix) - - # Skip devices that are not deployed - if structured_config.get("is_deployed", True) is False: - LOGGER.info("Device %s `is_deployed` key is set to False. Skipping device", device) - continue - - structured_configs[device] = structured_config - - return (connection_settings, structured_configs) - - def load_structured_config(self, device: str, structured_config_dir: str, structured_config_suffix: str) -> dict: - config_path = Path(structured_config_dir, f"{device}.{structured_config_suffix}") - with config_path.open(mode="r", encoding="UTF-8") as stream: - if structured_config_suffix in {"yml", "yaml"}: - return yaml.load(stream, Loader=yaml.CSafeLoader) - return json.load(stream) - - def get_connection_settings(self, device: str, host_hostvars: dict) -> dict: - # Following Ansible HTTPAPI connection plugin settings - # https://docs.ansible.com/ansible/latest/collections/ansible/netcommon/httpapi_connection.html - if get(host_hostvars, "ansible_connection") != "httpapi": - LOGGER.warning("Device %s is not using httpapi connection plugin, can't guarantee the connection settings. Skipping device", device) - return {} - - return { - "host": get(host_hostvars, "ansible_host", default=get(host_hostvars, "inventory_hostname")), - "username": get(host_hostvars, "ansible_user"), - "password": get(host_hostvars, "ansible_password"), - "enable": get(host_hostvars, "ansible_become", default=False), - "enable_password": get(host_hostvars, "ansible_become_password"), - "port": get(host_hostvars, "ansible_httpapi_port", default=(80 if get(host_hostvars, "ansible_httpapi_use_ssl", default=False) is False else 443)), - } - - -def setup_module_logging(result: dict) -> None: - """ - Add a Handler to copy the logs from the plugin into Ansible output based on their level. - - Parameters: - result: The dictionary used for the Ansible module results - """ - python_to_ansible_handler = PythonToAnsibleHandler(result, display) - LOGGER.addHandler(python_to_ansible_handler) - - # Set level to DEBUG to be able to see logs with `-v` and `-vvv` - LOGGER.setLevel(logging.DEBUG) diff --git a/ansible_collections/arista/avd/plugins/action/anta_workflow.py b/ansible_collections/arista/avd/plugins/action/anta_workflow.py new file mode 100644 index 00000000000..8abcfd58f0e --- /dev/null +++ b/ansible_collections/arista/avd/plugins/action/anta_workflow.py @@ -0,0 +1,391 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +from __future__ import annotations + +import json +import logging +from asyncio import run +from collections import defaultdict +from pathlib import Path +from typing import TYPE_CHECKING, Any + +import yaml +from ansible.errors import AnsibleActionFail +from ansible.plugins.action import ActionBase, display + +from ansible_collections.arista.avd.plugins.plugin_utils.utils import PythonToAnsibleHandler + +if TYPE_CHECKING: + from collections.abc import Mapping + +PLUGIN_NAME = "arista.avd.anta_workflow" + +try: + from pyavd._anta.lib import AntaCatalog, AntaInventory, AsyncEOSDevice, ResultManager, anta_runner, setup_logging + from pyavd._utils import default, get, strip_empties_from_dict + from pyavd.get_device_anta_catalog import get_device_anta_catalog, get_fabric_data + + HAS_PYAVD = True +except ImportError: + HAS_PYAVD = False + +LOGGER = logging.getLogger("ansible_collections.arista.avd") +# ANTA uses RichHandler at the root logger. Disabling propagation to avoid duplicate logs +LOGGER.propagate = False +LOGGING_LEVELS = ["DEBUG", "INFO", "ERROR", "WARNING", "CRITICAL"] + +ANSIBLE_HTTPAPI_CONNECTION_DOC = "https://docs.ansible.com/ansible/latest/collections/ansible/netcommon/httpapi_connection.html" + +ARGUMENT_SPEC = { + "structured_config_dir": {"type": "str", "required": True}, + "structured_config_suffix": {"type": "str", "default": "yml", "choices": ["yml", "yaml", "json"]}, + "device_list": {"type": "list", "elements": "str", "required": True}, + "custom_anta_catalogs": { + "type": "dict", + "options": { + "directory": {"type": "str", "required": True}, + "overwrite_default": {"type": "bool", "default": False}, + }, + }, + "skip_tests": { + "type": "dict", + "options": { + "all_devices": { + "type": "list", + "elements": "str", + }, + "per_device": { + "type": "list", + "elements": "dict", + "options": { + "device": {"type": "str", "required": True}, + "tests": {"type": "list", "elements": "str", "required": True}, + }, + }, + }, + }, + "anta_logging": { + "type": "dict", + "options": { + "log_level": {"type": "str", "default": "WARNING", "choices": LOGGING_LEVELS}, + "log_file": {"type": "str"}, + }, + }, + "anta_global_settings": { + "type": "dict", + "options": { + "timeout": {"type": "float", "default": 30.0}, + "disable_cache": {"type": "bool", "default": False}, + }, + }, +} + + +class ActionModule(ActionBase): + def run(self, tmp: Any = None, task_vars: dict | None = None) -> dict: + self._supports_check_mode = False + + if task_vars is None: + task_vars = {} + + result = super().run(tmp, task_vars) + del tmp # tmp no longer has any effect + + if not HAS_PYAVD: + msg = f"The {PLUGIN_NAME} plugin requires the 'pyavd' Python library. Got import error" + raise AnsibleActionFail(msg) + + # Setup module logging + setup_module_logging(result) + + # Setup variables + hostvars = task_vars["hostvars"] + + # Get task arguments and validate them + validated_args = strip_empties_from_dict(self._task.args) + validation_result, validated_args = self.validate_argument_spec(ARGUMENT_SPEC) + validated_args = strip_empties_from_dict(validated_args) + + # Converting to json and back to remove any AnsibeUnsafe types + validated_args = json.loads(json.dumps(validated_args)) + + # Launch the run_anta coroutine to run everything + return run(self.run_anta(validated_args, hostvars, result)) + + async def run_anta(self, validated_args: dict, hostvars: Mapping, result: dict) -> dict: + """Main coroutine to run the ANTA workflow. + + Parameters + ---------- + validated_args: The validated plugin arguments. + hostvars: The Ansible hostvars object containing all variables of each device. + result: The dictionary used for the Ansible module results. + + Returns: + ------- + dict: The updated Ansible module result dictionary. + """ + # Setup logging for ANTA + log_file = get(validated_args, "anta_logging.log_file") + if log_file: + log_file = Path(log_file) + log_level = get(validated_args, "anta_logging.log_level") + setup_logging(level=log_level, file=log_file) + + # Build the connection settings for each device using Ansible hostvars + device_list = get(validated_args, "device_list") + global_settings = get(validated_args, "anta_global_settings") + connection_settings = self.build_connection_settings(device_list, hostvars, global_settings) + + # Build the required ANTA objects + result_manager, inventory, catalog = self.build_objects( + device_list=device_list, + connection_settings=connection_settings, + structured_config_dir=get(validated_args, "structured_config_dir"), + structured_config_suffix=get(validated_args, "structured_config_suffix"), + custom_anta_catalogs=get(validated_args, "custom_anta_catalogs"), + skip_tests=get(validated_args, "skip_tests"), + ) + + await anta_runner(result_manager, inventory, catalog) + + # TODO: Do something useful with the results (reporting, etc.) + LOGGER.info("ANTA run completed; total tests: %s", len(result_manager.results)) + LOGGER.info("ANTA run results: %s", result_manager.json) + + return result + + def build_objects( + self, + device_list: list[str], + connection_settings: dict, + structured_config_dir: str, + structured_config_suffix: str, + custom_anta_catalogs: dict | None, + skip_tests: dict | None, + ) -> tuple[ResultManager, AntaInventory, AntaCatalog]: + """Build the objects required to run the ANTA. + + Parameters + ---------- + device_list: The list of device names. + connection_settings: The connection settings for each device to create the `AsyncEOSDevice` objects. + structured_config_dir: The directory where the structured configurations are stored. + structured_config_suffix: The suffix of the structured configuration files (yml, yaml, json). + custom_anta_catalogs: Optional custom ANTA catalogs input dictionary containing the directory and overwrite flag. + skip_tests: Optional skip_tests input dictionary containing tests to skip for all devices and per device. + + Returns: + ------- + tuple: A tuple containing the ResultManager, AntaInventory, and AntaCatalog ANTA objects. + + # NOTE: Tests from custom catalogs tagged by the device name will be honored, i.e. they will run only on the device with the same name + # TODO: Other tags will be ignored since we don't have a way (yet) to add these tags to the devices in the Ansible inventory + """ + # Initialize the ANTA objects + final_inventory = AntaInventory() + final_catalog = AntaCatalog() + result_manager = ResultManager() + + # Determine if we're overwriting the default catalog + overwrite_default_catalog = custom_anta_catalogs and custom_anta_catalogs["overwrite_default"] + + # Load custom ANTA catalogs if provided + if custom_anta_catalogs: + final_catalog = final_catalog.merge(self.load_custom_anta_catalogs(custom_anta_catalogs["directory"])) + + # Only process default catalog data if not overwriting the default catalog + if not overwrite_default_catalog: + parsed_skip_tests = self.parse_skip_tests(device_list, skip_tests) if skip_tests else {} + structured_configs = self.load_structured_configs(device_list, structured_config_dir, structured_config_suffix) + else: + LOGGER.info("Overwriting the default catalog with custom catalogs") + parsed_skip_tests = {} + structured_configs = {} + + # Create the fabric data from the structured configs. Will be empty and not used if overwriting the default catalog + fabric_data = get_fabric_data(structured_configs, logger=LOGGER) + + # Update the ANTA inventory and catalog for each device + for device, settings in connection_settings.items(): + anta_device = AsyncEOSDevice(name=device, **settings) + final_inventory.add_device(anta_device) + + # Skip adding the device catalog if the default catalog should be overwritten + if not overwrite_default_catalog: + device_skip_tests = get(parsed_skip_tests, device) + device_catalog = get_device_anta_catalog(device, fabric_data, skip_tests=device_skip_tests, logger=LOGGER) + final_catalog = final_catalog.merge(device_catalog) + + return result_manager, final_inventory, final_catalog + + def build_connection_settings(self, device_list: list[str], hostvars: Mapping, global_settings: dict) -> dict: + """Build the connection settings for each device using the Ansible hostvars. + + Parameters + ---------- + device_list: The list of device names. + hostvars: The Ansible hostvars object containing all variables of each device. + global_settings: The global settings dictionary from the plugin arguments. + + Returns: + ------- + dict: A dictionary with device names as keys and connection settings as values to create the `AsyncEOSDevice` objects. + """ + connection_settings = {} + + # Required settings to create the ANTA device object + required_settings = ["host", "username", "password"] + + for device in device_list: + if device not in hostvars: + LOGGER.warning("Device %s not found in Ansible inventory. Skipping device", device) + continue + + device_vars = hostvars[device] + device_connection_settings = { + "host": get(device_vars, "ansible_host", default=get(device_vars, "inventory_hostname")), + "username": get(device_vars, "ansible_user"), + "password": default( + get(device_vars, "ansible_password"), get(device_vars, "ansible_httpapi_pass"), get(device_vars, "ansible_httpapi_password") + ), + "enable": get(device_vars, "ansible_become", default=False), + "enable_password": get(device_vars, "ansible_become_password"), + "port": get(device_vars, "ansible_httpapi_port", default=(80 if get(device_vars, "ansible_httpapi_use_ssl", default=False) is False else 443)), + "timeout": get(global_settings, "timeout"), + "disable_cache": get(global_settings, "disable_cache"), + } + + # Make sure we found all required connection settings. Other settings have defaults in the ANTA device object + if any(value is None for key, value in device_connection_settings.items() if key in required_settings): + msg = ( + f"Device {device} is missing required connection settings. Skipping device. " + f"Please make sure all required connection variables are defined in the Ansible inventory, " + f"following the Ansible HTTPAPI connection plugin settings: {ANSIBLE_HTTPAPI_CONNECTION_DOC}" + ) + LOGGER.warning(msg) + continue + + connection_settings[device] = device_connection_settings + + return connection_settings + + def parse_skip_tests(self, device_list: list[str], skip_tests: dict) -> dict: + """Parse the skip_tests input dictionary. + + Parameters + ---------- + device_list: The list of device names. + skip_tests: The skip_tests input dictionary from the plugin arguments. + + Returns: + ------- + dict: A dictionary with device names as keys and the tests to skip as values. + """ + parsed_skip_tests = defaultdict(set) + + # Handle tests to skip for all devices + skip_for_all = set(get(skip_tests, "all_devices", default=[])) + for device in device_list: + parsed_skip_tests[device].update(skip_for_all) + + # Handle device-specific tests to skip + for item in get(skip_tests, "per_device", default=[]): + device = item["device"] + if device in device_list: + parsed_skip_tests[device].update(item["tests"]) + + return dict(parsed_skip_tests) + + def load_custom_anta_catalogs(self, custom_anta_catalogs_dir: str) -> AntaCatalog: + """Load custom ANTA catalogs from the provided directory. + + Supports YAML files only. + + Parameters + ---------- + custom_anta_catalogs_dir: The directory where the custom ANTA catalogs are stored. + + Returns: + ------- + AntaCatalog: Instance of the merged custom ANTA catalogs. + """ + custom_catalog = AntaCatalog() + for path_obj in Path(custom_anta_catalogs_dir).iterdir(): + if path_obj.is_file() and path_obj.suffix.lower() in {".yml", ".yaml"}: + # Error handling is done in ANTA + LOGGER.info("Loading custom ANTA catalog from %s", path_obj) + catalog = AntaCatalog.parse(path_obj) + custom_catalog = custom_catalog.merge(catalog) + + return custom_catalog + + def load_structured_configs(self, device_list: list[str], structured_config_dir: str, structured_config_suffix: str) -> dict: + """Load the structured configurations for the devices in the provided list from the given directory. + + Parameters + ---------- + device_list: The list of device names. + structured_config_dir: The directory where the structured configurations are stored. + structured_config_suffix: The suffix of the structured configuration files (yml, yaml, json). + + Returns: + ------- + dict: A dictionary with the device names as keys and the structured configurations as values. + """ + structured_configs = {} + for device in device_list: + try: + structured_config = self.load_device_structured_config(device, structured_config_dir, structured_config_suffix) + except FileNotFoundError: + LOGGER.warning("Structured configuration file for device %s not found. Skipping device", device) + continue + except (OSError, yaml.YAMLError, json.JSONDecodeError) as exc: + LOGGER.warning("Error loading structured configuration for device %s: %s. Skipping device", device, str(exc)) + continue + + # Skip devices that are not deployed + if structured_config.get("is_deployed", True) is False: + LOGGER.info("Device %s `is_deployed` key is set to False. Skipping device", device) + continue + + structured_configs[device] = structured_config + + return structured_configs + + def load_device_structured_config(self, device: str, structured_config_dir: str, structured_config_suffix: str) -> dict: + """Load the structured configuration for a device from the provided directory. + + Parameters + ---------- + device: The name of the device. + structured_config_dir: The directory where the structured configurations are stored. + structured_config_suffix: The suffix of the structured configuration files (yml, yaml, json). + + Returns: + ------- + dict: The structured configuration for the device. + """ + config_path = Path(structured_config_dir, f"{device}.{structured_config_suffix}") + with config_path.open(mode="r", encoding="UTF-8") as stream: + if structured_config_suffix in {"yml", "yaml"}: + return yaml.load(stream, Loader=yaml.CSafeLoader) + return json.load(stream) + + +def setup_module_logging(result: dict) -> None: + """Add a Handler to copy the logs from the plugin into Ansible output based on their level. + + Parameters + ---------- + result: The dictionary used for the Ansible module results. + """ + python_to_ansible_handler = PythonToAnsibleHandler(result, display) + LOGGER.addHandler(python_to_ansible_handler) + + # Set the logging level based on the Ansible verbosity level + if display.verbosity >= 3: + LOGGER.setLevel(logging.DEBUG) + elif display.verbosity >= 1: + LOGGER.setLevel(logging.INFO) diff --git a/python-avd/pyavd/_anta/utils/logging_utils.py b/python-avd/pyavd/_anta/utils/logging_utils.py index c7cbc22d5ca..51bd0caa449 100644 --- a/python-avd/pyavd/_anta/utils/logging_utils.py +++ b/python-avd/pyavd/_anta/utils/logging_utils.py @@ -28,7 +28,10 @@ class TestLoggerAdapter(logging.LoggerAdapter): """ def process(self, msg: object, kwargs: dict) -> tuple[str, dict]: - """Process the message and kwargs before logging.""" + """Process the message and kwargs before logging. + + TODO: Simplify this. + """ # Keep the extra dict in kwargs to pass it to the formatter if needed (following the standard LoggerAdapter behavior) kwargs["extra"] = self.extra diff --git a/python-avd/pyavd/_anta/utils/models.py b/python-avd/pyavd/_anta/utils/models.py index a0bf09b33da..6a37bd324ce 100644 --- a/python-avd/pyavd/_anta/utils/models.py +++ b/python-avd/pyavd/_anta/utils/models.py @@ -4,8 +4,7 @@ from __future__ import annotations from collections import defaultdict -from dataclasses import dataclass, field -from ipaddress import IPv4Address, ip_interface +from ipaddress import ip_interface from typing import TYPE_CHECKING, Self from anta.catalog import AntaTestDefinition @@ -26,11 +25,10 @@ from .logging_utils import TestLoggerAdapter -@dataclass class FabricData: - """FabricData data class. + """FabricData class. - Data class to store the structured configs and mappings for the fabric devices. Used to generate the test inputs. + Class to store the structured configs and mappings for the fabric devices. Used to generate the test inputs. Attributes: ---------- @@ -47,25 +45,34 @@ class FabricData: The logger object to use for logging messages. """ - structured_configs: dict[str, dict] - loopback0_mapping: dict[str, IPv4Address] = field(default_factory=dict, init=False) - vtep_mapping: dict[str, IPv4Address] = field(default_factory=dict, init=False) - combined_mapping: defaultdict[str, list[IPv4Address]] = field(default_factory=lambda: defaultdict(list), init=False) - logger: Logger + def __init__(self, structured_configs: dict[str, dict], logger: Logger) -> None: + """Initialize the FabricData instance.""" + self.structured_configs = structured_configs + self.loopback0_mapping = {} + self.vtep_mapping = {} + self.combined_mapping = defaultdict(list) + self.logger = logger - def __post_init__(self) -> None: - """Post init method to generate the mappings.""" + # Generate the mappings and populate the attributes self._generate_mappings() def _generate_mappings(self) -> None: - """Generate the class mappings.""" + """Generate the mappings.""" for device, config in self.structured_configs.items(): self._process_loopback0(device, config) self._process_vtep(device, config) - def _process_loopback0(self, device: str, config: dict) -> None: - """Process the loopback0 mapping.""" - loopback_interfaces = get(config, "loopback_interfaces", default=[]) + def _process_loopback0(self, device: str, structured_config: dict) -> None: + """Process the loopback0 mapping. + + Populates the loopback0_mapping and combined_mapping attributes. + + Parameters + ---------- + device: The device name. + structured_config: The structured configuration of the device. + """ + loopback_interfaces = get(structured_config, "loopback_interfaces", default=[]) if (loopback0 := get_item(loopback_interfaces, "name", "Loopback0")) is not None and (loopback_ip := get(loopback0, "ip_address")) is not None: ip_obj = ip_interface(loopback_ip).ip self.loopback0_mapping[device] = ip_obj @@ -73,10 +80,18 @@ def _process_loopback0(self, device: str, config: dict) -> None: else: self.logger.warning("<%s>: Loopback0 or IP missing. Some tests will be skipped.", device) - def _process_vtep(self, device: str, config: dict) -> None: - """Process the vtep mapping.""" - loopback_interfaces = get(config, "loopback_interfaces", default=[]) - vtep_interface = get(config, "vxlan_interface.Vxlan1.vxlan.source_interface") + def _process_vtep(self, device: str, structured_config: dict) -> None: + """Process the vtep mapping. + + Populates the vtep_mapping and combined_mapping attributes. + + Parameters + ---------- + device: The device name. + structured_config: The structured configuration of the device. + """ + loopback_interfaces = get(structured_config, "loopback_interfaces", default=[]) + vtep_interface = get(structured_config, "vxlan_interface.Vxlan1.vxlan.source_interface") # NOTE: For now we exclude WAN VTEPs from the vtep_mapping if vtep_interface is not None and "Dps" not in vtep_interface: diff --git a/python-avd/pyavd/_anta/utils/test_specs.py b/python-avd/pyavd/_anta/utils/test_index.py similarity index 98% rename from python-avd/pyavd/_anta/utils/test_specs.py rename to python-avd/pyavd/_anta/utils/test_index.py index bf3a930b200..1a6bc060941 100644 --- a/python-avd/pyavd/_anta/utils/test_specs.py +++ b/python-avd/pyavd/_anta/utils/test_index.py @@ -16,7 +16,7 @@ from .constants import StructuredConfigKey from .models import TestSpec -PYAVD_TEST_SPECS: list[TestSpec] = [ +PYAVD_TEST_INDEX: list[TestSpec] = [ TestSpec( name="VerifyAPIHttpsSSL", conditional_keys=[StructuredConfigKey.HTTPS_SSL_PROFILE], diff --git a/python-avd/pyavd/_anta/utils/test_loader.py b/python-avd/pyavd/_anta/utils/test_loader.py index 5248adf3204..9668320d968 100644 --- a/python-avd/pyavd/_anta/utils/test_loader.py +++ b/python-avd/pyavd/_anta/utils/test_loader.py @@ -9,12 +9,12 @@ from pyavd._utils import load_classes -from .test_specs import PYAVD_TEST_SPECS +from .test_index import PYAVD_TEST_INDEX if TYPE_CHECKING: from anta.models import AntaTest - from .test_specs import TestSpec + from .test_index import TestSpec LOGGER = logging.getLogger("pyavd") @@ -44,9 +44,9 @@ def update_test_spec(test_spec: TestSpec, anta_available_tests: dict[str, type[A return True -PYAVD_TEST_SPECS = list( +PYAVD_TEST_INDEX = list( filter( partial(update_test_spec, anta_available_tests=ANTA_AVAILABLE_TESTS), - PYAVD_TEST_SPECS, + PYAVD_TEST_INDEX, ), ) diff --git a/python-avd/pyavd/_errors/__init__.py b/python-avd/pyavd/_errors/__init__.py index c3c178bd6b5..e509c0c8535 100644 --- a/python-avd/pyavd/_errors/__init__.py +++ b/python-avd/pyavd/_errors/__init__.py @@ -26,6 +26,10 @@ class AristaAvdMissingVariableError(AristaAvdError): pass +class AvdConfigLoadError(AristaAvdError): + pass + + class AvdSchemaError(AristaAvdError): def __init__(self, message: str = "Schema Error", error: jsonschema.ValidationError | None = None) -> None: if isinstance(error, jsonschema.SchemaError): diff --git a/python-avd/pyavd/get_device_anta_catalog.py b/python-avd/pyavd/get_device_anta_catalog.py index c69629c4426..cd4a7756e44 100644 --- a/python-avd/pyavd/get_device_anta_catalog.py +++ b/python-avd/pyavd/get_device_anta_catalog.py @@ -24,6 +24,15 @@ def get_device_anta_catalog( ) -> AntaCatalog: """Generate an ANTA catalog for a single device. + By default, the ANTA catalog will be generated from all tests specified in the PyAVD test index, + located in the `pyavd._anta.utils.test_loader` module. The user can optionally provide a list of + custom TestSpec to be added to the default PyAVD test index and a set of test names to skip. + + When creating test definitions for the catalog, PyAVD will use the FabricData instance containing + the structured configurations of all devices in the fabric. Test definitions can be omitted from + the catalog if the required data is not available for a specific device. You can pass a custom + logger and set the log level to DEBUG to see which test definitions are skipped and the reason why. + Parameters ---------- hostname : str @@ -33,9 +42,9 @@ def get_device_anta_catalog( of all devices in the fabric to generate the catalog. The instance must be created using the `get_fabric_data` function of this module. custom_test_specs : list[TestSpec] - Optional user-defined list of TestSpec to be added to the default PyAVD test specs. + Optional user-defined list of TestSpec to be added to the default PyAVD test index. skip_tests : set[str] - Optional set of test names to skip from the default PyAVD test specs. + Optional set of test names to skip from the default PyAVD test index. logger : logging.Logger Optional logger to use for logging messages. If not provided, the `pyavd` logger will be used. @@ -47,7 +56,7 @@ def get_device_anta_catalog( logger = logger or LOGGER from ._anta.utils import ConfigManager, create_catalog - from ._anta.utils.test_loader import PYAVD_TEST_SPECS + from ._anta.utils.test_loader import PYAVD_TEST_INDEX custom_test_specs = custom_test_specs or [] skip_tests = skip_tests or set() @@ -56,7 +65,7 @@ def get_device_anta_catalog( config_manager = ConfigManager(hostname, fabric_data) # Filter out skipped tests and add custom test specs - filtered_test_specs = [test for test in PYAVD_TEST_SPECS if test.name not in skip_tests] + filtered_test_specs = [test for test in PYAVD_TEST_INDEX if test.name not in skip_tests] filtered_test_specs.extend([test for test in custom_test_specs if test not in filtered_test_specs]) return create_catalog(config_manager, filtered_test_specs, logger=logger)