Skip to content

Commit

Permalink
[FEAT] Support parquet RLE decoding for booleans (#3477)
Browse files Browse the repository at this point in the history
#3329 shows that we do not
currently support reading boolean values from parquet files when they
are RLE-encoded.

This PR adds support for this.
  • Loading branch information
desmondcheongzx authored Dec 4, 2024
1 parent de4fe50 commit d1d0fab
Show file tree
Hide file tree
Showing 3 changed files with 159 additions and 1 deletion.
38 changes: 37 additions & 1 deletion src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@ use std::collections::VecDeque;

use parquet2::{
deserialize::SliceFilteredIter,
encoding::Encoding,
encoding::{hybrid_rle, Encoding},
page::{split_buffer, DataPage, DictPage},
schema::Repetition,
};
Expand Down Expand Up @@ -51,6 +51,25 @@ impl<'a> Required<'a> {
}
}

#[derive(Debug)]
struct ValuesRle<'a>(hybrid_rle::HybridRleDecoder<'a>);

impl<'a> ValuesRle<'a> {
pub fn try_new(page: &'a DataPage) -> Result<Self> {
let (_, _, indices_buffer) = split_buffer(page)?;
// Skip the u32 length prefix.
let indices_buffer = &indices_buffer[std::mem::size_of::<u32>()..];
let decoder =
hybrid_rle::HybridRleDecoder::try_new(
indices_buffer,
1_u32, // The bit width for a boolean is 1.
page.num_values()
)
.map_err(crate::error::Error::from)?;
Ok(Self(decoder))
}
}

#[derive(Debug)]
struct FilteredRequired<'a> {
values: SliceFilteredIter<BitmapIter<'a>>,
Expand Down Expand Up @@ -79,6 +98,7 @@ impl<'a> FilteredRequired<'a> {
enum State<'a> {
Optional(OptionalPageValidity<'a>, Values<'a>),
Required(Required<'a>),
OptionalRle(OptionalPageValidity<'a>, ValuesRle<'a>),
FilteredRequired(FilteredRequired<'a>),
FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>),
}
Expand All @@ -88,6 +108,7 @@ impl<'a> State<'a> {
match self {
State::Optional(validity, _) => validity.len(),
State::Required(page) => page.length - page.offset,
Self::OptionalRle(validity, _) => validity.len(),
State::FilteredRequired(page) => page.len(),
State::FilteredOptional(optional, _) => optional.len(),
}
Expand Down Expand Up @@ -125,6 +146,12 @@ impl<'a> Decoder<'a> for BooleanDecoder {
Values::try_new(page)?,
)),
(Encoding::Plain, false, false) => Ok(State::Required(Required::new(page))),
(Encoding::Rle, true, false) => {
Ok(State::OptionalRle(
OptionalPageValidity::try_new(page)?,
ValuesRle::try_new(page)?,
))
},
(Encoding::Plain, true, true) => Ok(State::FilteredOptional(
FilteredOptionalPageValidity::try_new(page)?,
Values::try_new(page)?,
Expand Down Expand Up @@ -163,6 +190,15 @@ impl<'a> Decoder<'a> for BooleanDecoder {
values.extend_from_slice(page.values, page.offset, remaining);
page.offset += remaining;
}
State::OptionalRle(page_validity, page_values) => {
utils::extend_from_decoder(
validity,
page_validity,
Some(remaining),
values,
&mut page_values.0.by_ref().map(|x| x.unwrap()).map(|v| v != 0),
)
},
State::FilteredRequired(page) => {
values.reserve(remaining);
for item in page.values.by_ref().take(remaining) {
Expand Down
104 changes: 104 additions & 0 deletions src/parquet2/src/encoding/hybrid_rle/decoder.rs
Original file line number Diff line number Diff line change
Expand Up @@ -141,4 +141,108 @@ mod tests {
panic!()
};
}

#[test]
fn test_bool_bitpacked() {
let bit_width = 1usize;
let length = 4;
let values = [
2, 0, 0, 0, // Length indicator as u32.
0b00000011, // Bitpacked indicator with 1 value (1 << 1 | 1).
0b00001101, // Values (true, false, true, true).
];
let expected = &[1, 0, 1, 1];

let mut decoder = Decoder::new(&values[4..], bit_width);

if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() {
assert_eq!(values, &[0b00001101]);

let result = bitpacked::Decoder::<u8>::try_new(values, bit_width, length)
.unwrap()
.collect::<Vec<_>>();
assert_eq!(result, expected);
} else {
panic!("Expected bitpacked encoding");
}
}

#[test]
fn test_bool_rle() {
let bit_width = 1usize;
let length = 4;
let values = [
2, 0, 0, 0, // Length indicator as u32.
0b00001000, // RLE indicator (4 << 1 | 0).
true as u8 // Value to repeat.
];

let mut decoder = Decoder::new(&values[4..], bit_width);

if let Ok(HybridEncoded::Rle(value, run_length)) = decoder.next().unwrap() {
assert_eq!(value, &[1u8]); // true encoded as 1.
assert_eq!(run_length, length); // Repeated 4 times.
} else {
panic!("Expected RLE encoding");
}
}

#[test]
fn test_bool_mixed_rle() {
let bit_width = 1usize;
let values = [
4, 0, 0, 0, // Length indicator as u32.
0b00000011, // Bitpacked indicator with 1 value (1 << 1 | 1).
0b00001101, // Values (true, false, true, true).
0b00001000, // RLE indicator (4 << 1 | 0)
false as u8 // RLE value
];

let mut decoder = Decoder::new(&values[4..], bit_width);

// Decode bitpacked values.
if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() {
assert_eq!(values, &[0b00001101]);
} else {
panic!("Expected bitpacked encoding");
}

// Decode RLE values.
if let Ok(HybridEncoded::Rle(value, run_length)) = decoder.next().unwrap() {
assert_eq!(value, &[0u8]); // false encoded as 0.
assert_eq!(run_length, 4);
} else {
panic!("Expected RLE encoding");
}
}

#[test]
fn test_bool_nothing_encoded() {
let bit_width = 1usize;
let values = [0, 0, 0, 0]; // Length indicator only.

let mut decoder = Decoder::new(&values[4..], bit_width);
assert!(decoder.next().is_none());
}

#[test]
fn test_bool_invalid_encoding() {
let bit_width = 1usize;
let values = [
2, 0, 0, 0, // Length indicator as u32.
0b00000101, // Bitpacked indicator with 1 value (2 << 1 | 1).
true as u8 // Incomplete encoding (should have another u8).
];

let mut decoder = Decoder::new(&values[4..], bit_width);

if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() {
assert_eq!(values, &[1u8]);
} else {
panic!("Expected bitpacked encoding");
}

// Next call should return None since we've exhausted the buffer.
assert!(decoder.next().is_none());
}
}
18 changes: 18 additions & 0 deletions tests/io/test_parquet_roundtrip.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,9 +2,11 @@

import datetime
import decimal
import random

import numpy as np
import pyarrow as pa
import pyarrow.parquet as papq
import pytest

import daft
Expand Down Expand Up @@ -150,6 +152,22 @@ def test_roundtrip_sparse_tensor_types(tmp_path, fixed_shape):
assert before.to_arrow() == after.to_arrow()


@pytest.mark.parametrize("has_none", [True, False])
def test_roundtrip_boolean_rle(tmp_path, has_none):
file_path = f"{tmp_path}/test.parquet"
if has_none:
# Create an array of random True/False values that are None 10% of the time.
random_bools = random.choices([True, False, None], weights=[45, 45, 10], k=1000_000)
else:
# Create an array of random True/False values.
random_bools = random.choices([True, False], k=1000_000)
pa_original = pa.table({"bools": pa.array(random_bools, type=pa.bool_())})
# Use data page version 2.0 which uses RLE encoding for booleans.
papq.write_table(pa_original, file_path, data_page_version="2.0")
df_roundtrip = daft.read_parquet(file_path)
assert pa_original == df_roundtrip.to_arrow()


# TODO: reading/writing:
# 1. Embedding type
# 2. Image type
Expand Down

0 comments on commit d1d0fab

Please sign in to comment.