diff --git a/java/vector/src/main/codegen/templates/UnionVector.java b/java/vector/src/main/codegen/templates/UnionVector.java index ea79c5c2fba76..7724750a83d69 100644 --- a/java/vector/src/main/codegen/templates/UnionVector.java +++ b/java/vector/src/main/codegen/templates/UnionVector.java @@ -23,6 +23,7 @@ import org.apache.arrow.util.Preconditions; import org.apache.arrow.vector.BaseValueVector; import org.apache.arrow.vector.BitVectorHelper; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.FieldVector; import org.apache.arrow.vector.ValueVector; import org.apache.arrow.vector.complex.AbstractStructVector; @@ -278,7 +279,10 @@ public StructVector getStruct() { <#if minor.class?starts_with("Decimal") || is_timestamp_tz(minor.class) || minor.class == "Duration" || minor.class == "FixedSizeBinary"> public ${name}Vector get${name}Vector() { if (${uncappedName}Vector == null) { - throw new IllegalArgumentException("No ${name} present. Provide ArrowType argument to create a new vector"); + ${uncappedName}Vector = internalStruct.getChild(fieldName(MinorType.${name?upper_case}), ${name}Vector.class); + if (${uncappedName}Vector == null) { + throw new IllegalArgumentException("No ${name} present. Provide ArrowType argument to create a new vector"); + } } return ${uncappedName}Vector; } diff --git a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java index 1068f7c030eb5..a68fde2b1ad4b 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/complex/impl/TestPromotableWriter.java @@ -22,11 +22,13 @@ import static org.junit.Assert.assertNull; import static org.junit.jupiter.api.Assertions.assertThrows; +import java.math.BigDecimal; import java.nio.ByteBuffer; import java.nio.ByteOrder; import org.apache.arrow.memory.ArrowBuf; import org.apache.arrow.memory.BufferAllocator; +import org.apache.arrow.vector.DecimalVector; import org.apache.arrow.vector.DirtyRootAllocator; import org.apache.arrow.vector.complex.ListVector; import org.apache.arrow.vector.complex.NonNullableStructVector; @@ -35,14 +37,18 @@ import org.apache.arrow.vector.complex.writer.BaseWriter.StructWriter; import org.apache.arrow.vector.holders.DurationHolder; import org.apache.arrow.vector.holders.FixedSizeBinaryHolder; +import org.apache.arrow.vector.holders.NullableDecimalHolder; +import org.apache.arrow.vector.holders.NullableIntHolder; import org.apache.arrow.vector.holders.NullableTimeStampMilliTZHolder; import org.apache.arrow.vector.holders.TimeStampMilliTZHolder; +import org.apache.arrow.vector.holders.UnionHolder; import org.apache.arrow.vector.types.TimeUnit; import org.apache.arrow.vector.types.Types; import org.apache.arrow.vector.types.pojo.ArrowType; import org.apache.arrow.vector.types.pojo.ArrowType.ArrowTypeID; import org.apache.arrow.vector.types.pojo.Field; import org.apache.arrow.vector.types.pojo.FieldType; +import org.apache.arrow.vector.util.DecimalUtility; import org.junit.After; import org.junit.Before; import org.junit.Test; @@ -392,4 +398,41 @@ public void testNoPromoteFixedSizeBinaryToUnionWithNull() throws Exception { buf.close(); } } + + @Test + public void testPromoteToUnionFromDecimal() throws Exception { + try (final NonNullableStructVector container = NonNullableStructVector.empty(EMPTY_SCHEMA_PATH, allocator); + final DecimalVector v = container.addOrGet("dec", + FieldType.nullable(new ArrowType.Decimal(38, 1, 128)), DecimalVector.class); + final PromotableWriter writer = new PromotableWriter(v, container)) { + + container.allocateNew(); + container.setValueCount(1); + + writer.setPosition(0); + writer.writeDecimal(new BigDecimal("0.1")); + writer.setPosition(1); + writer.writeInt(1); + + container.setValueCount(3); + + UnionVector unionVector = (UnionVector) container.getChild("dec"); + UnionHolder holder = new UnionHolder(); + + unionVector.get(0, holder); + NullableDecimalHolder decimalHolder = new NullableDecimalHolder(); + holder.reader.read(decimalHolder); + + assertEquals(1, decimalHolder.isSet); + assertEquals(new BigDecimal("0.1"), + DecimalUtility.getBigDecimalFromArrowBuf(decimalHolder.buffer, 0, decimalHolder.scale, 128)); + + unionVector.get(1, holder); + NullableIntHolder intHolder = new NullableIntHolder(); + holder.reader.read(intHolder); + + assertEquals(1, intHolder.isSet); + assertEquals(1, intHolder.value); + } + } }