diff --git a/.github/workflows/code-testing.yml b/.github/workflows/code-testing.yml index 4a63f5677..3a66c5cf3 100644 --- a/.github/workflows/code-testing.yml +++ b/.github/workflows/code-testing.yml @@ -133,3 +133,20 @@ jobs: run: pip install .[doc] - name: "Build mkdocs documentation offline" run: mkdocs build + benchmarks: + name: Benchmark ANTA for Python 3.12 + runs-on: ubuntu-latest + needs: [test-python] + steps: + - uses: actions/checkout@v4 + - name: Setup Python + uses: actions/setup-python@v5 + with: + python-version: "3.12" + - name: Install dependencies + run: pip install .[dev] + - name: Run benchmarks + uses: CodSpeedHQ/action@v3 + with: + token: ${{ secrets.CODSPEED_TOKEN }} + run: pytest --codspeed --no-cov --log-cli-level INFO tests/benchmark \ No newline at end of file diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index ec89d26e5..ba1e0d8a2 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -69,6 +69,7 @@ repos: - types-pyOpenSSL - pylint_pydantic - pytest + - pytest-codspeed - respx - repo: https://github.com/codespell-project/codespell diff --git a/anta/catalog.py b/anta/catalog.py index ee56639f7..b5a77ad25 100644 --- a/anta/catalog.py +++ b/anta/catalog.py @@ -25,8 +25,14 @@ from anta.models import AntaTest if TYPE_CHECKING: + import sys from types import ModuleType + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + logger = logging.getLogger(__name__) # { : [ { : }, ... ] } @@ -123,7 +129,7 @@ def instantiate_inputs( raise ValueError(msg) @model_validator(mode="after") - def check_inputs(self) -> AntaTestDefinition: + def check_inputs(self) -> Self: """Check the `inputs` field typing. The `inputs` class attribute needs to be an instance of the AntaTest.Input subclass defined in the class `test`. diff --git a/anta/tests/field_notices.py b/anta/tests/field_notices.py index 71a11749f..6f98a2c9a 100644 --- a/anta/tests/field_notices.py +++ b/anta/tests/field_notices.py @@ -196,4 +196,4 @@ def test(self) -> None: self.result.is_success("FN72 is mitigated") return # We should never hit this point - self.result.is_error("Error in running test - FixedSystemvrm1 not found") + self.result.is_failure("Error in running test - Component FixedSystemvrm1 not found in 'show version'") diff --git a/anta/tests/interfaces.py b/anta/tests/interfaces.py index 9ff1cf357..32b85d493 100644 --- a/anta/tests/interfaces.py +++ b/anta/tests/interfaces.py @@ -71,7 +71,7 @@ def test(self) -> None: if ((duplex := (interface := interfaces["interfaces"][intf]).get("duplex", None)) is not None and duplex != duplex_full) or ( (members := interface.get("memberInterfaces", None)) is not None and any(stats["duplex"] != duplex_full for stats in members.values()) ): - self.result.is_error(f"Interface {intf} or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented.") + self.result.is_failure(f"Interface {intf} or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented.") return if (bandwidth := interfaces["interfaces"][intf]["bandwidth"]) == 0: @@ -705,7 +705,7 @@ def test(self) -> None: input_interface_detail = interface break else: - self.result.is_error(f"Could not find `{intf}` in the input interfaces. {GITHUB_SUGGESTION}") + self.result.is_failure(f"Could not find `{intf}` in the input interfaces. {GITHUB_SUGGESTION}") continue input_primary_ip = str(input_interface_detail.primary_ip) diff --git a/anta/tests/mlag.py b/anta/tests/mlag.py index 1d17ab642..c894b98b6 100644 --- a/anta/tests/mlag.py +++ b/anta/tests/mlag.py @@ -123,10 +123,7 @@ class VerifyMlagConfigSanity(AntaTest): def test(self) -> None: """Main test function for VerifyMlagConfigSanity.""" command_output = self.instance_commands[0].json_output - if (mlag_status := get_value(command_output, "mlagActive")) is None: - self.result.is_error(message="Incorrect JSON response - 'mlagActive' state was not found") - return - if mlag_status is False: + if command_output["mlagActive"] is False: self.result.is_skipped("MLAG is disabled") return keys_to_verify = ["globalConfiguration", "interfaceConfiguration"] diff --git a/anta/tests/routing/bgp.py b/anta/tests/routing/bgp.py index 97f919876..a37328608 100644 --- a/anta/tests/routing/bgp.py +++ b/anta/tests/routing/bgp.py @@ -8,7 +8,7 @@ from __future__ import annotations from ipaddress import IPv4Address, IPv4Network, IPv6Address -from typing import Any, ClassVar +from typing import TYPE_CHECKING, Any, ClassVar from pydantic import BaseModel, Field, PositiveInt, model_validator from pydantic.v1.utils import deep_update @@ -18,6 +18,14 @@ from anta.models import AntaCommand, AntaTemplate, AntaTest from anta.tools import get_item, get_value +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + def _add_bgp_failures(failures: dict[tuple[str, str | None], dict[str, Any]], afi: Afi, safi: Safi | None, vrf: str, issue: str | dict[str, Any]) -> None: """Add a BGP failure entry to the given `failures` dictionary. @@ -235,7 +243,7 @@ class BgpAfi(BaseModel): """Number of expected BGP peer(s).""" @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpAfi class. If afi is either ipv4 or ipv6, safi must be provided. @@ -375,7 +383,7 @@ class BgpAfi(BaseModel): """ @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpAfi class. If afi is either ipv4 or ipv6, safi must be provided. @@ -522,7 +530,7 @@ class BgpAfi(BaseModel): """List of BGP IPv4 or IPv6 peer.""" @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpAfi class. If afi is either ipv4 or ipv6, safi must be provided and vrf must NOT be all. @@ -1485,7 +1493,7 @@ class BgpPeer(BaseModel): """Outbound route map applied, defaults to None.""" @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the inputs provided to the BgpPeer class. At least one of 'inbound' or 'outbound' route-map must be provided. diff --git a/anta/tests/routing/generic.py b/anta/tests/routing/generic.py index cd9cf0d24..d1322a50d 100644 --- a/anta/tests/routing/generic.py +++ b/anta/tests/routing/generic.py @@ -9,12 +9,21 @@ from functools import cache from ipaddress import IPv4Address, IPv4Interface -from typing import ClassVar, Literal +from typing import TYPE_CHECKING, ClassVar, Literal from pydantic import model_validator +from anta.custom_types import PositiveInteger from anta.models import AntaCommand, AntaTemplate, AntaTest +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + class VerifyRoutingProtocolModel(AntaTest): """Verifies the configured routing protocol model is the one we expect. @@ -84,13 +93,13 @@ class VerifyRoutingTableSize(AntaTest): class Input(AntaTest.Input): """Input model for the VerifyRoutingTableSize test.""" - minimum: int + minimum: PositiveInteger """Expected minimum routing table size.""" - maximum: int + maximum: PositiveInteger """Expected maximum routing table size.""" - @model_validator(mode="after") # type: ignore[misc] - def check_min_max(self) -> AntaTest.Input: + @model_validator(mode="after") + def check_min_max(self) -> Self: """Validate that maximum is greater than minimum.""" if self.minimum > self.maximum: msg = f"Minimum {self.minimum} is greater than maximum {self.maximum}" diff --git a/anta/tests/security.py b/anta/tests/security.py index ae5b9bebd..007022dc5 100644 --- a/anta/tests/security.py +++ b/anta/tests/security.py @@ -9,7 +9,7 @@ # mypy: disable-error-code=attr-defined from datetime import datetime, timezone from ipaddress import IPv4Address -from typing import ClassVar +from typing import TYPE_CHECKING, ClassVar, get_args from pydantic import BaseModel, Field, model_validator @@ -17,6 +17,14 @@ from anta.models import AntaCommand, AntaTemplate, AntaTest from anta.tools import get_failed_logs, get_item, get_value +if TYPE_CHECKING: + import sys + + if sys.version_info >= (3, 11): + from typing import Self + else: + from typing_extensions import Self + class VerifySSHStatus(AntaTest): """Verifies if the SSHD agent is disabled in the default VRF. @@ -47,7 +55,7 @@ def test(self) -> None: try: line = next(line for line in command_output.split("\n") if line.startswith("SSHD status")) except StopIteration: - self.result.is_error("Could not find SSH status in returned output.") + self.result.is_failure("Could not find SSH status in returned output.") return status = line.split("is ")[1] @@ -416,19 +424,19 @@ class APISSLCertificate(BaseModel): """The encryption algorithm key size of the certificate.""" @model_validator(mode="after") - def validate_inputs(self: BaseModel) -> BaseModel: + def validate_inputs(self) -> Self: """Validate the key size provided to the APISSLCertificates class. If encryption_algorithm is RSA then key_size should be in {2048, 3072, 4096}. If encryption_algorithm is ECDSA then key_size should be in {256, 384, 521}. """ - if self.encryption_algorithm == "RSA" and self.key_size not in RsaKeySize.__args__: - msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for RSA encryption. Allowed sizes are {RsaKeySize.__args__}." + if self.encryption_algorithm == "RSA" and self.key_size not in get_args(RsaKeySize): + msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for RSA encryption. Allowed sizes are {get_args(RsaKeySize)}." raise ValueError(msg) - if self.encryption_algorithm == "ECDSA" and self.key_size not in EcdsaKeySize.__args__: - msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for ECDSA encryption. Allowed sizes are {EcdsaKeySize.__args__}." + if self.encryption_algorithm == "ECDSA" and self.key_size not in get_args(EcdsaKeySize): + msg = f"`{self.certificate_name}` key size {self.key_size} is invalid for ECDSA encryption. Allowed sizes are {get_args(EcdsaKeySize)}." raise ValueError(msg) return self diff --git a/anta/tests/system.py b/anta/tests/system.py index 486e5e1ed..d620d533b 100644 --- a/anta/tests/system.py +++ b/anta/tests/system.py @@ -89,9 +89,6 @@ class VerifyReloadCause(AntaTest): def test(self) -> None: """Main test function for VerifyReloadCause.""" command_output = self.instance_commands[0].json_output - if "resetCauses" not in command_output: - self.result.is_error(message="No reload causes available") - return if len(command_output["resetCauses"]) == 0: # No reload causes self.result.is_success() diff --git a/pyproject.toml b/pyproject.toml index d874b4edb..80a59e9ea 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -68,6 +68,8 @@ dev = [ "pytest-asyncio>=0.21.1", "pytest-cov>=4.1.0", "pytest-dependency", + "pytest-codspeed>=2.2.0", + "respx", "pytest-html>=3.2.0", "pytest-httpx>=0.30.0", "pytest-metadata>=3.0.0", @@ -171,6 +173,7 @@ render_collapsed = true testpaths = ["tests"] asyncio_mode = "auto" asyncio_default_fixture_loop_scope = "function" +norecursedirs = ["tests/benchmark"] # Do not run performance testing outside of Codspeed filterwarnings = [ # cvprac is raising the next warning "default:pkg_resources is deprecated:DeprecationWarning", @@ -450,13 +453,17 @@ disable = [ # Any rule listed here can be disabled: https://github.com/astral-sh "keyword-arg-before-vararg", "protected-access", "too-many-arguments", - "too-many-positional-arguments", # New in pylint 3.3.0 + "too-many-positional-arguments", "wrong-import-position", "pointless-statement", "broad-exception-caught", "line-too-long", "unused-variable", "redefined-builtin", + "global-statement", + "reimported", + "wrong-import-order", + "wrong-import-position", "abstract-class-instantiated", # Overlap with https://mypy.readthedocs.io/en/stable/error_code_list.html#check-instantiation-of-abstract-classes-abstract "unexpected-keyword-arg", # Overlap with https://mypy.readthedocs.io/en/stable/error_code_list.html#check-arguments-in-calls-call-arg and other rules "no-value-for-parameter" # Overlap with https://mypy.readthedocs.io/en/stable/error_code_list.html#check-arguments-in-calls-call-arg diff --git a/tests/benchmark/__init__.py b/tests/benchmark/__init__.py new file mode 100644 index 000000000..7714c95e7 --- /dev/null +++ b/tests/benchmark/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Benchmark tests for ANTA.""" diff --git a/tests/benchmark/conftest.py b/tests/benchmark/conftest.py new file mode 100644 index 000000000..c07cc99c2 --- /dev/null +++ b/tests/benchmark/conftest.py @@ -0,0 +1,40 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Fixtures for benchmarking ANTA.""" + +import logging + +import pytest +import respx +from _pytest.terminal import TerminalReporter + +from anta.catalog import AntaCatalog + +from .utils import AntaMockEnvironment + +logger = logging.getLogger(__name__) + +TEST_CASE_COUNT = None + + +@pytest.fixture(name="anta_mock_env", scope="session") # We want this fixture to have a scope set to session to avoid reparsing all the unit tests data. +def anta_mock_env_fixture() -> AntaMockEnvironment: + """Return an AntaMockEnvironment for this test session. Also configure respx to mock eAPI responses.""" + global TEST_CASE_COUNT # noqa: PLW0603 + eapi_route = respx.post(path="/command-api", headers={"Content-Type": "application/json-rpc"}) + env = AntaMockEnvironment() + TEST_CASE_COUNT = env.tests_count + eapi_route.side_effect = env.eapi_response + return env + + +@pytest.fixture # This fixture should have a scope set to function as the indexing result is stored in this object +def catalog(anta_mock_env: AntaMockEnvironment) -> AntaCatalog: + """Fixture that return an ANTA catalog from the AntaMockEnvironment of this test session.""" + return anta_mock_env.catalog + + +def pytest_terminal_summary(terminalreporter: TerminalReporter) -> None: + """Display the total number of ANTA unit test cases used to benchmark.""" + terminalreporter.write_sep("=", f"{TEST_CASE_COUNT} ANTA test cases") diff --git a/tests/benchmark/test_anta.py b/tests/benchmark/test_anta.py new file mode 100644 index 000000000..82d08cf6e --- /dev/null +++ b/tests/benchmark/test_anta.py @@ -0,0 +1,110 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Benchmark tests for ANTA.""" + +import asyncio +import logging +from unittest.mock import patch + +import pytest +import respx +from pytest_codspeed import BenchmarkFixture + +from anta.catalog import AntaCatalog +from anta.inventory import AntaInventory +from anta.result_manager import ResultManager +from anta.result_manager.models import AntaTestStatus +from anta.runner import main + +from .utils import collect, collect_commands + +logger = logging.getLogger(__name__) + + +@pytest.mark.parametrize( + "inventory", + [ + pytest.param({"count": 1, "disable_cache": True, "reachable": False}, id="1 device"), + pytest.param({"count": 2, "disable_cache": True, "reachable": False}, id="2 devices"), + ], + indirect=True, +) +def test_anta_dry_run(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None: + """Test and benchmark ANTA in Dry-Run Mode.""" + # Disable logging during ANTA execution to avoid having these function time in benchmarks + logging.disable() + + def bench() -> ResultManager: + """Need to wrap the ANTA Runner to instantiate a new ResultManger for each benchmark run.""" + manager = ResultManager() + asyncio.run(main(manager, inventory, catalog, dry_run=True)) + return manager + + manager = benchmark(bench) + + logging.disable(logging.NOTSET) + if len(manager.results) != 0: + pytest.fail("ANTA Dry-Run mode should not return any result", pytrace=False) + if catalog.final_tests_count != len(inventory) * len(catalog.tests): + pytest.fail(f"Expected {len(inventory) * len(catalog.tests)} selected tests but got {catalog.final_tests_count}", pytrace=False) + bench_info = ( + "\n--- ANTA NRFU Dry-Run Benchmark Information ---\n" f"Selected tests: {catalog.final_tests_count}\n" "-----------------------------------------------" + ) + logger.info(bench_info) + + +@pytest.mark.parametrize( + "inventory", + [ + pytest.param({"count": 1, "disable_cache": True}, id="1 device"), + pytest.param({"count": 2, "disable_cache": True}, id="2 devices"), + ], + indirect=True, +) +@patch("anta.models.AntaTest.collect", collect) +@patch("anta.device.AntaDevice.collect_commands", collect_commands) +@respx.mock # Mock eAPI responses +def test_anta(benchmark: BenchmarkFixture, catalog: AntaCatalog, inventory: AntaInventory) -> None: + """Test and benchmark ANTA. Mock eAPI responses.""" + # Disable logging during ANTA execution to avoid having these function time in benchmarks + logging.disable() + + def bench() -> ResultManager: + """Need to wrap the ANTA Runner to instantiate a new ResultManger for each benchmark run.""" + manager = ResultManager() + asyncio.run(main(manager, inventory, catalog)) + return manager + + manager = benchmark(bench) + + logging.disable(logging.NOTSET) + + if len(catalog.tests) * len(inventory) != len(manager.results): + # This could mean duplicates exist. + # TODO: consider removing this code and refactor unit test data as a dictionary with tuple keys instead of a list + seen = set() + dupes = [] + for test in catalog.tests: + if test in seen: + dupes.append(test) + else: + seen.add(test) + if dupes: + for test in dupes: + msg = f"Found duplicate in test catalog: {test}" + logger.error(msg) + pytest.fail(f"Expected {len(catalog.tests) * len(inventory)} test results but got {len(manager.results)}", pytrace=False) + bench_info = ( + "\n--- ANTA NRFU Benchmark Information ---\n" + f"Test results: {len(manager.results)}\n" + f"Success: {manager.get_total_results({AntaTestStatus.SUCCESS})}\n" + f"Failure: {manager.get_total_results({AntaTestStatus.FAILURE})}\n" + f"Skipped: {manager.get_total_results({AntaTestStatus.SKIPPED})}\n" + f"Error: {manager.get_total_results({AntaTestStatus.ERROR})}\n" + f"Unset: {manager.get_total_results({AntaTestStatus.UNSET})}\n" + "---------------------------------------" + ) + logger.info(bench_info) + assert manager.get_total_results({AntaTestStatus.ERROR}) == 0 + assert manager.get_total_results({AntaTestStatus.UNSET}) == 0 diff --git a/tests/benchmark/utils.py b/tests/benchmark/utils.py new file mode 100644 index 000000000..1017cfe0a --- /dev/null +++ b/tests/benchmark/utils.py @@ -0,0 +1,164 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Utils for the ANTA benchmark tests.""" + +from __future__ import annotations + +import asyncio +import copy +import importlib +import json +import pkgutil +from typing import TYPE_CHECKING, Any + +import httpx + +from anta.catalog import AntaCatalog, AntaTestDefinition +from anta.models import AntaCommand, AntaTest + +if TYPE_CHECKING: + from collections.abc import Generator + from types import ModuleType + + from anta.device import AntaDevice + + +async def collect(self: AntaTest) -> None: + """Patched anta.models.AntaTest.collect() method. + + When generating the catalog, we inject a unit test case name in the custom_field input to be able to retrieve the eos_data for this specific test. + We use this unit test case name in the eAPI request ID. + """ + if self.inputs.result_overwrite is None or self.inputs.result_overwrite.custom_field is None: + msg = f"The custom_field input is not present for test {self.name}" + raise RuntimeError(msg) + await self.device.collect_commands(self.instance_commands, collection_id=f"{self.name}:{self.inputs.result_overwrite.custom_field}") + + +async def collect_commands(self: AntaDevice, commands: list[AntaCommand], collection_id: str) -> None: + """Patched anta.device.AntaDevice.collect_commands() method. + + For the same reason as above, we inject the command index of the test to the eAPI request ID. + """ + await asyncio.gather(*(self.collect(command=command, collection_id=f"{collection_id}:{idx}") for idx, command in enumerate(commands))) + + +class AntaMockEnvironment: # pylint: disable=too-few-public-methods + """Generate an ANTA test catalog from the unit tests data. It can be accessed using the `catalog` attribute of this class instance. + + Also provide the attribute 'eos_data_catalog` with the output of all the commands used in the test catalog. + + Each module in `tests.units.anta_tests` has a `DATA` constant. + The `DATA` structure is a list of dictionaries used to parametrize the test. The list elements have the following keys: + - `name` (str): Test name as displayed by Pytest. + - `test` (AntaTest): An AntaTest subclass imported in the test module - e.g. VerifyUptime. + - `eos_data` (list[dict]): List of data mocking EOS returned data to be passed to the test. + - `inputs` (dict): Dictionary to instantiate the `test` inputs as defined in the class from `test`. + + The keys of `eos_data_catalog` is the tuple (DATA['test'], DATA['name']). The values are `eos_data`. + """ + + def __init__(self) -> None: + self._catalog, self.eos_data_catalog = self._generate_catalog() + self.tests_count = len(self._catalog.tests) + + @property + def catalog(self) -> AntaCatalog: + """AntaMockEnvironment object will always return a new AntaCatalog object based on the initial parsing. + + This is because AntaCatalog objects store indexes when tests are run and we want a new object each time a test is run. + """ + return copy.deepcopy(self._catalog) + + def _generate_catalog(self) -> tuple[AntaCatalog, dict[tuple[str, str], list[dict[str, Any]]]]: + """Generate the `catalog` and `eos_data_catalog` attributes.""" + + def import_test_modules() -> Generator[ModuleType, None, None]: + """Yield all test modules from the given package.""" + package = importlib.import_module("tests.units.anta_tests") + prefix = package.__name__ + "." + for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix): + if not is_pkg and module_name.split(".")[-1].startswith("test_"): + module = importlib.import_module(module_name) + if hasattr(module, "DATA"): + yield module + + test_definitions = [] + eos_data_catalog = {} + for module in import_test_modules(): + for test_data in module.DATA: + test = test_data["test"] + result_overwrite = AntaTest.Input.ResultOverwrite(custom_field=test_data["name"]) + if test_data["inputs"] is None: + inputs = test.Input(result_overwrite=result_overwrite) + else: + inputs = test.Input(**test_data["inputs"], result_overwrite=result_overwrite) + test_definition = AntaTestDefinition( + test=test, + inputs=inputs, + ) + eos_data_catalog[(test.__name__, test_data["name"])] = test_data["eos_data"] + test_definitions.append(test_definition) + + return (AntaCatalog(tests=test_definitions), eos_data_catalog) + + def eapi_response(self, request: httpx.Request) -> httpx.Response: + """Mock eAPI response. + + If the eAPI request ID has the format `ANTA-{test name}:{unit test name}:{command index}-{command ID}`, + the function will return the eos_data from the unit test case. + + Otherwise, it will mock 'show version' command or raise an Exception. + """ + words_count = 3 + + def parse_req_id(req_id: str) -> tuple[str, str, int] | None: + """Parse the patched request ID from the eAPI request.""" + req_id = req_id.removeprefix("ANTA-").rpartition("-")[0] + words = req_id.split(":", words_count) + if len(words) == words_count: + test_name, unit_test_name, command_index = words + return test_name, unit_test_name, int(command_index) + return None + + jsonrpc = json.loads(request.content) + assert jsonrpc["method"] == "runCmds" + commands = jsonrpc["params"]["cmds"] + ofmt = jsonrpc["params"]["format"] + req_id: str = jsonrpc["id"] + result = None + + # Extract the test name, unit test name, and command index from the request ID + if (words := parse_req_id(req_id)) is not None: + test_name, unit_test_name, idx = words + + # This should never happen, but better be safe than sorry + if (test_name, unit_test_name) not in self.eos_data_catalog: + msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: eos_data not found" + raise RuntimeError(msg) + + eos_data = self.eos_data_catalog[(test_name, unit_test_name)] + + # This could happen if the unit test data is not correctly defined + if idx >= len(eos_data): + msg = f"Error while generating a mock response for unit test {unit_test_name} of test {test_name}: missing test case in eos_data" + raise RuntimeError(msg) + result = {"output": eos_data[idx]} if ofmt == "text" else eos_data[idx] + elif {"cmd": "show version"} in commands and ofmt == "json": + # Mock 'show version' request performed during inventory refresh. + result = { + "modelName": "pytest", + } + + if result is not None: + return httpx.Response( + status_code=200, + json={ + "jsonrpc": "2.0", + "id": req_id, + "result": [result], + }, + ) + msg = f"The following eAPI Request has not been mocked: {jsonrpc}" + raise NotImplementedError(msg) diff --git a/tests/units/anta_tests/routing/test_generic.py b/tests/units/anta_tests/routing/test_generic.py index 0ac43f3c5..20f83b92b 100644 --- a/tests/units/anta_tests/routing/test_generic.py +++ b/tests/units/anta_tests/routing/test_generic.py @@ -5,8 +5,12 @@ from __future__ import annotations +import sys from typing import Any +import pytest +from pydantic import ValidationError + from anta.tests.routing.generic import VerifyRoutingProtocolModel, VerifyRoutingTableEntry, VerifyRoutingTableSize from tests.units.anta_tests import test @@ -66,16 +70,6 @@ "inputs": {"minimum": 42, "maximum": 666}, "expected": {"result": "failure", "messages": ["routing-table has 1000 routes and not between min (42) and maximum (666)"]}, }, - { - "name": "error-max-smaller-than-min", - "test": VerifyRoutingTableSize, - "eos_data": [{}], - "inputs": {"minimum": 666, "maximum": 42}, - "expected": { - "result": "error", - "messages": ["Minimum 666 is greater than maximum 42"], - }, - }, { "name": "success", "test": VerifyRoutingTableEntry, @@ -310,11 +304,33 @@ "inputs": {"vrf": "default", "routes": ["10.1.0.1", "10.1.0.2"], "collect": "all"}, "expected": {"result": "failure", "messages": ["The following route(s) are missing from the routing table of VRF default: ['10.1.0.2']"]}, }, - { - "name": "collect-input-error", - "test": VerifyRoutingTableEntry, - "eos_data": {}, - "inputs": {"vrf": "default", "routes": ["10.1.0.1", "10.1.0.2"], "collect": "not-valid"}, - "expected": {"result": "error", "messages": ["Inputs are not valid"]}, - }, ] + + +class TestVerifyRoutingTableSizeInputs: + """Test anta.tests.routing.generic.VerifyRoutingTableSize.Input.""" + + @pytest.mark.parametrize( + ("minimum", "maximum"), + [ + pytest.param(0, 0, id="zero"), + pytest.param(1, 2, id="1<2"), + pytest.param(0, sys.maxsize, id="max"), + ], + ) + def test_valid(self, minimum: int, maximum: int) -> None: + """Test VerifyRoutingTableSize valid inputs.""" + VerifyRoutingTableSize.Input(minimum=minimum, maximum=maximum) + + @pytest.mark.parametrize( + ("minimum", "maximum"), + [ + pytest.param(-2, -1, id="negative"), + pytest.param(2, 1, id="2<1"), + pytest.param(sys.maxsize, 0, id="max"), + ], + ) + def test_invalid(self, minimum: int, maximum: int) -> None: + """Test VerifyRoutingTableSize invalid inputs.""" + with pytest.raises(ValidationError): + VerifyRoutingTableSize.Input(minimum=minimum, maximum=maximum) diff --git a/tests/units/anta_tests/test_configuration.py b/tests/units/anta_tests/test_configuration.py index dbe22d365..d8f86beaa 100644 --- a/tests/units/anta_tests/test_configuration.py +++ b/tests/units/anta_tests/test_configuration.py @@ -60,14 +60,4 @@ "inputs": {"regex_patterns": ["bla", "bleh"]}, "expected": {"result": "failure", "messages": ["Following patterns were not found: 'bla','bleh'"]}, }, - { - "name": "failure-invalid-regex", - "test": VerifyRunningConfigLines, - "eos_data": ["enable password something\nsome other line"], - "inputs": {"regex_patterns": ["["]}, - "expected": { - "result": "error", - "messages": ["1 validation error for Input\nregex_patterns.0\n Value error, Invalid regex: unterminated character set at position 0"], - }, - }, ] diff --git a/tests/units/anta_tests/test_field_notices.py b/tests/units/anta_tests/test_field_notices.py index a30604b8b..8e7c9d8b3 100644 --- a/tests/units/anta_tests/test_field_notices.py +++ b/tests/units/anta_tests/test_field_notices.py @@ -358,8 +358,8 @@ ], "inputs": None, "expected": { - "result": "error", - "messages": ["Error in running test - FixedSystemvrm1 not found"], + "result": "failure", + "messages": ["Error in running test - Component FixedSystemvrm1 not found in 'show version'"], }, }, ] diff --git a/tests/units/anta_tests/test_interfaces.py b/tests/units/anta_tests/test_interfaces.py index 73ef6c6aa..ea8106e84 100644 --- a/tests/units/anta_tests/test_interfaces.py +++ b/tests/units/anta_tests/test_interfaces.py @@ -652,7 +652,7 @@ ], "inputs": {"threshold": 70.0}, "expected": { - "result": "error", + "result": "failure", "messages": ["Interface Ethernet1/1 or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented."], }, }, @@ -797,7 +797,7 @@ ], "inputs": {"threshold": 70.0}, "expected": { - "result": "error", + "result": "failure", "messages": ["Interface Port-Channel31 or one of its member interfaces is not Full-Duplex. VerifyInterfaceUtilization has not been implemented."], }, }, diff --git a/tests/units/anta_tests/test_mlag.py b/tests/units/anta_tests/test_mlag.py index 1ef547259..193d69c2d 100644 --- a/tests/units/anta_tests/test_mlag.py +++ b/tests/units/anta_tests/test_mlag.py @@ -110,17 +110,6 @@ "inputs": None, "expected": {"result": "skipped", "messages": ["MLAG is disabled"]}, }, - { - "name": "error", - "test": VerifyMlagConfigSanity, - "eos_data": [ - { - "dummy": False, - }, - ], - "inputs": None, - "expected": {"result": "error", "messages": ["Incorrect JSON response - 'mlagActive' state was not found"]}, - }, { "name": "failure-global", "test": VerifyMlagConfigSanity, diff --git a/tests/units/anta_tests/test_security.py b/tests/units/anta_tests/test_security.py index 792b06595..549890ad5 100644 --- a/tests/units/anta_tests/test_security.py +++ b/tests/units/anta_tests/test_security.py @@ -7,6 +7,9 @@ from typing import Any +import pytest +from pydantic import ValidationError + from anta.tests.security import ( VerifyAPIHttpsSSL, VerifyAPIHttpStatus, @@ -39,7 +42,7 @@ "test": VerifySSHStatus, "eos_data": ["SSH per host connection limit is 20\nFIPS status: disabled\n\n"], "inputs": None, - "expected": {"result": "error", "messages": ["Could not find SSH status in returned output."]}, + "expected": {"result": "failure", "messages": ["Could not find SSH status in returned output."]}, }, { "name": "failure-ssh-disabled", @@ -581,40 +584,6 @@ ], }, }, - { - "name": "error-wrong-input-rsa", - "test": VerifyAPISSLCertificate, - "eos_data": [], - "inputs": { - "certificates": [ - { - "certificate_name": "ARISTA_ROOT_CA.crt", - "expiry_threshold": 30, - "common_name": "Arista Networks Internal IT Root Cert Authority", - "encryption_algorithm": "RSA", - "key_size": 256, - }, - ] - }, - "expected": {"result": "error", "messages": ["Allowed sizes are (2048, 3072, 4096)."]}, - }, - { - "name": "error-wrong-input-ecdsa", - "test": VerifyAPISSLCertificate, - "eos_data": [], - "inputs": { - "certificates": [ - { - "certificate_name": "ARISTA_SIGNING_CA.crt", - "expiry_threshold": 30, - "common_name": "AristaIT-ICA ECDSA Issuing Cert Authority", - "encryption_algorithm": "ECDSA", - "key_size": 2048, - }, - ] - }, - "expected": {"result": "error", "messages": ["Allowed sizes are (256, 384, 512)."]}, - }, { "name": "success", "test": VerifyBannerLogin, @@ -1229,3 +1198,69 @@ "expected": {"result": "failure", "messages": ["Hardware entropy generation is disabled."]}, }, ] + + +class TestAPISSLCertificate: + """Test anta.tests.security.VerifyAPISSLCertificate.Input.APISSLCertificate.""" + + @pytest.mark.parametrize( + ("model_params", "error"), + [ + pytest.param( + { + "certificate_name": "ARISTA_ROOT_CA.crt", + "expiry_threshold": 30, + "common_name": "Arista Networks Internal IT Root Cert Authority", + "encryption_algorithm": "RSA", + "key_size": 256, + }, + "Value error, `ARISTA_ROOT_CA.crt` key size 256 is invalid for RSA encryption. Allowed sizes are (2048, 3072, 4096).", + id="RSA_wrong_size", + ), + pytest.param( + { + "certificate_name": "ARISTA_SIGNING_CA.crt", + "expiry_threshold": 30, + "common_name": "AristaIT-ICA ECDSA Issuing Cert Authority", + "encryption_algorithm": "ECDSA", + "key_size": 2048, + }, + "Value error, `ARISTA_SIGNING_CA.crt` key size 2048 is invalid for ECDSA encryption. Allowed sizes are (256, 384, 512).", + id="ECDSA_wrong_size", + ), + ], + ) + def test_invalid(self, model_params: dict[str, Any], error: str) -> None: + """Test invalid inputs for anta.tests.security.VerifyAPISSLCertificate.Input.APISSLCertificate.""" + with pytest.raises(ValidationError) as exec_info: + VerifyAPISSLCertificate.Input.APISSLCertificate.model_validate(model_params) + assert error == exec_info.value.errors()[0]["msg"] + + @pytest.mark.parametrize( + "model_params", + [ + pytest.param( + { + "certificate_name": "ARISTA_SIGNING_CA.crt", + "expiry_threshold": 30, + "common_name": "AristaIT-ICA ECDSA Issuing Cert Authority", + "encryption_algorithm": "ECDSA", + "key_size": 256, + }, + id="ECDSA", + ), + pytest.param( + { + "certificate_name": "ARISTA_ROOT_CA.crt", + "expiry_threshold": 30, + "common_name": "Arista Networks Internal IT Root Cert Authority", + "encryption_algorithm": "RSA", + "key_size": 4096, + }, + id="RSA", + ), + ], + ) + def test_valid(self, model_params: dict[str, Any]) -> None: + """Test valid inputs for anta.tests.security.VerifyAPISSLCertificate.Input.APISSLCertificate.""" + VerifyAPISSLCertificate.Input.APISSLCertificate.model_validate(model_params) diff --git a/tests/units/anta_tests/test_system.py b/tests/units/anta_tests/test_system.py index 22b9787b2..1eda8a1d5 100644 --- a/tests/units/anta_tests/test_system.py +++ b/tests/units/anta_tests/test_system.py @@ -76,13 +76,6 @@ "inputs": None, "expected": {"result": "failure", "messages": ["Reload cause is: 'Reload after crash.'"]}, }, - { - "name": "error", - "test": VerifyReloadCause, - "eos_data": [{}], - "inputs": None, - "expected": {"result": "error", "messages": ["No reload causes available"]}, - }, { "name": "success-without-minidump", "test": VerifyCoredump, diff --git a/tests/units/test_custom_types.py b/tests/units/test_custom_types.py index e3dc09d25..697017105 100644 --- a/tests/units/test_custom_types.py +++ b/tests/units/test_custom_types.py @@ -30,6 +30,7 @@ bgp_multiprotocol_capabilities_abbreviations, interface_autocomplete, interface_case_sensitivity, + validate_regex, ) # ------------------------------------------------------------------------------ @@ -281,3 +282,36 @@ def test_interface_case_sensitivity_uppercase() -> None: assert interface_case_sensitivity("ETHERNET") == "ETHERNET" assert interface_case_sensitivity("VLAN") == "VLAN" assert interface_case_sensitivity("LOOPBACK") == "LOOPBACK" + + +@pytest.mark.parametrize( + "str_input", + [ + REGEX_BGP_IPV4_MPLS_VPN, + REGEX_BGP_IPV4_UNICAST, + REGEX_TYPE_PORTCHANNEL, + REGEXP_BGP_IPV4_MPLS_LABELS, + REGEXP_BGP_L2VPN_AFI, + REGEXP_INTERFACE_ID, + REGEXP_PATH_MARKERS, + REGEXP_TYPE_EOS_INTERFACE, + REGEXP_TYPE_HOSTNAME, + REGEXP_TYPE_VXLAN_SRC_INTERFACE, + ], +) +def test_validate_regex_valid(str_input: str) -> None: + """Test validate_regex with valid regex.""" + assert validate_regex(str_input) == str_input + + +@pytest.mark.parametrize( + ("str_input", "error"), + [ + pytest.param("[", "Invalid regex: unterminated character set at position 0", id="unterminated character"), + pytest.param("\\", r"Invalid regex: bad escape \(end of pattern\) at position 0", id="bad escape"), + ], +) +def test_validate_regex_invalid(str_input: str, error: str) -> None: + """Test validate_regex with invalid regex.""" + with pytest.raises(ValueError, match=error): + validate_regex(str_input)