Skip to content

Commit

Permalink
First pass at bucket sampling on Porter
Browse files Browse the repository at this point in the history
  • Loading branch information
cygnusv committed Jan 17, 2024
1 parent 3b51fd6 commit 2d02642
Show file tree
Hide file tree
Showing 3 changed files with 207 additions and 3 deletions.
18 changes: 18 additions & 0 deletions porter/interfaces.py
Original file line number Diff line number Diff line change
Expand Up @@ -93,3 +93,21 @@ def decrypt(
)
response_data = {"decryption_results": decrypt_outcome}
return response_data

@attach_schema(schema.BucketSampling)
def bucket_sampling(
self,
quantity: int,
random_seed: Optional[int] = None,
exclude_ursulas: Optional[List[ChecksumAddress]] = None,
timeout: Optional[int] = None,
) -> Dict:
ursulas_info, block_number = self.implementer.bucket_sampling(
quantity=quantity,
random_seed=random_seed,
exclude_ursulas=exclude_ursulas,
timeout=timeout,
)

response_data = {"ursulas": ursulas_info, "block_number": block_number}
return response_data
138 changes: 136 additions & 2 deletions porter/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,9 @@
import os

from collections import defaultdict
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Sequence, Union
from random import Random
from typing import Dict, List, NamedTuple, Optional, Sequence, Tuple, Union

from constant_sorrow.constants import NO_CONTROL_PROTOCOL
from eth_typing import ChecksumAddress
Expand All @@ -10,7 +13,7 @@
ContractAgency,
TACoChildApplicationAgent,
)
from nucypher.blockchain.eth.domains import DEFAULT_DOMAIN, TACoDomain
from nucypher.blockchain.eth.domains import DEFAULT_DOMAIN, MAINNET, TACoDomain
from nucypher.blockchain.eth.interfaces import BlockchainInterfaceFactory
from nucypher.blockchain.eth.registry import ContractRegistry
from nucypher.characters.lawful import Ursula
Expand Down Expand Up @@ -278,6 +281,130 @@ def _make_reservoir(
include_addresses=include_ursulas,
)

def bucket_sampling(
self,
quantity: int,
random_seed: Optional[int] = None,
exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None,
timeout: Optional[int] = None,
) -> Tuple[List[UrsulaInfo], int]:
timeout = self._configure_timeout(
"sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT
)

if self.domain != MAINNET:
raise ValueError("Bucket sampling is only for TACo Mainnet")

class BucketStakingProvidersReservoir:
def __init__(
self,
staking_provider_map: Dict[ChecksumAddress, int],
seed: Optional[int] = None
):
self._providers = list(staking_provider_map.keys())
self._rng = Random(seed)

def __len__(self):
return len(self._providers)

def draw(self, _quantity):
if _quantity > len(self._providers):
raise ValueError(
f"Cannot sample {_quantity} out of {len(self._providers)} total staking providers"
)
return self._rng.sample(self._providers, k=_quantity)

def __call__(self) -> Optional[ChecksumAddress]:
if len(self._providers) > 0:
return self.draw(1)[0]
else:
return None

block_number = self.taco_child_application_agent.blockchain.client.block_number
_, sp_map = self.taco_child_application_agent.get_all_active_staking_providers()
for e in exclude_ursulas or []:
if e in sp_map:
del sp_map[e]

if len(sp_map) < quantity:
raise ValueError(
f"Insufficient nodes ({len(sp_map)}) from which to sample {quantity}"
)

reservoir = BucketStakingProvidersReservoir(sp_map, random_seed)

class BucketPrefetchStrategy:
BUCKET_CAP = 2

def __init__(self, _reservoir, need_successes: int):
self.reservoir = _reservoir
self.need_successes = need_successes
self.predefined_buckets = {}
self.bucketed_nodes = defaultdict(list)

def find_bucket(self, node):
for bucket_name, bucket in self.predefined_buckets.items():
if node in bucket:
return bucket_name
return None

def __call__(self, _successes: int) -> Optional[List[ChecksumAddress]]:
batch = []
batch_size = self.need_successes - _successes
while len(batch) < batch_size:
selected = self.reservoir()
if selected is None:
break
bucket = self.find_bucket(selected)
if bucket:
if len(self.bucketed_nodes[bucket]) >= self.BUCKET_CAP:
continue
self.bucketed_nodes[bucket].append(selected)
batch.append(selected)
if not batch:
return None
return batch

value_factory = BucketPrefetchStrategy(reservoir, quantity)

# TODO: same function that in get_ursulas
def get_ursula_info(ursula_address) -> Porter.UrsulaInfo:
if to_checksum_address(ursula_address) not in self.known_nodes:
raise ValueError(f"{ursula_address} is not known")

ursula_address = to_checksum_address(ursula_address)
ursula = self.known_nodes[ursula_address]
try:
# ensure node is up and reachable
self.network_middleware.ping(ursula)
return Porter.UrsulaInfo(checksum_address=ursula_address,
uri=f"{ursula.rest_interface.formal_uri}",
encrypting_key=ursula.public_keys(DecryptingPower))
except Exception as e:
self.log.debug(f"Ursula ({ursula_address}) is unreachable: {str(e)}")
raise

self.block_until_number_of_known_nodes_is(
quantity, timeout=timeout, learn_on_this_thread=True, eager=True
)

worker_pool = WorkerPool(
worker=get_ursula_info,
value_factory=value_factory,
target_successes=quantity,
timeout=timeout,
stagger_timeout=1,
)
worker_pool.start()
try:
successes = worker_pool.block_until_target_successes()
finally:
worker_pool.cancel()
# don't wait for it to stop by "joining" - too slow...

ursulas_info = successes.values()
return list(ursulas_info), block_number

def make_cli_controller(self, crash_on_error: bool = False):
controller = PorterCLIController(app_name=self.APP_NAME,
crash_on_error=crash_on_error,
Expand Down Expand Up @@ -363,4 +490,11 @@ def decrypt() -> Response:
response = controller(method_name="decrypt", control_request=request)
return response

@porter_flask_control.route('/bucket_sampling', methods=['GET'])
@by_path_counter
def bucket_sampling() -> Response:
"""Porter control endpoint for sampling Ursulas with provider caps (a.k.a. bucket sampling)"""
response = controller(method_name='bucket_sampling', control_request=request)
return response

return controller
54 changes: 53 additions & 1 deletion porter/schema.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@
from marshmallow import fields as marshmallow_fields

from porter.cli.types import EIP55_CHECKSUM_ADDRESS
from porter.fields.base import JSON, PositiveInteger, StringList
from porter.fields.base import Integer, JSON, PositiveInteger, StringList
from porter.fields.exceptions import InvalidArgumentCombo, InvalidInputData
from porter.fields.retrieve import CapsuleFrag, RetrievalKit
from porter.fields.taco import (
Expand Down Expand Up @@ -285,3 +285,55 @@ def check_valid_threshold_and_requests(self, data, **kwargs):
raise InvalidArgumentCombo(
"Number of provided requests must be >= the expected threshold"
)


class BucketSampling(BaseSchema):
quantity = PositiveInteger(
required=True,
load_only=True,
click=click.option(
'--quantity',
'-n',
help="Total number of Ursulas needed",
type=click.INT, required=True))

# optional
random_seed = Integer(
required=False,
load_only=True,
click=click.option(
"--seed",
help="Random seed for sampling",
type=click.INT,
required=False,
),
)

exclude_ursulas = StringList(
UrsulaChecksumAddress(),
click=click.option(
'--exclude-ursula',
'-e',
help="Ursula checksum address to exclude from sample",
multiple=True,
type=EIP55_CHECKSUM_ADDRESS,
required=False,
default=[]),
required=False,
load_only=True)

timeout = PositiveInteger(
required=False,
load_only=True,
click=click.option(
"--timeout",
"-t",
help="Timeout for getting the required quantity of ursulas",
type=click.INT,
required=False,
),
)

# output
ursulas = marshmallow_fields.List(marshmallow_fields.Nested(UrsulaInfoSchema), dump_only=True)
block_number = marshmallow_fields.Int()

0 comments on commit 2d02642

Please sign in to comment.