Skip to content

Commit

Permalink
fix test
Browse files Browse the repository at this point in the history
  • Loading branch information
Colin Ho authored and Colin Ho committed Oct 25, 2024
1 parent f5f44bf commit cffe79a
Showing 1 changed file with 39 additions and 26 deletions.
65 changes: 39 additions & 26 deletions tests/io/test_write_modes.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,10 +2,10 @@
from typing import List, Optional

import pytest
import s3fs

import daft
from daft import context
from tests.integration.io.conftest import minio_create_bucket

pytestmark = pytest.mark.skipif(
context.get_context().daft_execution_config.enable_native_executor is True,
Expand Down Expand Up @@ -53,7 +53,7 @@ def arrange_write_mode_test(existing_data, new_data, path, format, write_mode, p
write(existing_data, path, format, "append", partition_cols, io_config)

# Write some new data
write(new_data, path, format, write_mode, partition_cols, io_config)
print(write(new_data, path, format, write_mode, partition_cols, io_config))

# Read back the data
read_path = path + "/**" if partition_cols is not None else path
Expand Down Expand Up @@ -92,40 +92,53 @@ def test_write_modes_local(tmp_path, write_mode, format, num_partitions, partiti
raise ValueError(f"Unsupported write_mode: {write_mode}")


@pytest.fixture(scope="function")
def bucket(minio_io_config):
BUCKET = "write-modes-bucket"

fs = s3fs.S3FileSystem(
key=minio_io_config.s3.key_id,
password=minio_io_config.s3.access_key,
client_kwargs={"endpoint_url": minio_io_config.s3.endpoint_url},
)
if not fs.exists(BUCKET):
fs.mkdir(BUCKET)
yield BUCKET


@pytest.mark.integration()
@pytest.mark.parametrize("write_mode", ["append", "overwrite"])
@pytest.mark.parametrize("format", ["csv", "parquet"])
@pytest.mark.parametrize("num_partitions", [1, 2])
@pytest.mark.parametrize("partition_cols", [None, ["a"]])
def test_write_modes_s3_minio(
minio_io_config,
bucket,
write_mode,
format,
num_partitions,
partition_cols,
):
bucket_name = "my-bucket2"
path = f"s3://{bucket_name}/write_modes_s3_minio-{uuid.uuid4()}"
with minio_create_bucket(minio_io_config=minio_io_config, bucket_name=bucket_name):
existing_data = {"a": [i for i in range(10)]}
new_data = {
"a": [i for i in range(10, 20)],
}

read_back = arrange_write_mode_test(
daft.from_pydict(existing_data).into_partitions(num_partitions),
daft.from_pydict(new_data).into_partitions(num_partitions),
path,
format,
write_mode,
partition_cols,
minio_io_config,
)
path = f"s3://{bucket}/{str(uuid.uuid4())}"
existing_data = {"a": [i for i in range(10)]}
new_data = {
"a": [i for i in range(10, 20)],
}

read_back = arrange_write_mode_test(
daft.from_pydict(existing_data).into_partitions(num_partitions),
daft.from_pydict(new_data).into_partitions(num_partitions),
path,
format,
write_mode,
partition_cols,
minio_io_config,
)

# Check the data
if write_mode == "append":
assert read_back["a"] == existing_data["a"] + new_data["a"]
elif write_mode == "overwrite":
assert read_back["a"] == new_data["a"]
else:
raise ValueError(f"Unsupported write_mode: {write_mode}")
# Check the data
if write_mode == "append":
assert read_back["a"] == existing_data["a"] + new_data["a"]
elif write_mode == "overwrite":
assert read_back["a"] == new_data["a"]
else:
raise ValueError(f"Unsupported write_mode: {write_mode}")

0 comments on commit cffe79a

Please sign in to comment.