diff --git a/tests/memory/test_udf_project.py b/tests/memory/test_udf_project.py index 7ac1bfee98..c67cda0e28 100644 --- a/tests/memory/test_udf_project.py +++ b/tests/memory/test_udf_project.py @@ -9,6 +9,15 @@ from tests.memory.utils import run_wrapper_build_partitions +def format_bytes(bytes_value): + """Format bytes into human readable string with appropriate unit.""" + for unit in ["B", "KB", "MB", "GB"]: + if bytes_value < 1024: + return f"{bytes_value:.2f} {unit}" + bytes_value /= 1024 + return f"{bytes_value:.2f} GB" + + @daft.udf(return_dtype=str) def to_arrow_identity(s): data = s.to_arrow() @@ -49,10 +58,37 @@ def to_pylist_identity_batched_arrow_return(s): to_pylist_identity_batched_arrow_return, ], ) -def test_string_identity_projection(udf): +def test_short_string_identity_projection(udf): + instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))] + inputs = [{"a": [str(uuid.uuid4()) for _ in range(62500)]}] + _, memray_file = run_wrapper_build_partitions(inputs, instructions) + stats = compute_statistics(memray_file) + + expected_peak_bytes = 100 + assert stats.peak_memory_allocated < expected_peak_bytes, ( + f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) " + f"exceeded threshold ({format_bytes(expected_peak_bytes)})" + ) + + +@pytest.mark.parametrize( + "udf", + [ + to_arrow_identity, + to_pylist_identity, + to_arrow_identity_batched, + to_pylist_identity_batched, + to_pylist_identity_batched_arrow_return, + ], +) +def test_long_string_identity_projection(udf): instructions = [Project(ExpressionsProjection([udf(daft.col("a"))]))] inputs = [{"a": [str(uuid.uuid4()) for _ in range(625000)]}] _, memray_file = run_wrapper_build_partitions(inputs, instructions) stats = compute_statistics(memray_file) - assert stats.peak_memory_allocated < 100 + expected_peak_bytes = 100 + assert stats.peak_memory_allocated < expected_peak_bytes, ( + f"Peak memory ({format_bytes(stats.peak_memory_allocated)}) " + f"exceeded threshold ({format_bytes(expected_peak_bytes)})" + ) diff --git a/tests/memory/utils.py b/tests/memory/utils.py index bc92c3dc3c..a1f17f706d 100644 --- a/tests/memory/utils.py +++ b/tests/memory/utils.py @@ -1,3 +1,4 @@ +import logging import os import tempfile import uuid @@ -9,11 +10,16 @@ from daft.runners.ray_runner import build_partitions from daft.table import MicroPartition +logger = logging.getLogger(__name__) + def run_wrapper_build_partitions( input_partitions: list[dict], instructions: list[Instruction] ) -> tuple[list[MicroPartition], str]: inputs = [MicroPartition.from_pydict(p) for p in input_partitions] + + logger.info("Input total size: %s", sum(i.size_bytes() for i in inputs)) + tmpdir = tempfile.gettempdir() memray_path = os.path.join(tmpdir, f"memray-{uuid.uuid4()}.bin") with memray.Tracker(memray_path, native_traces=True, follow_fork=True):