Skip to content

Commit

Permalink
Updated JPMML-XGBoost dependency
Browse files Browse the repository at this point in the history
See #128
  • Loading branch information
vruusmann committed Apr 27, 2024
1 parent ef6cf30 commit 49a5531
Show file tree
Hide file tree
Showing 6 changed files with 10 additions and 77 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -23,18 +23,13 @@
import java.io.InputStream;
import java.util.LinkedHashMap;
import java.util.Map;
import java.util.function.Function;

import com.google.common.io.MoreFiles;
import ml.dmlc.xgboost4j.scala.Booster;
import ml.dmlc.xgboost4j.scala.spark.params.GeneralParams;
import org.apache.spark.ml.Model;
import org.apache.spark.ml.param.shared.HasPredictionCol;
import org.dmg.pmml.DataType;
import org.dmg.pmml.Field;
import org.dmg.pmml.mining.MiningModel;
import org.jpmml.converter.ContinuousFeature;
import org.jpmml.converter.Feature;
import org.jpmml.converter.Schema;
import org.jpmml.sparkml.ModelConverter;
import org.jpmml.xgboost.HasXGBoostOptions;
Expand All @@ -47,7 +42,7 @@ private BoosterUtil(){
}

static
public <M extends Model<M> & HasPredictionCol & GeneralParams, C extends ModelConverter<M> & HasSparkMLXGBoostOptions> MiningModel encodeBooster(C converter, Booster booster, Schema schema){
public <M extends Model<M> & HasPredictionCol & GeneralParams, C extends ModelConverter<M>> MiningModel encodeBooster(C converter, Booster booster, Schema schema){
M model = converter.getModel();

Learner learner;
Expand All @@ -66,41 +61,6 @@ public <M extends Model<M> & HasPredictionCol & GeneralParams, C extends ModelCo
throw new RuntimeException(e);
}

Boolean inputFloat = (Boolean)converter.getOption(HasSparkMLXGBoostOptions.OPTION_INPUT_FLOAT, null);
if((Boolean.TRUE).equals(inputFloat)){
Function<Feature, Feature> function = new Function<Feature, Feature>(){

@Override
public Feature apply(Feature feature){

if(feature instanceof ContinuousFeature){
ContinuousFeature continuousFeature = (ContinuousFeature)feature;

DataType dataType = continuousFeature.getDataType();
switch(dataType){
case INTEGER:
case FLOAT:
break;
case DOUBLE:
{
Field<?> field = continuousFeature.getField();

field.setDataType(DataType.FLOAT);

return new ContinuousFeature(continuousFeature.getEncoder(), field);
}
default:
break;
}
}

return feature;
}
};

schema = schema.toTransformedSchema(function);
}

Float missing = model.getMissing();
if(missing.isNaN()){
missing = null;
Expand All @@ -113,10 +73,12 @@ public Feature apply(Feature feature){
options.put(HasXGBoostOptions.OPTION_PRUNE, converter.getOption(HasXGBoostOptions.OPTION_PRUNE, false));
options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, converter.getOption(HasXGBoostOptions.OPTION_NTREE_LIMIT, null));

Boolean numeric = (Boolean)options.get(HasXGBoostOptions.OPTION_NUMERIC);
Schema xgbSchema = learner.configureSchema(options, schema);

MiningModel miningModel = learner.encodeModel(options, xgbSchema);

Schema xgbSchema = learner.toXGBoostSchema(numeric, schema);
miningModel = learner.configureModel(options, miningModel);

return learner.encodeMiningModel(options, xgbSchema);
return miningModel;
}
}

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@
import org.jpmml.converter.mining.MiningModelUtil;
import org.jpmml.sparkml.ProbabilisticClassificationModelConverter;

public class XGBoostClassificationModelConverter extends ProbabilisticClassificationModelConverter<XGBoostClassificationModel> implements HasSparkMLXGBoostOptions {
public class XGBoostClassificationModelConverter extends ProbabilisticClassificationModelConverter<XGBoostClassificationModel> {

public XGBoostClassificationModelConverter(XGBoostClassificationModel model){
super(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@
import org.jpmml.sparkml.SparkMLEncoder;
import org.jpmml.sparkml.model.HasPredictionModelOptions;

public class XGBoostRegressionModelConverter extends PredictionModelConverter<XGBoostRegressionModel> implements HasSparkMLXGBoostOptions {
public class XGBoostRegressionModelConverter extends PredictionModelConverter<XGBoostRegressionModel> {

public XGBoostRegressionModelConverter(XGBoostRegressionModel model){
super(model);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -51,7 +51,6 @@
import org.jpmml.sparkml.PipelineModelUtil;
import org.jpmml.sparkml.testing.SparkMLEncoderBatch;
import org.jpmml.sparkml.testing.SparkMLEncoderBatchTest;
import org.jpmml.sparkml.xgboost.HasSparkMLXGBoostOptions;
import org.jpmml.xgboost.HasXGBoostOptions;
import org.junit.AfterClass;
import org.junit.BeforeClass;
Expand All @@ -78,9 +77,8 @@ public XGBoostTest getArchiveBatchTest(){
public List<Map<String, Object>> getOptionsMatrix(){
Map<String, Object> options = new LinkedHashMap<>();

options.put(HasSparkMLXGBoostOptions.OPTION_INPUT_FLOAT, new Boolean[]{false, true});

options.put(HasXGBoostOptions.OPTION_COMPACT, new Boolean[]{false, true});
options.put(HasXGBoostOptions.OPTION_INPUT_FLOAT, new Boolean[]{false, true});
options.put(HasXGBoostOptions.OPTION_PRUNE, false);

return OptionsUtil.generateOptionsMatrix(options);
Expand Down
2 changes: 1 addition & 1 deletion pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -113,7 +113,7 @@
<dependency>
<groupId>org.jpmml</groupId>
<artifactId>pmml-xgboost</artifactId>
<version>1.8.3</version>
<version>1.8.4</version>
</dependency>

<dependency>
Expand Down

0 comments on commit 49a5531

Please sign in to comment.