Skip to content

Commit

Permalink
support dataset option
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed May 17, 2024
1 parent e1de9c5 commit 1223f36
Show file tree
Hide file tree
Showing 21 changed files with 577 additions and 23 deletions.
30 changes: 30 additions & 0 deletions cpp/src/arrow/dataset/file_csv.cc
Original file line number Diff line number Diff line change
Expand Up @@ -24,6 +24,7 @@
#include <unordered_set>
#include <utility>

#include "arrow/c/bridge.h"
#include "arrow/csv/options.h"
#include "arrow/csv/parser.h"
#include "arrow/csv/reader.h"
Expand Down Expand Up @@ -52,6 +53,9 @@ using internal::Executor;
using internal::SerialExecutor;

namespace dataset {
namespace {
inline bool parseBool(const std::string& value) { return value == "true" ? true : false; }
} // namespace

struct CsvInspectedFragment : public InspectedFragment {
CsvInspectedFragment(std::vector<std::string> column_names,
Expand Down Expand Up @@ -503,5 +507,31 @@ Future<> CsvFileWriter::FinishInternal() {
return Status::OK();
}

Result<std::shared_ptr<FragmentScanOptions>> CsvFragmentScanOptions::from(
const std::unordered_map<std::string, std::string>& configs) {
std::shared_ptr<CsvFragmentScanOptions> options =
std::make_shared<CsvFragmentScanOptions>();
for (auto const& it : configs) {
auto& key = it.first;
auto& value = it.second;
if (key == "delimiter") {
options->parse_options.delimiter = value.data()[0];
} else if (key == "quoting") {
options->parse_options.quoting = parseBool(value);
} else if (key == "column_type") {
int64_t schema_address = std::stol(value);
ArrowSchema* cSchema = reinterpret_cast<ArrowSchema*>(schema_address);
ARROW_ASSIGN_OR_RAISE(auto schema, arrow::ImportSchema(cSchema));
auto column_types = options->convert_options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
} else {
return Status::Invalid("Not support this config " + it.first);
}
}
return options;
}

} // namespace dataset
} // namespace arrow
3 changes: 3 additions & 0 deletions cpp/src/arrow/dataset/file_csv.h
Original file line number Diff line number Diff line change
Expand Up @@ -85,6 +85,9 @@ class ARROW_DS_EXPORT CsvFileFormat : public FileFormat {
struct ARROW_DS_EXPORT CsvFragmentScanOptions : public FragmentScanOptions {
std::string type_name() const override { return kCsvTypeName; }

static Result<std::shared_ptr<FragmentScanOptions>> from(
const std::unordered_map<std::string, std::string>& configs);

using StreamWrapFunc = std::function<Result<std::shared_ptr<io::InputStream>>(
std::shared_ptr<io::InputStream>)>;

Expand Down
1 change: 1 addition & 0 deletions cpp/src/arrow/engine/CMakeLists.txt
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ arrow_install_all_headers("arrow/engine")
set(ARROW_SUBSTRAIT_SRCS
substrait/expression_internal.cc
substrait/extended_expression_internal.cc
substrait/extension_internal.cc
substrait/extension_set.cc
substrait/extension_types.cc
substrait/options.cc
Expand Down
51 changes: 51 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_internal.cc
Original file line number Diff line number Diff line change
@@ -0,0 +1,51 @@
// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// This API is EXPERIMENTAL.

#include "arrow/engine/substrait/extension_internal.h"

#include "substrait/algebra.pb.h"

namespace arrow {
namespace engine {

Status FromProto(const substrait::extensions::AdvancedExtension& extension,
std::unordered_map<std::string, std::string>& out) {
if (!extension.has_enhancement()) {
return Status::Invalid("AdvancedExtension does not have enhancement");
}
const auto& enhancement = extension.enhancement();
substrait::Expression_Literal literal;

if (!enhancement.UnpackTo(&literal)) {
return Status::Invalid("Unpack the literal failed");
}

if (!literal.has_map()) {
return Status::Invalid("Literal does not have map");
}
auto literalMap = literal.map();
auto size = literalMap.key_values_size();
for (auto i = 0; i < size; i++) {
substrait::Expression_Literal_Map_KeyValue keyValue = literalMap.key_values(i);
out.emplace(keyValue.key().string(), keyValue.value().string());
}
return Status::OK();
}
} // namespace engine
} // namespace arrow
44 changes: 44 additions & 0 deletions cpp/src/arrow/engine/substrait/extension_internal.h
Original file line number Diff line number Diff line change
@@ -0,0 +1,44 @@

// Licensed to the Apache Software Foundation (ASF) under one
// or more contributor license agreements. See the NOTICE file
// distributed with this work for additional information
// regarding copyright ownership. The ASF licenses this file
// to you under the Apache License, Version 2.0 (the
// "License"); you may not use this file except in compliance
// with the License. You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing,
// software distributed under the License is distributed on an
// "AS IS" BASIS, WITHOUT WARRANTIES OR CONDITIONS OF ANY
// KIND, either express or implied. See the License for the
// specific language governing permissions and limitations
// under the License.

// This API is EXPERIMENTAL.

#pragma once

#include <memory>

#include "arrow/compute/type_fwd.h"
#include "arrow/engine/substrait/extension_set.h"
#include "arrow/engine/substrait/options.h"
#include "arrow/engine/substrait/relation.h"
#include "arrow/engine/substrait/visibility.h"
#include "arrow/result.h"
#include "arrow/status.h"

#include "substrait/extensions/extensions.pb.h" // IWYU pragma: export

namespace arrow {
namespace engine {

/// Convert a Substrait ExtendedExpression to a vector of expressions and output names
ARROW_ENGINE_EXPORT
Status FromProto(const substrait::extensions::AdvancedExtension& extension,
std::unordered_map<std::string, std::string>& out);

} // namespace engine
} // namespace arrow
8 changes: 8 additions & 0 deletions cpp/src/arrow/engine/substrait/serde.cc
Original file line number Diff line number Diff line change
Expand Up @@ -36,6 +36,7 @@
#include "arrow/dataset/file_base.h"
#include "arrow/engine/substrait/expression_internal.h"
#include "arrow/engine/substrait/extended_expression_internal.h"
#include "arrow/engine/substrait/extension_internal.h"
#include "arrow/engine/substrait/extension_set.h"
#include "arrow/engine/substrait/plan_internal.h"
#include "arrow/engine/substrait/relation.h"
Expand Down Expand Up @@ -247,6 +248,13 @@ Result<BoundExpressions> DeserializeExpressions(
return FromProto(extended_expression, ext_set_out, conversion_options, registry);
}

Status DeserializeMap(const Buffer& buf,
std::unordered_map<std::string, std::string> out) {
ARROW_ASSIGN_OR_RAISE(auto advanced_extension,
ParseFromBuffer<substrait::extensions::AdvancedExtension>(buf));
return FromProto(advanced_extension, out);
}

namespace {

Result<std::shared_ptr<acero::ExecPlan>> MakeSingleDeclarationPlan(
Expand Down
4 changes: 4 additions & 0 deletions cpp/src/arrow/engine/substrait/serde.h
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@
#include <memory>
#include <string>
#include <string_view>
#include <unordered_map>
#include <vector>

#include "arrow/compute/type_fwd.h"
Expand Down Expand Up @@ -183,6 +184,9 @@ ARROW_ENGINE_EXPORT Result<BoundExpressions> DeserializeExpressions(
const ConversionOptions& conversion_options = {},
ExtensionSet* ext_set_out = NULLPTR);

ARROW_ENGINE_EXPORT Status
DeserializeMap(const Buffer& buf, std::unordered_map<std::string, std::string> out);

/// \brief Deserializes a Substrait Type message to the corresponding Arrow type
///
/// \param[in] buf a buffer containing the protobuf serialization of a Substrait Type
Expand Down
2 changes: 1 addition & 1 deletion cpp/thirdparty/versions.txt
Original file line number Diff line number Diff line change
Expand Up @@ -108,7 +108,7 @@ ARROW_SUBSTRAIT_BUILD_SHA256_CHECKSUM=f989a862f694e7dbb695925ddb7c4ce06aa6c51aca
ARROW_S2N_TLS_BUILD_VERSION=v1.3.35
ARROW_S2N_TLS_BUILD_SHA256_CHECKSUM=9d32b26e6bfcc058d98248bf8fc231537e347395dd89cf62bb432b55c5da990d
ARROW_THRIFT_BUILD_VERSION=0.16.0
ARROW_THRIFT_BUILD_SHA256_CHECKSUM=f460b5c1ca30d8918ff95ea3eb6291b3951cf518553566088f3f2be8981f6209
ARROW_THRIFT_BUILD_SHA256_CHECKSUM=df2931de646a366c2e5962af679018bca2395d586e00ba82d09c0379f14f8e7b
ARROW_UCX_BUILD_VERSION=1.12.1
ARROW_UCX_BUILD_SHA256_CHECKSUM=9bef31aed0e28bf1973d28d74d9ac4f8926c43ca3b7010bd22a084e164e31b71
ARROW_UTF8PROC_BUILD_VERSION=v2.7.0
Expand Down
14 changes: 14 additions & 0 deletions java/dataset/pom.xml
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,8 @@
<arrow.cpp.build.dir>../../../cpp/release-build/</arrow.cpp.build.dir>
<parquet.version>1.13.1</parquet.version>
<avro.version>1.11.3</avro.version>
<substrait.version>0.31.0</substrait.version>
<protobuf.version>3.25.3</protobuf.version>
</properties>

<dependencies>
Expand All @@ -48,6 +50,18 @@
<groupId>org.immutables</groupId>
<artifactId>value</artifactId>
</dependency>
<dependency>
<groupId>io.substrait</groupId>
<artifactId>core</artifactId>
<version>${substrait.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>com.google.protobuf</groupId>
<artifactId>protobuf-java</artifactId>
<version>${protobuf.version}</version>
<scope>compile</scope>
</dependency>
<dependency>
<groupId>org.apache.arrow</groupId>
<artifactId>arrow-memory-netty</artifactId>
Expand Down
61 changes: 51 additions & 10 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "arrow/c/helpers.h"
#include "arrow/dataset/api.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_csv.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/filesystem/s3fs.h"
Expand Down Expand Up @@ -154,6 +155,21 @@ arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(
}
}

arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
GetFragmentScanOptions(jint file_format_id,
const std::unordered_map<std::string, std::string>& configs) {
switch (file_format_id) {
#ifdef ARROW_CSV
case 3:
return arrow::dataset::CsvFragmentScanOptions::from(configs);
#endif
default:
std::string error_message =
"illegal file format id: " + std::to_string(file_format_id);
return arrow::Status::Invalid(error_message);
}
}

class ReserveFromJava : public arrow::dataset::jni::ReservationListener {
public:
ReserveFromJava(JavaVM* vm, jobject java_reservation_listener)
Expand Down Expand Up @@ -502,12 +518,13 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset
/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: createScanner
* Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J
* Signature:
* (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ;Ljava/nio/ByteBuffer;J)J
*/
JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns,
jobject substrait_projection, jobject substrait_filter,
jlong batch_size, jlong memory_pool_id) {
jobject substrait_projection, jobject substrait_filter, jlong batch_size,
jlong file_format_id, jobject options, jlong memory_pool_id) {
JNI_METHOD_START
arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
if (pool == nullptr) {
Expand Down Expand Up @@ -556,6 +573,14 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann
}
JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr));
}
if (file_format_id != -1 && options != nullptr) {
std::unordered_map<std::string, std::string> option_map;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, option_map));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(scan_options));
}
JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));

auto scanner = JniGetOrThrow(scanner_builder->Finish());
Expand Down Expand Up @@ -667,14 +692,22 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_ensureS3Fina
/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: (Ljava/lang/String;II)J
* Signature: (Ljava/lang/String;IILjava/lang/String;Ljava/nio/ByteBuffer)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljava_lang_String_2I(
JNIEnv* env, jobject, jstring uri, jint file_format_id) {
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory(
JNIEnv* env, jobject, jstring uri, jint file_format_id, jobject options) {
JNI_METHOD_START
std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (options != nullptr) {
std::unordered_map<std::string, std::string> option_map;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, option_map));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
}
arrow::dataset::FileSystemFactoryOptions options;
std::shared_ptr<arrow::dataset::DatasetFactory> d =
JniGetOrThrow(arrow::dataset::FileSystemDatasetFactory::Make(
Expand All @@ -685,16 +718,24 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljav

/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: ([Ljava/lang/String;II)J
* Method: makeFileSystemDatasetFactoryWithFiles
* Signature: ([Ljava/lang/String;IIJ;Ljava/nio/ByteBuffer)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory___3Ljava_lang_String_2I(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id) {
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactoryWithFiles(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jobject options) {
JNI_METHOD_START

std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (options != nullptr) {
std::unordered_map<std::string, std::string> option_map;
std::shared_ptr<arrow::Buffer> buffer = LoadArrowBufferFromByteBuffer(env, options);
JniAssertOkOrThrow(arrow::engine::DeserializeMap(*buffer, option_map));
std::shared_ptr<arrow::dataset::FragmentScanOptions> scan_options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, option_map));
file_format->default_fragment_scan_options = scan_options;
}
arrow::dataset::FileSystemFactoryOptions options;

std::vector<std::string> uri_vec = ToStringVector(env, uris);
Expand Down
Loading

0 comments on commit 1223f36

Please sign in to comment.