diff --git a/src/daft-io/src/object_store_glob.rs b/src/daft-io/src/object_store_glob.rs index 5380d9c9d5..1a57ef2978 100644 --- a/src/daft-io/src/object_store_glob.rs +++ b/src/daft-io/src/object_store_glob.rs @@ -334,6 +334,37 @@ fn _should_return(fm: &FileMetadata) -> bool { } } +/// Validates the glob pattern before compiling it. The `globset` crate which we use for globbing is +/// very permissive and does not check for invalid usage of the '**' wildcard. This function ensures +/// that the glob pattern does not contain invalid usage of '**'. +fn verify_glob(glob: &str) -> super::Result<()> { + let re = regex::Regex::new(r"(?P.*?[^\\])\*\*(?P[^/\n].*)").unwrap(); + + if let Some(captures) = re.captures(glob) { + let before = captures.name("before").map_or("", |m| m.as_str()); + let after = captures.name("after").map_or("", |m| m.as_str()); + + // Ensure the 'before' part ends with a delimiter + let corrected_before = if !before.ends_with('/') { + format!("{}/", before) + } else { + before.to_string() + }; + + let corrected_pattern = format!("{corrected_before}**/*{after}"); + + return Err(super::Error::InvalidArgument { + msg: format!( + "Invalid usage of '**' in glob pattern. Found '{before}**{after}'. \ + The '**' wildcard should be used to match directories and must be surrounded by delimiters. \ + Did you perhaps mean: '{corrected_pattern}'?" + ), + }); + } + + Ok(()) +} + /// Globs an ObjectSource for Files /// /// Uses the `globset` crate for matching, and thus supports all the syntax enabled by that crate. @@ -404,6 +435,10 @@ pub async fn glob( }; let glob = glob.as_str(); + // Validate the glob pattern, this is necessary since the `globset` crate is overly permissive and happily compiles patterns + // like "/foo/bar/**.txt" which don't make sense. + verify_glob(glob)?; + let glob_fragments = to_glob_fragments(glob)?; let full_glob_matcher = GlobBuilder::new(glob) .literal_separator(true) @@ -662,3 +697,37 @@ pub async fn glob( Ok(to_rtn_stream.boxed()) } + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_verify_glob() { + // Test valid glob patterns + assert!(verify_glob("valid/pattern.txt").is_ok()); // Normal globbing works ok + assert!(verify_glob("another/valid/pattern/**/blah.txt").is_ok()); // No error if ** used as a segment + assert!(verify_glob("**").is_ok()); // ** by itself is ok + assert!(verify_glob("another/valid/pattern/**").is_ok()); // No trailing slash is ok + assert!(verify_glob("another/valid/pattern/**/").is_ok()); // Trailing slash is ok (should be interpreted as **/*) + assert!(verify_glob("another/valid/pattern/**/\\**.txt").is_ok()); // Escaped ** is ok + assert!(verify_glob("**/wildcard/*.txt").is_ok()); // Wildcard matching not affected + + // Test invalid glob patterns and check error messages + // The '**' wildcard should be used to match directories and must be surrounded by delimiters. + let err = verify_glob("invalid/**.txt").unwrap_err(); + assert!(err.to_string().contains("invalid/**/*.txt")); // Suggests adding a delimiter after '**' + + // '**' should be surrounded by delimiters to match directories, not used directly with file names. + let err = verify_glob("invalid/blahblah**.txt").unwrap_err(); + assert!(err.to_string().contains("invalid/blahblah/**/*.txt")); // Suggests adding a delimiter before '**' + + // Backslash should only escape the first '*', leading to non-escaped '**'. + let err = verify_glob("invalid/\\***.txt").unwrap_err(); + assert!(err.to_string().contains("invalid/\\\\*/**/*.txt")); // Suggests correcting the escape sequence (NOTE: double backslash) + + // Non-escaped '**' should trigger even when there is an escaped '**'. + let err = verify_glob("invalid/\\**blahblah**.txt").unwrap_err(); + assert!(err.to_string().contains("invalid/\\\\**blahblah/**/*.txt")); // Suggests adding delimiters around '**' + } +} diff --git a/tests/integration/io/parquet/test_reads_s3_minio.py b/tests/integration/io/parquet/test_reads_s3_minio.py index 307e9e5dfa..cd5d9ef988 100644 --- a/tests/integration/io/parquet/test_reads_s3_minio.py +++ b/tests/integration/io/parquet/test_reads_s3_minio.py @@ -35,7 +35,11 @@ def test_minio_parquet_read_no_files(minio_io_config): fs.touch("s3://data-engineering-prod/foo/file.txt") with pytest.raises(FileNotFoundError, match="Glob path had no matches:"): - daft.read_parquet("s3://data-engineering-prod/foo/**.parquet", io_config=minio_io_config) + # Need to have a special character within the test path to trigger the matching logic + daft.read_parquet( + "s3://data-engineering-prod/foo/this-should-not-match-anything-and-this-file-should-not-exist-*.parquet", + io_config=minio_io_config, + ) @pytest.mark.integration() diff --git a/tests/io/test_list_files_local.py b/tests/io/test_list_files_local.py index 234da538c9..90a24a1fe9 100644 --- a/tests/io/test_list_files_local.py +++ b/tests/io/test_list_files_local.py @@ -1,11 +1,13 @@ from __future__ import annotations import os +import re import pytest from fsspec.implementations.local import LocalFileSystem from daft.daft import io_glob +from daft.exceptions import DaftCoreException def local_recursive_list(fs, path) -> list: @@ -165,3 +167,44 @@ def test_missing_file_path(tmp_path, include_protocol): p = "file://" + p with pytest.raises(FileNotFoundError, match="/c/cc/ddd not found"): io_glob(p) + + +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_invalid_double_asterisk_usage_local(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + + path = str(d) + "/**.pq" + if include_protocol: + path = "file://" + path + + expected_correct_path = str(d) + "/**/*.pq" + if include_protocol: + expected_correct_path = "file://" + expected_correct_path + + # Need to escape these or the regex matcher will complain + expected_correct_path = re.escape(expected_correct_path) + + with pytest.raises(DaftCoreException, match=expected_correct_path): + io_glob(path) + + +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_literal_double_asterisk_file(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + file_with_literal_name = d / "*.pq" + file_with_literal_name.touch() + + path = str(d) + "/\**.pq" + if include_protocol: + path = "file://" + path + + fs = LocalFileSystem() + fs_result = fs.ls(str(d), detail=True) + fs_result = [f for f in fs_result if f["name"] == str(file_with_literal_name)] + + daft_ls_result = io_glob(path) + + assert len(daft_ls_result) == 1 + compare_local_result(daft_ls_result, fs_result)