Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Support creation and reading of StructuredDataset with local or remote uri #2914

25 changes: 25 additions & 0 deletions flytekit/types/structured/structured_dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,6 +30,7 @@
from flytekit.models import types as type_models
from flytekit.models.literals import Binary, Literal, Scalar, StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType
from flytekit.utils.asyn import loop_manager

if typing.TYPE_CHECKING:
import pandas as pd
Expand Down Expand Up @@ -176,8 +177,32 @@ def all(self) -> DF: # type: ignore
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
ctx = FlyteContextManager.current_context()

if self.uri is not None and self.dataframe is None:
expected = TypeEngine.to_literal_type(StructuredDataset)
self._set_literal(ctx, expected)

return flyte_dataset_transformer.open_as(ctx, self.literal, self._dataframe_type, self.metadata)

def _set_literal(self, ctx: FlyteContext, expected: LiteralType) -> None:
"""
Explicitly set the StructuredDataset Literal to handle the following cases:

1. Read a dataframe from a StructuredDataset with an uri, for example:

@task
def return_sd() -> StructuredDataset:
sd = StructuredDataset(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet")
df = sd.open(pd.DataFrame).all()
return df

For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5954.
"""
Comment on lines +187 to +200
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Very nice comment.

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

It means a lot to hear it from you!

to_literal = loop_manager.synced(flyte_dataset_transformer.async_to_literal)
self._literal_sd = to_literal(ctx, self, StructuredDataset, expected).scalar.structured_dataset
Comment on lines +201 to +202
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not sure if is here the best way to write it.
cc @wild-endeavor @thomasjpfan

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Thanks for pointing this out. I'm also pondering if this is a good practice...

Will be glad to learn more from you guys!

if self.metadata is None:
self._metadata = self._literal_sd.metadata

def iter(self) -> Generator[DF, None, None]:
if self._dataframe_type is None:
raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.")
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import tempfile
import typing
from collections import OrderedDict
from pathlib import Path

import google.cloud.bigquery
import pytest
Expand All @@ -21,6 +22,7 @@
from flytekit.models.literals import StructuredDatasetMetadata
from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType
from flytekit.tools.translator import get_serializable
from flytekit.types.file import FlyteFile
from flytekit.types.structured.structured_dataset import (
PARQUET,
StructuredDataset,
Expand Down Expand Up @@ -59,6 +61,21 @@ def generate_pandas() -> pd.DataFrame:
return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]})


@pytest.fixture
def local_tmp_pqt_file():
df = generate_pandas()

# Create a temporary parquet file
with tempfile.NamedTemporaryFile(delete=False, mode="w+b", suffix=".parquet") as pqt_file:
pqt_path = pqt_file.name
df.to_parquet(pqt_path)

yield pqt_path

# Cleanup
Path(pqt_path).unlink(missing_ok=True)


def test_formats_make_sense():
@task
def t1(a: pd.DataFrame) -> pd.DataFrame:
Expand Down Expand Up @@ -643,3 +660,27 @@ def wf_with_input() -> pd.DataFrame:

pd.testing.assert_frame_equal(wf_no_input(), default_val)
pd.testing.assert_frame_equal(wf_with_input(), input_val)



def test_read_sd_from_local_uri(local_tmp_pqt_file):

@task
def read_sd_from_uri(uri: str) -> pd.DataFrame:
sd = StructuredDataset(uri=uri, file_format="parquet")
df = sd.open(pd.DataFrame).all()

return df

@workflow
def read_sd_from_local_uri(uri: str) -> pd.DataFrame:
df = read_sd_from_uri(uri=uri)

return df


df = generate_pandas()

# Read sd from local uri
df_local = read_sd_from_local_uri(uri=local_tmp_pqt_file)
pd.testing.assert_frame_equal(df, df_local)
Loading