Skip to content

Commit

Permalink
Improved the filtering of result fields
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed May 10, 2020
1 parent f9b291e commit d99ad4d
Show file tree
Hide file tree
Showing 7 changed files with 57 additions and 27 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.jpmml.evaluator.EvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluatorFactory;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.testing.Batch;
import org.jpmml.evaluator.testing.BatchUtil;
import org.jpmml.evaluator.testing.Conflict;
Expand Down Expand Up @@ -154,14 +155,14 @@ public void execute() throws Exception {

List<? extends Map<FieldName, ?>> outputRecords = BatchUtil.parseRecords(outputTable, cellParser);

Predicate<FieldName> predicate;
Predicate<ResultField> predicate;

if(this.ignoredFields != null && !this.ignoredFields.isEmpty()){
predicate = (FieldName name) -> !this.ignoredFields.contains(name);
predicate = (ResultField resultField) -> !this.ignoredFields.contains(resultField.getName());
} else

{
predicate = (FieldName name) -> true;
predicate = (ResultField resultField) -> true;
}

Equivalence<Object> equivalence = new PMMLEquivalence(this.precision, this.zeroThreshold);
Expand All @@ -178,7 +179,7 @@ public void execute() throws Exception {
}

static
private Batch createBatch(Evaluator evaluator, List<? extends Map<FieldName, ?>> input, List<? extends Map<FieldName, ?>> output, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
private Batch createBatch(Evaluator evaluator, List<? extends Map<FieldName, ?>> input, List<? extends Map<FieldName, ?>> output, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
Batch batch = new Batch(){

@Override
Expand All @@ -197,7 +198,7 @@ public Evaluator getEvaluator(){
}

@Override
public Predicate<FieldName> getPredicate(){
public Predicate<ResultField> getPredicate(){
return predicate;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,7 @@
import org.jpmml.evaluator.FieldNameSet;
import org.jpmml.evaluator.FunctionNameStack;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.visitors.DefaultModelEvaluatorBattery;
import org.jpmml.model.PMMLUtil;
import org.jpmml.model.visitors.VisitorBattery;
Expand All @@ -45,12 +46,12 @@ public class ArchiveBatch implements Batch {

private String dataset = null;

private Predicate<FieldName> predicate = null;
private Predicate<ResultField> predicate = null;

private Equivalence<Object> equivalence = null;


public ArchiveBatch(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
public ArchiveBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
setName(Objects.requireNonNull(name));
setDataset(Objects.requireNonNull(dataset));
setPredicate(Objects.requireNonNull(predicate));
Expand Down Expand Up @@ -149,11 +150,11 @@ private void setDataset(String dataset){
}

@Override
public Predicate<FieldName> getPredicate(){
public Predicate<ResultField> getPredicate(){
return this.predicate;
}

private void setPredicate(Predicate<FieldName> predicate){
private void setPredicate(Predicate<ResultField> predicate){
this.predicate = predicate;
}

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.base.Equivalence;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ResultField;

public interface Batch extends AutoCloseable {

Expand Down Expand Up @@ -55,7 +56,7 @@ public interface Batch extends AutoCloseable {
* (between expected and actual output data records).
* </p>
*/
Predicate<FieldName> getPredicate();
Predicate<ResultField> getPredicate();

/**
* <p>
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,6 @@
import java.util.LinkedHashSet;
import java.util.List;
import java.util.Map;
import java.util.Objects;
import java.util.Set;
import java.util.function.Function;
import java.util.function.Predicate;
Expand All @@ -36,6 +35,9 @@
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.EvaluatorUtil;
import org.jpmml.evaluator.HasGroupFields;
import org.jpmml.evaluator.OutputField;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.TargetField;

public class BatchUtil {

Expand All @@ -59,7 +61,29 @@ public List<Conflict> evaluate(Batch batch) throws Exception {
throw new IllegalArgumentException("Expected the same number of data rows, got " + input.size() + " input data rows and " + output.size() + " expected output data rows");
}

Predicate<FieldName> predicate = (batch.getPredicate()).and(name -> !Objects.equals(Evaluator.DEFAULT_TARGET_NAME, name));
Predicate<ResultField> predicate = batch.getPredicate();

Set<FieldName> names = new LinkedHashSet<>();

List<TargetField> targetFields = evaluator.getTargetFields();
for(TargetField targetField : targetFields){

if(targetField.isSynthetic()){
continue;
} // End if

if(predicate.test(targetField)){
names.add(targetField.getName());
}
}

List<OutputField> outputFields = evaluator.getOutputFields();
for(OutputField outputField : outputFields){

if(predicate.test(outputField)){
names.add(outputField.getName());
}
}

Equivalence<Object> equivalence = batch.getEquivalence();

Expand All @@ -69,11 +93,11 @@ public List<Conflict> evaluate(Batch batch) throws Exception {
Map<FieldName, ?> arguments = input.get(i);

Map<FieldName, ?> expectedResults = output.get(i);
expectedResults = Maps.filterKeys(expectedResults, predicate::test);
expectedResults = Maps.filterKeys(expectedResults, names::contains);

try {
Map<FieldName, ?> actualResults = evaluator.evaluate(arguments);
actualResults = Maps.filterKeys(actualResults, predicate::test);
actualResults = Maps.filterKeys(actualResults, names::contains);

MapDifference<FieldName, ?> difference = Maps.<FieldName, Object>difference(expectedResults, actualResults, equivalence);
if(!difference.areEqual()){
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
import com.google.common.base.Equivalence;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.Evaluator;
import org.jpmml.evaluator.ResultField;

public class FilterBatch implements Batch {

Expand Down Expand Up @@ -57,7 +58,7 @@ public Evaluator getEvaluator() throws Exception {
}

@Override
public Predicate<FieldName> getPredicate(){
public Predicate<ResultField> getPredicate(){
Batch batch = getBatch();

return batch.getPredicate();
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -20,11 +20,13 @@

import java.util.Arrays;
import java.util.Collection;
import java.util.LinkedHashSet;
import java.util.Objects;
import java.util.function.Predicate;

import com.google.common.base.Equivalence;
import org.dmg.pmml.FieldName;
import org.jpmml.evaluator.ResultField;

abstract
public class IntegrationTest extends BatchTest {
Expand All @@ -40,18 +42,18 @@ public void evaluate(String name, String dataset) throws Exception {
evaluate(name, dataset, null, null);
}

public void evaluate(String name, String dataset, Predicate<FieldName> predicate) throws Exception {
public void evaluate(String name, String dataset, Predicate<ResultField> predicate) throws Exception {
evaluate(name, dataset, predicate, null);
}

public void evaluate(String name, String dataset, Equivalence<Object> equivalence) throws Exception {
evaluate(name, dataset, null, equivalence);
}

public void evaluate(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence) throws Exception {
public void evaluate(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence) throws Exception {

if(predicate == null){
predicate = (x -> true);
predicate = (resultField -> true);
} // End if

if(equivalence == null){
Expand All @@ -63,7 +65,7 @@ public void evaluate(String name, String dataset, Predicate<FieldName> predicate
}
}

protected Batch createBatch(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
protected Batch createBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
Batch result = new IntegrationTestBatch(name, dataset, predicate, equivalence){

@Override
Expand All @@ -84,17 +86,17 @@ private void setEquivalence(Equivalence<Object> equivalence){
}

static
public Predicate<FieldName> excludeFields(FieldName... names){
return excludeFields(Arrays.asList(names));
public Predicate<ResultField> excludeFields(FieldName... names){
return excludeFields(new LinkedHashSet<>(Arrays.asList(names)));
}

static
public Predicate<FieldName> excludeFields(Collection<FieldName> names){
Predicate<FieldName> predicate = new Predicate<FieldName>(){
public Predicate<ResultField> excludeFields(Collection<FieldName> names){
Predicate<ResultField> predicate = new Predicate<ResultField>(){

@Override
public boolean test(FieldName name){
return !names.contains(name);
public boolean test(ResultField resultField){
return !names.contains(resultField.getName());
}
};

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,6 @@

import com.google.common.base.Equivalence;
import org.dmg.pmml.Application;
import org.dmg.pmml.FieldName;
import org.dmg.pmml.MiningSchema;
import org.dmg.pmml.PMML;
import org.dmg.pmml.Visitor;
Expand All @@ -35,6 +34,7 @@
import org.jpmml.evaluator.EvaluatorBuilder;
import org.jpmml.evaluator.ModelEvaluatorBuilder;
import org.jpmml.evaluator.OutputFilters;
import org.jpmml.evaluator.ResultField;
import org.jpmml.evaluator.visitors.InvalidMarkupInspector;
import org.jpmml.evaluator.visitors.UnsupportedMarkupInspector;
import org.jpmml.model.SerializationUtil;
Expand All @@ -46,7 +46,7 @@ public class IntegrationTestBatch extends ArchiveBatch {
private Evaluator evaluator = null;


public IntegrationTestBatch(String name, String dataset, Predicate<FieldName> predicate, Equivalence<Object> equivalence){
public IntegrationTestBatch(String name, String dataset, Predicate<ResultField> predicate, Equivalence<Object> equivalence){
super(name, dataset, predicate, equivalence);
}

Expand Down

0 comments on commit d99ad4d

Please sign in to comment.