diff --git a/build-tools/repositories.gradle b/build-tools/repositories.gradle index 8993a5d367..e72fc8a8cc 100644 --- a/build-tools/repositories.gradle +++ b/build-tools/repositories.gradle @@ -7,5 +7,6 @@ repositories { mavenLocal() maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" } mavenCentral() + maven {url 'https://oss.sonatype.org/content/repositories/snapshots/'} maven { url "https://d1nvenhzbhpy0q.cloudfront.net/snapshots/lucene/" } } \ No newline at end of file diff --git a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java index f08abb85aa..3538797478 100644 --- a/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java +++ b/client/src/main/java/org/opensearch/ml/client/MachineLearningNodeClient.java @@ -13,7 +13,7 @@ import org.opensearch.action.delete.DeleteResponse; import org.opensearch.action.search.SearchRequest; import org.opensearch.action.search.SearchResponse; -import org.opensearch.client.node.NodeClient; +import org.opensearch.client.Client; import org.opensearch.ml.common.input.MLInput; import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.output.MLOutput; @@ -38,17 +38,16 @@ @RequiredArgsConstructor public class MachineLearningNodeClient implements MachineLearningClient { - NodeClient client; + Client client; @Override public void predict(String modelId, MLInput mlInput, ActionListener listener) { validateMLInput(mlInput, true); MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder() - .mlInput(mlInput) - .modelId(modelId) - .build(); - + .mlInput(mlInput) + .modelId(modelId) + .build(); client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener)); } @@ -80,9 +79,18 @@ public void getModel(String modelId, ActionListener listener) { .modelId(modelId) .build(); - client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, ActionListener.wrap(response -> { - listener.onResponse(MLModelGetResponse.fromActionResponse(response).getMlModel()); - }, listener::onFailure)); + client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener)); + } + + private ActionListener getMlGetModelResponseActionListener(ActionListener listener) { + ActionListener internalListener = ActionListener.wrap(predictionResponse -> { + listener.onResponse(predictionResponse.getMlModel()); + }, listener::onFailure); + ActionListener actionListener = wrapActionListener(internalListener, res -> { + MLModelGetResponse getResponse = MLModelGetResponse.fromActionResponse(res); + return getResponse; + }); + return actionListener; } @Override diff --git a/common/src/main/java/org/opensearch/ml/common/MLModel.java b/common/src/main/java/org/opensearch/ml/common/MLModel.java index 5ecf2afa2b..0549a5b43d 100644 --- a/common/src/main/java/org/opensearch/ml/common/MLModel.java +++ b/common/src/main/java/org/opensearch/ml/common/MLModel.java @@ -95,13 +95,6 @@ public MLModel(String name, FunctionName algorithm, Integer version, String cont this.totalChunks = totalChunks; } - public MLModel(FunctionName algorithm, Model model) { - this.name = model.getName(); - this.algorithm = algorithm; - this.version = model.getVersion(); - this.content = Base64.getEncoder().encodeToString(model.getContent()); - } - public MLModel(StreamInput input) throws IOException{ name = input.readOptionalString(); algorithm = input.readEnum(FunctionName.class); diff --git a/common/src/main/java/org/opensearch/ml/common/exception/MLException.java b/common/src/main/java/org/opensearch/ml/common/exception/MLException.java index 8994678401..fea47dc826 100644 --- a/common/src/main/java/org/opensearch/ml/common/exception/MLException.java +++ b/common/src/main/java/org/opensearch/ml/common/exception/MLException.java @@ -33,7 +33,7 @@ public MLException(Throwable cause) { } /** - * Constructor with specified error message adn cause. + * Constructor with specified error message and cause. * @param message error message * @param cause exception cause */ diff --git a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java index 47437f91eb..a58fa09f97 100644 --- a/common/src/main/java/org/opensearch/ml/common/input/MLInput.java +++ b/common/src/main/java/org/opensearch/ml/common/input/MLInput.java @@ -42,11 +42,17 @@ public class MLInput implements Input { public static final String INPUT_INDEX_FIELD = "input_index"; public static final String INPUT_QUERY_FIELD = "input_query"; public static final String INPUT_DATA_FIELD = "input_data"; + // For trained model + // Return bytes in model output public static final String RETURN_BYTES_FIELD = "return_bytes"; + // Return bytes in model output. This can be used together with return_bytes. public static final String RETURN_NUMBER_FIELD = "return_number"; + // Filter target response with name in model output public static final String TARGET_RESPONSE_FIELD = "target_response"; + // Filter target response with position in model output public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions"; + // Input text sentences for text embedding model public static final String TEXT_DOCS_FIELD = "text_docs"; // Algorithm name diff --git a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java index 06899dd05f..6e57f46b6c 100644 --- a/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java +++ b/common/src/main/java/org/opensearch/ml/common/output/model/ModelTensors.java @@ -66,12 +66,12 @@ public void writeTo(StreamOutput out) throws IOException { public void filter(ModelResultFilter resultFilter) { boolean returnBytes = resultFilter.isReturnBytes(); - boolean returnNUmber = resultFilter.isReturnNumber(); + boolean returnNumber = resultFilter.isReturnNumber(); List targetResponse = resultFilter.getTargetResponse(); List targetResponsePositions = resultFilter.getTargetResponsePositions(); if ((targetResponse == null || targetResponse.size() == 0) && (targetResponsePositions == null || targetResponsePositions.size() == 0)) { - mlModelTensors.forEach(output -> filter(output, returnBytes, returnNUmber)); + mlModelTensors.forEach(output -> filter(output, returnBytes, returnNumber)); return; } List targetOutput = new ArrayList<>(); @@ -79,10 +79,10 @@ public void filter(ModelResultFilter resultFilter) { for (int i = 0 ; i params) { + Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class); + predictable.initModel(mlModel, params); + return predictable; } - public static MLOutput predict(Input input, Model model) { + public static MLOutput predict(Input input, MLModel model) { validateMLInput(input); MLInput mlInput = (MLInput) input; Predictable predictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class); if (predictable == null) { throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm()); } - return predictable.predict(mlInput.getDataFrame(), model); + return predictable.predict(mlInput.getInputDataset(), model); } public static MLOutput trainAndPredict(Input input) { @@ -45,7 +55,7 @@ public static MLOutput trainAndPredict(Input input) { if (trainAndPredictable == null) { throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm()); } - return trainAndPredictable.trainAndPredict(mlInput.getDataFrame()); + return trainAndPredictable.trainAndPredict(mlInput.getInputDataset()); } public static Output execute(Input input) { @@ -63,9 +73,15 @@ private static void validateMLInput(Input input) { throw new IllegalArgumentException("Input should be MLInput"); } MLInput mlInput = (MLInput) input; - DataFrame dataFrame = mlInput.getDataFrame(); - if (dataFrame == null || dataFrame.size() == 0) { - throw new IllegalArgumentException("Input data frame should not be null or empty"); + MLInputDataset inputDataset = mlInput.getInputDataset(); + if (inputDataset == null) { + throw new IllegalArgumentException("Input data set should not be null"); + } + if (inputDataset instanceof DataFrameInputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); + if (dataFrame == null || dataFrame.size() == 0) { + throw new IllegalArgumentException("Input data frame should not be null or empty"); + } } } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java index a40b791adc..df19bf5823 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Predictable.java @@ -5,9 +5,11 @@ package org.opensearch.ml.engine; -import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.output.MLOutput; -import org.opensearch.ml.common.Model; + +import java.util.Map; /** * This is machine learning algorithms predict interface. @@ -15,11 +17,30 @@ public interface Predictable { /** - * Predict with given features and model (optional). - * @param dataFrame features data + * Predict with given input data and model (optional). + * Will reload model into memory with model content. + * @param inputDataset input data set * @param model the java serialized model * @return predicted results */ - MLOutput predict(DataFrame dataFrame, Model model); + MLOutput predict(MLInputDataset inputDataset, MLModel model); + + /** + * Predict with given input data with loaded model. + * @param inputDataset input data set + * @return predicted results + */ + MLOutput predict(MLInputDataset inputDataset); + /** + * Init model (load model into memory) with ML model content and params. + * @param model ML model + * @param params other parameters + */ + void initModel(MLModel model, Map params); + + /** + * Close resources like loaded model. + */ + void close(); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/TrainAndPredictable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/TrainAndPredictable.java index d67d685b2e..f6d761bb9b 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/TrainAndPredictable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/TrainAndPredictable.java @@ -5,20 +5,20 @@ package org.opensearch.ml.engine; -import org.opensearch.ml.common.dataframe.DataFrame; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.output.MLOutput; /** - * This is machine learning algorithms train interface. + * This is machine learning algorithms train and predict interface. */ public interface TrainAndPredictable extends Trainable, Predictable { /** - * Train model with given features. Then predict with the same data. - * @param dataFrame training data - * @return the java serialized model + * Train model with given input data. Then predict with the same data. + * @param inputDataset training data + * @return ML model with serialized model content */ - MLOutput trainAndPredict(DataFrame dataFrame); + MLOutput trainAndPredict(MLInputDataset inputDataset); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java index 755b365144..397b64c19e 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/Trainable.java @@ -5,8 +5,8 @@ package org.opensearch.ml.engine; -import org.opensearch.ml.common.dataframe.DataFrame; -import org.opensearch.ml.common.Model; +import org.opensearch.ml.common.MLModel; +import org.opensearch.ml.common.dataset.MLInputDataset; /** * This is machine learning algorithms train interface. @@ -15,9 +15,9 @@ public interface Trainable { /** * Train model with given features. - * @param dataFrame training data - * @return the java serialized model + * @param inputDataset training data + * @return ML model with serialized model content */ - Model train(DataFrame dataFrame); + MLModel train(MLInputDataset inputDataset); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java index d145c5d1aa..2c6c6c44be 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/ad/AnomalyDetectionLibSVM.java @@ -5,14 +5,16 @@ package org.opensearch.ml.engine.algorithms.ad; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.FunctionName; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.parameter.ad.AnomalyDetectionLibSVMParams; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.Model; import org.opensearch.ml.engine.Predictable; import org.opensearch.ml.engine.Trainable; import org.opensearch.ml.engine.annotation.Function; @@ -48,6 +50,7 @@ public class AnomalyDetectionLibSVM implements Trainable, Predictable { private static KernelType DEFAULT_KERNEL_TYPE = KernelType.RBF; private AnomalyDetectionLibSVMParams parameters; + private LibSVMModel libSVMAnomalyModel = null; public AnomalyDetectionLibSVM() {} @@ -69,15 +72,24 @@ private void validateParameters() { } @Override - public MLOutput predict(DataFrame dataFrame, Model model) { - if (model == null) { - throw new IllegalArgumentException("No model found for KMeans prediction."); - } + public void initModel(MLModel model, Map params) { + this.libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model); + } + @Override + public void close() { + this.libSVMAnomalyModel = null; + } + + @Override + public MLOutput predict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); + if (libSVMAnomalyModel == null) { + throw new IllegalArgumentException("model not loaded"); + } List> predictions; MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new AnomalyFactory(), "Anomaly detection LibSVM prediction data from OpenSearch", TribuoOutputType.ANOMALY_DETECTION_LIBSVM); - LibSVMModel libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model.getContent()); predictions = libSVMAnomalyModel.predict(predictionDataset); List> adResults = new ArrayList<>(); @@ -92,7 +104,18 @@ public MLOutput predict(DataFrame dataFrame, Model model) { } @Override - public Model train(DataFrame dataFrame) { + public MLOutput predict(MLInputDataset inputDataset, MLModel model) { + if (model == null) { + throw new IllegalArgumentException("No model found for KMeans prediction."); + } + + libSVMAnomalyModel = (LibSVMModel) ModelSerDeSer.deserialize(model); + return predict(inputDataset); + } + + @Override + public MLModel train(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); KernelType kernelType = parseKernelType(); SVMParameters params = new SVMParameters<>(new SVMAnomalyType(SVMAnomalyType.SVMMode.ONE_CLASS), kernelType); Double gamma = Optional.ofNullable(parameters.getGamma()).orElse(DEFAULT_GAMMA); @@ -118,10 +141,13 @@ public Model train(DataFrame dataFrame) { LibSVMModel libSVMModel = trainer.train(data); ((LibSVMAnomalyModel)libSVMModel).getNumberOfSupportVectors(); - Model model = new Model(); - model.setName(FunctionName.AD_LIBSVM.name()); - model.setVersion(VERSION); - model.setContent(ModelSerDeSer.serialize(libSVMModel)); + + MLModel model = MLModel.builder() + .name(FunctionName.AD_LIBSVM.name()) + .algorithm(FunctionName.AD_LIBSVM) + .version(VERSION) + .content(ModelSerDeSer.serializeToBase64(libSVMModel)) + .build(); return model; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java index f789c4078f..9c7de99fa5 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/KMeans.java @@ -5,14 +5,16 @@ package org.opensearch.ml.engine.algorithms.clustering; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.parameter.clustering.KMeansParams; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.Model; import org.opensearch.ml.engine.TrainAndPredictable; import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.utils.ModelSerDeSer; @@ -42,11 +44,12 @@ public class KMeans implements TrainAndPredictable { //The number of threads. private int numThreads = Math.max(Runtime.getRuntime().availableProcessors() / 2, 1); //Assume cpu-bound. - + //The random seed. private long seed = System.currentTimeMillis(); private KMeansTrainer.Distance distance; + private KMeansModel kMeansModel; public KMeans() {} public KMeans(MLAlgoParams parameters) { @@ -83,17 +86,21 @@ private void createDistance() { } @Override - public MLOutput predict(DataFrame dataFrame, Model model) { - if (model == null) { - throw new IllegalArgumentException("No model found for KMeans prediction."); - } + public void initModel(MLModel model, Map params) { + this.kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model); + } + + @Override + public void close() { + this.kMeansModel = null; + } - List> predictions; + @Override + public MLOutput predict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans prediction data from opensearch", TribuoOutputType.CLUSTERID); - KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.getContent()); - predictions = kMeansModel.predict(predictionDataset); - + List> predictions = kMeansModel.predict(predictionDataset); List> listClusterID = new ArrayList<>(); predictions.forEach(e -> listClusterID.add(Collections.singletonMap("ClusterID", e.getOutput().getID()))); @@ -101,23 +108,36 @@ public MLOutput predict(DataFrame dataFrame, Model model) { } @Override - public Model train(DataFrame dataFrame) { + public MLOutput predict(MLInputDataset inputDataset, MLModel model) { + if (model == null) { + throw new IllegalArgumentException("No model found for KMeans prediction."); + } + this.kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model); + return predict(inputDataset); + } + + @Override + public MLModel train(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); MutableDataset trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training data from opensearch", TribuoOutputType.CLUSTERID); Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS); Integer iterations = Optional.ofNullable(parameters.getIterations()).orElse(DEFAULT_ITERATIONS); KMeansTrainer trainer = new KMeansTrainer(centroids, iterations, distance, numThreads, seed); KMeansModel kMeansModel = trainer.train(trainDataset); - Model model = new Model(); - model.setName(FunctionName.KMEANS.name()); - model.setVersion(1); - model.setContent(ModelSerDeSer.serialize(kMeansModel)); + MLModel model = MLModel.builder() + .name(FunctionName.KMEANS.name()) + .algorithm(FunctionName.KMEANS) + .version(1) + .content(ModelSerDeSer.serializeToBase64(kMeansModel)) + .build(); return model; } @Override - public MLOutput trainAndPredict(DataFrame dataFrame) { + public MLOutput trainAndPredict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); MutableDataset trainDataset = TribuoUtil.generateDataset(dataFrame, new ClusteringFactory(), "KMeans training and predicting data from opensearch", TribuoOutputType.CLUSTERID); Integer centroids = Optional.ofNullable(parameters.getCentroids()).orElse(DEFAULT_CENTROIDS); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java index 43e28c777b..48810db810 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/clustering/RCFSummarize.java @@ -5,15 +5,17 @@ package org.opensearch.ml.engine.algorithms.clustering; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.parameter.clustering.RCFSummarizeParams; import org.opensearch.common.collect.Tuple; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.Model; import org.opensearch.ml.engine.TrainAndPredictable; import org.opensearch.ml.engine.annotation.Function; import org.opensearch.ml.engine.utils.MathUtil; @@ -21,7 +23,6 @@ import org.opensearch.ml.engine.utils.TribuoUtil; import com.amazon.randomcutforest.returntypes.SampleSummary; import com.amazon.randomcutforest.summarization.Summarizer; -import org.opensearch.ml.engine.algorithms.clustering.SerializableSummary; import java.util.ArrayList; import java.util.Arrays; @@ -43,6 +44,7 @@ public class RCFSummarize implements TrainAndPredictable { // Parameters private RCFSummarizeParams parameters; private BiFunction distance; + private SampleSummary summary; public RCFSummarize() {} @@ -109,32 +111,40 @@ private void createDistance() { } @Override - public Model train(DataFrame dataFrame) { + public MLModel train(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); Tuple featureNamesValues = TribuoUtil.transformDataFrameFloat(dataFrame); - SampleSummary summary = Summarizer.summarize(featureNamesValues.v2(), - parameters.getMaxK(), - parameters.getInitialK(), - parameters.getPhase1Reassign(), - distance, - rnd.nextLong(), - parameters.getParallel()); - - Model model = new Model(); - model.setName(FunctionName.RCF_SUMMARIZE.name()); - model.setVersion(1); - model.setContent(ModelSerDeSer.serialize(new SerializableSummary(summary))); - + SampleSummary summary = Summarizer.summarize(featureNamesValues.v2(), + parameters.getMaxK(), + parameters.getInitialK(), + parameters.getPhase1Reassign(), + distance, + rnd.nextLong(), + parameters.getParallel()); + + MLModel model = MLModel.builder() + .name(FunctionName.RCF_SUMMARIZE.name()) + .algorithm(FunctionName.RCF_SUMMARIZE) + .version(1) + .content(ModelSerDeSer.serializeToBase64(new SerializableSummary(summary))) + .build(); return model; } @Override - public MLOutput predict(DataFrame dataFrame, Model model) { - if (model == null) { - throw new IllegalArgumentException("No model found for RCFSummarize prediction."); - } + public void initModel(MLModel model, Map params) { + this.summary = ((SerializableSummary)ModelSerDeSer.deserialize(model)).getSummary(); + } - SampleSummary summary = ((SerializableSummary)ModelSerDeSer.deserialize(model.getContent())).getSummary(); + @Override + public void close() { + this.summary = null; + } + + @Override + public MLOutput predict(MLInputDataset inputDataset) { Iterable centroidsLst = Arrays.asList(summary.summaryPoints); + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); Tuple featureNamesValues = TribuoUtil.transformDataFrameFloat(dataFrame); List predictions = new ArrayList<>(); Arrays.stream(featureNamesValues.v2()).forEach(e->predictions.add(MathUtil.findNearest(e, centroidsLst, distance))); @@ -146,16 +156,27 @@ public MLOutput predict(DataFrame dataFrame, Model model) { } @Override - public MLOutput trainAndPredict(DataFrame dataFrame) { + public MLOutput predict(MLInputDataset inputDataset, MLModel model) { + if (model == null) { + throw new IllegalArgumentException("No model found for RCFSummarize prediction."); + } + + summary = ((SerializableSummary)ModelSerDeSer.deserialize(model)).getSummary(); + return predict(inputDataset); + } + + @Override + public MLOutput trainAndPredict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); Tuple featureNamesValues = TribuoUtil.transformDataFrameFloat(dataFrame); - SampleSummary summary = Summarizer.summarize(featureNamesValues.v2(), - parameters.getMaxK(), - parameters.getInitialK(), - parameters.getPhase1Reassign(), - distance, - rnd.nextLong(), - parameters.getParallel()); - + SampleSummary summary = Summarizer.summarize(featureNamesValues.v2(), + parameters.getMaxK(), + parameters.getInitialK(), + parameters.getPhase1Reassign(), + distance, + rnd.nextLong(), + parameters.getParallel()); + Iterable centroidsLst = Arrays.asList(summary.summaryPoints); List predictions = new ArrayList<>(); Arrays.stream(featureNamesValues.v2()).forEach(e->predictions.add(MathUtil.findNearest(e, centroidsLst, distance))); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java index 4446e816fd..ba9e8f432c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/BatchRandomCutForest.java @@ -10,12 +10,14 @@ import com.amazon.randomcutforest.state.RandomCutForestState; import lombok.extern.log4j.Log4j2; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.Model; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.ColumnMeta; import org.opensearch.ml.common.dataframe.ColumnValue; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.rcf.BatchRCFParams; import org.opensearch.ml.common.output.MLOutput; @@ -29,6 +31,9 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.ml.engine.utils.ModelSerDeSer.decodeBase64; +import static org.opensearch.ml.engine.utils.ModelSerDeSer.encodeBase64; + /** * Use RCF to detect non-time-series data. */ @@ -48,6 +53,8 @@ public class BatchRandomCutForest implements TrainAndPredictable { private static final RandomCutForestMapper rcfMapper = new RandomCutForestMapper(); + private RandomCutForest forest; + public BatchRandomCutForest(){} public BatchRandomCutForest(MLAlgoParams parameters) { @@ -63,31 +70,53 @@ public BatchRandomCutForest(MLAlgoParams parameters) { } @Override - public MLOutput predict(DataFrame dataFrame, Model model) { + public void initModel(MLModel model, Map params) { + RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model); + forest = rcfMapper.toModel(state); + } + + @Override + public void close() { + forest = null; + } + + @Override + public MLOutput predict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); + List> predictResult = process(dataFrame, forest, 0); + return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); + } + + @Override + public MLOutput predict(MLInputDataset inputDataset, MLModel model) { if (model == null) { throw new IllegalArgumentException("No model found for batch RCF prediction."); } - RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model.getContent()); - RandomCutForest forest = rcfMapper.toModel(state); - List> predictResult = process(dataFrame, forest, 0); - return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); + RandomCutForestState state = RCFModelSerDeSer.deserializeRCF(model); + forest = rcfMapper.toModel(state); + return predict(inputDataset); } @Override - public Model train(DataFrame dataFrame) { + public MLModel train(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); RandomCutForest forest = createRandomCutForest(dataFrame); Integer actualTrainingDataSize = trainingDataSize == null ? dataFrame.size() : trainingDataSize; process(dataFrame, forest, actualTrainingDataSize); - Model model = new Model(); - model.setName(FunctionName.BATCH_RCF.name()); - model.setVersion(1); + RandomCutForestState state = rcfMapper.toState(forest); - model.setContent(RCFModelSerDeSer.serializeRCF(state)); + MLModel model = MLModel.builder() + .name(FunctionName.BATCH_RCF.name()) + .algorithm(FunctionName.BATCH_RCF) + .version(1) + .content(encodeBase64(RCFModelSerDeSer.serializeRCF(state))) + .build(); return model; } @Override - public MLOutput trainAndPredict(DataFrame dataFrame) { + public MLOutput trainAndPredict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); RandomCutForest forest = createRandomCutForest(dataFrame); Integer actualTrainingDataSize = trainingDataSize == null ? dataFrame.size() : trainingDataSize; List> predictResult = process(dataFrame, forest, actualTrainingDataSize); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java index bd881e7beb..c2754d6eee 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/FixedInTimeRandomCutForest.java @@ -13,13 +13,15 @@ import com.amazon.randomcutforest.parkservices.state.ThresholdedRandomCutForestState; import lombok.extern.log4j.Log4j2; import org.opensearch.ml.common.FunctionName; -import org.opensearch.ml.common.Model; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.ColumnMeta; import org.opensearch.ml.common.dataframe.ColumnType; import org.opensearch.ml.common.dataframe.ColumnValue; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; import org.opensearch.ml.common.dataframe.Row; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.exception.MLValidationException; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.rcf.FitRCFParams; @@ -38,6 +40,8 @@ import java.util.Optional; import java.util.TimeZone; +import static org.opensearch.ml.engine.utils.ModelSerDeSer.encodeBase64; + /** * MLCommons doesn't support update trained model. So the trained RCF model in MLCommons * will be fixed in some time rather than updated by prediction data. We call it FIT(fixed @@ -67,6 +71,8 @@ public class FixedInTimeRandomCutForest implements TrainAndPredictable { private DateFormat simpleDateFormat; private static final ThresholdedRandomCutForestMapper trcfMapper = new ThresholdedRandomCutForestMapper(); + private ThresholdedRandomCutForest forest; + public FixedInTimeRandomCutForest(){} public FixedInTimeRandomCutForest(MLAlgoParams parameters) { @@ -93,31 +99,54 @@ public FixedInTimeRandomCutForest(MLAlgoParams parameters) { } } + + @Override + public void initModel(MLModel model, Map params) { + ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model); + this.forest = trcfMapper.toModel(state); + } + + @Override + public void close() { + this.forest = null; + } + @Override - public MLOutput predict(DataFrame dataFrame, Model model) { + public MLOutput predict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); + List> predictResult = process(dataFrame, forest); + return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); + } + + @Override + public MLOutput predict(MLInputDataset inputDataset, MLModel model) { if (model == null) { throw new IllegalArgumentException("No model found for FIT RCF prediction."); } - ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model.getContent()); - ThresholdedRandomCutForest forest = trcfMapper.toModel(state); - List> predictResult = process(dataFrame, forest); - return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); + ThresholdedRandomCutForestState state = RCFModelSerDeSer.deserializeTRCF(model); + forest = trcfMapper.toModel(state); + return predict(inputDataset); } @Override - public Model train(DataFrame dataFrame) { + public MLModel train(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); ThresholdedRandomCutForest forest = createThresholdedRandomCutForest(dataFrame); process(dataFrame, forest); - Model model = new Model(); - model.setName(FunctionName.FIT_RCF.name()); - model.setVersion(1); + ThresholdedRandomCutForestState state = trcfMapper.toState(forest); - model.setContent(RCFModelSerDeSer.serializeTRCF(state)); + MLModel model = MLModel.builder() + .name(FunctionName.FIT_RCF.name()) + .algorithm(FunctionName.FIT_RCF) + .version(1) + .content(encodeBase64(RCFModelSerDeSer.serializeTRCF(state))) + .build(); return model; } @Override - public MLOutput trainAndPredict(DataFrame dataFrame) { + public MLOutput trainAndPredict(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); ThresholdedRandomCutForest forest = createThresholdedRandomCutForest(dataFrame); List> predictResult = process(dataFrame, forest); return MLPredictionOutput.builder().predictionResult(DataFrameBuilder.load(predictResult)).build(); diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java index 8079268297..2e0d6dfc9c 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/rcf/RCFModelSerDeSer.java @@ -12,10 +12,13 @@ import io.protostuff.Schema; import io.protostuff.runtime.RuntimeSchema; import lombok.experimental.UtilityClass; +import org.opensearch.ml.common.MLModel; import java.security.AccessController; import java.security.PrivilegedAction; +import static org.opensearch.ml.engine.utils.ModelSerDeSer.decodeBase64; + @UtilityClass public class RCFModelSerDeSer { private static final int SERIALIZATION_BUFFER_BYTES = 512; @@ -34,10 +37,18 @@ public static byte[] serializeTRCF(ThresholdedRandomCutForestState model) { return serialize(model, trcfSchema); } + public static RandomCutForestState deserializeRCF(MLModel model) { + return deserializeRCF(decodeBase64(model.getContent())); + } + public static RandomCutForestState deserializeRCF(byte[] bytes) { return deserialize(bytes, rcfSchema); } + public static ThresholdedRandomCutForestState deserializeTRCF(MLModel model) { + return deserializeTRCF(decodeBase64(model.getContent())); + } + public static ThresholdedRandomCutForestState deserializeTRCF(byte[] bytes) { return deserialize(bytes, trcfSchema); } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java index db524e5b63..6d06174911 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LinearRegression.java @@ -5,14 +5,16 @@ package org.opensearch.ml.engine.algorithms.regression; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.parameter.regression.LinearRegressionParams; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.output.MLOutput; import org.opensearch.ml.common.output.MLPredictionOutput; -import org.opensearch.ml.common.Model; import org.opensearch.ml.engine.Predictable; import org.opensearch.ml.engine.Trainable; import org.opensearch.ml.engine.annotation.Function; @@ -42,6 +44,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.ml.engine.utils.ModelSerDeSer.serializeToBase64; + @Function(FunctionName.LINEAR_REGRESSION) public class LinearRegression implements Trainable, Predictable { @@ -71,6 +75,7 @@ public class LinearRegression implements Trainable, Predictable { private int loggingInterval; private int minibatchSize; private long seed; + private org.tribuo.Model regressionModel; public LinearRegression() {} @@ -191,13 +196,23 @@ private void validateParameters() { seed = Optional.ofNullable(parameters.getSeed()).orElse(DEFAULT_SEED); } + @Override - public MLOutput predict(DataFrame dataFrame, Model model) { - if (model == null) { - throw new IllegalArgumentException("No model found for linear regression prediction."); - } + public void initModel(MLModel model, Map params) { + this.regressionModel = (org.tribuo.Model) ModelSerDeSer.deserialize(model); + } + + @Override + public void close() { + this.regressionModel = null; + } - org.tribuo.Model regressionModel = (org.tribuo.Model) ModelSerDeSer.deserialize(model.getContent()); + @Override + public MLOutput predict(MLInputDataset inputDataset) { + if (regressionModel == null) { + throw new IllegalArgumentException("model not loaded"); + } + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); MutableDataset predictionDataset = TribuoUtil.generateDataset(dataFrame, new RegressionFactory(), "Linear regression prediction data from opensearch", TribuoOutputType.REGRESSOR); List> predictions = regressionModel.predict(predictionDataset); @@ -208,16 +223,29 @@ public MLOutput predict(DataFrame dataFrame, Model model) { } @Override - public Model train(DataFrame dataFrame) { + public MLOutput predict(MLInputDataset inputDataset, MLModel model) { + if (model == null) { + throw new IllegalArgumentException("No model found for linear regression prediction."); + } + + regressionModel = (org.tribuo.Model) ModelSerDeSer.deserialize(model); + return predict(inputDataset); + } + + @Override + public MLModel train(MLInputDataset inputDataset) { + DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame(); MutableDataset trainDataset = TribuoUtil.generateDatasetWithTarget(dataFrame, new RegressionFactory(), "Linear regression training data from opensearch", TribuoOutputType.REGRESSOR, parameters.getTarget()); Integer epochs = Optional.ofNullable(parameters.getEpochs()).orElse(DEFAULT_EPOCHS); LinearSGDTrainer linearSGDTrainer = new LinearSGDTrainer(objective, optimiser, epochs, loggingInterval, minibatchSize, seed); org.tribuo.Model regressionModel = linearSGDTrainer.train(trainDataset); - Model model = new Model(); - model.setName(FunctionName.LINEAR_REGRESSION.name()); - model.setVersion(1); - model.setContent(ModelSerDeSer.serialize(regressionModel)); + MLModel model = MLModel.builder() + .name(FunctionName.LINEAR_REGRESSION.name()) + .algorithm(FunctionName.LINEAR_REGRESSION) + .version(1) + .content(serializeToBase64(regressionModel)) + .build(); return model; } diff --git a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java index de6faed9c0..e7c766af08 100644 --- a/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java +++ b/ml-algorithms/src/main/java/org/opensearch/ml/engine/algorithms/regression/LogisticRegression.java @@ -5,10 +5,12 @@ package org.opensearch.ml.engine.algorithms.regression; -import org.opensearch.ml.common.Model; +import org.opensearch.ml.common.MLModel; import org.opensearch.ml.common.FunctionName; import org.opensearch.ml.common.dataframe.DataFrame; import org.opensearch.ml.common.dataframe.DataFrameBuilder; +import org.opensearch.ml.common.dataset.DataFrameInputDataset; +import org.opensearch.ml.common.dataset.MLInputDataset; import org.opensearch.ml.common.input.parameter.MLAlgoParams; import org.opensearch.ml.common.input.parameter.regression.LogisticRegressionParams; import org.opensearch.ml.common.output.MLOutput; @@ -41,6 +43,8 @@ import java.util.Map; import java.util.Optional; +import static org.opensearch.ml.engine.utils.ModelSerDeSer.serializeToBase64; + @Function(FunctionName.LOGISTIC_REGRESSION) public class LogisticRegression implements Trainable, Predictable { @@ -69,6 +73,8 @@ public class LogisticRegression implements Trainable, Predictable { private LogisticRegressionParams parameters; private StochasticGradientOptimiser optimiser; private LabelObjective objective; + private org.tribuo.Model