Skip to content

Commit

Permalink
Fix the risks found by PenTest (opensearch-project#76)
Browse files Browse the repository at this point in the history
* Create JvmService instance on demand.

Signed-off-by: Alex <[email protected]>

* Move the ml_parameters from XContent to the request parameters to avoid
the conflict with search XContent input.

Signed-off-by: Alex <[email protected]>

* Fix the security risks found by PenTest.
1. unhandled 500 server error.
2. Insecure Deserialization

Signed-off-by: Alex <[email protected]>

* Remove unnecessory '*' from the welcome list of model deserializer.

Signed-off-by: Alex <[email protected]>

Co-authored-by: Alex <[email protected]>
  • Loading branch information
2 people authored and jackiehanyang committed Nov 16, 2021
1 parent c2133bc commit 24a4c9c
Show file tree
Hide file tree
Showing 5 changed files with 49 additions and 13 deletions.
2 changes: 2 additions & 0 deletions ml-algorithms/build.gradle
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,9 @@ dependencies {
compile group: 'org.reflections', name: 'reflections', version: '0.9.12'
compile group: 'org.tribuo', name: 'tribuo-clustering-kmeans', version: '4.0.2'
compile group: 'org.tribuo', name: 'tribuo-regression-sgd', version: '4.0.2'
compile group: 'commons-io', name: 'commons-io', version: '2.11.0'
testCompile group: 'junit', name: 'junit', version: '4.12'
testImplementation group: 'org.mockito', name: 'mockito-core', version: '3.9.0'
}

jacocoTestReport {
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -13,16 +13,27 @@
package org.opensearch.ml.engine.utils;

import lombok.experimental.UtilityClass;
import org.apache.commons.io.serialization.ValidatingObjectInputStream;
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;

import java.io.ByteArrayInputStream;
import java.io.ByteArrayOutputStream;
import java.io.IOException;
import java.io.ObjectInputStream;
import java.io.ObjectOutputStream;

@UtilityClass
public class ModelSerDeSer {
// Welcome list includes OpenSearch ml plugin classes, JDK common classes and Tribuo libraries.
public static final String[] ACCEPT_CLASS_PATTERNS = {
"java.lang.*",
"java.util.*",
"java.time.*",
"org.opensearch.ml.*",
"*org.tribuo.*",
"com.oracle.labs.*",
"[*"
};

public static byte[] serialize(Object model) {
byte[] res = new byte[0];
try {
Expand All @@ -44,9 +55,13 @@ public static Object deserialize(byte[] modelBin) {
Object res;
try {
ByteArrayInputStream inputStream = new ByteArrayInputStream(modelBin);
ObjectInputStream objectInputStream = new ObjectInputStream(inputStream);
res = objectInputStream.readObject();
objectInputStream.close();
ValidatingObjectInputStream validatingObjectInputStream = new ValidatingObjectInputStream(inputStream);

// Validate the model class type to avoid deserialization attack.
validatingObjectInputStream.accept(ACCEPT_CLASS_PATTERNS);

res = validatingObjectInputStream.readObject();
validatingObjectInputStream.close();
inputStream.close();
} catch (IOException | ClassNotFoundException e) {
throw new ModelSerDeSerException("Failed to deserialize model.", e.getCause());
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,21 +12,42 @@

package org.opensearch.ml.engine;

import org.junit.Rule;
import org.junit.Test;
import org.junit.rules.ExpectedException;
import org.opensearch.ml.engine.algorithms.clustering.KMeans;
import org.opensearch.ml.engine.exceptions.ModelSerDeSerException;
import org.opensearch.ml.engine.utils.ModelSerDeSer;
import org.tribuo.clustering.kmeans.KMeansModel;

import static org.junit.Assert.assertTrue;
import java.util.ArrayList;
import java.util.Arrays;

import java.io.IOException;
import static org.junit.Assert.assertFalse;
import static org.junit.Assert.assertTrue;
import static org.opensearch.ml.engine.helper.KMeansHelper.constructKMeansDataFrame;

public class ModelSerDeSerTest {
private final DummyModel dummyModel = new DummyModel();
@Rule
public ExpectedException thrown = ExpectedException.none();

private final Object dummyModel = new Object();

@Test
public void testModelSerDeSer() throws IOException, ClassNotFoundException {
public void testModelSerDeSerBlocklModel() {
thrown.expect(ModelSerDeSerException.class);
byte[] modelBin = ModelSerDeSer.serialize(dummyModel);
DummyModel model = (DummyModel) ModelSerDeSer.deserialize(modelBin);
Object model = ModelSerDeSer.deserialize(modelBin);
assertTrue(model.equals(dummyModel));
}

@Test
public void testModelSerDeSerKMeans() {
KMeans kMeans = new KMeans(new ArrayList<>());
Model model = kMeans.train(constructKMeansDataFrame(100));

KMeansModel kMeansModel = (KMeansModel) ModelSerDeSer.deserialize(model.content);
byte[] serializedModel = ModelSerDeSer.serialize(kMeansModel);
assertFalse(Arrays.equals(serializedModel, model.content));
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -74,8 +74,7 @@ public void parseSearchQueryInput(MLInputDataset mlInputDataset, ActionListener<

client.search(searchRequest, ActionListener.wrap(r -> {
if (r == null || r.getHits() == null || r.getHits().getTotalHits() == null || r.getHits().getTotalHits().value == 0) {
// todo: add specific exception
listener.onFailure(new RuntimeException("No document found"));
listener.onFailure(new IllegalArgumentException("No document found"));
return;
}
SearchHits hits = r.getHits();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -21,7 +21,6 @@
import org.opensearch.ResourceNotFoundException;
import org.opensearch.action.ActionFuture;
import org.opensearch.action.ActionRequestValidationException;
import org.opensearch.common.io.stream.NotSerializableExceptionWrapper;
import org.opensearch.ml.common.dataset.MLInputDataset;
import org.opensearch.ml.common.dataset.SearchQueryInputDataset;
import org.opensearch.ml.common.transport.prediction.MLPredictionTaskAction;
Expand Down Expand Up @@ -99,6 +98,6 @@ public void testPredictionWithEmptyDataset() throws IOException {
emptySearchInputDataset
);
ActionFuture<MLPredictionTaskResponse> predictionFuture = client().execute(MLPredictionTaskAction.INSTANCE, predictionRequest);
expectThrows(NotSerializableExceptionWrapper.class, () -> predictionFuture.actionGet());
expectThrows(IllegalArgumentException.class, () -> predictionFuture.actionGet());
}
}

0 comments on commit 24a4c9c

Please sign in to comment.