From 6d97d446fb80c316f790b401dbf480658fa9519e Mon Sep 17 00:00:00 2001 From: Chengcheng Jin Date: Mon, 20 May 2024 10:51:30 +0000 Subject: [PATCH] change to serialize ExpressionLiteral --- .../engine/substrait/expression_internal.cc | 12 +++++++++++ .../engine/substrait/expression_internal.h | 4 ++++ cpp/src/arrow/engine/substrait/serde.cc | 9 ++++++--- .../dataset/scanner/FragmentScanOptions.java | 9 ++++++--- .../dataset/substrait/util/ConvertUtil.java | 20 ++----------------- 5 files changed, 30 insertions(+), 24 deletions(-) diff --git a/cpp/src/arrow/engine/substrait/expression_internal.cc b/cpp/src/arrow/engine/substrait/expression_internal.cc index 480cf30d3033f..8ec68a42cb96a 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.cc +++ b/cpp/src/arrow/engine/substrait/expression_internal.cc @@ -1537,5 +1537,17 @@ Result> ToProto( return std::move(out); } +Status FromProto(const substrait::Expression::Literal& literal, + std::unordered_map& out) { + ARROW_RETURN_IF(!literal.has_map(), Status::Invalid("Literal does not have a map.")); + auto literalMap = literal.map(); + auto size = literalMap.key_values_size(); + for (auto i = 0; i < size; i++) { + substrait::Expression_Literal_Map_KeyValue keyValue = literalMap.key_values(i); + out.emplace(keyValue.key().string(), keyValue.value().string()); + } + return Status::OK(); +} + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/expression_internal.h b/cpp/src/arrow/engine/substrait/expression_internal.h index 2ce2ee76af20b..9be81b7ab674e 100644 --- a/cpp/src/arrow/engine/substrait/expression_internal.h +++ b/cpp/src/arrow/engine/substrait/expression_internal.h @@ -61,5 +61,9 @@ ARROW_ENGINE_EXPORT Result FromProto(const substrait::AggregateFunction&, bool is_hash, const ExtensionSet&, const ConversionOptions&); +ARROW_ENGINE_EXPORT +Status FromProto(const substrait::Expression::Literal& literal, + std::unordered_map& out); + } // namespace engine } // namespace arrow diff --git a/cpp/src/arrow/engine/substrait/serde.cc b/cpp/src/arrow/engine/substrait/serde.cc index 8d97410772eb8..c4a3bb72886ad 100644 --- a/cpp/src/arrow/engine/substrait/serde.cc +++ b/cpp/src/arrow/engine/substrait/serde.cc @@ -250,9 +250,12 @@ Result DeserializeExpressions( Status DeserializeMap(const Buffer& buf, std::unordered_map out) { - ARROW_ASSIGN_OR_RAISE(auto advanced_extension, - ParseFromBuffer(buf)); - return FromProto(advanced_extension, out); + // ARROW_ASSIGN_OR_RAISE(auto advanced_extension, + // ParseFromBuffer(buf)); + // return FromProto(advanced_extension, out); + ARROW_ASSIGN_OR_RAISE(auto literal, + ParseFromBuffer(buf)); + return FromProto(literal, out); } namespace { diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java index 844448411a0c3..bd83f0d7e879f 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/scanner/FragmentScanOptions.java @@ -23,6 +23,7 @@ import org.apache.arrow.dataset.substrait.util.ConvertUtil; import io.substrait.proto.AdvancedExtension; +import io.substrait.proto.Expression; public interface FragmentScanOptions { String typeName(); @@ -42,9 +43,11 @@ default ByteBuffer serializeMap(Map config) { return null; } - AdvancedExtension extension = ConvertUtil.expressionToExtension(ConvertUtil.mapToExpression(config)); - ByteBuffer buf = ByteBuffer.allocateDirect(extension.getSerializedSize()); - buf.put(extension.toByteArray()); + Expression.Literal literal = ConvertUtil.mapToExpressionLiteral(config); + +// AdvancedExtension extension = ConvertUtil.expressionToExtension(ConvertUtil.mapToExpression(config)); + ByteBuffer buf = ByteBuffer.allocateDirect(literal.getSerializedSize()); + buf.put(literal.toByteArray()); return buf; } } diff --git a/java/dataset/src/main/java/org/apache/arrow/dataset/substrait/util/ConvertUtil.java b/java/dataset/src/main/java/org/apache/arrow/dataset/substrait/util/ConvertUtil.java index 658ac19345f93..31a4023af727b 100644 --- a/java/dataset/src/main/java/org/apache/arrow/dataset/substrait/util/ConvertUtil.java +++ b/java/dataset/src/main/java/org/apache/arrow/dataset/substrait/util/ConvertUtil.java @@ -19,9 +19,6 @@ import java.util.Map; -import com.google.protobuf.Any; - -import io.substrait.proto.AdvancedExtension; import io.substrait.proto.Expression; public class ConvertUtil { @@ -31,7 +28,7 @@ public class ConvertUtil { * * @return Substrait Expression */ - public static Expression mapToExpression(Map values) { + public static Expression.Literal mapToExpressionLiteral(Map values) { Expression.Literal.Builder literalBuilder = Expression.Literal.newBuilder(); Expression.Literal.Map.KeyValue.Builder keyValueBuilder = Expression.Literal.Map.KeyValue.newBuilder(); @@ -44,19 +41,6 @@ public static Expression mapToExpression(Map values) { mapBuilder.addKeyValues(keyValueBuilder.build()); } literalBuilder.setMap(mapBuilder.build()); - return Expression.newBuilder().setLiteral(literalBuilder.build()).build(); - } - - /** - * Add substrait expression to AdvancedExtension. - * - * @param expr Substrait Expression. - * @return Substrait AdvancedExtension - */ - public static AdvancedExtension expressionToExtension(Expression expr) { - AdvancedExtension.Builder extensionBuilder = AdvancedExtension.newBuilder(); - Any.Builder builder = extensionBuilder.getEnhancementBuilder(); - builder.setValue(expr.toByteString()); - return extensionBuilder.build(); + return literalBuilder.build(); } }