diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 8461882da1..a41ec19da4 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -176,28 +176,29 @@ def set_result(self, result: list[MaterializedResult[PartitionT]]) -> None: def done(self) -> bool: return self._result is not None + def result(self) -> MaterializedResult[PartitionT]: + assert self._result is not None, "Cannot call .result() on a PartitionTask that is not done" + return self._result + def cancel(self) -> None: # Currently only implemented for Ray tasks. - if self._result is not None: - self._result.cancel() + if self.done(): + self.result().cancel() def partition(self) -> PartitionT: """Get the PartitionT resulting from running this PartitionTask.""" - assert self._result is not None - return self._result.partition() + return self.result().partition() def partition_metadata(self) -> PartitionMetadata: """Get the metadata of the result partition. (Avoids retrieving the actual partition itself if possible.) """ - assert self._result is not None - return self._result.metadata() + return self.result().metadata() def vpartition(self) -> Table: """Get the raw vPartition of the result.""" - assert self._result is not None - return self._result.vpartition() + return self.result().vpartition() def __str__(self) -> str: return super().__str__() diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 8027ddf15e..a25f3511b4 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -741,7 +741,7 @@ def materialize( # Check if any inputs finished executing. while len(materializations) > 0 and materializations[0].done(): done_task = materializations.popleft() - yield done_task._result + yield done_task.result() # Materialize a single dependency. try: