Skip to content

Commit

Permalink
Added 'input_float' transformation option
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Apr 26, 2024
1 parent d49d862 commit d2a5d48
Show file tree
Hide file tree
Showing 3 changed files with 52 additions and 4 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -113,6 +113,13 @@ public class Main {
)
private boolean compact = true;

@Parameter (
names = {"--X-" + HasXGBoostOptions.OPTION_INPUT_FLOAT},
description = "Allow field data type updates",
arity = 1
)
private Boolean inputFloat = null;

@Parameter (
names = {"--X-" + HasXGBoostOptions.OPTION_NUMERIC},
description = "Simplify non-numeric split conditions to numeric split conditions",
Expand Down Expand Up @@ -210,6 +217,7 @@ private void run() throws Exception {
Map<String, Object> options = new LinkedHashMap<>();
options.put(HasXGBoostOptions.OPTION_MISSING, this.missing);
options.put(HasXGBoostOptions.OPTION_COMPACT, this.compact);
options.put(HasXGBoostOptions.OPTION_INPUT_FLOAT, this.inputFloat);
options.put(HasXGBoostOptions.OPTION_NUMERIC, this.numeric);
options.put(HasXGBoostOptions.OPTION_PRUNE, this.prune);
options.put(HasXGBoostOptions.OPTION_NTREE_LIMIT, this.ntreeLimit);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ public interface HasXGBoostOptions extends HasOptions, HasNativeConfiguration {

String OPTION_COMPACT = "compact";

String OPTION_INPUT_FLOAT = "input_float";

String OPTION_MISSING = "missing";

String OPTION_NTREE_LIMIT = "ntree_limit";
Expand Down
46 changes: 42 additions & 4 deletions pmml-xgboost/src/main/java/org/jpmml/xgboost/Learner.java
Original file line number Diff line number Diff line change
Expand Up @@ -48,6 +48,8 @@
import org.dmg.pmml.DerivedField;
import org.dmg.pmml.Expression;
import org.dmg.pmml.Field;
import org.dmg.pmml.HasContinuousDomain;
import org.dmg.pmml.Interval;
import org.dmg.pmml.OpType;
import org.dmg.pmml.PMML;
import org.dmg.pmml.PMMLFunctions;
Expand Down Expand Up @@ -400,7 +402,7 @@ public Label encodeLabel(String targetName, List<String> targetCategories, XGBoo
}
}

public Schema toNumericFilteredSchema(Boolean numeric, Schema schema){
public Schema toNumericFilteredSchema(Boolean numeric, Boolean inputFloat, Schema schema){
FeatureTransformer function = new FeatureTransformer(){

private List<? extends Feature> features = schema.getFeatures();
Expand Down Expand Up @@ -441,7 +443,42 @@ public Feature transformNumerical(Feature feature){
case FLOAT:
break;
case DOUBLE:
continuousFeature = continuousFeature.toContinuousFeature(DataType.FLOAT);
{
if(inputFloat != null && inputFloat){
Field<?> field = continuousFeature.getField();

field.setDataType(DataType.FLOAT);

// XXX
if(field instanceof HasContinuousDomain){
HasContinuousDomain<?> hasContinuousDomain = (HasContinuousDomain<?>)field;

if(hasContinuousDomain.hasIntervals()){
List<Interval> intervals = hasContinuousDomain.getIntervals();

for(Interval interval : intervals){
Number leftMargin = interval.getLeftMargin();
Number rightMargin = interval.getRightMargin();

if(leftMargin != null){
interval.setLeftMargin(Math.min(leftMargin.doubleValue(), leftMargin.floatValue()));
} // End if

if(rightMargin != null){
interval.setRightMargin(Math.max(rightMargin.doubleValue(), rightMargin.floatValue()));
}
}
}
}

continuousFeature = new ContinuousFeature(continuousFeature.getEncoder(), field);
} else

{
continuousFeature = continuousFeature
.toContinuousFeature(DataType.FLOAT);
}
}
break;
default:
throw new IllegalArgumentException("Expected integer, float or double data type for continuous feature " + continuousFeature.getName() + ", got " + dataType.value() + " data type");
Expand Down Expand Up @@ -608,14 +645,15 @@ public MiningModel encodeModel(Map<String, ?> options, Schema schema){
}

public Schema configureSchema(Map<String, ?> options, Schema schema){
Boolean numeric = (Boolean)options.get(HasXGBoostOptions.OPTION_NUMERIC);
Number missing = (Number)options.get(HasXGBoostOptions.OPTION_MISSING);
Boolean inputFloat = (Boolean)options.get(HasXGBoostOptions.OPTION_INPUT_FLOAT);
Boolean numeric = (Boolean)options.get(HasXGBoostOptions.OPTION_NUMERIC);

if(numeric == null){
numeric = Boolean.TRUE;
} // End if

schema = toNumericFilteredSchema(numeric, schema);
schema = toNumericFilteredSchema(numeric, inputFloat, schema);

if(missing != null){
schema = toMissingFilteredSchema(missing, schema);
Expand Down

0 comments on commit d2a5d48

Please sign in to comment.