From 7893b0c8447ffea1e7eaa494fe8b5553c52d86c6 Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Tue, 30 Jul 2024 13:59:52 +0000 Subject: [PATCH] GH-28866: [Java] Java Dataset API ScanOptions expansion --- java/dataset/src/main/cpp/jni_wrapper.cc | 105 +++++++++++++++--- .../dataset/TestFragmentScanOptions.java | 40 +++++++ 2 files changed, 130 insertions(+), 15 deletions(-) diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index f324f87d6c301..255950c83be50 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -368,28 +368,103 @@ std::shared_ptr LoadArrowBufferFromByteBuffer(JNIEnv* env, jobjec inline bool ParseBool(const std::string& value) { return value == "true" ? true : false; } +inline bool ParseChar(const std::string& value) { + if (value.size() != 1) { + JniThrow("Csv convert option " + value + " should be a char"); + } + return value.at(0); +} + /// \brief Construct FragmentScanOptions from config map #ifdef ARROW_CSV + +bool setCsvConvertOptions(arrow::csv::ConvertOptions& options, const std::string& key, + std::string& value) { + if (key == "column_types") { + int64_t schema_address = std::stol(value); + ArrowSchema* c_schema = reinterpret_cast(schema_address); + ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(c_schema)); + auto& column_types = options.column_types; + for (auto field : schema->fields()) { + column_types[field->name()] = field->type(); + } + } else if (key == "strings_can_be_null") { + options.strings_can_be_null = ParseBool(value); + } else if (key == "check_utf8") { + options.check_utf8 = ParseBool(value); + } else if (key == "null_values") { + options.null_values = {value}; + } else if (key == "true_values") { + options.true_values = {value}; + } else if (key == "false_values") { + options.false_values = {value}; + } else if (key == "quoted_strings_can_be_null") { + options.quoted_strings_can_be_null = ParseBool(value); + } else if (key == "auto_dict_encode") { + options.auto_dict_encode = ParseBool(value); + } else if (key == "auto_dict_max_cardinality") { + options.auto_dict_max_cardinality = std::stoi(value); + } else if (key == "decimal_point") { + options.decimal_point = ParseChar(value); + } else if (key == "include_missing_columns") { + options.include_missing_columns = ParseBool(value); + } else { + return false; + } + return true; +} + +bool setCsvParseOptions(arrow::csv::ParseOptions& options, const std::string& key, + std::string& value) { + if (key == "delimiter") { + options.delimiter = parseChar(value); + } else if (key == "quoting") { + options.quoting = ParseBool(value); + } else if (key == "quote_char") { + options.quote_char = ParseChar(value); + } else if (key == "double_quote") { + options.double_quote = ParseBool(value); + } else if (key == "escaping") { + options.escaping = ParseBool(value); + } else if (key == "escape_char") { + options.escape_char = ParseChar(value); + } else if (key == "newlines_in_values") { + options.newlines_in_values = ParseBool(value); + } else if (key == "ignore_empty_lines") { + options.ignore_empty_lines = ParseBool(value); + } else { + return false; + } + return true; +} + +bool setCsvReadOptions(arrow::csv::ReadOptions& options, const std::string& key, + std::string& value) { + if (key == "use_threads") { + options.use_threads = ParseBool(value); + } else if (key == "block_size") { + options.block_size = std::stoi(value); + } else if (key == "skip_rows") { + options.skip_rows = std::stoi(value); + } else if (key == "skip_rows_after_names") { + options.skip_rows_after_names = std::stoi(value); + } else if (key == "autogenerate_column_names") { + options.autogenerate_column_names = ParseBool(value); + } else { + return false; + } + return true; +} + arrow::Result> ToCsvFragmentScanOptions(const std::unordered_map& configs) { std::shared_ptr options = std::make_shared(); for (auto const& [key, value] : configs) { - if (key == "delimiter") { - options->parse_options.delimiter = value.data()[0]; - } else if (key == "quoting") { - options->parse_options.quoting = ParseBool(value); - } else if (key == "column_types") { - int64_t schema_address = std::stol(value); - ArrowSchema* c_schema = reinterpret_cast(schema_address); - ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(c_schema)); - auto& column_types = options->convert_options.column_types; - for (auto field : schema->fields()) { - column_types[field->name()] = field->type(); - } - } else if (key == "strings_can_be_null") { - options->convert_options.strings_can_be_null = ParseBool(value); - } else { + bool setValid = setCsvParseOptions(options->parse_options, key, value) && + setCsvConvertOptions(options->convert_options, key, value) && + setCsvReadOptions(options->read_options, key, value); + if (!setValid) { return arrow::Status::Invalid("Config " + key + " is not supported."); } } diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java index 9787e8308e73e..1bf0ea4b38a3c 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/TestFragmentScanOptions.java @@ -165,4 +165,44 @@ public void testCsvConvertOptionsNoOption() throws Exception { assertEquals(3, rowCount); } } + + @Test + public void testCsvReadParseAndReadOptions() throws Exception { + final Schema schema = + new Schema( + Collections.singletonList(Field.nullable("Id;Name;Language", new ArrowType.Utf8())), + null); + String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; + BufferAllocator allocator = new RootAllocator(Long.MAX_VALUE); + CsvFragmentScanOptions fragmentScanOptions = + new CsvFragmentScanOptions( + new CsvConvertOptions(ImmutableMap.of()), + ImmutableMap.of("skip_rows", "1"), + ImmutableMap.of("delimiter", ";")); + ScanOptions options = + new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .fragmentScanOptions(fragmentScanOptions) + .build(); + try (DatasetFactory datasetFactory = + new FileSystemDatasetFactory( + allocator, NativeMemoryPool.getDefault(), FileFormat.CSV, path); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches()) { + + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowCount = 0; + while (reader.loadNextBatch()) { + final ValueIterableVector idVector = + (ValueIterableVector) + reader.getVectorSchemaRoot().getVector("Id;Name;Language"); + assertThat( + idVector.getValueIterable(), + IsIterableContainingInOrder.contains("2;Peter;Python\n" + "3;Celin;C++")); + rowCount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(2, rowCount); + } + } }