>;
+ fn metadata(&self) -> PartitionMetadata;
+}
+
+impl PartitionSet
for Arc
+where
+ P: Partition + Clone,
+ PS: PartitionSet + Clone,
+{
+ fn get_merged_partitions(&self) -> DaftResult {
+ PS::get_merged_partitions(self)
+ }
+
+ fn get_preview_partitions(&self, num_rows: usize) -> DaftResult> {
+ PS::get_preview_partitions(self, num_rows)
+ }
+
+ fn num_partitions(&self) -> usize {
+ PS::num_partitions(self)
+ }
+
+ fn len(&self) -> usize {
+ PS::len(self)
+ }
+
+ fn size_bytes(&self) -> DaftResult {
+ PS::size_bytes(self)
+ }
+
+ fn has_partition(&self, idx: &PartitionId) -> bool {
+ PS::has_partition(self, idx)
+ }
+
+ fn delete_partition(&self, idx: &PartitionId) -> DaftResult<()> {
+ PS::delete_partition(self, idx)
+ }
+
+ fn set_partition(&self, idx: PartitionId, part: &P) -> DaftResult<()> {
+ PS::set_partition(self, idx, part)
+ }
+
+ fn get_partition(&self, idx: &PartitionId) -> DaftResult {
+ PS::get_partition(self, idx)
+ }
+
+ fn to_partition_stream(&self) -> BoxStream<'static, DaftResult
> {
+ PS::to_partition_stream(self)
+ }
+
+ fn metadata(&self) -> PartitionMetadata {
+ PS::metadata(self)
+ }
+}
+
+pub type PartitionSetRef = Arc>;
+
+pub trait PartitionSetCache>:
+ std::fmt::Debug + Send + Sync
+{
+ fn get_partition_set(&self, key: &str) -> Option>;
+ fn get_all_partition_sets(&self) -> Vec>;
+ fn put_partition_set(&self, key: &str, partition_set: &PS);
+ fn rm_partition_set(&self, key: &str);
+ fn clear(&self);
+}
+
+#[derive(Debug, Clone, Serialize, Deserialize)]
+pub enum PartitionCacheEntry {
+ #[serde(
+ serialize_with = "serialize_py_object",
+ deserialize_with = "deserialize_py_object"
+ )]
+ #[cfg(feature = "python")]
+ /// in python, the partition cache is a weakvalue dictionary, so it will store the entry as long as this reference exists.
+ Python(PyObject),
+
+ Rust {
+ key: String,
+ #[serde(skip)]
+ /// We don't ever actually reference the value, we're just holding it to ensure the partition set is kept alive.
+ ///
+ /// It's only wrapped in an `Option` to satisfy serde Deserialize. We skip (de)serializing, but serde still complains if it's not an Option.
+ value: Option>,
+ },
+}
+
+impl PartitionCacheEntry {
+ pub fn new_rust(key: String, value: Arc) -> Self {
+ Self::Rust {
+ key,
+ value: Some(value),
+ }
+ }
+}
diff --git a/src/daft-connect/src/lib.rs b/src/daft-connect/src/lib.rs
index 369bfe8e47..1371421396 100644
--- a/src/daft-connect/src/lib.rs
+++ b/src/daft-connect/src/lib.rs
@@ -30,6 +30,7 @@ use crate::session::Session;
mod config;
mod err;
mod op;
+
mod session;
mod translation;
pub mod util;
diff --git a/src/daft-connect/src/op/execute/root.rs b/src/daft-connect/src/op/execute/root.rs
index fb18318708..4a9feb4ed7 100644
--- a/src/daft-connect/src/op/execute/root.rs
+++ b/src/daft-connect/src/op/execute/root.rs
@@ -10,7 +10,6 @@ use crate::{
op::execute::{ExecuteStream, PlanIds},
session::Session,
translation,
- translation::Plan,
};
impl Session {
@@ -30,19 +29,24 @@ impl Session {
let finished = context.finished();
let (tx, rx) = tokio::sync::mpsc::channel::>(1);
+
+ let pset = self.psets.clone();
+
tokio::spawn(async move {
let execution_fut = async {
- let Plan { builder, psets } = translation::to_logical_plan(command).await?;
+ let translator = translation::SparkAnalyzer::new(&pset);
+ let lp = translator.to_logical_plan(command).await?;
// todo: convert optimize to async (looks like A LOT of work)... it touches a lot of API
// I tried and spent about an hour and gave up ~ Andrew Gazelka 🪦 2024-12-09
- let optimized_plan = tokio::task::spawn_blocking(move || builder.optimize())
+ let optimized_plan = tokio::task::spawn_blocking(move || lp.optimize())
.await
.unwrap()?;
let cfg = Arc::new(DaftExecutionConfig::default());
let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?;
- let mut result_stream = native_executor.run(psets, cfg, None)?.into_stream();
+
+ let mut result_stream = native_executor.run(&pset, cfg, None)?.into_stream();
while let Some(result) = result_stream.next().await {
let result = result?;
diff --git a/src/daft-connect/src/op/execute/write.rs b/src/daft-connect/src/op/execute/write.rs
index 44696f8164..5db783f5e1 100644
--- a/src/daft-connect/src/op/execute/write.rs
+++ b/src/daft-connect/src/op/execute/write.rs
@@ -32,6 +32,7 @@ impl Session {
};
let finished = context.finished();
+ let pset = self.psets.clone();
let result = async move {
let WriteOperation {
@@ -109,19 +110,19 @@ impl Session {
}
};
- let mut plan = translation::to_logical_plan(input).await?;
+ let translator = translation::SparkAnalyzer::new(&pset);
- plan.builder = plan
- .builder
+ let plan = translator.to_logical_plan(input).await?;
+
+ let plan = plan
.table_write(&path, FileFormat::Parquet, None, None, None)
.wrap_err("Failed to create table write plan")?;
- let optimized_plan = plan.builder.optimize()?;
+ let optimized_plan = plan.optimize()?;
let cfg = DaftExecutionConfig::default();
let native_executor = NativeExecutor::from_logical_plan_builder(&optimized_plan)?;
- let mut result_stream = native_executor
- .run(plan.psets, cfg.into(), None)?
- .into_stream();
+
+ let mut result_stream = native_executor.run(&pset, cfg.into(), None)?.into_stream();
// this is so we make sure the operation is actually done
// before we return
diff --git a/src/daft-connect/src/session.rs b/src/daft-connect/src/session.rs
index 24f7fabe80..30f827ba9e 100644
--- a/src/daft-connect/src/session.rs
+++ b/src/daft-connect/src/session.rs
@@ -1,5 +1,6 @@
use std::collections::BTreeMap;
+use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use uuid::Uuid;
pub struct Session {
@@ -10,6 +11,9 @@ pub struct Session {
id: String,
server_side_session_id: String,
+ /// MicroPartitionSet associated with this session
+ /// this will be filled up as the user runs queries
+ pub(crate) psets: InMemoryPartitionSetCache,
}
impl Session {
@@ -24,10 +28,12 @@ impl Session {
pub fn new(id: String) -> Self {
let server_side_session_id = Uuid::new_v4();
let server_side_session_id = server_side_session_id.to_string();
+
Self {
config_values: Default::default(),
id,
server_side_session_id,
+ psets: InMemoryPartitionSetCache::empty(),
}
}
diff --git a/src/daft-connect/src/translation.rs b/src/daft-connect/src/translation.rs
index 8b61b93f98..5d9bf89881 100644
--- a/src/daft-connect/src/translation.rs
+++ b/src/daft-connect/src/translation.rs
@@ -9,5 +9,5 @@ mod schema;
pub use datatype::{deser_spark_datatype, to_daft_datatype, to_spark_datatype};
pub use expr::to_daft_expr;
pub use literal::to_daft_literal;
-pub use logical_plan::{to_logical_plan, Plan};
+pub use logical_plan::SparkAnalyzer;
pub use schema::relation_to_schema;
diff --git a/src/daft-connect/src/translation/datatype/codec.rs b/src/daft-connect/src/translation/datatype/codec.rs
index 50f2d94a02..4f554765ba 100644
--- a/src/daft-connect/src/translation/datatype/codec.rs
+++ b/src/daft-connect/src/translation/datatype/codec.rs
@@ -2,7 +2,6 @@ use color_eyre::Help;
use eyre::{bail, ensure, eyre};
use serde_json::Value;
use spark_connect::data_type::Kind;
-use tracing::warn;
#[derive(Debug)]
enum TypeTag {
@@ -211,12 +210,10 @@ fn deser_struct_field(
bail!("expected object");
};
- let Some(metadata) = object.remove("metadata") else {
+ let Some(_metadata) = object.remove("metadata") else {
bail!("missing metadata");
};
- warn!("ignoring metadata: {metadata:?}");
-
let Some(name) = object.remove("name") else {
bail!("missing name");
};
diff --git a/src/daft-connect/src/translation/logical_plan.rs b/src/daft-connect/src/translation/logical_plan.rs
index b6097d17ad..15eb495502 100644
--- a/src/daft-connect/src/translation/logical_plan.rs
+++ b/src/daft-connect/src/translation/logical_plan.rs
@@ -1,14 +1,9 @@
use daft_logical_plan::LogicalPlanBuilder;
-use daft_micropartition::partitioning::InMemoryPartitionSet;
+use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use eyre::{bail, Context};
use spark_connect::{relation::RelType, Limit, Relation};
use tracing::warn;
-use crate::translation::logical_plan::{
- aggregate::aggregate, drop::drop, filter::filter, local_relation::local_relation,
- project::project, range::range, read::read, to_df::to_df, with_columns::with_columns,
-};
-
mod aggregate;
mod drop;
mod filter;
@@ -19,82 +14,86 @@ mod read;
mod to_df;
mod with_columns;
-pub struct Plan {
- pub builder: LogicalPlanBuilder,
- pub psets: InMemoryPartitionSet,
+pub struct SparkAnalyzer<'a> {
+ pub psets: &'a InMemoryPartitionSetCache,
}
-impl Plan {
- pub fn new(builder: LogicalPlanBuilder) -> Self {
- Self {
- builder,
- psets: InMemoryPartitionSet::default(),
- }
+impl SparkAnalyzer<'_> {
+ pub fn new(pset: &InMemoryPartitionSetCache) -> SparkAnalyzer {
+ SparkAnalyzer { psets: pset }
}
-}
-impl From for Plan {
- fn from(builder: LogicalPlanBuilder) -> Self {
- Self {
- builder,
- psets: InMemoryPartitionSet::default(),
- }
- }
-}
+ pub async fn to_logical_plan(&self, relation: Relation) -> eyre::Result {
+ let Some(common) = relation.common else {
+ bail!("Common metadata is required");
+ };
-pub async fn to_logical_plan(relation: Relation) -> eyre::Result {
- if let Some(common) = relation.common {
if common.origin.is_some() {
warn!("Ignoring common metadata for relation: {common:?}; not yet implemented");
}
- };
- let Some(rel_type) = relation.rel_type else {
- bail!("Relation type is required");
- };
+ let Some(rel_type) = relation.rel_type else {
+ bail!("Relation type is required");
+ };
- match rel_type {
- RelType::Limit(l) => limit(*l)
- .await
- .wrap_err("Failed to apply limit to logical plan"),
- RelType::Range(r) => range(r).wrap_err("Failed to apply range to logical plan"),
- RelType::Project(p) => project(*p)
- .await
- .wrap_err("Failed to apply project to logical plan"),
- RelType::Filter(f) => filter(*f)
- .await
- .wrap_err("Failed to apply filter to logical plan"),
- RelType::Aggregate(a) => aggregate(*a)
- .await
- .wrap_err("Failed to apply aggregate to logical plan"),
- RelType::WithColumns(w) => with_columns(*w)
- .await
- .wrap_err("Failed to apply with_columns to logical plan"),
- RelType::ToDf(t) => to_df(*t)
- .await
- .wrap_err("Failed to apply to_df to logical plan"),
- RelType::LocalRelation(l) => {
- local_relation(l).wrap_err("Failed to apply local_relation to logical plan")
+ match rel_type {
+ RelType::Limit(l) => self
+ .limit(*l)
+ .await
+ .wrap_err("Failed to apply limit to logical plan"),
+ RelType::Range(r) => self
+ .range(r)
+ .wrap_err("Failed to apply range to logical plan"),
+ RelType::Project(p) => self
+ .project(*p)
+ .await
+ .wrap_err("Failed to apply project to logical plan"),
+ RelType::Aggregate(a) => self
+ .aggregate(*a)
+ .await
+ .wrap_err("Failed to apply aggregate to logical plan"),
+ RelType::WithColumns(w) => self
+ .with_columns(*w)
+ .await
+ .wrap_err("Failed to apply with_columns to logical plan"),
+ RelType::ToDf(t) => self
+ .to_df(*t)
+ .await
+ .wrap_err("Failed to apply to_df to logical plan"),
+ RelType::LocalRelation(l) => {
+ let Some(plan_id) = common.plan_id else {
+ bail!("Plan ID is required for LocalRelation");
+ };
+ self.local_relation(plan_id, l)
+ .wrap_err("Failed to apply local_relation to logical plan")
+ }
+ RelType::Read(r) => read::read(r)
+ .await
+ .wrap_err("Failed to apply read to logical plan"),
+ RelType::Drop(d) => self
+ .drop(*d)
+ .await
+ .wrap_err("Failed to apply drop to logical plan"),
+ RelType::Filter(f) => self
+ .filter(*f)
+ .await
+ .wrap_err("Failed to apply filter to logical plan"),
+ plan => bail!("Unsupported relation type: {plan:?}"),
}
- RelType::Read(r) => read(r)
- .await
- .wrap_err("Failed to apply read to logical plan"),
- RelType::Drop(d) => drop(*d)
- .await
- .wrap_err("Failed to apply drop to logical plan"),
- plan => bail!("Unsupported relation type: {plan:?}"),
}
}
-async fn limit(limit: Limit) -> eyre::Result {
- let Limit { input, limit } = limit;
+impl SparkAnalyzer<'_> {
+ async fn limit(&self, limit: Limit) -> eyre::Result {
+ let Limit { input, limit } = limit;
- let Some(input) = input else {
- bail!("input must be set");
- };
+ let Some(input) = input else {
+ bail!("input must be set");
+ };
- let mut plan = Box::pin(to_logical_plan(*input)).await?;
- plan.builder = plan.builder.limit(i64::from(limit), false)?; // todo: eager or no
+ let plan = Box::pin(self.to_logical_plan(*input)).await?;
- Ok(plan)
+ plan.limit(i64::from(limit), false)
+ .wrap_err("Failed to apply limit to logical plan")
+ }
}
diff --git a/src/daft-connect/src/translation/logical_plan/aggregate.rs b/src/daft-connect/src/translation/logical_plan/aggregate.rs
index 3687f191f8..2a46b0cbba 100644
--- a/src/daft-connect/src/translation/logical_plan/aggregate.rs
+++ b/src/daft-connect/src/translation/logical_plan/aggregate.rs
@@ -1,52 +1,59 @@
+use daft_logical_plan::LogicalPlanBuilder;
use eyre::{bail, WrapErr};
use spark_connect::aggregate::GroupType;
-use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan};
+use super::SparkAnalyzer;
+use crate::translation::to_daft_expr;
-pub async fn aggregate(aggregate: spark_connect::Aggregate) -> eyre::Result {
- let spark_connect::Aggregate {
- input,
- group_type,
- grouping_expressions,
- aggregate_expressions,
- pivot,
- grouping_sets,
- } = aggregate;
+impl SparkAnalyzer<'_> {
+ pub async fn aggregate(
+ &self,
+ aggregate: spark_connect::Aggregate,
+ ) -> eyre::Result {
+ let spark_connect::Aggregate {
+ input,
+ group_type,
+ grouping_expressions,
+ aggregate_expressions,
+ pivot,
+ grouping_sets,
+ } = aggregate;
- let Some(input) = input else {
- bail!("input is required");
- };
+ let Some(input) = input else {
+ bail!("input is required");
+ };
- let mut plan = Box::pin(to_logical_plan(*input)).await?;
+ let mut plan = Box::pin(self.to_logical_plan(*input)).await?;
- let group_type = GroupType::try_from(group_type)
- .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?;
+ let group_type = GroupType::try_from(group_type)
+ .wrap_err_with(|| format!("Invalid group type: {group_type:?}"))?;
- assert_groupby(group_type)?;
+ assert_groupby(group_type)?;
- if let Some(pivot) = pivot {
- bail!("Pivot not yet supported; got {pivot:?}");
- }
+ if let Some(pivot) = pivot {
+ bail!("Pivot not yet supported; got {pivot:?}");
+ }
- if !grouping_sets.is_empty() {
- bail!("Grouping sets not yet supported; got {grouping_sets:?}");
- }
+ if !grouping_sets.is_empty() {
+ bail!("Grouping sets not yet supported; got {grouping_sets:?}");
+ }
- let grouping_expressions: Vec<_> = grouping_expressions
- .iter()
- .map(to_daft_expr)
- .try_collect()?;
+ let grouping_expressions: Vec<_> = grouping_expressions
+ .iter()
+ .map(to_daft_expr)
+ .try_collect()?;
- let aggregate_expressions: Vec<_> = aggregate_expressions
- .iter()
- .map(to_daft_expr)
- .try_collect()?;
+ let aggregate_expressions: Vec<_> = aggregate_expressions
+ .iter()
+ .map(to_daft_expr)
+ .try_collect()?;
- plan.builder = plan.builder
+ plan = plan
.aggregate(aggregate_expressions.clone(), grouping_expressions.clone())
.wrap_err_with(|| format!("Failed to apply aggregate to logical plan aggregate_expressions={aggregate_expressions:?} grouping_expressions={grouping_expressions:?}"))?;
- Ok(plan)
+ Ok(plan)
+ }
}
fn assert_groupby(plan: GroupType) -> eyre::Result<()> {
diff --git a/src/daft-connect/src/translation/logical_plan/drop.rs b/src/daft-connect/src/translation/logical_plan/drop.rs
index 35613add68..b5cac5a41b 100644
--- a/src/daft-connect/src/translation/logical_plan/drop.rs
+++ b/src/daft-connect/src/translation/logical_plan/drop.rs
@@ -1,39 +1,40 @@
+use daft_logical_plan::LogicalPlanBuilder;
use eyre::bail;
-use crate::translation::{to_logical_plan, Plan};
+use super::SparkAnalyzer;
-pub async fn drop(drop: spark_connect::Drop) -> eyre::Result {
- let spark_connect::Drop {
- input,
- columns,
- column_names,
- } = drop;
+impl SparkAnalyzer<'_> {
+ pub async fn drop(&self, drop: spark_connect::Drop) -> eyre::Result {
+ let spark_connect::Drop {
+ input,
+ columns,
+ column_names,
+ } = drop;
- let Some(input) = input else {
- bail!("input is required");
- };
+ let Some(input) = input else {
+ bail!("input is required");
+ };
- if !columns.is_empty() {
- bail!("columns is not supported; use column_names instead");
- }
-
- let mut plan = Box::pin(to_logical_plan(*input)).await?;
+ if !columns.is_empty() {
+ bail!("columns is not supported; use column_names instead");
+ }
- // Get all column names from the schema
- let all_columns = plan.builder.schema().names();
+ let plan = Box::pin(self.to_logical_plan(*input)).await?;
- // Create a set of columns to drop for efficient lookup
- let columns_to_drop: std::collections::HashSet<_> = column_names.iter().collect();
+ // Get all column names from the schema
+ let all_columns = plan.schema().names();
- // Create expressions for all columns except the ones being dropped
- let to_select = all_columns
- .iter()
- .filter(|col_name| !columns_to_drop.contains(*col_name))
- .map(|col_name| daft_dsl::col(col_name.clone()))
- .collect();
+ // Create a set of columns to drop for efficient lookup
+ let columns_to_drop: std::collections::HashSet<_> = column_names.iter().collect();
- // Use select to keep only the columns we want
- plan.builder = plan.builder.select(to_select)?;
+ // Create expressions for all columns except the ones being dropped
+ let to_select = all_columns
+ .iter()
+ .filter(|col_name| !columns_to_drop.contains(*col_name))
+ .map(|col_name| daft_dsl::col(col_name.clone()))
+ .collect();
- Ok(plan)
+ // Use select to keep only the columns we want
+ Ok(plan.select(to_select)?)
+ }
}
diff --git a/src/daft-connect/src/translation/logical_plan/filter.rs b/src/daft-connect/src/translation/logical_plan/filter.rs
index 6879464abc..43ad4c7a52 100644
--- a/src/daft-connect/src/translation/logical_plan/filter.rs
+++ b/src/daft-connect/src/translation/logical_plan/filter.rs
@@ -1,22 +1,24 @@
+use daft_logical_plan::LogicalPlanBuilder;
use eyre::bail;
-use crate::translation::{to_daft_expr, to_logical_plan, Plan};
+use super::SparkAnalyzer;
+use crate::translation::to_daft_expr;
-pub async fn filter(filter: spark_connect::Filter) -> eyre::Result {
- let spark_connect::Filter { input, condition } = filter;
+impl SparkAnalyzer<'_> {
+ pub async fn filter(&self, filter: spark_connect::Filter) -> eyre::Result {
+ let spark_connect::Filter { input, condition } = filter;
- let Some(input) = input else {
- bail!("input is required");
- };
+ let Some(input) = input else {
+ bail!("input is required");
+ };
- let Some(condition) = condition else {
- bail!("condition is required");
- };
+ let Some(condition) = condition else {
+ bail!("condition is required");
+ };
- let condition = to_daft_expr(&condition)?;
+ let condition = to_daft_expr(&condition)?;
- let mut plan = Box::pin(to_logical_plan(*input)).await?;
- plan.builder = plan.builder.filter(condition)?;
-
- Ok(plan)
+ let plan = Box::pin(self.to_logical_plan(*input)).await?;
+ Ok(plan.filter(condition)?)
+ }
}
diff --git a/src/daft-connect/src/translation/logical_plan/local_relation.rs b/src/daft-connect/src/translation/logical_plan/local_relation.rs
index 7244ed09b9..574e35a2fd 100644
--- a/src/daft-connect/src/translation/logical_plan/local_relation.rs
+++ b/src/daft-connect/src/translation/logical_plan/local_relation.rs
@@ -1,32 +1,28 @@
-use std::{collections::HashMap, io::Cursor, sync::Arc};
+use std::{io::Cursor, sync::Arc};
use arrow2::io::ipc::{
read::{StreamMetadata, StreamReader, StreamState, Version},
IpcField, IpcSchema,
};
use daft_core::series::Series;
-use daft_logical_plan::{
- logical_plan::Source, InMemoryInfo, LogicalPlan, LogicalPlanBuilder, PyLogicalPlanBuilder,
- SourceInfo,
+use daft_logical_plan::LogicalPlanBuilder;
+use daft_micropartition::partitioning::{
+ MicroPartitionSet, PartitionCacheEntry, PartitionMetadata, PartitionSet, PartitionSetCache,
};
-use daft_micropartition::partitioning::InMemoryPartitionSet;
use daft_schema::dtype::DaftDataType;
use daft_table::Table;
use eyre::{bail, ensure, WrapErr};
use itertools::Itertools;
-use crate::translation::{deser_spark_datatype, logical_plan::Plan, to_daft_datatype};
+use super::SparkAnalyzer;
+use crate::translation::{deser_spark_datatype, to_daft_datatype};
-pub fn local_relation(plan: spark_connect::LocalRelation) -> eyre::Result {
- #[cfg(not(feature = "python"))]
- {
- bail!("LocalRelation plan is only supported in Python mode");
- }
-
- #[cfg(feature = "python")]
- {
- use daft_micropartition::{python::PyMicroPartition, MicroPartition};
- use pyo3::{types::PyAnyMethods, Python};
+impl SparkAnalyzer<'_> {
+ pub fn local_relation(
+ &self,
+ plan_id: i64,
+ plan: spark_connect::LocalRelation,
+ ) -> eyre::Result {
let spark_connect::LocalRelation { data, schema } = plan;
let Some(data) = data else {
@@ -139,67 +135,33 @@ pub fn local_relation(plan: spark_connect::LocalRelation) -> eyre::Result
"Mismatch in row counts across columns; all columns must have the same number of rows."
);
- let Some(&num_rows) = num_rows.first() else {
- bail!("No columns were found; at least one column is required.")
- };
+ let batch = Table::from_nonempty_columns(columns)?;
- let table = Table::new_with_size(daft_schema.clone(), columns, num_rows)
- .wrap_err("Failed to create Table from columns and schema.")?;
-
- tables.push(table);
+ tables.push(batch);
}
tables
};
- // Note: Verify if the Daft schema used here matches the schema of the table.
- let micro_partition = MicroPartition::new_loaded(daft_schema, Arc::new(tables), None);
- let micro_partition = Arc::new(micro_partition);
-
- let plan = Python::with_gil(|py| {
- // Convert MicroPartition to a logical plan using Python interop.
- let py_micropartition = py
- .import_bound(pyo3::intern!(py, "daft.table"))?
- .getattr(pyo3::intern!(py, "MicroPartition"))?
- .getattr(pyo3::intern!(py, "_from_pymicropartition"))?
- .call1((PyMicroPartition::from(micro_partition.clone()),))?;
-
- // ERROR: 2: AttributeError: 'daft.daft.PySchema' object has no attribute '_schema'
- let py_plan_builder = py
- .import_bound(pyo3::intern!(py, "daft.dataframe.dataframe"))?
- .getattr(pyo3::intern!(py, "to_logical_plan_builder"))?
- .call1((py_micropartition,))?;
-
- let py_plan_builder = py_plan_builder.getattr(pyo3::intern!(py, "_builder"))?;
-
- let plan: PyLogicalPlanBuilder = py_plan_builder.extract()?;
-
- Ok::<_, eyre::Error>(plan.builder)
- })?;
-
- let cache_key = grab_singular_cache_key(&plan)?;
-
- let mut psets = HashMap::new();
- psets.insert(cache_key, vec![micro_partition]);
-
- let plan = Plan {
- builder: plan,
- psets: InMemoryPartitionSet::new(psets),
- };
-
- Ok(plan)
+ let pset = MicroPartitionSet::from_tables(plan_id as usize, tables)?;
+ let PartitionMetadata {
+ size_bytes,
+ num_rows,
+ } = pset.metadata();
+ let num_partitions = pset.num_partitions();
+
+ let partition_key: Arc = uuid::Uuid::new_v4().to_string().into();
+ let pset = Arc::new(pset);
+ self.psets.put_partition_set(&partition_key, &pset);
+
+ let lp = LogicalPlanBuilder::in_memory_scan(
+ &partition_key,
+ PartitionCacheEntry::new_rust(partition_key.to_string(), pset),
+ daft_schema,
+ num_partitions,
+ size_bytes,
+ num_rows,
+ )?;
+
+ Ok(lp)
}
}
-
-fn grab_singular_cache_key(plan: &LogicalPlanBuilder) -> eyre::Result {
- let plan = &*plan.plan;
-
- let LogicalPlan::Source(Source { source_info, .. }) = plan else {
- bail!("Expected a source plan");
- };
-
- let SourceInfo::InMemory(InMemoryInfo { cache_key, .. }) = &**source_info else {
- bail!("Expected an in-memory source");
- };
-
- Ok(cache_key.clone())
-}
diff --git a/src/daft-connect/src/translation/logical_plan/project.rs b/src/daft-connect/src/translation/logical_plan/project.rs
index af03c8dc2e..448242d31d 100644
--- a/src/daft-connect/src/translation/logical_plan/project.rs
+++ b/src/daft-connect/src/translation/logical_plan/project.rs
@@ -3,22 +3,26 @@
//! TL;DR: Project is Spark's equivalent of SQL SELECT - it selects columns, renames them via aliases,
//! and creates new columns from expressions. Example: `df.select(col("id").alias("my_number"))`
+use daft_logical_plan::LogicalPlanBuilder;
use eyre::bail;
use spark_connect::Project;
-use crate::translation::{logical_plan::Plan, to_daft_expr, to_logical_plan};
+use super::SparkAnalyzer;
+use crate::translation::to_daft_expr;
-pub async fn project(project: Project) -> eyre::Result {
- let Project { input, expressions } = project;
+impl SparkAnalyzer<'_> {
+ pub async fn project(&self, project: Project) -> eyre::Result {
+ let Project { input, expressions } = project;
- let Some(input) = input else {
- bail!("Project input is required");
- };
+ let Some(input) = input else {
+ bail!("Project input is required");
+ };
- let mut plan = Box::pin(to_logical_plan(*input)).await?;
+ let mut plan = Box::pin(self.to_logical_plan(*input)).await?;
- let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?;
- plan.builder = plan.builder.select(daft_exprs)?;
+ let daft_exprs: Vec<_> = expressions.iter().map(to_daft_expr).try_collect()?;
+ plan = plan.select(daft_exprs)?;
- Ok(plan)
+ Ok(plan)
+ }
}
diff --git a/src/daft-connect/src/translation/logical_plan/range.rs b/src/daft-connect/src/translation/logical_plan/range.rs
index ff15e0cacb..1660bef5bf 100644
--- a/src/daft-connect/src/translation/logical_plan/range.rs
+++ b/src/daft-connect/src/translation/logical_plan/range.rs
@@ -2,56 +2,59 @@ use daft_logical_plan::LogicalPlanBuilder;
use eyre::{ensure, Context};
use spark_connect::Range;
-use crate::translation::logical_plan::Plan;
+use super::SparkAnalyzer;
-pub fn range(range: Range) -> eyre::Result {
- #[cfg(not(feature = "python"))]
- {
- use eyre::bail;
- bail!("Range operations require Python feature to be enabled");
- }
+impl SparkAnalyzer<'_> {
+ pub fn range(&self, range: Range) -> eyre::Result {
+ #[cfg(not(feature = "python"))]
+ {
+ use eyre::bail;
+ bail!("Range operations require Python feature to be enabled");
+ }
- #[cfg(feature = "python")]
- {
- use daft_scan::python::pylib::ScanOperatorHandle;
- use pyo3::prelude::*;
- let Range {
- start,
- end,
- step,
- num_partitions,
- } = range;
+ #[cfg(feature = "python")]
+ {
+ use daft_scan::python::pylib::ScanOperatorHandle;
+ use pyo3::prelude::*;
+ let Range {
+ start,
+ end,
+ step,
+ num_partitions,
+ } = range;
- let partitions = num_partitions.unwrap_or(1);
+ let partitions = num_partitions.unwrap_or(1);
- ensure!(partitions > 0, "num_partitions must be greater than 0");
+ ensure!(partitions > 0, "num_partitions must be greater than 0");
- let start = start.unwrap_or(0);
+ let start = start.unwrap_or(0);
- let step = usize::try_from(step).wrap_err("step must be a positive integer")?;
- ensure!(step > 0, "step must be greater than 0");
+ let step = usize::try_from(step).wrap_err("step must be a positive integer")?;
+ ensure!(step > 0, "step must be greater than 0");
- let plan = Python::with_gil(|py| {
- let range_module = PyModule::import_bound(py, "daft.io._range")
- .wrap_err("Failed to import range module")?;
+ let plan = Python::with_gil(|py| {
+ let range_module = PyModule::import_bound(py, "daft.io._range")
+ .wrap_err("Failed to import range module")?;
- let range = range_module
- .getattr(pyo3::intern!(py, "RangeScanOperator"))
- .wrap_err("Failed to get range function")?;
+ let range = range_module
+ .getattr(pyo3::intern!(py, "RangeScanOperator"))
+ .wrap_err("Failed to get range function")?;
- let range = range
- .call1((start, end, step, partitions))
- .wrap_err("Failed to create range scan operator")?
- .to_object(py);
+ let range = range
+ .call1((start, end, step, partitions))
+ .wrap_err("Failed to create range scan operator")?
+ .to_object(py);
- let scan_operator_handle = ScanOperatorHandle::from_python_scan_operator(range, py)?;
+ let scan_operator_handle =
+ ScanOperatorHandle::from_python_scan_operator(range, py)?;
- let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?;
+ let plan = LogicalPlanBuilder::table_scan(scan_operator_handle.into(), None)?;
- eyre::Result::<_>::Ok(plan)
- })
- .wrap_err("Failed to create range scan")?;
+ eyre::Result::<_>::Ok(plan)
+ })
+ .wrap_err("Failed to create range scan")?;
- Ok(plan.into())
+ Ok(plan)
+ }
}
}
diff --git a/src/daft-connect/src/translation/logical_plan/read.rs b/src/daft-connect/src/translation/logical_plan/read.rs
index fc8a834fbb..9a73783191 100644
--- a/src/daft-connect/src/translation/logical_plan/read.rs
+++ b/src/daft-connect/src/translation/logical_plan/read.rs
@@ -1,12 +1,11 @@
+use daft_logical_plan::LogicalPlanBuilder;
use eyre::{bail, WrapErr};
use spark_connect::read::ReadType;
use tracing::warn;
-use crate::translation::Plan;
-
mod data_source;
-pub async fn read(read: spark_connect::Read) -> eyre::Result {
+pub async fn read(read: spark_connect::Read) -> eyre::Result {
let spark_connect::Read {
is_streaming,
read_type,
@@ -28,5 +27,5 @@ pub async fn read(read: spark_connect::Read) -> eyre::Result {
.wrap_err("Failed to create data source"),
}?;
- Ok(Plan::from(builder))
+ Ok(builder)
}
diff --git a/src/daft-connect/src/translation/logical_plan/to_df.rs b/src/daft-connect/src/translation/logical_plan/to_df.rs
index c2a355a1e5..e3d172661b 100644
--- a/src/daft-connect/src/translation/logical_plan/to_df.rs
+++ b/src/daft-connect/src/translation/logical_plan/to_df.rs
@@ -1,30 +1,28 @@
+use daft_logical_plan::LogicalPlanBuilder;
use eyre::{bail, WrapErr};
-use crate::translation::{logical_plan::Plan, to_logical_plan};
+use super::SparkAnalyzer;
+impl SparkAnalyzer<'_> {
+ pub async fn to_df(&self, to_df: spark_connect::ToDf) -> eyre::Result {
+ let spark_connect::ToDf {
+ input,
+ column_names,
+ } = to_df;
-pub async fn to_df(to_df: spark_connect::ToDf) -> eyre::Result {
- let spark_connect::ToDf {
- input,
- column_names,
- } = to_df;
+ let Some(input) = input else {
+ bail!("Input is required");
+ };
- let Some(input) = input else {
- bail!("Input is required");
- };
+ let mut plan = Box::pin(self.to_logical_plan(*input)).await?;
- let mut plan = Box::pin(to_logical_plan(*input))
- .await
- .wrap_err("Failed to translate relation to logical plan")?;
+ let column_names: Vec<_> = column_names
+ .iter()
+ .map(|s| daft_dsl::col(s.as_str()))
+ .collect();
- let column_names: Vec<_> = column_names
- .iter()
- .map(|s| daft_dsl::col(s.as_str()))
- .collect();
-
- plan.builder = plan
- .builder
- .select(column_names)
- .wrap_err("Failed to add columns to logical plan")?;
-
- Ok(plan)
+ plan = plan
+ .select(column_names)
+ .wrap_err("Failed to add columns to logical plan")?;
+ Ok(plan)
+ }
}
diff --git a/src/daft-connect/src/translation/logical_plan/with_columns.rs b/src/daft-connect/src/translation/logical_plan/with_columns.rs
index 08396ecdba..97b3c3d1d1 100644
--- a/src/daft-connect/src/translation/logical_plan/with_columns.rs
+++ b/src/daft-connect/src/translation/logical_plan/with_columns.rs
@@ -1,30 +1,35 @@
+use daft_logical_plan::LogicalPlanBuilder;
use eyre::bail;
use spark_connect::{expression::ExprType, Expression};
-use crate::translation::{to_daft_expr, to_logical_plan, Plan};
+use super::SparkAnalyzer;
+use crate::translation::to_daft_expr;
-pub async fn with_columns(with_columns: spark_connect::WithColumns) -> eyre::Result {
- let spark_connect::WithColumns { input, aliases } = with_columns;
+impl SparkAnalyzer<'_> {
+ pub async fn with_columns(
+ &self,
+ with_columns: spark_connect::WithColumns,
+ ) -> eyre::Result {
+ let spark_connect::WithColumns { input, aliases } = with_columns;
- let Some(input) = input else {
- bail!("input is required");
- };
+ let Some(input) = input else {
+ bail!("input is required");
+ };
- let mut plan = Box::pin(to_logical_plan(*input)).await?;
+ let plan = Box::pin(self.to_logical_plan(*input)).await?;
- let daft_exprs: Vec<_> = aliases
- .into_iter()
- .map(|alias| {
- let expression = Expression {
- common: None,
- expr_type: Some(ExprType::Alias(Box::new(alias))),
- };
+ let daft_exprs: Vec<_> = aliases
+ .into_iter()
+ .map(|alias| {
+ let expression = Expression {
+ common: None,
+ expr_type: Some(ExprType::Alias(Box::new(alias))),
+ };
- to_daft_expr(&expression)
- })
- .try_collect()?;
+ to_daft_expr(&expression)
+ })
+ .try_collect()?;
- plan.builder = plan.builder.with_columns(daft_exprs)?;
-
- Ok(plan)
+ Ok(plan.with_columns(daft_exprs)?)
+ }
}
diff --git a/src/daft-connect/src/translation/schema.rs b/src/daft-connect/src/translation/schema.rs
index 1868eaeb2d..605f1b640a 100644
--- a/src/daft-connect/src/translation/schema.rs
+++ b/src/daft-connect/src/translation/schema.rs
@@ -1,10 +1,12 @@
+use daft_micropartition::partitioning::InMemoryPartitionSetCache;
use spark_connect::{
data_type::{Kind, Struct, StructField},
DataType, Relation,
};
use tracing::warn;
-use crate::translation::{to_logical_plan, to_spark_datatype};
+use super::SparkAnalyzer;
+use crate::translation::to_spark_datatype;
#[tracing::instrument(skip_all)]
pub async fn relation_to_schema(input: Relation) -> eyre::Result {
@@ -14,9 +16,12 @@ pub async fn relation_to_schema(input: Relation) -> eyre::Result {
}
}
- let plan = Box::pin(to_logical_plan(input)).await?;
+ // We're just checking the schema here, so we don't need to use a persistent cache as it won't be used
+ let pset = InMemoryPartitionSetCache::empty();
+ let translator = SparkAnalyzer::new(&pset);
+ let plan = Box::pin(translator.to_logical_plan(input)).await?;
- let result = plan.builder.schema();
+ let result = plan.schema();
let fields: eyre::Result