Skip to content

Commit

Permalink
Adds KMS encryption_context for KMS encryption in the Kafka buffer. M…
Browse files Browse the repository at this point in the history
…oves the kms_key_id into a new kms section along with encryption_context. Resolves opensearch-project#3484 (opensearch-project#3486)

Signed-off-by: David Venable <[email protected]>
  • Loading branch information
dlvenable authored Oct 12, 2023
1 parent 7474480 commit b4b4a98
Show file tree
Hide file tree
Showing 4 changed files with 91 additions and 11 deletions.
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
package org.opensearch.dataprepper.plugins.kafka.common.key;

import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext;
import org.opensearch.dataprepper.plugins.kafka.configuration.KmsConfig;
import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
Expand All @@ -18,12 +19,13 @@ public KmsKeyProvider(AwsContext awsContext) {

@Override
public boolean supportsConfiguration(TopicConfig topicConfig) {
return topicConfig.getKmsKeyId() != null;
return topicConfig.getKmsConfig() != null && topicConfig.getKmsConfig().getKeyId() != null;
}

@Override
public byte[] apply(TopicConfig topicConfig) {
String kmsKeyId = topicConfig.getKmsKeyId();
KmsConfig kmsConfig = topicConfig.getKmsConfig();
String kmsKeyId = kmsConfig.getKeyId();

AwsCredentialsProvider awsCredentialsProvider = awsContext.get();

Expand All @@ -36,6 +38,7 @@ public byte[] apply(TopicConfig topicConfig) {
DecryptResponse decryptResponse = kmsClient.decrypt(builder -> builder
.keyId(kmsKeyId)
.ciphertextBlob(SdkBytes.fromByteArray(decodedEncryptionKey))
.encryptionContext(kmsConfig.getEncryptionContext())
);

return decryptResponse.plaintext().asByteArray();
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,21 @@
package org.opensearch.dataprepper.plugins.kafka.configuration;

import com.fasterxml.jackson.annotation.JsonProperty;

import java.util.Map;

public class KmsConfig {
@JsonProperty("key_id")
private String keyId;

@JsonProperty("encryption_context")
private Map<String, String> encryptionContext;

public String getKeyId() {
return keyId;
}

public Map<String, String> getEncryptionContext() {
return encryptionContext;
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -128,8 +128,8 @@ public class TopicConfig {
@JsonProperty("encryption_key")
private String encryptionKey;

@JsonProperty("kms_key_id")
private String kmsKeyId;
@JsonProperty("kms")
private KmsConfig kmsConfig;

public Long getRetentionPeriod() {
return retentionPeriod;
Expand All @@ -151,8 +151,8 @@ public String getEncryptionKey() {
return encryptionKey;
}

public String getKmsKeyId() {
return kmsKeyId;
public KmsConfig getKmsConfig() {
return kmsConfig;
}

public Boolean getAutoCommit() {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@
import org.mockito.MockedStatic;
import org.mockito.junit.jupiter.MockitoExtension;
import org.opensearch.dataprepper.plugins.kafka.common.aws.AwsContext;
import org.opensearch.dataprepper.plugins.kafka.configuration.KmsConfig;
import org.opensearch.dataprepper.plugins.kafka.configuration.TopicConfig;
import software.amazon.awssdk.auth.credentials.AwsCredentialsProvider;
import software.amazon.awssdk.core.SdkBytes;
Expand All @@ -19,13 +20,15 @@
import software.amazon.awssdk.services.kms.model.DecryptResponse;

import java.util.Base64;
import java.util.Map;
import java.util.UUID;
import java.util.function.Consumer;

import static org.hamcrest.CoreMatchers.equalTo;
import static org.hamcrest.MatcherAssert.assertThat;
import static org.mockito.ArgumentMatchers.any;
import static org.mockito.ArgumentMatchers.anyString;
import static org.mockito.ArgumentMatchers.isNull;
import static org.mockito.Mockito.mock;
import static org.mockito.Mockito.mockStatic;
import static org.mockito.Mockito.verify;
Expand All @@ -39,19 +42,28 @@ class KmsKeyProviderTest {
private AwsCredentialsProvider awsCredentialsProvider;
@Mock
private TopicConfig topicConfig;
@Mock
private KmsConfig kmsConfig;

private KmsKeyProvider createObjectUnderTest() {
return new KmsKeyProvider(awsContext);
}

@Test
void supportsConfiguration_returns_false_if_kmsKeyId_is_null() {
void supportsConfiguration_returns_false_if_kms_config_is_null() {
assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(false));
}

@Test
void supportsConfiguration_returns_false_if_kms_keyId_is_null() {
when(topicConfig.getKmsConfig()).thenReturn(kmsConfig);
assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(false));
}

@Test
void supportsConfiguration_returns_true_if_kmsKeyId_is_present() {
when(topicConfig.getKmsKeyId()).thenReturn(UUID.randomUUID().toString());
void supportsConfiguration_returns_true_if_kms_keyId_is_present() {
when(topicConfig.getKmsConfig()).thenReturn(kmsConfig);
when(kmsConfig.getKeyId()).thenReturn(UUID.randomUUID().toString());
assertThat(createObjectUnderTest().supportsConfiguration(topicConfig), equalTo(true));
}

Expand All @@ -77,7 +89,8 @@ void setUp() {
encryptionKey = UUID.randomUUID().toString();
String base64EncryptionKey = Base64.getEncoder().encodeToString(encryptionKey.getBytes());
when(topicConfig.getEncryptionKey()).thenReturn(base64EncryptionKey);
when(topicConfig.getKmsKeyId()).thenReturn(kmsKeyId);
when(topicConfig.getKmsConfig()).thenReturn(kmsConfig);
when(kmsConfig.getKeyId()).thenReturn(kmsKeyId);

kmsClient = mock(KmsClient.class);
DecryptResponse decryptResponse = mock(DecryptResponse.class);
Expand All @@ -104,8 +117,48 @@ void apply_returns_plaintext_from_decrypt_request() {
}

@Test
void apply_calls_decrypt_with_correct_values() {
void apply_calls_decrypt_with_correct_values_when_encryption_context_is_null() {
KmsKeyProvider objectUnderTest = createObjectUnderTest();

when(kmsConfig.getEncryptionContext()).thenReturn(null);

try (MockedStatic<KmsClient> kmsClientMockedStatic = mockStatic(KmsClient.class)) {
kmsClientMockedStatic.when(() -> KmsClient.builder()).thenReturn(kmsClientBuilder);
objectUnderTest.apply(topicConfig);
}

ArgumentCaptor<Consumer<DecryptRequest.Builder>> consumerArgumentCaptor = ArgumentCaptor.forClass(Consumer.class);
verify(kmsClient).decrypt(consumerArgumentCaptor.capture());

Consumer<DecryptRequest.Builder> actualConsumer = consumerArgumentCaptor.getValue();

DecryptRequest.Builder builder = mock(DecryptRequest.Builder.class);
when(builder.keyId(anyString())).thenReturn(builder);
when(builder.ciphertextBlob(any())).thenReturn(builder);
when(builder.encryptionContext(any())).thenReturn(builder);
actualConsumer.accept(builder);

verify(builder).keyId(kmsKeyId);
ArgumentCaptor<SdkBytes> actualBytesCaptor = ArgumentCaptor.forClass(SdkBytes.class);
verify(builder).ciphertextBlob(actualBytesCaptor.capture());

SdkBytes actualSdkBytes = actualBytesCaptor.getValue();
assertThat(actualSdkBytes.asByteArray(), equalTo(encryptionKey.getBytes()));

verify(builder).encryptionContext(isNull());
}

@Test
void apply_calls_decrypt_with_correct_values_when_encryption_context_is_present() {
Map<String, String> encryptionContext = Map.of(
UUID.randomUUID().toString(), UUID.randomUUID().toString(),
UUID.randomUUID().toString(), UUID.randomUUID().toString(),
UUID.randomUUID().toString(), UUID.randomUUID().toString()
);
KmsKeyProvider objectUnderTest = createObjectUnderTest();

when(kmsConfig.getEncryptionContext()).thenReturn(encryptionContext);

try (MockedStatic<KmsClient> kmsClientMockedStatic = mockStatic(KmsClient.class)) {
kmsClientMockedStatic.when(() -> KmsClient.builder()).thenReturn(kmsClientBuilder);
objectUnderTest.apply(topicConfig);
Expand All @@ -119,6 +172,7 @@ void apply_calls_decrypt_with_correct_values() {
DecryptRequest.Builder builder = mock(DecryptRequest.Builder.class);
when(builder.keyId(anyString())).thenReturn(builder);
when(builder.ciphertextBlob(any())).thenReturn(builder);
when(builder.encryptionContext(any())).thenReturn(builder);
actualConsumer.accept(builder);

verify(builder).keyId(kmsKeyId);
Expand All @@ -127,6 +181,8 @@ void apply_calls_decrypt_with_correct_values() {

SdkBytes actualSdkBytes = actualBytesCaptor.getValue();
assertThat(actualSdkBytes.asByteArray(), equalTo(encryptionKey.getBytes()));

verify(builder).encryptionContext(encryptionContext);
}
}
}

0 comments on commit b4b4a98

Please sign in to comment.