Skip to content

Commit

Permalink
[FEAT]: Throw error for invalid ** usage outside folder segments (e.g…
Browse files Browse the repository at this point in the history
…. /tmp/**.csv) (#3100)

Closes #1820.

Main issue seems to be that the `globset` crate is permissive for what
kind of pattern it builds (no error is thrown when we try to build a
pattern for `/tmp/**.csv`, for instance, so we have to check ourselves
for any such patterns.
  • Loading branch information
conradsoon authored Oct 31, 2024
1 parent 301cd48 commit 073ec37
Show file tree
Hide file tree
Showing 3 changed files with 117 additions and 1 deletion.
69 changes: 69 additions & 0 deletions src/daft-io/src/object_store_glob.rs
Original file line number Diff line number Diff line change
Expand Up @@ -334,6 +334,37 @@ fn _should_return(fm: &FileMetadata) -> bool {
}
}

/// Validates the glob pattern before compiling it. The `globset` crate which we use for globbing is
/// very permissive and does not check for invalid usage of the '**' wildcard. This function ensures
/// that the glob pattern does not contain invalid usage of '**'.
fn verify_glob(glob: &str) -> super::Result<()> {
let re = regex::Regex::new(r"(?P<before>.*?[^\\])\*\*(?P<after>[^/\n].*)").unwrap();

if let Some(captures) = re.captures(glob) {
let before = captures.name("before").map_or("", |m| m.as_str());
let after = captures.name("after").map_or("", |m| m.as_str());

// Ensure the 'before' part ends with a delimiter
let corrected_before = if !before.ends_with('/') {
format!("{}/", before)
} else {
before.to_string()
};

let corrected_pattern = format!("{corrected_before}**/*{after}");

return Err(super::Error::InvalidArgument {
msg: format!(
"Invalid usage of '**' in glob pattern. Found '{before}**{after}'. \
The '**' wildcard should be used to match directories and must be surrounded by delimiters. \
Did you perhaps mean: '{corrected_pattern}'?"
),
});
}

Ok(())
}

/// Globs an ObjectSource for Files
///
/// Uses the `globset` crate for matching, and thus supports all the syntax enabled by that crate.
Expand Down Expand Up @@ -404,6 +435,10 @@ pub async fn glob(
};
let glob = glob.as_str();

// Validate the glob pattern, this is necessary since the `globset` crate is overly permissive and happily compiles patterns
// like "/foo/bar/**.txt" which don't make sense.
verify_glob(glob)?;

let glob_fragments = to_glob_fragments(glob)?;
let full_glob_matcher = GlobBuilder::new(glob)
.literal_separator(true)
Expand Down Expand Up @@ -662,3 +697,37 @@ pub async fn glob(

Ok(to_rtn_stream.boxed())
}

#[cfg(test)]
mod tests {
use super::*;

#[test]
fn test_verify_glob() {
// Test valid glob patterns
assert!(verify_glob("valid/pattern.txt").is_ok()); // Normal globbing works ok
assert!(verify_glob("another/valid/pattern/**/blah.txt").is_ok()); // No error if ** used as a segment
assert!(verify_glob("**").is_ok()); // ** by itself is ok
assert!(verify_glob("another/valid/pattern/**").is_ok()); // No trailing slash is ok
assert!(verify_glob("another/valid/pattern/**/").is_ok()); // Trailing slash is ok (should be interpreted as **/*)
assert!(verify_glob("another/valid/pattern/**/\\**.txt").is_ok()); // Escaped ** is ok
assert!(verify_glob("**/wildcard/*.txt").is_ok()); // Wildcard matching not affected

// Test invalid glob patterns and check error messages
// The '**' wildcard should be used to match directories and must be surrounded by delimiters.
let err = verify_glob("invalid/**.txt").unwrap_err();
assert!(err.to_string().contains("invalid/**/*.txt")); // Suggests adding a delimiter after '**'

// '**' should be surrounded by delimiters to match directories, not used directly with file names.
let err = verify_glob("invalid/blahblah**.txt").unwrap_err();
assert!(err.to_string().contains("invalid/blahblah/**/*.txt")); // Suggests adding a delimiter before '**'

// Backslash should only escape the first '*', leading to non-escaped '**'.
let err = verify_glob("invalid/\\***.txt").unwrap_err();
assert!(err.to_string().contains("invalid/\\\\*/**/*.txt")); // Suggests correcting the escape sequence (NOTE: double backslash)

// Non-escaped '**' should trigger even when there is an escaped '**'.
let err = verify_glob("invalid/\\**blahblah**.txt").unwrap_err();
assert!(err.to_string().contains("invalid/\\\\**blahblah/**/*.txt")); // Suggests adding delimiters around '**'
}
}
6 changes: 5 additions & 1 deletion tests/integration/io/parquet/test_reads_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,7 +35,11 @@ def test_minio_parquet_read_no_files(minio_io_config):
fs.touch("s3://data-engineering-prod/foo/file.txt")

with pytest.raises(FileNotFoundError, match="Glob path had no matches:"):
daft.read_parquet("s3://data-engineering-prod/foo/**.parquet", io_config=minio_io_config)
# Need to have a special character within the test path to trigger the matching logic
daft.read_parquet(
"s3://data-engineering-prod/foo/this-should-not-match-anything-and-this-file-should-not-exist-*.parquet",
io_config=minio_io_config,
)


@pytest.mark.integration()
Expand Down
43 changes: 43 additions & 0 deletions tests/io/test_list_files_local.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,13 @@
from __future__ import annotations

import os
import re

import pytest
from fsspec.implementations.local import LocalFileSystem

from daft.daft import io_glob
from daft.exceptions import DaftCoreException


def local_recursive_list(fs, path) -> list:
Expand Down Expand Up @@ -165,3 +167,44 @@ def test_missing_file_path(tmp_path, include_protocol):
p = "file://" + p
with pytest.raises(FileNotFoundError, match="/c/cc/ddd not found"):
io_glob(p)


@pytest.mark.parametrize("include_protocol", [False, True])
def test_invalid_double_asterisk_usage_local(tmp_path, include_protocol):
d = tmp_path / "dir"
d.mkdir()

path = str(d) + "/**.pq"
if include_protocol:
path = "file://" + path

expected_correct_path = str(d) + "/**/*.pq"
if include_protocol:
expected_correct_path = "file://" + expected_correct_path

# Need to escape these or the regex matcher will complain
expected_correct_path = re.escape(expected_correct_path)

with pytest.raises(DaftCoreException, match=expected_correct_path):
io_glob(path)


@pytest.mark.parametrize("include_protocol", [False, True])
def test_literal_double_asterisk_file(tmp_path, include_protocol):
d = tmp_path / "dir"
d.mkdir()
file_with_literal_name = d / "*.pq"
file_with_literal_name.touch()

path = str(d) + "/\**.pq"
if include_protocol:
path = "file://" + path

fs = LocalFileSystem()
fs_result = fs.ls(str(d), detail=True)
fs_result = [f for f in fs_result if f["name"] == str(file_with_literal_name)]

daft_ls_result = io_glob(path)

assert len(daft_ls_result) == 1
compare_local_result(daft_ls_result, fs_result)

0 comments on commit 073ec37

Please sign in to comment.