From 8235c8d50e7c283530707d67b7028dd975c4152b Mon Sep 17 00:00:00 2001 From: AliceJoubert Date: Thu, 12 Sep 2024 10:16:58 +0200 Subject: [PATCH] Testing wip --- clinica/utils/inputs.py | 9 ++-- test/unittests/utils/test_utils_inputs.py | 50 +++++++++++++++++++++-- 2 files changed, 52 insertions(+), 7 deletions(-) diff --git a/clinica/utils/inputs.py b/clinica/utils/inputs.py index 959903c86..c984b7114 100644 --- a/clinica/utils/inputs.py +++ b/clinica/utils/inputs.py @@ -629,9 +629,12 @@ def _remove_sub_ses_from_list( session_indexes = [ i for i, session in enumerate(list_sessions) if session == ses ] - to_remove = list(set(sub_indexes) & set(session_indexes))[0] - list_subjects.pop(to_remove) - list_sessions.pop(to_remove) + to_remove = list(set(sub_indexes) & set(session_indexes)) + if to_remove: + if len(to_remove) > 1: + raise ValueError("") + list_subjects.pop(to_remove[0]) + list_sessions.pop(to_remove[0]) def clinica_file_reader( diff --git a/test/unittests/utils/test_utils_inputs.py b/test/unittests/utils/test_utils_inputs.py index 4659ebe02..5689087e5 100644 --- a/test/unittests/utils/test_utils_inputs.py +++ b/test/unittests/utils/test_utils_inputs.py @@ -14,6 +14,48 @@ ) +@pytest.mark.parametrize( + "input_subjects, input_sessions, to_remove, expected_subjects, expected_sessions", + ( + [ + ["sub1", "sub1", "sub2"], + ["ses1", "ses2", "ses1"], + [("sub1", "ses1"), ("sub3", "ses1")], + ["sub1", "sub2"], + ["ses2", "ses1"], + ], + [ + [ + "sub1", + ], + [ + "ses1", + ], + [], + [ + "sub1", + ], + [ + "ses1", + ], + ], + ), +) +def test_remove_sub_ses_from_list_success( + input_subjects, input_sessions, to_remove, expected_subjects, expected_sessions +): + from clinica.utils.inputs import _remove_sub_ses_from_list + + _remove_sub_ses_from_list(input_subjects, input_sessions, to_remove) + assert input_subjects == expected_subjects + assert input_sessions == expected_sessions + + +def test_remove_sub_ses_from_list_error(): + # todo + pass + + def test_get_parent_path(tmp_path): from clinica.utils.inputs import _get_parent_path @@ -389,7 +431,7 @@ def test_find_sub_ses_pattern_path_error_no_file(tmp_path): assert len(results) == 0 assert len(errors) == 1 - assert errors[0] == "\t* (sub-01 | ses-M00): No file found\n" + assert errors[0] == ("sub-01", "ses-M00") def test_find_sub_ses_pattern_path_error_more_than_one_file(tmp_path): @@ -410,7 +452,7 @@ def test_find_sub_ses_pattern_path_error_more_than_one_file(tmp_path): assert len(results) == 0 assert len(errors) == 1 - assert "\t* (sub-01 | ses-M00): More than 1 file found:" in errors[0] + assert errors[0] == ("sub-01", "ses-M00") def test_find_sub_ses_pattern_path(tmp_path): @@ -554,7 +596,7 @@ def test_clinica_file_reader_bids_directory(tmp_path, data_type): ) assert clinica_file_reader( [], [], tmp_path, information, raise_exception=True, n_procs=1 - ) == ([], "") + ) == ([], []) results, error_msg = clinica_file_reader( ["sub-01"], ["ses-M00"], tmp_path, information, raise_exception=True, n_procs=1 ) @@ -640,7 +682,7 @@ def test_clinica_file_reader_caps_directory(tmp_path): assert clinica_file_reader( [], [], tmp_path, information, raise_exception=True, n_procs=1 - ) == ([], "") + ) == ([], []) results, error_msg = clinica_file_reader( ["sub-01"], ["ses-M00"], tmp_path, information, raise_exception=True, n_procs=1