Skip to content

Commit

Permalink
feat: add download_url/upload_url to sql
Browse files Browse the repository at this point in the history
  • Loading branch information
frankliee committed Dec 28, 2024
1 parent e59581c commit 200bd52
Show file tree
Hide file tree
Showing 8 changed files with 273 additions and 60 deletions.
10 changes: 5 additions & 5 deletions src/daft-functions/src/uri/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ use snafu::prelude::*;
use crate::InvalidArgumentSnafu;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct DownloadFunction {
pub(super) max_connections: usize,
pub(super) raise_error_on_failure: bool,
pub(super) multi_thread: bool,
pub(super) config: Arc<IOConfig>,
pub struct DownloadFunction {
pub max_connections: usize,
pub raise_error_on_failure: bool,
pub multi_thread: bool,
pub config: Arc<IOConfig>,
}

#[typetag::serde]
Expand Down
4 changes: 2 additions & 2 deletions src/daft-functions/src/uri/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod download;
mod upload;
pub mod download;
pub mod upload;

use common_io_config::IOConfig;
use daft_dsl::{functions::ScalarFunction, ExprRef};
Expand Down
12 changes: 6 additions & 6 deletions src/daft-functions/src/uri/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,12 +9,12 @@ use futures::{StreamExt, TryStreamExt};
use serde::Serialize;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct UploadFunction {
pub(super) max_connections: usize,
pub(super) raise_error_on_failure: bool,
pub(super) multi_thread: bool,
pub(super) is_single_folder: bool,
pub(super) config: Arc<IOConfig>,
pub struct UploadFunction {
pub max_connections: usize,
pub raise_error_on_failure: bool,
pub multi_thread: bool,
pub is_single_folder: bool,
pub config: Arc<IOConfig>,
}

#[typetag::serde]
Expand Down
3 changes: 2 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat,
SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric,
SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs,
SQLModuleTemporal, SQLModuleUtf8,
SQLModuleTemporal, SQLModuleURL, SQLModuleUtf8,
},
planner::SQLPlanner,
unsupported_sql_err,
Expand All @@ -37,6 +37,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy<SQLFunctions> = Lazy::new(|| {
functions.register::<SQLModuleStructs>();
functions.register::<SQLModuleTemporal>();
functions.register::<SQLModuleUtf8>();
functions.register::<SQLModuleURL>();
functions.register::<SQLModuleConfig>();
functions.add_fn("coalesce", SQLCoalesce {});
functions
Expand Down
2 changes: 2 additions & 0 deletions src/daft-sql/src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod python;
pub mod sketch;
pub mod structs;
pub mod temporal;
pub mod url;
pub mod utf8;

pub use aggs::SQLModuleAggs;
Expand All @@ -30,6 +31,7 @@ pub use python::SQLModulePython;
pub use sketch::SQLModuleSketch;
pub use structs::SQLModuleStructs;
pub use temporal::SQLModuleTemporal;
pub use url::SQLModuleURL;
pub use utf8::SQLModuleUtf8;

/// A [SQLModule] is a collection of SQL functions that can be registered with a [SQLFunctions] instance.
Expand Down
192 changes: 192 additions & 0 deletions src/daft-sql/src/modules/url.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,192 @@
use std::sync::Arc;

use daft_dsl::{Expr, ExprRef, LiteralValue};
use daft_functions::uri::{download, download::DownloadFunction, upload, upload::UploadFunction};
use sqlparser::ast::FunctionArg;

use super::SQLModule;
use crate::{
error::{PlannerError, SQLPlannerResult},
functions::{SQLFunction, SQLFunctionArguments, SQLFunctions},
modules::config::expr_to_iocfg,
planner::SQLPlanner,
unsupported_sql_err,
};

pub struct SQLModuleURL;

impl SQLModule for SQLModuleURL {
fn register(parent: &mut SQLFunctions) {
parent.add_fn("url_download", UrlDownload);
parent.add_fn("url_upload", UrlUpload);
}
}

impl TryFrom<SQLFunctionArguments> for DownloadFunction {
type Error = PlannerError;

fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> {
let max_connections = args.try_get_named("max_connections")?.unwrap_or(32);
let raise_error_on_failure = args
.get_named("on_error")
.map(|arg| match arg.as_ref() {
Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() {
"raise" => Ok(true),
"null" => Ok(false),
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),

Check warning on line 36 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L28-L36

Added lines #L28 - L36 were not covered by tests
},
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),
})
.transpose()?
.unwrap_or(true);

Check warning on line 41 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L38-L41

Added lines #L38 - L41 were not covered by tests

// TODO: choice multi_thread based on the current engine (such as ray)
let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false);

Check warning on line 44 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L44

Added line #L44 was not covered by tests

let config = Arc::new(
args.get_named("io_config")
.map(expr_to_iocfg)
.transpose()?
.unwrap_or_default(),
);

Ok(Self {
max_connections,
raise_error_on_failure,
multi_thread,
config,
})
}

Check warning on line 59 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L46-L59

Added lines #L46 - L59 were not covered by tests
}

struct UrlDownload;

impl SQLFunction for UrlDownload {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> SQLPlannerResult<ExprRef> {
match inputs {
[input, args @ ..] => {
let input = planner.plan_function_arg(input)?;
let args: DownloadFunction = planner.plan_function_args(
args,
&["max_connections", "on_error", "multi_thread", "io_config"],
0,
)?;

Check warning on line 77 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L65-L77

Added lines #L65 - L77 were not covered by tests

Ok(download(
input,
args.max_connections,
args.raise_error_on_failure,
args.multi_thread,
Arc::try_unwrap(args.config).unwrap_or_default().into(),
))

Check warning on line 85 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L79-L85

Added lines #L79 - L85 were not covered by tests
}
_ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"),

Check warning on line 87 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L87

Added line #L87 was not covered by tests
}
}

Check warning on line 89 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L89

Added line #L89 was not covered by tests

fn docstrings(&self, _: &str) -> String {
"download data from the given url".to_string()
}

fn arg_names(&self) -> &'static [&'static str] {
&[
"input",
"max_connections",
"on_error",
"multi_thread",
"io_config",
]
}
}

impl TryFrom<SQLFunctionArguments> for UploadFunction {
type Error = PlannerError;

fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> {
let max_connections = args.try_get_named("max_connections")?.unwrap_or(32);

let raise_error_on_failure = args
.get_named("on_error")
.map(|arg| match arg.as_ref() {
Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() {
"raise" => Ok(true),
"null" => Ok(false),
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),

Check warning on line 118 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L115-L118

Added lines #L115 - L118 were not covered by tests
},
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),

Check warning on line 120 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L120

Added line #L120 was not covered by tests
})
.transpose()?
.unwrap_or(true);

// TODO: choice multi_thread based on the current engine (such as ray)
let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false);

// by default use row_specifc_urls
let is_single_folder = false;

let config = Arc::new(
args.get_named("io_config")
.map(expr_to_iocfg)
.transpose()?
.unwrap_or_default(),
);

Ok(Self {
max_connections,
raise_error_on_failure,
multi_thread,
is_single_folder,
config,
})
}
}

struct UrlUpload;

impl SQLFunction for UrlUpload {
fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
match inputs {
[input, location, args @ ..] => {
let input = planner.plan_function_arg(input)?;
let location = planner.plan_function_arg(location)?;
let mut args: UploadFunction = planner.plan_function_args(
args,
&["max_connections", "on_error", "multi_thread", "io_config"],
0,
)?;
if location.as_literal().is_some() {
args.is_single_folder = true;
}
Ok(upload(
input,
location,
args.max_connections,
args.raise_error_on_failure,
args.multi_thread,
args.is_single_folder,
Arc::try_unwrap(args.config).unwrap_or_default().into(),
))
}
_ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"),

Check warning on line 174 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L174

Added line #L174 was not covered by tests
}
}

fn docstrings(&self, _: &str) -> String {
"upload data to the given path".to_string()
}

fn arg_names(&self) -> &'static [&'static str] {
&[
"input",
"location",
"max_connections",
"on_error",
"multi_thread",
"io_config",
]
}
}
17 changes: 12 additions & 5 deletions tests/io/test_url_download_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,18 +30,25 @@ def local_image_data_fixture(tmpdir, image_data) -> YieldFixture[list[str]]:
def test_url_download_local(local_image_data_fixture, image_data):
data = {"urls": local_image_data_fixture}
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download())
assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]}

def check_results(df):
assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]}

check_results(df.with_column("data", df["urls"].url.download()))
check_results(daft.sql("SELECT urls, url_download(urls) AS data FROM df"))


@pytest.mark.integration()
def test_url_download_local_missing(local_image_data_fixture):
data = {"urls": local_image_data_fixture + ["/missing/path/x.jpeg"]}
df = daft.from_pydict(data)
df = df.with_column("data", df["urls"].url.download(on_error="raise"))

with pytest.raises(FileNotFoundError):
df.collect()
def check_results(df):
with pytest.raises(FileNotFoundError):
df.collect()

check_results(df.with_column("data", df["urls"].url.download(on_error="raise")))
check_results(daft.sql("SELECT urls, url_download(urls, on_error:='raise') AS data FROM df"))


@pytest.mark.integration()
Expand Down
Loading

0 comments on commit 200bd52

Please sign in to comment.