diff --git a/.eclipseformat.xml b/.eclipseformat.xml new file mode 100644 index 0000000000..6e93f9b22c --- /dev/null +++ b/.eclipseformat.xml @@ -0,0 +1,362 @@ + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + + diff --git a/.github/workflows/CI-workflow.yml b/.github/workflows/CI-workflow.yml index 9bcb5820d3..25421f7ee3 100644 --- a/.github/workflows/CI-workflow.yml +++ b/.github/workflows/CI-workflow.yml @@ -21,11 +21,11 @@ jobs: with: repository: 'opensearch-project/OpenSearch' path: OpenSearch - ref: 1.0.0-beta1 + ref: '1.x' - name: Build OpenSearch working-directory: ./OpenSearch - run: ./gradlew publishToMavenLocal -Dbuild.version_qualifier=beta1 -Dbuild.snapshot=false + run: ./gradlew publishToMavenLocal -Dbuild.version_qualifier=rc1 -Dbuild.snapshot=false - name: Build with Gradle run: ./gradlew build \ No newline at end of file diff --git a/build.gradle b/build.gradle index f37e5ec6ea..4050432337 100644 --- a/build.gradle +++ b/build.gradle @@ -14,7 +14,7 @@ buildscript { ext { ext { opensearch_group = "org.opensearch" - opensearch_version = "1.0.0-beta1" + opensearch_version = "1.0.0-rc1" } } diff --git a/gradle.properties b/gradle.properties index 4dd70a238d..82c9fa77f3 100644 --- a/gradle.properties +++ b/gradle.properties @@ -10,5 +10,5 @@ # # -opensearch_version=1.0.0-beta1 +opensearch_version=1.0.0-rc1 opensearchBaseVersion=1.0.0 \ No newline at end of file diff --git a/plugin/build.gradle b/plugin/build.gradle index d161fbe8f5..9ad7d3e9fe 100644 --- a/plugin/build.gradle +++ b/plugin/build.gradle @@ -15,6 +15,8 @@ plugins { id 'nebula.ospackage' id "io.freefair.lombok" id 'jacoco' + id "com.diffplug.gradle.spotless" version "3.26.1" + id 'checkstyle' } apply plugin: 'opensearch.opensearchplugin' apply plugin: 'opensearch.testclusters' @@ -40,6 +42,7 @@ dependencies { compile("com.fasterxml.jackson.core:jackson-databind:${versions.jackson}") compile group: 'com.google.guava', name: 'guava', version:'29.0-jre' + checkstyle "com.puppycrawl.tools:checkstyle:${project.checkstyle.toolVersion}" } test { @@ -94,18 +97,34 @@ jacocoTestReport { dependsOn test } +List jacocoExclusions = [ + // TODO: add more unit test to meet the minimal test coverage. + 'org.opensearch.ml.action.*', + 'org.opensearch.ml.constant.CommonValue', + 'org.opensearch.ml.indices.MLInputDatasetHandler', + 'org.opensearch.ml.plugin.*', + 'org.opensearch.ml.task.MLTaskRunner', +] + jacocoTestCoverageVerification { violationRules { rule { - limit { - counter = 'LINE' - minimum = 0.7 - } + element = 'CLASS' + excludes = jacocoExclusions limit { counter = 'BRANCH' minimum = 0.8 } } + rule { + element = 'CLASS' + excludes = jacocoExclusions + limit { + counter = 'LINE' + value = 'COVEREDRATIO' + minimum = 0.7 + } + } } dependsOn jacocoTestReport } @@ -115,4 +134,17 @@ configurations.all { resolutionStrategy.force 'junit:junit:4.12' resolutionStrategy.force 'org.apache.commons:commons-lang3:3.10' resolutionStrategy.force 'commons-logging:commons-logging:1.2' +} + +spotless { + java { + removeUnusedImports() + importOrder 'java', 'javax', 'org', 'com' + + eclipse().configFile rootProject.file('.eclipseformat.xml') + } +} + +checkstyle { + toolVersion = '8.29' } \ No newline at end of file diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionAction.java index 8bef5c6728..37f58e44f2 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionAction.java @@ -12,8 +12,8 @@ package org.opensearch.ml.action.prediction; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse; import org.opensearch.action.ActionType; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse; public class MLPredictionTaskExecutionAction extends ActionType { public static MLPredictionTaskExecutionAction INSTANCE = new MLPredictionTaskExecutionAction(); diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionTransportAction.java index 21c74156ad..93f2704adc 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/MLPredictionTaskExecutionTransportAction.java @@ -12,13 +12,13 @@ package org.opensearch.ml.action.prediction; -import org.opensearch.ml.task.MLTaskRunner; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; -import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse; import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskRequest; +import org.opensearch.ml.common.transport.prediction.MLPredictionTaskResponse; +import org.opensearch.ml.task.MLTaskRunner; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -28,9 +28,9 @@ public class MLPredictionTaskExecutionTransportAction extends HandledTransportAc @Inject public MLPredictionTaskExecutionTransportAction( - ActionFilters actionFilters, - TransportService transportService, - MLTaskRunner mlTaskRunner + ActionFilters actionFilters, + TransportService transportService, + MLTaskRunner mlTaskRunner ) { super(MLPredictionTaskExecutionAction.NAME, transportService, actionFilters, MLPredictionTaskRequest::new); this.mlTaskRunner = mlTaskRunner; diff --git a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java index 01cf966ba9..818eff33b0 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/prediction/TransportPredictionTaskAction.java @@ -15,6 +15,7 @@ import lombok.AccessLevel; import lombok.experimental.FieldDefaults; import lombok.extern.log4j.Log4j2; + import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; @@ -41,8 +42,7 @@ public TransportPredictionTaskAction(TransportService transportService, ActionFi } @Override - protected void doExecute(Task task, ActionRequest request, - ActionListener listener) { + protected void doExecute(Task task, ActionRequest request, ActionListener listener) { MLPredictionTaskRequest mlPredictionTaskRequest = MLPredictionTaskRequest.fromActionRequest(request); mlTaskRunner.runPrediction(mlPredictionTaskRequest, transportService, listener); } diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeRequest.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeRequest.java index a77589f579..321056e67d 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeRequest.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeRequest.java @@ -10,16 +10,16 @@ * */ - package org.opensearch.ml.action.stats; +import java.io.IOException; + import lombok.Getter; + import org.opensearch.action.support.nodes.BaseNodeRequest; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import java.io.IOException; - public class MLStatsNodeRequest extends BaseNodeRequest { @Getter private MLStatsNodesRequest mlStatsNodesRequest; diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java index d7e4ed74da..d50654f2bf 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodeResponse.java @@ -10,10 +10,13 @@ * */ - package org.opensearch.ml.action.stats; +import java.io.IOException; +import java.util.Map; + import lombok.Getter; + import org.opensearch.action.support.nodes.BaseNodeResponse; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.StreamInput; @@ -21,9 +24,6 @@ import org.opensearch.common.xcontent.ToXContentFragment; import org.opensearch.common.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.Map; - public class MLStatsNodeResponse extends BaseNodeResponse implements ToXContentFragment { @Getter private Map statsMap; diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesAction.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesAction.java index 445412f2b6..f0a099add4 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesAction.java @@ -10,11 +10,10 @@ * */ - package org.opensearch.ml.action.stats; -import org.opensearch.ml.constant.CommonValue; import org.opensearch.action.ActionType; +import org.opensearch.ml.constant.CommonValue; public class MLStatsNodesAction extends ActionType { // Internal Action which is not used for public facing RestAPIs. diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesRequest.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesRequest.java index 83918e1968..9a4559b447 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesRequest.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesRequest.java @@ -10,19 +10,19 @@ * */ - package org.opensearch.ml.action.stats; +import java.io.IOException; +import java.util.HashSet; +import java.util.Set; + import lombok.Getter; + import org.opensearch.action.support.nodes.BaseNodesRequest; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; -import java.io.IOException; -import java.util.HashSet; -import java.util.Set; - public class MLStatsNodesRequest extends BaseNodesRequest { /** * Key indicating all stats should be retrieved diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java index 7f0e6b633c..696487738c 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesResponse.java @@ -10,9 +10,11 @@ * */ - package org.opensearch.ml.action.stats; +import java.io.IOException; +import java.util.List; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.nodes.BaseNodesResponse; import org.opensearch.cluster.ClusterName; @@ -22,9 +24,6 @@ import org.opensearch.common.xcontent.ToXContentObject; import org.opensearch.common.xcontent.XContentBuilder; -import java.io.IOException; -import java.util.List; - public class MLStatsNodesResponse extends BaseNodesResponse implements ToXContentObject { private static final String NODES_KEY = "nodes"; diff --git a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java index ca1df81b59..4f3e9dca32 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/stats/MLStatsNodesTransportAction.java @@ -10,29 +10,28 @@ * */ - package org.opensearch.ml.action.stats; -import org.opensearch.ml.stats.InternalStatNames; -import org.opensearch.ml.stats.MLStats; +import java.io.IOException; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Set; + import org.opensearch.action.FailedNodeException; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.nodes.TransportNodesAction; import org.opensearch.cluster.service.ClusterService; import org.opensearch.common.inject.Inject; import org.opensearch.common.io.stream.StreamInput; +import org.opensearch.ml.stats.InternalStatNames; +import org.opensearch.ml.stats.MLStats; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import java.io.IOException; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Set; - public class MLStatsNodesTransportAction extends - TransportNodesAction { + TransportNodesAction { private MLStats mlStats; private final JvmService jvmService; @@ -48,23 +47,23 @@ public class MLStatsNodesTransportAction extends */ @Inject public MLStatsNodesTransportAction( - ThreadPool threadPool, - ClusterService clusterService, - TransportService transportService, - ActionFilters actionFilters, - MLStats mlStats, - JvmService jvmService + ThreadPool threadPool, + ClusterService clusterService, + TransportService transportService, + ActionFilters actionFilters, + MLStats mlStats, + JvmService jvmService ) { super( - MLStatsNodesAction.NAME, - threadPool, - clusterService, - transportService, - actionFilters, - MLStatsNodesRequest::new, - MLStatsNodeRequest::new, - ThreadPool.Names.MANAGEMENT, - MLStatsNodeResponse.class + MLStatsNodesAction.NAME, + threadPool, + clusterService, + transportService, + actionFilters, + MLStatsNodesRequest::new, + MLStatsNodeRequest::new, + ThreadPool.Names.MANAGEMENT, + MLStatsNodeResponse.class ); this.mlStats = mlStats; this.jvmService = jvmService; @@ -72,9 +71,9 @@ public MLStatsNodesTransportAction( @Override protected MLStatsNodesResponse newResponse( - MLStatsNodesRequest request, - List responses, - List failures + MLStatsNodesRequest request, + List responses, + List failures ) { return new MLStatsNodesResponse(clusterService.getClusterName(), responses, failures); } @@ -112,4 +111,3 @@ private MLStatsNodeResponse createMLStatsNodeResponse(MLStatsNodesRequest mlStat return new MLStatsNodeResponse(clusterService.localNode(), statValues); } } - diff --git a/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionAction.java b/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionAction.java index 2e0f5a8bd0..c89a3e1add 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionAction.java @@ -12,8 +12,8 @@ package org.opensearch.ml.action.training; -import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse; import org.opensearch.action.ActionType; +import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse; public class MLTrainingTaskExecutionAction extends ActionType { diff --git a/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionTransportAction.java b/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionTransportAction.java index 3e3283705a..13e1780e52 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionTransportAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/training/MLTrainingTaskExecutionTransportAction.java @@ -12,15 +12,16 @@ package org.opensearch.ml.action.training; -import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; -import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse; import lombok.AccessLevel; import lombok.experimental.FieldDefaults; import lombok.extern.log4j.Log4j2; + import org.opensearch.action.ActionListener; import org.opensearch.action.support.ActionFilters; import org.opensearch.action.support.HandledTransportAction; import org.opensearch.common.inject.Inject; +import org.opensearch.ml.common.transport.training.MLTrainingTaskRequest; +import org.opensearch.ml.common.transport.training.MLTrainingTaskResponse; import org.opensearch.ml.task.MLTaskRunner; import org.opensearch.tasks.Task; import org.opensearch.transport.TransportService; @@ -31,8 +32,11 @@ public class MLTrainingTaskExecutionTransportAction extends HandledTransportActi MLTaskRunner mlTaskRunner; @Inject - public MLTrainingTaskExecutionTransportAction(TransportService transportService, ActionFilters actionFilters, - MLTaskRunner mlTaskRunner) { + public MLTrainingTaskExecutionTransportAction( + TransportService transportService, + ActionFilters actionFilters, + MLTaskRunner mlTaskRunner + ) { super(MLTrainingTaskExecutionAction.NAME, transportService, actionFilters, MLTrainingTaskRequest::new); this.mlTaskRunner = mlTaskRunner; } diff --git a/plugin/src/main/java/org/opensearch/ml/action/training/TransportTrainingTaskAction.java b/plugin/src/main/java/org/opensearch/ml/action/training/TransportTrainingTaskAction.java index a1a53d6189..bbaa320154 100644 --- a/plugin/src/main/java/org/opensearch/ml/action/training/TransportTrainingTaskAction.java +++ b/plugin/src/main/java/org/opensearch/ml/action/training/TransportTrainingTaskAction.java @@ -13,6 +13,7 @@ package org.opensearch.ml.action.training; import lombok.extern.log4j.Log4j2; + import org.opensearch.action.ActionListener; import org.opensearch.action.ActionRequest; import org.opensearch.action.support.ActionFilters; @@ -32,8 +33,7 @@ public class TransportTrainingTaskAction extends HandledTransportAction { - if ( - r == null || - r.getHits() == null || - r.getHits().getTotalHits() == null || - r.getHits().getTotalHits().value == 0 - ) { + if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) { // todo: add specific exception listener.onFailure(new RuntimeException("No document found")); return; diff --git a/plugin/src/main/java/org/opensearch/ml/model/MLTask.java b/plugin/src/main/java/org/opensearch/ml/model/MLTask.java index e3fee43be9..fe834431ca 100644 --- a/plugin/src/main/java/org/opensearch/ml/model/MLTask.java +++ b/plugin/src/main/java/org/opensearch/ml/model/MLTask.java @@ -12,17 +12,18 @@ package org.opensearch.ml.model; +import java.io.IOException; +import java.time.Instant; + import lombok.Builder; import lombok.EqualsAndHashCode; import lombok.Getter; import lombok.Setter; + import org.opensearch.common.io.stream.StreamInput; import org.opensearch.common.io.stream.StreamOutput; import org.opensearch.common.io.stream.Writeable; -import java.io.IOException; -import java.time.Instant; - @Getter @EqualsAndHashCode public class MLTask implements Writeable { diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java index bb44d9a213..528a2b83e6 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/MachineLearningPlugin.java @@ -10,10 +10,16 @@ * */ - package org.opensearch.ml.plugin; -import com.google.common.collect.ImmutableMap; +import java.util.Collection; +import java.util.Collections; +import java.util.List; +import java.util.Map; +import java.util.function.Supplier; + +import org.opensearch.action.ActionRequest; +import org.opensearch.action.ActionResponse; import org.opensearch.client.Client; import org.opensearch.cluster.metadata.IndexNameExpressionResolver; import org.opensearch.cluster.node.DiscoveryNodeRole; @@ -45,9 +51,6 @@ import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.plugins.ActionPlugin; import org.opensearch.plugins.Plugin; -import org.opensearch.action.ActionRequest; -import org.opensearch.action.ActionResponse; -import com.google.common.collect.ImmutableList; import org.opensearch.repositories.RepositoriesService; import org.opensearch.rest.RestController; import org.opensearch.rest.RestHandler; @@ -57,11 +60,8 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.watcher.ResourceWatcherService; -import java.util.Collection; -import java.util.Collections; -import java.util.List; -import java.util.Map; -import java.util.function.Supplier; +import com.google.common.collect.ImmutableList; +import com.google.common.collect.ImmutableMap; public class MachineLearningPlugin extends Plugin implements ActionPlugin { public static final String TASK_THREAD_POOL = "OPENSEARCH_ML_TASK_THREAD_POOL"; @@ -69,8 +69,7 @@ public class MachineLearningPlugin extends Plugin implements ActionPlugin { private MLStats mlStats; - public static final Setting IS_ML_NODE_SETTING = - Setting.boolSetting("node.ml", false, Setting.Property.NodeScope); + public static final Setting IS_ML_NODE_SETTING = Setting.boolSetting("node.ml", false, Setting.Property.NodeScope); public static final DiscoveryNodeRole ML_ROLE = new DiscoveryNodeRole("ml", "l") { @Override @@ -79,64 +78,57 @@ public Setting legacySetting() { } }; - @Override public List> getActions() { - return ImmutableList.of( - new ActionHandler<>(MLStatsNodesAction.INSTANCE, - MLStatsNodesTransportAction.class), + return ImmutableList + .of( + new ActionHandler<>(MLStatsNodesAction.INSTANCE, MLStatsNodesTransportAction.class), new ActionHandler<>(MLPredictionTaskAction.INSTANCE, TransportPredictionTaskAction.class), new ActionHandler<>(MLTrainingTaskAction.INSTANCE, TransportTrainingTaskAction.class), - new ActionHandler<>(MLPredictionTaskExecutionAction.INSTANCE, - MLPredictionTaskExecutionTransportAction.class), + new ActionHandler<>(MLPredictionTaskExecutionAction.INSTANCE, MLPredictionTaskExecutionTransportAction.class), new ActionHandler<>(MLTrainingTaskExecutionAction.INSTANCE, MLTrainingTaskExecutionTransportAction.class) - ); + ); } @Override - public Collection createComponents(Client client, ClusterService clusterService, ThreadPool threadPool, - ResourceWatcherService resourceWatcherService, - ScriptService scriptService, - NamedXContentRegistry xContentRegistry, Environment environment, - NodeEnvironment nodeEnvironment, - NamedWriteableRegistry namedWriteableRegistry, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier repositoriesServiceSupplier) { + public Collection createComponents( + Client client, + ClusterService clusterService, + ThreadPool threadPool, + ResourceWatcherService resourceWatcherService, + ScriptService scriptService, + NamedXContentRegistry xContentRegistry, + Environment environment, + NodeEnvironment nodeEnvironment, + NamedWriteableRegistry namedWriteableRegistry, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier repositoriesServiceSupplier + ) { Map> stats = ImmutableMap - .>builder() - .put(StatNames.ML_EXECUTING_TASK_COUNT.getName(), new MLStat<>(false, new CounterSupplier())) - .build(); + .>builder() + .put(StatNames.ML_EXECUTING_TASK_COUNT.getName(), new MLStat<>(false, new CounterSupplier())) + .build(); this.mlStats = new MLStats(stats); return ImmutableList.of(mlStats); } @Override public List getRestHandlers( - Settings settings, - RestController restController, - ClusterSettings clusterSettings, - IndexScopedSettings indexScopedSettings, - SettingsFilter settingsFilter, - IndexNameExpressionResolver indexNameExpressionResolver, - Supplier nodesInCluster + Settings settings, + RestController restController, + ClusterSettings clusterSettings, + IndexScopedSettings indexScopedSettings, + SettingsFilter settingsFilter, + IndexNameExpressionResolver indexNameExpressionResolver, + Supplier nodesInCluster ) { RestStatsMLAction restStatsMLAction = new RestStatsMLAction(mlStats); - return ImmutableList - .of( - restStatsMLAction - ); + return ImmutableList.of(restStatsMLAction); } @Override public List> getExecutorBuilders(Settings settings) { - FixedExecutorBuilder ml = new FixedExecutorBuilder( - settings, - TASK_THREAD_POOL, - 4, - 4, - "ml.task_thread_pool", - false - ); + FixedExecutorBuilder ml = new FixedExecutorBuilder(settings, TASK_THREAD_POOL, 4, 4, "ml.task_thread_pool", false); return Collections.singletonList(ml); } diff --git a/plugin/src/main/java/org/opensearch/ml/plugin/package-info.java b/plugin/src/main/java/org/opensearch/ml/plugin/package-info.java index 97a3ccd044..65dc88ce3d 100644 --- a/plugin/src/main/java/org/opensearch/ml/plugin/package-info.java +++ b/plugin/src/main/java/org/opensearch/ml/plugin/package-info.java @@ -10,4 +10,4 @@ * */ -package org.opensearch.ml.plugin; \ No newline at end of file +package org.opensearch.ml.plugin; diff --git a/plugin/src/main/java/org/opensearch/ml/rest/RestStatsMLAction.java b/plugin/src/main/java/org/opensearch/ml/rest/RestStatsMLAction.java index 7a9a412574..be37d03bc1 100644 --- a/plugin/src/main/java/org/opensearch/ml/rest/RestStatsMLAction.java +++ b/plugin/src/main/java/org/opensearch/ml/rest/RestStatsMLAction.java @@ -12,15 +12,7 @@ package org.opensearch.ml.rest; -import com.google.common.annotations.VisibleForTesting; -import org.opensearch.ml.action.stats.MLStatsNodesAction; -import org.opensearch.ml.action.stats.MLStatsNodesRequest; -import org.opensearch.ml.stats.MLStats; -import com.google.common.collect.ImmutableList; -import org.opensearch.client.node.NodeClient; -import org.opensearch.rest.BaseRestHandler; -import org.opensearch.rest.RestRequest; -import org.opensearch.rest.action.RestToXContentListener; +import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; import java.util.Arrays; import java.util.Collections; @@ -30,7 +22,16 @@ import java.util.Set; import java.util.stream.Collectors; -import static org.opensearch.ml.plugin.MachineLearningPlugin.ML_BASE_URI; +import org.opensearch.client.node.NodeClient; +import org.opensearch.ml.action.stats.MLStatsNodesAction; +import org.opensearch.ml.action.stats.MLStatsNodesRequest; +import org.opensearch.ml.stats.MLStats; +import org.opensearch.rest.BaseRestHandler; +import org.opensearch.rest.RestRequest; +import org.opensearch.rest.action.RestToXContentListener; + +import com.google.common.annotations.VisibleForTesting; +import com.google.common.collect.ImmutableList; public class RestStatsMLAction extends BaseRestHandler { private static final String STATS_ML_ACTION = "stats_ml"; @@ -50,19 +51,17 @@ public String getName() { return STATS_ML_ACTION; } - @Override public List routes() { return ImmutableList - .of( - new Route(RestRequest.Method.GET, ML_BASE_URI + "/{nodeId}/stats/"), - new Route(RestRequest.Method.GET, ML_BASE_URI + "/{nodeId}/stats/{stat}"), - new Route(RestRequest.Method.GET, ML_BASE_URI + "/stats/"), - new Route(RestRequest.Method.GET, ML_BASE_URI + "/stats/{stat}") - ); + .of( + new Route(RestRequest.Method.GET, ML_BASE_URI + "/{nodeId}/stats/"), + new Route(RestRequest.Method.GET, ML_BASE_URI + "/{nodeId}/stats/{stat}"), + new Route(RestRequest.Method.GET, ML_BASE_URI + "/stats/"), + new Route(RestRequest.Method.GET, ML_BASE_URI + "/stats/{stat}") + ); } - @Override protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient client) { MLStatsNodesRequest mlStatsNodesRequest = getRequest(request); @@ -78,14 +77,10 @@ protected RestChannelConsumer prepareRequest(RestRequest request, NodeClient cli @VisibleForTesting MLStatsNodesRequest getRequest(RestRequest request) { // todo: add logic to triage request based on node type(ML node or data node) - MLStatsNodesRequest mlStatsRequest = new MLStatsNodesRequest( - splitCommaSeparatedParam(request, "nodeId").orElse(null)); + MLStatsNodesRequest mlStatsRequest = new MLStatsNodesRequest(splitCommaSeparatedParam(request, "nodeId").orElse(null)); mlStatsRequest.timeout(request.param("timeout")); - List requestedStats = - splitCommaSeparatedParam(request, "stat") - .map(Arrays::asList) - .orElseGet(Collections::emptyList); + List requestedStats = splitCommaSeparatedParam(request, "stat").map(Arrays::asList).orElseGet(Collections::emptyList); Set validStats = mlStats.getStats().keySet(); if (isAllStatsRequested(requestedStats)) { @@ -98,35 +93,28 @@ MLStatsNodesRequest getRequest(RestRequest request) { } @VisibleForTesting - Set getStatsToBeRetrieved( - RestRequest request, Set validStats, List requestedStats) { + Set getStatsToBeRetrieved(RestRequest request, Set validStats, List requestedStats) { if (requestedStats.contains(MLStatsNodesRequest.ALL_STATS_KEY)) { throw new IllegalArgumentException( - String.format("Request %s contains both %s and individual stats", - request.path(), MLStatsNodesRequest.ALL_STATS_KEY)); + String.format("Request %s contains both %s and individual stats", request.path(), MLStatsNodesRequest.ALL_STATS_KEY) + ); } - Set invalidStats = - requestedStats.stream() - .filter(s -> !validStats.contains(s)) - .collect(Collectors.toSet()); + Set invalidStats = requestedStats.stream().filter(s -> !validStats.contains(s)).collect(Collectors.toSet()); if (!invalidStats.isEmpty()) { - throw new IllegalArgumentException( - unrecognized(request, invalidStats, new HashSet<>(requestedStats), "stat")); + throw new IllegalArgumentException(unrecognized(request, invalidStats, new HashSet<>(requestedStats), "stat")); } return new HashSet<>(requestedStats); } @VisibleForTesting boolean isAllStatsRequested(List requestedStats) { - return requestedStats.isEmpty() - || (requestedStats.size() == 1 && requestedStats.contains(MLStatsNodesRequest.ALL_STATS_KEY)); + return requestedStats.isEmpty() || (requestedStats.size() == 1 && requestedStats.contains(MLStatsNodesRequest.ALL_STATS_KEY)); } @VisibleForTesting Optional splitCommaSeparatedParam(RestRequest request, String paramName) { - return Optional.ofNullable(request.param(paramName)) - .map(s -> s.split(",")); + return Optional.ofNullable(request.param(paramName)).map(s -> s.split(",")); } } diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLStat.java b/plugin/src/main/java/org/opensearch/ml/stats/MLStat.java index 726a71b0d8..60732f05f1 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLStat.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLStat.java @@ -12,12 +12,13 @@ package org.opensearch.ml.stats; +import java.util.function.Supplier; + import lombok.Getter; + import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.ml.stats.suppliers.SettableSupplier; -import java.util.function.Supplier; - /** * Class represents a stat the ML plugin keeps track of */ diff --git a/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java b/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java index 3c2c700bb6..91fbf32c9d 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/MLStats.java @@ -12,11 +12,11 @@ package org.opensearch.ml.stats; -import lombok.Getter; - import java.util.HashMap; import java.util.Map; +import lombok.Getter; + /** * This class is the main entry-point for access to the stats that the ML plugin keeps track of. */ diff --git a/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java b/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java index 09aa1bc78e..57d0d339e5 100644 --- a/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java +++ b/plugin/src/main/java/org/opensearch/ml/stats/StatNames.java @@ -12,11 +12,11 @@ package org.opensearch.ml.stats; -import lombok.Getter; - import java.util.HashSet; import java.util.Set; +import lombok.Getter; + /** * Enum containing names of all stats */ diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java index 092deaeb5b..fea6063b15 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskManager.java @@ -12,12 +12,12 @@ package org.opensearch.ml.task; -import org.opensearch.ml.model.MLTask; -import org.opensearch.ml.model.MLTaskState; - import java.util.Map; import java.util.concurrent.ConcurrentHashMap; +import org.opensearch.ml.model.MLTask; +import org.opensearch.ml.model.MLTaskState; + /** * MLTaskManager is responsible for managing MLTask. */ @@ -25,6 +25,7 @@ public class MLTaskManager { private final Map taskCaches; // todo make this value configurable in the future public final static int MAX_ML_TASK_PER_NODE = 10; + /** * Constructor to create ML task manager. * diff --git a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java index 096e9c9b73..6a0bfd2adc 100644 --- a/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java +++ b/plugin/src/main/java/org/opensearch/ml/task/MLTaskRunner.java @@ -12,8 +12,25 @@ package org.opensearch.ml.task; -import com.google.common.collect.ImmutableSet; +import static org.opensearch.ml.indices.MLIndicesHandler.OS_ML_MODEL_RESULT; +import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; +import static org.opensearch.ml.stats.InternalStatNames.JVM_HEAP_USAGE; +import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT; + +import java.time.Instant; +import java.util.ArrayList; +import java.util.Base64; +import java.util.HashMap; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.UUID; +import java.util.stream.Collectors; + +import javax.naming.LimitExceededException; + import lombok.extern.log4j.Log4j2; + import org.opensearch.action.ActionListener; import org.opensearch.action.ActionListenerResponseHandler; import org.opensearch.action.get.GetResponse; @@ -46,21 +63,7 @@ import org.opensearch.threadpool.ThreadPool; import org.opensearch.transport.TransportService; -import javax.naming.LimitExceededException; -import java.time.Instant; -import java.util.ArrayList; -import java.util.Base64; -import java.util.HashMap; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.UUID; -import java.util.stream.Collectors; - -import static org.opensearch.ml.plugin.MachineLearningPlugin.TASK_THREAD_POOL; -import static org.opensearch.ml.stats.StatNames.ML_EXECUTING_TASK_COUNT; -import static org.opensearch.ml.stats.InternalStatNames.JVM_HEAP_USAGE; -import static org.opensearch.ml.indices.MLIndicesHandler.OS_ML_MODEL_RESULT; +import com.google.common.collect.ImmutableSet; /** * MLTaskRunner is responsible for dispatching and running predict/training tasks. @@ -79,13 +82,13 @@ public class MLTaskRunner { private volatile Integer maxAdBatchTaskPerNode; public MLTaskRunner( - ThreadPool threadPool, - ClusterService clusterService, - Client client, - MLTaskManager mlTaskManager, - MLStats mlStats, - MLIndicesHandler mlIndicesHandler, - MLInputDatasetHandler mlInputDatasetHandler + ThreadPool threadPool, + ClusterService clusterService, + Client client, + MLTaskManager mlTaskManager, + MLStats mlStats, + MLIndicesHandler mlIndicesHandler, + MLInputDatasetHandler mlInputDatasetHandler ) { this.threadPool = threadPool; this.clusterService = clusterService; @@ -110,15 +113,15 @@ public void dispatchTask(ActionListener listener) { client.execute(MLStatsNodesAction.INSTANCE, MLStatsNodesRequest, ActionListener.wrap(mlStatsResponse -> { // Check JVM pressure List candidateNodeResponse = mlStatsResponse - .getNodes() - .stream() - .filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) - .collect(Collectors.toList()); + .getNodes() + .stream() + .filter(stat -> (long) stat.getStatsMap().get(JVM_HEAP_USAGE.getName()) < DEFAULT_JVM_HEAP_USAGE_THRESHOLD) + .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { String errorMessage = "All nodes' memory usage exceeds limitation" - + DEFAULT_JVM_HEAP_USAGE_THRESHOLD - + ". No eligible node to run ml jobs "; + + DEFAULT_JVM_HEAP_USAGE_THRESHOLD + + ". No eligible node to run ml jobs "; log.warn(errorMessage); listener.onFailure(new LimitExceededException(errorMessage)); return; @@ -126,9 +129,9 @@ public void dispatchTask(ActionListener listener) { // Check # of executing ML task candidateNodeResponse = candidateNodeResponse - .stream() - .filter(stat -> (Long) stat.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()) < maxAdBatchTaskPerNode) - .collect(Collectors.toList()); + .stream() + .filter(stat -> (Long) stat.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName()) < maxAdBatchTaskPerNode) + .collect(Collectors.toList()); if (candidateNodeResponse.size() == 0) { String errorMessage = "All nodes' executing ML task count exceeds limitation."; log.warn(errorMessage); @@ -138,19 +141,19 @@ public void dispatchTask(ActionListener listener) { // sort nodes by JVM usage percentage and # of executing ML task Optional targetNode = candidateNodeResponse - .stream() - .sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> { - int result = ((Long) r1.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName())) - .compareTo((Long) r2.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName())); - if (result == 0) { - // if multiple nodes have same running task count, choose the one with least - // JVM heap usage. - return ((Long) r1.getStatsMap().get(JVM_HEAP_USAGE.getName())) - .compareTo((Long) r2.getStatsMap().get(JVM_HEAP_USAGE.getName())); - } - return result; - }) - .findFirst(); + .stream() + .sorted((MLStatsNodeResponse r1, MLStatsNodeResponse r2) -> { + int result = ((Long) r1.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName())) + .compareTo((Long) r2.getStatsMap().get(ML_EXECUTING_TASK_COUNT.getName())); + if (result == 0) { + // if multiple nodes have same running task count, choose the one with least + // JVM heap usage. + return ((Long) r1.getStatsMap().get(JVM_HEAP_USAGE.getName())) + .compareTo((Long) r2.getStatsMap().get(JVM_HEAP_USAGE.getName())); + } + return result; + }) + .findFirst(); listener.onResponse(targetNode.get().getNode()); }, exception -> { log.error("Failed to get node's task stats", exception); @@ -164,32 +167,26 @@ public void dispatchTask(ActionListener listener) { * @param transportService transport service * @param listener Action listener */ - public void runPrediction(MLPredictionTaskRequest request, TransportService transportService, ActionListener listener) { + public void runPrediction( + MLPredictionTaskRequest request, + TransportService transportService, + ActionListener listener + ) { dispatchTask(ActionListener.wrap(node -> { if (clusterService.localNode().getId().equals(node.getId())) { // Execute prediction task locally - log - .info( - "execute ML prediction request {} locally on node {}", - request.toString(), - node.getId() - ); + log.info("execute ML prediction request {} locally on node {}", request.toString(), node.getId()); startPredictionTask(request, listener); } else { // Execute batch task remotely - log - .info( - "execute ML prediction request {} remotely on node {}", - request.toString(), - node.getId() - ); + log.info("execute ML prediction request {} remotely on node {}", request.toString(), node.getId()); transportService - .sendRequest( - node, - MLPredictionTaskExecutionAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLPredictionTaskResponse::new) - ); + .sendRequest( + node, + MLPredictionTaskExecutionAction.NAME, + request, + new ActionListenerResponseHandler<>(listener, MLPredictionTaskResponse::new) + ); } }, e -> listener.onFailure(e))); } @@ -199,37 +196,40 @@ public void runPrediction(MLPredictionTaskRequest request, TransportService tran * @param request MLPredictionTaskRequest * @param listener Action listener */ - public void startPredictionTask( - MLPredictionTaskRequest request, - ActionListener listener - ) { - MLTask mlTask = MLTask.builder() - .taskId(UUID.randomUUID().toString()) - .taskType(MLTaskType.PREDICTION) - .createTime(Instant.now()) - .state(MLTaskState.CREATED) - .build(); + public void startPredictionTask(MLPredictionTaskRequest request, ActionListener listener) { + MLTask mlTask = MLTask + .builder() + .taskId(UUID.randomUUID().toString()) + .taskType(MLTaskType.PREDICTION) + .createTime(Instant.now()) + .state(MLTaskState.CREATED) + .build(); if (request.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) { - ActionListener dataFrameActionListener = ActionListener.wrap(dataFrame -> { - predict(mlTask, dataFrame, request, listener); - }, e -> { - log.error("Failed to generate DataFrame from search query", e); - mlTaskManager.add(mlTask); - mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED); - mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage()); - listener.onFailure(e); - }); - mlInputDatasetHandler.parseSearchQueryInput(request.getInputDataset(), - new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false)); + ActionListener dataFrameActionListener = ActionListener + .wrap(dataFrame -> { predict(mlTask, dataFrame, request, listener); }, e -> { + log.error("Failed to generate DataFrame from search query", e); + mlTaskManager.add(mlTask); + mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED); + mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage()); + listener.onFailure(e); + }); + mlInputDatasetHandler + .parseSearchQueryInput( + request.getInputDataset(), + new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false) + ); } else { DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(request.getInputDataset()); - threadPool.executor(TASK_THREAD_POOL).execute(() -> { - predict(mlTask, inputDataFrame, request, listener); - }); + threadPool.executor(TASK_THREAD_POOL).execute(() -> { predict(mlTask, inputDataFrame, request, listener); }); } } - private void predict(MLTask mlTask, DataFrame inputDataFrame, MLPredictionTaskRequest request, ActionListener listener) { + private void predict( + MLTask mlTask, + DataFrame inputDataFrame, + MLPredictionTaskRequest request, + ActionListener listener + ) { // track ML task count and add ML task into cache mlStats.getStat(ML_EXECUTING_TASK_COUNT.getName()).increment(); mlTaskManager.add(mlTask); @@ -266,10 +266,12 @@ private void predict(MLTask mlTask, DataFrame inputDataFrame, MLPredictionTaskRe listener.onFailure(e); } - MLPredictionTaskResponse response = MLPredictionTaskResponse.builder() - .taskId(mlTask.getTaskId()) - .status(mlTaskManager.get(mlTask.getTaskId()).getState().name()) - .predictionResult(forecastsResults).build(); + MLPredictionTaskResponse response = MLPredictionTaskResponse + .builder() + .taskId(mlTask.getTaskId()) + .status(mlTaskManager.get(mlTask.getTaskId()).getState().name()) + .predictionResult(forecastsResults) + .build(); listener.onResponse(response); } @@ -279,32 +281,26 @@ private void predict(MLTask mlTask, DataFrame inputDataFrame, MLPredictionTaskRe * @param transportService transport service * @param listener Action listener */ - public void runTraining(MLTrainingTaskRequest request, TransportService transportService, ActionListener listener) { + public void runTraining( + MLTrainingTaskRequest request, + TransportService transportService, + ActionListener listener + ) { dispatchTask(ActionListener.wrap(node -> { if (clusterService.localNode().getId().equals(node.getId())) { // Execute training task locally - log - .info( - "execute ML training request {} locally on node {}", - request.toString(), - node.getId() - ); + log.info("execute ML training request {} locally on node {}", request.toString(), node.getId()); startTrainingTask(request, listener); } else { // Execute batch task remotely - log - .info( - "execute ML training request {} remotely on node {}", - request.toString(), - node.getId() - ); + log.info("execute ML training request {} remotely on node {}", request.toString(), node.getId()); transportService - .sendRequest( - node, - MLTrainingTaskExecutionAction.NAME, - request, - new ActionListenerResponseHandler<>(listener, MLTrainingTaskResponse::new) - ); + .sendRequest( + node, + MLTrainingTaskExecutionAction.NAME, + request, + new ActionListenerResponseHandler<>(listener, MLTrainingTaskResponse::new) + ); } }, e -> listener.onFailure(e))); } @@ -314,36 +310,31 @@ public void runTraining(MLTrainingTaskRequest request, TransportService transpor * @param request MLTrainingTaskRequest * @param listener Action listener */ - public void startTrainingTask( - MLTrainingTaskRequest request, - ActionListener listener - ) { - MLTask mlTask = MLTask.builder() - .taskId(UUID.randomUUID().toString()) - .taskType(MLTaskType.TRAINING) - .createTime(Instant.now()) - .state(MLTaskState.CREATED) - .build(); - listener.onResponse(MLTrainingTaskResponse.builder() - .taskId(mlTask.getTaskId()) - .status(MLTaskState.CREATED.name()) - .build()); + public void startTrainingTask(MLTrainingTaskRequest request, ActionListener listener) { + MLTask mlTask = MLTask + .builder() + .taskId(UUID.randomUUID().toString()) + .taskType(MLTaskType.TRAINING) + .createTime(Instant.now()) + .state(MLTaskState.CREATED) + .build(); + listener.onResponse(MLTrainingTaskResponse.builder().taskId(mlTask.getTaskId()).status(MLTaskState.CREATED.name()).build()); if (request.getInputDataset().getInputDataType().equals(MLInputDataType.SEARCH_QUERY)) { - ActionListener dataFrameActionListener = ActionListener.wrap(dataFrame -> { - train(mlTask, dataFrame, request); - }, e -> { - log.error("Failed to generate DataFrame from search query", e); - mlTaskManager.add(mlTask); - mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED); - mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage()); - }); - mlInputDatasetHandler.parseSearchQueryInput(request.getInputDataset(), new ThreadedActionListener<>( - log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false)); + ActionListener dataFrameActionListener = ActionListener + .wrap(dataFrame -> { train(mlTask, dataFrame, request); }, e -> { + log.error("Failed to generate DataFrame from search query", e); + mlTaskManager.add(mlTask); + mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.FAILED); + mlTaskManager.updateTaskError(mlTask.getTaskId(), e.getMessage()); + }); + mlInputDatasetHandler + .parseSearchQueryInput( + request.getInputDataset(), + new ThreadedActionListener<>(log, threadPool, TASK_THREAD_POOL, dataFrameActionListener, false) + ); } else { DataFrame inputDataFrame = mlInputDatasetHandler.parseDataFrameInput(request.getInputDataset()); - threadPool.executor(TASK_THREAD_POOL).execute(() -> { - train(mlTask, inputDataFrame, request); - }); + threadPool.executor(TASK_THREAD_POOL).execute(() -> { train(mlTask, inputDataFrame, request); }); } } @@ -390,7 +381,6 @@ private void handleMLTaskComplete(MLTask mlTask) { mlTaskManager.updateTaskState(mlTask.getTaskId(), MLTaskState.COMPLETED); } - private DiscoveryNode[] getEligibleMLNodes() { ClusterState state = this.clusterService.state(); final List eligibleNodes = new ArrayList<>(); diff --git a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java index cb344be828..6c4302ec8e 100644 --- a/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java +++ b/plugin/src/main/java/org/opensearch/ml/utils/MLNodeUtils.java @@ -12,9 +12,10 @@ package org.opensearch.ml.utils; -import org.opensearch.ml.plugin.MachineLearningPlugin; import lombok.experimental.UtilityClass; + import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.ml.plugin.MachineLearningPlugin; @UtilityClass public class MLNodeUtils { diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeRequestTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeRequestTests.java index 85f033da63..41a80c3061 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeRequestTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeRequestTests.java @@ -12,15 +12,15 @@ package org.opensearch.ml.action.stats; -import org.junit.Assert; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; - import java.io.IOException; import java.util.Arrays; import java.util.HashSet; import java.util.Set; +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; + public class MLStatsNodeRequestTests { @Test public void testSerializationDeserialization() throws IOException { @@ -37,10 +37,11 @@ public void testSerializationDeserialization() throws IOException { MLStatsNodeRequest request = new MLStatsNodeRequest(mlStatsNodesRequest); request.writeTo(output); MLStatsNodeRequest newRequest = new MLStatsNodeRequest(output.bytes().streamInput()); - Assert.assertEquals( + Assert + .assertEquals( newRequest.getMlStatsNodesRequest().getStatsToBeRetrieved().size(), request.getMlStatsNodesRequest().getStatsToBeRetrieved().size() - ); + ); for (String stat : newRequest.getMlStatsNodesRequest().getStatsToBeRetrieved()) { Assert.assertTrue(request.getMlStatsNodesRequest().getStatsToBeRetrieved().contains(stat)); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java index cb9c8a9bb9..5ea091ab31 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodeResponseTests.java @@ -12,17 +12,17 @@ package org.opensearch.ml.action.stats; -import org.junit.Assert; -import org.junit.Test; -import org.opensearch.Version; -import org.opensearch.cluster.node.DiscoveryNode; -import org.opensearch.common.io.stream.BytesStreamOutput; +import static org.opensearch.test.OpenSearchTestCase.buildNewFakeTransportAddress; import java.io.IOException; import java.util.HashMap; import java.util.Map; -import static org.opensearch.test.OpenSearchTestCase.buildNewFakeTransportAddress; +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.Version; +import org.opensearch.cluster.node.DiscoveryNode; +import org.opensearch.common.io.stream.BytesStreamOutput; public class MLStatsNodeResponseTests { @Test diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java index 6e80654ce1..e147ee3319 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesRequestTests.java @@ -12,15 +12,15 @@ package org.opensearch.ml.action.stats; -import org.junit.Assert; -import org.junit.Test; -import org.opensearch.common.io.stream.BytesStreamOutput; - import java.io.IOException; import java.util.Arrays; import java.util.HashSet; import java.util.Set; +import org.junit.Assert; +import org.junit.Test; +import org.opensearch.common.io.stream.BytesStreamOutput; + public class MLStatsNodesRequestTests { @Test public void testSerializationDeserialization() throws IOException { @@ -35,10 +35,7 @@ public void testSerializationDeserialization() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); request.writeTo(output); MLStatsNodesRequest newRequest = new MLStatsNodesRequest(output.bytes().streamInput()); - Assert.assertEquals( - newRequest.getStatsToBeRetrieved().size(), - request.getStatsToBeRetrieved().size() - ); + Assert.assertEquals(newRequest.getStatsToBeRetrieved().size(), request.getStatsToBeRetrieved().size()); for (String stat : newRequest.getStatsToBeRetrieved()) { Assert.assertTrue(request.getStatsToBeRetrieved().contains(stat)); } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java index 6ba55baccb..72c1788ae5 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesResponseTests.java @@ -12,16 +12,16 @@ package org.opensearch.ml.action.stats; +import java.io.IOException; +import java.util.ArrayList; +import java.util.List; + import org.junit.Assert; import org.junit.Test; import org.opensearch.action.FailedNodeException; import org.opensearch.cluster.ClusterName; import org.opensearch.common.io.stream.BytesStreamOutput; -import java.io.IOException; -import java.util.ArrayList; -import java.util.List; - public class MLStatsNodesResponseTests { @Test public void testSerializationDeserialization() throws IOException { @@ -32,9 +32,6 @@ public void testSerializationDeserialization() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); response.writeTo(output); MLStatsNodesResponse newResponse = new MLStatsNodesResponse(output.bytes().streamInput()); - Assert.assertEquals( - newResponse.getNodes().size(), - response.getNodes().size() - ); + Assert.assertEquals(newResponse.getNodes().size(), response.getNodes().size()); } } diff --git a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java index 6c739b6c18..85d8ac188d 100644 --- a/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/action/stats/MLStatsNodesTransportActionTests.java @@ -15,7 +15,18 @@ import static org.mockito.Mockito.mock; import static org.mockito.Mockito.when; +import java.io.IOException; +import java.util.Arrays; +import java.util.HashMap; +import java.util.HashSet; +import java.util.Map; +import java.util.Set; + +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; import org.opensearch.Version; +import org.opensearch.action.support.ActionFilters; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.common.io.stream.BytesStreamOutput; import org.opensearch.common.io.stream.StreamInput; @@ -24,21 +35,10 @@ import org.opensearch.ml.stats.MLStats; import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.ml.stats.suppliers.SettableSupplier; -import org.opensearch.action.support.ActionFilters; import org.opensearch.monitor.jvm.JvmService; import org.opensearch.monitor.jvm.JvmStats; import org.opensearch.test.OpenSearchIntegTestCase; import org.opensearch.transport.TransportService; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - -import java.io.IOException; -import java.util.Arrays; -import java.util.HashMap; -import java.util.HashSet; -import java.util.Map; -import java.util.Set; public class MLStatsNodesTransportActionTests extends OpenSearchIntegTestCase { private MLStatsNodesTransportAction action; @@ -73,12 +73,12 @@ public void setUp() throws Exception { when(mem.getHeapUsedPercent()).thenReturn(randomShort()); action = new MLStatsNodesTransportAction( - client().threadPool(), - clusterService(), - mock(TransportService.class), - mock(ActionFilters.class), - mlStats, - jvmService + client().threadPool(), + clusterService(), + mock(TransportService.class), + mock(ActionFilters.class), + mlStats, + jvmService ); } diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java index 92746c26d0..b195cdf970 100644 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/indices/MLIndicesHandlerTests.java @@ -23,6 +23,7 @@ public class MLIndicesHandlerTests extends OpenSearchIntegTestCase { ClusterService clusterService; Client client; MLIndicesHandler mlIndicesHandler; + @Before public void setup() { clusterService = clusterService(); diff --git a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java b/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java index 02efa24d18..5a9d038804 100644 --- a/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/indices/MLInputDatasetHandlerTests.java @@ -12,12 +12,28 @@ package org.opensearch.ml.indices; +import static org.mockito.Matchers.any; +import static org.mockito.Mockito.doAnswer; +import static org.mockito.Mockito.mock; +import static org.mockito.Mockito.spy; +import static org.mockito.Mockito.times; +import static org.mockito.Mockito.verify; +import static org.mockito.Mockito.when; + +import java.util.ArrayList; +import java.util.Arrays; +import java.util.Collections; +import java.util.HashMap; +import java.util.List; +import java.util.Map; + import org.apache.lucene.search.TotalHits; import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; import org.junit.rules.ExpectedException; +import org.mockito.ArgumentCaptor; import org.opensearch.action.ActionListener; import org.opensearch.action.search.SearchResponse; import org.opensearch.client.Client; @@ -32,24 +48,7 @@ import org.opensearch.search.SearchHits; import org.opensearch.search.builder.SearchSourceBuilder; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.Collections; -import java.util.HashMap; -import java.util.List; -import java.util.Map; - -import static org.mockito.Matchers.any; -import static org.mockito.Mockito.doAnswer; -import static org.mockito.Mockito.mock; -import static org.mockito.Mockito.spy; -import static org.mockito.Mockito.times; -import static org.mockito.Mockito.verify; -import static org.mockito.Mockito.when; -import org.mockito.ArgumentCaptor; - - -public class MLInputDatasetHandlerTests{ +public class MLInputDatasetHandlerTests { Client client; MLInputDatasetHandler mlInputDatasetHandler; ActionListener listener; @@ -85,9 +84,7 @@ public void testDataFrameInputDataset() { put("key1", 2.0D); } })); - DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder() - .dataFrame(testDataFrame) - .build(); + DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder().dataFrame(testDataFrame).build(); DataFrame result = mlInputDatasetHandler.parseDataFrameInput(dataFrameInputDataset); Assert.assertEquals(testDataFrame, result); } @@ -96,33 +93,34 @@ public void testDataFrameInputDataset() { public void testDataFrameInputDatasetWrongType() { expectedEx.expect(IllegalArgumentException.class); expectedEx.expectMessage("Input dataset is not DATA_FRAME type."); - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder() - .indices(Arrays.asList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); + SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset + .builder() + .indices(Arrays.asList("index1")) + .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) + .build(); DataFrame result = mlInputDatasetHandler.parseDataFrameInput(searchQueryInputDataset); } - @Test @SuppressWarnings("unchecked") public void testSearchQueryInputDatasetWithHits() { searchResponse = mock(SearchResponse.class); BytesReference bytesArray = new BytesArray("{\"taskId\":\"111\"}"); - SearchHit hit = new SearchHit( 1 ); + SearchHit hit = new SearchHit(1); hit.sourceRef(bytesArray); - SearchHits hits = new SearchHits(new SearchHit[] {hit}, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); + SearchHits hits = new SearchHits(new SearchHit[] { hit }, new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); when(searchResponse.getHits()).thenReturn(hits); doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments() [1]; + ActionListener listener = (ActionListener) invocation.getArguments()[1]; listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder() - .indices(Arrays.asList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); + SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset + .builder() + .indices(Arrays.asList("index1")) + .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) + .build(); mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); ArgumentCaptor captor = ArgumentCaptor.forClass(DataFrame.class); verify(listener, times(1)).onResponse(captor.capture()); @@ -136,15 +134,16 @@ public void testSearchQueryInputDatasetWithoutHits() { SearchHits hits = new SearchHits(new SearchHit[0], new TotalHits(1L, TotalHits.Relation.EQUAL_TO), 1f); when(searchResponse.getHits()).thenReturn(hits); doAnswer(invocation -> { - ActionListener listener = (ActionListener) invocation.getArguments() [1]; + ActionListener listener = (ActionListener) invocation.getArguments()[1]; listener.onResponse(searchResponse); return null; }).when(client).search(any(), any()); - SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset.builder() - .indices(Arrays.asList("index1")) - .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) - .build(); + SearchQueryInputDataset searchQueryInputDataset = SearchQueryInputDataset + .builder() + .indices(Arrays.asList("index1")) + .searchSourceBuilder(new SearchSourceBuilder().query(QueryBuilders.matchAllQuery())) + .build(); mlInputDatasetHandler.parseSearchQueryInput(searchQueryInputDataset, listener); verify(listener, times(1)).onFailure(any()); } @@ -158,9 +157,7 @@ public void testSearchQueryInputDatasetWrongType() { put("key1", 2.0D); } })); - DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder() - .dataFrame(testDataFrame) - .build(); + DataFrameInputDataset dataFrameInputDataset = DataFrameInputDataset.builder().dataFrame(testDataFrame).build(); mlInputDatasetHandler.parseSearchQueryInput(dataFrameInputDataset, listener); } diff --git a/plugin/src/test/java/org/opensearch/ml/model/MLTaskTests.java b/plugin/src/test/java/org/opensearch/ml/model/MLTaskTests.java index 96617a688d..81339ff904 100644 --- a/plugin/src/test/java/org/opensearch/ml/model/MLTaskTests.java +++ b/plugin/src/test/java/org/opensearch/ml/model/MLTaskTests.java @@ -12,26 +12,27 @@ package org.opensearch.ml.model; +import java.io.IOException; +import java.time.Instant; + import org.junit.Assert; import org.junit.Test; import org.opensearch.common.io.stream.BytesStreamOutput; -import java.io.IOException; -import java.time.Instant; - public class MLTaskTests { @Test public void testWriteTo() throws IOException { BytesStreamOutput output = new BytesStreamOutput(); Instant now = Instant.now(); - MLTask task1 = MLTask.builder() - .taskId("dummy taskId") - .taskType(MLTaskType.PREDICTION) - .modelId(null) - .createTime(now) - .state(MLTaskState.RUNNING) - .error(null) - .build(); + MLTask task1 = MLTask + .builder() + .taskId("dummy taskId") + .taskType(MLTaskType.PREDICTION) + .modelId(null) + .createTime(now) + .state(MLTaskState.RUNNING) + .error(null) + .build(); task1.writeTo(output); MLTask task2 = new MLTask(output.bytes().streamInput()); Assert.assertEquals(task1, task2); diff --git a/plugin/src/test/java/org/opensearch/ml/rest/RestStatsMLActionTests.java b/plugin/src/test/java/org/opensearch/ml/rest/RestStatsMLActionTests.java index 4a63a5cbe4..c3aa72ddf1 100644 --- a/plugin/src/test/java/org/opensearch/ml/rest/RestStatsMLActionTests.java +++ b/plugin/src/test/java/org/opensearch/ml/rest/RestStatsMLActionTests.java @@ -12,11 +12,18 @@ package org.opensearch.ml.rest; -import com.google.common.collect.ImmutableMap; +import java.util.ArrayList; +import java.util.Arrays; +import java.util.HashSet; +import java.util.List; +import java.util.Map; +import java.util.Optional; +import java.util.Set; + +import org.junit.Assert; import org.junit.Before; import org.junit.Rule; import org.junit.Test; -import org.junit.Assert; import org.junit.rules.ExpectedException; import org.opensearch.ml.action.stats.MLStatsNodesRequest; import org.opensearch.ml.plugin.MachineLearningPlugin; @@ -25,22 +32,14 @@ import org.opensearch.ml.stats.StatNames; import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.rest.RestRequest; - import org.opensearch.test.OpenSearchTestCase; import org.opensearch.test.rest.FakeRestRequest; -import java.util.ArrayList; -import java.util.Arrays; -import java.util.HashSet; -import java.util.List; -import java.util.Map; -import java.util.Optional; -import java.util.Set; - +import com.google.common.collect.ImmutableMap; public class RestStatsMLActionTests extends OpenSearchTestCase { @Rule - public ExpectedException thrown= ExpectedException.none(); + public ExpectedException thrown = ExpectedException.none(); RestStatsMLAction restAction; MLStats mlStats; @@ -48,24 +47,21 @@ public class RestStatsMLActionTests extends OpenSearchTestCase { @Before public void setup() { Map> statMap = ImmutableMap - .>builder() - .put(StatNames.ML_EXECUTING_TASK_COUNT.getName(), new MLStat<>(false, new CounterSupplier())) - .build(); + .>builder() + .put(StatNames.ML_EXECUTING_TASK_COUNT.getName(), new MLStat<>(false, new CounterSupplier())) + .build(); mlStats = new MLStats(statMap); restAction = new RestStatsMLAction(mlStats); } @Test public void testsplitCommaSeparatedParam() { - Map param = ImmutableMap - .builder() - .put("nodeId", "111,222") - .build(); + Map param = ImmutableMap.builder().put("nodeId", "111,222").build(); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()) - .withMethod(RestRequest.Method.GET) - .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") - .withParams(param) - .build(); + .withMethod(RestRequest.Method.GET) + .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") + .withParams(param) + .build(); Optional nodeId = restAction.splitCommaSeparatedParam(fakeRestRequest, "nodeId"); String[] array = nodeId.get(); Assert.assertEquals(array[0], "111"); @@ -87,13 +83,13 @@ public void testStatsSetContainsAllStatsKey() { thrown.expect(IllegalArgumentException.class); thrown.expectMessage(MLStatsNodesRequest.ALL_STATS_KEY); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()) - .withMethod(RestRequest.Method.GET) - .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") - .build(); + .withMethod(RestRequest.Method.GET) + .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") + .build(); Set validStats = new HashSet<>(); validStats.add("stat1"); validStats.add("stat2"); - List requestedStats = new ArrayList<>(Arrays.asList("stat1", "stat2",MLStatsNodesRequest.ALL_STATS_KEY)); + List requestedStats = new ArrayList<>(Arrays.asList("stat1", "stat2", MLStatsNodesRequest.ALL_STATS_KEY)); restAction.getStatsToBeRetrieved(fakeRestRequest, validStats, requestedStats); } @@ -102,28 +98,28 @@ public void testStatsSetContainsInvalidStats() { thrown.expect(IllegalArgumentException.class); thrown.expectMessage("unrecognized"); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()) - .withMethod(RestRequest.Method.GET) - .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") - .build(); + .withMethod(RestRequest.Method.GET) + .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") + .build(); Set validStats = new HashSet<>(); validStats.add("stat1"); validStats.add("stat2"); - List requestedStats = new ArrayList<>(Arrays.asList("stat1", "stat2","invalidStat")); + List requestedStats = new ArrayList<>(Arrays.asList("stat1", "stat2", "invalidStat")); restAction.getStatsToBeRetrieved(fakeRestRequest, validStats, requestedStats); } @Test public void testGetRequestAllStats() { Map param = ImmutableMap - .builder() - .put("nodeId", "111,222") - .put("stat", MLStatsNodesRequest.ALL_STATS_KEY) - .build(); + .builder() + .put("nodeId", "111,222") + .put("stat", MLStatsNodesRequest.ALL_STATS_KEY) + .build(); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()) - .withMethod(RestRequest.Method.GET) - .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/{stat}") - .withParams(param) - .build(); + .withMethod(RestRequest.Method.GET) + .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/{stat}") + .withParams(param) + .build(); MLStatsNodesRequest request = restAction.getRequest(fakeRestRequest); Assert.assertEquals(request.getStatsToBeRetrieved().size(), 1); Assert.assertTrue(request.getStatsToBeRetrieved().contains(StatNames.ML_EXECUTING_TASK_COUNT.getName())); @@ -131,15 +127,12 @@ public void testGetRequestAllStats() { @Test public void testGetRequestEmptyStats() { - Map param = ImmutableMap - .builder() - .put("nodeId", "111,222") - .build(); + Map param = ImmutableMap.builder().put("nodeId", "111,222").build(); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()) - .withMethod(RestRequest.Method.GET) - .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") - .withParams(param) - .build(); + .withMethod(RestRequest.Method.GET) + .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/") + .withParams(param) + .build(); MLStatsNodesRequest request = restAction.getRequest(fakeRestRequest); Assert.assertEquals(request.getStatsToBeRetrieved().size(), 1); Assert.assertTrue(request.getStatsToBeRetrieved().contains(StatNames.ML_EXECUTING_TASK_COUNT.getName())); @@ -148,15 +141,15 @@ public void testGetRequestEmptyStats() { @Test public void testGetRequestSpecifyStats() { Map param = ImmutableMap - .builder() - .put("nodeId", "111,222") - .put("stat", StatNames.ML_EXECUTING_TASK_COUNT.getName()) - .build(); + .builder() + .put("nodeId", "111,222") + .put("stat", StatNames.ML_EXECUTING_TASK_COUNT.getName()) + .build(); FakeRestRequest fakeRestRequest = new FakeRestRequest.Builder(xContentRegistry()) - .withMethod(RestRequest.Method.GET) - .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/{stat}") - .withParams(param) - .build(); + .withMethod(RestRequest.Method.GET) + .withPath(MachineLearningPlugin.ML_BASE_URI + "/{nodeId}/stats/{stat}") + .withParams(param) + .build(); MLStatsNodesRequest request = restAction.getRequest(fakeRestRequest); Assert.assertEquals(request.getStatsToBeRetrieved().size(), 1); Assert.assertTrue(request.getStatsToBeRetrieved().contains(StatNames.ML_EXECUTING_TASK_COUNT.getName())); diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatTests.java index 3b060d2406..f2f0d586dc 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatTests.java @@ -12,13 +12,13 @@ package org.opensearch.ml.stats; +import java.util.function.Supplier; + +import org.junit.Assert; +import org.junit.Test; import org.opensearch.ml.stats.suppliers.CounterSupplier; import org.opensearch.ml.stats.suppliers.SettableSupplier; import org.opensearch.test.OpenSearchTestCase; -import org.junit.Assert; -import org.junit.Test; - -import java.util.function.Supplier; public class MLStatTests extends OpenSearchTestCase { @Test @@ -32,9 +32,9 @@ public void testIsClusterLevel() { @Test public void testSetGetValue() { MLStat stat1 = new MLStat<>(false, new CounterSupplier()); - Assert.assertEquals("GetValue returns the incorrect value", 0L, (long)(stat1.getValue())); + Assert.assertEquals("GetValue returns the incorrect value", 0L, (long) (stat1.getValue())); stat1.setValue(1L); - Assert.assertEquals("GetValue returns the incorrect value", 0L, (long)(stat1.getValue())); + Assert.assertEquals("GetValue returns the incorrect value", 0L, (long) (stat1.getValue())); MLStat stat2 = new MLStat<>(false, new TestSupplier()); Assert.assertEquals("GetValue returns the incorrect value", "test", stat2.getValue()); @@ -42,9 +42,9 @@ public void testSetGetValue() { Assert.assertEquals("GetValue returns the incorrect value", "test", stat2.getValue()); MLStat stat3 = new MLStat<>(false, new SettableSupplier()); - Assert.assertEquals("GetValue returns the incorrect value", 0L, (long)stat3.getValue()); + Assert.assertEquals("GetValue returns the incorrect value", 0L, (long) stat3.getValue()); stat3.setValue(1L); - Assert.assertEquals("GetValue returns the incorrect value", 1L, (long)stat3.getValue()); + Assert.assertEquals("GetValue returns the incorrect value", 1L, (long) stat3.getValue()); } @Test diff --git a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java index 6356f8583e..6c6a1f445c 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/MLStatsTests.java @@ -12,17 +12,17 @@ package org.opensearch.ml.stats; -import org.opensearch.ml.stats.suppliers.CounterSupplier; -import org.opensearch.test.OpenSearchTestCase; -import org.junit.Assert; -import org.junit.Before; -import org.junit.Test; - import java.util.HashMap; import java.util.HashSet; import java.util.Map; import java.util.Set; +import org.junit.Assert; +import org.junit.Before; +import org.junit.Test; +import org.opensearch.ml.stats.suppliers.CounterSupplier; +import org.opensearch.test.OpenSearchTestCase; + public class MLStatsTests extends OpenSearchTestCase { private Map> statsMap; private MLStats mlStats; @@ -47,21 +47,26 @@ public void setup() { @Test public void testStatNamesGetNames() { - Assert.assertEquals("getNames of StatNames returns the incorrect number of stats", - StatNames.getNames().size(), StatNames.values().length); + Assert + .assertEquals( + "getNames of StatNames returns the incorrect number of stats", + StatNames.getNames().size(), + StatNames.values().length + ); } @Test public void testGetStats() { Map> stats = mlStats.getStats(); - Assert.assertEquals("getStats returns the incorrect number of stats", - stats.size(), statsMap.size()); + Assert.assertEquals("getStats returns the incorrect number of stats", stats.size(), statsMap.size()); for (Map.Entry> stat : stats.entrySet()) { - Assert.assertTrue("getStats returns incorrect stats", - mlStats.getStats().containsKey(stat.getKey()) && - mlStats.getStats().get(stat.getKey()) == stat.getValue()); + Assert + .assertTrue( + "getStats returns incorrect stats", + mlStats.getStats().containsKey(stat.getKey()) && mlStats.getStats().get(stat.getKey()) == stat.getValue() + ); } } @@ -69,9 +74,11 @@ public void testGetStats() { public void testGetStat() { MLStat stat = mlStats.getStat(clusterStatName1); - Assert.assertTrue("getStat returns incorrect stat", - mlStats.getStats().containsKey(clusterStatName1) && - mlStats.getStats().get(clusterStatName1) == stat); + Assert + .assertTrue( + "getStat returns incorrect stat", + mlStats.getStats().containsKey(clusterStatName1) && mlStats.getStats().get(clusterStatName1) == stat + ); } @Test(expected = IllegalArgumentException.class) @@ -85,9 +92,11 @@ public void testGetNodeStats() { Set> nodeStats = new HashSet<>(mlStats.getNodeStats().values()); for (MLStat stat : stats.values()) { - Assert.assertTrue("getNodeStats returns incorrect stat", - (stat.isClusterLevel() && !nodeStats.contains(stat)) || - (!stat.isClusterLevel() && nodeStats.contains(stat))); + Assert + .assertTrue( + "getNodeStats returns incorrect stat", + (stat.isClusterLevel() && !nodeStats.contains(stat)) || (!stat.isClusterLevel() && nodeStats.contains(stat)) + ); } } @@ -97,9 +106,11 @@ public void testGetClusterStats() { Set> clusterStats = new HashSet<>(mlStats.getClusterStats().values()); for (MLStat stat : stats.values()) { - Assert.assertTrue("getClusterStats returns incorrect stat", - (stat.isClusterLevel() && clusterStats.contains(stat)) || - (!stat.isClusterLevel() && !clusterStats.contains(stat))); + Assert + .assertTrue( + "getClusterStats returns incorrect stat", + (stat.isClusterLevel() && clusterStats.contains(stat)) || (!stat.isClusterLevel() && !clusterStats.contains(stat)) + ); } } } diff --git a/plugin/src/test/java/org/opensearch/ml/stats/suppliers/CounterSupplierTests.java b/plugin/src/test/java/org/opensearch/ml/stats/suppliers/CounterSupplierTests.java index f22800c78a..9e8004f2b4 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/suppliers/CounterSupplierTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/suppliers/CounterSupplierTests.java @@ -12,8 +12,8 @@ package org.opensearch.ml.stats.suppliers; -import org.opensearch.test.OpenSearchTestCase; import org.junit.Test; +import org.opensearch.test.OpenSearchTestCase; public class CounterSupplierTests extends OpenSearchTestCase { @Test diff --git a/plugin/src/test/java/org/opensearch/ml/stats/suppliers/SettableSupplierTests.java b/plugin/src/test/java/org/opensearch/ml/stats/suppliers/SettableSupplierTests.java index f62928a662..485f4887ad 100644 --- a/plugin/src/test/java/org/opensearch/ml/stats/suppliers/SettableSupplierTests.java +++ b/plugin/src/test/java/org/opensearch/ml/stats/suppliers/SettableSupplierTests.java @@ -12,8 +12,8 @@ package org.opensearch.ml.stats.suppliers; -import org.opensearch.test.OpenSearchTestCase; import org.junit.Test; +import org.opensearch.test.OpenSearchTestCase; public class SettableSupplierTests extends OpenSearchTestCase { @Test diff --git a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java index 27aef0417a..c0df4ba322 100644 --- a/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java +++ b/plugin/src/test/java/org/opensearch/ml/task/MLTaskManagerTests.java @@ -12,6 +12,8 @@ package org.opensearch.ml.task; +import java.time.Instant; + import org.junit.Assert; import org.junit.Before; import org.junit.Rule; @@ -21,8 +23,6 @@ import org.opensearch.ml.model.MLTaskState; import org.opensearch.ml.model.MLTaskType; -import java.time.Instant; - public class MLTaskManagerTests { MLTaskManager mlTaskManager; MLTask mlTask; @@ -33,12 +33,13 @@ public class MLTaskManagerTests { @Before public void setup() { this.mlTaskManager = new MLTaskManager(); - this.mlTask = MLTask.builder() - .taskId("task id") - .taskType(MLTaskType.PREDICTION) - .createTime(Instant.now()) - .state(MLTaskState.CREATED) - .build(); + this.mlTask = MLTask + .builder() + .taskId("task id") + .taskType(MLTaskType.PREDICTION) + .createTime(Instant.now()) + .state(MLTaskState.CREATED) + .build(); } @Test @@ -80,22 +81,10 @@ public void testRemove() { @Test public void testGetRunningTaskCount() { - MLTask task1 = MLTask.builder() - .taskId("1") - .state(MLTaskState.CREATED) - .build(); - MLTask task2 = MLTask.builder() - .taskId("2") - .state(MLTaskState.RUNNING) - .build(); - MLTask task3 = MLTask.builder() - .taskId("3") - .state(MLTaskState.FAILED) - .build(); - MLTask task4 = MLTask.builder() - .taskId("4") - .state(MLTaskState.COMPLETED) - .build(); + MLTask task1 = MLTask.builder().taskId("1").state(MLTaskState.CREATED).build(); + MLTask task2 = MLTask.builder().taskId("2").state(MLTaskState.RUNNING).build(); + MLTask task3 = MLTask.builder().taskId("3").state(MLTaskState.FAILED).build(); + MLTask task4 = MLTask.builder().taskId("4").state(MLTaskState.COMPLETED).build(); mlTaskManager.add(task1); mlTaskManager.add(task2); mlTaskManager.add(task3); @@ -105,22 +94,10 @@ public void testGetRunningTaskCount() { @Test public void testClear() { - MLTask task1 = MLTask.builder() - .taskId("1") - .state(MLTaskState.CREATED) - .build(); - MLTask task2 = MLTask.builder() - .taskId("2") - .state(MLTaskState.RUNNING) - .build(); - MLTask task3 = MLTask.builder() - .taskId("3") - .state(MLTaskState.FAILED) - .build(); - MLTask task4 = MLTask.builder() - .taskId("4") - .state(MLTaskState.COMPLETED) - .build(); + MLTask task1 = MLTask.builder().taskId("1").state(MLTaskState.CREATED).build(); + MLTask task2 = MLTask.builder().taskId("2").state(MLTaskState.RUNNING).build(); + MLTask task3 = MLTask.builder().taskId("3").state(MLTaskState.FAILED).build(); + MLTask task4 = MLTask.builder().taskId("4").state(MLTaskState.COMPLETED).build(); mlTaskManager.add(task1); mlTaskManager.add(task2); mlTaskManager.add(task3); diff --git a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java index afffd4b5b0..88b8b96474 100644 --- a/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java +++ b/plugin/src/test/java/org/opensearch/ml/utils/MLNodeUtilsTests.java @@ -12,33 +12,30 @@ package org.opensearch.ml.utils; +import static java.util.Collections.emptyMap; + +import java.util.HashSet; +import java.util.Set; + import org.junit.Assert; import org.junit.Test; import org.opensearch.Version; import org.opensearch.cluster.node.DiscoveryNode; import org.opensearch.cluster.node.DiscoveryNodeRole; -import org.opensearch.common.settings.Setting; import org.opensearch.ml.plugin.MachineLearningPlugin; import org.opensearch.test.OpenSearchTestCase; -import java.util.HashSet; -import java.util.Set; - -import static java.util.Collections.emptyMap; - public class MLNodeUtilsTests extends OpenSearchTestCase { @Test public void testIsMLNode() { Set roleSet = new HashSet<>(); roleSet.add(DiscoveryNodeRole.DATA_ROLE); roleSet.add(DiscoveryNodeRole.INGEST_ROLE); - DiscoveryNode normalNode = new DiscoveryNode("Normal node", buildNewFakeTransportAddress(), emptyMap(), roleSet, - Version.CURRENT); + DiscoveryNode normalNode = new DiscoveryNode("Normal node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT); Assert.assertFalse(MLNodeUtils.isMLNode(normalNode)); roleSet.add(MachineLearningPlugin.ML_ROLE); - DiscoveryNode mlNode = new DiscoveryNode("ML node", buildNewFakeTransportAddress(), emptyMap(), roleSet, - Version.CURRENT); + DiscoveryNode mlNode = new DiscoveryNode("ML node", buildNewFakeTransportAddress(), emptyMap(), roleSet, Version.CURRENT); Assert.assertTrue(MLNodeUtils.isMLNode(mlNode)); } }