Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

feat!: change catalog provider and schema provider methods to be asynchronous #13582

Open
wants to merge 4 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 2 additions & 2 deletions benchmarks/src/bin/external_aggr.rs
Original file line number Diff line number Diff line change
Expand Up @@ -245,9 +245,9 @@ impl ExternalAggrConfig {
table,
start.elapsed().as_millis()
);
ctx.register_table(table, Arc::new(memtable))?;
ctx.register_table(table, Arc::new(memtable)).await?;
} else {
ctx.register_table(table, table_provider)?;
ctx.register_table(table, table_provider).await?;
}
}
Ok(())
Expand Down
2 changes: 1 addition & 1 deletion benchmarks/src/bin/h2o.rs
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,7 @@ async fn group_by(opt: &GroupBy) -> Result<()> {
let partition_size = num_cpus::get();
let memtable =
MemTable::load(Arc::new(csv), Some(partition_size), &ctx.state()).await?;
ctx.register_table("x", Arc::new(memtable))?;
ctx.register_table("x", Arc::new(memtable)).await?;
} else {
ctx.register_csv("x", path, CsvReadOptions::default().schema(&schema))
.await?;
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/src/imdb/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -358,9 +358,9 @@ impl RunOpt {
table,
start.elapsed().as_millis()
);
ctx.register_table(*table, Arc::new(memtable))?;
ctx.register_table(*table, Arc::new(memtable)).await?;
} else {
ctx.register_table(*table, table_provider)?;
ctx.register_table(*table, table_provider).await?;
}
}
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/src/sort_tpch.rs
Original file line number Diff line number Diff line change
Expand Up @@ -236,9 +236,9 @@ impl RunOpt {
table,
start.elapsed().as_millis()
);
ctx.register_table(table, Arc::new(memtable))?;
ctx.register_table(table, Arc::new(memtable)).await?;
} else {
ctx.register_table(table, table_provider)?;
ctx.register_table(table, table_provider).await?;
}
}
Ok(())
Expand Down
4 changes: 2 additions & 2 deletions benchmarks/src/tpch/run.rs
Original file line number Diff line number Diff line change
Expand Up @@ -182,9 +182,9 @@ impl RunOpt {
table,
start.elapsed().as_millis()
);
ctx.register_table(*table, Arc::new(memtable))?;
ctx.register_table(*table, Arc::new(memtable)).await?;
} else {
ctx.register_table(*table, table_provider)?;
ctx.register_table(*table, table_provider).await?;
}
}
Ok(())
Expand Down
3 changes: 2 additions & 1 deletion datafusion-examples/examples/advanced_parquet_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -168,7 +168,8 @@ async fn main() -> Result<()> {
// SessionContext for running queries that has the table provider
// registered as "index_table"
let ctx = SessionContext::new();
ctx.register_table("index_table", Arc::clone(&provider) as _)?;
ctx.register_table("index_table", Arc::clone(&provider) as _)
.await?;

// register object store provider for urls like `file://` work
let url = Url::try_from("file://").unwrap();
Expand Down
6 changes: 3 additions & 3 deletions datafusion-examples/examples/advanced_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ impl Accumulator for GeometricMean {
}

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
async fn create_context() -> Result<SessionContext> {
use datafusion::datasource::MemTable;
// define a schema.
let schema = Arc::new(Schema::new(vec![
Expand Down Expand Up @@ -227,7 +227,7 @@ fn create_context() -> Result<SessionContext> {

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
ctx.register_table("t", Arc::new(provider)).await?;
Ok(ctx)
}

Expand Down Expand Up @@ -401,7 +401,7 @@ impl GroupsAccumulator for GeometricMeanGroupsAccumulator {

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
let ctx = create_context().await?;

// create the AggregateUDF
let geometric_mean = AggregateUDF::from(GeoMeanUdaf::new());
Expand Down
6 changes: 3 additions & 3 deletions datafusion-examples/examples/advanced_udf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -195,7 +195,7 @@ impl ScalarUDFImpl for PowUdf {
/// and invoke it via the DataFrame API and SQL
#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
let ctx = create_context().await?;

// create the UDF
let pow = ScalarUDF::from(PowUdf::new());
Expand Down Expand Up @@ -234,7 +234,7 @@ async fn main() -> Result<()> {
/// | 5.1 | 4.0 |
/// +-----+-----+
/// ```
fn create_context() -> Result<SessionContext> {
async fn create_context() -> Result<SessionContext> {
// define data.
let a: ArrayRef = Arc::new(Float32Array::from(vec![2.1, 3.1, 4.1, 5.1]));
let b: ArrayRef = Arc::new(Float64Array::from(vec![1.0, 2.0, 3.0, 4.0]));
Expand All @@ -244,6 +244,6 @@ fn create_context() -> Result<SessionContext> {
let ctx = SessionContext::new();

// declare a table in memory. In Spark API, this corresponds to createDataFrame(...).
ctx.register_batch("t", batch)?;
ctx.register_batch("t", batch).await?;
Ok(ctx)
}
2 changes: 1 addition & 1 deletion datafusion-examples/examples/analyzer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -46,7 +46,7 @@ pub async fn main() -> Result<()> {
let ctx = SessionContext::new();
ctx.add_analyzer_rule(Arc::clone(&rule) as _);

ctx.register_batch("employee", employee_batch())?;
ctx.register_batch("employee", employee_batch()).await?;

// Now, planning any SQL statement also invokes the AnalyzerRule
let plan = ctx
Expand Down
37 changes: 24 additions & 13 deletions datafusion-examples/examples/catalog.rs
Original file line number Diff line number Diff line change
Expand Up @@ -74,11 +74,15 @@ async fn main() -> Result<()> {
.await?;

// register schemas into catalog
catalog.register_schema("schema_a", schema_a.clone())?;
catalog.register_schema("schema_b", schema_b.clone())?;
catalog
.register_schema("schema_a", schema_a.clone())
.await?;
catalog
.register_schema("schema_b", schema_b.clone())
.await?;

// register our catalog in the context
ctx.register_catalog("dircat", Arc::new(catalog));
ctx.register_catalog("dircat", Arc::new(catalog)).await;
{
// catalog was passed down into our custom catalog list since we override the ctx's default
let catalogs = cataloglist.catalogs.read().unwrap();
Expand Down Expand Up @@ -184,7 +188,7 @@ impl SchemaProvider for DirSchema {
self
}

fn table_names(&self) -> Vec<String> {
async fn table_names(&self) -> Vec<String> {
let tables = self.tables.read().unwrap();
tables.keys().cloned().collect::<Vec<_>>()
}
Expand All @@ -194,11 +198,11 @@ impl SchemaProvider for DirSchema {
Ok(tables.get(name).cloned())
}

fn table_exist(&self, name: &str) -> bool {
async fn table_exist(&self, name: &str) -> bool {
let tables = self.tables.read().unwrap();
tables.contains_key(name)
}
fn register_table(
async fn register_table(
&self,
name: String,
table: Arc<dyn TableProvider>,
Expand All @@ -212,7 +216,10 @@ impl SchemaProvider for DirSchema {
/// If supported by the implementation, removes an existing table from this schema and returns it.
/// If no table of that name exists, returns Ok(None).
#[allow(unused_variables)]
fn deregister_table(&self, name: &str) -> Result<Option<Arc<dyn TableProvider>>> {
async fn deregister_table(
&self,
name: &str,
) -> Result<Option<Arc<dyn TableProvider>>> {
let mut tables = self.tables.write().unwrap();
log::info!("dropping table {name}");
Ok(tables.remove(name))
Expand All @@ -230,11 +237,13 @@ impl DirCatalog {
}
}
}

#[async_trait]
impl CatalogProvider for DirCatalog {
fn as_any(&self) -> &dyn Any {
self
}
fn register_schema(
async fn register_schema(
&self,
name: &str,
schema: Arc<dyn SchemaProvider>,
Expand All @@ -244,12 +253,12 @@ impl CatalogProvider for DirCatalog {
Ok(Some(schema))
}

fn schema_names(&self) -> Vec<String> {
async fn schema_names(&self) -> Vec<String> {
let schemas = self.schemas.read().unwrap();
schemas.keys().cloned().collect()
}

fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
async fn schema(&self, name: &str) -> Option<Arc<dyn SchemaProvider>> {
let schemas = self.schemas.read().unwrap();
let maybe_schema = schemas.get(name);
if let Some(schema) = maybe_schema {
Expand All @@ -272,11 +281,13 @@ impl CustomCatalogProviderList {
}
}
}

#[async_trait]
impl CatalogProviderList for CustomCatalogProviderList {
fn as_any(&self) -> &dyn Any {
self
}
fn register_catalog(
async fn register_catalog(
&self,
name: String,
catalog: Arc<dyn CatalogProvider>,
Expand All @@ -287,13 +298,13 @@ impl CatalogProviderList for CustomCatalogProviderList {
}

/// Retrieves the list of available catalog names
fn catalog_names(&self) -> Vec<String> {
async fn catalog_names(&self) -> Vec<String> {
let cats = self.catalogs.read().unwrap();
cats.keys().cloned().collect()
}

/// Retrieves a specific catalog by name, provided it exists.
fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
async fn catalog(&self, name: &str) -> Option<Arc<dyn CatalogProvider>> {
let cats = self.catalogs.read().unwrap();
cats.get(name).cloned()
}
Expand Down
7 changes: 5 additions & 2 deletions datafusion-examples/examples/custom_file_format.rs
Original file line number Diff line number Diff line change
Expand Up @@ -179,7 +179,10 @@ impl GetExt for TSVFileFactory {
#[tokio::main]
async fn main() -> Result<()> {
// Create a new context with the default configuration
let mut state = SessionStateBuilder::new().with_default_features().build();
let mut state = SessionStateBuilder::new()
.with_default_features()
.build()
.await;

// Register the custom file format
let file_format = Arc::new(TSVFileFactory::new());
Expand All @@ -189,7 +192,7 @@ async fn main() -> Result<()> {
let ctx = SessionContext::new_with_state(state);

let mem_table = create_mem_table();
ctx.register_table("mem_table", mem_table).unwrap();
ctx.register_table("mem_table", mem_table).await.unwrap();

let temp_dir = tempdir().unwrap();
let table_save_path = temp_dir.path().join("mem_table.tsv");
Expand Down
4 changes: 2 additions & 2 deletions datafusion-examples/examples/dataframe.rs
Original file line number Diff line number Diff line change
Expand Up @@ -109,7 +109,7 @@ async fn read_csv(ctx: &SessionContext) -> Result<()> {

// You can also create DataFrames from the result of sql queries
// and using the `enable_url_table` refer to local files directly
let dyn_ctx = ctx.clone().enable_url_table();
let dyn_ctx = ctx.clone().enable_url_table().await;
let csv_df = dyn_ctx
.sql(&format!("SELECT rating, unixtime FROM '{}'", file_path))
.await?;
Expand All @@ -127,7 +127,7 @@ async fn read_memory(ctx: &SessionContext) -> Result<()> {
let batch = RecordBatch::try_from_iter(vec![("a", a), ("b", b)])?;

// declare a table in memory. In Apache Spark API, this corresponds to createDataFrame(...).
ctx.register_batch("t", batch)?;
ctx.register_batch("t", batch).await?;
let df = ctx.table("t").await?;

// construct an expression corresponding to "SELECT a, b FROM t WHERE b = 10" in SQL
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -64,7 +64,7 @@ async fn main() -> Result<()> {
df.show().await?;

// dynamic query by the file path
let ctx = ctx.enable_url_table();
let ctx = ctx.enable_url_table().await;
let df = ctx
.sql(format!(r#"SELECT * FROM '{}' LIMIT 10"#, &path).as_str())
.await?;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,8 @@ async fn main() -> Result<()> {
let ctx = SessionContext::new();

// Display the data to show the full cycle works.
ctx.register_table("external_table", Arc::new(foreign_table_provider))?;
ctx.register_table("external_table", Arc::new(foreign_table_provider))
.await?;
let df = ctx.table("external_table").await?;
df.show().await?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/file_stream_provider.rs
Original file line number Diff line number Diff line change
Expand Up @@ -160,7 +160,7 @@ mod non_windows {
let order = vec![vec![datafusion_expr::col("a1").sort(true, false)]];

let provider = fifo_table(schema.clone(), fifo_path, order.clone());
ctx.register_table("fifo", provider)?;
ctx.register_table("fifo", provider).await?;

let df = ctx.sql("SELECT * FROM fifo").await.unwrap();
let mut stream = df.execute_stream().await.unwrap();
Expand Down
10 changes: 5 additions & 5 deletions datafusion-examples/examples/flight/flight_sql_server.rs
Original file line number Diff line number Diff line change
Expand Up @@ -170,11 +170,11 @@ impl FlightSqlServiceImpl {
let mut schemas = vec![];
let mut names = vec![];
let mut types = vec![];
for catalog in ctx.catalog_names() {
let catalog_provider = ctx.catalog(&catalog).unwrap();
for schema in catalog_provider.schema_names() {
let schema_provider = catalog_provider.schema(&schema).unwrap();
for table in schema_provider.table_names() {
for catalog in ctx.catalog_names().await {
let catalog_provider = ctx.catalog(&catalog).await.unwrap();
for schema in catalog_provider.schema_names().await {
let schema_provider = catalog_provider.schema(&schema).await.unwrap();
for table in schema_provider.table_names().await {
let table_provider =
schema_provider.table(&table).await.unwrap().unwrap();
catalogs.push(catalog.clone());
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/make_date.rs
Original file line number Diff line number Diff line change
Expand Up @@ -49,7 +49,7 @@ async fn main() -> Result<()> {
let ctx = SessionContext::new();

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
ctx.register_batch("t", batch)?;
ctx.register_batch("t", batch).await?;
let df = ctx.table("t").await?;

// use make_date function to convert col 'y', 'm' & 'd' to a date
Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/memtable.rs
Original file line number Diff line number Diff line change
Expand Up @@ -34,7 +34,7 @@ async fn main() -> Result<()> {
let ctx = SessionContext::new();

// Register the in-memory table containing the data
ctx.register_table("users", Arc::new(mem_table))?;
ctx.register_table("users", Arc::new(mem_table)).await?;

let dataframe = ctx.sql("SELECT * FROM users;").await?;

Expand Down
2 changes: 1 addition & 1 deletion datafusion-examples/examples/optimizer_rule.rs
Original file line number Diff line number Diff line change
Expand Up @@ -48,7 +48,7 @@ pub async fn main() -> Result<()> {
ctx.add_optimizer_rule(Arc::new(MyOptimizerRule {}));

// Now, let's plan and run queries with the new rule
ctx.register_batch("person", person_batch())?;
ctx.register_batch("person", person_batch()).await?;
let sql = "SELECT * FROM person WHERE age = 22";
let plan = ctx.sql(sql).await?.into_optimized_plan()?;

Expand Down
3 changes: 2 additions & 1 deletion datafusion-examples/examples/parquet_index.rs
Original file line number Diff line number Diff line change
Expand Up @@ -121,7 +121,8 @@ async fn main() -> Result<()> {
// Create a SessionContext for running queries that has the table provider
// registered as "index_table"
let ctx = SessionContext::new();
ctx.register_table("index_table", Arc::clone(&provider) as _)?;
ctx.register_table("index_table", Arc::clone(&provider) as _)
.await?;

// register object store provider for urls like `file://` work
let url = Url::try_from("file://").unwrap();
Expand Down
6 changes: 3 additions & 3 deletions datafusion-examples/examples/simple_udaf.rs
Original file line number Diff line number Diff line change
Expand Up @@ -26,7 +26,7 @@ use datafusion_common::cast::as_float64_array;
use std::sync::Arc;

// create local session context with an in-memory table
fn create_context() -> Result<SessionContext> {
async fn create_context() -> Result<SessionContext> {
use datafusion::arrow::datatypes::{Field, Schema};
use datafusion::datasource::MemTable;
// define a schema.
Expand All @@ -47,7 +47,7 @@ fn create_context() -> Result<SessionContext> {

// declare a table in memory. In spark API, this corresponds to createDataFrame(...).
let provider = MemTable::try_new(schema, vec![vec![batch1], vec![batch2]])?;
ctx.register_table("t", Arc::new(provider))?;
ctx.register_table("t", Arc::new(provider)).await?;
Ok(ctx)
}

Expand Down Expand Up @@ -137,7 +137,7 @@ impl Accumulator for GeometricMean {

#[tokio::main]
async fn main() -> Result<()> {
let ctx = create_context()?;
let ctx = create_context().await?;

// here is where we define the UDAF. We also declare its signature:
let geometric_mean = create_udaf(
Expand Down
Loading
Loading