From f35724e39457c74879542bb316066e1a910aa909 Mon Sep 17 00:00:00 2001 From: Victoria Zotova Date: Fri, 2 Aug 2024 14:00:19 -0400 Subject: [PATCH 01/12] Adds Schema for accepting minimum version as parameter and draft to handle this parameter --- porter/fields/base.py | 18 ++++++++++++++++++ porter/interfaces.py | 4 ++++ porter/main.py | 36 ++++++++++++++++++++++++++++++++++-- porter/schema.py | 25 +++++++++++++++++++++++++ 4 files changed, 81 insertions(+), 2 deletions(-) diff --git a/porter/fields/base.py b/porter/fields/base.py index 68dc27e..f25da86 100644 --- a/porter/fields/base.py +++ b/porter/fields/base.py @@ -1,5 +1,6 @@ import json from base64 import b64decode, b64encode +import re import click from marshmallow import fields @@ -108,3 +109,20 @@ def _deserialize(self, value, attr, data, **kwargs): f"Unexpected object type, {type(result)}; expected {self.expected_type}") return result + + +class VersionString(String): + + def _serialize(self, value, attr, obj, **kwargs) -> str: + if (type(value) is not list or len(value) == 0 or len(value) > 3): + raise InvalidInputData( + f"Unexpected object type, {type(value)}; expected list[3]") + + return ".".join(value) + + def _deserialize(self, value, attr, data, **kwargs) -> list: + pattern = r'(\d+\.)?(\d+\.)?(\d+)' + match = re.findall(pattern, value) + if len(match) != 1: + raise InvalidInputData("Minimum version must have x.x.x format") + return match[0] diff --git a/porter/interfaces.py b/porter/interfaces.py index 0efc634..69aa027 100644 --- a/porter/interfaces.py +++ b/porter/interfaces.py @@ -40,6 +40,7 @@ def get_ursulas( include_ursulas: Optional[List[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> Dict: ursulas_info = self.implementer.get_ursulas( quantity=quantity, @@ -47,6 +48,7 @@ def get_ursulas( include_ursulas=include_ursulas, timeout=timeout, duration=duration, + min_version=min_version, ) response_data = {"ursulas": ursulas_info} # list of UrsulaInfo objects @@ -104,6 +106,7 @@ def bucket_sampling( exclude_ursulas: Optional[List[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> Dict: ursulas, block_number = self.implementer.bucket_sampling( quantity=quantity, @@ -111,6 +114,7 @@ def bucket_sampling( exclude_ursulas=exclude_ursulas, timeout=timeout, duration=duration, + min_version=min_version, ) response_data = {"ursulas": ursulas, "block_number": block_number} diff --git a/porter/main.py b/porter/main.py index 81e19fc..95d8dec 100644 --- a/porter/main.py +++ b/porter/main.py @@ -155,6 +155,26 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str): ): BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint) + @staticmethod + def _parse_version(version: Optional[str] = None) -> list: + if not version: + return [0, 0, 0] + parsed = version.split("\.") + if len(parsed) <= 1: + raise InvalidInputData("Minimum version must have x.x.x format") + return parsed + + @staticmethod + def _is_version_greater_or_equal(min_version: list, version: list) -> bool: + for i in version: + if (version[i] < min_version[i]): + return False + return True + + def _get_ursula_version(self, ursula: Ursula) -> list: + response = self.network_middleware.client.get(node_or_sprout=ursula, path="status/?json=true") + return self._parse_version(response["version"]) + def get_ursulas( self, quantity: int, @@ -162,11 +182,13 @@ def get_ursulas( include_ursulas: Optional[Sequence[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> List[UrsulaInfo]: timeout = self._configure_timeout( "sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT ) duration = duration or 0 + min_version_parsed = self._parse_version(min_version) reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration) available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir) @@ -185,7 +207,11 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo: ursula = self.known_nodes[ursula_address] try: # ensure node is up and reachable - self.network_middleware.ping(ursula) + # self.network_middleware.ping(ursula) + version = self._get_ursula_version(ursula) + if not self._is_version_greater_or_equal(min_version_parsed, version): + raise ValueError(f"Ursula ({ursula_address}) has too old version ({version})") + return Porter.UrsulaInfo(checksum_address=ursula_address, uri=f"{ursula.rest_interface.formal_uri}", encrypting_key=ursula.public_keys(DecryptingPower)) @@ -299,11 +325,13 @@ def bucket_sampling( exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None, timeout: Optional[int] = None, duration: Optional[int] = None, + min_version: Optional[str] = None, ) -> Tuple[List[ChecksumAddress], int]: timeout = self._configure_timeout( "bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT ) duration = duration or 0 + min_version_parsed = self._parse_version(min_version) if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING: raise ValueError("Bucket sampling is only for TACo Mainnet") @@ -418,7 +446,11 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: ursula = self.known_nodes[ursula_address] try: # ensure node is up and reachable - self.network_middleware.ping(ursula) + # self.network_middleware.ping(ursula) + version = self._get_ursula_version(ursula) + if not self._is_version_greater_or_equal(min_version_parsed, version): + raise ValueError(f"Ursula ({ursula_address}) has too old version ({version})") + return ursula_address except Exception as e: message = f"Ursula ({ursula_address}) is unreachable: {str(e)}" diff --git a/porter/schema.py b/porter/schema.py index b64c0cc..88e9468 100644 --- a/porter/schema.py +++ b/porter/schema.py @@ -8,6 +8,7 @@ Integer, NonNegativeInteger, PositiveInteger, + VersionString, StringList, ) from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData @@ -125,6 +126,18 @@ class GetUrsulas(BaseSchema): ), ) + min_version = VersionString( + required=False, + load_only=True, + click=click.option( + "--min-version", + "-mv", + help="Minimum acceptable version of Ursula", + type=click.STRING, + required=False, + ), + ) + # output ursulas = marshmallow_fields.List(marshmallow_fields.Nested(UrsulaInfoSchema), dump_only=True) @@ -369,6 +382,18 @@ class BucketSampling(BaseSchema): ), ) + min_version = VersionString( + required=False, + load_only=True, + click=click.option( + "--min-version", + "-mv", + help="Minimum acceptable version of Ursula", + type=click.STRING, + required=False, + ), + ) + # output ursulas = marshmallow_fields.List(UrsulaChecksumAddress, dump_only=True) block_number = marshmallow_fields.Int(dump_only=True) From 5a0c1b41d3e61060af87e805a800152401a89c3f Mon Sep 17 00:00:00 2001 From: derekpierre Date: Tue, 6 Aug 2024 15:00:55 -0400 Subject: [PATCH 02/12] Use packaging.version.parse for comparing versions. Split request into path and params. --- porter/main.py | 60 ++++++++++++++++++++++++-------------------------- 1 file changed, 29 insertions(+), 31 deletions(-) diff --git a/porter/main.py b/porter/main.py index 95d8dec..90378cc 100644 --- a/porter/main.py +++ b/porter/main.py @@ -32,6 +32,7 @@ TreasureMap, ) from nucypher_core.umbral import PublicKey +from packaging.version import parse from prometheus_flask_exporter import PrometheusMetrics import porter @@ -156,24 +157,14 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str): BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint) @staticmethod - def _parse_version(version: Optional[str] = None) -> list: - if not version: - return [0, 0, 0] - parsed = version.split("\.") - if len(parsed) <= 1: - raise InvalidInputData("Minimum version must have x.x.x format") - return parsed - - @staticmethod - def _is_version_greater_or_equal(min_version: list, version: list) -> bool: - for i in version: - if (version[i] < min_version[i]): - return False - return True - - def _get_ursula_version(self, ursula: Ursula) -> list: - response = self.network_middleware.client.get(node_or_sprout=ursula, path="status/?json=true") - return self._parse_version(response["version"]) + def _is_version_greater_or_equal(min_version: str, version: str) -> bool: + return parse(version) >= parse(min_version) + + def _get_ursula_version(self, ursula: Ursula) -> str: + response = self.network_middleware.client.get( + node_or_sprout=ursula, path="status", params={"json": "true"} + ) + return response["version"] def get_ursulas( self, @@ -188,7 +179,6 @@ def get_ursulas( "sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT ) duration = duration or 0 - min_version_parsed = self._parse_version(min_version) reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration) available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir) @@ -206,15 +196,20 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo: ursula_address = to_checksum_address(ursula_address) ursula = self.known_nodes[ursula_address] try: - # ensure node is up and reachable - # self.network_middleware.ping(ursula) + # ensure node is up and reachable and check version version = self._get_ursula_version(ursula) - if not self._is_version_greater_or_equal(min_version_parsed, version): - raise ValueError(f"Ursula ({ursula_address}) has too old version ({version})") - - return Porter.UrsulaInfo(checksum_address=ursula_address, - uri=f"{ursula.rest_interface.formal_uri}", - encrypting_key=ursula.public_keys(DecryptingPower)) + if min_version and not self._is_version_greater_or_equal( + min_version, version + ): + raise ValueError( + f"Ursula ({ursula_address}) has too old version ({version})" + ) + + return Porter.UrsulaInfo( + checksum_address=ursula_address, + uri=f"{ursula.rest_interface.formal_uri}", + encrypting_key=ursula.public_keys(DecryptingPower), + ) except Exception as e: self.log.debug(f"Ursula ({ursula_address}) is unreachable: {str(e)}") raise @@ -331,7 +326,6 @@ def bucket_sampling( "bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT ) duration = duration or 0 - min_version_parsed = self._parse_version(min_version) if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING: raise ValueError("Bucket sampling is only for TACo Mainnet") @@ -448,9 +442,13 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: # ensure node is up and reachable # self.network_middleware.ping(ursula) version = self._get_ursula_version(ursula) - if not self._is_version_greater_or_equal(min_version_parsed, version): - raise ValueError(f"Ursula ({ursula_address}) has too old version ({version})") - + if min_version and not self._is_version_greater_or_equal( + min_version, version + ): + raise ValueError( + f"Ursula ({ursula_address}) has too old version ({version})" + ) + return ursula_address except Exception as e: message = f"Ursula ({ursula_address}) is unreachable: {str(e)}" From 9a4ca14ea11d2bd6e1117c5d2ebddc89e9fd5c3c Mon Sep 17 00:00:00 2001 From: derekpierre Date: Tue, 6 Aug 2024 15:01:48 -0400 Subject: [PATCH 03/12] Mock function call that gets ursula version. --- tests/test_get_ursulas.py | 7 +++++++ 1 file changed, 7 insertions(+) diff --git a/tests/test_get_ursulas.py b/tests/test_get_ursulas.py index bedaa34..c596a5d 100644 --- a/tests/test_get_ursulas.py +++ b/tests/test_get_ursulas.py @@ -195,12 +195,15 @@ def test_get_ursulas_schema(get_random_checksum_address): @pytest.mark.parametrize("timeout", [None, 15, 20]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_get_ursulas_python_interface( + mocker, porter, ursulas, timeout, duration, excluded_staker_address_for_duration_greater_than_0, ): + mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0") + # simple quantity = 4 ursulas_info = porter.get_ursulas(quantity=quantity) @@ -286,12 +289,16 @@ def test_get_ursulas_python_interface( @pytest.mark.parametrize("timeout", [None, 10, 20]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_get_ursulas_web_interface( + mocker, + porter, porter_web_controller, ursulas, timeout, duration, excluded_staker_address_for_duration_greater_than_0, ): + mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0") + # Send bad data to assert error return response = porter_web_controller.get( "/get_ursulas", data=json.dumps({"bad": "input"}) From ee9ee4762955f41eb6e0d98069cb988ade6ce6c6 Mon Sep 17 00:00:00 2001 From: Victoria Zotova Date: Tue, 6 Aug 2024 16:16:22 -0400 Subject: [PATCH 04/12] run linter, use Version.parse for schema --- porter/fields/base.py | 16 ++++++---------- porter/schema.py | 2 +- 2 files changed, 7 insertions(+), 11 deletions(-) diff --git a/porter/fields/base.py b/porter/fields/base.py index f25da86..8d1ef47 100644 --- a/porter/fields/base.py +++ b/porter/fields/base.py @@ -1,9 +1,9 @@ import json from base64 import b64decode, b64encode -import re import click from marshmallow import fields +from packaging.version import Version, parse from porter.fields.exceptions import InvalidInputData @@ -114,15 +114,11 @@ def _deserialize(self, value, attr, data, **kwargs): class VersionString(String): def _serialize(self, value, attr, obj, **kwargs) -> str: - if (type(value) is not list or len(value) == 0 or len(value) > 3): + if type(value) is not Version: raise InvalidInputData( - f"Unexpected object type, {type(value)}; expected list[3]") - - return ".".join(value) + f"Unexpected object type, {type(value)}; expected Version" + ) + return str(value) def _deserialize(self, value, attr, data, **kwargs) -> list: - pattern = r'(\d+\.)?(\d+\.)?(\d+)' - match = re.findall(pattern, value) - if len(match) != 1: - raise InvalidInputData("Minimum version must have x.x.x format") - return match[0] + return parse(value) diff --git a/porter/schema.py b/porter/schema.py index 88e9468..17b2eac 100644 --- a/porter/schema.py +++ b/porter/schema.py @@ -8,8 +8,8 @@ Integer, NonNegativeInteger, PositiveInteger, - VersionString, StringList, + VersionString, ) from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData from porter.fields.retrieve import CapsuleFrag, RetrievalKit From ef31aee2577390eca411a27b070e23cbc8be6598 Mon Sep 17 00:00:00 2001 From: Victoria Zotova Date: Tue, 6 Aug 2024 19:06:46 -0400 Subject: [PATCH 05/12] Move mock of getting ursula version to the porter fixture --- tests/conftest.py | 3 ++- tests/test_get_ursulas.py | 7 ------- 2 files changed, 2 insertions(+), 8 deletions(-) diff --git a/tests/conftest.py b/tests/conftest.py index 1763e9e..fae7cf7 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -247,7 +247,7 @@ def mock_signer(get_random_checksum_address): @pytest.fixture(scope="module") @pytest.mark.usefixtures('testerchain', 'agency') -def porter(ursulas, mock_rest_middleware, test_registry): +def porter(ursulas, mock_rest_middleware, test_registry, module_mocker): porter = Porter( domain=TEMPORARY_DOMAIN, eth_endpoint=MOCK_ETH_PROVIDER_URI, @@ -259,6 +259,7 @@ def porter(ursulas, mock_rest_middleware, test_registry): verify_node_bonding=False, network_middleware=mock_rest_middleware, ) + module_mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0") yield porter porter.stop_learning_loop() diff --git a/tests/test_get_ursulas.py b/tests/test_get_ursulas.py index c596a5d..bedaa34 100644 --- a/tests/test_get_ursulas.py +++ b/tests/test_get_ursulas.py @@ -195,15 +195,12 @@ def test_get_ursulas_schema(get_random_checksum_address): @pytest.mark.parametrize("timeout", [None, 15, 20]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_get_ursulas_python_interface( - mocker, porter, ursulas, timeout, duration, excluded_staker_address_for_duration_greater_than_0, ): - mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0") - # simple quantity = 4 ursulas_info = porter.get_ursulas(quantity=quantity) @@ -289,16 +286,12 @@ def test_get_ursulas_python_interface( @pytest.mark.parametrize("timeout", [None, 10, 20]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_get_ursulas_web_interface( - mocker, - porter, porter_web_controller, ursulas, timeout, duration, excluded_staker_address_for_duration_greater_than_0, ): - mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0") - # Send bad data to assert error return response = porter_web_controller.get( "/get_ursulas", data=json.dumps({"bad": "input"}) From f735076d01c1231d481cf94d1fcdd088fded51b4 Mon Sep 17 00:00:00 2001 From: Victoria Zotova Date: Tue, 6 Aug 2024 19:36:26 -0400 Subject: [PATCH 06/12] Fix parsing versions after manual testing --- porter/main.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/porter/main.py b/porter/main.py index 90378cc..3bea7a2 100644 --- a/porter/main.py +++ b/porter/main.py @@ -158,13 +158,13 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str): @staticmethod def _is_version_greater_or_equal(min_version: str, version: str) -> bool: - return parse(version) >= parse(min_version) + return parse(version) >= min_version def _get_ursula_version(self, ursula: Ursula) -> str: response = self.network_middleware.client.get( node_or_sprout=ursula, path="status", params={"json": "true"} ) - return response["version"] + return response.json()["version"] def get_ursulas( self, From 619ea9421262060af13ad9ac9f1e32a860191fe8 Mon Sep 17 00:00:00 2001 From: Victoria Zotova Date: Thu, 8 Aug 2024 13:37:13 -0400 Subject: [PATCH 07/12] Changing mocking in tests for RestMiddleware and testing versioning --- porter/fields/base.py | 16 +++----- porter/main.py | 21 ++++++++--- tests/conftest.py | 49 +++++++++++++++++++++++- tests/test_bucket_sampling.py | 67 +++++++++++++++++++++++++++++++++ tests/test_get_ursulas.py | 71 +++++++++++++++++++++++++++++++++++ 5 files changed, 206 insertions(+), 18 deletions(-) diff --git a/porter/fields/base.py b/porter/fields/base.py index 8d1ef47..1d0c65f 100644 --- a/porter/fields/base.py +++ b/porter/fields/base.py @@ -3,7 +3,7 @@ import click from marshmallow import fields -from packaging.version import Version, parse +from packaging.version import parse from porter.fields.exceptions import InvalidInputData @@ -113,12 +113,8 @@ def _deserialize(self, value, attr, data, **kwargs): class VersionString(String): - def _serialize(self, value, attr, obj, **kwargs) -> str: - if type(value) is not Version: - raise InvalidInputData( - f"Unexpected object type, {type(value)}; expected Version" - ) - return str(value) - - def _deserialize(self, value, attr, data, **kwargs) -> list: - return parse(value) + def _validate(self, value): + try: + parse(value) + except Exception: + raise InvalidInputData(f"{self.name} must be a correct version.") diff --git a/porter/main.py b/porter/main.py index 3bea7a2..2a5f4d4 100644 --- a/porter/main.py +++ b/porter/main.py @@ -32,7 +32,7 @@ TreasureMap, ) from nucypher_core.umbral import PublicKey -from packaging.version import parse +from packaging.version import Version, parse from prometheus_flask_exporter import PrometheusMetrics import porter @@ -157,7 +157,7 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str): BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint) @staticmethod - def _is_version_greater_or_equal(min_version: str, version: str) -> bool: + def _is_version_greater_or_equal(min_version: Version, version: str) -> bool: return parse(version) >= min_version def _get_ursula_version(self, ursula: Ursula) -> str: @@ -179,6 +179,7 @@ def get_ursulas( "sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT ) duration = duration or 0 + parse_min_version = parse(min_version) if min_version else None reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration) available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir) @@ -198,8 +199,8 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo: try: # ensure node is up and reachable and check version version = self._get_ursula_version(ursula) - if min_version and not self._is_version_greater_or_equal( - min_version, version + if parse_min_version and not self._is_version_greater_or_equal( + parse_min_version, version ): raise ValueError( f"Ursula ({ursula_address}) has too old version ({version})" @@ -326,6 +327,7 @@ def bucket_sampling( "bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT ) duration = duration or 0 + parse_min_version = parse(min_version) if min_version else None if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING: raise ValueError("Bucket sampling is only for TACo Mainnet") @@ -387,6 +389,7 @@ def __init__(self, _reservoir, need_successes: int): self.need_successes = need_successes self.predefined_buckets = self.read_buckets() self.bucketed_nodes = defaultdict(list) + self.selected_nodes = dict() def read_buckets(self) -> Dict: try: @@ -413,6 +416,10 @@ def find_bucket(self, node): return bucket_name return None + def mark_as_not_successful(self, failure: ChecksumAddress): + bucket = self.selected_nodes[failure] + self.bucketed_nodes[bucket].remove(failure) + def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: batch = [] batch_size = self.need_successes - _successes @@ -425,6 +432,7 @@ def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP: continue self.bucketed_nodes[bucket].append(selected) + self.selected_nodes[selected] = bucket batch.append(selected) if not batch: return None @@ -442,8 +450,8 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: # ensure node is up and reachable # self.network_middleware.ping(ursula) version = self._get_ursula_version(ursula) - if min_version and not self._is_version_greater_or_equal( - min_version, version + if parse_min_version and not self._is_version_greater_or_equal( + parse_min_version, version ): raise ValueError( f"Ursula ({ursula_address}) has too old version ({version})" @@ -453,6 +461,7 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: except Exception as e: message = f"Ursula ({ursula_address}) is unreachable: {str(e)}" self.log.debug(message) + value_factory.mark_as_not_successful(ursula_address) raise self.block_until_number_of_known_nodes_is( diff --git a/tests/conftest.py b/tests/conftest.py index fae7cf7..82f4c0e 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -34,9 +34,11 @@ from tests.constants import ( MOCK_ETH_PROVIDER_URI, TEMPORARY_DOMAIN, + TEST_ETH_PROVIDER_URI, TESTERCHAIN_CHAIN_ID, ) from tests.mock.interfaces import MockBlockchain +from tests.utils.middleware import MockRestMiddleware, _TestMiddlewareClient from tests.utils.registry import MockRegistrySource, mock_registry_sources # Crash on server error by default @@ -245,9 +247,53 @@ def mock_signer(get_random_checksum_address): return signer +class _MockMiddlewareClient(_TestMiddlewareClient): + class MockResponse: + def __init__(self, json_data, status_code): + self.json_data = json_data + self.status_code = status_code + + def json(self): + return self.json_data + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.ursulas_versions = {} + + def get(self, *args, **kwargs): + if kwargs.get("path") == "status" and kwargs.get("params")["json"]: + node_address = kwargs.get("node_or_sprout").checksum_address + version = self.ursulas_versions.get(node_address, "1.1.1") + return _MockMiddlewareClient.MockResponse({"version": version}, 200) + + real_get = super(_TestMiddlewareClient, self).__getattr__("get") + return real_get(*args, **kwargs) + + +class _MockRestMiddleware(MockRestMiddleware): + """ + Modified middleware to emulate returning status with version. + """ + + def __init__(self, *args, **kwargs): + super().__init__(*args, **kwargs) + self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI) + + def set_ursulas_versions(self, ursulas_versions: dict): + self.client.ursulas_versions = ursulas_versions + + def clean_ursulas_versions(self): + self.client.ursulas_versions = {} + + +@pytest.fixture(scope="module") +def mock_rest_middleware(): + return _MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI) + + @pytest.fixture(scope="module") @pytest.mark.usefixtures('testerchain', 'agency') -def porter(ursulas, mock_rest_middleware, test_registry, module_mocker): +def porter(ursulas, mock_rest_middleware, test_registry): porter = Porter( domain=TEMPORARY_DOMAIN, eth_endpoint=MOCK_ETH_PROVIDER_URI, @@ -259,7 +305,6 @@ def porter(ursulas, mock_rest_middleware, test_registry, module_mocker): verify_node_bonding=False, network_middleware=mock_rest_middleware, ) - module_mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0") yield porter porter.stop_learning_loop() diff --git a/tests/test_bucket_sampling.py b/tests/test_bucket_sampling.py index 4d63e44..1296314 100644 --- a/tests/test_bucket_sampling.py +++ b/tests/test_bucket_sampling.py @@ -23,6 +23,18 @@ def json(self): mocker.patch("requests.get", return_value=MockRequestResponse()) +@pytest.fixture(autouse=True) +def mock_worker_pool_sleep(): + + original = WorkerPool._sleep + + def _sleep(worker_pool, timeout): + original(worker_pool, 0.01) + pass + + WorkerPool._sleep = _sleep + + def test_bucket_sampling_schema(get_random_checksum_address): # # Input i.e. load @@ -81,6 +93,11 @@ def test_bucket_sampling_schema(get_random_checksum_address): updated_data["timeout"] = 20 BucketSampling().load(updated_data) + # min version + updated_data = dict(required_data) + updated_data["min_version"] = "1.1.1" + BucketSampling().load(updated_data) + # list input formatted as ',' separated strings updated_data = dict(required_data) updated_data["exclude_ursulas"] = ",".join(exclude_ursulas) @@ -133,6 +150,18 @@ def test_bucket_sampling_schema(get_random_checksum_address): updated_data["duration"] = -1 BucketSampling().load(updated_data) + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "v1x1.1" + BucketSampling().load(updated_data) + + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "1-1-1" + BucketSampling().load(updated_data) + # # Output i.e. dump # @@ -210,11 +239,22 @@ def test_bucket_sampling_python_interface( with pytest.raises(WorkerPool.OutOfValues): _, _ = porter.bucket_sampling(quantity=5) + # no nodes with specified version + with pytest.raises(WorkerPool.OutOfValues): + _, _ = porter.bucket_sampling(quantity=1, timeout=30, min_version="2.2.2") + porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"}) + ursulas_info, _ = porter.bucket_sampling(quantity=1, min_version="2.2.2") + assert ursulas_info[0] == sampled_ursulas[0] + with pytest.raises(WorkerPool.OutOfValues): + porter.bucket_sampling(quantity=2, min_version="2.2.2") + porter.network_middleware.clean_ursulas_versions() + @pytest.mark.parametrize("timeout", [None, 10]) @pytest.mark.parametrize("random_seed", [None, 42]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_bucket_sampling_web_interface( + porter, porter_web_controller, ursulas, timeout, @@ -310,3 +350,30 @@ def test_bucket_sampling_web_interface( ) assert response.status_code == 400 assert "Insufficient nodes" in response.text + + # + # Failure case: no nodes with specified version + # + failed_ursula_params = dict(get_ursulas_params) + failed_ursula_params["quantity"] = 1 + failed_ursula_params["min_version"] = "2.0.0" + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + + porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"}) + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert response.status_code == 200 + response_data = json.loads(response.data) + ursulas_info = response_data["result"]["ursulas"] + assert ursulas_info[0] == sampled_ursulas[0] + + failed_ursula_params["quantity"] = 2 + response = porter_web_controller.get( + "/bucket_sampling", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + porter.network_middleware.clean_ursulas_versions() diff --git a/tests/test_get_ursulas.py b/tests/test_get_ursulas.py index bedaa34..10233a1 100644 --- a/tests/test_get_ursulas.py +++ b/tests/test_get_ursulas.py @@ -1,6 +1,7 @@ import json import pytest +from nucypher.utilities.concurrency import WorkerPool from nucypher_core.umbral import SecretKey from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData @@ -8,6 +9,18 @@ from porter.schema import GetUrsulas, UrsulaInfoSchema +@pytest.fixture(autouse=True) +def mock_worker_pool_sleep(): + + original = WorkerPool._sleep + + def _sleep(worker_pool, timeout): + original(worker_pool, 0.01) + pass + + WorkerPool._sleep = _sleep + + def test_get_ursulas_schema(get_random_checksum_address): # # Input i.e. load @@ -88,6 +101,11 @@ def test_get_ursulas_schema(get_random_checksum_address): assert data["exclude_ursulas"] == [exclude_ursulas[0]] assert data["include_ursulas"] == [include_ursulas[0]] + # min version + updated_data = dict(required_data) + updated_data["min_version"] = "1.1.1" + GetUrsulas().load(updated_data) + # invalid include entry updated_data = dict(required_data) updated_data["exclude_ursulas"] = exclude_ursulas @@ -171,6 +189,18 @@ def test_get_ursulas_schema(get_random_checksum_address): updated_data["duration"] = -1 GetUrsulas().load(updated_data) + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "v1x1.1" + GetUrsulas().load(updated_data) + + # invalid min version + with pytest.raises(InvalidInputData): + updated_data = dict(required_data) + updated_data["min_version"] = "1-1-1" + GetUrsulas().load(updated_data) + # # Output i.e. dump # @@ -282,10 +312,23 @@ def test_get_ursulas_python_interface( with pytest.raises(ValueError, match="Insufficient nodes"): porter.get_ursulas(quantity=len(ursulas) + 1) + # no nodes with specified version + with pytest.raises(WorkerPool.OutOfValues): + porter.get_ursulas(quantity=1, min_version="2.2.2") + porter.network_middleware.set_ursulas_versions( + {ursulas[0].checksum_address: "3.0.0"} + ) + ursulas_info = porter.get_ursulas(quantity=1, min_version="2.2.2") + assert ursulas[0].checksum_address == ursulas_info[0].checksum_address + with pytest.raises(WorkerPool.OutOfValues): + porter.get_ursulas(quantity=2, min_version="2.2.2") + porter.network_middleware.clean_ursulas_versions() + @pytest.mark.parametrize("timeout", [None, 10, 20]) @pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365]) def test_get_ursulas_web_interface( + porter, porter_web_controller, ursulas, timeout, @@ -388,3 +431,31 @@ def test_get_ursulas_web_interface( ) assert response.status_code == 400 assert "Insufficient nodes" in response.text + + # + # Failure case: no nodes with specified version + # + failed_ursula_params = dict(get_ursulas_params) + failed_ursula_params["quantity"] = 1 + failed_ursula_params["min_version"] = "2.0.0" + del failed_ursula_params["include_ursulas"] + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + + porter.network_middleware.set_ursulas_versions({include_ursulas[0]: "3.0.0"}) + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert response.status_code == 200 + response_data = json.loads(response.data) + ursulas_info = response_data["result"]["ursulas"] + assert ursulas_info[0]["checksum_address"] == include_ursulas[0] + + failed_ursula_params["quantity"] = 2 + response = porter_web_controller.get( + "/get_ursulas", data=json.dumps(failed_ursula_params) + ) + assert "has too old version (1.1.1)" in response.text + porter.network_middleware.clean_ursulas_versions() From 2c9a5e924d67ce2b9e3b893d94eb7b88a70fcc1a Mon Sep 17 00:00:00 2001 From: derekpierre Date: Thu, 8 Aug 2024 16:44:41 -0400 Subject: [PATCH 08/12] Copy dict for safety. --- tests/conftest.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/tests/conftest.py b/tests/conftest.py index 82f4c0e..6c47cc2 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -280,7 +280,7 @@ def __init__(self, *args, **kwargs): self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI) def set_ursulas_versions(self, ursulas_versions: dict): - self.client.ursulas_versions = ursulas_versions + self.client.ursulas_versions = dict(ursulas_versions) def clean_ursulas_versions(self): self.client.ursulas_versions = {} From 1c2200143fbeb895b5218bb93e7c1a0502982c8d Mon Sep 17 00:00:00 2001 From: derekpierre Date: Thu, 8 Aug 2024 16:45:24 -0400 Subject: [PATCH 09/12] Use monkeypatch for WorkerPool._sleep override. --- tests/test_bucket_sampling.py | 6 ++---- tests/test_get_ursulas.py | 6 ++---- 2 files changed, 4 insertions(+), 8 deletions(-) diff --git a/tests/test_bucket_sampling.py b/tests/test_bucket_sampling.py index 1296314..61b24f5 100644 --- a/tests/test_bucket_sampling.py +++ b/tests/test_bucket_sampling.py @@ -24,15 +24,13 @@ def json(self): @pytest.fixture(autouse=True) -def mock_worker_pool_sleep(): - +def mock_worker_pool_sleep(monkeypatch): original = WorkerPool._sleep def _sleep(worker_pool, timeout): original(worker_pool, 0.01) - pass - WorkerPool._sleep = _sleep + monkeypatch.setattr(WorkerPool, "_sleep", _sleep) def test_bucket_sampling_schema(get_random_checksum_address): diff --git a/tests/test_get_ursulas.py b/tests/test_get_ursulas.py index 10233a1..120eeee 100644 --- a/tests/test_get_ursulas.py +++ b/tests/test_get_ursulas.py @@ -10,15 +10,13 @@ @pytest.fixture(autouse=True) -def mock_worker_pool_sleep(): - +def mock_worker_pool_sleep(monkeypatch): original = WorkerPool._sleep def _sleep(worker_pool, timeout): original(worker_pool, 0.01) - pass - WorkerPool._sleep = _sleep + monkeypatch.setattr(WorkerPool, "_sleep", _sleep) def test_get_ursulas_schema(get_random_checksum_address): From bd537c6acf26cd24c1a49b213a6fea34d306373e Mon Sep 17 00:00:00 2001 From: derekpierre Date: Thu, 8 Aug 2024 16:50:54 -0400 Subject: [PATCH 10/12] Use custom exception for Ursula having a version that is too old. --- porter/main.py | 19 ++++++++++--------- tests/test_bucket_sampling.py | 10 ++++++++-- tests/test_get_ursulas.py | 10 ++++++++-- 3 files changed, 26 insertions(+), 13 deletions(-) diff --git a/porter/main.py b/porter/main.py index 2a5f4d4..22a104f 100644 --- a/porter/main.py +++ b/porter/main.py @@ -101,6 +101,12 @@ class DecryptOutcome(NamedTuple): ] errors: Dict[ChecksumAddress, str] + class UrsulaVersionTooOld(Exception): + def __init__(self, ursula_address: str, version: str, min_version: str): + super().__init__( + f"Ursula ({ursula_address}) version is too old ({version} < {min_version})" + ) + def __init__( self, eth_endpoint: str, @@ -197,14 +203,12 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo: ursula_address = to_checksum_address(ursula_address) ursula = self.known_nodes[ursula_address] try: - # ensure node is up and reachable and check version + # ensure node is up and reachable and possibly check version version = self._get_ursula_version(ursula) if parse_min_version and not self._is_version_greater_or_equal( parse_min_version, version ): - raise ValueError( - f"Ursula ({ursula_address}) has too old version ({version})" - ) + raise self.UrsulaVersionTooOld(ursula_address, version, min_version) return Porter.UrsulaInfo( checksum_address=ursula_address, @@ -447,15 +451,12 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress: ursula_address = to_checksum_address(ursula_address) ursula = self.known_nodes[ursula_address] try: - # ensure node is up and reachable - # self.network_middleware.ping(ursula) + # ensure node is up and reachable and possibly check version version = self._get_ursula_version(ursula) if parse_min_version and not self._is_version_greater_or_equal( parse_min_version, version ): - raise ValueError( - f"Ursula ({ursula_address}) has too old version ({version})" - ) + raise self.UrsulaVersionTooOld(ursula_address, version, min_version) return ursula_address except Exception as e: diff --git a/tests/test_bucket_sampling.py b/tests/test_bucket_sampling.py index 61b24f5..7b257d2 100644 --- a/tests/test_bucket_sampling.py +++ b/tests/test_bucket_sampling.py @@ -358,7 +358,10 @@ def test_bucket_sampling_web_interface( response = porter_web_controller.get( "/bucket_sampling", data=json.dumps(failed_ursula_params) ) - assert "has too old version (1.1.1)" in response.text + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"}) response = porter_web_controller.get( @@ -373,5 +376,8 @@ def test_bucket_sampling_web_interface( response = porter_web_controller.get( "/bucket_sampling", data=json.dumps(failed_ursula_params) ) - assert "has too old version (1.1.1)" in response.text + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) porter.network_middleware.clean_ursulas_versions() diff --git a/tests/test_get_ursulas.py b/tests/test_get_ursulas.py index 120eeee..294100b 100644 --- a/tests/test_get_ursulas.py +++ b/tests/test_get_ursulas.py @@ -440,7 +440,10 @@ def test_get_ursulas_web_interface( response = porter_web_controller.get( "/get_ursulas", data=json.dumps(failed_ursula_params) ) - assert "has too old version (1.1.1)" in response.text + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) porter.network_middleware.set_ursulas_versions({include_ursulas[0]: "3.0.0"}) response = porter_web_controller.get( @@ -455,5 +458,8 @@ def test_get_ursulas_web_interface( response = porter_web_controller.get( "/get_ursulas", data=json.dumps(failed_ursula_params) ) - assert "has too old version (1.1.1)" in response.text + assert ( + f"version is too old (1.1.1 < {failed_ursula_params['min_version']})" + in response.text + ) porter.network_middleware.clean_ursulas_versions() From 78f0d93a270a6731d305025a806aa70e1810fee7 Mon Sep 17 00:00:00 2001 From: derekpierre Date: Fri, 9 Aug 2024 08:18:08 -0400 Subject: [PATCH 11/12] Fix bug where node doesn't have a bucket when we try to mark it as unsuccessful. Annotate what node dictionaries represent. --- porter/main.py | 13 ++++++++----- 1 file changed, 8 insertions(+), 5 deletions(-) diff --git a/porter/main.py b/porter/main.py index 22a104f..767a69a 100644 --- a/porter/main.py +++ b/porter/main.py @@ -392,8 +392,10 @@ def __init__(self, _reservoir, need_successes: int): self.reservoir = _reservoir self.need_successes = need_successes self.predefined_buckets = self.read_buckets() - self.bucketed_nodes = defaultdict(list) - self.selected_nodes = dict() + self.bucketed_nodes = defaultdict( + list + ) # -> + self.selected_nodes = dict() # -> def read_buckets(self) -> Dict: try: @@ -420,9 +422,10 @@ def find_bucket(self, node): return bucket_name return None - def mark_as_not_successful(self, failure: ChecksumAddress): - bucket = self.selected_nodes[failure] - self.bucketed_nodes[bucket].remove(failure) + def mark_as_not_successful(self, unsuccessful_node: ChecksumAddress): + bucket = self.selected_nodes.get(unsuccessful_node) + if bucket: + self.bucketed_nodes[bucket].remove(unsuccessful_node) def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: batch = [] From 702334ad7f3777abe3ca97ec62e4eb3fc02f6953 Mon Sep 17 00:00:00 2001 From: Victoria Zotova Date: Fri, 9 Aug 2024 09:57:18 -0400 Subject: [PATCH 12/12] Added min_version parameter for /get_ursulas in README --- README.rst | 3 +++ 1 file changed, 3 insertions(+) diff --git a/README.rst b/README.rst index 27c3f18..a26bf08 100644 --- a/README.rst +++ b/README.rst @@ -339,6 +339,9 @@ Parameters | | | | are greater than this max default value are | | | | | capped at the default value | +----------------------------------+------------------+------------------------------------------------+ +| ``min_version`` | *(Optional)* | | Minimum acceptable version of Ursula. | +| | VersionString | | | ++----------------------------------+------------------+------------------------------------------------+ Returns