Skip to content

Commit

Permalink
fix tests
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 committed Jan 23, 2024
1 parent 83be7b2 commit 067fbf4
Show file tree
Hide file tree
Showing 3 changed files with 7 additions and 11 deletions.
9 changes: 4 additions & 5 deletions daft/table/table_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -388,7 +388,6 @@ def write_tabular(
part_keys_postfix_per_table = [None]

visited_paths = []
visited_sizes = []
partition_idx = []

execution_config = get_context().daft_execution_config
Expand Down Expand Up @@ -424,20 +423,20 @@ def write_tabular(
target_num_files = max(math.ceil(size_bytes / target_file_size / inflation_factor), 1)
num_rows = len(arrow_table)

rows_per_file = math.ceil(num_rows / target_num_files)
rows_per_file = max(math.ceil(num_rows / target_num_files), 1)

target_row_groups = max(math.ceil(size_bytes / TARGET_ROW_GROUP_SIZE / inflation_factor), 1)
rows_per_row_group = min(math.ceil(num_rows / target_row_groups), rows_per_file)
rows_per_row_group = max(min(math.ceil(num_rows / target_row_groups), rows_per_file), 1)

def file_visitor(written_file):
visited_paths.append(written_file.path)
visited_sizes.append(written_file.size)
partition_idx.append(i)

Check warning on line 433 in daft/table/table_io.py

View check run for this annotation

Codecov / codecov/patch

daft/table/table_io.py#L432-L433

Added lines #L432 - L433 were not covered by tests

kwargs = dict()

if ARROW_VERSION >= (7, 0, 0):
kwargs["max_rows_per_file"] = rows_per_file
print(rows_per_row_group, rows_per_file)
kwargs["min_rows_per_group"] = rows_per_row_group
kwargs["max_rows_per_group"] = rows_per_row_group

Expand All @@ -455,7 +454,7 @@ def file_visitor(written_file):
**kwargs,
)

data_dict: dict[str, Any] = {schema.column_names()[0]: visited_paths, schema.column_names()[1]: visited_sizes}
data_dict: dict[str, Any] = {schema.column_names()[0]: visited_paths}

if partition_values is not None:
partition_idx_series = Series.from_pylist(partition_idx).cast(DataType.int64())
Expand Down
5 changes: 1 addition & 4 deletions src/daft-plan/src/logical_ops/sink.rs
Original file line number Diff line number Diff line change
Expand Up @@ -22,10 +22,7 @@ pub struct Sink {

impl Sink {
pub(crate) fn try_new(input: Arc<LogicalPlan>, sink_info: Arc<SinkInfo>) -> DaftResult<Self> {
let mut fields = vec![
Field::new("path", daft_core::DataType::Utf8),
Field::new("file_size", daft_core::DataType::UInt64),
];
let mut fields = vec![Field::new("path", daft_core::DataType::Utf8)];
let schema = input.schema();

match sink_info.as_ref() {
Expand Down
4 changes: 2 additions & 2 deletions tests/cookbook/test_write.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,8 +38,8 @@ def test_empty_parquet_write_without_partitioning(tmp_path):
df = daft.read_csv(COOKBOOK_DATA_CSV)
df = df.where(daft.lit(False))
output_files = df.write_parquet(tmp_path)
assert len(output_files) == 1
assert len(output_files._preview.preview_partition) == 1
assert len(output_files) == 0
assert len(output_files._preview.preview_partition) == 0


def test_empty_parquet_write_with_partitioning(tmp_path):
Expand Down

0 comments on commit 067fbf4

Please sign in to comment.