From b04bc8dfaf78b35d4990792d83f591acb9674359 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?=E6=B1=9F=E5=AE=B6=E7=91=8B?= <36886416+JiangJiaWei1103@users.noreply.github.com> Date: Tue, 19 Nov 2024 14:38:57 +0100 Subject: [PATCH] [BUG] Support creation and reading of StructuredDataset with local or remote uri (#2914) * Manually fill sd literal and metadata Signed-off-by: JiaWei Jiang * Use StructuredDatasetTransformerEngine to set sd literal Signed-off-by: JiaWei Jiang * Add unit test for reading sd from uri Signed-off-by: JiaWei Jiang * Put tasks into wf to mimic real-world use cases Signed-off-by: JiaWei Jiang * add env Signed-off-by: Future-Outlier * env again Signed-off-by: Future-Outlier * Modify python ff reading logic to a flyte task Signed-off-by: JiaWei Jiang * Use task param instead of global path const Signed-off-by: JiaWei Jiang * Remove unit tests that need to interact with s3 Signed-off-by: JiaWei Jiang --------- Signed-off-by: JiaWei Jiang Signed-off-by: Future-Outlier Co-authored-by: Future-Outlier --- .../types/structured/structured_dataset.py | 25 +++++++++++ .../test_structured_dataset.py | 41 +++++++++++++++++++ 2 files changed, 66 insertions(+) diff --git a/flytekit/types/structured/structured_dataset.py b/flytekit/types/structured/structured_dataset.py index 89d088c264..f4a2194749 100644 --- a/flytekit/types/structured/structured_dataset.py +++ b/flytekit/types/structured/structured_dataset.py @@ -30,6 +30,7 @@ from flytekit.models import types as type_models from flytekit.models.literals import Binary, Literal, Scalar, StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, StructuredDatasetType +from flytekit.utils.asyn import loop_manager if typing.TYPE_CHECKING: import pandas as pd @@ -176,8 +177,32 @@ def all(self) -> DF: # type: ignore if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") ctx = FlyteContextManager.current_context() + + if self.uri is not None and self.dataframe is None: + expected = TypeEngine.to_literal_type(StructuredDataset) + self._set_literal(ctx, expected) + return flyte_dataset_transformer.open_as(ctx, self.literal, self._dataframe_type, self.metadata) + def _set_literal(self, ctx: FlyteContext, expected: LiteralType) -> None: + """ + Explicitly set the StructuredDataset Literal to handle the following cases: + + 1. Read a dataframe from a StructuredDataset with an uri, for example: + + @task + def return_sd() -> StructuredDataset: + sd = StructuredDataset(uri="s3://my-s3-bucket/s3_flyte_dir/df.parquet", file_format="parquet") + df = sd.open(pd.DataFrame).all() + return df + + For details, please refer to this issue: https://github.com/flyteorg/flyte/issues/5954. + """ + to_literal = loop_manager.synced(flyte_dataset_transformer.async_to_literal) + self._literal_sd = to_literal(ctx, self, StructuredDataset, expected).scalar.structured_dataset + if self.metadata is None: + self._metadata = self._literal_sd.metadata + def iter(self) -> Generator[DF, None, None]: if self._dataframe_type is None: raise ValueError("No dataframe type set. Use open() to set the local dataframe type you want to use.") diff --git a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py index 3cc19f295c..5433b79a9c 100644 --- a/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py +++ b/tests/flytekit/unit/types/structured_dataset/test_structured_dataset.py @@ -2,6 +2,7 @@ import tempfile import typing from collections import OrderedDict +from pathlib import Path import google.cloud.bigquery import pytest @@ -21,6 +22,7 @@ from flytekit.models.literals import StructuredDatasetMetadata from flytekit.models.types import LiteralType, SchemaType, SimpleType, StructuredDatasetType from flytekit.tools.translator import get_serializable +from flytekit.types.file import FlyteFile from flytekit.types.structured.structured_dataset import ( PARQUET, StructuredDataset, @@ -59,6 +61,21 @@ def generate_pandas() -> pd.DataFrame: return pd.DataFrame({"name": ["Tom", "Joseph"], "age": [20, 22]}) +@pytest.fixture +def local_tmp_pqt_file(): + df = generate_pandas() + + # Create a temporary parquet file + with tempfile.NamedTemporaryFile(delete=False, mode="w+b", suffix=".parquet") as pqt_file: + pqt_path = pqt_file.name + df.to_parquet(pqt_path) + + yield pqt_path + + # Cleanup + Path(pqt_path).unlink(missing_ok=True) + + def test_formats_make_sense(): @task def t1(a: pd.DataFrame) -> pd.DataFrame: @@ -643,3 +660,27 @@ def wf_with_input() -> pd.DataFrame: pd.testing.assert_frame_equal(wf_no_input(), default_val) pd.testing.assert_frame_equal(wf_with_input(), input_val) + + + +def test_read_sd_from_local_uri(local_tmp_pqt_file): + + @task + def read_sd_from_uri(uri: str) -> pd.DataFrame: + sd = StructuredDataset(uri=uri, file_format="parquet") + df = sd.open(pd.DataFrame).all() + + return df + + @workflow + def read_sd_from_local_uri(uri: str) -> pd.DataFrame: + df = read_sd_from_uri(uri=uri) + + return df + + + df = generate_pandas() + + # Read sd from local uri + df_local = read_sd_from_local_uri(uri=local_tmp_pqt_file) + pd.testing.assert_frame_equal(df, df_local)