Skip to content

Commit

Permalink
feat: add download_url/upload_url to sql
Browse files Browse the repository at this point in the history
  • Loading branch information
frankliee committed Dec 15, 2024
1 parent 95a61d2 commit cbeb3f9
Show file tree
Hide file tree
Showing 8 changed files with 213 additions and 13 deletions.
10 changes: 5 additions & 5 deletions src/daft-functions/src/uri/download.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,11 +12,11 @@ use snafu::prelude::*;
use crate::InvalidArgumentSnafu;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct DownloadFunction {
pub(super) max_connections: usize,
pub(super) raise_error_on_failure: bool,
pub(super) multi_thread: bool,
pub(super) config: Arc<IOConfig>,
pub struct DownloadFunction {
pub max_connections: usize,
pub raise_error_on_failure: bool,
pub multi_thread: bool,
pub config: Arc<IOConfig>,
}

#[typetag::serde]
Expand Down
4 changes: 2 additions & 2 deletions src/daft-functions/src/uri/mod.rs
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
mod download;
mod upload;
pub mod download;
pub mod upload;

use common_io_config::IOConfig;
use daft_dsl::{functions::ScalarFunction, ExprRef};
Expand Down
10 changes: 5 additions & 5 deletions src/daft-functions/src/uri/upload.rs
Original file line number Diff line number Diff line change
Expand Up @@ -9,11 +9,11 @@ use futures::{StreamExt, TryStreamExt};
use serde::Serialize;

#[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)]
pub(super) struct UploadFunction {
pub(super) location: String,
pub(super) max_connections: usize,
pub(super) multi_thread: bool,
pub(super) config: Arc<IOConfig>,
pub struct UploadFunction {
pub location: String,
pub max_connections: usize,
pub multi_thread: bool,
pub config: Arc<IOConfig>,
}

#[typetag::serde]
Expand Down
3 changes: 2 additions & 1 deletion src/daft-sql/src/functions.rs
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@ use crate::{
coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat,
SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric,
SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs,
SQLModuleTemporal, SQLModuleUtf8,
SQLModuleTemporal, SQLModuleURL, SQLModuleUtf8,
},
planner::SQLPlanner,
unsupported_sql_err,
Expand All @@ -37,6 +37,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy<SQLFunctions> = Lazy::new(|| {
functions.register::<SQLModuleStructs>();
functions.register::<SQLModuleTemporal>();
functions.register::<SQLModuleUtf8>();
functions.register::<SQLModuleURL>();
functions.register::<SQLModuleConfig>();
functions.add_fn("coalesce", SQLCoalesce {});
functions
Expand Down
2 changes: 2 additions & 0 deletions src/daft-sql/src/modules/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -15,6 +15,7 @@ pub mod python;
pub mod sketch;
pub mod structs;
pub mod temporal;
pub mod url;
pub mod utf8;

pub use aggs::SQLModuleAggs;
Expand All @@ -30,6 +31,7 @@ pub use python::SQLModulePython;
pub use sketch::SQLModuleSketch;
pub use structs::SQLModuleStructs;
pub use temporal::SQLModuleTemporal;
pub use url::SQLModuleURL;
pub use utf8::SQLModuleUtf8;

/// A [SQLModule] is a collection of SQL functions that can be registered with a [SQLFunctions] instance.
Expand Down
172 changes: 172 additions & 0 deletions src/daft-sql/src/modules/url.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,172 @@
use std::sync::Arc;

use daft_dsl::{Expr, ExprRef, LiteralValue};
use daft_functions::uri::{download, download::DownloadFunction, upload, upload::UploadFunction};
use sqlparser::ast::FunctionArg;

use super::SQLModule;
use crate::{
error::{PlannerError, SQLPlannerResult},
functions::{SQLFunction, SQLFunctionArguments, SQLFunctions},
modules::config::expr_to_iocfg,
planner::SQLPlanner,
unsupported_sql_err,
};

pub struct SQLModuleURL;

impl SQLModule for SQLModuleURL {
fn register(parent: &mut SQLFunctions) {
parent.add_fn("url_download", UrlDownload);
parent.add_fn("url_upload", UrlUpload);
}
}

impl TryFrom<SQLFunctionArguments> for DownloadFunction {
type Error = PlannerError;

fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> {
let max_connections = args.try_get_named("max_connections")?.unwrap_or(32);
let raise_error_on_failure = args
.get_named("on_error")
.map(|arg| match arg.as_ref() {
Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() {
"raise" => Ok(true),
"null" => Ok(false),
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),

Check warning on line 36 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L28-L36

Added lines #L28 - L36 were not covered by tests
},
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"),
})
.transpose()?
.unwrap_or(true);

Check warning on line 41 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L38-L41

Added lines #L38 - L41 were not covered by tests

// TODO: choice multi_thread based on the current engine (such as ray)
let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false);

Check warning on line 44 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L44

Added line #L44 was not covered by tests

let config = Arc::new(
args.get_named("io_config")
.map(expr_to_iocfg)
.transpose()?
.unwrap_or_default(),
);

Ok(Self {
max_connections,
raise_error_on_failure,
multi_thread,
config,
})
}

Check warning on line 59 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L46-L59

Added lines #L46 - L59 were not covered by tests
}

struct UrlDownload;

impl SQLFunction for UrlDownload {
fn to_expr(
&self,
inputs: &[sqlparser::ast::FunctionArg],
planner: &crate::planner::SQLPlanner,
) -> SQLPlannerResult<ExprRef> {
match inputs {
[input, args @ ..] => {
let input = planner.plan_function_arg(input)?;
let args: DownloadFunction = planner.plan_function_args(
args,
&["max_connections", "on_error", "multi_thread", "io_config"],
0,
)?;

Check warning on line 77 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L65-L77

Added lines #L65 - L77 were not covered by tests

Ok(download(
input,
args.max_connections,
args.raise_error_on_failure,
args.multi_thread,
Arc::try_unwrap(args.config).unwrap_or_default().into(),
))

Check warning on line 85 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L79-L85

Added lines #L79 - L85 were not covered by tests
}
_ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"),

Check warning on line 87 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L87

Added line #L87 was not covered by tests
}
}

Check warning on line 89 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L89

Added line #L89 was not covered by tests

fn docstrings(&self, _: &str) -> String {
"download data from the given url".to_string()
}

fn arg_names(&self) -> &'static [&'static str] {
&[
"input",
"max_connections",
"on_error",
"multi_thread",
"io_config",
]
}
}

impl TryFrom<SQLFunctionArguments> for UploadFunction {
type Error = PlannerError;

fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> {
let location = args.try_get_named("location")?.ok_or_else(|| {
PlannerError::invalid_operation("location is required for url_upload")

Check warning on line 111 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L111

Added line #L111 was not covered by tests
})?;
let max_connections = args.try_get_named("max_connections")?.unwrap_or(32);

// TODO: choice multi_thread based on the current engine (such as ray)
let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false);

let config = Arc::new(
args.get_named("io_config")
.map(expr_to_iocfg)
.transpose()?
.unwrap_or_default(),
);

Ok(Self {
location,
max_connections,
multi_thread,
config,
})
}
}

struct UrlUpload;

impl SQLFunction for UrlUpload {
fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> {
match inputs {
[input, args @ ..] => {
let input = planner.plan_function_arg(input)?;
let args: UploadFunction = planner.plan_function_args(
args,
&["location", "max_connections", "multi_thread", "io_config"],
0,
)?;

Ok(upload(
input,
args.location.as_str(),
args.max_connections,
args.multi_thread,
Arc::try_unwrap(args.config).unwrap_or_default().into(),
))
}
_ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"),

Check warning on line 155 in src/daft-sql/src/modules/url.rs

View check run for this annotation

Codecov / codecov/patch

src/daft-sql/src/modules/url.rs#L155

Added line #L155 was not covered by tests
}
}

fn docstrings(&self, _: &str) -> String {
"upload data to the given path".to_string()
}

fn arg_names(&self) -> &'static [&'static str] {
&[
"input",
"location",
"max_connections",
"multi_thread",
"io_config",
]
}
}
8 changes: 8 additions & 0 deletions tests/io/test_url_download_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -34,6 +34,14 @@ def test_url_download_local(local_image_data_fixture, image_data):
assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]}


@pytest.mark.integration()
def test_sql_url_download_local(local_image_data_fixture, image_data):
data = {"urls": local_image_data_fixture}
df = daft.from_pydict(data)
df = daft.sql("SELECT url_download(urls) as data FROM df")
assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]}


@pytest.mark.integration()
def test_url_download_local_missing(local_image_data_fixture):
data = {"urls": local_image_data_fixture + ["/missing/path/x.jpeg"]}
Expand Down
17 changes: 17 additions & 0 deletions tests/io/test_url_upload_local.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,3 +18,20 @@ def test_upload_local(tmpdir):
path = path[len("file://") :]
with open(path, "rb") as f:
assert f.read() == expected


def test_sql_upload_local(tmpdir):
bytes_data = [b"a", b"b", b"c"]
data = {"data": bytes_data}
df = daft.from_pydict(data)
df = daft.sql(f"SELECT data, url_upload(data, location :='{tmpdir!s}') AS files FROM df")
df.collect()

results = df.to_pydict()
assert results["data"] == bytes_data
assert len(results["files"]) == len(bytes_data)
for path, expected in zip(results["files"], bytes_data):
assert path.startswith("file://")
path = path[len("file://") :]
with open(path, "rb") as f:
assert f.read() == expected

0 comments on commit cbeb3f9

Please sign in to comment.