From 16edc9837d8f54386949a8c599994dc8bdb475ef Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 21 Nov 2023 10:52:20 -0800 Subject: [PATCH 1/4] add sliced string cast for hash column generator --- .../compactor_v2/utils/primary_key_index.py | 36 ++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/deltacat/compute/compactor_v2/utils/primary_key_index.py b/deltacat/compute/compactor_v2/utils/primary_key_index.py index be560cac..828eb737 100644 --- a/deltacat/compute/compactor_v2/utils/primary_key_index.py +++ b/deltacat/compute/compactor_v2/utils/primary_key_index.py @@ -167,6 +167,40 @@ def group_by_pk_hash_bucket( return result +def _sliced_string_cast(array: pa.ChunkedArray) -> pa.ChunkedArray: + """performs slicing of a pyarrow array prior casting to a string. + This prevents a pyarrow from allocating too large of an array causing a failure. + """ + dtype = array.type + MAX_BYTES = 2147483646 + max_str_len = None + if pa.types.is_integer(dtype): + max_str_len = 21 # -INT_MAX + elif pa.types.is_floating(dtype): + max_str_len = 24 + elif pa.types.is_decimal128(dtype): + max_str_len = 39 + elif pa.types.is_decimal256(dtype): + max_str_len = 77 + + if max_str_len is not None: + max_elems_per_chunk = MAX_BYTES // (2 * max_str_len) # safety factor of 2 + all_chunks = [] + for chunk in array.chunks: + if len(chunk) < max_elems_per_chunk: + all_chunks.append(chunk) + else: + curr_pos = 0 + total_len = len(chunk) + while curr_pos < total_len: + sliced = chunk.slice(curr_pos, max_elems_per_chunk) + curr_pos += len(sliced) + all_chunks.append(sliced) + array = pa.chunked_array(all_chunks, type=dtype) + + return pc.cast(array, pa.string()) + + def generate_pk_hash_column( tables: List[pa.Table], primary_keys: Optional[List[str]] = None, @@ -182,7 +216,7 @@ def generate_pk_hash_column( def _generate_pk_hash(table: pa.Table) -> pa.Array: pk_columns = [] for pk_name in primary_keys: - pk_columns.append(pc.cast(table[pk_name], pa.string())) + pk_columns.append(_sliced_string_cast(table[pk_name])) pk_columns.append(PK_DELIMITER) hash_column = pc.binary_join_element_wise(*pk_columns) From 60efc046d050cd7beae755827d5a9c2ecfda40f4 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 21 Nov 2023 14:30:10 -0800 Subject: [PATCH 2/4] factor calculations out and move to pyarrow uril --- .../compactor_v2/utils/primary_key_index.py | 37 +--------- deltacat/utils/pyarrow.py | 68 +++++++++++++++++++ 2 files changed, 70 insertions(+), 35 deletions(-) diff --git a/deltacat/compute/compactor_v2/utils/primary_key_index.py b/deltacat/compute/compactor_v2/utils/primary_key_index.py index 828eb737..f80eeb27 100644 --- a/deltacat/compute/compactor_v2/utils/primary_key_index.py +++ b/deltacat/compute/compactor_v2/utils/primary_key_index.py @@ -17,6 +17,7 @@ from deltacat.compute.compactor.utils import system_columns as sc from deltacat.io.object_store import IObjectStore from deltacat.utils.performance import timed_invocation +from deltacat.utils.pyarrow import sliced_string_cast logger = logs.configure_deltacat_logger(logging.getLogger(__name__)) @@ -167,40 +168,6 @@ def group_by_pk_hash_bucket( return result -def _sliced_string_cast(array: pa.ChunkedArray) -> pa.ChunkedArray: - """performs slicing of a pyarrow array prior casting to a string. - This prevents a pyarrow from allocating too large of an array causing a failure. - """ - dtype = array.type - MAX_BYTES = 2147483646 - max_str_len = None - if pa.types.is_integer(dtype): - max_str_len = 21 # -INT_MAX - elif pa.types.is_floating(dtype): - max_str_len = 24 - elif pa.types.is_decimal128(dtype): - max_str_len = 39 - elif pa.types.is_decimal256(dtype): - max_str_len = 77 - - if max_str_len is not None: - max_elems_per_chunk = MAX_BYTES // (2 * max_str_len) # safety factor of 2 - all_chunks = [] - for chunk in array.chunks: - if len(chunk) < max_elems_per_chunk: - all_chunks.append(chunk) - else: - curr_pos = 0 - total_len = len(chunk) - while curr_pos < total_len: - sliced = chunk.slice(curr_pos, max_elems_per_chunk) - curr_pos += len(sliced) - all_chunks.append(sliced) - array = pa.chunked_array(all_chunks, type=dtype) - - return pc.cast(array, pa.string()) - - def generate_pk_hash_column( tables: List[pa.Table], primary_keys: Optional[List[str]] = None, @@ -216,7 +183,7 @@ def generate_pk_hash_column( def _generate_pk_hash(table: pa.Table) -> pa.Array: pk_columns = [] for pk_name in primary_keys: - pk_columns.append(_sliced_string_cast(table[pk_name])) + pk_columns.append(sliced_string_cast(table[pk_name])) pk_columns.append(PK_DELIMITER) hash_column = pc.binary_join_element_wise(*pk_columns) diff --git a/deltacat/utils/pyarrow.py b/deltacat/utils/pyarrow.py index 10773bf7..a43d75e6 100644 --- a/deltacat/utils/pyarrow.py +++ b/deltacat/utils/pyarrow.py @@ -11,6 +11,8 @@ from deltacat.exceptions import ValidationError import pyarrow as pa +import numpy as np +import pyarrow.compute as pc from fsspec import AbstractFileSystem from pyarrow import csv as pacsv from pyarrow import feather as paf @@ -38,6 +40,7 @@ sanitize_kwargs_to_callable, sanitize_kwargs_by_supported_kwargs, ) +from functools import lru_cache logger = logs.configure_deltacat_logger(logging.getLogger(__name__)) @@ -738,3 +741,68 @@ def clear_remaining(self) -> None: """ self._remaining_tables.clear() self._remaining_record_count = 0 + + +@lru_cache +def _int_max_string_len() -> int: + PA_UINT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(2**64 - 1, type=pa.uint64()), pa.string()) + ).as_py() + PA_INT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(-(2**63), type=pa.int64()), pa.string()) + ).as_py() + return max(PA_UINT64_MAX_STR_BYTES, PA_INT64_MAX_STR_BYTES) + + +@lru_cache +def _float_max_string_len() -> int: + PA_POS_FLOAT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(np.finfo(np.float64).max, type=pa.float64()), pa.string()) + ).as_py() + PA_NEG_FLOAT64_MAX_STR_BYTES = pc.binary_length( + pc.cast(pa.scalar(np.finfo(np.float64).min, type=pa.float64()), pa.string()) + ).as_py() + return max(PA_POS_FLOAT64_MAX_STR_BYTES, PA_NEG_FLOAT64_MAX_STR_BYTES) + + +def _max_decimal128_string_len(): + return 40 # "-" + 38 digits + decimal + + +def _max_decimal256_string_len(): + return 78 # "-" + 76 digits + decimal + + +def sliced_string_cast(array: pa.ChunkedArray) -> pa.ChunkedArray: + """performs slicing of a pyarrow array prior casting to a string. + This prevents a pyarrow from allocating too large of an array causing a failure. + Issue: https://github.com/apache/arrow/issues/38835 + """ + dtype = array.type + MAX_BYTES = 2147483646 + max_str_len = None + if pa.types.is_integer(dtype): + max_str_len = _int_max_string_len() + elif pa.types.is_floating(dtype): + max_str_len = _float_max_string_len() + elif pa.types.is_decimal128(dtype): + max_str_len = _max_decimal128_string_len() + elif pa.types.is_decimal256(dtype): + max_str_len = _max_decimal256_string_len() + + if max_str_len is not None: + max_elems_per_chunk = MAX_BYTES // (2 * max_str_len) # safety factor of 2 + all_chunks = [] + for chunk in array.chunks: + if len(chunk) < max_elems_per_chunk: + all_chunks.append(chunk) + else: + curr_pos = 0 + total_len = len(chunk) + while curr_pos < total_len: + sliced = chunk.slice(curr_pos, max_elems_per_chunk) + curr_pos += len(sliced) + all_chunks.append(sliced) + array = pa.chunked_array(all_chunks, type=dtype) + + return pc.cast(array, pa.string()) From 636702eb7f2fdcc4259c0011bf003cbf74441d64 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 21 Nov 2023 14:40:22 -0800 Subject: [PATCH 3/4] set max size for lru --- deltacat/utils/pyarrow.py | 4 ++-- 1 file changed, 2 insertions(+), 2 deletions(-) diff --git a/deltacat/utils/pyarrow.py b/deltacat/utils/pyarrow.py index a43d75e6..f0852a1f 100644 --- a/deltacat/utils/pyarrow.py +++ b/deltacat/utils/pyarrow.py @@ -743,7 +743,7 @@ def clear_remaining(self) -> None: self._remaining_record_count = 0 -@lru_cache +@lru_cache(maxsize=1) def _int_max_string_len() -> int: PA_UINT64_MAX_STR_BYTES = pc.binary_length( pc.cast(pa.scalar(2**64 - 1, type=pa.uint64()), pa.string()) @@ -754,7 +754,7 @@ def _int_max_string_len() -> int: return max(PA_UINT64_MAX_STR_BYTES, PA_INT64_MAX_STR_BYTES) -@lru_cache +@lru_cache(maxsize=1) def _float_max_string_len() -> int: PA_POS_FLOAT64_MAX_STR_BYTES = pc.binary_length( pc.cast(pa.scalar(np.finfo(np.float64).max, type=pa.float64()), pa.string()) From d00b2c63841bf674d0d1774c1176c89c0e5a62bf Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Tue, 21 Nov 2023 14:42:46 -0800 Subject: [PATCH 4/4] add todo --- deltacat/utils/pyarrow.py | 1 + 1 file changed, 1 insertion(+) diff --git a/deltacat/utils/pyarrow.py b/deltacat/utils/pyarrow.py index f0852a1f..c7babdf2 100644 --- a/deltacat/utils/pyarrow.py +++ b/deltacat/utils/pyarrow.py @@ -777,6 +777,7 @@ def sliced_string_cast(array: pa.ChunkedArray) -> pa.ChunkedArray: """performs slicing of a pyarrow array prior casting to a string. This prevents a pyarrow from allocating too large of an array causing a failure. Issue: https://github.com/apache/arrow/issues/38835 + TODO: deprecate this function when pyarrow performs proper ChunkedArray -> ChunkedArray casting """ dtype = array.type MAX_BYTES = 2147483646