Skip to content

Commit

Permalink
Add user provided SSE-C arguments to CompleteMultipartUpload call
Browse files Browse the repository at this point in the history
  • Loading branch information
nateprewitt committed Sep 22, 2023
1 parent 1eef558 commit b0bee85
Show file tree
Hide file tree
Showing 6 changed files with 48 additions and 7 deletions.
5 changes: 5 additions & 0 deletions .changes/next-release/bugfix-SSEC-19962.json
Original file line number Diff line number Diff line change
@@ -0,0 +1,5 @@
{
"type": "bugfix",
"category": "``SSE-C``",
"description": "Pass SSECustomer* arguements to CompleteMultipartUpload for upload operations\""
}
30 changes: 26 additions & 4 deletions s3transfer/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -380,6 +380,12 @@ class MultipartUploader:
'RequestPayer',
]

COMPLETE_MULTIPART_ARGS = [
'SSECustomerKey',
'SSECustomerAlgorithm',
'RequestPayer',
]

def __init__(
self,
client,
Expand All @@ -395,11 +401,25 @@ def __init__(
def _extra_upload_part_args(self, extra_args):
# Only the args in UPLOAD_PART_ARGS actually need to be passed
# onto the upload_part calls.
upload_parts_args = {}
return self._filter_extra_args(
extra_args,
self.UPLOAD_PART_ARGS,
)

def _extra_complete_multipart_upload_args(self, extra_args):
# Only the args in COMPLETE_MULTIPART_ARGS need to be passed
# onto the complete_multipart_upload call.
return self._filter_extra_args(
extra_args,
self.COMPLETE_MULTIPART_ARGS,
)

def _filter_extra_args(self, extra_args, allowed_list):
filtered_args = {}
for key, value in extra_args.items():
if key in self.UPLOAD_PART_ARGS:
upload_parts_args[key] = value
return upload_parts_args
if key in allowed_list:
filtered_args[key] = value
return filtered_args

def upload_file(self, filename, bucket, key, callback, extra_args):
response = self._client.create_multipart_upload(
Expand All @@ -424,11 +444,13 @@ def upload_file(self, filename, bucket, key, callback, extra_args):
filename, '/'.join([bucket, key]), e
)
)
extra_cmu_args = self._extra_complete_multipart_upload_args(extra_args)
self._client.complete_multipart_upload(
Bucket=bucket,
Key=key,
UploadId=upload_id,
MultipartUpload={'Parts': parts},
**extra_cmu_args,
)

def _upload_parts(
Expand Down
7 changes: 6 additions & 1 deletion s3transfer/copies.py
Original file line number Diff line number Diff line change
Expand Up @@ -69,7 +69,12 @@ class CopySubmissionTask(SubmissionTask):
'TaggingDirective',
]

COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
COMPLETE_MULTIPART_ARGS = [
'RequestPayer',
'ExpectedBucketOwner',
'SSECustomerKey',
'SSECustomerAlgorithm',
]

def _submit(
self, client, config, osutil, request_executor, transfer_future
Expand Down
7 changes: 6 additions & 1 deletion s3transfer/upload.py
Original file line number Diff line number Diff line change
Expand Up @@ -521,7 +521,12 @@ class UploadSubmissionTask(SubmissionTask):
'ExpectedBucketOwner',
]

COMPLETE_MULTIPART_ARGS = ['RequestPayer', 'ExpectedBucketOwner']
COMPLETE_MULTIPART_ARGS = [
'RequestPayer',
'ExpectedBucketOwner',
'SSECustomerKey',
'SSECustomerAlgorithm',
]

def _get_upload_input_manager_cls(self, transfer_future):
"""Retrieves a class for managing input for an upload based on file type
Expand Down
4 changes: 3 additions & 1 deletion tests/functional/test_copy.py
Original file line number Diff line number Diff line change
Expand Up @@ -629,7 +629,9 @@ def test_copy_passes_args_to_create_multipart_and_upload_part(self):
self.add_head_object_response(expected_params=head_params)

self._add_params_to_expected_params(
add_copy_kwargs, ['create_mpu', 'copy'], self.extra_args
add_copy_kwargs,
['create_mpu', 'copy', 'complete_mpu'],
self.extra_args,
)
self.add_successful_copy_responses(**add_copy_kwargs)

Expand Down
2 changes: 2 additions & 0 deletions tests/unit/test_s3transfer.py
Original file line number Diff line number Diff line change
Expand Up @@ -351,6 +351,8 @@ def test_multipart_upload_injects_proper_kwargs(self):
Bucket='bucket',
UploadId='upload_id',
Key='key',
SSECustomerKey='fakekey',
SSECustomerAlgorithm='AES256',
)

def test_multipart_upload_is_aborted_on_error(self):
Expand Down

0 comments on commit b0bee85

Please sign in to comment.