Skip to content

Commit

Permalink
Refactored the parsing of 'pandas_categorical' attribute
Browse files Browse the repository at this point in the history
  • Loading branch information
vruusmann committed Dec 14, 2023
1 parent 1656f58 commit a2eb6ca
Show file tree
Hide file tree
Showing 6 changed files with 93 additions and 27 deletions.
1 change: 1 addition & 0 deletions NOTICE.txt
Original file line number Diff line number Diff line change
@@ -1,4 +1,5 @@
JPMML-LightGBM includes third-party dependencies that are released under the Apache License, Version 2.0:
* Gson - https://github.com/google/gson
* Guava - https://github.com/google/guava
* JCommander - http://jcommander.org

Expand Down
25 changes: 5 additions & 20 deletions pmml-lightgbm/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -39,6 +39,11 @@
<scope>provided</scope>
</dependency>

<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
</dependency>

<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Expand All @@ -65,28 +70,8 @@
<artifactId>maven-javadoc-plugin</artifactId>
<configuration>
<javadocVersion>1.8</javadocVersion>
<sourcepath>${basedir}/src/main/java</sourcepath>
</configuration>
</plugin>
<plugin>
<groupId>org.codehaus.mojo</groupId>
<artifactId>javacc-maven-plugin</artifactId>
<version>3.0.1</version>
<executions>
<execution>
<goals>
<goal>javacc</goal>
</goals>
</execution>
</executions>
<dependencies>
<dependency>
<groupId>net.java.dev.javacc</groupId>
<artifactId>javacc</artifactId>
<version>7.0.13</version>
</dependency>
</dependencies>
</plugin>
</plugins>
</build>
</project>
4 changes: 2 additions & 2 deletions pmml-lightgbm/src/main/java/org/jpmml/lightgbm/GBDT.java
Original file line number Diff line number Diff line change
Expand Up @@ -68,7 +68,7 @@ public class GBDT {

private Map<String, String> feature_importances = Collections.emptyMap();

private List<List<Object>> pandas_categorical = Collections.emptyList();
private List<List<?>> pandas_categorical = Collections.emptyList();


public void load(List<Section> sections){
Expand Down Expand Up @@ -611,7 +611,7 @@ private Map<String, String> loadFeatureSection(Section section){
return result;
}

private List<List<Object>> loadPandasCategorical(Section section){
private List<List<?>> loadPandasCategorical(Section section){
String id = section.id();

try {
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,76 @@
/*
* Copyright (c) 2023 Villu Ruusmann
*
* This file is part of JPMML-LightGBM
*
* JPMML-LightGBM is free software: you can redistribute it and/or modify
* it under the terms of the GNU Affero General Public License as published by
* the Free Software Foundation, either version 3 of the License, or
* (at your option) any later version.
*
* JPMML-LightGBM is distributed in the hope that it will be useful,
* but WITHOUT ANY WARRANTY; without even the implied warranty of
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
* GNU Affero General Public License for more details.
*
* You should have received a copy of the GNU Affero General Public License
* along with JPMML-LightGBM. If not, see <http://www.gnu.org/licenses/>.
*/
package org.jpmml.lightgbm;

import java.util.ArrayList;
import java.util.Collections;
import java.util.List;

import com.google.gson.Gson;
import com.google.gson.GsonBuilder;
import com.google.gson.JsonElement;
import com.google.gson.JsonParser;
import com.google.gson.ToNumberPolicy;

public class PandasCategoricalParser {

private String string = null;


public PandasCategoricalParser(String string){
setString(string);
}

public List<List<?>> parsePandasCategorical(){
String string = getString();

if(!string.startsWith(PandasCategoricalParser.PREFIX)){
throw new IllegalArgumentException(string);
}

string = string.substring(PandasCategoricalParser.PREFIX.length());

JsonElement element = JsonParser.parseString(string);

Gson gson = new GsonBuilder()
.setObjectToNumberStrategy(ToNumberPolicy.LONG_OR_DOUBLE)
.create();

List<List<?>> result = gson.fromJson(element, ListOfLists.class);
if(result == null){
result = Collections.emptyList();
}

return result;
}

public String getString(){
return this.string;
}

private void setString(String string){
this.string = string;
}

static
private class ListOfLists extends ArrayList<List<?>> {
}

private static final String PREFIX = "pandas_categorical:";
}
Original file line number Diff line number Diff line change
Expand Up @@ -22,8 +22,6 @@
import java.util.Collections;
import java.util.List;

import org.jpmml.lightgbm.PandasCategoricalParser;
import org.jpmml.lightgbm.ParseException;
import org.junit.Test;

import static org.junit.Assert.assertEquals;
Expand All @@ -32,17 +30,17 @@ public class PandasCategoricalParserTest {

@Test
public void parse() throws Exception {
List<List<Object>> pandasCategories = parsePandasCategorical("null");
List<List<?>> pandasCategories = parsePandasCategorical("null");

assertEquals(Collections.emptyList(), pandasCategories);

pandasCategories = parsePandasCategorical("[[\"null\", \"A\", \"B, B\", \"C, [C], C\"], [-2, -1, 0, 1, 2], [-2.0, -1.0, 0.0, 1.0, 2.0], [false, true]]");

assertEquals(Arrays.asList(Arrays.asList("null", "A", "B, B", "C, [C], C"), Arrays.asList(-2, -1, 0, 1, 2), Arrays.asList(-2d, -1d, 0d, 1d, 2d), Arrays.asList(Boolean.FALSE, Boolean.TRUE)), pandasCategories);
assertEquals(Arrays.asList(Arrays.asList("null", "A", "B, B", "C, [C], C"), Arrays.asList(-2L, -1L, 0L, 1L, 2L), Arrays.asList(-2d, -1d, 0d, 1d, 2d), Arrays.asList(Boolean.FALSE, Boolean.TRUE)), pandasCategories);
}

static
private List<List<Object>> parsePandasCategorical(String value) throws ParseException {
private List<List<?>> parsePandasCategorical(String value){
PandasCategoricalParser parser = new PandasCategoricalParser("pandas_categorical:" + value);

return parser.parsePandasCategorical();
Expand Down
6 changes: 6 additions & 0 deletions pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,12 @@
<version>1.72</version>
</dependency>

<dependency>
<groupId>com.google.code.gson</groupId>
<artifactId>gson</artifactId>
<version>[2.8.1, 2.10.1]</version>
</dependency>

<dependency>
<groupId>junit</groupId>
<artifactId>junit</artifactId>
Expand Down

0 comments on commit a2eb6ca

Please sign in to comment.