diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index 2ce447da90..e6fe2216d3 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -1,3 +1,4 @@ +use std::collections::HashSet; use std::ops::Range; use std::sync::Arc; @@ -6,6 +7,7 @@ use bytes::Bytes; use futures::stream::{BoxStream, Stream}; use futures::StreamExt; use globset::GlobBuilder; +use lazy_static::lazy_static; use tokio::sync::mpsc::Sender; use tokio::sync::OwnedSemaphorePermit; use url::Position; @@ -115,11 +117,98 @@ pub(crate) trait ObjectSource: Sync + Send { } } -/// Checks if a given string contains special glob characters -/// NOTE: we use the `globset` crate which defines the following glob behavior: -/// https://docs.rs/globset/latest/globset/index.html#syntax -fn contains_special_character(s: &str) -> bool { - s.contains('*') || s.contains('?') || s.contains('{') || s.contains('[') +lazy_static! { + /// Check if a given char is considered a special glob character + /// NOTE: we use the `globset` crate which defines the following glob behavior: + /// https://docs.rs/globset/latest/globset/index.html#syntax + static ref GLOB_SPECIAL_CHARACTERS: HashSet = { + let mut set = HashSet::new(); + set.insert('*'); + set.insert('?'); + set.insert('{'); + set.insert('['); + set + }; +} + +#[derive(Debug, Clone)] +struct GlobFragment { + data: String, + first_wildcard_idx: Option, +} + +impl GlobFragment { + pub fn new(data: &str) -> Self { + let first_wildcard_idx = if data.is_empty() { + None + } else if GLOB_SPECIAL_CHARACTERS.contains(&data.chars().nth(0).unwrap()) { + Some(0) + } else { + let mut idx = None; + for (i, window) in data + .chars() + .collect::>() + .as_slice() + .windows(2) + .enumerate() + { + let &[c1, c2] = window else { + unreachable!("Window contains 2 elements") + }; + if (c1 != '\\') && GLOB_SPECIAL_CHARACTERS.contains(&c2) { + idx = Some(i + 1); + break; + } + } + idx + }; + GlobFragment { + data: data.to_string(), + first_wildcard_idx, + } + } + + pub fn has_special_character(&self) -> bool { + self.first_wildcard_idx.is_some() + } + + pub fn join(fragments: &[GlobFragment], sep: &str) -> Self { + GlobFragment::new( + fragments + .iter() + .map(|frag: &GlobFragment| frag.data.as_str()) + .collect::>() + .join(sep) + .as_str(), + ) + } + + pub fn escaped_str(&self) -> String { + // Clean up the string by applying backslash escapes: + // 1. \\ is cleaned up to just \ + // 2. \ followed by anything else is just ignored + let mut result = String::new(); + let mut ptr = 0; + while ptr < self.data.len() { + let remaining = &self.data.as_str()[ptr..]; + match remaining.find("\\\\") { + Some(backslash_idx) => { + result.push_str(&remaining[..backslash_idx].replace('\\', "")); + result.extend(std::iter::once('\\')); + ptr += backslash_idx + 2; + } + None => { + result.push_str(&remaining.replace('\\', "")); + break; + } + } + } + result + } + + pub fn raw_str(&self) -> &str { + self.data.as_str() + } } /// Parses a glob URL string into "fragments" @@ -127,7 +216,7 @@ fn contains_special_character(s: &str) -> bool { /// 1. Split by delimiter ("/") /// 2. Non-wildcard fragments are joined and coalesced by delimiter /// 3. The first fragment is prefixed by "{scheme}://" -fn to_glob_fragments(glob_str: &str) -> Vec { +fn to_glob_fragments(glob_str: &str) -> Vec { let delimiter = "/".to_string(); let glob_url = url::Url::parse(glob_str) .unwrap_or_else(|_| panic!("Glob string must be able to be parsed as URL: {glob_str}")); @@ -137,14 +226,18 @@ fn to_glob_fragments(glob_str: &str) -> Vec { let mut glob_fragments = glob_url[Position::BeforeUsername..].split(&delimiter).fold( (vec![], vec![]), |(mut acc, mut fragments_so_far), current_fragment| { - if contains_special_character(current_fragment) { + let current_fragment = GlobFragment::new(current_fragment); + if current_fragment.has_special_character() { if !fragments_so_far.is_empty() { - acc.push(fragments_so_far.join(delimiter.as_str())); + acc.push(GlobFragment::join( + fragments_so_far.as_slice(), + delimiter.as_str(), + )); } - acc.push(current_fragment.to_string()); + acc.push(current_fragment); (acc, vec![]) } else { - fragments_so_far.push(current_fragment.to_string()); + fragments_so_far.push(current_fragment); (acc, fragments_so_far) } }, @@ -152,13 +245,15 @@ fn to_glob_fragments(glob_str: &str) -> Vec { let mut glob_fragments = if glob_fragments.1.is_empty() { glob_fragments.0 } else { + let last_fragment = GlobFragment::join(glob_fragments.1.as_slice(), delimiter.as_str()); glob_fragments .0 .drain(..) - .chain(std::iter::once(glob_fragments.1.join(delimiter.as_str()))) + .chain(std::iter::once(last_fragment)) .collect() }; - glob_fragments[0] = format!("{url_scheme}://") + glob_fragments[0].as_str(); + glob_fragments[0] = + GlobFragment::new((format!("{url_scheme}://") + glob_fragments[0].raw_str()).as_str()); glob_fragments } @@ -181,22 +276,24 @@ pub(crate) async fn glob( result_tx: Sender>, source: Arc, path: &str, - glob_fragments: (Vec, usize), + glob_fragments: (Vec, usize), ) { let path = path.to_string(); tokio::spawn(async move { log::debug!(target: "glob", "Visiting '{path}' with glob_fragments: {glob_fragments:?}"); let (glob_fragments, i) = glob_fragments; - let current_fragment = glob_fragments[i].as_str(); + let current_fragment = &glob_fragments[i]; - // BASE CASE: current_fragment contains a ** + // BASE CASE: current_fragment is a ** // We perform a recursive ls and filter on the results for only FileType::File results that match the full glob - if current_fragment.contains("**") { - let glob_matcher = GlobBuilder::new(glob_fragments.join("/").as_str()) - .literal_separator(true) - .build() - .expect("Cannot parse glob") - .compile_matcher(); + if current_fragment.escaped_str() == "**" { + let glob_matcher = + GlobBuilder::new(GlobFragment::join(glob_fragments.as_slice(), "/").raw_str()) + .literal_separator(true) + .backslash_escape(true) + .build() + .expect("Cannot parse glob") + .compile_matcher(); let next_level_file_metadata = source.iter_dir(path.as_str(), Some("/"), None).await; @@ -236,14 +333,16 @@ pub(crate) async fn glob( // BASE CASE: current fragment is the last fragment in `glob_fragments` } else if i == glob_fragments.len() - 1 { - let glob_matcher = GlobBuilder::new(glob_fragments.join("/").as_str()) - .literal_separator(true) - .build() - .expect("Cannot parse glob") - .compile_matcher(); + let glob_matcher = + GlobBuilder::new(GlobFragment::join(glob_fragments.as_slice(), "/").raw_str()) + .literal_separator(true) + .backslash_escape(true) + .build() + .expect("Cannot parse glob") + .compile_matcher(); // Last fragment contains a wildcard: we list the last level and match against the full glob - if contains_special_character(current_fragment) { + if current_fragment.has_special_character() { let next_level_file_metadata = source.iter_dir(path.as_str(), Some("/"), None).await; @@ -271,7 +370,7 @@ pub(crate) async fn glob( } // Last fragment does not contain wildcard: we just need to check that the full path exists and is a File } else { - let full_dir_path = path.to_string() + current_fragment; + let full_dir_path = path.to_string() + current_fragment.escaped_str().as_str(); let single_file_ls = source.ls(full_dir_path.as_str(), Some("/"), None).await; match single_file_ls { Ok(mut single_file_ls) => { @@ -288,9 +387,9 @@ pub(crate) async fn glob( } // RECURSIVE CASE: current_fragment contains a special character (e.g. *) - } else if contains_special_character(current_fragment) { + } else if current_fragment.has_special_character() { let partial_glob_matcher = - GlobBuilder::new(glob_fragments[..i + 1].join("/").as_str()) + GlobBuilder::new(GlobFragment::join(&glob_fragments[..i + 1], "/").raw_str()) .literal_separator(true) .build() .expect("Cannot parse glob") @@ -324,7 +423,7 @@ pub(crate) async fn glob( // RECURSIVE CASE: current_fragment contains no special characters, and is a path to a specific File or Directory } else { - let full_dir_path = path.to_string() + current_fragment; + let full_dir_path = path.to_string() + current_fragment.escaped_str().as_str(); visit( result_tx.clone(), source.clone(), diff --git a/tests/integration/io/test_list_files_s3_minio.py b/tests/integration/io/test_list_files_s3_minio.py index 9c0c4c32ea..380e3207c1 100644 --- a/tests/integration/io/test_list_files_s3_minio.py +++ b/tests/integration/io/test_list_files_s3_minio.py @@ -155,6 +155,25 @@ def test_directory_globbing_fragment_wildcard(minio_io_config, path_expect_pair) assert sorted(daft_ls_result, key=lambda d: d["path"]) == sorted(expect, key=lambda d: d["path"]) +@pytest.mark.integration() +@pytest.mark.parametrize( + "path_expect_pair", + [ + (r"s3://bucket/\*.match", [{"type": "File", "path": "s3://bucket/*.match", "size": 0}]), + ("s3://bucket/\\\\.match", [{"type": "File", "path": r"s3://bucket/\.match", "size": 0}]), + ("s3://bucket/\\a.match", [{"type": "File", "path": "s3://bucket/a.match", "size": 0}]), + ], +) +def test_directory_globbing_escape_characters(minio_io_config, path_expect_pair): + globpath, expect = path_expect_pair + with minio_create_bucket(minio_io_config, bucket_name="bucket") as fs: + files = ["a.match", "*.match", r"\.match"] + for name in files: + fs.touch(f"bucket/{name}") + daft_ls_result = io_glob(globpath, io_config=minio_io_config) + assert sorted(daft_ls_result, key=lambda d: d["path"]) == sorted(expect, key=lambda d: d["path"]) + + @pytest.mark.integration() def test_flat_directory_listing(minio_io_config): bucket_name = "bucket"