Skip to content

Commit

Permalink
refactor ML algorithm package for supporting custom model (opensearch…
Browse files Browse the repository at this point in the history
…-project#474)

Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
ylwu-amzn authored and b4sjoo committed Nov 15, 2022
1 parent a2b9cb1 commit 161ccb2
Show file tree
Hide file tree
Showing 40 changed files with 689 additions and 349 deletions.
1 change: 1 addition & 0 deletions build-tools/repositories.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@ repositories {
mavenLocal()
maven { url "https://aws.oss.sonatype.org/content/repositories/snapshots" }
mavenCentral()
maven {url 'https://oss.sonatype.org/content/repositories/snapshots/'}
maven { url "https://d1nvenhzbhpy0q.cloudfront.net/snapshots/lucene/" }
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@
import org.opensearch.action.delete.DeleteResponse;
import org.opensearch.action.search.SearchRequest;
import org.opensearch.action.search.SearchResponse;
import org.opensearch.client.node.NodeClient;
import org.opensearch.client.Client;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.output.MLOutput;
Expand All @@ -38,17 +38,16 @@
@RequiredArgsConstructor
public class MachineLearningNodeClient implements MachineLearningClient {

NodeClient client;
Client client;

@Override
public void predict(String modelId, MLInput mlInput, ActionListener<MLOutput> listener) {
validateMLInput(mlInput, true);

MLPredictionTaskRequest predictionRequest = MLPredictionTaskRequest.builder()
.mlInput(mlInput)
.modelId(modelId)
.build();

.mlInput(mlInput)
.modelId(modelId)
.build();
client.execute(MLPredictionTaskAction.INSTANCE, predictionRequest, getMlPredictionTaskResponseActionListener(listener));
}

Expand Down Expand Up @@ -80,9 +79,18 @@ public void getModel(String modelId, ActionListener<MLModel> listener) {
.modelId(modelId)
.build();

client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, ActionListener.wrap(response -> {
listener.onResponse(MLModelGetResponse.fromActionResponse(response).getMlModel());
}, listener::onFailure));
client.execute(MLModelGetAction.INSTANCE, mlModelGetRequest, getMlGetModelResponseActionListener(listener));
}

private ActionListener<MLModelGetResponse> getMlGetModelResponseActionListener(ActionListener<MLModel> listener) {
ActionListener<MLModelGetResponse> internalListener = ActionListener.wrap(predictionResponse -> {
listener.onResponse(predictionResponse.getMlModel());
}, listener::onFailure);
ActionListener<MLModelGetResponse> actionListener = wrapActionListener(internalListener, res -> {
MLModelGetResponse getResponse = MLModelGetResponse.fromActionResponse(res);
return getResponse;
});
return actionListener;
}

@Override
Expand Down
7 changes: 0 additions & 7 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -95,13 +95,6 @@ public MLModel(String name, FunctionName algorithm, Integer version, String cont
this.totalChunks = totalChunks;
}

public MLModel(FunctionName algorithm, Model model) {
this.name = model.getName();
this.algorithm = algorithm;
this.version = model.getVersion();
this.content = Base64.getEncoder().encodeToString(model.getContent());
}

public MLModel(StreamInput input) throws IOException{
name = input.readOptionalString();
algorithm = input.readEnum(FunctionName.class);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ public MLException(Throwable cause) {
}

/**
* Constructor with specified error message adn cause.
* Constructor with specified error message and cause.
* @param message error message
* @param cause exception cause
*/
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -42,11 +42,17 @@ public class MLInput implements Input {
public static final String INPUT_INDEX_FIELD = "input_index";
public static final String INPUT_QUERY_FIELD = "input_query";
public static final String INPUT_DATA_FIELD = "input_data";

// For trained model
// Return bytes in model output
public static final String RETURN_BYTES_FIELD = "return_bytes";
// Return bytes in model output. This can be used together with return_bytes.
public static final String RETURN_NUMBER_FIELD = "return_number";
// Filter target response with name in model output
public static final String TARGET_RESPONSE_FIELD = "target_response";
// Filter target response with position in model output
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
// Input text sentences for text embedding model
public static final String TEXT_DOCS_FIELD = "text_docs";

// Algorithm name
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -66,23 +66,23 @@ public void writeTo(StreamOutput out) throws IOException {

public void filter(ModelResultFilter resultFilter) {
boolean returnBytes = resultFilter.isReturnBytes();
boolean returnNUmber = resultFilter.isReturnNumber();
boolean returnNumber = resultFilter.isReturnNumber();
List<String> targetResponse = resultFilter.getTargetResponse();
List<Integer> targetResponsePositions = resultFilter.getTargetResponsePositions();
if ((targetResponse == null || targetResponse.size() == 0)
&& (targetResponsePositions == null || targetResponsePositions.size() == 0)) {
mlModelTensors.forEach(output -> filter(output, returnBytes, returnNUmber));
mlModelTensors.forEach(output -> filter(output, returnBytes, returnNumber));
return;
}
List<ModelTensor> targetOutput = new ArrayList<>();
if (mlModelTensors != null) {
for (int i = 0 ; i<mlModelTensors.size(); i++) {
ModelTensor output = mlModelTensors.get(i);
if (targetResponse != null && targetResponse.contains(output.getName())) {
filter(output, returnBytes, returnNUmber);
filter(output, returnBytes, returnNumber);
targetOutput.add(output);
} else if (targetResponsePositions != null && targetResponsePositions.contains(i)) {
filter(output, returnBytes, returnNUmber);
filter(output, returnBytes, returnNumber);
targetOutput.add(output);
}
}
Expand Down
34 changes: 25 additions & 9 deletions ml-algorithms/src/main/java/org/opensearch/ml/engine/MLEngine.java
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,47 @@

package org.opensearch.ml.engine;

import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.input.Input;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.ml.common.input.MLInput;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.output.Output;

import java.util.Map;

/**
* This is the interface to all ml algorithms.
*/
public class MLEngine {

public static Model train(Input input) {
public static MLModel train(Input input) {
validateMLInput(input);
MLInput mlInput = (MLInput) input;
Trainable trainable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
if (trainable == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
}
return trainable.train(mlInput.getDataFrame());
return trainable.train(mlInput.getInputDataset());
}

public static Predictable load(MLModel mlModel, Map<String, Object> params) {
Predictable predictable = MLEngineClassLoader.initInstance(mlModel.getAlgorithm(), null, MLAlgoParams.class);
predictable.initModel(mlModel, params);
return predictable;
}

public static MLOutput predict(Input input, Model model) {
public static MLOutput predict(Input input, MLModel model) {
validateMLInput(input);
MLInput mlInput = (MLInput) input;
Predictable predictable = MLEngineClassLoader.initInstance(mlInput.getAlgorithm(), mlInput.getParameters(), MLAlgoParams.class);
if (predictable == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
}
return predictable.predict(mlInput.getDataFrame(), model);
return predictable.predict(mlInput.getInputDataset(), model);
}

public static MLOutput trainAndPredict(Input input) {
Expand All @@ -45,7 +55,7 @@ public static MLOutput trainAndPredict(Input input) {
if (trainAndPredictable == null) {
throw new IllegalArgumentException("Unsupported algorithm: " + mlInput.getAlgorithm());
}
return trainAndPredictable.trainAndPredict(mlInput.getDataFrame());
return trainAndPredictable.trainAndPredict(mlInput.getInputDataset());
}

public static Output execute(Input input) {
Expand All @@ -63,9 +73,15 @@ private static void validateMLInput(Input input) {
throw new IllegalArgumentException("Input should be MLInput");
}
MLInput mlInput = (MLInput) input;
DataFrame dataFrame = mlInput.getDataFrame();
if (dataFrame == null || dataFrame.size() == 0) {
throw new IllegalArgumentException("Input data frame should not be null or empty");
MLInputDataset inputDataset = mlInput.getInputDataset();
if (inputDataset == null) {
throw new IllegalArgumentException("Input data set should not be null");
}
if (inputDataset instanceof DataFrameInputDataset) {
DataFrame dataFrame = ((DataFrameInputDataset)inputDataset).getDataFrame();
if (dataFrame == null || dataFrame.size() == 0) {
throw new IllegalArgumentException("Input data frame should not be null or empty");
}
}
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -5,21 +5,42 @@

package org.opensearch.ml.engine;

import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.output.MLOutput;
import org.opensearch.ml.common.Model;

import java.util.Map;

/**
* This is machine learning algorithms predict interface.
*/
public interface Predictable {

/**
* Predict with given features and model (optional).
* @param dataFrame features data
* Predict with given input data and model (optional).
* Will reload model into memory with model content.
* @param inputDataset input data set
* @param model the java serialized model
* @return predicted results
*/
MLOutput predict(DataFrame dataFrame, Model model);
MLOutput predict(MLInputDataset inputDataset, MLModel model);

/**
* Predict with given input data with loaded model.
* @param inputDataset input data set
* @return predicted results
*/
MLOutput predict(MLInputDataset inputDataset);

/**
* Init model (load model into memory) with ML model content and params.
* @param model ML model
* @param params other parameters
*/
void initModel(MLModel model, Map<String, Object> params);

/**
* Close resources like loaded model.
*/
void close();
}
Original file line number Diff line number Diff line change
Expand Up @@ -5,20 +5,20 @@

package org.opensearch.ml.engine;

import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.output.MLOutput;


/**
* This is machine learning algorithms train interface.
* This is machine learning algorithms train and predict interface.
*/
public interface TrainAndPredictable extends Trainable, Predictable {

/**
* Train model with given features. Then predict with the same data.
* @param dataFrame training data
* @return the java serialized model
* Train model with given input data. Then predict with the same data.
* @param inputDataset training data
* @return ML model with serialized model content
*/
MLOutput trainAndPredict(DataFrame dataFrame);
MLOutput trainAndPredict(MLInputDataset inputDataset);

}
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@

package org.opensearch.ml.engine;

import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.Model;
import org.opensearch.ml.common.MLModel;
import org.opensearch.ml.common.dataset.MLInputDataset;

/**
* This is machine learning algorithms train interface.
Expand All @@ -15,9 +15,9 @@ public interface Trainable {

/**
* Train model with given features.
* @param dataFrame training data
* @return the java serialized model
* @param inputDataset training data
* @return ML model with serialized model content
*/
Model train(DataFrame dataFrame);
MLModel train(MLInputDataset inputDataset);

}
Loading

0 comments on commit 161ccb2

Please sign in to comment.