Skip to content

Commit

Permalink
[BUG] Concat Fix when Variable Length Array is sliced (#1750)
Browse files Browse the repository at this point in the history
  • Loading branch information
samster25 authored Dec 20, 2023
1 parent 555aed1 commit 52b5209
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 1 deletion.
3 changes: 2 additions & 1 deletion src/daft-core/src/array/ops/concat.rs
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,8 @@ macro_rules! impl_variable_length_concat {
bitmap.extend_constant(arr.len(), true);
}
}
buffer.extend_from_slice(arr.values().as_slice());
let range = (*arr.offsets().first() as usize)..(*arr.offsets().last() as usize);
buffer.extend_from_slice(&arr.values().as_slice()[range]);
}
let dtype = arrays.first().unwrap().data_type().clone();
#[allow(unused_unsafe)]
Expand Down
21 changes: 21 additions & 0 deletions tests/series/test_concat.py
Original file line number Diff line number Diff line change
Expand Up @@ -42,6 +42,27 @@ def test_series_concat(dtype, chunks) -> None:
counter += 1


@pytest.mark.parametrize(
"dtype, chunks", itertools.product(ARROW_FLOAT_TYPES + ARROW_INT_TYPES + ARROW_STRING_TYPES, [1, 2, 3, 10])
)
def test_series_concat_with_slicing(dtype, chunks) -> None:
series = []
for i in range(chunks):
s = Series.from_pylist([i] * 4).cast(dtype=DataType.from_arrow_type(dtype))
series.append(s.slice(0, 2))

concated = Series.concat(series)

assert concated.datatype() == DataType.from_arrow_type(dtype)
concated_list = concated.to_pylist()

counter = 0
for i in range(chunks):
for _ in range(2):
assert float(concated_list[counter]) == i
counter += 1


@pytest.mark.parametrize("fixed", [False, True])
@pytest.mark.parametrize("chunks", [1, 2, 3, 10])
def test_series_concat_list_array(chunks, fixed) -> None:
Expand Down

0 comments on commit 52b5209

Please sign in to comment.