Skip to content

Commit

Permalink
apacheGH-28866: [Java] Java Dataset API ScanOptions expansion
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed Jul 30, 2024
1 parent fd69e5e commit 7893b0c
Show file tree
Hide file tree
Showing 2 changed files with 130 additions and 15 deletions.
105 changes: 90 additions & 15 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -368,28 +368,103 @@ std::shared_ptr<arrow::Buffer> LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec

inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; }

inline bool ParseChar(const std::string& value) {
if (value.size() != 1) {
JniThrow("Csv convert option " + value + " should be a char");
}
return value.at(0);
}

/// \brief Construct FragmentScanOptions from config map
#ifdef ARROW_CSV

bool setCsvConvertOptions(arrow::csv::ConvertOptions& options, const std::string& key,
std::string& value) {
if (key == "column_types") {
int64_t schema_address = std::stol(value);
ArrowSchema* c_schema = reinterpret_cast<ArrowSchema*>(schema_address);
ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(c_schema));
auto& column_types = options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
} else if (key == "strings_can_be_null") {
options.strings_can_be_null = ParseBool(value);
} else if (key == "check_utf8") {
options.check_utf8 = ParseBool(value);
} else if (key == "null_values") {
options.null_values = {value};
} else if (key == "true_values") {
options.true_values = {value};
} else if (key == "false_values") {
options.false_values = {value};
} else if (key == "quoted_strings_can_be_null") {
options.quoted_strings_can_be_null = ParseBool(value);
} else if (key == "auto_dict_encode") {
options.auto_dict_encode = ParseBool(value);
} else if (key == "auto_dict_max_cardinality") {
options.auto_dict_max_cardinality = std::stoi(value);
} else if (key == "decimal_point") {
options.decimal_point = ParseChar(value);
} else if (key == "include_missing_columns") {
options.include_missing_columns = ParseBool(value);
} else {
return false;
}
return true;
}

bool setCsvParseOptions(arrow::csv::ParseOptions& options, const std::string& key,
std::string& value) {
if (key == "delimiter") {
options.delimiter = parseChar(value);
} else if (key == "quoting") {
options.quoting = ParseBool(value);
} else if (key == "quote_char") {
options.quote_char = ParseChar(value);
} else if (key == "double_quote") {
options.double_quote = ParseBool(value);
} else if (key == "escaping") {
options.escaping = ParseBool(value);
} else if (key == "escape_char") {
options.escape_char = ParseChar(value);
} else if (key == "newlines_in_values") {
options.newlines_in_values = ParseBool(value);
} else if (key == "ignore_empty_lines") {
options.ignore_empty_lines = ParseBool(value);
} else {
return false;
}
return true;
}

bool setCsvReadOptions(arrow::csv::ReadOptions& options, const std::string& key,
std::string& value) {
if (key == "use_threads") {
options.use_threads = ParseBool(value);
} else if (key == "block_size") {
options.block_size = std::stoi(value);
} else if (key == "skip_rows") {
options.skip_rows = std::stoi(value);
} else if (key == "skip_rows_after_names") {
options.skip_rows_after_names = std::stoi(value);
} else if (key == "autogenerate_column_names") {
options.autogenerate_column_names = ParseBool(value);
} else {
return false;
}
return true;
}

arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
ToCsvFragmentScanOptions(const std::unordered_map<std::string, std::string>& configs) {
std::shared_ptr<arrow::dataset::CsvFragmentScanOptions> options =
std::make_shared<arrow::dataset::CsvFragmentScanOptions>();
for (auto const& [key, value] : configs) {
if (key == "delimiter") {
options->parse_options.delimiter = value.data()[0];
} else if (key == "quoting") {
options->parse_options.quoting = ParseBool(value);
} else if (key == "column_types") {
int64_t schema_address = std::stol(value);
ArrowSchema* c_schema = reinterpret_cast<ArrowSchema*>(schema_address);
ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(c_schema));
auto& column_types = options->convert_options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
} else if (key == "strings_can_be_null") {
options->convert_options.strings_can_be_null = ParseBool(value);
} else {
bool setValid = setCsvParseOptions(options->parse_options, key, value) &&
setCsvConvertOptions(options->convert_options, key, value) &&
setCsvReadOptions(options->read_options, key, value);
if (!setValid) {
return arrow::Status::Invalid("Config " + key + " is not supported.");
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -165,4 +165,44 @@ public void testCsvConvertOptionsNoOption() throws Exception {
assertEquals(3, rowCount);
}
}

@Test
public void testCsvReadParseAndReadOptions() throws Exception {
final Schema schema =
new Schema(
Collections.singletonList(Field.nullable("Id;Name;Language", new ArrowType.Utf8())),
null);
String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv";
BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE);
CsvFragmentScanOptions fragmentScanOptions =
new CsvFragmentScanOptions(
new CsvConvertOptions(ImmutableMap.of()),
ImmutableMap.of("skip_rows", "1"),
ImmutableMap.of("delimiter", ";"));
ScanOptions options =
new ScanOptions.Builder(/*batchSize*/ 32768)
.columns(Optional.empty())
.fragmentScanOptions(fragmentScanOptions)
.build();
try (DatasetFactory datasetFactory =
new FileSystemDatasetFactory(
allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path);
Dataset dataset = datasetFactory.finish();
Scanner scanner = dataset.newScan(options);
ArrowReader reader = scanner.scanBatches()) {

assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields());
int rowCount = 0;
while (reader.loadNextBatch()) {
final ValueIterableVector<String> idVector =
(ValueIterableVector<String>)
reader.getVectorSchemaRoot().getVector("Id;Name;Language");
assertThat(
idVector.getValueIterable(),
IsIterableContainingInOrder.contains("2;Peter;Python\n" + "3;Celin;C++"));
rowCount += reader.getVectorSchemaRoot().getRowCount();
}
assertEquals(2, rowCount);
}
}
}

0 comments on commit 7893b0c

Please sign in to comment.