Skip to content

Commit

Permalink
Optimized data exchange between Python and Java
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed May 5, 2023
1 parent f6522bf commit 1cb8ebd
Show file tree
Hide file tree
Showing 3 changed files with 128 additions and 15 deletions.
13 changes: 8 additions & 5 deletions jpmml_evaluator/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,14 +162,17 @@ def evaluate(self, arguments, nan_as_missing = True):

def evaluateAll(self, arguments_df, nan_as_missing = True):
arguments_df = _canonicalizeAll(arguments_df, nan_as_missing = nan_as_missing)
argument_records = arguments_df.to_dict(orient = "records")
argument_records = self.backend.dumps(argument_records)
arguments_dict = {
"columns" : arguments_df.columns.tolist(),
"data" : arguments_df.values.tolist()
}
arguments = self.backend.dumps(arguments_dict)
try:
result_records = self.backend.staticInvoke("org.jpmml.evaluator.python.PythonUtil", "evaluateAll", self.javaEvaluator, argument_records)
results = self.backend.staticInvoke("org.jpmml.evaluator.python.PythonUtil", "evaluateAll", self.javaEvaluator, arguments)
except Exception as e:
raise self.backend.toJavaError(e)
result_records = self.backend.loads(result_records)
results_df = DataFrame.from_records(result_records)
results_dict = self.backend.loads(results)
results_df = DataFrame(data = results_dict["data"], columns = results_dict["columns"])
if hasattr(self, "dropColumns"):
for dropColumn in self.dropColumns:
results_df.drop(str(dropColumn), axis = 1, inplace = True)
Expand Down
Binary file not shown.
130 changes: 120 additions & 10 deletions src/main/java/org/jpmml/evaluator/python/PythonUtil.java
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,12 @@
import java.util.AbstractMap;
import java.util.ArrayList;
import java.util.Collection;
import java.util.HashMap;
import java.util.LinkedHashMap;
import java.util.List;
import java.util.Map;
import java.util.Set;
import java.util.stream.Collectors;

import net.razorvine.pickle.Pickler;
import net.razorvine.pickle.Unpickler;
Expand All @@ -42,25 +44,63 @@ private PythonUtil(){
}

static
public byte[] evaluateAll(Evaluator evaluator, byte[] listOfDictsBytes) throws IOException {
List<Map<String, ?>> batchedArguments = (List)unpickle(listOfDictsBytes);
public byte[] evaluateAll(Evaluator evaluator, byte[] dictBytes) throws IOException {
Map<String, ?> argumentsDict = (Map)unpickle(dictBytes);

List<Map<String, ?>> batchedResults = evaluateAll(evaluator, batchedArguments);
Map<String, ?> resultsDict = evaluateAll(evaluator, argumentsDict);

return pickle(batchedResults);
return pickle(resultsDict);
}

static
public List<Map<String, ?>> evaluateAll(Evaluator evaluator, List<Map<String, ?>> batchedArguments){
List<Map<String, ?>> batchedResults = new ArrayList<>();
public Map<String, ?> evaluateAll(Evaluator evaluator, Map<String, ?> argumentsDict){
ColumnMapper argumentMapper = new ColumnMapper((List<String>)argumentsDict.get("columns"));
List<List<?>> argumentData = (List<List<?>>)argumentsDict.get("data");

for(Map<String, ?> arguments : batchedArguments){
Map<String, ?> results = evaluate(evaluator, arguments);
ColumnMapper resultMapper = new ColumnMapper();
List<List<?>> resultData = new ArrayList<>();

batchedResults.add(results);
TabularArguments arguments = new TabularArguments(argumentMapper);

for(int i = 0; i < argumentData.size(); i++){
List<?> argumentRow = argumentData.get(i);

arguments.setRow(argumentRow);

Map<String, ?> results = evaluator.evaluate(arguments);

int numberOfColumns = resultMapper.size();

List<Object> resultRow = new ArrayList<>(numberOfColumns);
for(int j = 0; j < numberOfColumns; j++){
resultRow.add(null);
}

Collection<? extends Map.Entry<String, ?>> entries = results.entrySet();
for(Map.Entry<String, ?> entry : entries){
String name = entry.getKey();
Object value = EvaluatorUtil.decode(entry.getValue());

Integer index = resultMapper.putIfAbsent(name, numberOfColumns);
if(index == null){
resultRow.add(value);

numberOfColumns++;
} else

{
resultRow.set(index, value);
}
}

resultData.add(resultRow);
}

return batchedResults;
Map<String, Object> resultsDict = new HashMap<>();
resultsDict.put("columns", new ArrayList<>(resultMapper.getColumns()));
resultsDict.put("data", resultData);

return resultsDict;
}

static
Expand Down Expand Up @@ -161,6 +201,76 @@ private byte[] pickle(Object object) throws IOException {
return pickler.dumps(object);
}

static
private class ColumnMapper extends HashMap<String, Integer> {

public ColumnMapper(){
}

public ColumnMapper(List<String> columns){

for(int i = 0; i < columns.size(); i++){
String column = columns.get(i);

putIfAbsent(column, size());
}
}

public List<String> getColumns(){
Collection<Map.Entry<String, Integer>> entries = entrySet();

return entries.stream()
.sorted((left, right) -> {
return (left.getValue()).compareTo(right.getValue());
})
.map(entry -> entry.getKey())
.collect(Collectors.toList());
}
}

static
private class TabularArguments extends AbstractMap<String, Object> {

private ColumnMapper mapper = null;

private List<?> row = null;


public TabularArguments(ColumnMapper mapper){
this.mapper = mapper;
}

@Override
public Object get(Object key){
Integer index = this.mapper.get(key);

if(index != null){
return getValue(index);
}

return null;
}

@Override
public Set<Entry<String, Object>> entrySet(){
throw new UnsupportedOperationException();
}

public Object getValue(int index){
List<?> row = getRow();

return row.get(index);
}

public List<?> getRow(){
return this.row;
}

public void setRow(List<?> row){
this.row = row;
}
}

static {
ClassLoader clazzLoader = PythonUtil.class.getClassLoader();

Expand Down

0 comments on commit 1cb8ebd

Please sign in to comment.