diff --git a/jpmml_evaluator/__init__.py b/jpmml_evaluator/__init__.py index c66ed4d..01e4e11 100644 --- a/jpmml_evaluator/__init__.py +++ b/jpmml_evaluator/__init__.py @@ -162,14 +162,17 @@ def evaluate(self, arguments, nan_as_missing = True): def evaluateAll(self, arguments_df, nan_as_missing = True): arguments_df = _canonicalizeAll(arguments_df, nan_as_missing = nan_as_missing) - argument_records = arguments_df.to_dict(orient = "records") - argument_records = self.backend.dumps(argument_records) + arguments_dict = { + "columns" : arguments_df.columns.tolist(), + "data" : arguments_df.values.tolist() + } + arguments = self.backend.dumps(arguments_dict) try: - result_records = self.backend.staticInvoke("org.jpmml.evaluator.python.PythonUtil", "evaluateAll", self.javaEvaluator, argument_records) + results = self.backend.staticInvoke("org.jpmml.evaluator.python.PythonUtil", "evaluateAll", self.javaEvaluator, arguments) except Exception as e: raise self.backend.toJavaError(e) - result_records = self.backend.loads(result_records) - results_df = DataFrame.from_records(result_records) + results_dict = self.backend.loads(results) + results_df = DataFrame(data = results_dict["data"], columns = results_dict["columns"]) if hasattr(self, "dropColumns"): for dropColumn in self.dropColumns: results_df.drop(str(dropColumn), axis = 1, inplace = True) diff --git a/jpmml_evaluator/resources/jpmml-evaluator-python-1.3-SNAPSHOT.jar b/jpmml_evaluator/resources/jpmml-evaluator-python-1.3-SNAPSHOT.jar index 6e25c3f..e6e32d8 100644 Binary files a/jpmml_evaluator/resources/jpmml-evaluator-python-1.3-SNAPSHOT.jar and b/jpmml_evaluator/resources/jpmml-evaluator-python-1.3-SNAPSHOT.jar differ diff --git a/src/main/java/org/jpmml/evaluator/python/PythonUtil.java b/src/main/java/org/jpmml/evaluator/python/PythonUtil.java index 9310ef3..1c08af4 100644 --- a/src/main/java/org/jpmml/evaluator/python/PythonUtil.java +++ b/src/main/java/org/jpmml/evaluator/python/PythonUtil.java @@ -22,10 +22,12 @@ import java.util.AbstractMap; import java.util.ArrayList; import java.util.Collection; +import java.util.HashMap; import java.util.LinkedHashMap; import java.util.List; import java.util.Map; import java.util.Set; +import java.util.stream.Collectors; import net.razorvine.pickle.Pickler; import net.razorvine.pickle.Unpickler; @@ -42,25 +44,63 @@ private PythonUtil(){ } static - public byte[] evaluateAll(Evaluator evaluator, byte[] listOfDictsBytes) throws IOException { - List> batchedArguments = (List)unpickle(listOfDictsBytes); + public byte[] evaluateAll(Evaluator evaluator, byte[] dictBytes) throws IOException { + Map argumentsDict = (Map)unpickle(dictBytes); - List> batchedResults = evaluateAll(evaluator, batchedArguments); + Map resultsDict = evaluateAll(evaluator, argumentsDict); - return pickle(batchedResults); + return pickle(resultsDict); } static - public List> evaluateAll(Evaluator evaluator, List> batchedArguments){ - List> batchedResults = new ArrayList<>(); + public Map evaluateAll(Evaluator evaluator, Map argumentsDict){ + ColumnMapper argumentMapper = new ColumnMapper((List)argumentsDict.get("columns")); + List> argumentData = (List>)argumentsDict.get("data"); - for(Map arguments : batchedArguments){ - Map results = evaluate(evaluator, arguments); + ColumnMapper resultMapper = new ColumnMapper(); + List> resultData = new ArrayList<>(); - batchedResults.add(results); + TabularArguments arguments = new TabularArguments(argumentMapper); + + for(int i = 0; i < argumentData.size(); i++){ + List argumentRow = argumentData.get(i); + + arguments.setRow(argumentRow); + + Map results = evaluator.evaluate(arguments); + + int numberOfColumns = resultMapper.size(); + + List resultRow = new ArrayList<>(numberOfColumns); + for(int j = 0; j < numberOfColumns; j++){ + resultRow.add(null); + } + + Collection> entries = results.entrySet(); + for(Map.Entry entry : entries){ + String name = entry.getKey(); + Object value = EvaluatorUtil.decode(entry.getValue()); + + Integer index = resultMapper.putIfAbsent(name, numberOfColumns); + if(index == null){ + resultRow.add(value); + + numberOfColumns++; + } else + + { + resultRow.set(index, value); + } + } + + resultData.add(resultRow); } - return batchedResults; + Map resultsDict = new HashMap<>(); + resultsDict.put("columns", new ArrayList<>(resultMapper.getColumns())); + resultsDict.put("data", resultData); + + return resultsDict; } static @@ -161,6 +201,76 @@ private byte[] pickle(Object object) throws IOException { return pickler.dumps(object); } + static + private class ColumnMapper extends HashMap { + + public ColumnMapper(){ + } + + public ColumnMapper(List columns){ + + for(int i = 0; i < columns.size(); i++){ + String column = columns.get(i); + + putIfAbsent(column, size()); + } + } + + public List getColumns(){ + Collection> entries = entrySet(); + + return entries.stream() + .sorted((left, right) -> { + return (left.getValue()).compareTo(right.getValue()); + }) + .map(entry -> entry.getKey()) + .collect(Collectors.toList()); + } + } + + static + private class TabularArguments extends AbstractMap { + + private ColumnMapper mapper = null; + + private List row = null; + + + public TabularArguments(ColumnMapper mapper){ + this.mapper = mapper; + } + + @Override + public Object get(Object key){ + Integer index = this.mapper.get(key); + + if(index != null){ + return getValue(index); + } + + return null; + } + + @Override + public Set> entrySet(){ + throw new UnsupportedOperationException(); + } + + public Object getValue(int index){ + List row = getRow(); + + return row.get(index); + } + + public List getRow(){ + return this.row; + } + + public void setRow(List row){ + this.row = row; + } + } + static { ClassLoader clazzLoader = PythonUtil.class.getClassLoader();