Skip to content

Commit

Permalink
Add GlobFragment abstraction on top of raw Strings to help with escap…
Browse files Browse the repository at this point in the history
…e characters
  • Loading branch information
Jay Chia committed Sep 29, 2023
1 parent 4a9bf7e commit da44cf1
Show file tree
Hide file tree
Showing 2 changed files with 149 additions and 31 deletions.
161 changes: 130 additions & 31 deletions src/daft-io/src/object_io.rs
Original file line number Diff line number Diff line change
@@ -1,3 +1,4 @@
use std::collections::HashSet;
use std::ops::Range;
use std::sync::Arc;

Expand All @@ -6,6 +7,7 @@ use bytes::Bytes;
use futures::stream::{BoxStream, Stream};
use futures::StreamExt;
use globset::GlobBuilder;
use lazy_static::lazy_static;
use tokio::sync::mpsc::Sender;
use tokio::sync::OwnedSemaphorePermit;
use url::Position;
Expand Down Expand Up @@ -115,19 +117,106 @@ pub(crate) trait ObjectSource: Sync + Send {
}
}

/// Checks if a given string contains special glob characters
/// NOTE: we use the `globset` crate which defines the following glob behavior:
/// https://docs.rs/globset/latest/globset/index.html#syntax
fn contains_special_character(s: &str) -> bool {
s.contains('*') || s.contains('?') || s.contains('{') || s.contains('[')
lazy_static! {
/// Check if a given char is considered a special glob character
/// NOTE: we use the `globset` crate which defines the following glob behavior:
/// https://docs.rs/globset/latest/globset/index.html#syntax
static ref GLOB_SPECIAL_CHARACTERS: HashSet<char> = {
let mut set = HashSet::new();
set.insert('*');
set.insert('?');
set.insert('{');
set.insert('[');
set
};
}

#[derive(Debug, Clone)]
struct GlobFragment {
data: String,
first_wildcard_idx: Option<usize>,
}

impl GlobFragment {
pub fn new(data: &str) -> Self {
let first_wildcard_idx = if data.is_empty() {
None
} else if GLOB_SPECIAL_CHARACTERS.contains(&data.chars().nth(0).unwrap()) {
Some(0)
} else {
let mut idx = None;
for (i, window) in data
.chars()
.collect::<Vec<char>>()
.as_slice()
.windows(2)
.enumerate()
{
let &[c1, c2] = window else {
unreachable!("Window contains 2 elements")
};
if (c1 != '\\') && GLOB_SPECIAL_CHARACTERS.contains(&c2) {
idx = Some(i + 1);
break;
}
}
idx
};
GlobFragment {
data: data.to_string(),
first_wildcard_idx,
}
}

pub fn has_special_character(&self) -> bool {
self.first_wildcard_idx.is_some()
}

pub fn join(fragments: &[GlobFragment], sep: &str) -> Self {
GlobFragment::new(
fragments
.iter()
.map(|frag: &GlobFragment| frag.data.as_str())
.collect::<Vec<&str>>()
.join(sep)
.as_str(),
)
}

pub fn escaped_str(&self) -> String {
// Clean up the string by applying backslash escapes:
// 1. \\ is cleaned up to just \
// 2. \ followed by anything else is just ignored
let mut result = String::new();
let mut ptr = 0;
while ptr < self.data.len() {
let remaining = &self.data.as_str()[ptr..];
match remaining.find("\\\\") {
Some(backslash_idx) => {
result.push_str(&remaining[..backslash_idx].replace('\\', ""));
result.extend(std::iter::once('\\'));
ptr += backslash_idx + 2;
}
None => {
result.push_str(&remaining.replace('\\', ""));
break;
}
}
}
result
}

pub fn raw_str(&self) -> &str {
self.data.as_str()
}
}

/// Parses a glob URL string into "fragments"
/// Fragments are the glob URL string but:
/// 1. Split by delimiter ("/")
/// 2. Non-wildcard fragments are joined and coalesced by delimiter
/// 3. The first fragment is prefixed by "{scheme}://"
fn to_glob_fragments(glob_str: &str) -> Vec<String> {
fn to_glob_fragments(glob_str: &str) -> Vec<GlobFragment> {
let delimiter = "/".to_string();
let glob_url = url::Url::parse(glob_str)
.unwrap_or_else(|_| panic!("Glob string must be able to be parsed as URL: {glob_str}"));
Expand All @@ -137,28 +226,34 @@ fn to_glob_fragments(glob_str: &str) -> Vec<String> {
let mut glob_fragments = glob_url[Position::BeforeUsername..].split(&delimiter).fold(
(vec![], vec![]),
|(mut acc, mut fragments_so_far), current_fragment| {
if contains_special_character(current_fragment) {
let current_fragment = GlobFragment::new(current_fragment);
if current_fragment.has_special_character() {
if !fragments_so_far.is_empty() {
acc.push(fragments_so_far.join(delimiter.as_str()));
acc.push(GlobFragment::join(
fragments_so_far.as_slice(),
delimiter.as_str(),
));
}
acc.push(current_fragment.to_string());
acc.push(current_fragment);
(acc, vec![])
} else {
fragments_so_far.push(current_fragment.to_string());
fragments_so_far.push(current_fragment);
(acc, fragments_so_far)
}
},
);
let mut glob_fragments = if glob_fragments.1.is_empty() {
glob_fragments.0
} else {
let last_fragment = GlobFragment::join(glob_fragments.1.as_slice(), delimiter.as_str());
glob_fragments
.0
.drain(..)
.chain(std::iter::once(glob_fragments.1.join(delimiter.as_str())))
.chain(std::iter::once(last_fragment))
.collect()
};
glob_fragments[0] = format!("{url_scheme}://") + glob_fragments[0].as_str();
glob_fragments[0] =
GlobFragment::new((format!("{url_scheme}://") + glob_fragments[0].raw_str()).as_str());

glob_fragments
}
Expand All @@ -181,22 +276,24 @@ pub(crate) async fn glob(
result_tx: Sender<super::Result<FileMetadata>>,
source: Arc<dyn ObjectSource>,
path: &str,
glob_fragments: (Vec<String>, usize),
glob_fragments: (Vec<GlobFragment>, usize),
) {
let path = path.to_string();
tokio::spawn(async move {
log::debug!(target: "glob", "Visiting '{path}' with glob_fragments: {glob_fragments:?}");
let (glob_fragments, i) = glob_fragments;
let current_fragment = glob_fragments[i].as_str();
let current_fragment = &glob_fragments[i];

// BASE CASE: current_fragment contains a **
// BASE CASE: current_fragment is a **
// We perform a recursive ls and filter on the results for only FileType::File results that match the full glob
if current_fragment.contains("**") {
let glob_matcher = GlobBuilder::new(glob_fragments.join("/").as_str())
.literal_separator(true)
.build()
.expect("Cannot parse glob")
.compile_matcher();
if current_fragment.escaped_str() == "**" {
let glob_matcher =
GlobBuilder::new(GlobFragment::join(glob_fragments.as_slice(), "/").raw_str())
.literal_separator(true)
.backslash_escape(true)
.build()
.expect("Cannot parse glob")
.compile_matcher();

let next_level_file_metadata =
source.iter_dir(path.as_str(), Some("/"), None).await;
Expand Down Expand Up @@ -236,14 +333,16 @@ pub(crate) async fn glob(

// BASE CASE: current fragment is the last fragment in `glob_fragments`
} else if i == glob_fragments.len() - 1 {
let glob_matcher = GlobBuilder::new(glob_fragments.join("/").as_str())
.literal_separator(true)
.build()
.expect("Cannot parse glob")
.compile_matcher();
let glob_matcher =
GlobBuilder::new(GlobFragment::join(glob_fragments.as_slice(), "/").raw_str())
.literal_separator(true)
.backslash_escape(true)
.build()
.expect("Cannot parse glob")
.compile_matcher();

// Last fragment contains a wildcard: we list the last level and match against the full glob
if contains_special_character(current_fragment) {
if current_fragment.has_special_character() {
let next_level_file_metadata =
source.iter_dir(path.as_str(), Some("/"), None).await;

Expand Down Expand Up @@ -271,7 +370,7 @@ pub(crate) async fn glob(
}
// Last fragment does not contain wildcard: we just need to check that the full path exists and is a File
} else {
let full_dir_path = path.to_string() + current_fragment;
let full_dir_path = path.to_string() + current_fragment.escaped_str().as_str();
let single_file_ls = source.ls(full_dir_path.as_str(), Some("/"), None).await;
match single_file_ls {
Ok(mut single_file_ls) => {
Expand All @@ -288,9 +387,9 @@ pub(crate) async fn glob(
}

// RECURSIVE CASE: current_fragment contains a special character (e.g. *)
} else if contains_special_character(current_fragment) {
} else if current_fragment.has_special_character() {
let partial_glob_matcher =
GlobBuilder::new(glob_fragments[..i + 1].join("/").as_str())
GlobBuilder::new(GlobFragment::join(&glob_fragments[..i + 1], "/").raw_str())
.literal_separator(true)
.build()
.expect("Cannot parse glob")
Expand Down Expand Up @@ -324,7 +423,7 @@ pub(crate) async fn glob(

// RECURSIVE CASE: current_fragment contains no special characters, and is a path to a specific File or Directory
} else {
let full_dir_path = path.to_string() + current_fragment;
let full_dir_path = path.to_string() + current_fragment.escaped_str().as_str();
visit(
result_tx.clone(),
source.clone(),
Expand Down
19 changes: 19 additions & 0 deletions tests/integration/io/test_list_files_s3_minio.py
Original file line number Diff line number Diff line change
Expand Up @@ -155,6 +155,25 @@ def test_directory_globbing_fragment_wildcard(minio_io_config, path_expect_pair)
assert sorted(daft_ls_result, key=lambda d: d["path"]) == sorted(expect, key=lambda d: d["path"])


@pytest.mark.integration()
@pytest.mark.parametrize(
"path_expect_pair",
[
(r"s3://bucket/\*.match", [{"type": "File", "path": "s3://bucket/*.match", "size": 0}]),
("s3://bucket/\\\\.match", [{"type": "File", "path": r"s3://bucket/\.match", "size": 0}]),
("s3://bucket/\\a.match", [{"type": "File", "path": "s3://bucket/a.match", "size": 0}]),
],
)
def test_directory_globbing_escape_characters(minio_io_config, path_expect_pair):
globpath, expect = path_expect_pair
with minio_create_bucket(minio_io_config, bucket_name="bucket") as fs:
files = ["a.match", "*.match", r"\.match"]
for name in files:
fs.touch(f"bucket/{name}")
daft_ls_result = io_glob(globpath, io_config=minio_io_config)
assert sorted(daft_ls_result, key=lambda d: d["path"]) == sorted(expect, key=lambda d: d["path"])


@pytest.mark.integration()
def test_flat_directory_listing(minio_io_config):
bucket_name = "bucket"
Expand Down

0 comments on commit da44cf1

Please sign in to comment.