Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Upgrade Daft to 0.1.17 for improved performance and resource usage #217

Merged
merged 20 commits into from
Sep 12, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion deltacat/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,7 +44,7 @@

deltacat.logs.configure_deltacat_logger(logging.getLogger(__name__))

__version__ = "0.1.19"
__version__ = "0.1.20"


__all__ = [
Expand Down
1 change: 1 addition & 0 deletions deltacat/aws/constants.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

from deltacat.utils.common import env_integer, env_string

DAFT_MAX_S3_CONNECTIONS_PER_FILE = env_integer("DAFT_MAX_S3_CONNECTIONS_PER_FILE", 8)
BOTO_MAX_RETRIES = env_integer("BOTO_MAX_RETRIES", 15)
TIMEOUT_ERROR_CODES: List[str] = ["ReadTimeoutError", "ConnectTimeoutError"]
AWS_REGION = env_string("AWS_REGION", "us-east-1")
55 changes: 52 additions & 3 deletions deltacat/tests/utils/test_daft.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,18 +42,67 @@ def test_read_from_s3_single_column_via_column_names(self):
self.assertEqual(table.num_rows, 100)

def test_read_from_s3_single_column_with_schema(self):
schema = pa.schema([("a", pa.int64()), ("b", pa.string())])
schema = pa.schema([("a", pa.int8()), ("b", pa.string())])
pa_read_func_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(
schema=schema
)
table = daft_s3_file_to_table(
self.MVP_PATH,
content_encoding=ContentEncoding.IDENTITY.value,
content_type=ContentType.PARQUET.value,
include_columns=["b"],
include_columns=["a"],
pa_read_func_kwargs_provider=pa_read_func_kwargs_provider,
)
self.assertEqual(table.schema.names, ["b"])
self.assertEqual(table.schema.names, ["a"])
self.assertEqual(table.schema.field("a").type, pa.int8())
self.assertEqual(table.num_rows, 100)

def test_read_from_s3_single_column_with_schema_reverse_order(self):
schema = pa.schema([("b", pa.string()), ("a", pa.int8())])
pa_read_func_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(
schema=schema
)
table = daft_s3_file_to_table(
self.MVP_PATH,
content_encoding=ContentEncoding.IDENTITY.value,
content_type=ContentType.PARQUET.value,
pa_read_func_kwargs_provider=pa_read_func_kwargs_provider,
)
self.assertEqual(table.schema.names, ["b", "a"])
self.assertEqual(table.schema.field("a").type, pa.int8())
self.assertEqual(table.num_rows, 100)

def test_read_from_s3_single_column_with_schema_subset_cols(self):
schema = pa.schema([("a", pa.int8())])
pa_read_func_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(
schema=schema
)
table = daft_s3_file_to_table(
self.MVP_PATH,
content_encoding=ContentEncoding.IDENTITY.value,
content_type=ContentType.PARQUET.value,
pa_read_func_kwargs_provider=pa_read_func_kwargs_provider,
)
self.assertEqual(table.schema.names, ["a"])
self.assertEqual(table.schema.field("a").type, pa.int8())
self.assertEqual(table.num_rows, 100)

def test_read_from_s3_single_column_with_schema_extra_cols(self):
schema = pa.schema([("a", pa.int8()), ("MISSING", pa.string())])
pa_read_func_kwargs_provider = ReadKwargsProviderPyArrowSchemaOverride(
schema=schema
)
table = daft_s3_file_to_table(
self.MVP_PATH,
content_encoding=ContentEncoding.IDENTITY.value,
content_type=ContentType.PARQUET.value,
pa_read_func_kwargs_provider=pa_read_func_kwargs_provider,
)
self.assertEqual(
table.schema.names, ["a", "MISSING"]
) # NOTE: "MISSING" is padded as a null array
self.assertEqual(table.schema.field("a").type, pa.int8())
self.assertEqual(table.schema.field("MISSING").type, pa.string())
self.assertEqual(table.num_rows, 100)

def test_read_from_s3_single_column_with_row_groups(self):
Expand Down
28 changes: 18 additions & 10 deletions deltacat/utils/daft.py
Original file line number Diff line number Diff line change
@@ -1,17 +1,17 @@
import logging
from typing import Optional, List

from daft.table import Table
from daft.logical.schema import Schema
from daft.table import read_parquet_into_pyarrow
from daft import TimeUnit
from daft.io import IOConfig, S3Config
import pyarrow as pa

from deltacat import logs
from deltacat.utils.common import ReadKwargsProvider
from deltacat.utils.schema import coerce_pyarrow_table_to_schema

from deltacat.types.media import ContentType, ContentEncoding
from deltacat.aws.constants import BOTO_MAX_RETRIES
from deltacat.aws.constants import BOTO_MAX_RETRIES, DAFT_MAX_S3_CONNECTIONS_PER_FILE
from deltacat.utils.performance import timed_invocation

from deltacat.types.partial_download import (
Expand Down Expand Up @@ -62,11 +62,12 @@ def daft_s3_file_to_table(
session_token=s3_client_kwargs.get("aws_session_token"),
retry_mode="adaptive",
num_tries=BOTO_MAX_RETRIES,
max_connections=DAFT_MAX_S3_CONNECTIONS_PER_FILE,
)
)

table, latency = timed_invocation(
Table.read_parquet,
pa_table, latency = timed_invocation(
read_parquet_into_pyarrow,
path=s3_url,
columns=include_columns or column_names,
row_groups=row_groups,
Expand All @@ -78,10 +79,17 @@ def daft_s3_file_to_table(
logger.debug(f"Time to read S3 object from {s3_url} into daft table: {latency}s")

if kwargs.get("schema") is not None:
schema = kwargs["schema"]
input_schema = kwargs["schema"]
if include_columns is not None:
schema = pa.schema([schema.field(col) for col in include_columns])
daft_schema = Schema.from_pyarrow_schema(schema)
return table.cast_to_schema(daft_schema).to_arrow()
input_schema = pa.schema(
jaychia marked this conversation as resolved.
Show resolved Hide resolved
[input_schema.field(col) for col in include_columns],
metadata=input_schema.metadata,
)
elif column_names is not None:
input_schema = pa.schema(
[input_schema.field(col) for col in column_names],
metadata=input_schema.metadata,
)
return coerce_pyarrow_table_to_schema(pa_table, input_schema)
else:
return table.to_arrow()
return pa_table
42 changes: 42 additions & 0 deletions deltacat/utils/schema.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pyarrow as pa


def coerce_pyarrow_table_to_schema(
pa_table: pa.Table, input_schema: pa.Schema
) -> pa.Table:
"""Coerces a PyArrow table to the supplied schema

1. For each field in `pa_table`, cast it to the field in `input_schema` if one with a matching name
is available
2. Reorder the fields in the casted table to the supplied schema
3. If any fields in the supplied schema are not present, add a null array of the correct type

Args:
pa_table (pa.Table): Table to coerce
input_schema (pa.Schema): Schema to coerce to

Returns:
pa.Table: Table with schema == `input_schema`
"""
input_schema_names = set(input_schema.names)

# Perform casting of types to provided schema's types
cast_to_schema = [
input_schema.field(inferred_field.name)
if inferred_field.name in input_schema_names
else inferred_field
for inferred_field in pa_table.schema
]
casted_table = pa_table.cast(pa.schema(cast_to_schema))

# Reorder and pad columns with a null column where necessary
pa_table_column_names = set(casted_table.column_names)
columns = []
for name in input_schema.names:
if name in pa_table_column_names:
columns.append(casted_table[name])
else:
columns.append(
pa.nulls(len(casted_table), type=input_schema.field(name).type)
)
return pa.table(columns, schema=input_schema)
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
# setup.py install_requires
# any changes here should also be reflected in setup.py "install_requires"
boto3 ~= 1.20
getdaft==0.1.16
getdaft==0.1.17
numpy == 1.21.5
pandas == 1.3.5
pyarrow == 12.0.1
Expand Down
2 changes: 1 addition & 1 deletion setup.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ def find_version(*paths):
"typing-extensions == 4.4.0",
"pymemcache == 4.0.0",
"redis == 4.6.0",
"getdaft == 0.1.16",
"getdaft == 0.1.17",
"schedule == 1.2.0",
],
setup_requires=["wheel"],
Expand Down