Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: change catalog provider and schema provider methods to be asynchronous #13582

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/src/bin/dfbench.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,12 +53,12 @@ pub async fn main() -> Result<()> {
env_logger::init();

match Options::from_args() {
Options::Tpch(opt) => opt.run().await,
Options::Tpch(opt) => Box::pin(opt.run()).await,
Options::TpchConvert(opt) => opt.run().await,
Options::Clickbench(opt) => opt.run().await,
Options::ParquetFilter(opt) => opt.run().await,
Options::Sort(opt) => opt.run().await,
Options::SortTpch(opt) => opt.run().await,
Options::Imdb(opt) => opt.run().await,
Options::Imdb(opt) => Box::pin(opt.run()).await,
}
}
4 changes: 2 additions & 2 deletions benchmarks/src/bin/external_aggr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ impl ExternalAggrConfig {
table,
start.elapsed().as_millis()
);
ctx.register_table(table, Arc::new(memtable))?;
ctx.register_table(table, Arc::new(memtable)).await?;
} else {
ctx.register_table(table, table_provider)?;
ctx.register_table(table, table_provider).await?;
}
}
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/src/bin/h2o.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async fn group_by(opt: &GroupBy) -> Result<()> {
let partition_size = num_cpus::get();
let memtable =
MemTable::load(Arc::new(csv), Some(partition_size), &ctx.state()).await?;
ctx.register_table("x", Arc::new(memtable))?;
ctx.register_table("x", Arc::new(memtable)).await?;
} else {
ctx.register_csv("x", path, CsvReadOptions::default().schema(&schema))
.await?;
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/src/bin/imdb.rs
Original file line number Diff line number Diff line change
Expand Up @@ -53,7 +53,7 @@ pub async fn main() -> Result<()> {
env_logger::init();
match ImdbOpt::from_args() {
ImdbOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => {
opt.run().await
Box::pin(opt.run()).await
}
ImdbOpt::Convert(opt) => opt.run().await,
}
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/src/bin/tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -58,7 +58,7 @@ async fn main() -> Result<()> {
env_logger::init();
match TpchOpt::from_args() {
TpchOpt::Benchmark(BenchmarkSubCommandOpt::DataFusionBenchmark(opt)) => {
opt.run().await
Box::pin(opt.run()).await
}
TpchOpt::Convert(opt) => opt.run().await,
}
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/src/imdb/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,9 @@ impl RunOpt {
table,
start.elapsed().as_millis()
);
ctx.register_table(*table, Arc::new(memtable))?;
ctx.register_table(*table, Arc::new(memtable)).await?;
} else {
ctx.register_table(*table, table_provider)?;
ctx.register_table(*table, table_provider).await?;
}
}
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/src/sort_tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ impl RunOpt {
table,
start.elapsed().as_millis()
);
ctx.register_table(table, Arc::new(memtable))?;
ctx.register_table(table, Arc::new(memtable)).await?;
} else {
ctx.register_table(table, table_provider)?;
ctx.register_table(table, table_provider).await?;
}
}
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/src/tpch/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ impl RunOpt {
table,
start.elapsed().as_millis()
);
ctx.register_table(*table, Arc::new(memtable))?;
ctx.register_table(*table, Arc::new(memtable)).await?;
} else {
ctx.register_table(*table, table_provider)?;
ctx.register_table(*table, table_provider).await?;
}
}
Ok(())
Expand Down
1 change: 1 addition & 0 deletions datafusion-cli/Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

91 changes: 59 additions & 32 deletions datafusion-cli/src/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,6 +31,7 @@ use datafusion::execution::session_state::SessionStateBuilder;

use async_trait::async_trait;
use dirs::home_dir;
use futures::stream::BoxStream;
use parking_lot::RwLock;

/// Wraps another catalog, automatically register require object stores for the file locations
Expand All @@ -49,28 +50,29 @@ impl DynamicObjectStoreCatalog {
}
}

#[async_trait]
impl CatalogProviderList for DynamicObjectStoreCatalog {
fn as_any(&self) -> &dyn Any {
self
}

fn register_catalog(
async fn register_catalog(
&self,
name: String,
catalog: Arc<dyn CatalogProvider>,
) -> Option<Arc<dyn CatalogProvider>> {
self.inner.register_catalog(name, catalog)
) -> Result<Option<Arc<dyn CatalogProvider>>> {
self.inner.register_catalog(name, catalog).await
}

fn catalog_names(&self) -> Vec<String> {
self.inner.catalog_names()
async fn catalog_names(&self) -> BoxStream<'static, Result<String>> {
self.inner.catalog_names().await
}

fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
async fn catalog(&self, name: &str) -> Result<Option<Arc<dyn CatalogProvider>>> {
let state = self.state.clone();
self.inner.catalog(name).map(|catalog| {
Ok(self.inner.catalog(name).await?.map(|catalog| {
Arc::new(DynamicObjectStoreCatalogProvider::new(catalog, state)) as _
})
}))
}
}

Expand All @@ -90,28 +92,29 @@ impl DynamicObjectStoreCatalogProvider {
}
}

#[async_trait]
impl CatalogProvider for DynamicObjectStoreCatalogProvider {
fn as_any(&self) -> &dyn Any {
self
}

fn schema_names(&self) -> Vec<String> {
self.inner.schema_names()
async fn schema_names(&self) -> BoxStream<'static, Result<String>> {
self.inner.schema_names().await
}

fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
async fn schema(&self, name: &str) -> Result<Option<Arc<dyn SchemaProvider>>> {
let state = self.state.clone();
self.inner.schema(name).map(|schema| {
Ok(self.inner.schema(name).await?.map(|schema| {
Arc::new(DynamicObjectStoreSchemaProvider::new(schema, state)) as _
})
}))
}

fn register_schema(
async fn register_schema(
&self,
name: &str,
schema: Arc<dyn SchemaProvider>,
) -> Result<Option<Arc<dyn SchemaProvider>>> {
self.inner.register_schema(name, schema)
self.inner.register_schema(name, schema).await
}
}

Expand All @@ -138,16 +141,16 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider {
self
}

fn table_names(&self) -> Vec<String> {
self.inner.table_names()
async fn table_names(&self) -> BoxStream<'static, Result<String>> {
self.inner.table_names().await
}

fn register_table(
async fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.register_table(name, table)
self.inner.register_table(name, table).await
}

async fn table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
Expand All @@ -166,7 +169,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider {
.ok_or_else(|| plan_datafusion_err!("locking error"))?
.read()
.clone();
let mut builder = SessionStateBuilder::from(state.clone());
let mut builder = SessionStateBuilder::new_from_existing(state.clone()).await;
let optimized_name = substitute_tilde(name.to_owned());
let table_url = ListingTableUrl::parse(optimized_name.as_str())?;
let scheme = table_url.scheme();
Expand Down Expand Up @@ -194,7 +197,7 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider {
}
_ => {}
};
state = builder.build();
state = builder.build().await;
let store = get_object_store(
&state,
table_url.scheme(),
Expand All @@ -208,12 +211,15 @@ impl SchemaProvider for DynamicObjectStoreSchemaProvider {
self.inner.table(name).await
}

fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.deregister_table(name)
async fn deregister_table(
&self,
name: &str,
) -> Result<Option<Arc<dyn TableProvider>>> {
self.inner.deregister_table(name).await
}

fn table_exist(&self, name: &str) -> bool {
self.inner.table_exist(name)
async fn table_exist(&self, name: &str) -> bool {
self.inner.table_exist(name).await
}
}

Expand All @@ -234,8 +240,9 @@ mod tests {

use datafusion::catalog::SchemaProvider;
use datafusion::prelude::SessionContext;
use futures::TryStreamExt;

fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
async fn setup_context() -> (SessionContext, Arc<dyn SchemaProvider>) {
let ctx = SessionContext::new();
ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new(
ctx.state().catalog_list().clone(),
Expand All @@ -247,10 +254,30 @@ mod tests {
ctx.state_weak_ref(),
) as &dyn CatalogProviderList;
let catalog = provider
.catalog(provider.catalog_names().first().unwrap())
.catalog(
&provider
.catalog_names()
.await
.try_next()
.await
.unwrap()
.unwrap(),
)
.await
.unwrap()
.unwrap();
let schema = catalog
.schema(catalog.schema_names().first().unwrap())
.schema(
&catalog
.schema_names()
.await
.try_next()
.await
.unwrap()
.unwrap(),
)
.await
.unwrap()
.unwrap();
(ctx, schema)
}
Expand All @@ -262,7 +289,7 @@ mod tests {
let domain = "example.com";
let location = format!("http://{domain}/file.parquet");

let (ctx, schema) = setup_context();
let (ctx, schema) = setup_context().await;

// That's a non registered table so expecting None here
let table = schema.table(&location).await?;
Expand All @@ -287,7 +314,7 @@ mod tests {
let bucket = "examples3bucket";
let location = format!("s3://{bucket}/file.parquet");

let (ctx, schema) = setup_context();
let (ctx, schema) = setup_context().await;

let table = schema.table(&location).await?;
assert!(table.is_none());
Expand All @@ -309,7 +336,7 @@ mod tests {
let bucket = "examplegsbucket";
let location = format!("gs://{bucket}/file.parquet");

let (ctx, schema) = setup_context();
let (ctx, schema) = setup_context().await;

let table = schema.table(&location).await?;
assert!(table.is_none());
Expand All @@ -329,7 +356,7 @@ mod tests {
#[tokio::test]
async fn query_invalid_location_test() {
let location = "ts://file.parquet";
let (_ctx, schema) = setup_context();
let (_ctx, schema) = setup_context().await;

assert!(schema.table(location).await.is_err());
}
Expand Down
3 changes: 2 additions & 1 deletion datafusion-cli/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -176,7 +176,8 @@ async fn main_inner() -> Result<()> {
// enable dynamic file query
let ctx =
SessionContext::new_with_config_rt(session_config.clone(), Arc::new(runtime_env))
.enable_url_table();
.enable_url_table()
.await;
ctx.refresh_catalogs().await?;
// install dynamic catalog provider that can register required object stores
ctx.register_catalog_list(Arc::new(DynamicObjectStoreCatalog::new(
Expand Down
3 changes: 2 additions & 1 deletion datafusion-examples/examples/advanced_parquet_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ async fn main() -> Result<()> {
// SessionContext for running queries that has the table provider
// registered as "index_table"
let ctx = SessionContext::new();
ctx.register_table("index_table", Arc::clone(&provider) as _)?;
ctx.register_table("index_table", Arc::clone(&provider) as _)
.await?;

// register object store provider for urls like `file://` work
let url = Url::try_from("file://").unwrap();
Expand Down
6 changes: 3 additions & 3 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl Accumulator for GeometricMean {
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
async fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down Expand Up @@ -227,7 +227,7 @@ fn create_context() -> Result<SessionContext> {

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
ctx.register_table("t", Arc::new(provider)).await?;
Ok(ctx)
}

Expand Down Expand Up @@ -401,7 +401,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
let ctx = create_context().await?;

// create the AggregateUDF
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
Expand Down
6 changes: 3 additions & 3 deletions datafusion-examples/examples/advanced_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl ScalarUDFImpl for PowUdf {
/// and invoke it via the DataFrame API and SQL
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
let ctx = create_context().await?;

// create the UDF
let pow = ScalarUDF::from(PowUdf::new());
Expand Down Expand Up @@ -234,7 +234,7 @@ async fn main() -> Result<()> {
/// | 5.1 | 4.0 |
/// +-----+-----+
/// ```
fn create_context() -> Result<SessionContext> {
async fn create_context() -> Result<SessionContext> {
// define data.
let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1]));
let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
Expand All @@ -244,6 +244,6 @@ fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();

// declare a table in memory. In Spark API, this corresponds to createDataFrame(...).
ctx.register_batch("t", batch)?;
ctx.register_batch("t", batch).await?;
Ok(ctx)
}
Loading