diff --git a/tests/test_raillabel/filter/test_filter.py b/tests/test_raillabel/filter/test_filter.py index db12ac6..de1dc37 100644 --- a/tests/test_raillabel/filter/test_filter.py +++ b/tests/test_raillabel/filter/test_filter.py @@ -29,13 +29,18 @@ def delete_sensor_from_data(data: dict, sensor_id: str) -> dict: return data +def metadata() -> raillabel.format.Metadata: + return raillabel.format.Metadata(schema_version="1.0.0") + +def build_frames(frames: list) -> dict: + return {frame.uid: frame for frame in frames} + @pytest.fixture def loader(): return raillabel.load_.loader_classes.LoaderRailLabel() def test_filter_unexpected_kwarg(json_paths): - # Loads scene scene = raillabel.load(json_paths["openlabel_v1_short"], validate=False) with pytest.raises(TypeError): @@ -43,34 +48,35 @@ def test_filter_unexpected_kwarg(json_paths): def test_mutual_exclusivity(json_paths): - # Loads scene scene = raillabel.load(json_paths["openlabel_v1_short"], validate=False) with pytest.raises(ValueError): raillabel.filter(scene, include_frames=[0], exclude_frames=[1, 2]) -def test_filter_frames(json_paths, json_data, loader): - data = json_data["openlabel_v1_short"] - - # Loads scene - scene = raillabel.load(json_paths["openlabel_v1_short"], validate=False) - - # Deletes the excluded data - del data["openlabel"]["frames"]["1"] - del data["openlabel"]["objects"]["6fe55546-0dd7-4e40-b6b4-bb7ea3445772"] - data = delete_sensor_from_data(data, "radar") - - # Loads the ground truth filtered data - scene_filtered_ground_truth = loader.load(data) +def test_filter_frames(): + scene = raillabel.Scene( + metadata=metadata, + frames=build_frames([ + raillabel.format.Frame(0), + raillabel.format.Frame(1), + raillabel.format.Frame(2), + ]) + ) - # Tests for include filter - scene_filtered = raillabel.filter(scene, include_frames=[0]) - assert scene_filtered == scene_filtered_ground_truth + assert raillabel.filter(scene, include_frames=[0]) == raillabel.Scene( + metadata=metadata, + frames=build_frames([ + raillabel.format.Frame(0), + ]) + ) - # Tests for exclude filter - scene_filtered = raillabel.filter(scene, exclude_frames=[1]) - assert scene_filtered == scene_filtered_ground_truth + assert raillabel.filter(scene, exclude_frames=[1, 2]) == raillabel.Scene( + metadata=metadata, + frames=build_frames([ + raillabel.format.Frame(0), + ]) + ) def test_filter_start(json_paths, json_data, loader):