diff --git a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java index f4cdd6f06..b74a71048 100644 --- a/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java +++ b/src/main/java/org/opensearch/securityanalytics/model/DetectorTrigger.java @@ -309,6 +309,22 @@ public String getSeverity() { return severity; } + public List getRuleTypes() { + return ruleTypes; + } + + public List getRuleIds() { + return ruleIds; + } + + public List getRuleSeverityLevels() { + return ruleSeverityLevels; + } + + public List getTags() { + return tags; + } + public List getActions() { List transformedActions = new ArrayList<>(); diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java index 08a00c86e..98ac07585 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportIndexDetectorAction.java @@ -96,6 +96,7 @@ import org.opensearch.securityanalytics.rules.exceptions.SigmaError; import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; import org.opensearch.securityanalytics.util.DetectorIndices; +import org.opensearch.securityanalytics.util.DetectorUtils; import org.opensearch.securityanalytics.util.IndexUtils; import org.opensearch.securityanalytics.util.MonitorService; import org.opensearch.securityanalytics.util.RuleIndices; @@ -114,6 +115,7 @@ import java.util.List; import java.util.Locale; import java.util.Map; +import java.util.UUID; import java.util.concurrent.atomic.AtomicBoolean; import java.util.concurrent.atomic.AtomicReference; import java.util.stream.Collectors; @@ -155,7 +157,7 @@ public class TransportIndexDetectorAction extends HandledTransportAction> rulesById, Detect StepListener> indexMonitorsStep = new StepListener<>(); indexMonitorsStep.whenComplete( - indexMonitorResponses -> saveWorkflow(detector, indexMonitorResponses, refreshPolicy, listener), + indexMonitorResponses -> saveWorkflow(rulesById, detector, indexMonitorResponses, refreshPolicy, listener), e -> { log.error("Failed to index the workflow", e); listener.onFailure(e); @@ -283,7 +285,7 @@ private void createMonitorFromQueries(List> rulesById, Detect int numberOfUnprocessedResponses = monitorRequests.size() - 1; if (numberOfUnprocessedResponses == 0) { - saveWorkflow(detector, monitorResponses, refreshPolicy, listener); + saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener); } else { // Saves the rest of the monitors and saves the workflow if supported saveMonitors( @@ -312,7 +314,7 @@ private void createMonitorFromQueries(List> rulesById, Detect AlertingPluginInterface.INSTANCE.indexMonitor((NodeClient) client, monitorRequests.get(0), namedWriteableRegistry, indexDocLevelMonitorStep); indexDocLevelMonitorStep.whenComplete(addedFirstMonitorResponse -> { monitorResponses.add(addedFirstMonitorResponse); - saveWorkflow(detector, monitorResponses, refreshPolicy, listener); + saveWorkflow(rulesById, detector, monitorResponses, refreshPolicy, listener); }, listener::onFailure ); @@ -346,19 +348,22 @@ public void onFailure(Exception e) { /** * If the workflow is enabled, saves the workflow, updates the detector and returns the saved monitors * if not, returns the saved monitors + * + * @param rulesById * @param detector * @param monitorResponses * @param refreshPolicy * @param actionListener */ private void saveWorkflow( - Detector detector, - List monitorResponses, - RefreshPolicy refreshPolicy, - ActionListener> actionListener + List> rulesById, Detector detector, + List monitorResponses, + RefreshPolicy refreshPolicy, + ActionListener> actionListener ) { if (enabledWorkflowUsage) { workflowService.upsertWorkflow( + rulesById, monitorResponses, null, detector, @@ -446,7 +451,7 @@ public void onResponse(Map> ruleFieldMappings) { monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( Collectors.toList())); - updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); } catch (IOException | SigmaError ex) { listener.onFailure(ex); } @@ -474,7 +479,7 @@ public void onFailure(Exception e) { monitorIdsToBeDeleted.removeAll(monitorsToBeUpdated.stream().map(IndexMonitorRequest::getMonitorId).collect( Collectors.toList())); - updateAlertingMonitors(detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); + updateAlertingMonitors(rulesById, detector, monitorsToBeAdded, monitorsToBeUpdated, monitorIdsToBeDeleted, refreshPolicy, listener); } } @@ -493,6 +498,7 @@ public void onFailure(Exception e) { * @param listener Listener that accepts the list of updated monitors if the action was successful */ private void updateAlertingMonitors( + List> rulesById, Detector detector, List monitorsToBeAdded, List monitorsToBeUpdated, @@ -519,6 +525,7 @@ private void updateAlertingMonitors( } if (detector.isWorkflowSupported() && enabledWorkflowUsage) { updateWorkflowStep( + rulesById, detector, monitorsToBeDeleted, refreshPolicy, @@ -560,6 +567,7 @@ public void onFailure(Exception e) { } private void updateWorkflowStep( + List> rulesById, Detector detector, List monitorsToBeDeleted, RefreshPolicy refreshPolicy, @@ -596,6 +604,7 @@ public void onFailure(Exception e) { } else { // Update workflow and delete the monitors workflowService.upsertWorkflow( + rulesById, addNewMonitorsResponse, updateMonitorResponse, detector, @@ -749,8 +758,8 @@ public void onResponse(Map> ruleFieldMappings) { queryBackendMap.get(rule.getCategory()))); } } - // if workflow usage enabled, add chained findings monitor request since there are bucket level requests - if(enabledWorkflowUsage && false == monitorRequests.isEmpty()) { + // if workflow usage enabled, add chained findings monitor request if there are bucket level requests and if the detector triggers have any group by rules configured to trigger + if (enabledWorkflowUsage && !monitorRequests.isEmpty() && !DetectorUtils.getAggRuleIdsConfiguredToTrigger(detector, queries).isEmpty()) { monitorRequests.add(createDocLevelMonitorMatchAllRequest(detector, RefreshPolicy.IMMEDIATE, detector.getId()+"_chained_findings", Method.POST)); } listener.onResponse(monitorRequests); @@ -841,7 +850,7 @@ private IndexMonitorRequest createBucketLevelMonitorRequest( triggers.add(bucketLevelTrigger1); } **/ - Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName(), false, detector.getSchedule(), detector.getLastUpdateTime(), null, + Monitor monitor = new Monitor(monitorId, Monitor.NO_VERSION, detector.getName() + UUID.randomUUID(), false, detector.getSchedule(), detector.getLastUpdateTime(), null, MonitorType.BUCKET_LEVEL_MONITOR, detector.getUser(), 1, bucketLevelMonitorInputs, triggers, Map.of(), new DataSources(detector.getRuleIndex(), detector.getFindingsIndex(), diff --git a/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java b/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java index 5e9d25c38..28e316e06 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java +++ b/src/main/java/org/opensearch/securityanalytics/util/DetectorUtils.java @@ -4,8 +4,11 @@ */ package org.opensearch.securityanalytics.util; +import org.apache.commons.lang3.tuple.Pair; import org.apache.lucene.search.TotalHits; import org.opensearch.cluster.routing.Preference; +import org.opensearch.commons.alerting.action.IndexMonitorResponse; +import org.opensearch.commons.alerting.model.Monitor; import org.opensearch.core.action.ActionListener; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; @@ -25,6 +28,7 @@ import org.opensearch.search.suggest.Suggest; import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; +import org.opensearch.securityanalytics.model.Rule; import java.io.IOException; import java.util.Collections; @@ -32,6 +36,7 @@ import java.util.LinkedList; import java.util.List; import java.util.Set; +import java.util.stream.Collectors; public class DetectorUtils { @@ -95,4 +100,36 @@ public void onFailure(Exception e) { } }); } + + public static List getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger( + Detector detector, + List> rulesById, + List monitorResponses + ) { + List aggRuleIdsConfiguredToTrigger = getAggRuleIdsConfiguredToTrigger(detector, rulesById); + return monitorResponses.stream().filter( + // In the case of bucket level monitors rule id is trigger id + it -> Monitor.MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType() + && !it.getMonitor().getTriggers().isEmpty() + && aggRuleIdsConfiguredToTrigger.contains(it.getMonitor().getTriggers().get(0).getId()) + ).map(IndexMonitorResponse::getId).collect(Collectors.toList()); + } + public static List getAggRuleIdsConfiguredToTrigger(Detector detector, List> rulesById) { + Set ruleIdsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getRuleIds().stream()).collect(Collectors.toSet()); + Set tagsConfiguredToTrigger = detector.getTriggers().stream().flatMap(t -> t.getTags().stream()).collect(Collectors.toSet()); + return rulesById.stream() + .filter(it -> checkIfRuleIsAggAndTriggerable( it.getRight(), ruleIdsConfiguredToTrigger, tagsConfiguredToTrigger)) + .map(stringRulePair -> stringRulePair.getRight().getId()) + .collect(Collectors.toList()); + } + + private static boolean checkIfRuleIsAggAndTriggerable(Rule rule, Set ruleIdsConfiguredToTrigger, Set tagsConfiguredToTrigger) { + if (rule.isAggregationRule()) { + return ruleIdsConfiguredToTrigger.contains(rule.getId()) + || rule.getTags().stream().anyMatch(tag -> tagsConfiguredToTrigger.contains(tag.getValue())); + } + return false; + } + + } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java index 21a0013c7..5ce495b98 100644 --- a/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java +++ b/src/main/java/org/opensearch/securityanalytics/util/WorkflowService.java @@ -4,6 +4,7 @@ */ package org.opensearch.securityanalytics.util; +import org.apache.commons.lang3.tuple.Pair; import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; import org.opensearch.OpenSearchException; @@ -28,6 +29,7 @@ import org.opensearch.index.seqno.SequenceNumbers; import org.opensearch.rest.RestRequest.Method; import org.opensearch.securityanalytics.model.Detector; +import org.opensearch.securityanalytics.model.Rule; import java.util.ArrayList; import java.util.Collections; @@ -37,6 +39,8 @@ import java.util.concurrent.atomic.AtomicInteger; import java.util.stream.Collectors; +import static org.opensearch.securityanalytics.util.DetectorUtils.getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger; + /** * Alerting common clas used for workflow manipulation */ @@ -67,6 +71,7 @@ public WorkflowService(Client client, MonitorService monitorService) { * @param listener */ public void upsertWorkflow( + List> rulesById, List addedMonitorResponses, List updatedMonitorResponses, Detector detector, @@ -90,13 +95,13 @@ public void upsertWorkflow( } ChainedMonitorFindings chainedMonitorFindings = null; String cmfMonitorId = null; - if(addedMonitorResponses.stream().anyMatch(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName()))) { - List bucketMonitorIds = addedMonitorResponses.stream().filter(res -> res.getMonitor().getMonitorType().equals(MonitorType.BUCKET_LEVEL_MONITOR)).map(IndexMonitorResponse::getId).collect(Collectors.toList()); - if(!updatedMonitors.isEmpty()) { - bucketMonitorIds.addAll(updatedMonitorResponses.stream().filter(res -> res.getMonitor().getMonitorType().equals(MonitorType.BUCKET_LEVEL_MONITOR)).map(IndexMonitorResponse::getId).collect(Collectors.toList())); + if (addedMonitorResponses.stream().anyMatch(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName()))) { + List monitorResponses = new ArrayList<>(addedMonitorResponses); + if (updatedMonitorResponses != null) { + monitorResponses.addAll(updatedMonitorResponses); } cmfMonitorId = addedMonitorResponses.stream().filter(res -> (detector.getName() + "_chained_findings").equals(res.getMonitor().getName())).findFirst().get().getId(); - chainedMonitorFindings = new ChainedMonitorFindings(null, bucketMonitorIds); + chainedMonitorFindings = new ChainedMonitorFindings(null, getBucketLevelMonitorIdsWhoseRulesAreConfiguredToTrigger(detector, rulesById, monitorResponses)); } IndexWorkflowRequest indexWorkflowRequest = createWorkflowRequest(monitorIds, @@ -154,7 +159,7 @@ private IndexWorkflowRequest createWorkflowRequest(List monitorIds, Dete return delegate; } ).collect(Collectors.toList()); - + Sequence sequence = new Sequence(delegates); CompositeInput compositeInput = new CompositeInput(sequence); @@ -185,21 +190,5 @@ private IndexWorkflowRequest createWorkflowRequest(List monitorIds, Dete null ); } - - private Map mapMonitorIds(List monitorResponses) { - return monitorResponses.stream().collect( - Collectors.toMap( - // In the case of bucket level monitors rule id is trigger id - it -> { - if (MonitorType.BUCKET_LEVEL_MONITOR == it.getMonitor().getMonitorType()) { - return it.getMonitor().getTriggers().get(0).getId(); - } else { - return Detector.DOC_LEVEL_MONITOR; - } - }, - IndexMonitorResponse::getId - ) - ); - } } diff --git a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java index 5f03b4e5d..522260a0a 100644 --- a/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java +++ b/src/test/java/org/opensearch/securityanalytics/SecurityAnalyticsRestTestCase.java @@ -127,7 +127,6 @@ protected void createRuleTopicIndex(String detectorType, String additionalMappin assertEquals(RestStatus.OK, restStatus(response)); } } - protected void verifyWorkflow(Map detectorMap, List monitorIds, int expectedDelegatesNum) throws IOException{ String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); @@ -431,6 +430,11 @@ protected boolean alertingMonitorExists(String monitorId) throws IOException { return alertingMonitorExists(client(), monitorId); } + protected Response getAlertingMonitor(RestClient client, String monitorId) throws IOException { + Response response = makeRequest(client, "GET", String.format(Locale.getDefault(), "/_plugins/_alerting/monitors/%s", monitorId), Collections.emptyMap(), null); + return response; + } + protected boolean alertingMonitorExists(RestClient client, String monitorId) throws IOException { try { Response response = makeRequest(client, "GET", String.format(Locale.getDefault(), "/_plugins/_alerting/monitors/%s", monitorId), Collections.emptyMap(), null); diff --git a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java index b98a6e641..a361c5394 100644 --- a/src/test/java/org/opensearch/securityanalytics/TestHelpers.java +++ b/src/test/java/org/opensearch/securityanalytics/TestHelpers.java @@ -65,6 +65,10 @@ public static Detector randomDetector(List rules, String detectorType) { public static Detector randomDetectorWithInputs(List inputs) { return randomDetector(null, null, null, inputs, List.of(), null, null, null, null); } + + public static Detector randomDetectorWithInputsAndTriggers(List inputs, List triggers) { + return randomDetector(null, null, null, inputs, triggers, null, null, null, null); + } public static Detector randomDetectorWithInputs(List inputs, String detectorType) { return randomDetector(null, detectorType, null, inputs, List.of(), null, null, null, null); } @@ -84,9 +88,6 @@ public static Detector randomDetectorWithTriggers(List rules, List inputs, List triggers) { - return randomDetector(null, null, null, inputs, triggers, null, null, null, null); - } public static Detector randomDetectorWithTriggers(List rules, List triggers, String detectorType, DetectorInput input) { return randomDetector(null, detectorType, null, List.of(input), triggers, null, null, null, null); diff --git a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java index 95d8ff6cb..36595d07c 100644 --- a/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java +++ b/src/test/java/org/opensearch/securityanalytics/resthandler/DetectorMonitorRestApiIT.java @@ -4,27 +4,6 @@ */ package org.opensearch.securityanalytics.resthandler; -import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; -import static org.opensearch.securityanalytics.TestHelpers.randomDetector; -import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; -import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; -import static org.opensearch.securityanalytics.TestHelpers.randomDoc; -import static org.opensearch.securityanalytics.TestHelpers.randomIndex; -import static org.opensearch.securityanalytics.TestHelpers.randomRule; -import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; -import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; - -import java.io.IOException; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Set; -import java.util.stream.Collectors; - import org.apache.hc.core5.http.HttpStatus; import org.junit.Assert; import org.opensearch.action.search.SearchResponse; @@ -39,8 +18,31 @@ import org.opensearch.securityanalytics.model.Detector; import org.opensearch.securityanalytics.model.DetectorInput; import org.opensearch.securityanalytics.model.DetectorRule; +import org.opensearch.securityanalytics.model.DetectorTrigger; import org.opensearch.securityanalytics.model.Rule; +import java.io.IOException; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Set; +import java.util.stream.Collectors; + +import static org.opensearch.securityanalytics.TestHelpers.randomAggregationRule; +import static org.opensearch.securityanalytics.TestHelpers.randomDetector; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorType; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputs; +import static org.opensearch.securityanalytics.TestHelpers.randomDetectorWithInputsAndTriggers; +import static org.opensearch.securityanalytics.TestHelpers.randomDoc; +import static org.opensearch.securityanalytics.TestHelpers.randomIndex; +import static org.opensearch.securityanalytics.TestHelpers.randomRule; +import static org.opensearch.securityanalytics.TestHelpers.windowsIndexMapping; +import static org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings.ENABLE_WORKFLOW_USAGE; + public class DetectorMonitorRestApiIT extends SecurityAnalyticsRestTestCase { /** * 1. Creates detector with 5 doc prepackaged level rules and one doc level monitor based on the given rules @@ -56,10 +58,10 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -73,11 +75,11 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); assertEquals(5, response.getHits().getTotalHits().value); @@ -85,12 +87,12 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; // Verify that one document level monitor is created List hits = executeSearch(Detector.DETECTORS_INDEX, request); @@ -110,7 +112,7 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t String avgTermRuleId = createRule(randomAggregationRule( "avg", " > 1")); // Update detector and empty doc level rules so detector contains only one aggregation rule DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(sumRuleId), new DetectorRule(avgTermRuleId)), - Collections.emptyList()); + Collections.emptyList()); Detector updatedDetector = randomDetectorWithInputs(List.of(input)); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); @@ -150,7 +152,7 @@ public void testRemoveDocLevelRuleAddAggregationRules_verifyFindings_success() t List> findings = (List)getFindingsBody.get("findings"); for(Map finding : findings) { Set aggRulesFinding = ((List>)finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( - Collectors.toSet()); + Collectors.toSet()); // Bucket monitor finding will have one rule String aggRuleId = aggRulesFinding.iterator().next(); @@ -182,10 +184,10 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -195,7 +197,7 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw String maxRuleId = createRule(randomAggregationRule( "max", " > 2")); List detectorRules = List.of(new DetectorRule(maxRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + Collections.emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -204,22 +206,22 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(Rule.CUSTOM_RULES_INDEX, request, true); assertEquals(1, response.getHits().getTotalHits().value); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; // Verify that one bucket level monitor is created List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); @@ -234,7 +236,7 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw String randomDocRuleId = createRule(randomRule()); List prepackagedRules = getRandomPrePackagedRules(); input = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(randomDocRuleId)), - prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); + prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); Detector updatedDetector = randomDetectorWithInputs(List.of(input)); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); @@ -259,11 +261,11 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw // Verify rules request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); assertEquals(6, response.getHits().getTotalHits().value); @@ -294,7 +296,7 @@ public void testReplaceAggregationRuleWithDocRule_verifyFindings_success() throw List foundDocIds = new ArrayList<>(); for(Map finding : findings) { Set aggRulesFinding = ((List>)finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( - Collectors.toSet()); + Collectors.toSet()); assertTrue(docRuleIds.containsAll(aggRulesFinding)); @@ -320,10 +322,10 @@ public void testRemoveAllRulesAndUpdateDetector_success() throws IOException { Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -338,22 +340,22 @@ public void testRemoveAllRulesAndUpdateDetector_success() throws IOException { Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); assertEquals(randomPrepackagedRules.size(), response.getHits().getTotalHits().value); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; // Verify that one doc level monitor is created List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); @@ -398,10 +400,10 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -411,7 +413,7 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio String sumRuleId = createRule(randomAggregationRule("sum", " > 1")); List detectorRules = List.of(new DetectorRule(sumRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + Collections.emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -420,12 +422,12 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); @@ -437,7 +439,7 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio // Test adding the new max monitor and updating the existing sum monitor String maxRuleId = createRule(randomAggregationRule("max", " > 3")); DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(maxRuleId), new DetectorRule(sumRuleId)), - Collections.emptyList()); + Collections.emptyList()); Detector updatedDetector = randomDetectorWithInputs(List.of(newInput)); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(updatedDetector)); @@ -475,7 +477,7 @@ public void testAddNewAggregationRule_verifyFindings_success() throws IOExceptio Map finding = ((List) getFindingsBody.get("findings")).get(0); Set aggRulesFinding = ((List>) finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( - Collectors.toSet()); + Collectors.toSet()); assertEquals(sumRuleId, aggRulesFinding.iterator().next()); @@ -505,10 +507,10 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -523,7 +525,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio List detectorRules = aggRuleIds.stream().map(DetectorRule::new).collect(Collectors.toList()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + Collections.emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -532,12 +534,12 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); @@ -548,7 +550,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio // Test deleting the aggregation rule DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), List.of(new DetectorRule(avgRuleId)), - Collections.emptyList()); + Collections.emptyList()); detector = randomDetectorWithInputs(List.of(newInput)); Response updateResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); @@ -590,7 +592,7 @@ public void testDeleteAggregationRule_verifyFindings_success() throws IOExceptio Map finding = ((List) getFindingsBody.get("findings")).get(0); Set aggRulesFinding = ((List>) finding.get("queries")).stream().map(it -> it.get("id").toString()).collect( - Collectors.toSet()); + Collectors.toSet()); assertEquals(avgRuleId, aggRulesFinding.iterator().next()); @@ -620,10 +622,10 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -639,7 +641,7 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti List prepackagedDocRules = getRandomPrePackagedRules(); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - prepackagedDocRules.stream().map(DetectorRule::new).collect(Collectors.toList())); + prepackagedDocRules.stream().map(DetectorRule::new).collect(Collectors.toList())); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -648,12 +650,12 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); @@ -664,8 +666,8 @@ public void testReplaceAggregationRule_verifyFindings_success() throws IOExcepti String maxRuleId = createRule(randomAggregationRule("max", " > 2")); DetectorInput newInput = new DetectorInput("windows detector for security analytics", List.of("windows"), - List.of(new DetectorRule(avgRuleId), new DetectorRule(maxRuleId)), - getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); + List.of(new DetectorRule(avgRuleId), new DetectorRule(maxRuleId)), + getRandomPrePackagedRules().stream().map(DetectorRule::new).collect(Collectors.toList())); detector = randomDetectorWithInputs(List.of(newInput)); createResponse = makeRequest(client(), "PUT", SecurityAnalyticsPlugin.DETECTOR_BASE_URI + "/" + detectorId, Collections.emptyMap(), toHttpEntity(detector)); @@ -739,10 +741,10 @@ public void testMinAggregationRule_findingSuccess() throws IOException { Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -753,7 +755,7 @@ public void testMinAggregationRule_findingSuccess() throws IOException { aggRuleIds.add(createRule(randomAggregationRule("min", " > 3", testOpCode))); List detectorRules = aggRuleIds.stream().map(id -> new DetectorRule(id)).collect(Collectors.toList()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + Collections.emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -762,12 +764,12 @@ public void testMinAggregationRule_findingSuccess() throws IOException { Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); String request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); @@ -826,10 +828,10 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -850,10 +852,10 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti List prepackagedRules = getRandomPrePackagedRules(); List detectorRules = List.of(new DetectorRule(sumRuleId), new DetectorRule(maxRuleId), new DetectorRule(minRuleId), - new DetectorRule(avgRuleId), new DetectorRule(cntRuleId), new DetectorRule(randomDocRuleId)); + new DetectorRule(avgRuleId), new DetectorRule(cntRuleId), new DetectorRule(randomDocRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); + prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); @@ -861,25 +863,25 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); - assertEquals(7, response.getHits().getTotalHits().value); + assertEquals(6, response.getHits().getTotalHits().value); assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); Map updatedDetectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); @@ -889,7 +891,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti List monitorIds = ((List) (updatedDetectorMap).get("monitor_id")); - assertEquals(7, monitorIds.size()); + assertEquals(6, monitorIds.size()); indexDoc(index, "1", randomDoc(2, 4, infoOpCode)); indexDoc(index, "2", randomDoc(3, 4, infoOpCode)); @@ -902,7 +904,7 @@ public void testMultipleAggregationAndDocRules_findingSuccess() throws IOExcepti Map numberOfMonitorTypes = new HashMap<>(); - for (String monitorId: monitorIds) { + for (String monitorId: monitorIds) { Map monitor = (Map)(entityAsMap(client().performRequest(new Request("GET", "/_plugins/_alerting/monitors/" + monitorId)))).get("monitor"); numberOfMonitorTypes.merge(monitor.get("monitor_type"), 1, Integer::sum); Response executeResponse = executeAlertingMonitor(monitorId, Collections.emptyMap()); @@ -931,7 +933,7 @@ else if (ruleId == minRuleId) { } assertEquals(5, numberOfMonitorTypes.get(MonitorType.BUCKET_LEVEL_MONITOR.getValue()).intValue()); - assertEquals(2, numberOfMonitorTypes.get(MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue()); + assertEquals(1, numberOfMonitorTypes.get(MonitorType.DOC_LEVEL_MONITOR.getValue()).intValue()); Map params = new HashMap<>(); params.put("detector_id", detectorId); @@ -941,7 +943,7 @@ else if (ruleId == minRuleId) { // Assert findings assertNotNull(getFindingsBody); // 8 findings from doc level rules, and 3 findings for aggregation (sum, max and min) - assertEquals(19, getFindingsBody.get("total_findings")); + assertEquals(11, getFindingsBody.get("total_findings")); String findingDetectorId = ((Map)((List)getFindingsBody.get("findings")).get(0)).get("detectorId").toString(); assertEquals(detectorId, findingDetectorId); @@ -979,7 +981,7 @@ else if (ruleId == minRuleId) { assertTrue(Arrays.asList("1", "2", "3", "4", "5", "6", "7", "8").containsAll(docLevelFinding)); } - public void testCreateDetector_verifyWorkflowCreation_success() throws IOException { + public void testCreateDetector_verifyWorkflowCreation_success_WithoutGroupByRulesInTrigger() throws IOException { updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); String index = createTestIndex(randomIndex(), windowsIndexMapping()); @@ -987,10 +989,10 @@ public void testCreateDetector_verifyWorkflowCreation_success() throws IOExcepti Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -1003,32 +1005,102 @@ public void testCreateDetector_verifyWorkflowCreation_success() throws IOExcepti String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + Collections.emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(1, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(2, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(2, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 2); + } + + + + public void testCreateDetector_verifyWorkflowCreation_success_WithGroupByRulesInTrigger() throws IOException { + updateClusterSetting(ENABLE_WORKFLOW_USAGE.getKey(), "true"); + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String testOpCode = "Test"; + + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String randomDocRuleId = createRule(randomRule()); + List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + Collections.emptyList()); + DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(maxRuleId), List.of(), List.of(), List.of()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); - assertEquals(2, response.getHits().getTotalHits().value); + assertEquals(2, response.getHits().getTotalHits().value); assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); Map responseBody = asMap(createResponse); String detectorId = responseBody.get("_id").toString(); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); @@ -1055,10 +1127,10 @@ public void testUpdateDetector_disabledWorkflowUsage_verifyWorkflowNotCreated_su Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -1069,17 +1141,17 @@ public void testUpdateDetector_disabledWorkflowUsage_verifyWorkflowNotCreated_su List detectorRules = List.of(new DetectorRule(randomDocRuleId)); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); + Collections.emptyList()); Detector detector = randomDetectorWithInputs(List.of(input)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); assertEquals(1, response.getHits().getTotalHits().value); @@ -1089,12 +1161,12 @@ public void testUpdateDetector_disabledWorkflowUsage_verifyWorkflowNotCreated_su String detectorId = responseBody.get("_id").toString(); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); @@ -1128,10 +1200,10 @@ public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -1142,21 +1214,21 @@ public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); String randomDocRuleId = createRule(randomRule()); - List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); + DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); - Detector detector = randomDetectorWithInputs(List.of(input)); + Collections.emptyList()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); assertEquals(2, response.getHits().getTotalHits().value); @@ -1166,12 +1238,12 @@ public void testUpdateDetector_removeRule_verifyWorkflowUpdate_success() throws String detectorId = responseBody.get("_id").toString(); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); @@ -1252,10 +1324,10 @@ public void testCreateDetector_workflowWithDuplicateMonitor_failure() throws IOE Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -1268,19 +1340,19 @@ public void testCreateDetector_workflowWithDuplicateMonitor_failure() throws IOE String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); - + DetectorTrigger t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); - Detector detector = randomDetectorWithInputs(List.of(input)); + Collections.emptyList()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); assertEquals(2, response.getHits().getTotalHits().value); @@ -1290,12 +1362,12 @@ public void testCreateDetector_workflowWithDuplicateMonitor_failure() throws IOE String detectorId = responseBody.get("_id").toString(); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); @@ -1321,10 +1393,10 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); // both req params and req body are supported createMappingRequest.setJsonEntity( - "{ \"index_name\":\"" + index + "\"," + - " \"rule_topic\":\"" + randomDetectorType() + "\", " + - " \"partial\":true" + - "}" + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" ); Response createMappingResponse = client().performRequest(createMappingRequest); @@ -1337,19 +1409,20 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor String randomDocRuleId = createRule(randomRule()); List detectorRules = List.of(new DetectorRule(maxRuleId), new DetectorRule(randomDocRuleId)); - + DetectorTrigger t1, t2; + t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(randomDocRuleId, maxRuleId), List.of(), List.of(), List.of()); DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, - Collections.emptyList()); - Detector detector = randomDetectorWithInputs(List.of(input)); + Collections.emptyList()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1)); Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); String request = "{\n" + - " \"query\" : {\n" + - " \"match_all\":{\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); assertEquals(2, response.getHits().getTotalHits().value); @@ -1359,12 +1432,12 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor String detectorId = responseBody.get("_id").toString(); request = "{\n" + - " \"query\" : {\n" + - " \"match\":{\n" + - " \"_id\": \"" + detectorId + "\"\n" + - " }\n" + - " }\n" + - "}"; + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; List hits = executeSearch(Detector.DETECTORS_INDEX, request); SearchHit hit = hits.get(0); Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); @@ -1445,6 +1518,143 @@ public void testCreateDetector_verifyWorkflowExecutionBucketLevelDocLevelMonitor assertTrue(Arrays.asList("1", "2", "3", "4", "5").containsAll(docLevelFinding)); } + public void testCreateDetector_verifyWorkflowExecutionMultipleBucketLevelDocLevelMonitors_success_WithBucketLevelTriggersOnRuleIds() throws IOException { + String index = createTestIndex(randomIndex(), windowsIndexMapping()); + + // Execute CreateMappingsAction to add alias mapping for index + Request createMappingRequest = new Request("POST", SecurityAnalyticsPlugin.MAPPER_BASE_URI); + // both req params and req body are supported + createMappingRequest.setJsonEntity( + "{ \"index_name\":\"" + index + "\"," + + " \"rule_topic\":\"" + randomDetectorType() + "\", " + + " \"partial\":true" + + "}" + ); + + Response createMappingResponse = client().performRequest(createMappingRequest); + + assertEquals(HttpStatus.SC_OK, createMappingResponse.getStatusLine().getStatusCode()); + + String infoOpCode = "Info"; + String testOpCode = "Test"; + + // 5 custom aggregation rules + String sumRuleId = createRule(randomAggregationRule("sum", " > 1", infoOpCode)); + String maxRuleId = createRule(randomAggregationRule("max", " > 3", testOpCode)); + String minRuleId = createRule(randomAggregationRule("min", " > 3", testOpCode)); + String avgRuleId = createRule(randomAggregationRule("avg", " > 3", infoOpCode)); + String cntRuleId = createRule(randomAggregationRule("count", " > 3", "randomTestCode")); + String randomDocRuleId = createRule(randomRule()); + List prepackagedRules = getRandomPrePackagedRules(); + + List detectorRules = List.of(new DetectorRule(sumRuleId), new DetectorRule(maxRuleId), new DetectorRule(minRuleId), + new DetectorRule(avgRuleId), new DetectorRule(cntRuleId), new DetectorRule(randomDocRuleId)); + + DetectorInput input = new DetectorInput("windows detector for security analytics", List.of("windows"), detectorRules, + prepackagedRules.stream().map(DetectorRule::new).collect(Collectors.toList())); + DetectorTrigger t1, t2; + t1 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(sumRuleId, maxRuleId), List.of(), List.of(), List.of()); + t2 = new DetectorTrigger(null, "test-trigger", "1", List.of(), List.of(minRuleId, avgRuleId, cntRuleId), List.of(), List.of(), List.of()); + Detector detector = randomDetectorWithInputsAndTriggers(List.of(input), List.of(t1, t2)); + + Response createResponse = makeRequest(client(), "POST", SecurityAnalyticsPlugin.DETECTOR_BASE_URI, Collections.emptyMap(), toHttpEntity(detector)); + + String request = "{\n" + + " \"query\" : {\n" + + " \"match_all\":{\n" + + " }\n" + + " }\n" + + "}"; + SearchResponse response = executeSearchAndGetResponse(DetectorMonitorConfig.getRuleIndex(randomDetectorType()), request, true); + + assertEquals(7, response.getHits().getTotalHits().value); + + assertEquals("Create detector failed", RestStatus.CREATED, restStatus(createResponse)); + Map responseBody = asMap(createResponse); + + String detectorId = responseBody.get("_id").toString(); + request = "{\n" + + " \"query\" : {\n" + + " \"match\":{\n" + + " \"_id\": \"" + detectorId + "\"\n" + + " }\n" + + " }\n" + + "}"; + List hits = executeSearch(Detector.DETECTORS_INDEX, request); + SearchHit hit = hits.get(0); + Map detectorMap = (HashMap)(hit.getSourceAsMap().get("detector")); + List inputArr = (List) detectorMap.get("inputs"); + + assertEquals(6, ((Map>) inputArr.get(0)).get("detector_input").get("custom_rules").size()); + + List monitorIds = ((List) (detectorMap).get("monitor_id")); + assertEquals(7, monitorIds.size()); + + assertNotNull("Workflow not created", detectorMap.get("workflow_ids")); + assertEquals("Number of workflows not correct", 1, ((List) detectorMap.get("workflow_ids")).size()); + + indexDoc(index, "1", randomDoc(2, 4, infoOpCode)); + indexDoc(index, "2", randomDoc(3, 4, infoOpCode)); + indexDoc(index, "3", randomDoc(1, 4, infoOpCode)); + indexDoc(index, "4", randomDoc(5, 3, testOpCode)); + indexDoc(index, "5", randomDoc(2, 3, testOpCode)); + indexDoc(index, "6", randomDoc(4, 3, testOpCode)); + indexDoc(index, "7", randomDoc(6, 2, testOpCode)); + indexDoc(index, "8", randomDoc(1, 1, testOpCode)); + // Verify workflow + verifyWorkflow(detectorMap, monitorIds, 7); + + String workflowId = ((List) detectorMap.get("workflow_ids")).get(0); + + HashMap bucketMonitorsToRuleMap = (HashMap) detectorMap.get("bucket_monitor_id_rule_id"); + String docMonitorId = bucketMonitorsToRuleMap.get("-1"); + String chainedFindingsMonitorId = bucketMonitorsToRuleMap.get("chained_findings_monitor"); + Map monitorNameToIdMap = new HashMap<>(); + for (Map.Entry entry : bucketMonitorsToRuleMap.entrySet()) { + Response getMonitorRes = getAlertingMonitor(client(), entry.getValue()); + Map resMap = asMap(getMonitorRes); + Map stringObjectMap = (Map) resMap.get("monitor"); + String name = stringObjectMap.get("name").toString(); + monitorNameToIdMap.put(name, entry.getValue()); + } + + + Response executeResponse = executeAlertingWorkflow(workflowId, Collections.emptyMap()); + + Map executeWorkflowResponseMap = entityAsMap(executeResponse); + List> monitorRunResults = (List>) executeWorkflowResponseMap.get("monitor_run_results"); + + for (Map runResult : monitorRunResults) { + String monitorName = runResult.get("monitor_name").toString(); + String monitorId = monitorNameToIdMap.get(monitorName); + if(monitorId.equals(docMonitorId)){ + int noOfSigmaRuleMatches = ((List>) ((Map) runResult.get("input_results")).get("results")).get(0).size(); + // 5 prepackaged and 1 custom doc level rule + assertEquals(6, noOfSigmaRuleMatches); + } else if(monitorId.equals(chainedFindingsMonitorId)) { + + } else { + Map trigger_results = (Map) runResult.get("trigger_results"); + if (trigger_results.containsKey(maxRuleId)) { + assertRuleMonitorFinding(runResult, maxRuleId, 5, List.of("2", "3")); + } else if( trigger_results.containsKey(sumRuleId)) { + assertRuleMonitorFinding(runResult, sumRuleId, 3, List.of("4")); + } else if( trigger_results.containsKey(minRuleId)) { + assertRuleMonitorFinding(runResult, minRuleId, 5, List.of("2")); + } + } + } + + Map params = new HashMap<>(); + params.put("detector_id", detectorId); + Response getFindingsResponse = makeRequest(client(), "GET", SecurityAnalyticsPlugin.FINDINGS_BASE_URI + "/_search", params, null); + Map getFindingsBody = entityAsMap(getFindingsResponse); + + // Assert findings + assertNotNull(getFindingsBody); + assertEquals(33, getFindingsBody.get("total_findings")); + } + private static void assertRuleMonitorFinding(Map executeResults, String ruleId, int expectedDocCount, List expectedTriggerResult) { List> buckets = ((List>)(((Map)((Map)((Map)((List)((Map) executeResults.get("input_results")).get("results")).get(0)).get("aggregations")).get("result_agg")).get("buckets"))); @@ -1454,4 +1664,4 @@ private static void assertRuleMonitorFinding(Map executeResults, List triggerResultBucketKeys = ((Map)((Map) ((Map)executeResults.get("trigger_results")).get(ruleId)).get("agg_result_buckets")).keySet().stream().collect(Collectors.toList()); Assert.assertEquals(expectedTriggerResult, triggerResultBucketKeys); } -} +} \ No newline at end of file