Skip to content

Commit

Permalink
Refactor timeout code so that it can be unit tested.
Browse files Browse the repository at this point in the history
Added unit test for timeout adjustment.
  • Loading branch information
derekpierre committed Dec 12, 2023
1 parent 7164d74 commit fab46b5
Show file tree
Hide file tree
Showing 2 changed files with 65 additions and 19 deletions.
40 changes: 21 additions & 19 deletions porter/main.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
from pathlib import Path
from typing import Dict, List, NamedTuple, Optional, Sequence
from typing import Dict, List, NamedTuple, Optional, Sequence, Union

from constant_sorrow.constants import NO_CONTROL_PROTOCOL
from eth_typing import ChecksumAddress
Expand Down Expand Up @@ -152,15 +152,9 @@ def get_ursulas(
include_ursulas: Optional[Sequence[ChecksumAddress]] = None,
timeout: Optional[int] = None,
) -> List[UrsulaInfo]:
if timeout and timeout > self.MAX_GET_URSULAS_TIMEOUT:
self.log.warn(
f"Provided sampling timeout ({timeout}s) exceeds "
f"maximum ({self.MAX_GET_URSULAS_TIMEOUT}s); "
f"using {self.MAX_GET_URSULAS_TIMEOUT}s instead"
)
timeout = self.MAX_GET_URSULAS_TIMEOUT
else:
timeout = timeout or self.MAX_GET_URSULAS_TIMEOUT
timeout = self._configure_timeout(
"sampling", timeout, self.MAX_GET_URSULAS_TIMEOUT
)

reservoir = self._make_reservoir(exclude_ursulas, include_ursulas)
available_nodes_to_sample = len(reservoir.values) + len(reservoir.reservoir)
Expand Down Expand Up @@ -244,15 +238,9 @@ def decrypt(
timeout: Optional[int] = None,
) -> DecryptOutcome:
decryption_client = ThresholdDecryptionClient(self)
if timeout and timeout > self.MAX_DECRYPTION_TIMEOUT:
self.log.warn(
f"Provided decryption timeout ({timeout}s) exceeds "
f"maximum ({self.MAX_DECRYPTION_TIMEOUT}s); "
f"using {self.MAX_DECRYPTION_TIMEOUT}s instead"
)
timeout = self.MAX_DECRYPTION_TIMEOUT
else:
timeout = timeout or self.MAX_DECRYPTION_TIMEOUT
timeout = self._configure_timeout(
"decryption", timeout, self.MAX_DECRYPTION_TIMEOUT
)

successes, failures = decryption_client.gather_encrypted_decryption_shares(
encrypted_requests=encrypted_decryption_requests,
Expand All @@ -265,6 +253,20 @@ def decrypt(
)
return decrypt_outcome

def _configure_timeout(
self, operation: str, timeout: Union[int, None], max_timeout: int
):
if timeout and timeout > max_timeout:
self.log.warn(
f"Provided {operation} timeout ({timeout}s) exceeds "
f"maximum ({max_timeout}s); "
f"using {max_timeout}s instead"
)
timeout = max_timeout
else:
timeout = timeout or max_timeout
return timeout

def _make_reservoir(
self,
exclude_ursulas: Optional[Sequence[ChecksumAddress]] = None,
Expand Down
44 changes: 44 additions & 0 deletions tests/test_porter_configure_timeout.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@
import pytest

from porter.main import Porter


@pytest.mark.parametrize(
"timeout_scenarios",
[
(None, 10, 10),
(1, 10, 1),
(5, 10, 5),
(9, 10, 9),
(10, 10, 10),
(11, 10, 10),
(20, 10, 10),
(25, 10, 10),
(
Porter.MAX_GET_URSULAS_TIMEOUT - 1,
Porter.MAX_GET_URSULAS_TIMEOUT,
Porter.MAX_GET_URSULAS_TIMEOUT - 1,
),
(
Porter.MAX_GET_URSULAS_TIMEOUT + 1,
Porter.MAX_GET_URSULAS_TIMEOUT,
Porter.MAX_GET_URSULAS_TIMEOUT,
),
(
Porter.MAX_DECRYPTION_TIMEOUT / 2,
Porter.MAX_DECRYPTION_TIMEOUT,
Porter.MAX_DECRYPTION_TIMEOUT / 2,
),
(
Porter.MAX_DECRYPTION_TIMEOUT * 2,
Porter.MAX_DECRYPTION_TIMEOUT,
Porter.MAX_DECRYPTION_TIMEOUT,
),
],
)
def test_porter_configure_timeout_defined_results(porter, timeout_scenarios):
provided_timeout, max_timeout, expected_timeout = timeout_scenarios
resultant_timeout = porter._configure_timeout(
operation="test", timeout=provided_timeout, max_timeout=max_timeout
)
assert resultant_timeout == expected_timeout

0 comments on commit fab46b5

Please sign in to comment.