Skip to content

Commit

Permalink
support dataset option
Browse files Browse the repository at this point in the history
  • Loading branch information
jinchengchenghh committed May 14, 2024
1 parent fc7c723 commit 883417a
Show file tree
Hide file tree
Showing 11 changed files with 294 additions and 19 deletions.
53 changes: 46 additions & 7 deletions java/dataset/src/main/cpp/jni_wrapper.cc
Original file line number Diff line number Diff line change
Expand Up @@ -25,6 +25,7 @@
#include "arrow/c/helpers.h"
#include "arrow/dataset/api.h"
#include "arrow/dataset/file_base.h"
#include "arrow/dataset/file_csv.h"
#include "arrow/filesystem/localfs.h"
#include "arrow/filesystem/path_util.h"
#include "arrow/filesystem/s3fs.h"
Expand Down Expand Up @@ -154,6 +155,29 @@ arrow::Result<std::shared_ptr<arrow::dataset::FileFormat>> GetFileFormat(
}
}

arrow::Result<std::shared_ptr<arrow::dataset::FragmentScanOptions>>
GetFragmentScanOptions(jint file_format_id, jlong schema_address) {
switch (file_format_id) {
#ifdef ARROW_CSV
case 3: {
std::shared_ptr<arrow::dataset::CsvFragmentScanOptions> csv_options =
std::make_shared<arrow::dataset::CsvFragmentScanOptions>();
ArrowSchema* cSchema = reinterpret_cast<ArrowSchema*>(schema_address);
auto schema = JniGetOrThrow(arrow::ImportSchema(cSchema));
auto column_types = csv_options->convert_options.column_types;
for (auto field : schema->fields()) {
column_types[field->name()] = field->type();
}
return csv_options;
}
#endif
default:
std::string error_message =
"illegal file format id: " + std::to_string(file_format_id);
return arrow::Status::Invalid(error_message);
}
}

class ReserveFromJava : public arrow::dataset::jni::ReservationListener {
public:
ReserveFromJava(JavaVM* vm, jobject java_reservation_listener)
Expand Down Expand Up @@ -502,12 +526,12 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_closeDataset
/*
* Class: org_apache_arrow_dataset_jni_JniWrapper
* Method: createScanner
* Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJ)J
* Signature: (J[Ljava/lang/String;Ljava/nio/ByteBuffer;Ljava/nio/ByteBuffer;JJJJ)J
*/
JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScanner(
JNIEnv* env, jobject, jlong dataset_id, jobjectArray columns,
jobject substrait_projection, jobject substrait_filter,
jlong batch_size, jlong memory_pool_id) {
jobject substrait_projection, jobject substrait_filter, jlong file_format_id,
jlong schema_address, jlong batch_size, jlong memory_pool_id) {
JNI_METHOD_START
arrow::MemoryPool* pool = reinterpret_cast<arrow::MemoryPool*>(memory_pool_id);
if (pool == nullptr) {
Expand Down Expand Up @@ -556,6 +580,11 @@ JNIEXPORT jlong JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_createScann
}
JniAssertOkOrThrow(scanner_builder->Filter(*filter_expr));
}
if (file_format_id != -1 && schema_address != -1) {
std::shared_ptr<arrow::dataset::FragmentScanOptions> options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, schema_address));
JniAssertOkOrThrow(scanner_builder->FragmentScanOptions(options));
}
JniAssertOkOrThrow(scanner_builder->BatchSize(batch_size));

auto scanner = JniGetOrThrow(scanner_builder->Finish());
Expand Down Expand Up @@ -667,14 +696,19 @@ JNIEXPORT void JNICALL Java_org_apache_arrow_dataset_jni_JniWrapper_ensureS3Fina
/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: (Ljava/lang/String;II)J
* Signature: (Ljava/lang/String;IIJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljava_lang_String_2I(
JNIEnv* env, jobject, jstring uri, jint file_format_id) {
JNIEnv* env, jobject, jstring uri, jint file_format_id, jlong schema_address) {
JNI_METHOD_START
std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (schema_address != -1) {
std::shared_ptr<arrow::dataset::FragmentScanOptions> options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, schema_address));
file_format->default_fragment_scan_options = options;
}
arrow::dataset::FileSystemFactoryOptions options;
std::shared_ptr<arrow::dataset::DatasetFactory> d =
JniGetOrThrow(arrow::dataset::FileSystemDatasetFactory::Make(
Expand All @@ -686,15 +720,20 @@ Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory__Ljav
/*
* Class: org_apache_arrow_dataset_file_JniWrapper
* Method: makeFileSystemDatasetFactory
* Signature: ([Ljava/lang/String;II)J
* Signature: ([Ljava/lang/String;IIJ)J
*/
JNIEXPORT jlong JNICALL
Java_org_apache_arrow_dataset_file_JniWrapper_makeFileSystemDatasetFactory___3Ljava_lang_String_2I(
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id) {
JNIEnv* env, jobject, jobjectArray uris, jint file_format_id, jlong schema_address) {
JNI_METHOD_START

std::shared_ptr<arrow::dataset::FileFormat> file_format =
JniGetOrThrow(GetFileFormat(file_format_id));
if (schema_address != -1) {
std::shared_ptr<arrow::dataset::FragmentScanOptions> options =
JniGetOrThrow(GetFragmentScanOptions(file_format_id, schema_address));
file_format->default_fragment_scan_options = options;
}
arrow::dataset::FileSystemFactoryOptions options;

std::vector<std::string> uri_vec = ToStringVector(env, uris);
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,29 +19,57 @@

import org.apache.arrow.dataset.jni.NativeDatasetFactory;
import org.apache.arrow.dataset.jni.NativeMemoryPool;
import org.apache.arrow.dataset.scanner.FragmentScanOptions;
import org.apache.arrow.dataset.scanner.csv.CsvFragmentScanOptions;
import org.apache.arrow.memory.BufferAllocator;

import java.util.Optional;

/**
* Java binding of the C++ FileSystemDatasetFactory.
*/
public class FileSystemDatasetFactory extends NativeDatasetFactory {

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String uri) {
super(allocator, memoryPool, createNative(format, uri));
String uri, Optional<FragmentScanOptions> fragmentScanOptions) {
super(allocator, memoryPool, createNative(format, uri, fragmentScanOptions));
}

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String uri) {
super(allocator, memoryPool, createNative(format, uri, Optional.empty()));
}

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String[] uris, Optional<FragmentScanOptions> fragmentScanOptions) {
super(allocator, memoryPool, createNative(format, uris, fragmentScanOptions));
}

public FileSystemDatasetFactory(BufferAllocator allocator, NativeMemoryPool memoryPool, FileFormat format,
String[] uris) {
super(allocator, memoryPool, createNative(format, uris));
super(allocator, memoryPool, createNative(format, uris, Optional.empty()));
}

private static long getArrowSchemaAddress(Optional<FragmentScanOptions> fragmentScanOptions) {
if (fragmentScanOptions.isPresent()) {
FragmentScanOptions options = fragmentScanOptions.get();
if (options instanceof CsvFragmentScanOptions) {
return ((CsvFragmentScanOptions) options)
.getConvertOptions().getArrowSchemaAddress();
}
}

return -1;
}

private static long createNative(FileFormat format, String uri) {
return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id());
private static long createNative(FileFormat format, String uri, Optional<FragmentScanOptions> fragmentScanOptions) {
long cSchemaAddress = getArrowSchemaAddress(fragmentScanOptions);
return JniWrapper.get().makeFileSystemDatasetFactory(uri, format.id(), cSchemaAddress);
}

private static long createNative(FileFormat format, String[] uris) {
return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id());
private static long createNative(FileFormat format, String[] uris, Optional<FragmentScanOptions> fragmentScanOptions) {
long cSchemaAddress = getArrowSchemaAddress(fragmentScanOptions);
return JniWrapper.get().makeFileSystemDatasetFactory(uris, format.id(), cSchemaAddress);
}

}
Original file line number Diff line number Diff line change
Expand Up @@ -43,7 +43,7 @@ private JniWrapper() {
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String uri, int fileFormat);
public native long makeFileSystemDatasetFactory(String uri, int fileFormat, long schemaAddress);

/**
* Create FileSystemDatasetFactory and return its native pointer. The pointer is pointing to a
Expand All @@ -54,7 +54,7 @@ private JniWrapper() {
* @return the native pointer of the arrow::dataset::FileSystemDatasetFactory instance.
* @see FileFormat
*/
public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat);
public native long makeFileSystemDatasetFactory(String[] uris, int fileFormat, long cSchemaAddress);

/**
* Write the content in a {@link org.apache.arrow.c.ArrowArrayStream} into files. This internally
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -80,7 +80,8 @@ private JniWrapper() {
* @return the native pointer of the arrow::dataset::Scanner instance.
*/
public native long createScanner(long datasetId, String[] columns, ByteBuffer substraitProjection,
ByteBuffer substraitFilter, long batchSize, long memoryPool);
ByteBuffer substraitFilter, long batchSize, long fileFormat,
long schemaAddress, long memoryPool);

/**
* Get a serialized schema from native instance of a Scanner.
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,9 @@

package org.apache.arrow.dataset.jni;

import org.apache.arrow.dataset.scanner.FragmentScanOptions;
import org.apache.arrow.dataset.scanner.ScanOptions;
import org.apache.arrow.dataset.scanner.csv.CsvFragmentScanOptions;
import org.apache.arrow.dataset.source.Dataset;

/**
Expand All @@ -40,11 +42,20 @@ public synchronized NativeScanner newScan(ScanOptions options) {
if (closed) {
throw new NativeInstanceReleasedException();
}

long cSchemaAddress = -1;
int fileFormat = -1;
if (options.getFragmentScanOptions().isPresent()) {
FragmentScanOptions fragmentScanOptions = options.getFragmentScanOptions().get();
if (fragmentScanOptions instanceof CsvFragmentScanOptions) {
cSchemaAddress = ((CsvFragmentScanOptions) fragmentScanOptions)
.getConvertOptions().getArrowSchemaAddress();
}
fileFormat = fragmentScanOptions.fileFormatId();
}
long scannerId = JniWrapper.get().createScanner(datasetId, options.getColumns().orElse(null),
options.getSubstraitProjection().orElse(null),
options.getSubstraitFilter().orElse(null),
options.getBatchSize(), context.getMemoryPool().getNativeInstanceId());
options.getBatchSize(), fileFormat, cSchemaAddress, context.getMemoryPool().getNativeInstanceId());

return new NativeScanner(context, scannerId);
}
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.dataset.scanner;

public interface FragmentScanOptions {
String typeName();

int fileFormatId();
}
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,8 @@ public class ScanOptions {
private final Optional<ByteBuffer> substraitProjection;
private final Optional<ByteBuffer> substraitFilter;

private final Optional<FragmentScanOptions> fragmentScanOptions;

/**
* Constructor.
* @param columns Projected columns. Empty for scanning all columns.
Expand Down Expand Up @@ -61,6 +63,7 @@ public ScanOptions(long batchSize, Optional<String[]> columns) {
this.columns = columns;
this.substraitProjection = Optional.empty();
this.substraitFilter = Optional.empty();
this.fragmentScanOptions = Optional.empty();
}

public ScanOptions(long batchSize) {
Expand All @@ -83,6 +86,10 @@ public Optional<ByteBuffer> getSubstraitFilter() {
return substraitFilter;
}

public Optional<FragmentScanOptions> getFragmentScanOptions() {
return fragmentScanOptions;
}

/**
* Builder for Options used during scanning.
*/
Expand All @@ -91,6 +98,7 @@ public static class Builder {
private Optional<String[]> columns;
private ByteBuffer substraitProjection;
private ByteBuffer substraitFilter;
private FragmentScanOptions fragmentScanOptions;

/**
* Constructor.
Expand Down Expand Up @@ -136,6 +144,12 @@ public Builder substraitFilter(ByteBuffer substraitFilter) {
return this;
}

public Builder fragmentScanOptions(FragmentScanOptions fragmentScanOptions) {
Preconditions.checkNotNull(fragmentScanOptions);
this.fragmentScanOptions = fragmentScanOptions;
return this;
}

public ScanOptions build() {
return new ScanOptions(this);
}
Expand All @@ -146,5 +160,6 @@ private ScanOptions(Builder builder) {
columns = builder.columns;
substraitProjection = Optional.ofNullable(builder.substraitProjection);
substraitFilter = Optional.ofNullable(builder.substraitFilter);
fragmentScanOptions = Optional.ofNullable(builder.fragmentScanOptions);
}
}
Original file line number Diff line number Diff line change
@@ -0,0 +1,47 @@
/*
* Licensed to the Apache Software Foundation (ASF) under one or more
* contributor license agreements. See the NOTICE file distributed with
* this work for additional information regarding copyright ownership.
* The ASF licenses this file to You under the Apache License, Version 2.0
* (the "License"); you may not use this file except in compliance with
* the License. You may obtain a copy of the License at
*
* http://www.apache.org/licenses/LICENSE-2.0
*
* Unless required by applicable law or agreed to in writing, software
* distributed under the License is distributed on an "AS IS" BASIS,
* WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
* See the License for the specific language governing permissions and
* limitations under the License.
*/

package org.apache.arrow.dataset.scanner.csv;

import org.apache.arrow.c.ArrowSchema;

import java.util.Map;
import java.util.Optional;

public class CsvConvertOptions {

private final Map<String, String> configs;

private Optional<ArrowSchema> cSchema = Optional.empty();

public CsvConvertOptions(Map<String, String> configs) {
this.configs = configs;
}

public Optional<ArrowSchema> getArrowSchema() {
return cSchema;
}

public long getArrowSchemaAddress() {
return cSchema.isPresent() ? cSchema.get().memoryAddress() : -1;
}

public void setArrowSchema(ArrowSchema cSchema) {
this.cSchema = Optional.of(cSchema);
}

}
Loading

0 comments on commit 883417a

Please sign in to comment.