Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Support for requestPays in java-storage call flow #995

Open
wants to merge 2 commits into
base: master
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,7 @@
@VisibleForTesting
public class GoogleCloudStorageClientImpl extends ForwardingGoogleCloudStorage {
private static final GoogleLogger logger = GoogleLogger.forEnclosingClass();

private static GoogleCloudStorageImpl delegate;
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

make this final?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why is this static?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Having the base class and derived class having same field is an anti pattern

private final GoogleCloudStorageOptions storageOptions;
private final Storage storage;

Expand All @@ -65,6 +65,7 @@ public class GoogleCloudStorageClientImpl extends ForwardingGoogleCloudStorage {
.setNameFormat("gcsio-storage-client-write-channel-pool-%d")
.setDaemon(true)
.build());

/**
* Having an instance of gscImpl to redirect calls to Json client while new client implementation
* is in WIP.
Expand All @@ -78,13 +79,14 @@ public class GoogleCloudStorageClientImpl extends ForwardingGoogleCloudStorage {
@Nullable Function<List<AccessBoundary>, String> downscopedAccessTokenFn)
throws IOException {
super(
GoogleCloudStorageImpl.builder()
.setOptions(options)
.setCredentials(credentials)
.setHttpTransport(httpTransport)
.setHttpRequestInitializer(httpRequestInitializer)
.setDownscopedAccessTokenFn(downscopedAccessTokenFn)
.build());
delegate =
GoogleCloudStorageImpl.builder()
.setOptions(options)
.setCredentials(credentials)
.setHttpTransport(httpTransport)
.setHttpRequestInitializer(httpRequestInitializer)
.setDownscopedAccessTokenFn(downscopedAccessTokenFn)
.build());
this.storageOptions = options;
this.storage =
clientLibraryStorage == null ? createStorage(credentials, options) : clientLibraryStorage;
Expand All @@ -108,7 +110,12 @@ public WritableByteChannel create(StorageResourceId resourceId, CreateObjectOpti

GoogleCloudStorageClientWriteChannel channel =
new GoogleCloudStorageClientWriteChannel(
storage, storageOptions, resourceIdWithGeneration, options, backgroundTasksThreadPool);
storage,
storageOptions,
resourceIdWithGeneration,
options,
requesterShouldPay(resourceIdWithGeneration.getBucketName()),
backgroundTasksThreadPool);
channel.initialize();
return channel;
}
Expand Down Expand Up @@ -144,7 +151,8 @@ private SeekableByteChannel open(
itemInfo == null ? getItemInfo(resourceId) : itemInfo,
readOptions,
errorExtractor,
storageOptions);
storageOptions,
requesterShouldPay(resourceId.getBucketName()));
}

@Override
Expand All @@ -160,6 +168,11 @@ public void close() {
}
}

@Override
public GoogleCloudStorageImpl getDelegate() {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

this is overriding the same method as base class?

return delegate;
}

/**
* Gets the object generation for a write operation
*
Expand Down Expand Up @@ -195,6 +208,10 @@ private static Storage createStorage(
.getService();
}

private boolean requesterShouldPay(String bucketName) throws IOException {
return getDelegate().requesterShouldPay(bucketName);
}

public static Builder builder() {
return new AutoBuilder_GoogleCloudStorageClientImpl_Builder();
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -58,6 +58,7 @@ class GoogleCloudStorageClientReadChannel implements SeekableByteChannel {
// The size of this object generation, in bytes.
private final long objectSize;
private final ErrorTypeExtractor errorExtractor;
private final boolean requesterPays;
private ContentReadChannel contentReadChannel;

private boolean open = true;
Expand All @@ -71,16 +72,18 @@ public GoogleCloudStorageClientReadChannel(
GoogleCloudStorageItemInfo itemInfo,
GoogleCloudStorageReadOptions readOptions,
ErrorTypeExtractor errorExtractor,
GoogleCloudStorageOptions storageOptions)
GoogleCloudStorageOptions storageOptions,
boolean requesterPays)
throws IOException {
validate(itemInfo);
this.storage = storage;
this.errorExtractor = errorExtractor;
this.resourceId =
new StorageResourceId(
itemInfo.getBucketName(), itemInfo.getObjectName(), itemInfo.getContentGeneration());
this.readOptions = readOptions;
this.requesterPays = requesterPays;
this.storageOptions = storageOptions;
this.readOptions = readOptions;
this.objectSize = itemInfo.getSize();
this.contentReadChannel = new ContentReadChannel(readOptions, resourceId);
}
Expand Down Expand Up @@ -502,10 +505,16 @@ private BlobSourceOption[] generateReadOptions(BlobId blobId) {
if (blobId.getGeneration() != null) {
blobReadOptions.add(BlobSourceOption.generationMatch(blobId.getGeneration()));
}

if (storageOptions.getEncryptionKey() != null) {
blobReadOptions.add(
BlobSourceOption.decryptionKey(storageOptions.getEncryptionKey().value()));
}

if (requesterPays) {
blobReadOptions.add(
BlobSourceOption.userProject(storageOptions.getRequesterPaysOptions().getProjectId()));
}
return blobReadOptions.toArray(new BlobSourceOption[blobReadOptions.size()]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@

import com.google.cloud.WriteChannel;
import com.google.cloud.hadoop.util.AbstractGoogleAsyncWriteChannel;
import com.google.cloud.hadoop.util.AsyncWriteChannelOptions;
import com.google.cloud.storage.BlobId;
import com.google.cloud.storage.BlobInfo;
import com.google.cloud.storage.Storage;
Expand All @@ -46,18 +45,18 @@ class GoogleCloudStorageClientWriteChannel extends AbstractGoogleAsyncWriteChann
private final StorageResourceId resourceId;
private WriteChannel writeChannel;
private boolean uploadSucceeded = false;
// TODO: not supported as of now
// private final String requesterPaysProject;

public GoogleCloudStorageClientWriteChannel(
Storage storage,
GoogleCloudStorageOptions storageOptions,
StorageResourceId resourceId,
CreateObjectOptions createOptions,
boolean requesterPays,
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

wont this break the public API? Maybe add a new constructor?

Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

consider renaming this to something like requesterPaysEnabled or something similar

ExecutorService uploadThreadPool) {
super(uploadThreadPool, storageOptions.getWriteChannelOptions());
this.resourceId = resourceId;
this.writeChannel = getClientWriteChannel(storage, resourceId, createOptions, storageOptions);
this.writeChannel =
getClientWriteChannel(storage, resourceId, createOptions, storageOptions, requesterPays);
}

@Override
Expand Down Expand Up @@ -90,14 +89,13 @@ private static WriteChannel getClientWriteChannel(
Storage storage,
StorageResourceId resourceId,
CreateObjectOptions createOptions,
GoogleCloudStorageOptions storageOptions) {
AsyncWriteChannelOptions channelOptions = storageOptions.getWriteChannelOptions();
GoogleCloudStorageOptions storageOptions,
boolean requesterPays) {
WriteChannel writeChannel =
storage.writer(
getBlobInfo(resourceId, createOptions),
generateWriteOptions(createOptions, storageOptions));
writeChannel.setChunkSize(channelOptions.getUploadChunkSize());

generateWriteOptions(createOptions, storageOptions, requesterPays));
writeChannel.setChunkSize(storageOptions.getWriteChannelOptions().getUploadChunkSize());
return writeChannel;
}

Expand Down Expand Up @@ -156,9 +154,10 @@ public Boolean call() throws Exception {
}

private static BlobWriteOption[] generateWriteOptions(
CreateObjectOptions createOptions, GoogleCloudStorageOptions storageOptions) {
CreateObjectOptions createOptions,
GoogleCloudStorageOptions storageOptions,
boolean requesterPays) {
List<BlobWriteOption> blobWriteOptions = new ArrayList<>();

blobWriteOptions.add(BlobWriteOption.disableGzipContent());
blobWriteOptions.add(BlobWriteOption.generationMatch());
if (createOptions.getKmsKeyName() != null) {
Expand All @@ -171,6 +170,10 @@ private static BlobWriteOption[] generateWriteOptions(
blobWriteOptions.add(
BlobWriteOption.encryptionKey(storageOptions.getEncryptionKey().value()));
}
if (requesterPays) {
blobWriteOptions.add(
BlobWriteOption.userProject(storageOptions.getRequesterPaysOptions().getProjectId()));
}
return blobWriteOptions.toArray(new BlobWriteOption[blobWriteOptions.size()]);
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2260,7 +2260,7 @@ private <RequestT extends StorageRequest<?>> void setRequesterPaysProject(
}
}

private boolean requesterShouldPay(String bucketName) {
protected boolean requesterShouldPay(String bucketName) {
if (bucketName == null) {
return false;
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,13 +33,17 @@
import com.google.cloud.hadoop.gcsio.GoogleCloudStorageReadOptions.Fadvise;
import com.google.cloud.hadoop.gcsio.integration.GoogleCloudStorageTestHelper;
import com.google.cloud.hadoop.util.GrpcErrorTypeExtractor;
import com.google.cloud.hadoop.util.RequesterPaysOptions;
import com.google.cloud.storage.BlobId;
import com.google.cloud.storage.Storage;
import com.google.cloud.storage.Storage.BlobSourceOption;
import com.google.common.collect.ImmutableList;
import com.google.protobuf.ByteString;
import java.io.EOFException;
import java.io.IOException;
import java.nio.ByteBuffer;
import java.nio.channels.ClosedChannelException;
import java.util.List;
import java.util.Random;
import org.junit.Before;
import org.junit.Test;
Expand All @@ -60,6 +64,10 @@ public class GoogleCloudStorageClientReadChannelTest {
private static final ByteString CONTENT =
GoogleCloudStorageTestHelper.createTestData(OBJECT_SIZE);

private ArgumentCaptor<BlobId> blobIdCaptor = ArgumentCaptor.forClass(BlobId.class);
private ArgumentCaptor<BlobSourceOption> blobSourceOptionCaptor =
ArgumentCaptor.forClass(BlobSourceOption.class);

private static final GoogleCloudStorageReadOptions DEFAULT_READ_OPTION =
GoogleCloudStorageReadOptions.builder()
.setFadvise(Fadvise.RANDOM)
Expand Down Expand Up @@ -88,11 +96,48 @@ public class GoogleCloudStorageClientReadChannelTest {

@Before
public void setUp() throws IOException {

fakeReadChannel = spy(new FakeReadChannel(CONTENT));
when(mockedStorage.reader(any(), any())).thenReturn(fakeReadChannel);
when(mockedStorage.reader(blobIdCaptor.capture(), blobSourceOptionCaptor.capture()))
.thenReturn(fakeReadChannel);
readChannel = getJavaStorageChannel(DEFAULT_ITEM_INFO, DEFAULT_READ_OPTION);
}

@Test
public void verifyRequesterPaysOption() throws IOException {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Do we have any integration test for requesterPays?

String dummyProjectId = "dummyProjectId";
int readBytes = 100;
fakeReadChannel = spy(new FakeReadChannel(CONTENT));
when(mockedStorage.reader(blobIdCaptor.capture(), blobSourceOptionCaptor.capture()))
.thenReturn(fakeReadChannel);
readChannel =
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

what are we really testing here?

new GoogleCloudStorageClientReadChannel(
mockedStorage,
DEFAULT_ITEM_INFO,
DEFAULT_READ_OPTION,
GrpcErrorTypeExtractor.INSTANCE,
GoogleCloudStorageOptions.DEFAULT.toBuilder()
.setRequesterPaysOptions(
RequesterPaysOptions.DEFAULT.toBuilder().setProjectId(dummyProjectId).build())
.build(), /*requesterPays*/
true);
getJavaStorageChannel(DEFAULT_ITEM_INFO, DEFAULT_READ_OPTION);
int startPosition = 0;
readChannel.position(startPosition);
assertThat(readChannel.position()).isEqualTo(startPosition);

ByteBuffer buffer = ByteBuffer.allocate(readBytes);
readChannel.read(buffer);
verifyContent(buffer, startPosition, readBytes);
List<BlobSourceOption> optionsList = blobSourceOptionCaptor.getAllValues();
assertThat(optionsList).contains(BlobSourceOption.userProject(dummyProjectId));
verify(fakeReadChannel, times(1)).seek(anyLong());
verify(fakeReadChannel, times(1)).limit(anyLong());
verify(fakeReadChannel, times(1)).read(any());

verifyNoMoreInteractions(fakeReadChannel);
}

@Test
public void inValidSeekPositions() {
int seekPosition = -1;
Expand Down Expand Up @@ -535,6 +580,7 @@ private GoogleCloudStorageClientReadChannel getJavaStorageChannel(
objectInfo,
readOptions,
GrpcErrorTypeExtractor.INSTANCE,
GoogleCloudStorageOptions.DEFAULT.toBuilder().build());
GoogleCloudStorageOptions.DEFAULT.toBuilder().build(), /*requesterPays*/
false);
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import com.google.cloud.WriteChannel;
import com.google.cloud.hadoop.gcsio.integration.GoogleCloudStorageTestHelper;
import com.google.cloud.hadoop.util.AsyncWriteChannelOptions;
import com.google.cloud.hadoop.util.RequesterPaysOptions;
import com.google.cloud.hadoop.util.RetryHttpInitializer;
import com.google.cloud.hadoop.util.RetryHttpInitializerOptions;
import com.google.cloud.hadoop.util.testing.FakeCredentials;
Expand Down Expand Up @@ -102,6 +103,47 @@ public static void cleanUp() {
}
}

@Test
public void verifyRequesterPaysOption() throws IOException {
String dummyProjectId = "dummyProjectId";
int numberOfChunks = 1;
writeChannel =
new GoogleCloudStorageClientWriteChannel(
mockedStorage,
GoogleCloudStorageOptions.DEFAULT.toBuilder()
.setWriteChannelOptions(
AsyncWriteChannelOptions.DEFAULT.toBuilder()
.setGrpcChecksumsEnabled(true)
.build())
.setRequesterPaysOptions(
RequesterPaysOptions.DEFAULT.toBuilder().setProjectId(dummyProjectId).build())
.build(),
resourceId,
CreateObjectOptions.DEFAULT_NO_OVERWRITE.toBuilder()
.setContentType(CONTENT_TYPE)
.setContentEncoding(CONTENT_ENCODING)
.setMetadata(GoogleCloudStorageTestHelper.getDecodedMetadata(metadata))
.setKmsKeyName(KMS_KEY)
.build(),
/*requesterPays=*/ true,
EXECUTOR_SERVICE);
writeChannel.initialize();

ByteString data =
GoogleCloudStorageTestHelper.createTestData(
MAX_WRITE_CHUNK_BYTES.getNumber() * numberOfChunks - 1);
writeChannel.write(data.asReadOnlyByteBuffer());
writeChannel.close();

List<BlobWriteOption> optionsList = blobWriteOptionsCapture.getAllValues();
assertThat(optionsList).contains(BlobWriteOption.userProject(dummyProjectId));
// Fake writer only writes half the buffer at a time
verify(fakeWriteChannel, times(numberOfChunks * 2)).write(any());
verify(fakeWriteChannel, times(1)).close();
verifyBlobInfoProperties(blobInfoCapture, resourceId);
assertThat(writeChannel.isUploadSuccessful()).isTrue();
}

@Test
public void writeMultipleChunksSuccess() throws IOException {
int numberOfChunks = 10;
Expand Down Expand Up @@ -232,6 +274,7 @@ private GoogleCloudStorageClientWriteChannel getJavaStorageChannel() {
.setMetadata(GoogleCloudStorageTestHelper.getDecodedMetadata(metadata))
.setKmsKeyName(KMS_KEY)
.build(),
/*requesterPays=*/ false,
EXECUTOR_SERVICE);
}

Expand Down