Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Allow path to be pattern in list_dir_names() #92

Draft
wants to merge 2 commits into
base: main
Choose a base branch
from
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 25 additions & 9 deletions audeer/core/io.py
Original file line number Diff line number Diff line change
Expand Up @@ -345,7 +345,7 @@ def list_dir_names(
"""List of folder names located inside provided path.

Args:
path: path to directory
path: path to directory or pattern
basenames: if ``True`` return relative path in respect to ``path``
recursive: if ``True`` includes subdirectories
hidden: if ``True`` includes directories starting with a dot (``.``)
Expand All @@ -354,8 +354,10 @@ def list_dir_names(
list of paths to directories

Raises:
NotADirectoryError: if path is not a directory
FileNotFoundError: if path does not exists
NotADirectoryError: if ``os.path.dirname(path)``
is not a directory
FileNotFoundError: if ``os.path.dirname(path)``
does not exists

Examples:
>>> _ = mkdir('path/a/.b/c')
Expand All @@ -380,16 +382,30 @@ def list_dir_names(

"""
path = safe_path(path)
if not os.path.isdir(path):
pattern = os.path.basename(path)
path = os.path.dirname(path)
else:
pattern = None

def helper(p: str, paths: typing.List[str]):
ps = [os.path.join(p, x) for x in os.listdir(p)]
ps = [x for x in ps if os.path.isdir(x)]
folders = [x for x in ps if os.path.isdir(x)]
if pattern:
folders = [
folder for folder in folders
if fnmatch.fnmatch(os.path.basename(folder), f'{pattern}')
]

if not hidden:
ps = [x for x in ps if not os.path.basename(x).startswith('.')]
paths.extend(ps)
if len(ps) > 0 and recursive:
for p in ps:
helper(p, paths)
folders = [
folder for folder in folders
if not os.path.basename(folder).startswith('.')
]
paths.extend(folders)
if len(folders) > 0 and recursive:
for folder in folders:
helper(folder, paths)

paths = []
helper(path, paths)
Expand Down
41 changes: 25 additions & 16 deletions tests/test_io.py
Original file line number Diff line number Diff line change
Expand Up @@ -190,38 +190,47 @@ def test_file_extension(path, extension):


@pytest.mark.parametrize(
'dir_list,expected,recursive,hidden',
'dir_list,path,expected,recursive,hidden',
[
([], [], False, False),
([], [], True, False),
(['a', 'b', 'c'], ['a', 'b', 'c'], False, False),
(['a', 'b', 'c'], ['a', 'b', 'c'], True, False),
(['a'], ['a'], False, False),
(['a'], ['a'], True, False),
([], './', [], False, False),
([], './', [], True, False),
(['a', 'b', 'c'], './', ['a', 'b', 'c'], False, False),
(['a', 'b', 'c'], './', ['a', 'b', 'c'], True, False),
(['a'], './', ['a'], False, False),
(['a'], './', ['a'], True, False),
(
['a', os.path.join('a', 'b'), os.path.join('a', 'b', 'c')],
'./',
['a', os.path.join('a', 'b'), os.path.join('a', 'b', 'c')],
True,
False,
),
# pattern
([], './a', [], True, False),
(['a', 'b', 'c'], './a', [], False, False),
(['a', 'b', 'c'], './a', [], True, False),
(['aa', 'ba', 'ca'], './a*', ['aa'], False, False),
(['aa', 'ba', 'ca'], './ba', [], False, False),
# hidden
(['a', '.b'], ['a'], True, False),
(['a', '.b'], ['.b', 'a'], True, True),
(['a', '.b'], './', ['a'], True, False),
(['a', '.b'], './', ['.b', 'a'], True, True),
(
['a', '.b', os.path.join('a', '.b'), os.path.join('a', '.b', 'c')],
'./',
['a'],
True,
False,
),
(
['a', '.b', os.path.join('a', '.b'), os.path.join('a', '.b', 'c')],
'./',
['.b', 'a', os.path.join('a', '.b'), os.path.join('a', '.b', 'c')],
True,
True,
),
],
)
def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden):
def test_list_dir_names(tmpdir, dir_list, path, expected, recursive, hidden):

dir_tmp = tmpdir.mkdir('folder')
directories = []
Expand All @@ -232,9 +241,9 @@ def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden):
for directory in directories:
assert os.path.isdir(directory)

path = os.path.join(str(dir_tmp), '.')
abs_path = os.path.join(str(dir_tmp), path)
dirs = audeer.list_dir_names(
path,
abs_path,
basenames=False,
recursive=recursive,
hidden=hidden,
Expand All @@ -244,7 +253,7 @@ def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden):

# test basenames
dirs = audeer.list_dir_names(
path,
abs_path,
basenames=True,
recursive=recursive,
hidden=hidden,
Expand All @@ -254,10 +263,10 @@ def test_list_dir_names(tmpdir, dir_list, expected, recursive, hidden):

def test_list_dir_names_errors(tmpdir):
with pytest.raises(NotADirectoryError):
file = audeer.touch(audeer.path(tmpdir, 'file.txt'))
audeer.list_dir_names(file)
file = audeer.touch(audeer.path(tmpdir, 'file'))
audeer.list_dir_names(audeer.path(file, 'pattern'))
with pytest.raises(FileNotFoundError):
audeer.list_dir_names('not-existent')
audeer.list_dir_names('not-existent/pattern')


@pytest.mark.parametrize(
Expand Down