From 619ea9421262060af13ad9ac9f1e32a860191fe8 Mon Sep 17 00:00:00 2001 From: Victoria Zotova Date: Thu, 8 Aug 2024 13:37:13 -0400 Subject: [PATCH] Changing mocking in tests for RestMiddleware and testing versioning --- porter/fields/base.py | 16 +++----- porter/main.py | 21 ++++++++--- tests/conftest.py | 49 +++++++++++++++++++++++- tests/test_bucket_sampling.py | 67 +++++++++++++++++++++++++++++++++ tests/test_get_ursulas.py | 71 +++++++++++++++++++++++++++++++++++ 5 files changed, 206 insertions(+), 18 deletions(-) diff --git a/porter/fields/base.py b/porter/fields/base.py index 8d1ef47..1d0c65f 100644 --- a/porter/fields/base.py +++ b/porter/fields/base.py @@ -3,7 +3,7 @@ import click from marshmallow import fields -from packaging.version import Version, parse +from packaging.version import parse from porter.fields.exceptions import InvalidInputData @@ -113,12 +113,8 @@ def _deserialize(self, value, attr, data, **kwargs): class VersionString(String): - def _serialize(self, value, attr, obj, **kwargs) -> str: - if type(value) is not Version: - raise InvalidInputData( - f"Unexpected object type, {type(value)}; expected Version" - ) - return str(value) - - def _deserialize(self, value, attr, data, **kwargs) -> list: - return parse(value) + def _validate(self, value): + try: + parse(value) + except Exception: + raise InvalidInputData(f"{self.name} must be a correct version.") diff --git a/porter/main.py b/porter/main.py index 3bea7a2..2a5f4d4 100644 --- a/porter/main.py +++ b/porter/main.py @@ -32,7 +32,7 @@ TreasureMap, ) from nucypher_core.umbral import PublicKey -from packaging.version import parse +from packaging.version import Version, parse from prometheus_flask_exporter import PrometheusMetrics import porter @@ -157,7 +157,7 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str): BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint) @staticmethod - def _is_version_greater_or_equal(min_version: str, version: str) -> bool: + def _is_version_greater_or_equal(min_version: Version, version: str) -> bool: return parse(version) >= min_version def _get_ursula_version(self, ursula: Ursula) -> str: @@ -179,6 +179,7 @@ def get_ursulas( "sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT ) duration = duration or 0 + parse_min_version = parse(min_version) if min_version else None reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration) available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir) @@ -198,8 +199,8 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo: try: # ensure node is up and reachable and check version version = self._get_ursula_version(ursula) - if min_version and not self._is_version_greater_or_equal( - min_version, version + if parse_min_version and not self._is_version_greater_or_equal( + parse_min_version, version ): raise ValueError( f"Ursula ({ursula_address}) has too old version ({version})" @@ -326,6 +327,7 @@ def bucket_sampling( "bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT ) duration = duration or 0 + parse_min_version = parse(min_version) if min_version else None if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING: raise ValueError("Bucket sampling is only for TACo Mainnet") @@ -387,6 +389,7 @@ def __init__(self, _reservoir, need_successes: int): self.need_successes = need_successes self.predefined_buckets = self.read_buckets() self.bucketed_nodes = defaultdict(list) + self.selected_nodes = dict() def read_buckets(self) -> Dict: try: @@ -413,6 +416,10 @@ def find_bucket(self, node): return bucket_name return None + def mark_as_not_successful(self, failure: ChecksumAddress): + bucket = self.selected_nodes[failure] + self.bucketed_nodes[bucket].remove(failure) + def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: batch = [] batch_size = self.need_successes - _successes @@ -425,6 +432,7 @@ def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP: continue self.bucketed_nodes[bucket].append(selected) + self.selected_nodes[selected] = bucket batch.append(selected) if not batch: return None @@ -442,8 +450,8 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: # ensure node is up and reachable # self.network_middleware.ping(ursula) version = self._get_ursula_version(ursula) - if min_version and not self._is_version_greater_or_equal( - min_version, version + if parse_min_version and not self._is_version_greater_or_equal( + parse_min_version, version ): raise ValueError( f"Ursula ({ursula_address}) has too old version ({version})" @@ -453,6 +461,7 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: except Exception as e: message = f"Ursula ({ursula_address}) is unreachable: {str(e)}" self.log.debug(message) + value_factory.mark_as_not_successful(ursula_address) raise self.block_until_number_of_known_nodes_is( diff --git a/tests/conftest.py b/tests/conftest.py index fae7cf7..82f4c0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,9 +34,11 @@ from tests.constants import ( MOCK_ETH_PROVIDER_URI, TEMPORARY_DOMAIN, + TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID, ) from tests.mock.interfaces import MockBlockchain +from tests.utils.middleware import MockRestMiddleware, _TestMiddlewareClient from tests.utils.registry import MockRegistrySource, mock_registry_sources # Crash on server error by default @@ -245,9 +247,53 @@ def mock_signer(get_random_checksum_address): return signer +class _MockMiddlewareClient(_TestMiddlewareClient): + class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ursulas_versions = {} + + def get(self, *args, **kwargs): + if kwargs.get("path") == "status" and kwargs.get("params")["json"]: + node_address = kwargs.get("node_or_sprout").checksum_address + version = self.ursulas_versions.get(node_address, "1.1.1") + return _MockMiddlewareClient.MockResponse({"version": version}, 200) + + real_get = super(_TestMiddlewareClient, self).__getattr__("get") + return real_get(*args, **kwargs) + + +class _MockRestMiddleware(MockRestMiddleware): + """ + Modified middleware to emulate returning status with version. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI) + + def set_ursulas_versions(self, ursulas_versions: dict): + self.client.ursulas_versions = ursulas_versions + + def clean_ursulas_versions(self): + self.client.ursulas_versions = {} + + +@pytest.fixture(scope="module") +def mock_rest_middleware(): + return _MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI) + + @pytest.fixture(scope="module") @pytest.mark.usefixtures('testerchain', 'agency') -def porter(ursulas, mock_rest_middleware, test_registry, module_mocker): +def porter(ursulas, mock_rest_middleware, test_registry): porter = Porter( domain=TEMPORARY_DOMAIN, eth_endpoint=MOCK_ETH_PROVIDER_URI, @@ -259,7 +305,6 @@ def porter(ursulas, mock_rest_middleware, test_registry, module_mocker): verify_node_bonding=False, network_middleware=mock_rest_middleware, ) - module_mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0") yield porter porter.stop_learning_loop() diff --git a/tests/test_bucket_sampling.py b/tests/test_bucket_sampling.py index 4d63e44..1296314 100644 --- a/tests/test_bucket_sampling.py +++ b/tests/test_bucket_sampling.py @@ -23,6 +23,18 @@ def json(self): mocker.patch("requests.get", return_value=MockRequestResponse()) +@pytest.fixture(autouse=True) +def mock_worker_pool_sleep(): + + original = WorkerPool._sleep + + def _sleep(worker_pool, timeout): + original(worker_pool, 0.01) + pass + + WorkerPool._sleep = _sleep + + def test_bucket_sampling_schema(get_random_checksum_address): # # Input i.e. load @@ -81,6 +93,11 @@ def test_bucket_sampling_schema(get_random_checksum_address): updated_data["timeout"] = 20 BucketSampling().load(updated_data) + # min version + updated_data = dict(required_data) + updated_data["min_version"] = "1.1.1" + BucketSampling().load(updated_data) + # list input formatted as ',' separated strings updated_data = dict(required_data) updated_data["exclude_ursulas"] = ",".join(exclude_ursulas) @@ -133,6 +150,18 @@ def test_bucket_sampling_schema(get_random_checksum_address): updated_data["duration"] = -1 BucketSampling().load(updated_data) + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "v1x1.1" + BucketSampling().load(updated_data) + + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "1-1-1" + BucketSampling().load(updated_data) + # # Output i.e. dump # @@ -210,11 +239,22 @@ def test_bucket_sampling_python_interface( with pytest.raises(WorkerPool.OutOfValues): _, _ = porter.bucket_sampling(quantity=5) + # no nodes with specified version + with pytest.raises(WorkerPool.OutOfValues): + _, _ = porter.bucket_sampling(quantity=1, timeout=30, min_version="2.2.2") + porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"}) + ursulas_info, _ = porter.bucket_sampling(quantity=1, min_version="2.2.2") + assert ursulas_info[0] == sampled_ursulas[0] + with pytest.raises(WorkerPool.OutOfValues): + porter.bucket_sampling(quantity=2, min_version="2.2.2") + porter.network_middleware.clean_ursulas_versions() + @pytest.mark.parametrize("timeout", [None, 10]) @pytest.mark.parametrize("random_seed", [None, 42]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_bucket_sampling_web_interface( + porter, porter_web_controller, ursulas, timeout, @@ -310,3 +350,30 @@ def test_bucket_sampling_web_interface( ) assert response.status_code == 400 assert "Insufficient nodes" in response.text + + # + # Failure case: no nodes with specified version + # + failed_ursula_params = dict(get_ursulas_params) + failed_ursula_params["quantity"] = 1 + failed_ursula_params["min_version"] = "2.0.0" + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + + porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"}) + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert response.status_code == 200 + response_data = json.loads(response.data) + ursulas_info = response_data["result"]["ursulas"] + assert ursulas_info[0] == sampled_ursulas[0] + + failed_ursula_params["quantity"] = 2 + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + porter.network_middleware.clean_ursulas_versions() diff --git a/tests/test_get_ursulas.py b/tests/test_get_ursulas.py index bedaa34..10233a1 100644 --- a/tests/test_get_ursulas.py +++ b/tests/test_get_ursulas.py @@ -1,6 +1,7 @@ import json import pytest +from nucypher.utilities.concurrency import WorkerPool from nucypher_core.umbral import SecretKey from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData @@ -8,6 +9,18 @@ from porter.schema import GetUrsulas, UrsulaInfoSchema +@pytest.fixture(autouse=True) +def mock_worker_pool_sleep(): + + original = WorkerPool._sleep + + def _sleep(worker_pool, timeout): + original(worker_pool, 0.01) + pass + + WorkerPool._sleep = _sleep + + def test_get_ursulas_schema(get_random_checksum_address): # # Input i.e. load @@ -88,6 +101,11 @@ def test_get_ursulas_schema(get_random_checksum_address): assert data["exclude_ursulas"] == [exclude_ursulas[0]] assert data["include_ursulas"] == [include_ursulas[0]] + # min version + updated_data = dict(required_data) + updated_data["min_version"] = "1.1.1" + GetUrsulas().load(updated_data) + # invalid include entry updated_data = dict(required_data) updated_data["exclude_ursulas"] = exclude_ursulas @@ -171,6 +189,18 @@ def test_get_ursulas_schema(get_random_checksum_address): updated_data["duration"] = -1 GetUrsulas().load(updated_data) + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "v1x1.1" + GetUrsulas().load(updated_data) + + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "1-1-1" + GetUrsulas().load(updated_data) + # # Output i.e. dump # @@ -282,10 +312,23 @@ def test_get_ursulas_python_interface( with pytest.raises(ValueError, match="Insufficient nodes"): porter.get_ursulas(quantity=len(ursulas) + 1) + # no nodes with specified version + with pytest.raises(WorkerPool.OutOfValues): + porter.get_ursulas(quantity=1, min_version="2.2.2") + porter.network_middleware.set_ursulas_versions( + {ursulas[0].checksum_address: "3.0.0"} + ) + ursulas_info = porter.get_ursulas(quantity=1, min_version="2.2.2") + assert ursulas[0].checksum_address == ursulas_info[0].checksum_address + with pytest.raises(WorkerPool.OutOfValues): + porter.get_ursulas(quantity=2, min_version="2.2.2") + porter.network_middleware.clean_ursulas_versions() + @pytest.mark.parametrize("timeout", [None, 10, 20]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_get_ursulas_web_interface( + porter, porter_web_controller, ursulas, timeout, @@ -388,3 +431,31 @@ def test_get_ursulas_web_interface( ) assert response.status_code == 400 assert "Insufficient nodes" in response.text + + # + # Failure case: no nodes with specified version + # + failed_ursula_params = dict(get_ursulas_params) + failed_ursula_params["quantity"] = 1 + failed_ursula_params["min_version"] = "2.0.0" + del failed_ursula_params["include_ursulas"] + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + + porter.network_middleware.set_ursulas_versions({include_ursulas[0]: "3.0.0"}) + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert response.status_code == 200 + response_data = json.loads(response.data) + ursulas_info = response_data["result"]["ursulas"] + assert ursulas_info[0]["checksum_address"] == include_ursulas[0] + + failed_ursula_params["quantity"] = 2 + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + porter.network_middleware.clean_ursulas_versions()