From c5adbeb2e04fd12942eef4cfbfbcb9e9787263a6 Mon Sep 17 00:00:00 2001 From: Xiayue Charles Lin Date: Thu, 14 Sep 2023 14:37:01 -0600 Subject: [PATCH] Fix using request::add instead of ::max --- .../rules/push_down_projection.rs | 7 ++++-- src/daft-plan/src/resource_request.rs | 23 ++++++++++--------- 2 files changed, 17 insertions(+), 13 deletions(-) diff --git a/src/daft-plan/src/optimization/rules/push_down_projection.rs b/src/daft-plan/src/optimization/rules/push_down_projection.rs index 86447a5d4a..78b55655e7 100644 --- a/src/daft-plan/src/optimization/rules/push_down_projection.rs +++ b/src/daft-plan/src/optimization/rules/push_down_projection.rs @@ -8,7 +8,7 @@ use indexmap::IndexSet; use crate::{ logical_ops::{Aggregate, Project, Source}, - LogicalPlan, + LogicalPlan, ResourceRequest, }; use super::{ApplyOrder, OptimizerRule, Transformed}; @@ -123,7 +123,10 @@ impl PushDownProjection { let new_plan: LogicalPlan = Project::try_new( upstream_projection.input.clone(), merged_projection, - &upstream_projection.resource_request + &projection.resource_request, + ResourceRequest::max(&[ + &upstream_projection.resource_request, + &projection.resource_request, + ]), )? .into(); let new_plan: Arc = new_plan.into(); diff --git a/src/daft-plan/src/resource_request.rs b/src/daft-plan/src/resource_request.rs index 62df5166ca..b4dbb940d0 100644 --- a/src/daft-plan/src/resource_request.rs +++ b/src/daft-plan/src/resource_request.rs @@ -1,11 +1,8 @@ use daft_core::{impl_bincode_py_state_serialization, utils::hashable_float_wrapper::FloatWrapper}; +#[cfg(feature = "python")] +use pyo3::{pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyResult, Python}; use std::hash::{Hash, Hasher}; use std::ops::Add; -#[cfg(feature = "python")] -use { - pyo3::{pyclass, pyclass::CompareOp, pymethods, types::PyBytes, PyResult, Python}, - std::cmp::max, -}; use serde::{Deserialize, Serialize}; @@ -29,6 +26,15 @@ impl ResourceRequest { memory_bytes, } } + + pub fn max(resource_requests: &[&Self]) -> Self { + resource_requests.iter().fold(Default::default(), |acc, e| { + let max_num_cpus = lift(float_max, acc.num_cpus, e.num_cpus); + let max_num_gpus = lift(float_max, acc.num_gpus, e.num_gpus); + let max_memory_bytes = lift(std::cmp::max, acc.memory_bytes, e.memory_bytes); + Self::new_internal(max_num_cpus, max_num_gpus, max_memory_bytes) + }) + } } impl Add for &ResourceRequest { @@ -82,12 +88,7 @@ impl ResourceRequest { #[staticmethod] pub fn max_resources(resource_requests: Vec) -> Self { - resource_requests.iter().fold(Default::default(), |acc, e| { - let max_num_cpus = lift(float_max, acc.num_cpus, e.num_cpus); - let max_num_gpus = lift(float_max, acc.num_gpus, e.num_gpus); - let max_memory_bytes = lift(max, acc.memory_bytes, e.memory_bytes); - Self::new_internal(max_num_cpus, max_num_gpus, max_memory_bytes) - }) + Self::max(&resource_requests.iter().collect::>()) } #[getter]