Skip to content

Commit

Permalink
add unit-test for snowflake structure dataset encoder/decoder
Browse files Browse the repository at this point in the history
Signed-off-by: HH <[email protected]>
  • Loading branch information
hhcs9527 committed Sep 10, 2023
1 parent 06aa5a7 commit afe39bc
Show file tree
Hide file tree
Showing 2 changed files with 17 additions and 32 deletions.
34 changes: 10 additions & 24 deletions flytekit/types/structured/snowflake.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import typing

import pandas as pd
import pyarrow as pa
import snowflake.connector
from snowflake.connector.pandas_tools import write_pandas

Expand All @@ -20,26 +19,21 @@


def get_private_key():
import os
from cryptography.hazmat.backends import default_backend
from cryptography.hazmat.primitives.asymmetric import rsa
from cryptography.hazmat.primitives.asymmetric import dsa
from cryptography.hazmat.primitives import serialization

import flytekit

pk_path = flytekit.current_context().secrets.get_secrets_file(SNOWFLAKE, "rsa_key.p8")

with open(pk_path, "rb") as key:
p_key= serialization.load_pem_private_key(
key.read(),
password=None,
backend=default_backend()
)
p_key = serialization.load_pem_private_key(key.read(), password=None, backend=default_backend())

return p_key.private_bytes(
encoding=serialization.Encoding.DER,
format=serialization.PrivateFormat.PKCS8,
encryption_algorithm=serialization.NoEncryption())
encryption_algorithm=serialization.NoEncryption(),
)


def _write_to_sf(structured_dataset: StructuredDataset):
Expand All @@ -51,15 +45,9 @@ def _write_to_sf(structured_dataset: StructuredDataset):
df = structured_dataset.dataframe

conn = snowflake.connector.connect(
user=user,
account=account,
private_key=get_private_key(),
database=database,
schema=schema,
warehouse=warehouse
user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse
)

cs = conn.cursor()
write_pandas(conn, df, table)


Expand All @@ -73,18 +61,16 @@ def _read_from_sf(
_, user, account, database, schema, warehouse, table = re.split("\\/|://|:", uri)

conn = snowflake.connector.connect(
user=user,
account=account,
private_key=get_private_key(),
database=database,
schema=schema,
warehouse=warehouse
user=user, account=account, private_key=get_private_key(), database=database, schema=schema, warehouse=warehouse
)

cs = conn.cursor()
cs.execute(f"select * from {table}")

return cs.fetch_pandas_all()
dff = cs.fetch_pandas_all()
print("cs", cs)
print("dff", dff)
return dff


class PandasToSnowflakeEncodingHandlers(StructuredDatasetEncoder):
Expand Down
15 changes: 7 additions & 8 deletions tests/flytekit/unit/types/structured_dataset/test_snowflake.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import mock
import pytest
import pandas as pd
import pytest
from typing_extensions import Annotated

from flytekit import StructuredDataset, kwtypes, task, workflow
Expand All @@ -16,7 +16,9 @@ def gen_df() -> Annotated[pd.DataFrame, my_cols, "parquet"]:

@task
def t1(df: pd.DataFrame) -> Annotated[StructuredDataset, my_cols]:
return StructuredDataset(dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table")
return StructuredDataset(
dataframe=df, uri="snowflake://dummy_user:dummy_account/dummy_database/dummy_schema/dummy_warehouse/dummy_table"
)


@task
Expand All @@ -34,15 +36,12 @@ def wf() -> pd.DataFrame:
@mock.patch("snowflake.connector.connect")
@pytest.mark.asyncio
async def test_sf_wf(mock_connect):
class mock_pages:
class mock_dataframe:
def to_dataframe(self):
return pd.DataFrame({"Name": ["Tom", "Joseph"], "Age": [20, 22]})

class mock_rows:
pages = [mock_pages()]

mock_connect_instance = mock_connect.return_value
mock_coursor_instance = mock_connect.cursor.return_value
mock_coursor_instance.fetch_pandas_all.return_value = mock_rows
mock_coursor_instance = mock_connect_instance.cursor.return_value
mock_coursor_instance.fetch_pandas_all.return_value = mock_dataframe().to_dataframe()

assert wf().equals(pd_df)

0 comments on commit afe39bc

Please sign in to comment.