Skip to content

Commit

Permalink
Added 'input_float' model transformation option. Fixes #128
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Apr 27, 2023
1 parent 2487206 commit 433ce1c
Show file tree
Hide file tree
Showing 5 changed files with 46 additions and 0 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,12 +23,17 @@
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;

import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.xgboost.HasXGBoostOptions;
Expand Down Expand Up @@ -60,6 +65,41 @@ public <M extends Model<M> & HasPredictionCol & GeneralParams, C extends ModelCo
throw new RuntimeException(ioe);
}

Boolean inputFloat = (Boolean)converter.getOption(HasSparkMLXGBoostOptions.OPTION_INPUT_FLOAT, null);
if((Boolean.TRUE).equals(inputFloat)){
Function<Feature, Feature> function = new Function<Feature, Feature>(){

@Override
public Feature apply(Feature feature){

if(feature instanceof ContinuousFeature){
ContinuousFeature continuousFeature = (ContinuousFeature)feature;

DataType dataType = continuousFeature.getDataType();
switch(dataType){
case INTEGER:
case FLOAT:
break;
case DOUBLE:
{
Field<?> field = continuousFeature.getField();

field.setDataType(DataType.FLOAT);

return new ContinuousFeature(continuousFeature.getEncoder(), field);
}
default:
break;
}
}

return feature;
}
};

schema = schema.toTransformedSchema(function);
}

Float missing = model.getMissing();
if(missing.isNaN()){
missing = null;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -22,4 +22,6 @@
import org.jpmml.xgboost.HasXGBoostOptions;

public interface HasSparkMLXGBoostOptions extends HasSparkMLOptions, HasXGBoostOptions {

String OPTION_INPUT_FLOAT = "input_float";
}
Original file line number Diff line number Diff line change
Expand Up @@ -35,6 +35,7 @@
import org.jpmml.model.visitors.AbstractVisitor;
import org.jpmml.sparkml.testing.SparkMLEncoderBatch;
import org.jpmml.sparkml.testing.SparkMLEncoderBatchTest;
import org.jpmml.sparkml.xgboost.HasSparkMLXGBoostOptions;
import org.jpmml.xgboost.HasXGBoostOptions;
import org.junit.AfterClass;
import org.junit.BeforeClass;
Expand All @@ -61,7 +62,10 @@ public XGBoostTest getArchiveBatchTest(){
public List<Map<String, Object>> getOptionsMatrix(){
Map<String, Object> options = new LinkedHashMap<>();

options.put(HasSparkMLXGBoostOptions.OPTION_INPUT_FLOAT, new Boolean[]{false, true});

options.put(HasXGBoostOptions.OPTION_COMPACT, new Boolean[]{false, true});
options.put(HasXGBoostOptions.OPTION_PRUNE, false);

return OptionsUtil.generateOptionsMatrix(options);
}
Expand Down
Binary file not shown.
Binary file modified pmml-sparkml-xgboost/src/test/resources/pipeline/XGBoostIris.zip
Binary file not shown.

0 comments on commit 433ce1c

Please sign in to comment.