diff --git a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java index ca97abd7e..6bd9c8e5b 100644 --- a/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java +++ b/src/main/java/org/opensearch/securityanalytics/SecurityAnalyticsPlugin.java @@ -218,7 +218,8 @@ public List getRestHandlers(Settings settings, new RestIndexCustomLogTypeAction(), new RestSearchCustomLogTypeAction(), new RestDeleteCustomLogTypeAction(), - new RestGetCorrelationsAlertsAction() + new RestGetCorrelationsAlertsAction(), + new RestAcknowledgeCorrelationAlertsAction() ); } @@ -340,7 +341,8 @@ public List> getSettings() { new ActionHandler<>(SearchCustomLogTypeAction.INSTANCE, TransportSearchCustomLogTypeAction.class), new ActionHandler<>(DeleteCustomLogTypeAction.INSTANCE, TransportDeleteCustomLogTypeAction.class), new ActionHandler<>(PutTIFJobAction.INSTANCE, TransportPutTIFJobAction.class), - new ActionPlugin.ActionHandler<>(GetCorrelationAlertsAction.INSTANCE, TransportGetCorrelationAlertsAction.class) + new ActionPlugin.ActionHandler<>(GetCorrelationAlertsAction.INSTANCE, TransportGetCorrelationAlertsAction.class), + new ActionPlugin.ActionHandler<>(CorrelationAckAlertsAction.INSTANCE, TransportAckCorrelationAlertsAction.class) ); } diff --git a/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsAction.java new file mode 100644 index 000000000..7aced9f42 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsAction.java @@ -0,0 +1,20 @@ +/* + * Copyright OpenSearch Contributors + * SPDX-License-Identifier: Apache-2.0 + */ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionType; + +/** + * Acknowledge Alert Action + */ +public class CorrelationAckAlertsAction extends ActionType { + public static final String NAME = "cluster:admin/opensearch/securityanalytics/correlationAlerts/ack"; + public static final CorrelationAckAlertsAction INSTANCE = new CorrelationAckAlertsAction(); + + public CorrelationAckAlertsAction() { + super(NAME, CorrelationAckAlertsResponse::new); + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsRequest.java new file mode 100644 index 000000000..b9cf798ae --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsRequest.java @@ -0,0 +1,52 @@ +package org.opensearch.securityanalytics.action; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionRequestValidationException; +import org.opensearch.action.ValidateActions; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContent; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class CorrelationAckAlertsRequest extends ActionRequest { + private final List correlationAlertIds; + + public CorrelationAckAlertsRequest(List correlationAlertIds) { + this.correlationAlertIds = correlationAlertIds; + } + + public CorrelationAckAlertsRequest(StreamInput in) throws IOException { + correlationAlertIds = Collections.unmodifiableList(in.readStringList()); + } + + @Override + public ActionRequestValidationException validate() { + ActionRequestValidationException validationException = null; + if(correlationAlertIds == null || correlationAlertIds.isEmpty()) { + validationException = ValidateActions.addValidationError("alert ids list cannot be empty", validationException); + } + return validationException; + } + + public void writeTo(StreamOutput out) throws IOException { + out.writeStringCollection(this.correlationAlertIds); + } + + public XContentBuilder toXContent(XContentBuilder builder, ToXContent.Params params) throws IOException { + return builder.startObject() + .field("correlation_alert_ids", correlationAlertIds) + .endObject(); + } + + public static AckAlertsRequest readFrom(StreamInput sin) throws IOException { + return new AckAlertsRequest(sin); + } + + public List getCorrelationAlertIds() { + return correlationAlertIds; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsResponse.java b/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsResponse.java new file mode 100644 index 000000000..654f929f5 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/action/CorrelationAckAlertsResponse.java @@ -0,0 +1,52 @@ +package org.opensearch.securityanalytics.action; + +import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.core.action.ActionResponse; +import org.opensearch.core.common.io.stream.StreamInput; +import org.opensearch.core.common.io.stream.StreamOutput; +import org.opensearch.core.xcontent.ToXContentObject; +import org.opensearch.core.xcontent.XContentBuilder; + +import java.io.IOException; +import java.util.Collections; +import java.util.List; + +public class CorrelationAckAlertsResponse extends ActionResponse implements ToXContentObject { + + private final List acknowledged; + private final List failed; + + public CorrelationAckAlertsResponse(List acknowledged, List failed) { + this.acknowledged = acknowledged; + this.failed = failed; + } + + public CorrelationAckAlertsResponse(StreamInput sin) throws IOException { + this( + Collections.unmodifiableList(sin.readList(CorrelationAlert::new)), + Collections.unmodifiableList(sin.readList(CorrelationAlert::new)) + ); + } + + @Override + public void writeTo(StreamOutput streamOutput) throws IOException { + streamOutput.writeList(this.acknowledged); + streamOutput.writeList(this.failed); + } + + @Override + public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { + builder.startObject() + .field("acknowledged",this.acknowledged) + .field("failed",this.failed); + return builder.endObject(); + } + + public List getAcknowledged() { + return acknowledged; + } + + public List getFailed() { + return failed; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java index 77811f6d1..1eade7bd1 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsRequest.java @@ -58,9 +58,9 @@ public GetCorrelationAlertsRequest(StreamInput sin) throws IOException { @Override public ActionRequestValidationException validate() { ActionRequestValidationException validationException = null; - if ((correlationRuleId == null || correlationRuleId.length() == 0)) { + if ((correlationRuleId != null && correlationRuleId.isEmpty())) { validationException = addValidationError(String.format(Locale.getDefault(), - "At least one of correlation rule id needs to be passed", CORRELATION_RULE_ID), + "Correlation ruleId is empty or not valid", CORRELATION_RULE_ID), validationException); } return validationException; diff --git a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java index 52c4ebc96..d883056f8 100644 --- a/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java +++ b/src/main/java/org/opensearch/securityanalytics/action/GetCorrelationAlertsResponse.java @@ -1,18 +1,20 @@ package org.opensearch.securityanalytics.action; +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; import org.opensearch.commons.alerting.model.CorrelationAlert; import org.opensearch.core.action.ActionResponse; import org.opensearch.core.common.io.stream.StreamInput; import org.opensearch.core.common.io.stream.StreamOutput; import org.opensearch.core.xcontent.ToXContentObject; import org.opensearch.core.xcontent.XContentBuilder; - import java.io.IOException; import java.util.Collections; import java.util.List; public class GetCorrelationAlertsResponse extends ActionResponse implements ToXContentObject { + private static final Logger log = LogManager.getLogger(GetCorrelationAlertsResponse.class); private static final String CORRELATION_ALERTS_FIELD = "correlationAlerts"; private static final String TOTAL_ALERTS_FIELD = "total_alerts"; @@ -41,16 +43,8 @@ public void writeTo(StreamOutput out) throws IOException { @Override public XContentBuilder toXContent(XContentBuilder builder, Params params) throws IOException { builder.startObject() - .field(CORRELATION_ALERTS_FIELD, alerts) - .field(TOTAL_ALERTS_FIELD, totalAlerts); + .field(CORRELATION_ALERTS_FIELD, this.alerts) + .field(TOTAL_ALERTS_FIELD, this.totalAlerts); return builder.endObject(); } - - public List getAlerts() { - return this.alerts; - } - - public Integer getTotalAlerts() { - return this.totalAlerts; - } } \ No newline at end of file diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java index 03d4a0b73..20cff273a 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/JoinEngine.java @@ -560,7 +560,6 @@ private void getCorrelatedFindings(String detectorType, Map if (!correlatedFindings.isEmpty()) { CorrelationRuleScheduler correlationRuleScheduler = new CorrelationRuleScheduler(client, correlationAlertService, notificationService); correlationRuleScheduler.schedule(correlationRules, correlatedFindings, request.getFinding().getId(), indexTimeout, user); - correlationRuleScheduler.shutdown(); } for (Map.Entry> autoCorrelation: autoCorrelations.entrySet()) { diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java index bc76176a0..34f369c0a 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertService.java @@ -6,16 +6,18 @@ import org.apache.logging.log4j.LogManager; import org.apache.logging.log4j.Logger; +import org.opensearch.action.bulk.BulkItemResponse; +import org.opensearch.action.bulk.BulkRequest; +import org.opensearch.action.bulk.BulkResponse; import org.opensearch.action.index.IndexRequest; import org.opensearch.action.index.IndexResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; +import org.opensearch.action.update.UpdateRequest; import org.opensearch.client.Client; import org.opensearch.common.lucene.uid.Versions; -import org.opensearch.commons.alerting.model.ActionExecutionResult; import org.opensearch.commons.alerting.model.Alert; import org.opensearch.commons.alerting.model.Table; -import org.opensearch.commons.authuser.User; import org.opensearch.common.unit.TimeValue; import org.opensearch.common.xcontent.LoggingDeprecationHandler; import org.opensearch.common.xcontent.XContentFactory; @@ -24,16 +26,18 @@ import org.opensearch.core.xcontent.NamedXContentRegistry; import org.opensearch.core.xcontent.XContentBuilder; import org.opensearch.core.xcontent.XContentParser; -import org.opensearch.core.xcontent.XContentParserUtils; import org.opensearch.index.query.BoolQueryBuilder; import org.opensearch.index.query.QueryBuilders; +import org.opensearch.index.query.TermsQueryBuilder; +import org.opensearch.script.Script; +import org.opensearch.script.ScriptType; import org.opensearch.search.SearchHit; import org.opensearch.search.builder.SearchSourceBuilder; import org.opensearch.commons.alerting.model.CorrelationAlert; import org.opensearch.search.sort.FieldSortBuilder; -import org.opensearch.search.sort.SortBuilder; import org.opensearch.search.sort.SortBuilders; import org.opensearch.search.sort.SortOrder; +import org.opensearch.securityanalytics.action.CorrelationAckAlertsResponse; import org.opensearch.securityanalytics.action.GetCorrelationAlertsResponse; import org.opensearch.securityanalytics.util.CorrelationIndices; import java.io.IOException; @@ -41,6 +45,7 @@ import java.util.List; import java.util.ArrayList; import java.util.Collections; +import java.util.Map; public class CorrelationAlertService { private static final Logger log = LogManager.getLogger(CorrelationAlertService.class); @@ -145,127 +150,17 @@ public void indexCorrelationAlert(CorrelationAlert correlationAlert, TimeValue i } } - public List parseCorrelationAlerts(final SearchResponse response) throws IOException { - List alerts = new ArrayList<>(); - for (SearchHit hit : response.getHits()) { - XContentParser xcp = XContentType.JSON.xContent().createParser( - xContentRegistry, - LoggingDeprecationHandler.INSTANCE, - hit.getSourceAsString() - ); - xcp.nextToken(); - CorrelationAlert correlationAlert = parse(xcp, hit.getId(), hit.getVersion()); - alerts.add(correlationAlert); - } - return alerts; - } - - // logic will be moved to common-utils, once the parsing logic in common-utils is fixed - public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException { - // Parse additional CorrelationAlert-specific fields - List correlatedFindingIds = new ArrayList<>(); - String correlationRuleId = null; - String correlationRuleName = null; - User user = null; - int schemaVersion = 0; - String triggerName = null; - Alert.State state = null; - String errorMessage = null; - String severity = null; - List actionExecutionResults = new ArrayList<>(); - Instant startTime = null; - Instant endTime = null; - Instant acknowledgedTime = null; - - while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { - String fieldName = xcp.currentName(); - xcp.nextToken(); - switch (fieldName) { - case CORRELATED_FINDING_IDS: - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); - while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { - correlatedFindingIds.add(xcp.text()); - } - break; - case CORRELATION_RULE_ID: - correlationRuleId = xcp.text(); - break; - case CORRELATION_RULE_NAME: - correlationRuleName = xcp.text(); - break; - case USER_FIELD: - user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp); - break; - case ALERT_ID_FIELD: - id = xcp.text(); - break; - case ALERT_VERSION_FIELD: - version = xcp.longValue(); - break; - case SCHEMA_VERSION_FIELD: - schemaVersion = xcp.intValue(); - break; - case TRIGGER_NAME_FIELD: - triggerName = xcp.text(); - break; - case STATE_FIELD: - state = Alert.State.valueOf(xcp.text()); - break; - case ERROR_MESSAGE_FIELD: - errorMessage = xcp.textOrNull(); - break; - case SEVERITY_FIELD: - severity = xcp.text(); - break; - case ACTION_EXECUTION_RESULTS_FIELD: - XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); - while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { - actionExecutionResults.add(ActionExecutionResult.parse(xcp)); - } - break; - case START_TIME_FIELD: - startTime = Instant.parse(xcp.text()); - break; - case END_TIME_FIELD: - endTime = Instant.parse(xcp.text()); - break; - case ACKNOWLEDGED_TIME_FIELD: - if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { - acknowledgedTime = null; - } else { - acknowledgedTime = Instant.parse(xcp.text()); - } - break; - } + public void getAlerts(String ruleId, Table tableProp, ActionListener listener) { + BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery(); + if (ruleId != null) { + queryBuilder = QueryBuilders.boolQuery() + .must(QueryBuilders.termQuery("correlation_rule_id", ruleId)); } - // Create and return CorrelationAlert object - return new CorrelationAlert( - correlatedFindingIds, - correlationRuleId, - correlationRuleName, - id, - version, - schemaVersion, - user, - triggerName, - state, - startTime, - endTime, - acknowledgedTime, - errorMessage, - severity, - actionExecutionResults - ); - } - public void getAlertsByRuleId(String ruleId, Table tableProp, ActionListener listener) { - BoolQueryBuilder queryBuilder = QueryBuilders.boolQuery() - .must(QueryBuilders.termQuery("correlation_rule_id", ruleId)); - FieldSortBuilder sortBuilder = SortBuilders .fieldSort(tableProp.getSortString()) .order(SortOrder.fromString(tableProp.getSortOrder())); - if (!tableProp.getMissing().isEmpty()) { + if (tableProp.getMissing() != null && !tableProp.getMissing().isEmpty()) { sortBuilder.missing(tableProp.getMissing()); } @@ -299,6 +194,100 @@ public void getAlertsByRuleId(String ruleId, Table tableProp, ActionListener alertIds, ActionListener listener) { + BulkRequest bulkRequest = new BulkRequest(); + List acknowledgedAlerts = new ArrayList<>(); + List failedAlerts = new ArrayList<>(); + + TermsQueryBuilder termsQueryBuilder = QueryBuilders.termsQuery("id", alertIds); + SearchSourceBuilder searchSourceBuilder = new SearchSourceBuilder().query(termsQueryBuilder); + SearchRequest searchRequest = new SearchRequest(CorrelationIndices.CORRELATION_ALERT_INDEX) + .source(searchSourceBuilder); + + // Execute the search request + client.search(searchRequest, new ActionListener() { + @Override + public void onResponse(SearchResponse searchResponse) { + // Iterate through the search hits + for (SearchHit hit : searchResponse.getHits().getHits()) { + // Construct a script to update the document with the new state and acknowledgedTime + // Construct a script to update the document with the new state and acknowledgedTime + Script script = new Script(ScriptType.INLINE, "painless", + "ctx._source.state = params.state; ctx._source.acknowledged_time = params.time", + Map.of("state", Alert.State.ACKNOWLEDGED, "time", Instant.now())); + // Create an update request with the script + UpdateRequest updateRequest = new UpdateRequest(CorrelationIndices.CORRELATION_ALERT_INDEX, hit.getId()) + .script(script); + + // Add the update request to the bulk request + bulkRequest.add(updateRequest); + + // Add the current alert to the acknowledged alerts list + try { + acknowledgedAlerts.add(getParsedCorrelationAlert(hit)); + } catch (IOException e) { + log.error("Exception while acknowledging alerts: {}", e.toString()); + } + } + + // Check if there are any update requests in the bulk request + if (!bulkRequest.requests().isEmpty()) { + // Execute the bulk request asynchronously + client.bulk(bulkRequest, new ActionListener() { + @Override + public void onResponse(BulkResponse bulkResponse) { + // Iterate through the bulk response to identify failed updates + for (BulkItemResponse itemResponse : bulkResponse.getItems()) { + if (itemResponse.isFailed()) { + // If an update failed, add the corresponding alert to the failed alerts list + failedAlerts.add(acknowledgedAlerts.get(itemResponse.getItemId())); + } + } + // Create and pass the CorrelationAckAlertsResponse to the listener + listener.onResponse(new CorrelationAckAlertsResponse(acknowledgedAlerts, failedAlerts)); + } + + @Override + public void onFailure(Exception e) { + // Handle failure + listener.onFailure(e); + } + }); + } else { + // If there are no update requests, return an empty response + listener.onResponse(new CorrelationAckAlertsResponse(acknowledgedAlerts, failedAlerts)); + } + } + + @Override + public void onFailure(Exception e) { + // Handle failure + listener.onFailure(e); + } + }); + } + + + public List parseCorrelationAlerts(final SearchResponse response) throws IOException { + List alerts = new ArrayList<>(); + for (SearchHit hit : response.getHits()) { + CorrelationAlert correlationAlert = getParsedCorrelationAlert(hit); + alerts.add(correlationAlert); + } + return alerts; + } + + private CorrelationAlert getParsedCorrelationAlert(SearchHit hit) throws IOException { + XContentParser xcp = XContentType.JSON.xContent().createParser( + xContentRegistry, + LoggingDeprecationHandler.INSTANCE, + hit.getSourceAsString() + ); + xcp.nextToken(); + CorrelationAlert correlationAlert = CorrelationAlertsList.parse(xcp, hit.getId(), hit.getVersion()); + return correlationAlert; + } + } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java index a6cdda9a6..e7e45afe5 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationAlertsList.java @@ -4,8 +4,16 @@ */ package org.opensearch.securityanalytics.correlation.alert; +import org.opensearch.commons.alerting.model.ActionExecutionResult; +import org.opensearch.commons.alerting.model.Alert; import org.opensearch.commons.alerting.model.CorrelationAlert; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.core.xcontent.XContentParserUtils; +import java.io.IOException; +import java.time.Instant; +import java.util.ArrayList; import java.util.List; /** @@ -22,6 +30,105 @@ public CorrelationAlertsList(List correlationAlertList, Intege this.totalAlerts = totalAlerts; } + // logic will be moved to common-utils, once the parsing logic in common-utils is fixed + public static CorrelationAlert parse(XContentParser xcp, String id, long version) throws IOException { + // Parse additional CorrelationAlert-specific fields + List correlatedFindingIds = new ArrayList<>(); + String correlationRuleId = null; + String correlationRuleName = null; + User user = null; + int schemaVersion = 0; + String triggerName = null; + Alert.State state = null; + String errorMessage = null; + String severity = null; + List actionExecutionResults = new ArrayList<>(); + Instant startTime = null; + Instant endTime = null; + Instant acknowledgedTime = null; + + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + switch (fieldName) { + case CorrelationAlertService.CORRELATED_FINDING_IDS: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + correlatedFindingIds.add(xcp.text()); + } + break; + case CorrelationAlertService.CORRELATION_RULE_ID: + correlationRuleId = xcp.text(); + break; + case CorrelationAlertService.CORRELATION_RULE_NAME: + correlationRuleName = xcp.text(); + break; + case CorrelationAlertService.USER_FIELD: + user = (xcp.currentToken() == XContentParser.Token.VALUE_NULL) ? null : User.parse(xcp); + break; + case CorrelationAlertService.ALERT_ID_FIELD: + id = xcp.text(); + break; + case CorrelationAlertService.ALERT_VERSION_FIELD: + version = xcp.longValue(); + break; + case CorrelationAlertService.SCHEMA_VERSION_FIELD: + schemaVersion = xcp.intValue(); + break; + case CorrelationAlertService.TRIGGER_NAME_FIELD: + triggerName = xcp.text(); + break; + case CorrelationAlertService.STATE_FIELD: + state = Alert.State.valueOf(xcp.text()); + break; + case CorrelationAlertService.ERROR_MESSAGE_FIELD: + errorMessage = xcp.textOrNull(); + break; + case CorrelationAlertService.SEVERITY_FIELD: + severity = xcp.text(); + break; + case CorrelationAlertService.ACTION_EXECUTION_RESULTS_FIELD: + XContentParserUtils.ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + actionExecutionResults.add(ActionExecutionResult.parse(xcp)); + } + break; + case CorrelationAlertService.START_TIME_FIELD: + startTime = Instant.parse(xcp.text()); + break; + case CorrelationAlertService.END_TIME_FIELD: + endTime = Instant.parse(xcp.text()); + break; + case CorrelationAlertService.ACKNOWLEDGED_TIME_FIELD: + if (xcp.currentToken() == XContentParser.Token.VALUE_NULL) { + acknowledgedTime = null; + } else { + acknowledgedTime = Instant.parse(xcp.text()); + } + break; + } + } + + // Create and return CorrelationAlert object + return new CorrelationAlert( + correlatedFindingIds, + correlationRuleId, + correlationRuleName, + id, + version, + schemaVersion, + user, + triggerName, + state, + startTime, + endTime, + acknowledgedTime, + errorMessage, + severity, + actionExecutionResults + ); + } + public List getCorrelationAlertList() { return correlationAlertList; } diff --git a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java index 945407b15..9c2e23966 100644 --- a/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java +++ b/src/main/java/org/opensearch/securityanalytics/correlation/alert/CorrelationRuleScheduler.java @@ -22,8 +22,6 @@ import java.util.List; import java.util.ArrayList; import java.util.Map; -import java.util.concurrent.ExecutorService; -import java.util.concurrent.Executors; public class CorrelationRuleScheduler { @@ -31,13 +29,11 @@ public class CorrelationRuleScheduler { private final Client client; private final CorrelationAlertService correlationAlertService; private final NotificationService notificationService; - private final ExecutorService executorService; public CorrelationRuleScheduler(Client client, CorrelationAlertService correlationAlertService, NotificationService notificationService) { this.client = client; this.correlationAlertService = correlationAlertService; this.notificationService = notificationService; - this.executorService = Executors.newCachedThreadPool(); } public void schedule(List correlationRules, Map> correlatedFindings, String sourceFinding, TimeValue indexTimeout, User user) { @@ -56,15 +52,12 @@ public void schedule(List correlationRules, Map findingIds, TimeValue indexTimeout, String sourceFindingId, User user) { long startTime = Instant.now().toEpochMilli(); long endTime = startTime + correlationRule.getCorrTimeWindow(); RuleTask ruleTask = new RuleTask(correlationRule, findingIds, startTime, endTime, correlationAlertService, notificationService, indexTimeout, sourceFindingId, user); - executorService.submit(ruleTask); + ruleTask.run(); } private class RuleTask implements Runnable { diff --git a/src/main/java/org/opensearch/securityanalytics/resthandler/RestAcknowledgeCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/resthandler/RestAcknowledgeCorrelationAlertsAction.java new file mode 100644 index 000000000..faf36bb7c --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/resthandler/RestAcknowledgeCorrelationAlertsAction.java @@ -0,0 +1,68 @@ +package org.opensearch.securityanalytics.resthandler; + +import org.opensearch.client.node.NodeClient; +import org.opensearch.core.xcontent.XContentParser; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; +import org.opensearch.securityanalytics.SecurityAnalyticsPlugin; +import org.opensearch.securityanalytics.action.CorrelationAckAlertsAction; +import org.opensearch.securityanalytics.action.CorrelationAckAlertsRequest; + +import java.io.IOException; +import java.util.ArrayList; +import java.util.Collections; +import java.util.List; +import java.util.Locale; + +import static org.opensearch.core.xcontent.XContentParserUtils.ensureExpectedToken; + + +/** + * Acknowledge list of correlation alerts generated by correlation rules. + */ +public class RestAcknowledgeCorrelationAlertsAction extends BaseRestHandler { + @Override + public String getName() { + return "ack_correlation_alerts_action"; + } + + @Override + public List routes() { + return Collections.singletonList( + new Route(RestRequest.Method.POST, String.format( + Locale.getDefault(), + "%s/_acknowledge/correlationAlerts", + SecurityAnalyticsPlugin.PLUGINS_BASE_URI) + )); + } + + @Override + protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient nodeClient) throws IOException { + List alertIds = getAlertIds(request.contentParser()); + CorrelationAckAlertsRequest CorrelationAckAlertsRequest = new CorrelationAckAlertsRequest(alertIds); + return channel -> nodeClient.execute( + CorrelationAckAlertsAction.INSTANCE, + CorrelationAckAlertsRequest, + new RestToXContentListener<>(channel) + ); + } + + private List getAlertIds(XContentParser xcp) throws IOException { + List ids = new ArrayList<>(); + ensureExpectedToken(XContentParser.Token.START_OBJECT, xcp.nextToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_OBJECT) { + String fieldName = xcp.currentName(); + xcp.nextToken(); + if (fieldName.equals("alertIds")) { + ensureExpectedToken(XContentParser.Token.START_ARRAY, xcp.currentToken(), xcp); + while (xcp.nextToken() != XContentParser.Token.END_ARRAY) { + ids.add(xcp.text()); + } + } + + } + return ids; + } +} + diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportAckCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportAckCorrelationAlertsAction.java new file mode 100644 index 000000000..aee6282f7 --- /dev/null +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportAckCorrelationAlertsAction.java @@ -0,0 +1,77 @@ +package org.opensearch.securityanalytics.transport; + +import org.apache.logging.log4j.LogManager; +import org.apache.logging.log4j.Logger; +import org.opensearch.OpenSearchStatusException; +import org.opensearch.action.support.ActionFilters; +import org.opensearch.action.support.HandledTransportAction; +import org.opensearch.client.Client; +import org.opensearch.cluster.service.ClusterService; +import org.opensearch.common.inject.Inject; +import org.opensearch.common.settings.Settings; +import org.opensearch.commons.authuser.User; +import org.opensearch.core.action.ActionListener; +import org.opensearch.core.rest.RestStatus; +import org.opensearch.core.xcontent.NamedXContentRegistry; +import org.opensearch.securityanalytics.action.CorrelationAckAlertsAction; +import org.opensearch.securityanalytics.action.GetCorrelationAlertsAction; +import org.opensearch.securityanalytics.action.CorrelationAckAlertsRequest; +import org.opensearch.securityanalytics.action.CorrelationAckAlertsResponse; +import org.opensearch.securityanalytics.correlation.alert.CorrelationAlertService; +import org.opensearch.securityanalytics.settings.SecurityAnalyticsSettings; +import org.opensearch.tasks.Task; +import org.opensearch.threadpool.ThreadPool; +import org.opensearch.transport.TransportService; + +public class TransportAckCorrelationAlertsAction extends HandledTransportAction implements SecureTransportAction { + + private final NamedXContentRegistry xContentRegistry; + + private final ClusterService clusterService; + + private final Settings settings; + + private final ThreadPool threadPool; + + private final CorrelationAlertService correlationAlertService; + + private volatile Boolean filterByEnabled; + + private static final Logger log = LogManager.getLogger(TransportGetCorrelationAlertsAction.class); + + + @Inject + public TransportAckCorrelationAlertsAction(TransportService transportService, ActionFilters actionFilters, ClusterService clusterService, CorrelationAckAlertsAction correlationAckAlertsAction, ThreadPool threadPool, Settings settings, NamedXContentRegistry xContentRegistry, Client client) { + super(correlationAckAlertsAction.NAME, transportService, actionFilters, CorrelationAckAlertsRequest::new); + this.xContentRegistry = xContentRegistry; + this.correlationAlertService = new CorrelationAlertService(client, xContentRegistry); + this.clusterService = clusterService; + this.threadPool = threadPool; + this.settings = settings; + this.filterByEnabled = SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES.get(this.settings); + this.clusterService.getClusterSettings().addSettingsUpdateConsumer(SecurityAnalyticsSettings.FILTER_BY_BACKEND_ROLES, this::setFilterByEnabled); + } + + @Override + protected void doExecute(Task task, CorrelationAckAlertsRequest request, ActionListener actionListener) { + + User user = readUserFromThreadContext(this.threadPool); + + String validateBackendRoleMessage = validateUserBackendRoles(user, this.filterByEnabled); + if (!"".equals(validateBackendRoleMessage)) { + actionListener.onFailure(new OpenSearchStatusException("Do not have permissions to resource", RestStatus.FORBIDDEN)); + return; + } + + if (!request.getCorrelationAlertIds().isEmpty()) { + correlationAlertService.acknowledgeAlerts( + request.getCorrelationAlertIds(), + actionListener + ); + } + } + + private void setFilterByEnabled(boolean filterByEnabled) { + this.filterByEnabled = filterByEnabled; + } +} diff --git a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java index 2408dbd84..2886deab9 100644 --- a/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java +++ b/src/main/java/org/opensearch/securityanalytics/transport/TransportGetCorrelationAlertsAction.java @@ -61,11 +61,17 @@ protected void doExecute(Task task, GetCorrelationAlertsRequest request, ActionL } if (request.getCorrelationRuleId() != null) { - correlationAlertService.getAlertsByRuleId( + correlationAlertService.getAlerts( request.getCorrelationRuleId(), request.getTable(), actionListener ); + } else { + correlationAlertService.getAlerts( + null, + request.getTable(), + actionListener + ); } }