-
Notifications
You must be signed in to change notification settings - Fork 174
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat: add download_url/upload_url to sql
- Loading branch information
Showing
8 changed files
with
213 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,172 @@ | ||
use std::sync::Arc; | ||
|
||
use daft_dsl::{Expr, ExprRef, LiteralValue}; | ||
use daft_functions::uri::{download, download::DownloadFunction, upload, upload::UploadFunction}; | ||
use sqlparser::ast::FunctionArg; | ||
|
||
use super::SQLModule; | ||
use crate::{ | ||
error::{PlannerError, SQLPlannerResult}, | ||
functions::{SQLFunction, SQLFunctionArguments, SQLFunctions}, | ||
modules::config::expr_to_iocfg, | ||
planner::SQLPlanner, | ||
unsupported_sql_err, | ||
}; | ||
|
||
pub struct SQLModuleURL; | ||
|
||
impl SQLModule for SQLModuleURL { | ||
fn register(parent: &mut SQLFunctions) { | ||
parent.add_fn("url_download", UrlDownload); | ||
parent.add_fn("url_upload", UrlUpload); | ||
} | ||
} | ||
|
||
impl TryFrom<SQLFunctionArguments> for DownloadFunction { | ||
type Error = PlannerError; | ||
|
||
fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> { | ||
let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); | ||
let raise_error_on_failure = args | ||
.get_named("on_error") | ||
.map(|arg| match arg.as_ref() { | ||
Expr::Literal(LiteralValue::Utf8(s)) => match s.as_ref() { | ||
"raise" => Ok(true), | ||
"null" => Ok(false), | ||
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), | ||
}, | ||
_ => unsupported_sql_err!("Expected on_error to be 'raise' or 'null'"), | ||
}) | ||
.transpose()? | ||
.unwrap_or(true); | ||
|
||
// TODO: choice multi_thread based on the current engine (such as ray) | ||
let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false); | ||
|
||
let config = Arc::new( | ||
args.get_named("io_config") | ||
.map(expr_to_iocfg) | ||
.transpose()? | ||
.unwrap_or_default(), | ||
); | ||
|
||
Ok(Self { | ||
max_connections, | ||
raise_error_on_failure, | ||
multi_thread, | ||
config, | ||
}) | ||
} | ||
} | ||
|
||
struct UrlDownload; | ||
|
||
impl SQLFunction for UrlDownload { | ||
fn to_expr( | ||
&self, | ||
inputs: &[sqlparser::ast::FunctionArg], | ||
planner: &crate::planner::SQLPlanner, | ||
) -> SQLPlannerResult<ExprRef> { | ||
match inputs { | ||
[input, args @ ..] => { | ||
let input = planner.plan_function_arg(input)?; | ||
let args: DownloadFunction = planner.plan_function_args( | ||
args, | ||
&["max_connections", "on_error", "multi_thread", "io_config"], | ||
0, | ||
)?; | ||
|
||
Ok(download( | ||
input, | ||
args.max_connections, | ||
args.raise_error_on_failure, | ||
args.multi_thread, | ||
Arc::try_unwrap(args.config).unwrap_or_default().into(), | ||
)) | ||
} | ||
_ => unsupported_sql_err!("Invalid arguments for url_download: '{inputs:?}'"), | ||
} | ||
} | ||
|
||
fn docstrings(&self, _: &str) -> String { | ||
"download data from the given url".to_string() | ||
} | ||
|
||
fn arg_names(&self) -> &'static [&'static str] { | ||
&[ | ||
"input", | ||
"max_connections", | ||
"on_error", | ||
"multi_thread", | ||
"io_config", | ||
] | ||
} | ||
} | ||
|
||
impl TryFrom<SQLFunctionArguments> for UploadFunction { | ||
type Error = PlannerError; | ||
|
||
fn try_from(args: SQLFunctionArguments) -> Result<Self, Self::Error> { | ||
let location = args.try_get_named("location")?.ok_or_else(|| { | ||
PlannerError::invalid_operation("location is required for url_upload") | ||
})?; | ||
let max_connections = args.try_get_named("max_connections")?.unwrap_or(32); | ||
|
||
// TODO: choice multi_thread based on the current engine (such as ray) | ||
let multi_thread = args.try_get_named("multi_thread")?.unwrap_or(false); | ||
|
||
let config = Arc::new( | ||
args.get_named("io_config") | ||
.map(expr_to_iocfg) | ||
.transpose()? | ||
.unwrap_or_default(), | ||
); | ||
|
||
Ok(Self { | ||
location, | ||
max_connections, | ||
multi_thread, | ||
config, | ||
}) | ||
} | ||
} | ||
|
||
struct UrlUpload; | ||
|
||
impl SQLFunction for UrlUpload { | ||
fn to_expr(&self, inputs: &[FunctionArg], planner: &SQLPlanner) -> SQLPlannerResult<ExprRef> { | ||
match inputs { | ||
[input, args @ ..] => { | ||
let input = planner.plan_function_arg(input)?; | ||
let args: UploadFunction = planner.plan_function_args( | ||
args, | ||
&["location", "max_connections", "multi_thread", "io_config"], | ||
0, | ||
)?; | ||
|
||
Ok(upload( | ||
input, | ||
args.location.as_str(), | ||
args.max_connections, | ||
args.multi_thread, | ||
Arc::try_unwrap(args.config).unwrap_or_default().into(), | ||
)) | ||
} | ||
_ => unsupported_sql_err!("Invalid arguments for url_upload: '{inputs:?}'"), | ||
} | ||
} | ||
|
||
fn docstrings(&self, _: &str) -> String { | ||
"upload data to the given path".to_string() | ||
} | ||
|
||
fn arg_names(&self) -> &'static [&'static str] { | ||
&[ | ||
"input", | ||
"location", | ||
"max_connections", | ||
"multi_thread", | ||
"io_config", | ||
] | ||
} | ||
} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters