From d1d0faba9a908db9db62cc85ea4319b8305ca4f6 Mon Sep 17 00:00:00 2001 From: Desmond Cheong Date: Wed, 4 Dec 2024 13:42:38 -0800 Subject: [PATCH] [FEAT] Support parquet RLE decoding for booleans (#3477) https://github.com/Eventual-Inc/Daft/issues/3329 shows that we do not currently support reading boolean values from parquet files when they are RLE-encoded. This PR adds support for this. --- .../parquet/read/deserialize/boolean/basic.rs | 38 ++++++- .../src/encoding/hybrid_rle/decoder.rs | 104 ++++++++++++++++++ tests/io/test_parquet_roundtrip.py | 18 +++ 3 files changed, 159 insertions(+), 1 deletion(-) diff --git a/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs b/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs index d12bff3ece..85224f29f2 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use parquet2::{ deserialize::SliceFilteredIter, - encoding::Encoding, + encoding::{hybrid_rle, Encoding}, page::{split_buffer, DataPage, DictPage}, schema::Repetition, }; @@ -51,6 +51,25 @@ impl<'a> Required<'a> { } } +#[derive(Debug)] +struct ValuesRle<'a>(hybrid_rle::HybridRleDecoder<'a>); + +impl<'a> ValuesRle<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, indices_buffer) = split_buffer(page)?; + // Skip the u32 length prefix. + let indices_buffer = &indices_buffer[std::mem::size_of::()..]; + let decoder = + hybrid_rle::HybridRleDecoder::try_new( + indices_buffer, + 1_u32, // The bit width for a boolean is 1. + page.num_values() + ) + .map_err(crate::error::Error::from)?; + Ok(Self(decoder)) + } +} + #[derive(Debug)] struct FilteredRequired<'a> { values: SliceFilteredIter>, @@ -79,6 +98,7 @@ impl<'a> FilteredRequired<'a> { enum State<'a> { Optional(OptionalPageValidity<'a>, Values<'a>), Required(Required<'a>), + OptionalRle(OptionalPageValidity<'a>, ValuesRle<'a>), FilteredRequired(FilteredRequired<'a>), FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>), } @@ -88,6 +108,7 @@ impl<'a> State<'a> { match self { State::Optional(validity, _) => validity.len(), State::Required(page) => page.length - page.offset, + Self::OptionalRle(validity, _) => validity.len(), State::FilteredRequired(page) => page.len(), State::FilteredOptional(optional, _) => optional.len(), } @@ -125,6 +146,12 @@ impl<'a> Decoder<'a> for BooleanDecoder { Values::try_new(page)?, )), (Encoding::Plain, false, false) => Ok(State::Required(Required::new(page))), + (Encoding::Rle, true, false) => { + Ok(State::OptionalRle( + OptionalPageValidity::try_new(page)?, + ValuesRle::try_new(page)?, + )) + }, (Encoding::Plain, true, true) => Ok(State::FilteredOptional( FilteredOptionalPageValidity::try_new(page)?, Values::try_new(page)?, @@ -163,6 +190,15 @@ impl<'a> Decoder<'a> for BooleanDecoder { values.extend_from_slice(page.values, page.offset, remaining); page.offset += remaining; } + State::OptionalRle(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.0.by_ref().map(|x| x.unwrap()).map(|v| v != 0), + ) + }, State::FilteredRequired(page) => { values.reserve(remaining); for item in page.values.by_ref().take(remaining) { diff --git a/src/parquet2/src/encoding/hybrid_rle/decoder.rs b/src/parquet2/src/encoding/hybrid_rle/decoder.rs index 4417cc55a0..625f9df36e 100644 --- a/src/parquet2/src/encoding/hybrid_rle/decoder.rs +++ b/src/parquet2/src/encoding/hybrid_rle/decoder.rs @@ -141,4 +141,108 @@ mod tests { panic!() }; } + + #[test] + fn test_bool_bitpacked() { + let bit_width = 1usize; + let length = 4; + let values = [ + 2, 0, 0, 0, // Length indicator as u32. + 0b00000011, // Bitpacked indicator with 1 value (1 << 1 | 1). + 0b00001101, // Values (true, false, true, true). + ]; + let expected = &[1, 0, 1, 1]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() { + assert_eq!(values, &[0b00001101]); + + let result = bitpacked::Decoder::::try_new(values, bit_width, length) + .unwrap() + .collect::>(); + assert_eq!(result, expected); + } else { + panic!("Expected bitpacked encoding"); + } + } + + #[test] + fn test_bool_rle() { + let bit_width = 1usize; + let length = 4; + let values = [ + 2, 0, 0, 0, // Length indicator as u32. + 0b00001000, // RLE indicator (4 << 1 | 0). + true as u8 // Value to repeat. + ]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + if let Ok(HybridEncoded::Rle(value, run_length)) = decoder.next().unwrap() { + assert_eq!(value, &[1u8]); // true encoded as 1. + assert_eq!(run_length, length); // Repeated 4 times. + } else { + panic!("Expected RLE encoding"); + } + } + + #[test] + fn test_bool_mixed_rle() { + let bit_width = 1usize; + let values = [ + 4, 0, 0, 0, // Length indicator as u32. + 0b00000011, // Bitpacked indicator with 1 value (1 << 1 | 1). + 0b00001101, // Values (true, false, true, true). + 0b00001000, // RLE indicator (4 << 1 | 0) + false as u8 // RLE value + ]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + // Decode bitpacked values. + if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() { + assert_eq!(values, &[0b00001101]); + } else { + panic!("Expected bitpacked encoding"); + } + + // Decode RLE values. + if let Ok(HybridEncoded::Rle(value, run_length)) = decoder.next().unwrap() { + assert_eq!(value, &[0u8]); // false encoded as 0. + assert_eq!(run_length, 4); + } else { + panic!("Expected RLE encoding"); + } + } + + #[test] + fn test_bool_nothing_encoded() { + let bit_width = 1usize; + let values = [0, 0, 0, 0]; // Length indicator only. + + let mut decoder = Decoder::new(&values[4..], bit_width); + assert!(decoder.next().is_none()); + } + + #[test] + fn test_bool_invalid_encoding() { + let bit_width = 1usize; + let values = [ + 2, 0, 0, 0, // Length indicator as u32. + 0b00000101, // Bitpacked indicator with 1 value (2 << 1 | 1). + true as u8 // Incomplete encoding (should have another u8). + ]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() { + assert_eq!(values, &[1u8]); + } else { + panic!("Expected bitpacked encoding"); + } + + // Next call should return None since we've exhausted the buffer. + assert!(decoder.next().is_none()); + } } diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 8804ef6d33..81d5689600 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -2,9 +2,11 @@ import datetime import decimal +import random import numpy as np import pyarrow as pa +import pyarrow.parquet as papq import pytest import daft @@ -150,6 +152,22 @@ def test_roundtrip_sparse_tensor_types(tmp_path, fixed_shape): assert before.to_arrow() == after.to_arrow() +@pytest.mark.parametrize("has_none", [True, False]) +def test_roundtrip_boolean_rle(tmp_path, has_none): + file_path = f"{tmp_path}/test.parquet" + if has_none: + # Create an array of random True/False values that are None 10% of the time. + random_bools = random.choices([True, False, None], weights=[45, 45, 10], k=1000_000) + else: + # Create an array of random True/False values. + random_bools = random.choices([True, False], k=1000_000) + pa_original = pa.table({"bools": pa.array(random_bools, type=pa.bool_())}) + # Use data page version 2.0 which uses RLE encoding for booleans. + papq.write_table(pa_original, file_path, data_page_version="2.0") + df_roundtrip = daft.read_parquet(file_path) + assert pa_original == df_roundtrip.to_arrow() + + # TODO: reading/writing: # 1. Embedding type # 2. Image type