diff --git a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java index 13f25774e218b..db1ad0b31b3a8 100644 --- a/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java +++ b/java/vector/src/main/java/org/apache/arrow/vector/DecimalVector.java @@ -212,7 +212,6 @@ public void set(int index, ArrowBuf buffer) { * @param value array of bytes containing decimal in big endian byte order. */ public void setBigEndian(int index, byte[] value) { - assert value.length <= TYPE_WIDTH; BitVectorHelper.setValidityBitToOne(validityBuffer, index); final int length = value.length; int startIndex = index * TYPE_WIDTH; @@ -224,13 +223,32 @@ public void setBigEndian(int index, byte[] value) { valueBuffer.setByte(startIndex + 3, value[i-3]); startIndex += 4; } - } else { + + return; + } + + if (length == 0) { + valueBuffer.setZero(startIndex, TYPE_WIDTH); + return; + } + + if (length < 16) { for (int i = length - 1; i >= 0; i--) { valueBuffer.setByte(startIndex, value[i]); startIndex++; } - valueBuffer.setZero(startIndex, TYPE_WIDTH - length); + + final byte pad = (byte) (value[0] < 0 ? 0xFF : 0x00); + final int maxStartIndex = (index + 1) * TYPE_WIDTH; + while (startIndex < maxStartIndex) { + valueBuffer.setByte(startIndex, pad); + startIndex++; + } + + return; } + + throw new IllegalArgumentException("Invalid decimal value length. Valid length in [1 - 16], got " + length); } /** @@ -468,4 +486,4 @@ public void copyValueSafe(int fromIndex, int toIndex) { to.copyFromSafe(fromIndex, toIndex, DecimalVector.this); } } -} \ No newline at end of file +} diff --git a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java index 8c86452fcc3bf..15c56ae2bc382 100644 --- a/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java +++ b/java/vector/src/test/java/org/apache/arrow/vector/TestDecimalVector.java @@ -20,6 +20,7 @@ import static org.junit.Assert.assertEquals; import static org.junit.Assert.assertTrue; +import static org.junit.Assert.fail; import java.math.BigDecimal; import java.math.BigInteger; @@ -190,4 +191,57 @@ public void testBigDecimalReadWrite() { assertEquals(decimal8, decimalVector.getObject(7)); } } + + /** + * Test {@link DecimalVector#setBigEndian(int, byte[])} which takes BE layout input and stores in LE layout. + * Cases to cover: value given in byte array in different lengths in range [1-16] and negative values. + */ + @Test + public void decimalBE2LE() { + try (DecimalVector decimalVector = TestUtils.newVector(DecimalVector.class, "decimal", new ArrowType.Decimal(21, 2), allocator);) { + decimalVector.allocateNew(); + + BigInteger[] testBigInts = new BigInteger[] { + new BigInteger("0"), + new BigInteger("-1"), + new BigInteger("23"), + new BigInteger("234234"), + new BigInteger("-234234234"), + new BigInteger("234234234234"), + new BigInteger("-56345345345345"), + new BigInteger("29823462983462893462934679234653456345"), // converts to 16 byte array + new BigInteger("-3894572983475982374598324598234346536"), // converts to 16 byte array + new BigInteger("-345345"), + new BigInteger("754533") + }; + + int insertionIdx = 0; + insertionIdx++; // insert a null + for (BigInteger val : testBigInts) { + decimalVector.setBigEndian(insertionIdx++, val.toByteArray()); + } + insertionIdx++; // insert a null + // insert a zero length buffer + decimalVector.setBigEndian(insertionIdx++, new byte[0]); + + // Try inserting a buffer larger than 16bytes and expect a failure + try { + decimalVector.setBigEndian(insertionIdx, new byte[17]); + fail("above statement should have failed"); + } catch (IllegalArgumentException ex) { + assertTrue(ex.getMessage().equals("Invalid decimal value length. Valid length in [1 - 16], got 17")); + } + decimalVector.setValueCount(insertionIdx); + + // retrieve values and check if they are correct + int outputIdx = 0; + assertTrue(decimalVector.isNull(outputIdx++)); + for (BigInteger expected : testBigInts) { + final BigDecimal actual = decimalVector.getObject(outputIdx++); + assertEquals(expected, actual.unscaledValue()); + } + assertTrue(decimalVector.isNull(outputIdx++)); + assertEquals(BigInteger.valueOf(0), decimalVector.getObject(outputIdx).unscaledValue()); + } + } }