diff --git a/deltacat/compute/compactor_v2/utils/primary_key_index.py b/deltacat/compute/compactor_v2/utils/primary_key_index.py index be560cac..f80eeb27 100644 --- a/deltacat/compute/compactor_v2/utils/primary_key_index.py +++ b/deltacat/compute/compactor_v2/utils/primary_key_index.py @@ -17,6 +17,7 @@ from deltacat.compute.compactor.utils import system_columns as sc from deltacat.io.object_store import IObjectStore from deltacat.utils.performance import timed_invocation +from deltacat.utils.pyarrow import sliced_string_cast logger = logs.configure_deltacat_logger(logging.getLogger(__name__)) @@ -182,7 +183,7 @@ def generate_pk_hash_column( def _generate_pk_hash(table: pa.Table) -> pa.Array: pk_columns = [] for pk_name in primary_keys: - pk_columns.append(pc.cast(table[pk_name], pa.string())) + pk_columns.append(sliced_string_cast(table[pk_name])) pk_columns.append(PK_DELIMITER) hash_column = pc.binary_join_element_wise(*pk_columns) diff --git a/deltacat/utils/pyarrow.py b/deltacat/utils/pyarrow.py index 10773bf7..c7babdf2 100644 --- a/deltacat/utils/pyarrow.py +++ b/deltacat/utils/pyarrow.py @@ -11,6 +11,8 @@ from deltacat.exceptions import ValidationError import pyarrow as pa +import numpy as np +import pyarrow.compute as pc from fsspec import AbstractFileSystem from pyarrow import csv as pacsv from pyarrow import feather as paf @@ -38,6 +40,7 @@ sanitize_kwargs_to_callable, sanitize_kwargs_by_supported_kwargs, ) +from functools import lru_cache logger = logs.configure_deltacat_logger(logging.getLogger(__name__)) @@ -738,3 +741,69 @@ def clear_remaining(self) -> None: """ self._remaining_tables.clear() self._remaining_record_count = 0 + + +@lru_cache(maxsize=1) +def _int_max_string_len() -> int: + PA_UINT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(2**64 - 1, type=pa.uint64()), pa.string()) + ).as_py() + PA_INT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(-(2**63), type=pa.int64()), pa.string()) + ).as_py() + return max(PA_UINT64_MAX_STR_BYTES, PA_INT64_MAX_STR_BYTES) + + +@lru_cache(maxsize=1) +def _float_max_string_len() -> int: + PA_POS_FLOAT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(np.finfo(np.float64).max, type=pa.float64()), pa.string()) + ).as_py() + PA_NEG_FLOAT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(np.finfo(np.float64).min, type=pa.float64()), pa.string()) + ).as_py() + return max(PA_POS_FLOAT64_MAX_STR_BYTES, PA_NEG_FLOAT64_MAX_STR_BYTES) + + +def _max_decimal128_string_len(): + return 40 # "-" + 38 digits + decimal + + +def _max_decimal256_string_len(): + return 78 # "-" + 76 digits + decimal + + +def sliced_string_cast(array: pa.ChunkedArray) -> pa.ChunkedArray: + """performs slicing of a pyarrow array prior casting to a string. + This prevents a pyarrow from allocating too large of an array causing a failure. + Issue: https://github.com/apache/arrow/issues/38835 + TODO: deprecate this function when pyarrow performs proper ChunkedArray -> ChunkedArray casting + """ + dtype = array.type + MAX_BYTES = 2147483646 + max_str_len = None + if pa.types.is_integer(dtype): + max_str_len = _int_max_string_len() + elif pa.types.is_floating(dtype): + max_str_len = _float_max_string_len() + elif pa.types.is_decimal128(dtype): + max_str_len = _max_decimal128_string_len() + elif pa.types.is_decimal256(dtype): + max_str_len = _max_decimal256_string_len() + + if max_str_len is not None: + max_elems_per_chunk = MAX_BYTES // (2 * max_str_len) # safety factor of 2 + all_chunks = [] + for chunk in array.chunks: + if len(chunk) < max_elems_per_chunk: + all_chunks.append(chunk) + else: + curr_pos = 0 + total_len = len(chunk) + while curr_pos < total_len: + sliced = chunk.slice(curr_pos, max_elems_per_chunk) + curr_pos += len(sliced) + all_chunks.append(sliced) + array = pa.chunked_array(all_chunks, type=dtype) + + return pc.cast(array, pa.string())