diff --git a/porter/interfaces.py b/porter/interfaces.py index 894b876..a42b16e 100644 --- a/porter/interfaces.py +++ b/porter/interfaces.py @@ -93,3 +93,21 @@ def decrypt( ) response_data = {"decryption_results": decrypt_outcome} return response_data + + @attach_schema(schema.BucketSampling) + def bucket_sampling( + self, + quantity: int, + random_seed: Optional[int] = None, + exclude_ursulas: Optional[List[ChecksumAddress]] = None, + timeout: Optional[int] = None, + ) -> Dict: + ursulas_info, block_number = self.implementer.bucket_sampling( + quantity=quantity, + random_seed=random_seed, + exclude_ursulas=exclude_ursulas, + timeout=timeout, + ) + + response_data = {"ursulas": ursulas_info, "block_number": block_number} + return response_data diff --git a/porter/main.py b/porter/main.py index 3beb5a7..f84ddcc 100644 --- a/porter/main.py +++ b/porter/main.py @@ -1,6 +1,9 @@ import os + +from collections import defaultdict from pathlib import Path -from typing import Dict, List, NamedTuple, Optional, Sequence, Union +from random import Random +from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union from constant_sorrow.constants import NO_CONTROL_PROTOCOL from eth_typing import ChecksumAddress @@ -10,7 +13,7 @@ ContractAgency, TACoChildApplicationAgent, ) -from nucypher.blockchain.eth.domains import DEFAULT_DOMAIN, TACoDomain +from nucypher.blockchain.eth.domains import DEFAULT_DOMAIN, MAINNET, TACoDomain from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory from nucypher.blockchain.eth.registry import ContractRegistry from nucypher.characters.lawful import Ursula @@ -278,6 +281,130 @@ def _make_reservoir( include_addresses=include_ursulas, ) + def bucket_sampling( + self, + quantity: int, + random_seed: Optional[int] = None, + exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None, + timeout: Optional[int] = None, + ) -> Tuple[List[UrsulaInfo], int]: + timeout = self._configure_timeout( + "sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT + ) + + if self.domain != MAINNET: + raise ValueError("Bucket sampling is only for TACo Mainnet") + + class BucketStakingProvidersReservoir: + def __init__( + self, + staking_provider_map: Dict[ChecksumAddress, int], + seed: Optional[int] = None + ): + self._providers = list(staking_provider_map.keys()) + self._rng = Random(seed) + + def __len__(self): + return len(self._providers) + + def draw(self, _quantity): + if _quantity > len(self._providers): + raise ValueError( + f"Cannot sample {_quantity} out of {len(self._providers)} total staking providers" + ) + return self._rng.sample(self._providers, k=_quantity) + + def __call__(self) -> Optional[ChecksumAddress]: + if len(self._providers) > 0: + return self.draw(1)[0] + else: + return None + + block_number = self.taco_child_application_agent.blockchain.client.block_number + _, sp_map = self.taco_child_application_agent.get_all_active_staking_providers() + for e in exclude_ursulas or []: + if e in sp_map: + del sp_map[e] + + if len(sp_map) < quantity: + raise ValueError( + f"Insufficient nodes ({len(sp_map)}) from which to sample {quantity}" + ) + + reservoir = BucketStakingProvidersReservoir(sp_map, random_seed) + + class BucketPrefetchStrategy: + BUCKET_CAP = 2 + + def __init__(self, _reservoir, need_successes: int): + self.reservoir = _reservoir + self.need_successes = need_successes + self.predefined_buckets = {} + self.bucketed_nodes = defaultdict(list) + + def find_bucket(self, node): + for bucket_name, bucket in self.predefined_buckets.items(): + if node in bucket: + return bucket_name + return None + + def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]: + batch = [] + batch_size = self.need_successes - _successes + while len(batch) < batch_size: + selected = self.reservoir() + if selected is None: + break + bucket = self.find_bucket(selected) + if bucket: + if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP: + continue + self.bucketed_nodes[bucket].append(selected) + batch.append(selected) + if not batch: + return None + return batch + + value_factory = BucketPrefetchStrategy(reservoir, quantity) + + # TODO: same function that in get_ursulas + def get_ursula_info(ursula_address) -> Porter.UrsulaInfo: + if to_checksum_address(ursula_address) not in self.known_nodes: + raise ValueError(f"{ursula_address} is not known") + + ursula_address = to_checksum_address(ursula_address) + ursula = self.known_nodes[ursula_address] + try: + # ensure node is up and reachable + self.network_middleware.ping(ursula) + return Porter.UrsulaInfo(checksum_address=ursula_address, + uri=f"{ursula.rest_interface.formal_uri}", + encrypting_key=ursula.public_keys(DecryptingPower)) + except Exception as e: + self.log.debug(f"Ursula ({ursula_address}) is unreachable: {str(e)}") + raise + + self.block_until_number_of_known_nodes_is( + quantity, timeout=timeout, learn_on_this_thread=True, eager=True + ) + + worker_pool = WorkerPool( + worker=get_ursula_info, + value_factory=value_factory, + target_successes=quantity, + timeout=timeout, + stagger_timeout=1, + ) + worker_pool.start() + try: + successes = worker_pool.block_until_target_successes() + finally: + worker_pool.cancel() + # don't wait for it to stop by "joining" - too slow... + + ursulas_info = successes.values() + return list(ursulas_info), block_number + def make_cli_controller(self, crash_on_error: bool = False): controller = PorterCLIController(app_name=self.APP_NAME, crash_on_error=crash_on_error, @@ -363,4 +490,11 @@ def decrypt() -> Response: response = controller(method_name="decrypt", control_request=request) return response + @porter_flask_control.route('/bucket_sampling', methods=['GET']) + @by_path_counter + def bucket_sampling() -> Response: + """Porter control endpoint for sampling Ursulas with provider caps (a.k.a. bucket sampling)""" + response = controller(method_name='bucket_sampling', control_request=request) + return response + return controller diff --git a/porter/schema.py b/porter/schema.py index cecb9d7..90b21d1 100644 --- a/porter/schema.py +++ b/porter/schema.py @@ -3,7 +3,7 @@ from marshmallow import fields as marshmallow_fields from porter.cli.types import EIP55_CHECKSUM_ADDRESS -from porter.fields.base import JSON, PositiveInteger, StringList +from porter.fields.base import Integer, JSON, PositiveInteger, StringList from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData from porter.fields.retrieve import CapsuleFrag, RetrievalKit from porter.fields.taco import ( @@ -285,3 +285,55 @@ def check_valid_threshold_and_requests(self, data, **kwargs): raise InvalidArgumentCombo( "Number of provided requests must be >= the expected threshold" ) + + +class BucketSampling(BaseSchema): + quantity = PositiveInteger( + required=True, + load_only=True, + click=click.option( + '--quantity', + '-n', + help="Total number of Ursulas needed", + type=click.INT, required=True)) + + # optional + random_seed = Integer( + required=False, + load_only=True, + click=click.option( + "--seed", + help="Random seed for sampling", + type=click.INT, + required=False, + ), + ) + + exclude_ursulas = StringList( + UrsulaChecksumAddress(), + click=click.option( + '--exclude-ursula', + '-e', + help="Ursula checksum address to exclude from sample", + multiple=True, + type=EIP55_CHECKSUM_ADDRESS, + required=False, + default=[]), + required=False, + load_only=True) + + timeout = PositiveInteger( + required=False, + load_only=True, + click=click.option( + "--timeout", + "-t", + help="Timeout for getting the required quantity of ursulas", + type=click.INT, + required=False, + ), + ) + + # output + ursulas = marshmallow_fields.List(marshmallow_fields.Nested(UrsulaInfoSchema), dump_only=True) + block_number = marshmallow_fields.Int()