Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

add sliced string cast for hash column generator #249

Merged
merged 4 commits into from
Dec 1, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion deltacat/compute/compactor_v2/utils/primary_key_index.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,7 @@
from deltacat.compute.compactor.utils import system_columns as sc
from deltacat.io.object_store import IObjectStore
from deltacat.utils.performance import timed_invocation
from deltacat.utils.pyarrow import sliced_string_cast

logger = logs.configure_deltacat_logger(logging.getLogger(__name__))

Expand Down Expand Up @@ -182,7 +183,7 @@ def generate_pk_hash_column(
def _generate_pk_hash(table: pa.Table) -> pa.Array:
pk_columns = []
for pk_name in primary_keys:
pk_columns.append(pc.cast(table[pk_name], pa.string()))
pk_columns.append(sliced_string_cast(table[pk_name]))

pk_columns.append(PK_DELIMITER)
hash_column = pc.binary_join_element_wise(*pk_columns)
Expand Down
69 changes: 69 additions & 0 deletions deltacat/utils/pyarrow.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,8 @@
from deltacat.exceptions import ValidationError

import pyarrow as pa
import numpy as np
import pyarrow.compute as pc
from fsspec import AbstractFileSystem
from pyarrow import csv as pacsv
from pyarrow import feather as paf
Expand Down Expand Up @@ -38,6 +40,7 @@
sanitize_kwargs_to_callable,
sanitize_kwargs_by_supported_kwargs,
)
from functools import lru_cache

logger = logs.configure_deltacat_logger(logging.getLogger(__name__))

Expand Down Expand Up @@ -738,3 +741,69 @@ def clear_remaining(self) -> None:
"""
self._remaining_tables.clear()
self._remaining_record_count = 0


@lru_cache(maxsize=1)
def _int_max_string_len() -> int:
PA_UINT64_MAX_STR_BYTES = pc.binary_length(
pc.cast(pa.scalar(2**64 - 1, type=pa.uint64()), pa.string())
).as_py()
PA_INT64_MAX_STR_BYTES = pc.binary_length(
pc.cast(pa.scalar(-(2**63), type=pa.int64()), pa.string())
).as_py()
return max(PA_UINT64_MAX_STR_BYTES, PA_INT64_MAX_STR_BYTES)


@lru_cache(maxsize=1)
def _float_max_string_len() -> int:
PA_POS_FLOAT64_MAX_STR_BYTES = pc.binary_length(
pc.cast(pa.scalar(np.finfo(np.float64).max, type=pa.float64()), pa.string())
).as_py()
PA_NEG_FLOAT64_MAX_STR_BYTES = pc.binary_length(
pc.cast(pa.scalar(np.finfo(np.float64).min, type=pa.float64()), pa.string())
).as_py()
return max(PA_POS_FLOAT64_MAX_STR_BYTES, PA_NEG_FLOAT64_MAX_STR_BYTES)


def _max_decimal128_string_len():
return 40 # "-" + 38 digits + decimal


def _max_decimal256_string_len():
return 78 # "-" + 76 digits + decimal


def sliced_string_cast(array: pa.ChunkedArray) -> pa.ChunkedArray:
rkenmi marked this conversation as resolved.
Show resolved Hide resolved
"""performs slicing of a pyarrow array prior casting to a string.
This prevents a pyarrow from allocating too large of an array causing a failure.
Issue: https://github.com/apache/arrow/issues/38835
rkenmi marked this conversation as resolved.
Show resolved Hide resolved
TODO: deprecate this function when pyarrow performs proper ChunkedArray -> ChunkedArray casting
"""
dtype = array.type
MAX_BYTES = 2147483646
max_str_len = None
if pa.types.is_integer(dtype):
max_str_len = _int_max_string_len()
elif pa.types.is_floating(dtype):
max_str_len = _float_max_string_len()
elif pa.types.is_decimal128(dtype):
max_str_len = _max_decimal128_string_len()
elif pa.types.is_decimal256(dtype):
max_str_len = _max_decimal256_string_len()

if max_str_len is not None:
max_elems_per_chunk = MAX_BYTES // (2 * max_str_len) # safety factor of 2
all_chunks = []
for chunk in array.chunks:
if len(chunk) < max_elems_per_chunk:
all_chunks.append(chunk)
else:
curr_pos = 0
total_len = len(chunk)
while curr_pos < total_len:
sliced = chunk.slice(curr_pos, max_elems_per_chunk)
curr_pos += len(sliced)
all_chunks.append(sliced)
array = pa.chunked_array(all_chunks, type=dtype)

return pc.cast(array, pa.string())
Loading