Skip to content

Commit

Permalink
Changing mocking in tests for RestMiddleware and testing versioning
Browse files Browse the repository at this point in the history
  • Loading branch information
vzotova committed Aug 8, 2024
1 parent f735076 commit 619ea94
Show file tree
Hide file tree
Showing 5 changed files with 206 additions and 18 deletions.
16 changes: 6 additions & 10 deletions porter/fields/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@

import click
from marshmallow import fields
from packaging.version import Version, parse
from packaging.version import parse

from porter.fields.exceptions import InvalidInputData

Expand Down Expand Up @@ -113,12 +113,8 @@ def _deserialize(self, value, attr, data, **kwargs):

class VersionString(String):

def _serialize(self, value, attr, obj, **kwargs) -> str:
if type(value) is not Version:
raise InvalidInputData(
f"Unexpected object type, {type(value)}; expected Version"
)
return str(value)

def _deserialize(self, value, attr, data, **kwargs) -> list:
return parse(value)
def _validate(self, value):
try:
parse(value)
except Exception:
raise InvalidInputData(f"{self.name} must be a correct version.")
21 changes: 15 additions & 6 deletions porter/main.py
Original file line number Diff line number Diff line change
Expand Up @@ -32,7 +32,7 @@
TreasureMap,
)
from nucypher_core.umbral import PublicKey
from packaging.version import parse
from packaging.version import Version, parse
from prometheus_flask_exporter import PrometheusMetrics

import porter
Expand Down Expand Up @@ -157,7 +157,7 @@ def _initialize_endpoints(eth_endpoint: str, polygon_endpoint: str):
BlockchainInterfaceFactory.initialize_interface(endpoint=polygon_endpoint)

@staticmethod
def _is_version_greater_or_equal(min_version: str, version: str) -> bool:
def _is_version_greater_or_equal(min_version: Version, version: str) -> bool:
return parse(version) >= min_version

def _get_ursula_version(self, ursula: Ursula) -> str:
Expand All @@ -179,6 +179,7 @@ def get_ursulas(
"sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT
)
duration = duration or 0
parse_min_version = parse(min_version) if min_version else None

reservoir = self._make_reservoir(exclude_ursulas, include_ursulas, duration)
available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir)
Expand All @@ -198,8 +199,8 @@ def get_ursula_info(ursula_address) -> Porter.UrsulaInfo:
try:
# ensure node is up and reachable and check version
version = self._get_ursula_version(ursula)
if min_version and not self._is_version_greater_or_equal(
min_version, version
if parse_min_version and not self._is_version_greater_or_equal(
parse_min_version, version
):
raise ValueError(
f"Ursula ({ursula_address}) has too old version ({version})"
Expand Down Expand Up @@ -326,6 +327,7 @@ def bucket_sampling(
"bucket_sampling", timeout, self.MAX_BUCKET_SAMPLING_TIMEOUT
)
duration = duration or 0
parse_min_version = parse(min_version) if min_version else None

if self.domain not in self._ALLOWED_DOMAINS_FOR_BUCKET_SAMPLING:
raise ValueError("Bucket sampling is only for TACo Mainnet")
Expand Down Expand Up @@ -387,6 +389,7 @@ def __init__(self, _reservoir, need_successes: int):
self.need_successes = need_successes
self.predefined_buckets = self.read_buckets()
self.bucketed_nodes = defaultdict(list)
self.selected_nodes = dict()

def read_buckets(self) -> Dict:
try:
Expand All @@ -413,6 +416,10 @@ def find_bucket(self, node):
return bucket_name
return None

def mark_as_not_successful(self, failure: ChecksumAddress):
bucket = self.selected_nodes[failure]
self.bucketed_nodes[bucket].remove(failure)

def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]:
batch = []
batch_size = self.need_successes - _successes
Expand All @@ -425,6 +432,7 @@ def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]:
if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP:
continue
self.bucketed_nodes[bucket].append(selected)
self.selected_nodes[selected] = bucket
batch.append(selected)
if not batch:
return None
Expand All @@ -442,8 +450,8 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress:
# ensure node is up and reachable
# self.network_middleware.ping(ursula)
version = self._get_ursula_version(ursula)
if min_version and not self._is_version_greater_or_equal(
min_version, version
if parse_min_version and not self._is_version_greater_or_equal(
parse_min_version, version
):
raise ValueError(
f"Ursula ({ursula_address}) has too old version ({version})"
Expand All @@ -453,6 +461,7 @@ def make_sure_ursula_is_online(ursula_address) -> ChecksumAddress:
except Exception as e:
message = f"Ursula ({ursula_address}) is unreachable: {str(e)}"
self.log.debug(message)
value_factory.mark_as_not_successful(ursula_address)
raise

self.block_until_number_of_known_nodes_is(
Expand Down
49 changes: 47 additions & 2 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,9 +34,11 @@
from tests.constants import (
MOCK_ETH_PROVIDER_URI,
TEMPORARY_DOMAIN,
TEST_ETH_PROVIDER_URI,
TESTERCHAIN_CHAIN_ID,
)
from tests.mock.interfaces import MockBlockchain
from tests.utils.middleware import MockRestMiddleware, _TestMiddlewareClient
from tests.utils.registry import MockRegistrySource, mock_registry_sources

# Crash on server error by default
Expand Down Expand Up @@ -245,9 +247,53 @@ def mock_signer(get_random_checksum_address):
return signer


class _MockMiddlewareClient(_TestMiddlewareClient):
class MockResponse:
def __init__(self, json_data, status_code):
self.json_data = json_data
self.status_code = status_code

def json(self):
return self.json_data

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.ursulas_versions = {}

def get(self, *args, **kwargs):
if kwargs.get("path") == "status" and kwargs.get("params")["json"]:
node_address = kwargs.get("node_or_sprout").checksum_address
version = self.ursulas_versions.get(node_address, "1.1.1")
return _MockMiddlewareClient.MockResponse({"version": version}, 200)

real_get = super(_TestMiddlewareClient, self).__getattr__("get")
return real_get(*args, **kwargs)


class _MockRestMiddleware(MockRestMiddleware):
"""
Modified middleware to emulate returning status with version.
"""

def __init__(self, *args, **kwargs):
super().__init__(*args, **kwargs)
self.client = _MockMiddlewareClient(eth_endpoint=TEST_ETH_PROVIDER_URI)

def set_ursulas_versions(self, ursulas_versions: dict):
self.client.ursulas_versions = ursulas_versions

def clean_ursulas_versions(self):
self.client.ursulas_versions = {}


@pytest.fixture(scope="module")
def mock_rest_middleware():
return _MockRestMiddleware(eth_endpoint=TEST_ETH_PROVIDER_URI)


@pytest.fixture(scope="module")
@pytest.mark.usefixtures('testerchain', 'agency')
def porter(ursulas, mock_rest_middleware, test_registry, module_mocker):
def porter(ursulas, mock_rest_middleware, test_registry):
porter = Porter(
domain=TEMPORARY_DOMAIN,
eth_endpoint=MOCK_ETH_PROVIDER_URI,
Expand All @@ -259,7 +305,6 @@ def porter(ursulas, mock_rest_middleware, test_registry, module_mocker):
verify_node_bonding=False,
network_middleware=mock_rest_middleware,
)
module_mocker.patch.object(porter, "_get_ursula_version", return_value="7.4.0")
yield porter
porter.stop_learning_loop()

Expand Down
67 changes: 67 additions & 0 deletions tests/test_bucket_sampling.py
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,18 @@ def json(self):
mocker.patch("requests.get", return_value=MockRequestResponse())


@pytest.fixture(autouse=True)
def mock_worker_pool_sleep():

original = WorkerPool._sleep

def _sleep(worker_pool, timeout):
original(worker_pool, 0.01)
pass

WorkerPool._sleep = _sleep


def test_bucket_sampling_schema(get_random_checksum_address):
#
# Input i.e. load
Expand Down Expand Up @@ -81,6 +93,11 @@ def test_bucket_sampling_schema(get_random_checksum_address):
updated_data["timeout"] = 20
BucketSampling().load(updated_data)

# min version
updated_data = dict(required_data)
updated_data["min_version"] = "1.1.1"
BucketSampling().load(updated_data)

# list input formatted as ',' separated strings
updated_data = dict(required_data)
updated_data["exclude_ursulas"] = ",".join(exclude_ursulas)
Expand Down Expand Up @@ -133,6 +150,18 @@ def test_bucket_sampling_schema(get_random_checksum_address):
updated_data["duration"] = -1
BucketSampling().load(updated_data)

# invalid min version
with pytest.raises(InvalidInputData):
updated_data = dict(required_data)
updated_data["min_version"] = "v1x1.1"
BucketSampling().load(updated_data)

# invalid min version
with pytest.raises(InvalidInputData):
updated_data = dict(required_data)
updated_data["min_version"] = "1-1-1"
BucketSampling().load(updated_data)

#
# Output i.e. dump
#
Expand Down Expand Up @@ -210,11 +239,22 @@ def test_bucket_sampling_python_interface(
with pytest.raises(WorkerPool.OutOfValues):
_, _ = porter.bucket_sampling(quantity=5)

# no nodes with specified version
with pytest.raises(WorkerPool.OutOfValues):
_, _ = porter.bucket_sampling(quantity=1, timeout=30, min_version="2.2.2")
porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"})
ursulas_info, _ = porter.bucket_sampling(quantity=1, min_version="2.2.2")
assert ursulas_info[0] == sampled_ursulas[0]
with pytest.raises(WorkerPool.OutOfValues):
porter.bucket_sampling(quantity=2, min_version="2.2.2")
porter.network_middleware.clean_ursulas_versions()


@pytest.mark.parametrize("timeout", [None, 10])
@pytest.mark.parametrize("random_seed", [None, 42])
@pytest.mark.parametrize("duration", [None, 0, 60 * 60 * 24, 60 * 60 * 24 * 365])
def test_bucket_sampling_web_interface(
porter,
porter_web_controller,
ursulas,
timeout,
Expand Down Expand Up @@ -310,3 +350,30 @@ def test_bucket_sampling_web_interface(
)
assert response.status_code == 400
assert "Insufficient nodes" in response.text

#
# Failure case: no nodes with specified version
#
failed_ursula_params = dict(get_ursulas_params)
failed_ursula_params["quantity"] = 1
failed_ursula_params["min_version"] = "2.0.0"
response = porter_web_controller.get(
"/bucket_sampling", data=json.dumps(failed_ursula_params)
)
assert "has too old version (1.1.1)" in response.text

porter.network_middleware.set_ursulas_versions({sampled_ursulas[0]: "3.0.0"})
response = porter_web_controller.get(
"/bucket_sampling", data=json.dumps(failed_ursula_params)
)
assert response.status_code == 200
response_data = json.loads(response.data)
ursulas_info = response_data["result"]["ursulas"]
assert ursulas_info[0] == sampled_ursulas[0]

failed_ursula_params["quantity"] = 2
response = porter_web_controller.get(
"/bucket_sampling", data=json.dumps(failed_ursula_params)
)
assert "has too old version (1.1.1)" in response.text
porter.network_middleware.clean_ursulas_versions()
Loading

0 comments on commit 619ea94

Please sign in to comment.