Skip to content

Commit

Permalink
Fix RayRunnerIO to use object refs
Browse files Browse the repository at this point in the history
  • Loading branch information
Jay Chia committed Dec 8, 2023
1 parent 018748e commit 413b991
Show file tree
Hide file tree
Showing 3 changed files with 26 additions and 31 deletions.
9 changes: 9 additions & 0 deletions daft/context.py
Original file line number Diff line number Diff line change
Expand Up @@ -115,6 +115,15 @@ def get_context() -> DaftContext:
return _DaftContext


def set_context(ctx: DaftContext) -> DaftContext:
global _DaftContext

pop_context()
_DaftContext = ctx

return _DaftContext


def pop_context() -> DaftContext:
"""Helper used in tests and test fixtures to clear the global runner and allow for re-setting of configs."""
global _DaftContext
Expand Down
37 changes: 15 additions & 22 deletions daft/runners/ray_runner.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,7 +11,7 @@

import pyarrow as pa

from daft.context import get_context, set_config
from daft.context import set_config
from daft.logical.builder import LogicalPlanBuilder
from daft.plan_scheduler import PhysicalPlanScheduler
from daft.runners.progress_bar import ProgressBar
Expand Down Expand Up @@ -223,8 +223,8 @@ def wait(self) -> None:


class RayRunnerIO(runner_io.RunnerIO):
def __init__(self, daft_config: PyDaftConfig, *args, **kwargs):
self.daft_config = daft_config
def __init__(self, daft_config_objref: ray.ObjectRef, *args, **kwargs):
self.daft_config_objref = daft_config_objref
super().__init__(*args, **kwargs)

def glob_paths_details(
Expand All @@ -237,7 +237,7 @@ def glob_paths_details(
return FileInfos.from_table(
ray.get(
_glob_path_into_file_infos.remote(
self.daft_config, source_paths, file_format_config, io_config=io_config
self.daft_config_objref, source_paths, file_format_config, io_config=io_config
)
)
.to_table()
Expand All @@ -256,7 +256,7 @@ def get_schema_from_first_filepath(
first_path = file_infos[0].file_path
return ray.get(
sample_schema_from_filepath.remote(
self.daft_config,
self.daft_config_objref,
first_path,
file_format_config,
storage_config,
Expand All @@ -267,7 +267,6 @@ def partition_set_from_ray_dataset(
self,
ds: RayDataset,
) -> tuple[RayPartitionSet, Schema]:
daft_config_objref = ray.put(self.daft_config)
arrow_schema = ds.schema(fetch_if_missing=True)
if not isinstance(arrow_schema, pa.Schema):
# Convert Dataset to an Arrow dataset.
Expand Down Expand Up @@ -295,15 +294,15 @@ def partition_set_from_ray_dataset(
# NOTE: This materializes the entire Ray Dataset - we could make this more intelligent by creating a new RayDatasetScan node
# which can iterate on Ray Dataset blocks and materialize as-needed
daft_vpartitions = [
_make_daft_partition_from_ray_dataset_blocks.remote(daft_config_objref, block, daft_schema)
_make_daft_partition_from_ray_dataset_blocks.remote(self.daft_config_objref, block, daft_schema)
for block in block_refs
]

return (
RayPartitionSet(
_daft_config_objref=daft_config_objref,
_daft_config_objref=self.daft_config_objref,
_results={
i: RayMaterializedResult(obj, _daft_config_objref=daft_config_objref)
i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref)
for i, obj in enumerate(daft_vpartitions)
},
),
Expand All @@ -317,15 +316,13 @@ def partition_set_from_dask_dataframe(
import dask
from ray.util.dask import ray_dask_get

daft_config_objref = ray.put(self.daft_config)

partitions = ddf.to_delayed()
if not partitions:
raise ValueError("Can't convert an empty Dask DataFrame (with no partitions) to a Daft DataFrame.")
persisted_partitions = dask.persist(*partitions, scheduler=ray_dask_get)
parts = [_to_pandas_ref(next(iter(part.dask.values()))) for part in persisted_partitions]
daft_vpartitions, schemas = zip(
*(_make_daft_partition_from_dask_dataframe_partitions.remote(daft_config_objref, p) for p in parts)
*(_make_daft_partition_from_dask_dataframe_partitions.remote(self.daft_config_objref, p) for p in parts)
)
schemas = ray.get(list(schemas))
# Dask shouldn't allow inconsistent schemas across partitions, but we double-check here.
Expand All @@ -336,9 +333,9 @@ def partition_set_from_dask_dataframe(
)
return (
RayPartitionSet(
_daft_config_objref=daft_config_objref,
_daft_config_objref=self.daft_config_objref,
_results={
i: RayMaterializedResult(obj, _daft_config_objref=daft_config_objref)
i: RayMaterializedResult(obj, _daft_config_objref=self.daft_config_objref)
for i, obj in enumerate(daft_vpartitions)
},
),
Expand Down Expand Up @@ -391,36 +388,32 @@ def fanout_pipeline(
daft_config: PyDaftConfig, instruction_stack: list[Instruction], *inputs: MicroPartition
) -> list[list[PartitionMetadata] | MicroPartition]:
set_config(daft_config)

return build_partitions(instruction_stack, *inputs)


@ray.remote(scheduling_strategy="SPREAD")
def reduce_pipeline(
daft_config: PyDaftConfig, instruction_stack: list[Instruction], inputs: list
) -> list[list[PartitionMetadata] | MicroPartition]:
set_config(daft_config)

import ray

set_config(daft_config)
return build_partitions(instruction_stack, *ray.get(inputs))


@ray.remote(scheduling_strategy="SPREAD")
def reduce_and_fanout(
daft_config: PyDaftConfig, instruction_stack: list[Instruction], inputs: list
) -> list[list[PartitionMetadata] | MicroPartition]:
set_config(daft_config)

import ray

set_config(daft_config)
return build_partitions(instruction_stack, *ray.get(inputs))


@ray.remote
def get_meta(daft_config: PyDaftConfig, partition: MicroPartition) -> PartitionMetadata:
set_config(daft_config)

return PartitionMetadata.from_table(partition)


Expand Down Expand Up @@ -839,7 +832,7 @@ def put_partition_set_into_cache(self, pset: PartitionSet) -> PartitionCacheEntr
return self._part_set_cache.put_partition_set(pset=pset)

def runner_io(self) -> RayRunnerIO:
return RayRunnerIO(daft_config=self.daft_config)
return RayRunnerIO(daft_config_objref=self.daft_config_objref)


@dataclass(frozen=True)
Expand All @@ -859,7 +852,7 @@ def metadata(self) -> PartitionMetadata:
if self._metadatas is not None and self._metadata_index is not None:
return self._metadatas.get_index(self._metadata_index)
else:
return ray.get(get_meta.remote(get_context().daft_config, self._partition))
return ray.get(get_meta.remote(self._daft_config_objref, self._partition))

def cancel(self) -> None:
return ray.cancel(self._partition)
Expand Down
11 changes: 2 additions & 9 deletions tests/io/test_merge_scan_tasks.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,23 +10,16 @@

@contextlib.contextmanager
def override_merge_scan_tasks_configs(merge_scan_tasks_min_size_bytes: int, merge_scan_tasks_max_size_bytes: int):
config = daft.context.get_context().daft_config
original_merge_scan_tasks_min_size_bytes = config.merge_scan_tasks_min_size_bytes
original_merge_scan_tasks_max_size_bytes = config.merge_scan_tasks_max_size_bytes
old_context = daft.context.pop_context()

try:
daft.context.pop_context()
daft.context.set_config(
merge_scan_tasks_min_size_bytes=merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=merge_scan_tasks_max_size_bytes,
)
yield
finally:
daft.context.pop_context()
daft.context.set_config(
merge_scan_tasks_min_size_bytes=original_merge_scan_tasks_min_size_bytes,
merge_scan_tasks_max_size_bytes=original_merge_scan_tasks_max_size_bytes,
)
daft.context.set_context(old_context)


@pytest.fixture(scope="function")
Expand Down

0 comments on commit 413b991

Please sign in to comment.