diff --git a/.github/assets/.template.yaml b/.github/assets/template.yaml similarity index 95% rename from .github/assets/.template.yaml rename to .github/assets/template.yaml index 726db29c85..d631a34152 100644 --- a/.github/assets/.template.yaml +++ b/.github/assets/template.yaml @@ -48,4 +48,4 @@ setup_commands: - uv v - echo "source $HOME/.venv/bin/activate" >> $HOME/.bashrc - source .venv/bin/activate -- uv pip install pip ray[default] py-spy \{{DAFT_INSTALL}} +- uv pip install pip ray[default] py-spy \{{DAFT_INSTALL}} \{{OTHER_INSTALLS}} diff --git a/.github/ci-scripts/read_inline_metadata.py b/.github/ci-scripts/read_inline_metadata.py new file mode 100644 index 0000000000..1ffff3d110 --- /dev/null +++ b/.github/ci-scripts/read_inline_metadata.py @@ -0,0 +1,25 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = [] +# /// + +import re + +import tomllib + +REGEX = r"(?m)^# /// (?P[a-zA-Z0-9-]+)$\s(?P(^#(| .*)$\s)+)^# ///$" + + +def read(script: str) -> dict | None: + name = "script" + matches = list(filter(lambda m: m.group("type") == name, re.finditer(REGEX, script))) + if len(matches) > 1: + raise ValueError(f"Multiple {name} blocks found") + elif len(matches) == 1: + content = "".join( + line[2:] if line.startswith("# ") else line[1:] + for line in matches[0].group("content").splitlines(keepends=True) + ) + return tomllib.loads(content) + else: + return None diff --git a/.github/ci-scripts/templatize_ray_config.py b/.github/ci-scripts/templatize_ray_config.py index 887fe0f786..2a444c7eec 100644 --- a/.github/ci-scripts/templatize_ray_config.py +++ b/.github/ci-scripts/templatize_ray_config.py @@ -1,10 +1,20 @@ +# /// script +# requires-python = ">=3.12" +# dependencies = ['pydantic'] +# /// + import sys from argparse import ArgumentParser from dataclasses import dataclass +from pathlib import Path from typing import Optional +import read_inline_metadata +from pydantic import BaseModel, Field + CLUSTER_NAME_PLACEHOLDER = "\\{{CLUSTER_NAME}}" DAFT_INSTALL_PLACEHOLDER = "\\{{DAFT_INSTALL}}" +OTHER_INSTALL_PLACEHOLDER = "\{{OTHER_INSTALLS}}" PYTHON_VERSION_PLACEHOLDER = "\\{{PYTHON_VERSION}}" CLUSTER_PROFILE__NODE_COUNT = "\\{{CLUSTER_PROFILE/node_count}}" CLUSTER_PROFILE__INSTANCE_TYPE = "\\{{CLUSTER_PROFILE/instance_type}}" @@ -12,6 +22,8 @@ CLUSTER_PROFILE__SSH_USER = "\\{{CLUSTER_PROFILE/ssh_user}}" CLUSTER_PROFILE__VOLUME_MOUNT = "\\{{CLUSTER_PROFILE/volume_mount}}" +NOOP_STEP = "echo 'noop step; skipping'" + @dataclass class Profile: @@ -22,6 +34,11 @@ class Profile: volume_mount: Optional[str] = None +class Metadata(BaseModel, extra="allow"): + dependencies: list[str] = Field(default_factory=list) + env: dict[str, str] = Field(default_factory=dict) + + profiles: dict[str, Optional[Profile]] = { "debug_xs-x86": Profile( instance_type="t3.large", @@ -50,15 +67,16 @@ class Profile: content = sys.stdin.read() parser = ArgumentParser() - parser.add_argument("--cluster-name") + parser.add_argument("--cluster-name", required=True) parser.add_argument("--daft-wheel-url") parser.add_argument("--daft-version") - parser.add_argument("--python-version") - parser.add_argument("--cluster-profile") + parser.add_argument("--python-version", required=True) + parser.add_argument("--cluster-profile", required=True, choices=["debug_xs-x86", "medium-x86"]) + parser.add_argument("--working-dir", required=True) + parser.add_argument("--entrypoint-script", required=True) args = parser.parse_args() - if args.cluster_name: - content = content.replace(CLUSTER_NAME_PLACEHOLDER, args.cluster_name) + content = content.replace(CLUSTER_NAME_PLACEHOLDER, args.cluster_name) if args.daft_wheel_url and args.daft_version: raise ValueError( @@ -72,26 +90,24 @@ class Profile: daft_install = "getdaft" content = content.replace(DAFT_INSTALL_PLACEHOLDER, daft_install) - if args.python_version: - content = content.replace(PYTHON_VERSION_PLACEHOLDER, args.python_version) - - if cluster_profile := args.cluster_profile: - cluster_profile: str - if cluster_profile not in profiles: - raise Exception(f'Cluster profile "{cluster_profile}" not found') - - profile = profiles[cluster_profile] - if profile is None: - raise Exception(f'Cluster profile "{cluster_profile}" not yet implemented') - - assert profile is not None - content = content.replace(CLUSTER_PROFILE__NODE_COUNT, str(profile.node_count)) - content = content.replace(CLUSTER_PROFILE__INSTANCE_TYPE, profile.instance_type) - content = content.replace(CLUSTER_PROFILE__IMAGE_ID, profile.image_id) - content = content.replace(CLUSTER_PROFILE__SSH_USER, profile.ssh_user) - if profile.volume_mount: - content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, profile.volume_mount) - else: - content = content.replace(CLUSTER_PROFILE__VOLUME_MOUNT, "echo 'Nothing to mount; skipping'") + content = content.replace(PYTHON_VERSION_PLACEHOLDER, args.python_version) + + profile = profiles[args.cluster_profile] + content = content.replace(CLUSTER_PROFILE__NODE_COUNT, str(profile.node_count)) + content = content.replace(CLUSTER_PROFILE__INSTANCE_TYPE, profile.instance_type) + content = content.replace(CLUSTER_PROFILE__IMAGE_ID, profile.image_id) + content = content.replace(CLUSTER_PROFILE__SSH_USER, profile.ssh_user) + content = content.replace( + CLUSTER_PROFILE__VOLUME_MOUNT, profile.volume_mount if profile.volume_mount else NOOP_STEP + ) + + working_dir = Path(args.working_dir) + assert working_dir.exists() and working_dir.is_dir() + entrypoint_script_fullpath: Path = working_dir / args.entrypoint_script + assert entrypoint_script_fullpath.exists() and entrypoint_script_fullpath.is_file() + with open(entrypoint_script_fullpath) as f: + metadata = Metadata(**read_inline_metadata.read(f.read())) + + content = content.replace(OTHER_INSTALL_PLACEHOLDER, " ".join(metadata.dependencies)) print(content) diff --git a/.github/workflows/run-cluster.yaml b/.github/workflows/run-cluster.yaml index a47f8aec9d..903706b8cc 100644 --- a/.github/workflows/run-cluster.yaml +++ b/.github/workflows/run-cluster.yaml @@ -4,35 +4,40 @@ on: workflow_dispatch: inputs: daft_wheel_url: + description: Daft python-wheel URL type: string - description: A public https url pointing directly to a daft python-wheel to install required: false daft_version: + description: Daft version (errors if both this and "Daft python-wheel URL" are provided) type: string - description: A released version of daft on PyPi to install (errors if both this and `daft_wheel_url` are provided) required: false python_version: + description: Python version type: string - description: The version of python to use required: false default: "3.9" cluster_profile: + description: Cluster profile type: choice options: - medium-x86 - debug_xs-x86 - description: The profile to use for the cluster required: false default: medium-x86 - command: - type: string - description: The command to run on the cluster - required: true working_dir: + description: Working directory type: string - description: The working directory to submit to the cluster required: false default: .github/working-dir + entrypoint_script: + description: Entry-point python script (must be inside of the working directory) + type: string + required: true + entrypoint_args: + description: Entry-point arguments + type: string + required: false + default: "" jobs: run-command: @@ -42,6 +47,8 @@ jobs: id-token: write contents: read steps: + - name: Log workflow inputs + run: echo "${{ toJson(github.event.inputs) }}" - name: Checkout repo uses: actions/checkout@v4 with: @@ -63,13 +70,17 @@ jobs: - name: Dynamically update ray config file run: | source .venv/bin/activate - (cat .github/assets/.template.yaml \ - | python .github/ci-scripts/templatize_ray_config.py \ - --cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \ - --daft-wheel-url '${{ inputs.daft_wheel_url }}' \ - --daft-version '${{ inputs.daft_version }}' \ - --python-version '${{ inputs.python_version }}' \ - --cluster-profile '${{ inputs.cluster_profile }}' + (cat .github/assets/template.yaml | \ + uv run \ + --python 3.12 \ + .github/ci-scripts/templatize_ray_config.py \ + --cluster-name "ray-ci-run-${{ github.run_id }}_${{ github.run_attempt }}" \ + --daft-wheel-url '${{ inputs.daft_wheel_url }}' \ + --daft-version '${{ inputs.daft_version }}' \ + --python-version '${{ inputs.python_version }}' \ + --cluster-profile '${{ inputs.cluster_profile }}' \ + --working-dir '${{ inputs.working_dir }}' \ + --entrypoint-script '${{ inputs.entrypoint_script }}' ) >> .github/assets/ray.yaml cat .github/assets/ray.yaml - name: Download private ssh key @@ -88,7 +99,7 @@ jobs: - name: Submit job to ray cluster run: | source .venv/bin/activate - if [[ -z '${{ inputs.command }}' ]]; then + if [[ -z '${{ inputs.entrypoint_script }}' ]]; then echo 'Invalid command submitted; command cannot be empty' exit 1 fi @@ -96,7 +107,7 @@ jobs: --working-dir ${{ inputs.working_dir }} \ --address http://localhost:8265 \ --runtime-env-json '{"env_vars": {"DAFT_ENABLE_RAY_TRACING": "1"}}' \ - -- ${{ inputs.command }} + -- python ${{ inputs.entrypoint_script }} ${{ inputs.entrypoint_args }} - name: Download log files from ray cluster run: | source .venv/bin/activate diff --git a/.github/working-dir/shuffle_testing.py b/.github/working-dir/shuffle_testing.py new file mode 100644 index 0000000000..18b7c76307 --- /dev/null +++ b/.github/working-dir/shuffle_testing.py @@ -0,0 +1,257 @@ +# /// script +# dependencies = ['numpy'] +# /// + +import argparse +import random +import time +from functools import partial +from typing import Any, Dict + +import numpy as np +import pyarrow as pa + +import daft +import daft.context +from daft.io._generator import read_generator +from daft.table.table import Table + +# Constants +GB = 1 << 30 +MB = 1 << 20 +KB = 1 << 10 +ROW_SIZE = 100 * KB + + +def parse_size(size_str: str) -> int: + """Convert human-readable size string to bytes.""" + units = {"B": 1, "KB": KB, "MB": MB, "GB": GB} + size_str = size_str.upper() + value, unit = size_str.split(" ") + return int(float(value) * units[unit]) + + +def get_skewed_distribution(num_partitions: int, skew_factor: float) -> np.ndarray: + """ + Generate a skewed distribution using a power law. + Higher skew_factor means more skewed distribution. + """ + if skew_factor <= 0: + return np.ones(num_partitions) / num_partitions + + # Generate power law distribution + x = np.arange(1, num_partitions + 1) + weights = 1.0 / (x**skew_factor) + return weights / weights.sum() + + +def get_partition_size(base_size: int, size_variation: float, partition_idx: int) -> int: + """ + Calculate size for a specific partition with variation. + + Args: + base_size: The base partition size in bytes + size_variation: Float between 0 and 1 indicating maximum variation (e.g., 0.2 = ±20%) + partition_idx: Index of the partition (used for random seed) + + Returns: + Adjusted partition size in bytes + """ + if size_variation <= 0: + return base_size + + # Use partition_idx as seed for consistent variation per partition + random.seed(f"size_{partition_idx}") + + # Generate variation factor between (1-variation) and (1+variation) + variation_factor = 1.0 + random.uniform(-size_variation, size_variation) + + # Ensure we don't go below 10% of base size + min_size = base_size * 0.1 + return max(int(base_size * variation_factor), int(min_size)) + + +def generate( + num_partitions: int, + base_partition_size: int, + skew_factor: float, + timing_variation: float, + size_variation: float, + partition_idx: int, +): + """Generate data for a single partition with optional skew, timing and size variations.""" + + # Calculate actual partition size with variation + actual_partition_size = get_partition_size(base_partition_size, size_variation, partition_idx) + num_rows = actual_partition_size // ROW_SIZE + + # Apply skewed distribution if specified + if skew_factor > 0: + weights = get_skewed_distribution(num_partitions, skew_factor) + data = { + "ints": pa.array( + np.random.choice(num_partitions, size=num_rows, p=weights, replace=True).astype(np.uint64), + ), + "bytes": pa.array( + [np.random.bytes(ROW_SIZE) for _ in range(num_rows)], + type=pa.binary(ROW_SIZE), + ), + } + else: + data = { + "ints": pa.array(np.random.randint(0, num_partitions, size=num_rows)), + "bytes": pa.array( + [np.random.bytes(ROW_SIZE) for _ in range(num_rows)], + type=pa.binary(ROW_SIZE), + ), + } + + # Simulate varying processing times if specified + if timing_variation > 0: + random.seed(f"timing_{partition_idx}") + delay = random.uniform(0, timing_variation) + time.sleep(delay) + + yield Table.from_pydict(data) + + +def generator( + num_partitions: int, + partition_size: int, + skew_factor: float, + timing_variation: float, + size_variation: float, +): + """Generate data for all partitions.""" + for i in range(num_partitions): + yield partial( + generate, + num_partitions, + partition_size, + skew_factor, + timing_variation, + size_variation, + i, + ) + + +def setup_daft(shuffle_algorithm: str = None): + """Configure Daft execution settings.""" + daft.context.set_runner_ray() + daft.context.set_execution_config(shuffle_algorithm=shuffle_algorithm, pre_shuffle_merge_threshold=8 * GB) + + +def create_schema(): + """Create the Daft schema for the dataset.""" + return daft.Schema._from_field_name_and_types([("ints", daft.DataType.uint64()), ("bytes", daft.DataType.binary())]) + + +def run_benchmark( + num_partitions: int, + partition_size: int, + skew_factor: float, + timing_variation: float, + size_variation: float, + shuffle_algorithm: str = None, +) -> Dict[str, Any]: + """Run the memory benchmark and return statistics.""" + setup_daft(shuffle_algorithm) + schema = create_schema() + + def benchmark_func(): + return ( + read_generator( + generator( + num_partitions, + partition_size, + skew_factor, + timing_variation, + size_variation, + ), + schema, + ) + .repartition(num_partitions, "ints") + .collect() + ) + + start_time = time.time() + + benchmark_func() + + end_time = time.time() + return end_time - start_time + + +def main(): + parser = argparse.ArgumentParser(description="Run memory benchmark for data processing") + parser.add_argument("--partitions", type=int, default=1000, help="Number of partitions") + parser.add_argument( + "--partition-size", + type=str, + default="100 MB", + help="Base size for each partition (e.g., 300 MB, 1 GB)", + ) + parser.add_argument( + "--skew-factor", + type=float, + default=0.0, + help="Skew factor for partition distribution (0.0 for uniform, higher for more skew)", + ) + parser.add_argument( + "--timing-variation", + type=float, + default=0.0, + help="Maximum random delay in seconds for partition processing", + ) + parser.add_argument( + "--size-variation", + type=float, + default=0.0, + help="Maximum partition size variation as fraction (0.0-1.0, e.g., 0.2 for ±20%%)", + ) + parser.add_argument("--shuffle-algorithm", type=str, default=None, help="Shuffle algorithm to use") + + args = parser.parse_args() + + if not 0 <= args.size_variation <= 1: + parser.error("Size variation must be between 0 and 1") + + partition_size_bytes = parse_size(args.partition_size) + + print("Running benchmark with configuration:") + print(f"Partitions: {args.partitions}") + print(f"Base partition size: {args.partition_size} ({partition_size_bytes} bytes)") + print(f"Size variation: ±{args.size_variation*100:.0f}%") + print(f"Row size: {ROW_SIZE/KB:.0f}KB (fixed)") + print(f"Skew factor: {args.skew_factor}") + print(f"Timing variation: {args.timing_variation}s") + print(f"Shuffle algorithm: {args.shuffle_algorithm or 'default'}") + + try: + timing = run_benchmark( + num_partitions=args.partitions, + partition_size=partition_size_bytes, + skew_factor=args.skew_factor, + timing_variation=args.timing_variation, + size_variation=args.size_variation, + shuffle_algorithm=args.shuffle_algorithm, + ) + + print("\nRan benchmark with configuration:") + print(f"Partitions: {args.partitions}") + print(f"Base partition size: {args.partition_size} ({partition_size_bytes} bytes)") + print(f"Size variation: ±{args.size_variation*100:.0f}%") + print(f"Row size: {ROW_SIZE/KB:.0f}KB (fixed)") + print(f"Skew factor: {args.skew_factor}") + print(f"Timing variation: {args.timing_variation}s") + print(f"Shuffle algorithm: {args.shuffle_algorithm or 'default'}") + print("\nResults:") + print(f"Total time: {timing:.2f}s") + + except Exception as e: + print(f"Error running benchmark: {str(e)}") + raise + + +if __name__ == "__main__": + main() diff --git a/.github/working-dir/uv_run_script_example.py b/.github/working-dir/uv_run_script_example.py new file mode 100644 index 0000000000..0d7611736e --- /dev/null +++ b/.github/working-dir/uv_run_script_example.py @@ -0,0 +1,3 @@ +import daft + +print(daft) diff --git a/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs b/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs index d12bff3ece..85224f29f2 100644 --- a/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs +++ b/src/arrow2/src/io/parquet/read/deserialize/boolean/basic.rs @@ -2,7 +2,7 @@ use std::collections::VecDeque; use parquet2::{ deserialize::SliceFilteredIter, - encoding::Encoding, + encoding::{hybrid_rle, Encoding}, page::{split_buffer, DataPage, DictPage}, schema::Repetition, }; @@ -51,6 +51,25 @@ impl<'a> Required<'a> { } } +#[derive(Debug)] +struct ValuesRle<'a>(hybrid_rle::HybridRleDecoder<'a>); + +impl<'a> ValuesRle<'a> { + pub fn try_new(page: &'a DataPage) -> Result { + let (_, _, indices_buffer) = split_buffer(page)?; + // Skip the u32 length prefix. + let indices_buffer = &indices_buffer[std::mem::size_of::()..]; + let decoder = + hybrid_rle::HybridRleDecoder::try_new( + indices_buffer, + 1_u32, // The bit width for a boolean is 1. + page.num_values() + ) + .map_err(crate::error::Error::from)?; + Ok(Self(decoder)) + } +} + #[derive(Debug)] struct FilteredRequired<'a> { values: SliceFilteredIter>, @@ -79,6 +98,7 @@ impl<'a> FilteredRequired<'a> { enum State<'a> { Optional(OptionalPageValidity<'a>, Values<'a>), Required(Required<'a>), + OptionalRle(OptionalPageValidity<'a>, ValuesRle<'a>), FilteredRequired(FilteredRequired<'a>), FilteredOptional(FilteredOptionalPageValidity<'a>, Values<'a>), } @@ -88,6 +108,7 @@ impl<'a> State<'a> { match self { State::Optional(validity, _) => validity.len(), State::Required(page) => page.length - page.offset, + Self::OptionalRle(validity, _) => validity.len(), State::FilteredRequired(page) => page.len(), State::FilteredOptional(optional, _) => optional.len(), } @@ -125,6 +146,12 @@ impl<'a> Decoder<'a> for BooleanDecoder { Values::try_new(page)?, )), (Encoding::Plain, false, false) => Ok(State::Required(Required::new(page))), + (Encoding::Rle, true, false) => { + Ok(State::OptionalRle( + OptionalPageValidity::try_new(page)?, + ValuesRle::try_new(page)?, + )) + }, (Encoding::Plain, true, true) => Ok(State::FilteredOptional( FilteredOptionalPageValidity::try_new(page)?, Values::try_new(page)?, @@ -163,6 +190,15 @@ impl<'a> Decoder<'a> for BooleanDecoder { values.extend_from_slice(page.values, page.offset, remaining); page.offset += remaining; } + State::OptionalRle(page_validity, page_values) => { + utils::extend_from_decoder( + validity, + page_validity, + Some(remaining), + values, + &mut page_values.0.by_ref().map(|x| x.unwrap()).map(|v| v != 0), + ) + }, State::FilteredRequired(page) => { values.reserve(remaining); for item in page.values.by_ref().take(remaining) { diff --git a/src/parquet2/src/encoding/hybrid_rle/decoder.rs b/src/parquet2/src/encoding/hybrid_rle/decoder.rs index 4417cc55a0..625f9df36e 100644 --- a/src/parquet2/src/encoding/hybrid_rle/decoder.rs +++ b/src/parquet2/src/encoding/hybrid_rle/decoder.rs @@ -141,4 +141,108 @@ mod tests { panic!() }; } + + #[test] + fn test_bool_bitpacked() { + let bit_width = 1usize; + let length = 4; + let values = [ + 2, 0, 0, 0, // Length indicator as u32. + 0b00000011, // Bitpacked indicator with 1 value (1 << 1 | 1). + 0b00001101, // Values (true, false, true, true). + ]; + let expected = &[1, 0, 1, 1]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() { + assert_eq!(values, &[0b00001101]); + + let result = bitpacked::Decoder::::try_new(values, bit_width, length) + .unwrap() + .collect::>(); + assert_eq!(result, expected); + } else { + panic!("Expected bitpacked encoding"); + } + } + + #[test] + fn test_bool_rle() { + let bit_width = 1usize; + let length = 4; + let values = [ + 2, 0, 0, 0, // Length indicator as u32. + 0b00001000, // RLE indicator (4 << 1 | 0). + true as u8 // Value to repeat. + ]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + if let Ok(HybridEncoded::Rle(value, run_length)) = decoder.next().unwrap() { + assert_eq!(value, &[1u8]); // true encoded as 1. + assert_eq!(run_length, length); // Repeated 4 times. + } else { + panic!("Expected RLE encoding"); + } + } + + #[test] + fn test_bool_mixed_rle() { + let bit_width = 1usize; + let values = [ + 4, 0, 0, 0, // Length indicator as u32. + 0b00000011, // Bitpacked indicator with 1 value (1 << 1 | 1). + 0b00001101, // Values (true, false, true, true). + 0b00001000, // RLE indicator (4 << 1 | 0) + false as u8 // RLE value + ]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + // Decode bitpacked values. + if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() { + assert_eq!(values, &[0b00001101]); + } else { + panic!("Expected bitpacked encoding"); + } + + // Decode RLE values. + if let Ok(HybridEncoded::Rle(value, run_length)) = decoder.next().unwrap() { + assert_eq!(value, &[0u8]); // false encoded as 0. + assert_eq!(run_length, 4); + } else { + panic!("Expected RLE encoding"); + } + } + + #[test] + fn test_bool_nothing_encoded() { + let bit_width = 1usize; + let values = [0, 0, 0, 0]; // Length indicator only. + + let mut decoder = Decoder::new(&values[4..], bit_width); + assert!(decoder.next().is_none()); + } + + #[test] + fn test_bool_invalid_encoding() { + let bit_width = 1usize; + let values = [ + 2, 0, 0, 0, // Length indicator as u32. + 0b00000101, // Bitpacked indicator with 1 value (2 << 1 | 1). + true as u8 // Incomplete encoding (should have another u8). + ]; + + let mut decoder = Decoder::new(&values[4..], bit_width); + + if let Ok(HybridEncoded::Bitpacked(values)) = decoder.next().unwrap() { + assert_eq!(values, &[1u8]); + } else { + panic!("Expected bitpacked encoding"); + } + + // Next call should return None since we've exhausted the buffer. + assert!(decoder.next().is_none()); + } } diff --git a/tests/io/test_parquet_roundtrip.py b/tests/io/test_parquet_roundtrip.py index 8804ef6d33..81d5689600 100644 --- a/tests/io/test_parquet_roundtrip.py +++ b/tests/io/test_parquet_roundtrip.py @@ -2,9 +2,11 @@ import datetime import decimal +import random import numpy as np import pyarrow as pa +import pyarrow.parquet as papq import pytest import daft @@ -150,6 +152,22 @@ def test_roundtrip_sparse_tensor_types(tmp_path, fixed_shape): assert before.to_arrow() == after.to_arrow() +@pytest.mark.parametrize("has_none", [True, False]) +def test_roundtrip_boolean_rle(tmp_path, has_none): + file_path = f"{tmp_path}/test.parquet" + if has_none: + # Create an array of random True/False values that are None 10% of the time. + random_bools = random.choices([True, False, None], weights=[45, 45, 10], k=1000_000) + else: + # Create an array of random True/False values. + random_bools = random.choices([True, False], k=1000_000) + pa_original = pa.table({"bools": pa.array(random_bools, type=pa.bool_())}) + # Use data page version 2.0 which uses RLE encoding for booleans. + papq.write_table(pa_original, file_path, data_page_version="2.0") + df_roundtrip = daft.read_parquet(file_path) + assert pa_original == df_roundtrip.to_arrow() + + # TODO: reading/writing: # 1. Embedding type # 2. Image type