Skip to content

Commit

Permalink
add input/output for custom model (opensearch-project#473)
Browse files Browse the repository at this point in the history
Signed-off-by: Yaliang Wu <[email protected]>

Signed-off-by: Yaliang Wu <[email protected]>
Signed-off-by: Sicheng Song <[email protected]>
  • Loading branch information
ylwu-amzn authored and b4sjoo committed Nov 15, 2022
1 parent 6df66a6 commit a2b9cb1
Show file tree
Hide file tree
Showing 21 changed files with 909 additions and 15 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,11 @@ public class CommonValue {
+ MLModel.MODEL_CONFIG_FIELD
+ "\" : {\"properties\":{\""
+ MODEL_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
+ EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"long\"},\""
+ EMBEDDING_DIMENSION_FIELD + "\":{\"type\":\"integer\"},\""
+ FRAMEWORK_TYPE_FIELD + "\":{\"type\":\"keyword\"},\""
+ ALL_CONFIG_FIELD + "\":{\"type\":\"text\"}}},\n"
+ " \""
+ MLModel.MODEL_CONTENT_HASH_FIELD
+ MLModel.MODEL_CONTENT_HASH_VALUE_FIELD
+ "\" : {\"type\": \"keyword\"},\n"
+ " \""
+ MLModel.CREATED_TIME_FIELD
Expand Down
7 changes: 4 additions & 3 deletions common/src/main/java/org/opensearch/ml/common/MLModel.java
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ public class MLModel implements ToXContentObject {
public static final String MODEL_FORMAT_FIELD = "model_format";
public static final String MODEL_STATE_FIELD = "model_state";
public static final String MODEL_CONTENT_SIZE_IN_BYTES_FIELD = "model_content_size_in_bytes";
public static final String MODEL_CONTENT_HASH_FIELD = "model_content_hash";
//SHA256 hash value of model content.
public static final String MODEL_CONTENT_HASH_VALUE_FIELD = "model_content_hash_value";
public static final String MODEL_CONFIG_FIELD = "model_config";
public static final String CREATED_TIME_FIELD = "created_time";
public static final String LAST_UPLOADED_TIME_FIELD = "last_uploaded_time";
Expand Down Expand Up @@ -206,7 +207,7 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
builder.field(MODEL_CONTENT_SIZE_IN_BYTES_FIELD, modelContentSizeInBytes);
}
if (modelContentHash != null) {
builder.field(MODEL_CONTENT_HASH_FIELD, modelContentHash);
builder.field(MODEL_CONTENT_HASH_VALUE_FIELD, modelContentHash);
}
if (modelConfig != null) {
builder.field(MODEL_CONFIG_FIELD, modelConfig);
Expand Down Expand Up @@ -303,7 +304,7 @@ public static MLModel parse(XContentParser parser) throws IOException {
case MODEL_CONTENT_SIZE_IN_BYTES_FIELD:
modelContentSizeInBytes = parser.longValue();
break;
case MODEL_CONTENT_HASH_FIELD:
case MODEL_CONTENT_HASH_VALUE_FIELD:
modelContentHash = parser.text();
break;
case MODEL_CONFIG_FIELD:
Expand Down
16 changes: 13 additions & 3 deletions common/src/main/java/org/opensearch/ml/common/MLTask.java
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,8 @@ public class MLTask implements ToXContentObject, Writeable {
private final MLInputDataType inputType;
private Float progress;
private final String outputIndex;
private final String workerNode;
@Setter
private String workerNode;
private final Instant createTime;
private Instant lastUpdateTime;
@Setter
Expand Down Expand Up @@ -100,7 +101,11 @@ public MLTask(StreamInput input) throws IOException {
this.taskType = input.readEnum(MLTaskType.class);
this.functionName = input.readEnum(FunctionName.class);
this.state = input.readEnum(MLTaskState.class);
this.inputType = input.readEnum(MLInputDataType.class);
if (input.readBoolean()) {
this.inputType = input.readEnum(MLInputDataType.class);
} else {
this.inputType = null;
}
this.progress = input.readOptionalFloat();
this.outputIndex = input.readOptionalString();
this.workerNode = input.readString();
Expand All @@ -122,7 +127,12 @@ public void writeTo(StreamOutput out) throws IOException {
out.writeEnum(taskType);
out.writeEnum(functionName);
out.writeEnum(state);
out.writeEnum(inputType);
if (inputType != null) {
out.writeBoolean(true);
out.writeEnum(inputType);
} else {
out.writeBoolean(false);
}
out.writeOptionalFloat(progress);
out.writeOptionalString(outputIndex);
out.writeString(workerNode);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -27,5 +27,6 @@ public enum MLTaskState {
RUNNING,
COMPLETED,
FAILED,
CANCELLED
CANCELLED,
COMPLETED_WITH_ERROR
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,5 +9,7 @@ public enum MLTaskType {
TRAINING,
PREDICTION,
TRAINING_AND_PREDICTION,
EXECUTION
EXECUTION,
UPLOAD_MODEL,
LOAD_MODEL
}
Original file line number Diff line number Diff line change
Expand Up @@ -7,5 +7,6 @@

public enum MLInputDataType {
SEARCH_QUERY,
DATA_FRAME
DATA_FRAME,
TEXT_DOCS
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,62 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.dataset;

import lombok.AccessLevel;
import lombok.Builder;
import lombok.Getter;
import lombok.experimental.FieldDefaults;
import org.opensearch.common.io.stream.StreamInput;
import org.opensearch.common.io.stream.StreamOutput;
import org.opensearch.ml.common.annotation.InputDataSet;
import org.opensearch.ml.common.output.model.ModelResultFilter;

import java.io.IOException;
import java.util.List;
import java.util.Objects;

@Getter
@FieldDefaults(makeFinal = true, level = AccessLevel.PRIVATE)
@InputDataSet(MLInputDataType.TEXT_DOCS)
public class TextDocsInputDataSet extends MLInputDataset{

private ModelResultFilter resultFilter;

private List<String> docs;

@Builder
public TextDocsInputDataSet(List<String> docs, ModelResultFilter resultFilter) {
super(MLInputDataType.TEXT_DOCS);
this.resultFilter = resultFilter;
Objects.requireNonNull(docs);
if (docs.size() == 0) {
throw new IllegalArgumentException("empty docs");
}
this.docs = docs;
}

public TextDocsInputDataSet(StreamInput streamInput) throws IOException {
super(MLInputDataType.TEXT_DOCS);
docs = streamInput.readStringList();
if (streamInput.readBoolean()) {
resultFilter = new ModelResultFilter(streamInput);
} else {
resultFilter = null;
}
}

@Override
public void writeTo(StreamOutput streamOutput) throws IOException {
super.writeTo(streamOutput);
streamOutput.writeStringCollection(docs);
if (resultFilter != null) {
streamOutput.writeBoolean(true);
resultFilter.writeTo(streamOutput);
} else {
streamOutput.writeBoolean(false);
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,15 @@ public MLException(Throwable cause) {
super(cause);
}

/**
* Constructor with specified error message adn cause.
* @param message error message
* @param cause exception cause
*/
public MLException(String message, Throwable cause) {
super(message, cause);
}

/**
* Returns if the exception should be counted in stats.
*
Expand Down
68 changes: 66 additions & 2 deletions common/src/main/java/org/opensearch/ml/common/input/MLInput.java
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,12 @@
import org.opensearch.ml.common.dataframe.DataFrame;
import org.opensearch.ml.common.dataframe.DefaultDataFrame;
import org.opensearch.ml.common.dataset.DataFrameInputDataset;
import org.opensearch.ml.common.output.model.ModelResultFilter;
import org.opensearch.ml.common.dataset.MLInputDataType;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.FunctionName;
import org.opensearch.ml.common.dataset.TextDocsInputDataSet;
import org.opensearch.ml.common.input.parameter.MLAlgoParams;
import org.opensearch.search.builder.SearchSourceBuilder;

Expand All @@ -40,6 +42,12 @@ public class MLInput implements Input {
public static final String INPUT_INDEX_FIELD = "input_index";
public static final String INPUT_QUERY_FIELD = "input_query";
public static final String INPUT_DATA_FIELD = "input_data";
// For trained model
public static final String RETURN_BYTES_FIELD = "return_bytes";
public static final String RETURN_NUMBER_FIELD = "return_number";
public static final String TARGET_RESPONSE_FIELD = "target_response";
public static final String TARGET_RESPONSE_POSITIONS_FIELD = "target_response_positions";
public static final String TEXT_DOCS_FIELD = "text_docs";

// Algorithm name
private FunctionName algorithm;
Expand All @@ -58,7 +66,8 @@ public MLInput(FunctionName algorithm, MLAlgoParams parameters, MLInputDataset i
this.inputDataset = inputDataset;
}

public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder, List<String> sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) {
public MLInput(FunctionName algorithm, MLAlgoParams parameters, SearchSourceBuilder searchSourceBuilder,
List<String> sourceIndices, DataFrame dataFrame, MLInputDataset inputDataset) {
validate(algorithm);
this.algorithm = algorithm;
this.parameters = parameters;
Expand Down Expand Up @@ -123,6 +132,25 @@ public XContentBuilder toXContent(XContentBuilder builder, Params params) throws
((DataFrameInputDataset)inputDataset).getDataFrame().toXContent(builder, EMPTY_PARAMS);
builder.endObject();
break;
case TEXT_DOCS:
TextDocsInputDataSet textInputDataSet = (TextDocsInputDataSet) this.inputDataset;
List<String> docs = textInputDataSet.getDocs();
ModelResultFilter resultFilter = textInputDataSet.getResultFilter();
if (docs != null && docs.size() > 0) {
builder.field(TEXT_DOCS_FIELD, docs.toArray(new String[0]));
}
if (resultFilter != null) {
builder.field(RETURN_BYTES_FIELD, resultFilter.isReturnBytes());
builder.field(RETURN_NUMBER_FIELD, resultFilter.isReturnNumber());
List<String> targetResponse = resultFilter.getTargetResponse();
if (targetResponse != null && targetResponse.size() > 0) {
builder.field(TARGET_RESPONSE_FIELD, targetResponse.toArray(new String[0]));
}
List<Integer> targetPositions = resultFilter.getTargetResponsePositions();
if (targetPositions != null && targetPositions.size() > 0) {
builder.field(TARGET_RESPONSE_POSITIONS_FIELD, targetPositions.toArray(new Integer[0]));
}
}
default:
break;
}
Expand All @@ -140,6 +168,12 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
List<String> sourceIndices = new ArrayList<>();
DataFrame dataFrame = null;

boolean returnBytes = false;
boolean returnNumber = true;
List<String> targetResponse = new ArrayList<>();
List<Integer> targetResponsePositions = new ArrayList<>();
List<String> textDocs = new ArrayList<>();

ensureExpectedToken(XContentParser.Token.START_OBJECT, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_OBJECT) {
String fieldName = parser.currentName();
Expand All @@ -161,12 +195,42 @@ public static MLInput parse(XContentParser parser, String inputAlgoName) throws
break;
case INPUT_DATA_FIELD:
dataFrame = DefaultDataFrame.parse(parser);
break;
case RETURN_BYTES_FIELD:
returnBytes = parser.booleanValue();
break;
case RETURN_NUMBER_FIELD:
returnNumber = parser.booleanValue();
break;
case TARGET_RESPONSE_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponse.add(parser.text());
}
break;
case TARGET_RESPONSE_POSITIONS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
targetResponsePositions.add(parser.intValue());
}
break;
case TEXT_DOCS_FIELD:
ensureExpectedToken(XContentParser.Token.START_ARRAY, parser.currentToken(), parser);
while (parser.nextToken() != XContentParser.Token.END_ARRAY) {
textDocs.add(parser.text());
}
break;
default:
parser.skipChildren();
break;
}
}
return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, null);
MLInputDataset inputDataSet = null;
if (algorithm == FunctionName.TEXT_EMBEDDING) {
ModelResultFilter filter = new ModelResultFilter(returnBytes, returnNumber, targetResponse, targetResponsePositions);
inputDataSet = new TextDocsInputDataSet(textDocs, filter);
}
return new MLInput(algorithm, mlParameters, searchSourceBuilder, sourceIndices, dataFrame, inputDataSet);
}

private MLInputDataset createInputDataSet(SearchSourceBuilder searchSourceBuilder, List<String> sourceIndices, DataFrame dataFrame) {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,7 @@ public static MLModelState from(String value) {
try {
return MLModelState.valueOf(value);
} catch (Exception e) {
throw new IllegalArgumentException("Wrong model format");
throw new IllegalArgumentException("Wrong model state");
}
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -9,4 +9,5 @@ public enum MLOutputType {
TRAINING,
PREDICTION,
SAMPLE_ALGO,
MODEL_TENSOR
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,63 @@
/*
* Copyright OpenSearch Contributors
* SPDX-License-Identifier: Apache-2.0
*/

package org.opensearch.ml.common.output.model;

public enum MLResultDataType {
FLOAT32(Format.FLOATING, 4),
FLOAT64(Format.FLOATING, 8),
FLOAT16(Format.FLOATING, 2),
UINT8(Format.UINT, 1),
INT32(Format.INT, 4),
INT8(Format.INT, 1),
INT64(Format.INT, 8),
BOOLEAN(Format.BOOLEAN, 1),
UNKNOWN(Format.UNKNOWN, 0),
STRING(Format.STRING, -1);

/** The general data type format categories. */
public enum Format {
FLOATING,
UINT,
INT,
BOOLEAN,
STRING,
UNKNOWN
}

private Format format;
private int numOfBytes;

MLResultDataType(Format format, int numOfBytes) {
this.format = format;
this.numOfBytes = numOfBytes;
}
/**
* Checks whether it is a floating data type.
*
* @return whether it is a floating data type
*/
public boolean isFloating() {
return format == Format.FLOATING;
}

/**
* Checks whether it is an integer data type.
*
* @return whether it is an integer type
*/
public boolean isInteger() {
return format == Format.UINT || format == Format.INT;
}

/**
* Checks whether it is a boolean data type.
*
* @return whether it is a boolean data type
*/
public boolean isBoolean() {
return format == Format.BOOLEAN;
}
}
Loading

0 comments on commit a2b9cb1

Please sign in to comment.