diff --git a/audeer/core/io.py b/audeer/core/io.py index a73d5a1..bfe34f0 100644 --- a/audeer/core/io.py +++ b/audeer/core/io.py @@ -345,7 +345,7 @@ def list_dir_names( """List of folder names located inside provided path. Args: - path: path to directory + path: path to directory or pattern basenames: if ``True`` return relative path in respect to ``path`` recursive: if ``True`` includes subdirectories hidden: if ``True`` includes directories starting with a dot (``.``) @@ -354,8 +354,10 @@ def list_dir_names( list of paths to directories Raises: - NotADirectoryError: if path is not a directory - FileNotFoundError: if path does not exists + NotADirectoryError: if ``os.path.dirname(path)`` + is not a directory + FileNotFoundError: if ``os.path.dirname(path)`` + does not exists Examples: >>> _ = mkdir('path/a/.b/c') @@ -380,16 +382,30 @@ def list_dir_names( """ path = safe_path(path) + if not os.path.isdir(path): + pattern = os.path.basename(path) + path = os.path.dirname(path) + else: + pattern = None def helper(p: str, paths: typing.List[str]): ps = [os.path.join(p, x) for x in os.listdir(p)] - ps = [x for x in ps if os.path.isdir(x)] + folders = [x for x in ps if os.path.isdir(x)] + if pattern: + folders = [ + folder for folder in folders + if fnmatch.fnmatch(os.path.basename(folder), f'{pattern}') + ] + if not hidden: - ps = [x for x in ps if not os.path.basename(x).startswith('.')] - paths.extend(ps) - if len(ps) > 0 and recursive: - for p in ps: - helper(p, paths) + folders = [ + folder for folder in folders + if not os.path.basename(folder).startswith('.') + ] + paths.extend(folders) + if len(folders) > 0 and recursive: + for folder in folders: + helper(folder, paths) paths = [] helper(path, paths) diff --git a/tests/test_io.py b/tests/test_io.py index 062f314..e310ec1 100644 --- a/tests/test_io.py +++ b/tests/test_io.py @@ -190,38 +190,47 @@ def test_file_extension(path, extension): @pytest.mark.parametrize( - 'dir_list,expected,recursive,hidden', + 'dir_list,path,expected,recursive,hidden', [ - ([], [], False, False), - ([], [], True, False), - (['a', 'b', 'c'], ['a', 'b', 'c'], False, False), - (['a', 'b', 'c'], ['a', 'b', 'c'], True, False), - (['a'], ['a'], False, False), - (['a'], ['a'], True, False), + ([], './', [], False, False), + ([], './', [], True, False), + (['a', 'b', 'c'], './', ['a', 'b', 'c'], False, False), + (['a', 'b', 'c'], './', ['a', 'b', 'c'], True, False), + (['a'], './', ['a'], False, False), + (['a'], './', ['a'], True, False), ( ['a', os.path.join('a', 'b'), os.path.join('a', 'b', 'c')], + './', ['a', os.path.join('a', 'b'), os.path.join('a', 'b', 'c')], True, False, ), + # pattern + ([], './a', [], True, False), + (['a', 'b', 'c'], './a', [], False, False), + (['a', 'b', 'c'], './a', [], True, False), + (['aa', 'ba', 'ca'], './a*', ['aa'], False, False), + (['aa', 'ba', 'ca'], './ba', [], False, False), # hidden - (['a', '.b'], ['a'], True, False), - (['a', '.b'], ['.b', 'a'], True, True), + (['a', '.b'], './', ['a'], True, False), + (['a', '.b'], './', ['.b', 'a'], True, True), ( ['a', '.b', os.path.join('a', '.b'), os.path.join('a', '.b', 'c')], + './', ['a'], True, False, ), ( ['a', '.b', os.path.join('a', '.b'), os.path.join('a', '.b', 'c')], + './', ['.b', 'a', os.path.join('a', '.b'), os.path.join('a', '.b', 'c')], True, True, ), ], ) -def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden): +def test_list_dir_names(tmpdir, dir_list, path, expected, recursive, hidden): dir_tmp = tmpdir.mkdir('folder') directories = [] @@ -232,9 +241,9 @@ def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden): for directory in directories: assert os.path.isdir(directory) - path = os.path.join(str(dir_tmp), '.') + abs_path = os.path.join(str(dir_tmp), path) dirs = audeer.list_dir_names( - path, + abs_path, basenames=False, recursive=recursive, hidden=hidden, @@ -244,7 +253,7 @@ def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden): # test basenames dirs = audeer.list_dir_names( - path, + abs_path, basenames=True, recursive=recursive, hidden=hidden, @@ -254,10 +263,10 @@ def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden): def test_list_dir_names_errors(tmpdir): with pytest.raises(NotADirectoryError): - file = audeer.touch(audeer.path(tmpdir, 'file.txt')) - audeer.list_dir_names(file) + file = audeer.touch(audeer.path(tmpdir, 'file')) + audeer.list_dir_names(audeer.path(file, 'pattern')) with pytest.raises(FileNotFoundError): - audeer.list_dir_names('not-existent') + audeer.list_dir_names('not-existent/pattern') @pytest.mark.parametrize(