diff --git a/.pre-commit-config.yaml b/.pre-commit-config.yaml index fbf8c3fd2..b13d8ba60 100644 --- a/.pre-commit-config.yaml +++ b/.pre-commit-config.yaml @@ -69,6 +69,7 @@ repos: - types-pyOpenSSL - pylint_pydantic - pytest + - respx - repo: https://github.com/codespell-project/codespell rev: v2.3.0 diff --git a/pyproject.toml b/pyproject.toml index 7202d4839..00f37a4d1 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -72,6 +72,7 @@ dev = [ "pytest-httpx>=0.30.0", "pytest-metadata>=3.0.0", "pytest>=7.4.0", + "respx", "ruff>=0.5.4,<0.7.0", "tox>=4.10.0,<5.0.0", "types-PyYAML", diff --git a/tests/benchmark/__init__.py b/tests/benchmark/__init__.py new file mode 100644 index 000000000..23f788698 --- /dev/null +++ b/tests/benchmark/__init__.py @@ -0,0 +1,4 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Benchmark package for ANTA.""" diff --git a/tests/benchmark/data.py b/tests/benchmark/data.py new file mode 100644 index 000000000..4fdc4367c --- /dev/null +++ b/tests/benchmark/data.py @@ -0,0 +1,61 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Loader of DATA from all tests/units/anta_tests modules.""" + +import importlib +import pkgutil +from collections.abc import Generator +from pathlib import Path +from types import ModuleType +from typing import Any + +from anta.catalog import AntaCatalog + +DATA_DIR: Path = Path(__file__).parent.parent.resolve() / "data" + + +def import_test_modules(package_name: str) -> Generator[ModuleType, None, None]: + """Yield all test modules from the given package.""" + package = importlib.import_module(package_name) + prefix = package.__name__ + "." + for _, module_name, is_pkg in pkgutil.walk_packages(package.__path__, prefix): + if not is_pkg and module_name.split(".")[-1].startswith("test_"): + module = importlib.import_module(module_name) + if hasattr(module, "DATA"): + yield module + + +def collect_outputs() -> dict[str, Any]: + """Collect DATA from all unit test modules and return a dictionary of outputs per test.""" + outputs = {} + for module in import_test_modules("tests.units.anta_tests"): + for test_data in module.DATA: + test = test_data["test"].__name__ + if test not in outputs: + outputs[test] = test_data["eos_data"][0] + + return outputs + + +def load_catalog(filename: Path) -> AntaCatalog: + """Load a catalog from a Path.""" + catalog = AntaCatalog.parse(filename) + + # Removing filters for testing purposes + for test in catalog.tests: + test.inputs.filters = None + return catalog + + +def load_catalogs() -> dict[str, AntaCatalog]: + """Load catalogs from the data directory.""" + return { + "small": load_catalog(DATA_DIR / "test_catalog.yml"), + "medium": load_catalog(DATA_DIR / "test_catalog_medium.yml"), + "large": load_catalog(DATA_DIR / "test_catalog_large.yml"), + } + + +OUTPUTS = collect_outputs() +CATALOGS = load_catalogs() diff --git a/tests/benchmark/patched_objects.py b/tests/benchmark/patched_objects.py new file mode 100644 index 000000000..ea8fd67a5 --- /dev/null +++ b/tests/benchmark/patched_objects.py @@ -0,0 +1,13 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Patched objects for the ANTA benchmark tests.""" + +from anta.device import AsyncEOSDevice + + +async def mock_refresh(self: AsyncEOSDevice) -> None: + """Mock the refresh method for the AsyncEOSDevice object.""" + self.hw_model = "cEOSLab" + self.established = True + self.is_online = True diff --git a/tests/benchmark/test_runner.py b/tests/benchmark/test_runner.py new file mode 100644 index 000000000..9914c5060 --- /dev/null +++ b/tests/benchmark/test_runner.py @@ -0,0 +1,52 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Benchmark ANTA runner.""" + +from typing import Literal +from unittest.mock import patch + +import pytest +import respx + +from anta.device import AsyncEOSDevice +from anta.result_manager import ResultManager +from anta.runner import main as anta_runner + +from .patched_objects import mock_refresh +from .utils import generate_inventory, generate_response, get_catalog + + +# Parametrize the test to run with different inventory sizes, and test multipliers if needed +@pytest.mark.asyncio +@pytest.mark.respx(assert_all_mocked=True, assert_all_called=True) +@pytest.mark.parametrize( + ("inventory_size", "catalog_size"), + [ + (10, "small"), + (10, "medium"), + (10, "large"), + ], + ids=["small_run", "medium_run", "large_run"], +) +async def test_runner(respx_mock: respx.MockRouter, inventory_size: int, catalog_size: Literal["small", "medium", "large"]) -> None: + """Test the ANTA runner.""" + # We mock all POST requests to eAPI + route = respx_mock.route(path="/command-api", method="POST") + + # We also mock all responses using data from the unit tests + route.side_effect = generate_response + + # Create the required ANTA objects + inventory = generate_inventory(inventory_size) + catalog = get_catalog(catalog_size) + manager = ResultManager() + + # Apply the patches for the run + with patch.object(AsyncEOSDevice, "refresh", mock_refresh): + # Run ANTA + await anta_runner(manager, inventory, catalog) + + # NOTE: See if we want to generate a report and benchmark + + assert respx_mock.calls.called diff --git a/tests/benchmark/utils.py b/tests/benchmark/utils.py new file mode 100644 index 000000000..9ba74c344 --- /dev/null +++ b/tests/benchmark/utils.py @@ -0,0 +1,66 @@ +# Copyright (c) 2023-2024 Arista Networks, Inc. +# Use of this source code is governed by the Apache License 2.0 +# that can be found in the LICENSE file. +"""Utils for the ANTA benchmark tests.""" + +import json + +import httpx + +from anta.catalog import AntaCatalog +from anta.device import AsyncEOSDevice +from anta.inventory import AntaInventory + +from .data import CATALOGS, OUTPUTS + + +def get_catalog(size: str) -> AntaCatalog: + """Return the catalog for the given size.""" + return CATALOGS[size] + + +def generate_response(request: httpx.Request) -> httpx.Response: + """Generate a response for the eAPI request.""" + jsonrpc = json.loads(request.content) + req_id = jsonrpc["id"] + ofmt = jsonrpc["params"]["format"] + + # Extract the test name from the request ID + test_name = req_id.split("-")[1] + + # This should never happen, but better be safe than sorry + if test_name not in OUTPUTS: + msg = f"Error while generating a mock response for test {test_name}: test not found in unit tests data." + raise RuntimeError(msg) + + output = OUTPUTS[test_name] + + result = {"output": output} if ofmt == "text" else output + + return httpx.Response( + status_code=200, + json={ + "jsonrpc": "2.0", + "id": req_id, + "result": [result], + }, + ) + + +def generate_inventory(size: int = 10) -> AntaInventory: + """Generate an ANTA inventory with fake devices.""" + inventory = AntaInventory() + for i in range(size): + inventory.add_device( + AsyncEOSDevice( + host=f"device-{i}.example.com", + username="admin", + password="admin", # noqa: S106 + name=f"device-{i}", + enable_password="admin", # noqa: S106 + enable=True, + disable_cache=True, + ) + ) + + return inventory diff --git a/tests/conftest.py b/tests/conftest.py index e31533840..8e70fc463 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -12,10 +12,7 @@ # Load fixtures from dedicated file tests/lib/fixture.py # As well as pytest_asyncio plugin to test co-routines -pytest_plugins = [ - "tests.lib.fixture", - "pytest_asyncio", -] +pytest_plugins = ["tests.lib.fixture", "pytest_asyncio"] # Enable nice assert messages # https://docs.pytest.org/en/7.1.x/how-to/writing_plugins.html#assertion-rewriting