Skip to content

Commit

Permalink
update AwsClientParameters validation for verify (#15574)
Browse files Browse the repository at this point in the history
  • Loading branch information
zzstoatzz authored Oct 7, 2024
1 parent 6590d85 commit 53f67a7
Show file tree
Hide file tree
Showing 2 changed files with 32 additions and 9 deletions.
7 changes: 5 additions & 2 deletions src/integrations/prefect-aws/prefect_aws/client_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -48,8 +48,8 @@ class AwsClientParameters(BaseModel):
use_ssl: bool = Field(
default=True, description="Whether or not to use SSL.", title="Use SSL"
)
verify: Union[bool, FilePath] = Field(
default=True, description="Whether or not to verify SSL certificates."
verify: Union[bool, FilePath, None] = Field(
default=None, description="Whether or not to verify SSL certificates."
)
verify_cert_path: Optional[FilePath] = Field(
default=None,
Expand Down Expand Up @@ -154,6 +154,9 @@ def get_params_override(self) -> Dict[str, Any]:
params_override[key].signature_version = UNSIGNED
elif key == "verify_cert_path":
params_override["verify"] = value
elif key == "verify":
if value is not None:
params_override[key] = value
else:
params_override[key] = value
return params_override
34 changes: 27 additions & 7 deletions src/integrations/prefect-aws/tests/test_client_parameters.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,7 +10,7 @@ class TestAwsClientParameters:
@pytest.mark.parametrize(
"params,result",
[
(AwsClientParameters(), {}),
(AwsClientParameters(), {"use_ssl": True}),
(
AwsClientParameters(
use_ssl=False, verify=False, endpoint_url="http://localhost:9000"
Expand All @@ -23,21 +23,17 @@ class TestAwsClientParameters:
),
(
AwsClientParameters(endpoint_url="https://localhost:9000"),
{"endpoint_url": "https://localhost:9000"},
{"use_ssl": True, "endpoint_url": "https://localhost:9000"},
),
(
AwsClientParameters(api_version="1.0.0"),
{"api_version": "1.0.0"},
{"use_ssl": True, "api_version": "1.0.0"},
),
],
)
def test_get_params_override_expected_output(
self, params: AwsClientParameters, result: Dict[str, Any], tmp_path
):
if "use_ssl" not in result:
result["use_ssl"] = True
if "verify" not in result:
result["verify"] = True
assert result == params.get_params_override()

@pytest.mark.parametrize(
Expand Down Expand Up @@ -131,3 +127,27 @@ def test_get_params_override_with_both_cert_path(self, tmp_path):
)
override_params = params.get_params_override()
assert override_params["verify"] == cert_path

def test_get_params_override_with_default_verify(self):
params = AwsClientParameters()
override_params = params.get_params_override()
assert (
"verify" not in override_params
), "verify should not be in params_override when not explicitly set"

def test_get_params_override_with_explicit_verify(self):
params_true = AwsClientParameters(verify=True)
params_false = AwsClientParameters(verify=False)

override_params_true = params_true.get_params_override()
override_params_false = params_false.get_params_override()

assert (
"verify" in override_params_true
), "verify should be in params_override when explicitly set to True"
assert override_params_true["verify"] is True

assert (
"verify" in override_params_false
), "verify should be in params_override when explicitly set to False"
assert override_params_false["verify"] is False

0 comments on commit 53f67a7

Please sign in to comment.