diff --git a/tests/test_output.py b/tests/test_output.py index 8a0c7c0e..6aa629fd 100644 --- a/tests/test_output.py +++ b/tests/test_output.py @@ -153,3 +153,41 @@ def test_save(): assert helpers.allclose(0.01, save.data["porosity"].mean()) assert "userx" not in save.data + + +@pytest.mark.parametrize( + "output_ref, islice", + [ + (helpers.output_eleme[0], 0), + (helpers.output_eleme[0], [0, 2]), + (helpers.output_eleme[0], "AAA00"), + (helpers.output_eleme[0], ["AAA00", "AAA02"]), + (helpers.output_conne[0], 0), + (helpers.output_conne[0], [0, 2]), + (helpers.output_conne[0], "AAA00"), + ], +) +def test_getitem(output_ref, islice): + output = output_ref[islice] + + idx = [islice] if isinstance(islice, (int, str)) else islice + idx = [i if isinstance(i, int) else int(i[-1]) for i in idx] + + if not isinstance(output, dict): + assert np.allclose(output.time, output_ref.time) + assert len(idx) == output.n_data + + for i, iref in enumerate(idx): + if isinstance(output.labels[i], str): + assert output.labels[i] == output_ref.labels[iref] + + else: + for label, label_ref in zip(output.labels[i], output_ref.labels[iref]): + assert label == label_ref + + for k, v in output.data.items(): + assert np.allclose(v[i], output_ref.data[k][iref]) + + else: + for k, v in output.items(): + assert np.allclose(v, output_ref.data[k][idx[0]])