diff --git a/src/integrationTest/java/org/opensearch/security/DlsTests.java b/src/integrationTest/java/org/opensearch/security/DlsTests.java index 41da011bb7..a13dbbf92e 100644 --- a/src/integrationTest/java/org/opensearch/security/DlsTests.java +++ b/src/integrationTest/java/org/opensearch/security/DlsTests.java @@ -9,50 +9,52 @@ */ package org.opensearch.security; +import static org.hamcrest.MatcherAssert.assertThat; +import static org.hamcrest.Matchers.oneOf; +import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; +import static org.opensearch.client.RequestOptions.DEFAULT; +import static org.opensearch.security.Song.SONGS; +import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; +import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS; + import java.io.IOException; import java.util.ArrayList; +import java.util.HashMap; import java.util.List; import java.util.Map; -import java.util.OptionalDouble; +import java.util.Random; +import java.util.UUID; +import java.util.concurrent.Callable; import java.util.concurrent.CompletableFuture; import java.util.concurrent.ExecutorService; import java.util.concurrent.Executors; -import java.util.concurrent.Future; import java.util.stream.Collectors; import java.util.stream.IntStream; -import java.util.concurrent.Callable; - - -import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; import org.junit.Before; import org.junit.BeforeClass; import org.junit.ClassRule; import org.junit.Test; import org.junit.runner.RunWith; - -import org.opensearch.action.admin.indices.create.CreateIndexRequest; - +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.index.IndexRequest; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; import org.opensearch.client.RestHighLevelClient; -import org.opensearch.test.framework.AsyncActions; +import org.opensearch.common.xcontent.XContentFactory; +import org.opensearch.core.xcontent.XContentBuilder; +import org.opensearch.index.query.QueryBuilders; +import org.opensearch.core.xcontent.ToXContent.MapParams; +import org.opensearch.search.aggregations.AggregationBuilders; +import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.test.framework.TestSecurityConfig; import org.opensearch.test.framework.cluster.ClusterManager; import org.opensearch.test.framework.cluster.LocalCluster; import org.opensearch.test.framework.cluster.TestRestClient; import org.opensearch.test.framework.cluster.TestRestClient.HttpResponse; -import static org.hamcrest.MatcherAssert.assertThat; -import static org.hamcrest.Matchers.either; -import static org.hamcrest.Matchers.equalTo; -import static org.hamcrest.Matchers.oneOf; -import static org.opensearch.action.support.WriteRequest.RefreshPolicy.IMMEDIATE; -import static org.opensearch.client.RequestOptions.DEFAULT; -import static org.opensearch.security.Song.SONGS; -import static org.opensearch.test.framework.TestSecurityConfig.AuthcDomain.AUTHC_HTTPBASIC_INTERNAL; -import static org.opensearch.test.framework.TestSecurityConfig.Role.ALL_ACCESS; +import com.carrotsearch.randomizedtesting.annotations.ThreadLeakScope; //./gradlew integrationTest --tests org.opensearch.security.DlsTests @RunWith(com.carrotsearch.randomizedtesting.RandomizedRunner.class) @@ -60,11 +62,10 @@ public class DlsTests { private enum TestRoles { - EMPTY_DLS, - DLS_ONLY_ROCK, - DLS_ONLY_JAZZ, - DLS_ONLY_ROCK_AND_JAZZ, - DLS_ONLY_LONG_VALUE; + NO_MASKING, + MASKING_RANDOM_STRING, + MASKING_RANDOM_LONG, + MASKING_LOW_REPEAT_VALUE; } static final String INDEX_NAME_PREFIX = "test-index-"; @@ -90,32 +91,24 @@ private enum TestRoles { .nodeSettings( Map.of("plugins.security.restapi.roles_enabled", List.of("user_" + ADMIN_USER.getName() + "__" + ALL_ACCESS.getName())) ) - .roles(new TestSecurityConfig.Role(TestRoles.EMPTY_DLS.name()) + .roles(new TestSecurityConfig.Role(TestRoles.NO_MASKING.name()) .clusterPermissions("cluster_composite_ops_ro") .indexPermissions("read") - .dls("") .on("*"), - new TestSecurityConfig.Role(TestRoles.DLS_ONLY_ROCK.name()) + new TestSecurityConfig.Role(TestRoles.MASKING_RANDOM_STRING.name()) .clusterPermissions("cluster_composite_ops_ro") .indexPermissions("read") - .dls("{\"bool\":{\"must\":[{\"terms\":{\"genre.keyword\":[\"rock\"]}}]}}") + .maskedFields("guid") .on("*"), - new TestSecurityConfig.Role(TestRoles.DLS_ONLY_JAZZ.name()) + new TestSecurityConfig.Role(TestRoles.MASKING_RANDOM_LONG.name()) .clusterPermissions("cluster_composite_ops_ro") .indexPermissions("read") - .dls("{\"bool\":{\"must\":[{\"terms\":{\"genre.keyword\":[\"jazz\"]}}]}}") + .maskedFields("longId") .on("*"), - new TestSecurityConfig.Role(TestRoles.DLS_ONLY_LONG_VALUE.name()) + new TestSecurityConfig.Role(TestRoles.MASKING_LOW_REPEAT_VALUE.name()) .clusterPermissions("cluster_composite_ops_ro") .indexPermissions("read") - .dls("{\"bool\":{\"must\":[{\"terms\":{\"genre.keyword\":[\"" - + "0123456789".repeat(100) // ==1000 characters - + "\"]}}]}}") - .on("*"), - new TestSecurityConfig.Role(TestRoles.DLS_ONLY_ROCK_AND_JAZZ.name()) - .clusterPermissions("cluster_composite_ops_ro") - .indexPermissions("read") - .dls("{\"bool\":{\"must\":[{\"terms\":{\"genre.keyword\":[\"jazz\"]}},{\"terms\":{\"genre.keyword\":[\"rock\"]}}]}}") + .maskedFields("genre") .on("*")) .authc(AUTHC_HTTPBASIC_INTERNAL) .users( @@ -130,6 +123,14 @@ public static void createTestData() { @Before public void setup() { + removeRolesFromReader(); + + // try (Client client = cluster.getInternalNodeClient()) { + // client.admin().indices() + // } + } + + private void removeRolesFromReader() { try (TestRestClient client = cluster.getRestClient(ADMIN_USER)) { for (TestRoles role : TestRoles.values()) { final String path = "_plugins/_security/api/rolesmapping/" + role.name(); @@ -137,9 +138,6 @@ public void setup() { assertThat(response.getStatusCode(), oneOf(200, 204, 404)); } } - // try (Client client = cluster.getInternalNodeClient()) { - // client.admin().indices() - // } } @Test @@ -149,75 +147,41 @@ public void testBaselinedDlsScenarios() throws Exception { queryAndGetStats(ADMIN_USER); queryAndGetStats(READER); - attachRoleToReader(TestRoles.EMPTY_DLS); + removeRolesFromReader(); + attachRoleToReader(TestRoles.NO_MASKING); queryAndGetStats(READER); - attachRoleToReader(TestRoles.DLS_ONLY_ROCK); + removeRolesFromReader(); + attachRoleToReader(TestRoles.MASKING_LOW_REPEAT_VALUE); queryAndGetStats(READER); - attachRoleToReader(TestRoles.DLS_ONLY_JAZZ); + removeRolesFromReader(); + attachRoleToReader(TestRoles.MASKING_RANDOM_LONG); queryAndGetStats(READER); - final long endMs = System.currentTimeMillis() - startMs; - System.out.println("Finished checks in " + endMs + "ms"); - - return null; - }; - createIndices(5); - check.call(); - setup(); - createIndices(50); - check.call(); - - setup(); - createIndices(100); - check.call(); - } - - @Test - public void testConsolidatedDlsScenarios() throws Exception { - final Callable check = () -> { - final long startMs = System.currentTimeMillis(); - queryAndGetStats(READER); - attachRoleToReader(TestRoles.DLS_ONLY_ROCK_AND_JAZZ); + removeRolesFromReader(); + attachRoleToReader(TestRoles.MASKING_RANDOM_STRING); queryAndGetStats(READER); + final long endMs = System.currentTimeMillis() - startMs; System.out.println("Finished checks in " + endMs + "ms"); + return null; }; - createIndices(5); - check.call(); - setup(); - createIndices(50); + createIndices(1, 50); check.call(); setup(); - createIndices(100); - check.call(); - } - - @Test - public void testDlsLargerQueryScenarios() throws Exception { - final Callable check = () -> { - final long startMs = System.currentTimeMillis(); - queryAndGetStats(READER); - - attachRoleToReader(TestRoles.DLS_ONLY_LONG_VALUE); - queryAndGetStats(READER); - final long endMs = System.currentTimeMillis() - startMs; - System.out.println("Finished checks in " + endMs + "ms"); - return null; - }; - createIndices(5); + createIndices(1, 50 * 100); check.call(); setup(); - createIndices(50); + createIndices(3, 50 * 100); check.call(); setup(); - createIndices(100); + createIndices(3, 50 * 100 * 10); check.call(); } @@ -231,35 +195,78 @@ private void attachRoleToReader(final TestRoles role) { } } - private void createIndices(final int count) throws IOException { - System.out.println("Creating " + count + " indices with 1 document"); + private void createIndices(final int count, final int docCount) throws IOException { + System.out.println("Creating " + count + " indices with " + docCount + " documents"); + final long currentTimeMillis = System.currentTimeMillis(); try (Client client = cluster.getInternalNodeClient()) { final ExecutorService pool = Executors.newFixedThreadPool(25); - final List> futures = IntStream.range(1, count).mapToObj(n -> { + final List> futures = IntStream.range(1, count + 1).mapToObj(n -> { final String indexName = INDEX_NAME_PREFIX + n; - return CompletableFuture.runAsync(() -> client.prepareIndex().setIndex(indexName).setRefreshPolicy(IMMEDIATE).setSource(SONGS[0].asMap()).get(), pool); + final Random random = new Random(); + return CompletableFuture.runAsync(() -> { + var docs = new ArrayList(); + final Map baseDoc = new HashMap(SONGS[0].asMap()); + for (int i = 0; i < docCount - 1; i++) { + var uuid = UUID.randomUUID().toString(); + baseDoc.put("guid", uuid); + baseDoc.put("longId", random.nextLong()); + docs.add(new IndexRequest().index(indexName).id(uuid).source(baseDoc)); + } + + for (int indexReqGroupN = 0; indexReqGroupN < docCount / 250; indexReqGroupN++) { + BulkRequest br = new BulkRequest(); + docs.stream().skip(n * 250).limit(250).forEach(ir -> { + br.add(ir); + }); + + if (br.numberOfActions() != 0) { + client.bulk(br).actionGet(); + } + } + }, pool); }).collect(Collectors.toList()); final CompletableFuture futuresCompleted = CompletableFuture.allOf(futures.toArray(new CompletableFuture[0])); futuresCompleted.join(); } + System.out.println("Creation completed, " + (System.currentTimeMillis() - currentTimeMillis) + "ms"); } private void queryAndGetStats(final TestSecurityConfig.User user) throws IOException { + TestRestClient adminClient = cluster.getRestClient(ADMIN_USER); try (RestHighLevelClient restHighLevelClient = cluster.getRestHighLevelClient(user)) { - final int samplesToIgnore = 5; + try (TestRestClient client = cluster.getRestClient(user)) { + + final int samplesToIgnore = 5; final int samples = 100 + samplesToIgnore; final List results = new ArrayList<>(); for (int i = 0; i < samples; i++) { final long start = System.currentTimeMillis(); - final SearchResponse response = restHighLevelClient.search(new SearchRequest(INDEX_NAME_PREFIX + "*"), DEFAULT); - final long endMs = System.currentTimeMillis() - start; - results.add(endMs); + SearchSourceBuilder ssb = new SearchSourceBuilder(); + ssb.aggregation(AggregationBuilders.filters("my-filter", QueryBuilders.queryStringQuery("last"))); + ssb.aggregation(AggregationBuilders.count("counting").field("genre.keyword")); + ssb.aggregation(AggregationBuilders.avg("averaging").field("longId")); + ssb.size(2); + SearchRequest request = new SearchRequest(INDEX_NAME_PREFIX + "*"); + request.source(ssb); + + final SearchResponse response = restHighLevelClient.search(request, DEFAULT); + if (i == 0) { + try (final XContentBuilder builder = XContentFactory.jsonBuilder()) { + response.toXContent(builder, new MapParams(Map.of("a","b"))); + System.err.println("Response " + builder.toString()); + } + } + + final var statsResponse = adminClient.get("_cluster/stats"); + statsResponse. + + results.add(999L); } // toss out inital samples IntStream.range(0, samplesToIgnore).forEach(n -> results.remove(0)); - System.out.println("User, Count, Avg, Max, Min, Std ms " + + System.out.println("User, Count, Avg, Max, Min, Std ms:" + user.getName() + ", " + results.size() + ", " + results.stream().mapToLong(a -> a).average().getAsDouble() + @@ -267,6 +274,7 @@ private void queryAndGetStats(final TestSecurityConfig.User user) throws IOExcep ", " + results.stream().mapToLong(a -> a).min().getAsLong() + ", " + String.format("%.2f", calcStd(results))); } + } } private static double calcStd(final List numbers) {