From 1223f364e0ac3d17eb2d88e38e1dde4348f2f48c Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Fri, 17 May 2024 15:41:54 +0000 Subject: [PATCH] support dataset option --- cpp/src/arrow/dataset/file_csv.cc | 30 ++++++ cpp/src/arrow/dataset/file_csv.h | 3 + cpp/src/arrow/engine/CMakeLists.txt | 1 + .../engine/substrait/extension_internal.cc | 51 ++++++++++ .../engine/substrait/extension_internal.h | 44 +++++++++ cpp/src/arrow/engine/substrait/serde.cc | 8 ++ cpp/src/arrow/engine/substrait/serde.h | 4 + cpp/thirdparty/versions.txt | 2 +- java/dataset/pom.xml | 14 +++ java/dataset/src/main/cpp/jni_wrapper.cc | 61 ++++++++++-- .../file/FileSystemDatasetFactory.java | 30 ++++-- .../apache/arrow/dataset/file/JniWrapper.java | 8 +- .../apache/arrow/dataset/jni/JniWrapper.java | 3 +- .../arrow/dataset/jni/NativeDataset.java | 14 ++- .../dataset/scanner/FragmentScanOptions.java | 55 +++++++++++ .../arrow/dataset/scanner/ScanOptions.java | 21 +++++ .../scanner/csv/CsvConvertOptions.java | 55 +++++++++++ .../scanner/csv/CsvFragmentScanOptions.java | 94 +++++++++++++++++++ .../dataset/substrait/StringMapNode.java | 53 +++++++++++ .../substrait/TestAceroSubstraitConsumer.java | 45 +++++++++ .../src/test/resources/data/student.csv | 4 + 21 files changed, 577 insertions(+), 23 deletions(-) create mode 100644 cpp/src/arrow/engine/substrait/extension_internal.cc create mode 100644 cpp/src/arrow/engine/substrait/extension_internal.h create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java create mode 100644 java/dataset/src/main/java/org/apache/arrow/dataset/substrait/StringMapNode.java create mode 100644 java/dataset/src/test/resources/data/student.csv diff --git a/cpp/src/arrow/dataset/file_csv.cc b/cpp/src/arrow/dataset/file_csv.cc index 09ab775727c98..3f43c65c205c4 100644 --- a/cpp/src/arrow/dataset/file_csv.cc +++ b/cpp/src/arrow/dataset/file_csv.cc @@ -24,6 +24,7 @@ #include #include +#include "arrow/c/bridge.h" #include "arrow/csv/options.h" #include "arrow/csv/parser.h" #include "arrow/csv/reader.h" @@ -52,6 +53,9 @@ using internal::Executor; using internal::SerialExecutor; namespace dataset { +namespace { +inline bool parseBool(const std::string& value) { return value == "true" ? true : false; } +} // namespace struct CsvInspectedFragment : public InspectedFragment { CsvInspectedFragment(std::vector column_names, @@ -503,5 +507,31 @@ Future<> CsvFileWriter::FinishInternal() { return Status::OK(); } +Result> CsvFragmentScanOptions::from( + const std::unordered_map& configs) { + std::shared_ptr options = + std::make_shared(); + for (auto const& it : configs) { + auto& key = it.first; + auto& value = it.second; + if (key == "delimiter") { + options->parse_options.delimiter = value.data()[0]; + } else if (key == "quoting") { + options->parse_options.quoting = parseBool(value); + } else if (key == "column_type") { + int64_t schema_address = std::stol(value); + ArrowSchema* cSchema = reinterpret_cast(schema_address); + ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(cSchema)); + auto column_types = options->convert_options.column_types; + for (auto field : schema->fields()) { + column_types[field->name()] = field->type(); + } + } else { + return Status::Invalid("Not support this config " + it.first); + } + } + return options; +} + } // namespace dataset } // namespace arrow diff --git a/cpp/src/arrow/dataset/file_csv.h b/cpp/src/arrow/dataset/file_csv.h index 42e3fd7246988..4d2825183fb71 100644 --- a/cpp/src/arrow/dataset/file_csv.h +++ b/cpp/src/arrow/dataset/file_csv.h @@ -85,6 +85,9 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat { struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions { std::string type_name() const override { return kCsvTypeName; } + static Result> from( + const std::unordered_map& configs); + using StreamWrapFunc = std::function>( std::shared_ptr)>; diff --git a/cpp/src/arrow/engine/CMakeLists.txt b/cpp/src/arrow/engine/CMakeLists.txt index fcaa242b11487..946425edb8cd5 100644 --- a/cpp/src/arrow/engine/CMakeLists.txt +++ b/cpp/src/arrow/engine/CMakeLists.txt @@ -22,6 +22,7 @@ arrow_install_all_headers("arrow/engine") set(ARROW_SUBSTRAIT_SRCS substrait/expression_internal.cc substrait/extended_expression_internal.cc + substrait/extension_internal.cc substrait/extension_set.cc substrait/extension_types.cc substrait/options.cc diff --git a/cpp/src/arrow/engine/substrait/extension_internal.cc b/cpp/src/arrow/engine/substrait/extension_internal.cc new file mode 100644 index 0000000000000..857b388e4211e --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extension_internal.cc @@ -0,0 +1,51 @@ +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#include "arrow/engine/substrait/extension_internal.h" + +#include "substrait/algebra.pb.h" + +namespace arrow { +namespace engine { + +Status FromProto(const substrait::extensions::AdvancedExtension& extension, + std::unordered_map& out) { + if (!extension.has_enhancement()) { + return Status::Invalid("AdvancedExtension does not have enhancement"); + } + const auto& enhancement = extension.enhancement(); + substrait::Expression_Literal literal; + + if (!enhancement.UnpackTo(&literal)) { + return Status::Invalid("Unpack the literal failed"); + } + + if (!literal.has_map()) { + return Status::Invalid("Literal does not have map"); + } + auto literalMap = literal.map(); + auto size = literalMap.key_values_size(); + for (auto i = 0; i < size; i++) { + substrait::Expression_Literal_Map_KeyValue keyValue = literalMap.key_values(i); + out.emplace(keyValue.key().string(), keyValue.value().string()); + } + return Status::OK(); +} +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/extension_internal.h b/cpp/src/arrow/engine/substrait/extension_internal.h new file mode 100644 index 0000000000000..84ae57c3de03b --- /dev/null +++ b/cpp/src/arrow/engine/substrait/extension_internal.h @@ -0,0 +1,44 @@ + +// Licensed to the Apache Software Foundation (ASF) under one +// or more contributor license agreements. See the NOTICE file +// distributed with this work for additional information +// regarding copyright ownership. The ASF licenses this file +// to you under the Apache License, Version 2.0 (the +// "License"); you may not use this file except in compliance +// with the License. You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, +// software distributed under the License is distributed on an +// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY +// KIND, either express or implied. See the License for the +// specific language governing permissions and limitations +// under the License. + +// This API is EXPERIMENTAL. + +#pragma once + +#include + +#include "arrow/compute/type_fwd.h" +#include "arrow/engine/substrait/extension_set.h" +#include "arrow/engine/substrait/options.h" +#include "arrow/engine/substrait/relation.h" +#include "arrow/engine/substrait/visibility.h" +#include "arrow/result.h" +#include "arrow/status.h" + +#include "substrait/extensions/extensions.pb.h" // IWYU pragma: export + +namespace arrow { +namespace engine { + +/// Convert a Substrait ExtendedExpression to a vector of expressions and output names +ARROW_ENGINE_EXPORT +Status FromProto(const substrait::extensions::AdvancedExtension& extension, + std::unordered_map& out); + +} // namespace engine +} // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 9e670f121778e..8d97410772eb8 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -36,6 +36,7 @@ #include "arrow/dataset/file_base.h" #include "arrow/engine/substrait/expression_internal.h" #include "arrow/engine/substrait/extended_expression_internal.h" +#include "arrow/engine/substrait/extension_internal.h" #include "arrow/engine/substrait/extension_set.h" #include "arrow/engine/substrait/plan_internal.h" #include "arrow/engine/substrait/relation.h" @@ -247,6 +248,13 @@ Result DeserializeExpressions( return FromProto(extended_expression, ext_set_out, conversion_options, registry); } +Status DeserializeMap(const Buffer& buf, + std::unordered_map out) { + ARROW_ASSIGN_OR_RAISE(auto advanced_extension, + ParseFromBuffer(buf)); + return FromProto(advanced_extension, out); +} + namespace { Result> MakeSingleDeclarationPlan( diff --git a/cpp/src/arrow/engine/substrait/serde.h b/cpp/src/arrow/engine/substrait/serde.h index ab749f4a64b05..1f6d261c2671b 100644 --- a/cpp/src/arrow/engine/substrait/serde.h +++ b/cpp/src/arrow/engine/substrait/serde.h @@ -23,6 +23,7 @@ #include #include #include +#include #include #include "arrow/compute/type_fwd.h" @@ -183,6 +184,9 @@ ARROW_ENGINE_EXPORT Result DeserializeExpressions( const ConversionOptions& conversion_options = {}, ExtensionSet* ext_set_out = NULLPTR); +ARROW_ENGINE_EXPORT Status +DeserializeMap(const Buffer& buf, std::unordered_map out); + /// \brief Deserializes a Substrait Type message to the corresponding Arrow type /// /// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type diff --git a/cpp/thirdparty/versions.txt b/cpp/thirdparty/versions.txt index 4983f3cee2c2d..893ba83d85d6a 100644 --- a/cpp/thirdparty/versions.txt +++ b/cpp/thirdparty/versions.txt @@ -108,7 +108,7 @@ ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=f989a862f694e7dbb695925ddb7c4ce06aa6c51aca ARROW_S2N_TLS_BUILD_VERSION=v1.3.35 ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d ARROW_THRIFT_BUILD_VERSION=0.16.0 -ARROW_THRIFT_BUILD_SHA256_CHECKSUM=f460b5c1ca30d8918ff95ea3eb6291b3951cf518553566088f3f2be8981f6209 +ARROW_THRIFT_BUILD_SHA256_CHECKSUM=df2931de646a366c2e5962af679018bca2395d586e00ba82d09c0379f14f8e7b ARROW_UCX_BUILD_VERSION=1.12.1 ARROW_UCX_BUILD_SHA256_CHECKSUM=9bef31aed0e28bf1973d28d74d9ac4f8926c43ca3b7010bd22a084e164e31b71 ARROW_UTF8PROC_BUILD_VERSION=v2.7.0 diff --git a/java/dataset/pom.xml b/java/dataset/pom.xml index 2121119af398e..23ee45f2fd5a5 100644 --- a/java/dataset/pom.xml +++ b/java/dataset/pom.xml @@ -25,6 +25,8 @@ ../../../cpp/release-build/ 1.13.1 1.11.3 + 0.31.0 + 3.25.3 @@ -48,6 +50,18 @@ org.immutables value + + io.substrait + core + ${substrait.version} + compile + + + com.google.protobuf + protobuf-java + ${protobuf.version} + compile + org.apache.arrow arrow-memory-netty diff --git a/java/dataset/src/main/cpp/jni_wrapper.cc b/java/dataset/src/main/cpp/jni_wrapper.cc index 19a43c8d2fa41..011fe4bac3530 100644 --- a/java/dataset/src/main/cpp/jni_wrapper.cc +++ b/java/dataset/src/main/cpp/jni_wrapper.cc @@ -25,6 +25,7 @@ #include "arrow/c/helpers.h" #include "arrow/dataset/api.h" #include "arrow/dataset/file_base.h" +#include "arrow/dataset/file_csv.h" #include "arrow/filesystem/localfs.h" #include "arrow/filesystem/path_util.h" #include "arrow/filesystem/s3fs.h" @@ -154,6 +155,21 @@ arrow::Result> GetFileFormat( } } +arrow::Result> +GetFragmentScanOptions(jint file_format_id, + const std::unordered_map& configs) { + switch (file_format_id) { +#ifdef ARROW_CSV + case 3: + return arrow::dataset::CsvFragmentScanOptions::from(configs); +#endif + default: + std::string error_message = + "illegal file format id: " + std::to_string(file_format_id); + return arrow::Status::Invalid(error_message); + } +} + class ReserveFromJava : public arrow::dataset::jni::ReservationListener { public: ReserveFromJava(JavaVM* vm, jobject java_reservation_listener) @@ -502,12 +518,13 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset /* * Class: org_apache_arrow_dataset_jni_JniWrapper * Method: createScanner - * Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J + * Signature: + * (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ;Ljava/nio/ByteBuffer;J)J */ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner( JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns, - jobject substrait_projection, jobject substrait_filter, - jlong batch_size, jlong memory_pool_id) { + jobject substrait_projection, jobject substrait_filter, jlong batch_size, + jlong file_format_id, jobject options, jlong memory_pool_id) { JNI_METHOD_START arrow::MemoryPool* pool = reinterpret_cast(memory_pool_id); if (pool == nullptr) { @@ -556,6 +573,14 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann } JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr)); } + if (file_format_id != -1 && options != nullptr) { + std::unordered_map option_map; + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, options); + JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, option_map)); + std::shared_ptr scan_options = + JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map)); + JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(scan_options)); + } JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size)); auto scanner = JniGetOrThrow(scanner_builder->Finish()); @@ -667,14 +692,22 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_ensureS3Fina /* * Class: org_apache_arrow_dataset_file_JniWrapper * Method: makeFileSystemDatasetFactory - * Signature: (Ljava/lang/String;II)J + * Signature: (Ljava/lang/String;IILjava/lang/String;Ljava/nio/ByteBuffer)J */ JNIEXPORT jlong JNICALL -Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljava_lang_String_2I( - JNIEnv* env, jobject, jstring uri, jint file_format_id) { +Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory( + JNIEnv* env, jobject, jstring uri, jint file_format_id, jobject options) { JNI_METHOD_START std::shared_ptr file_format = JniGetOrThrow(GetFileFormat(file_format_id)); + if (options != nullptr) { + std::unordered_map option_map; + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, options); + JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, option_map)); + std::shared_ptr scan_options = + JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map)); + file_format->default_fragment_scan_options = scan_options; + } arrow::dataset::FileSystemFactoryOptions options; std::shared_ptr d = JniGetOrThrow(arrow::dataset::FileSystemDatasetFactory::Make( @@ -685,16 +718,24 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljav /* * Class: org_apache_arrow_dataset_file_JniWrapper - * Method: makeFileSystemDatasetFactory - * Signature: ([Ljava/lang/String;II)J + * Method: makeFileSystemDatasetFactoryWithFiles + * Signature: ([Ljava/lang/String;IIJ;Ljava/nio/ByteBuffer)J */ JNIEXPORT jlong JNICALL -Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory___3Ljava_lang_String_2I( - JNIEnv* env, jobject, jobjectArray uris, jint file_format_id) { +Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactoryWithFiles( + JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jobject options) { JNI_METHOD_START std::shared_ptr file_format = JniGetOrThrow(GetFileFormat(file_format_id)); + if (options != nullptr) { + std::unordered_map option_map; + std::shared_ptr buffer = LoadArrowBufferFromByteBuffer(env, options); + JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, option_map)); + std::shared_ptr scan_options = + JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map)); + file_format->default_fragment_scan_options = scan_options; + } arrow::dataset::FileSystemFactoryOptions options; std::vector uri_vec = ToStringVector(env, uris); diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java index aa315690592ee..a0b6fb168eca9 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/FileSystemDatasetFactory.java @@ -17,8 +17,11 @@ package org.apache.arrow.dataset.file; +import java.util.Optional; + import org.apache.arrow.dataset.jni.NativeDatasetFactory; import org.apache.arrow.dataset.jni.NativeMemoryPool; +import org.apache.arrow.dataset.scanner.FragmentScanOptions; import org.apache.arrow.memory.BufferAllocator; /** @@ -27,21 +30,34 @@ public class FileSystemDatasetFactory extends NativeDatasetFactory { public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, - String uri) { - super(allocator, memoryPool, createNative(format, uri)); + String uri, Optional fragmentScanOptions) { + super(allocator, memoryPool, createNative(format, uri, fragmentScanOptions)); + } + + public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, + String uri) { + super(allocator, memoryPool, createNative(format, uri, Optional.empty())); + } + + public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, + String[] uris, Optional fragmentScanOptions) { + super(allocator, memoryPool, createNative(format, uris, fragmentScanOptions)); } public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format, String[] uris) { - super(allocator, memoryPool, createNative(format, uris)); + super(allocator, memoryPool, createNative(format, uris, Optional.empty())); } - private static long createNative(FileFormat format, String uri) { - return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id()); + private static long createNative(FileFormat format, String uri, Optional fragmentScanOptions) { + return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id(), + fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null)); } - private static long createNative(FileFormat format, String[] uris) { - return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id()); + private static long createNative(FileFormat format, String[] uris, + Optional fragmentScanOptions) { + return JniWrapper.get().makeFileSystemDatasetFactoryWithFiles(uris, format.id(), + fragmentScanOptions.map(FragmentScanOptions::serialize).orElse(null)); } } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java index c3a1a4e58a140..c3f8e12b38ebe 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/file/JniWrapper.java @@ -17,6 +17,8 @@ package org.apache.arrow.dataset.file; +import java.nio.ByteBuffer; + import org.apache.arrow.dataset.jni.JniLoader; /** @@ -43,7 +45,8 @@ private JniWrapper() { * @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance. * @see FileFormat */ - public native long makeFileSystemDatasetFactory(String uri, int fileFormat); + public native long makeFileSystemDatasetFactory(String uri, int fileFormat, + ByteBuffer serializedFragmentScanOptions); /** * Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a @@ -54,7 +57,8 @@ private JniWrapper() { * @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance. * @see FileFormat */ - public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat); + public native long makeFileSystemDatasetFactoryWithFiles(String[] uris, int fileFormat, + ByteBuffer serializedFragmentScanOptions); /** * Write the content in a {@link org.apache.arrow.c.ArrowArrayStream} into files. This internally diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java index 637a3e8f22a9a..6d6309140605b 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/JniWrapper.java @@ -80,7 +80,8 @@ private JniWrapper() { * @return the native pointer of the arrow::dataset::Scanner instance. */ public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection, - ByteBuffer substraitFilter, long batchSize, long memoryPool); + ByteBuffer substraitFilter, long batchSize, long fileFormat, + ByteBuffer serializedFragmentScanOptions, long memoryPool); /** * Get a serialized schema from native instance of a Scanner. diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java index d9abad9971c4e..3a96fe768761c 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/jni/NativeDataset.java @@ -17,6 +17,9 @@ package org.apache.arrow.dataset.jni; +import java.nio.ByteBuffer; + +import org.apache.arrow.dataset.scanner.FragmentScanOptions; import org.apache.arrow.dataset.scanner.ScanOptions; import org.apache.arrow.dataset.source.Dataset; @@ -40,11 +43,18 @@ public synchronized NativeScanner newScan(ScanOptions options) { if (closed) { throw new NativeInstanceReleasedException(); } - + int fileFormat = -1; + ByteBuffer serialized = null; + if (options.getFragmentScanOptions().isPresent()) { + FragmentScanOptions fragmentScanOptions = options.getFragmentScanOptions().get(); + fileFormat = fragmentScanOptions.fileFormatId(); + serialized = fragmentScanOptions.serialize(); + } long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null), options.getSubstraitProjection().orElse(null), options.getSubstraitFilter().orElse(null), - options.getBatchSize(), context.getMemoryPool().getNativeInstanceId()); + options.getBatchSize(), fileFormat, serialized, + context.getMemoryPool().getNativeInstanceId()); return new NativeScanner(context, scannerId); } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java new file mode 100644 index 0000000000000..a5c20b1a796cf --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.scanner; + +import java.nio.ByteBuffer; +import java.util.Map; + +import org.apache.arrow.dataset.substrait.StringMapNode; + +import com.google.protobuf.Any; + +import io.substrait.proto.AdvancedExtension; + +public interface FragmentScanOptions { + String typeName(); + + int fileFormatId(); + + ByteBuffer serialize(); + + /** + * serialize the map. + * + * @param config config map + * @return bufer to jni call argument, should be DirectByteBuffer + */ + default ByteBuffer serializeMap(Map config) { + if (config.isEmpty()) { + return null; + } + StringMapNode stringMapNode = new StringMapNode(config); + AdvancedExtension.Builder extensionBuilder = AdvancedExtension.newBuilder(); + Any.Builder builder = extensionBuilder.getEnhancementBuilder(); + builder.setValue(stringMapNode.toProtobuf().toByteString()); + AdvancedExtension extension = extensionBuilder.build(); + ByteBuffer buf = ByteBuffer.allocateDirect(extension.getSerializedSize()); + buf.put(extension.toByteArray()); + return buf; + } +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java index 995d05ac3b314..aad71930c431b 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/ScanOptions.java @@ -31,6 +31,8 @@ public class ScanOptions { private final Optional substraitProjection; private final Optional substraitFilter; + private final Optional fragmentScanOptions; + /** * Constructor. * @param columns Projected columns. Empty for scanning all columns. @@ -61,6 +63,7 @@ public ScanOptions(long batchSize, Optional columns) { this.columns = columns; this.substraitProjection = Optional.empty(); this.substraitFilter = Optional.empty(); + this.fragmentScanOptions = Optional.empty(); } public ScanOptions(long batchSize) { @@ -83,6 +86,10 @@ public Optional getSubstraitFilter() { return substraitFilter; } + public Optional getFragmentScanOptions() { + return fragmentScanOptions; + } + /** * Builder for Options used during scanning. */ @@ -91,6 +98,7 @@ public static class Builder { private Optional columns; private ByteBuffer substraitProjection; private ByteBuffer substraitFilter; + private FragmentScanOptions fragmentScanOptions; /** * Constructor. @@ -136,6 +144,18 @@ public Builder substraitFilter(ByteBuffer substraitFilter) { return this; } + /** + * Set the FragmentScanOptions. + * + * @param fragmentScanOptions scan options + * @return the ScanOptions configured. + */ + public Builder fragmentScanOptions(FragmentScanOptions fragmentScanOptions) { + Preconditions.checkNotNull(fragmentScanOptions); + this.fragmentScanOptions = fragmentScanOptions; + return this; + } + public ScanOptions build() { return new ScanOptions(this); } @@ -146,5 +166,6 @@ private ScanOptions(Builder builder) { columns = builder.columns; substraitProjection = Optional.ofNullable(builder.substraitProjection); substraitFilter = Optional.ofNullable(builder.substraitFilter); + fragmentScanOptions = Optional.ofNullable(builder.fragmentScanOptions); } } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java new file mode 100644 index 0000000000000..41052836b1c00 --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvConvertOptions.java @@ -0,0 +1,55 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.scanner.csv; + +import java.util.Map; +import java.util.Optional; + +import org.apache.arrow.c.ArrowSchema; + +public class CsvConvertOptions { + + private final Map configs; + + private Optional cSchema = Optional.empty(); + + public CsvConvertOptions(Map configs) { + this.configs = configs; + } + + public Optional getArrowSchema() { + return cSchema; + } + + public long getArrowSchemaAddress() { + return cSchema.isPresent() ? cSchema.get().memoryAddress() : -1; + } + + public Map getConfigs() { + return configs; + } + + public void set(String key, String value) { + configs.put(key, value); + } + + public void setArrowSchema(ArrowSchema cSchema) { + this.cSchema = Optional.of(cSchema); + } + +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java new file mode 100644 index 0000000000000..425af34e9ed10 --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/csv/CsvFragmentScanOptions.java @@ -0,0 +1,94 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.scanner.csv; + +import java.io.Serializable; +import java.nio.ByteBuffer; +import java.util.Locale; +import java.util.Map; +import java.util.stream.Collectors; +import java.util.stream.Stream; + +import org.apache.arrow.dataset.file.FileFormat; +import org.apache.arrow.dataset.scanner.FragmentScanOptions; + +public class CsvFragmentScanOptions implements Serializable, FragmentScanOptions { + private final CsvConvertOptions convertOptions; + private final Map readOptions; + private final Map parseOptions; + + + /** + * csv scan options, map to CPP struct CsvFragmentScanOptions. + * + * @param convertOptions same struct in CPP + * @param readOptions same struct in CPP + * @param parseOptions same struct in CPP + */ + public CsvFragmentScanOptions(CsvConvertOptions convertOptions, + Map readOptions, + Map parseOptions) { + this.convertOptions = convertOptions; + this.readOptions = readOptions; + this.parseOptions = parseOptions; + } + + public String typeName() { + return FileFormat.CSV.name().toLowerCase(Locale.ROOT); + } + + /** + * File format id. + * + * @return id + */ + public int fileFormatId() { + return FileFormat.CSV.id(); + } + + /** + * Serialize this class to ByteBuffer and then called by jni call. + * + * @return DirectByteBuffer + */ + public ByteBuffer serialize() { + Map options = Stream.concat(Stream.concat(readOptions.entrySet().stream(), + parseOptions.entrySet().stream()), + convertOptions.getConfigs().entrySet().stream()).collect( + Collectors.toMap(Map.Entry::getKey, Map.Entry::getValue)); + options.put("column_type", Long.toString(convertOptions.getArrowSchemaAddress())); + return serializeMap(options); + } + + public static CsvFragmentScanOptions deserialize(String serialized) { + throw new UnsupportedOperationException("Not implemented now"); + } + + public CsvConvertOptions getConvertOptions() { + return convertOptions; + } + + public Map getReadOptions() { + return readOptions; + } + + public Map getParseOptions() { + return parseOptions; + } + +} diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/substrait/StringMapNode.java b/java/dataset/src/main/java/org/apache/arrow/dataset/substrait/StringMapNode.java new file mode 100644 index 0000000000000..0873da70b2964 --- /dev/null +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/substrait/StringMapNode.java @@ -0,0 +1,53 @@ +/* + * Licensed to the Apache Software Foundation (ASF) under one or more + * contributor license agreements. See the NOTICE file distributed with + * this work for additional information regarding copyright ownership. + * The ASF licenses this file to You under the Apache License, Version 2.0 + * (the "License"); you may not use this file except in compliance with + * the License. You may obtain a copy of the License at + * + * http://www.apache.org/licenses/LICENSE-2.0 + * + * Unless required by applicable law or agreed to in writing, software + * distributed under the License is distributed on an "AS IS" BASIS, + * WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. + * See the License for the specific language governing permissions and + * limitations under the License. + */ + +package org.apache.arrow.dataset.substrait; + +import java.io.Serializable; +import java.util.HashMap; +import java.util.Map; + +import io.substrait.proto.Expression; + +public class StringMapNode implements Serializable { + private final Map values = new HashMap<>(); + + public StringMapNode(Map values) { + this.values.putAll(values); + } + + /** + * Serialize String map. + * + * @return Substrait Literal + */ + public Expression.Literal toProtobuf() { + Expression.Literal.Builder literalBuilder = Expression.Literal.newBuilder(); + Expression.Literal.Map.KeyValue.Builder keyValueBuilder = + Expression.Literal.Map.KeyValue.newBuilder(); + Expression.Literal.Map.Builder mapBuilder = Expression.Literal.Map.newBuilder(); + for (Map.Entry entry : values.entrySet()) { + literalBuilder.setString(entry.getKey()); + keyValueBuilder.setKey(literalBuilder.build()); + literalBuilder.setString(entry.getValue()); + keyValueBuilder.setValue(literalBuilder.build()); + mapBuilder.addKeyValues(keyValueBuilder.build()); + } + literalBuilder.setMap(mapBuilder.build()); + return literalBuilder.build(); + } +} diff --git a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java index 0fba72892cdc6..460253942e11c 100644 --- a/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java +++ b/java/dataset/src/test/java/org/apache/arrow/dataset/substrait/TestAceroSubstraitConsumer.java @@ -31,6 +31,9 @@ import java.util.Map; import java.util.Optional; +import org.apache.arrow.c.ArrowSchema; +import org.apache.arrow.c.CDataDictionaryProvider; +import org.apache.arrow.c.Data; import org.apache.arrow.dataset.ParquetWriteSupport; import org.apache.arrow.dataset.TestDataset; import org.apache.arrow.dataset.file.FileFormat; @@ -38,8 +41,11 @@ import org.apache.arrow.dataset.jni.NativeMemoryPool; import org.apache.arrow.dataset.scanner.ScanOptions; import org.apache.arrow.dataset.scanner.Scanner; +import org.apache.arrow.dataset.scanner.csv.CsvConvertOptions; +import org.apache.arrow.dataset.scanner.csv.CsvFragmentScanOptions; import org.apache.arrow.dataset.source.Dataset; import org.apache.arrow.dataset.source.DatasetFactory; +import org.apache.arrow.memory.BufferAllocator; import org.apache.arrow.vector.ipc.ArrowReader; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.Field; @@ -49,6 +55,8 @@ import org.junit.Test; import org.junit.rules.TemporaryFolder; +import com.google.common.collect.ImmutableMap; + public class TestAceroSubstraitConsumer extends TestDataset { @ClassRule @@ -457,4 +465,41 @@ private static ByteBuffer getByteBuffer(String base64EncodedSubstrait) { substraitExpression.put(decodedSubstrait); return substraitExpression; } + + @Test + public void testCsvConvertOptions() throws Exception { + final Schema schema = new Schema(Arrays.asList( + Field.nullable("Id", new ArrowType.Int(32, true)), + Field.nullable("Name", new ArrowType.Utf8()), + Field.nullable("Language", new ArrowType.Utf8()) + ), null); + String path = "file://" + getClass().getResource("/").getPath() + "/data/student.csv"; + BufferAllocator allocator = rootAllocator(); + try (ArrowSchema cSchema = ArrowSchema.allocateNew(allocator)) { + Data.exportSchema(allocator, schema, new CDataDictionaryProvider(), cSchema); + CsvConvertOptions convertOptions = new CsvConvertOptions(ImmutableMap.of("delimiter", ";")); + convertOptions.setArrowSchema(cSchema); + CsvFragmentScanOptions fragmentScanOptions = new CsvFragmentScanOptions( + convertOptions, ImmutableMap.of(), ImmutableMap.of()); + ScanOptions options = new ScanOptions.Builder(/*batchSize*/ 32768) + .columns(Optional.empty()) + .fragmentScanOptions(fragmentScanOptions) + .build(); + try ( + DatasetFactory datasetFactory = new FileSystemDatasetFactory(allocator, NativeMemoryPool.getDefault(), + FileFormat.CSV, path); + Dataset dataset = datasetFactory.finish(); + Scanner scanner = dataset.newScan(options); + ArrowReader reader = scanner.scanBatches() + ) { + assertEquals(schema.getFields(), reader.getVectorSchemaRoot().getSchema().getFields()); + int rowCount = 0; + while (reader.loadNextBatch()) { + assertEquals("[1, 2, 3]", reader.getVectorSchemaRoot().getVector("Id").toString()); + rowCount += reader.getVectorSchemaRoot().getRowCount(); + } + assertEquals(3, rowCount); + } + } + } } diff --git a/java/dataset/src/test/resources/data/student.csv b/java/dataset/src/test/resources/data/student.csv new file mode 100644 index 0000000000000..3291946092156 --- /dev/null +++ b/java/dataset/src/test/resources/data/student.csv @@ -0,0 +1,4 @@ +Id;Name;Language +1;Juno;Java +2;Peter;Python +3;Celin;C++