From 655307ac9c4ae52b35103ac334765ed3548f01c3 Mon Sep 17 00:00:00 2001 From: frankzfli Date: Sun, 15 Dec 2024 01:48:55 +0800 Subject: [PATCH] feat: add download_url/upload_url to sql --- src/daft-functions/src/uri/download.rs | 10 +- src/daft-functions/src/uri/mod.rs | 4 +- src/daft-functions/src/uri/upload.rs | 10 +- src/daft-sql/src/functions.rs | 3 +- src/daft-sql/src/modules/mod.rs | 2 + src/daft-sql/src/modules/url.rs | 172 +++++++++++++++++++++++++ tests/io/test_url_download_local.py | 8 ++ tests/io/test_url_upload_local.py | 17 +++ 8 files changed, 213 insertions(+), 13 deletions(-) create mode 100644 src/daft-sql/src/modules/url.rs diff --git a/src/daft-functions/src/uri/download.rs b/src/daft-functions/src/uri/download.rs index 24d3f89d33..59dd8ec649 100644 --- a/src/daft-functions/src/uri/download.rs +++ b/src/daft-functions/src/uri/download.rs @@ -12,11 +12,11 @@ use snafu::prelude::*; use crate::InvalidArgumentSnafu; #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)] -pub(super) struct DownloadFunction { - pub(super) max_connections: usize, - pub(super) raise_error_on_failure: bool, - pub(super) multi_thread: bool, - pub(super) config: Arc, +pub struct DownloadFunction { + pub max_connections: usize, + pub raise_error_on_failure: bool, + pub multi_thread: bool, + pub config: Arc, } #[typetag::serde] diff --git a/src/daft-functions/src/uri/mod.rs b/src/daft-functions/src/uri/mod.rs index cb74a6f045..24b540e8ae 100644 --- a/src/daft-functions/src/uri/mod.rs +++ b/src/daft-functions/src/uri/mod.rs @@ -1,5 +1,5 @@ -mod download; -mod upload; +pub mod download; +pub mod upload; use common_io_config::IOConfig; use daft_dsl::{functions::ScalarFunction, ExprRef}; diff --git a/src/daft-functions/src/uri/upload.rs b/src/daft-functions/src/uri/upload.rs index 1ad91b888b..2c5e1b203c 100644 --- a/src/daft-functions/src/uri/upload.rs +++ b/src/daft-functions/src/uri/upload.rs @@ -9,11 +9,11 @@ use futures::{StreamExt, TryStreamExt}; use serde::Serialize; #[derive(Debug, Clone, Serialize, serde::Deserialize, PartialEq, Eq, Hash)] -pub(super) struct UploadFunction { - pub(super) location: String, - pub(super) max_connections: usize, - pub(super) multi_thread: bool, - pub(super) config: Arc, +pub struct UploadFunction { + pub location: String, + pub max_connections: usize, + pub multi_thread: bool, + pub config: Arc, } #[typetag::serde] diff --git a/src/daft-sql/src/functions.rs b/src/daft-sql/src/functions.rs index d75f090072..4e969394a6 100644 --- a/src/daft-sql/src/functions.rs +++ b/src/daft-sql/src/functions.rs @@ -14,7 +14,7 @@ use crate::{ coalesce::SQLCoalesce, hashing, SQLModule, SQLModuleAggs, SQLModuleConfig, SQLModuleFloat, SQLModuleImage, SQLModuleJson, SQLModuleList, SQLModuleMap, SQLModuleNumeric, SQLModulePartitioning, SQLModulePython, SQLModuleSketch, SQLModuleStructs, - SQLModuleTemporal, SQLModuleUtf8, + SQLModuleTemporal, SQLModuleURL, SQLModuleUtf8, }, planner::SQLPlanner, unsupported_sql_err, @@ -37,6 +37,7 @@ pub(crate) static SQL_FUNCTIONS: Lazy = Lazy::new(|| { functions.register::(); functions.register::(); functions.register::(); + functions.register::(); functions.register::(); functions.add_fn("coalesce", SQLCoalesce {}); functions diff --git a/src/daft-sql/src/modules/mod.rs b/src/daft-sql/src/modules/mod.rs index 30195dc52f..c3f654adff 100644 --- a/src/daft-sql/src/modules/mod.rs +++ b/src/daft-sql/src/modules/mod.rs @@ -15,6 +15,7 @@ pub mod python; pub mod sketch; pub mod structs; pub mod temporal; +pub mod url; pub mod utf8; pub use aggs::SQLModuleAggs; @@ -30,6 +31,7 @@ pub use python::SQLModulePython; pub use sketch::SQLModuleSketch; pub use structs::SQLModuleStructs; pub use temporal::SQLModuleTemporal; +pub use url::SQLModuleURL; pub use utf8::SQLModuleUtf8; /// A [SQLModule] is a collection of SQL functions that can be registered with a [SQLFunctions] instance. diff --git a/src/daft-sql/src/modules/url.rs b/src/daft-sql/src/modules/url.rs new file mode 100644 index 0000000000..134fdf11b2 --- /dev/null +++ b/src/daft-sql/src/modules/url.rs @@ -0,0 +1,172 @@ +use std::sync::Arc; + +use daft_dsl::{Expr, ExprRef, LiteralValue}; +use daft_functions::uri::{download, download::DownloadFunction, upload, upload::UploadFunction}; +use sqlparser::ast::FunctionArg; + +use super::SQLModule; +use crate::{ + error::{PlannerError, SQLPlannerResult}, + functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, + modules::config::expr_to_iocfg, + planner::SQLPlanner, + unsupported_sql_err, +}; + +pub struct SQLModuleURL; + +impl SQLModule for SQLModuleURL { + fn register(parent: &mut SQLFunctions) { + parent.add_fn("url_download", UrlDownload); + parent.add_fn("url_upload", UrlUpload); + } +} + +impl TryFrom for DownloadFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); + let raise_error_on_failure = args + .get_named("on_error") + .map(|arg| match arg.as_ref() { + Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() { + "raise" => Ok(true), + "null" => Ok(false), + _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + }, + _ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), + }) + .transpose()? + .unwrap_or(true); + + // TODO: choice multi_thread based on the current engine (such as ray) + let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false); + + let config = Arc::new( + args.get_named("io_config") + .map(expr_to_iocfg) + .transpose()? + .unwrap_or_default(), + ); + + Ok(Self { + max_connections, + raise_error_on_failure, + multi_thread, + config, + }) + } +} + +struct UrlDownload; + +impl SQLFunction for UrlDownload { + fn to_expr( + &self, + inputs: &[sqlparser::ast::FunctionArg], + planner: &crate::planner::SQLPlanner, + ) -> SQLPlannerResult { + match inputs { + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: DownloadFunction = planner.plan_function_args( + args, + &["max_connections", "on_error", "multi_thread", "io_config"], + 0, + )?; + + Ok(download( + input, + args.max_connections, + args.raise_error_on_failure, + args.multi_thread, + Arc::try_unwrap(args.config).unwrap_or_default().into(), + )) + } + _ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"), + } + } + + fn docstrings(&self, _: &str) -> String { + "download data from the given url".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[ + "input", + "max_connections", + "on_error", + "multi_thread", + "io_config", + ] + } +} + +impl TryFrom for UploadFunction { + type Error = PlannerError; + + fn try_from(args: SQLFunctionArguments) -> Result { + let location = args.try_get_named("location")?.ok_or_else(|| { + PlannerError::invalid_operation("location is required for url_upload") + })?; + let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); + + // TODO: choice multi_thread based on the current engine (such as ray) + let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false); + + let config = Arc::new( + args.get_named("io_config") + .map(expr_to_iocfg) + .transpose()? + .unwrap_or_default(), + ); + + Ok(Self { + location, + max_connections, + multi_thread, + config, + }) + } +} + +struct UrlUpload; + +impl SQLFunction for UrlUpload { + fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult { + match inputs { + [input, args @ ..] => { + let input = planner.plan_function_arg(input)?; + let args: UploadFunction = planner.plan_function_args( + args, + &["location", "max_connections", "multi_thread", "io_config"], + 0, + )?; + + Ok(upload( + input, + args.location.as_str(), + args.max_connections, + args.multi_thread, + Arc::try_unwrap(args.config).unwrap_or_default().into(), + )) + } + _ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"), + } + } + + fn docstrings(&self, _: &str) -> String { + "upload data to the given path".to_string() + } + + fn arg_names(&self) -> &'static [&'static str] { + &[ + "input", + "location", + "max_connections", + "multi_thread", + "io_config", + ] + } +} diff --git a/tests/io/test_url_download_local.py b/tests/io/test_url_download_local.py index 57dcbf6d9e..04edb43049 100644 --- a/tests/io/test_url_download_local.py +++ b/tests/io/test_url_download_local.py @@ -34,6 +34,14 @@ def test_url_download_local(local_image_data_fixture, image_data): assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]} +@pytest.mark.integration() +def test_sql_url_download_local(local_image_data_fixture, image_data): + data = {"urls": local_image_data_fixture} + df = daft.from_pydict(data) + df = daft.sql("SELECT urls, url_download(urls) AS data FROM df") + assert df.to_pydict() == {**data, "data": [image_data for _ in range(len(local_image_data_fixture))]} + + @pytest.mark.integration() def test_url_download_local_missing(local_image_data_fixture): data = {"urls": local_image_data_fixture + ["/missing/path/x.jpeg"]} diff --git a/tests/io/test_url_upload_local.py b/tests/io/test_url_upload_local.py index 2f3fcbd5e3..3a7a5daa0c 100644 --- a/tests/io/test_url_upload_local.py +++ b/tests/io/test_url_upload_local.py @@ -18,3 +18,20 @@ def test_upload_local(tmpdir): path = path[len("file://") :] with open(path, "rb") as f: assert f.read() == expected + + +def test_sql_upload_local(tmpdir): + bytes_data = [b"a", b"b", b"c"] + data = {"data": bytes_data} + df = daft.from_pydict(data) + df = daft.sql(f"SELECT data, url_upload(data, location :='{tmpdir!s}') AS files FROM df") + df.collect() + + results = df.to_pydict() + assert results["data"] == bytes_data + assert len(results["files"]) == len(bytes_data) + for path, expected in zip(results["files"], bytes_data): + assert path.startswith("file://") + path = path[len("file://") :] + with open(path, "rb") as f: + assert f.read() == expected