Skip to content

Commit

Permalink
[PERF] Lazily import heavy modules to speed up import times (#2826)
Browse files Browse the repository at this point in the history
Introduce lazy imports for heavy modules that are not needed as
top-level imports. For example, `ray` does not need to be a top level
import (it should only be imported when using the ray runner or when
specific ray data extension types needed. Another example would be
`UnityCatalogTable`, which is a relatively heavy import despite only
being needed when using delta lake.

Modules to import lazily were determined by the proportion of import
time as shown by `importtime-output-wrapper -c 'import daft' --format
waterfall --depth 25`.

The list of newly lazily imported modules are:
- `daft.unity_catalog`
- `fsspec`
- `numpy`
- `pandas`
- `PIL.Image`
- `pyarrow`
- `pyarrow.csv` 
- `pyarrow.dataset`
- `pyarrow.fs`
- `pyarrow.json`
- `pyarrow.parquet` 
- `ray`
- `ray.data.extensions`
- `xml.etree.ElementTree` 

Uses #2836 in order to defer
the import of `pyarrow`.

Additionally, we move all type-checking-only module imports into type
checking blocks.

With these changes, import times go from roughly 0.6-0.7s to ~0.045s
(~13-15x faster).

---------

Co-authored-by: Sammy Sidhu <[email protected]>
  • Loading branch information
desmondcheongzx and samster25 authored Sep 19, 2024
1 parent dba931f commit 78a92a2
Show file tree
Hide file tree
Showing 52 changed files with 401 additions and 311 deletions.
13 changes: 13 additions & 0 deletions daft/.ruff.toml
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
extend = "../.ruff.toml"

[lint]
extend-select = [
"TID253", # banned-module-level-imports, derived from flake8-tidy-imports
"TCH" # flake8-type-checking
]

[lint.flake8-tidy-imports]
# Ban certain modules from being imported at module level, instead requiring
# that they're imported lazily (e.g., within a function definition,
# with daft.lazy_import.LazyImport, or with TYPE_CHECKING).
banned-module-level-imports = ["daft.unity_catalog", "fsspec", "numpy", "pandas", "PIL", "pyarrow", "ray", "xml"]
30 changes: 19 additions & 11 deletions daft/arrow_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import sys

import pyarrow as pa
from daft.dependencies import pa


def ensure_array(arr: pa.Array) -> pa.Array:
Expand Down Expand Up @@ -34,13 +34,21 @@ class _FixEmptyStructArrays:
Python layer before going through ffi into Rust.
"""

EMPTY_STRUCT_TYPE = pa.struct([])
SINGLE_FIELD_STRUCT_TYPE = pa.struct({"": pa.null()})
SINGLE_FIELD_STRUCT_VALUE = {"": None}
@staticmethod
def get_empty_struct_type():
return pa.struct([])

@staticmethod
def get_single_field_struct_type():
return pa.struct({"": pa.null()})

@staticmethod
def get_single_field_struct_value():
return {"": None}

def ensure_table(table: pa.Table) -> pa.Table:
empty_struct_fields = [
(i, f) for (i, f) in enumerate(table.schema) if f.type == _FixEmptyStructArrays.EMPTY_STRUCT_TYPE
(i, f) for (i, f) in enumerate(table.schema) if f.type == _FixEmptyStructArrays.get_empty_struct_type()
]
if not empty_struct_fields:
return table
Expand All @@ -49,19 +57,19 @@ def ensure_table(table: pa.Table) -> pa.Table:
return table

def ensure_chunked_array(arr: pa.ChunkedArray) -> pa.ChunkedArray:
if arr.type != _FixEmptyStructArrays.EMPTY_STRUCT_TYPE:
if arr.type != _FixEmptyStructArrays.get_empty_struct_type():
return arr
return pa.chunked_array([_FixEmptyStructArrays.ensure_array(chunk) for chunk in arr.chunks])

def ensure_array(arr: pa.Array) -> pa.Array:
"""Recursively converts empty struct arrays to single-field struct arrays"""
if arr.type == _FixEmptyStructArrays.EMPTY_STRUCT_TYPE:
if arr.type == _FixEmptyStructArrays.get_empty_struct_type():
return pa.array(
[
_FixEmptyStructArrays.SINGLE_FIELD_STRUCT_VALUE if valid.as_py() else None
_FixEmptyStructArrays.get_single_field_struct_value() if valid.as_py() else None
for valid in arr.is_valid()
],
type=_FixEmptyStructArrays.SINGLE_FIELD_STRUCT_TYPE,
type=_FixEmptyStructArrays.get_single_field_struct_type(),
)

elif isinstance(arr, pa.StructArray):
Expand All @@ -77,10 +85,10 @@ def ensure_array(arr: pa.Array) -> pa.Array:

def remove_empty_struct_placeholders(arr: pa.Array):
"""Recursively removes the empty struct placeholders placed by _FixEmptyStructArrays.ensure_array"""
if arr.type == _FixEmptyStructArrays.SINGLE_FIELD_STRUCT_TYPE:
if arr.type == _FixEmptyStructArrays.get_single_field_struct_type():
return pa.array(
[{} if valid.as_py() else None for valid in arr.is_valid()],
type=_FixEmptyStructArrays.EMPTY_STRUCT_TYPE,
type=_FixEmptyStructArrays.get_empty_struct_type(),
)

elif isinstance(arr, pa.StructArray):
Expand Down
14 changes: 6 additions & 8 deletions daft/daft/__init__.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -3,8 +3,6 @@ import datetime
from enum import Enum
from typing import TYPE_CHECKING, Any, Callable, Iterator

import pyarrow

from daft.dataframe.display import MermaidOptions
from daft.execution import physical_plan
from daft.io.scan import ScanOperator
Expand Down Expand Up @@ -994,7 +992,7 @@ class PyDataType:
def tensor(dtype: PyDataType, shape: tuple[int, ...] | None = None) -> PyDataType: ...
@staticmethod
def python() -> PyDataType: ...
def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pyarrow.DataType: ...
def to_arrow(self, cast_tensor_type_for_ray: builtins.bool | None = None) -> pa.DataType: ...
def is_numeric(self) -> builtins.bool: ...
def is_image(self) -> builtins.bool: ...
def is_fixed_shape_image(self) -> builtins.bool: ...
Expand Down Expand Up @@ -1271,11 +1269,11 @@ class PyCatalog:

class PySeries:
@staticmethod
def from_arrow(name: str, pyarrow_array: pyarrow.Array) -> PySeries: ...
def from_arrow(name: str, pyarrow_array: pa.Array) -> PySeries: ...
@staticmethod
def from_pylist(name: str, pylist: list[Any], pyobj: str) -> PySeries: ...
def to_pylist(self) -> list[Any]: ...
def to_arrow(self) -> pyarrow.Array: ...
def to_arrow(self) -> pa.Array: ...
def __abs__(self) -> PySeries: ...
def __add__(self, other: PySeries) -> PySeries: ...
def __sub__(self, other: PySeries) -> PySeries: ...
Expand Down Expand Up @@ -1456,10 +1454,10 @@ class PyTable:
def concat(tables: list[PyTable]) -> PyTable: ...
def slice(self, start: int, end: int) -> PyTable: ...
@staticmethod
def from_arrow_record_batches(record_batches: list[pyarrow.RecordBatch], schema: PySchema) -> PyTable: ...
def from_arrow_record_batches(record_batches: list[pa.RecordBatch], schema: PySchema) -> PyTable: ...
@staticmethod
def from_pylist_series(dict: dict[str, PySeries]) -> PyTable: ...
def to_arrow_record_batch(self) -> pyarrow.RecordBatch: ...
def to_arrow_record_batch(self) -> pa.RecordBatch: ...
@staticmethod
def empty(schema: PySchema | None = None) -> PyTable: ...

Expand All @@ -1476,7 +1474,7 @@ class PyMicroPartition:
@staticmethod
def from_tables(tables: list[PyTable]) -> PyMicroPartition: ...
@staticmethod
def from_arrow_record_batches(record_batches: list[pyarrow.RecordBatch], schema: PySchema) -> PyMicroPartition: ...
def from_arrow_record_batches(record_batches: list[pa.RecordBatch], schema: PySchema) -> PyMicroPartition: ...
@staticmethod
def concat(tables: list[PyMicroPartition]) -> PyMicroPartition: ...
def slice(self, start: int, end: int) -> PyMicroPartition: ...
Expand Down
4 changes: 3 additions & 1 deletion daft/dataframe/preview.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,10 @@
from __future__ import annotations

from dataclasses import dataclass
from typing import TYPE_CHECKING

from daft.table import MicroPartition
if TYPE_CHECKING:
from daft.table import MicroPartition


@dataclass(frozen=True)
Expand Down
53 changes: 34 additions & 19 deletions daft/datatype.py
Original file line number Diff line number Diff line change
@@ -1,14 +1,14 @@
from __future__ import annotations

import builtins
from typing import TYPE_CHECKING

import pyarrow as pa

from daft.context import get_context
from daft.daft import ImageMode, PyDataType, PyTimeUnit
from daft.dependencies import pa

if TYPE_CHECKING:
import builtins

import numpy as np


Expand Down Expand Up @@ -501,25 +501,40 @@ def __hash__(self) -> int:
return self._dtype.__hash__()


class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")
_EXT_TYPE_REGISTERED = False
_STATIC_DAFT_EXTENSION = None

def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())

def __arrow_ext_serialize__(self):
return self._metadata
def _ensure_registered_super_ext_type():
global _EXT_TYPE_REGISTERED
global _STATIC_DAFT_EXTENSION
if not _EXT_TYPE_REGISTERED:

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)
class DaftExtension(pa.ExtensionType):
def __init__(self, dtype, metadata=b""):
# attributes need to be set first before calling
# super init (as that calls serialize)
self._metadata = metadata
super().__init__(dtype, "daft.super_extension")

def __reduce__(self):
return type(self).__arrow_ext_deserialize__, (self.storage_type, self.__arrow_ext_serialize__())

def __arrow_ext_serialize__(self):
return self._metadata

@classmethod
def __arrow_ext_deserialize__(cls, storage_type, serialized):
return cls(storage_type, serialized)

_STATIC_DAFT_EXTENSION = DaftExtension
pa.register_extension_type(DaftExtension(pa.null()))
import atexit

atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
_EXT_TYPE_REGISTERED = True

pa.register_extension_type(DaftExtension(pa.null()))
import atexit

atexit.register(lambda: pa.unregister_extension_type("daft.super_extension"))
def get_super_ext_type():
_ensure_registered_super_ext_type()
return _STATIC_DAFT_EXTENSION
5 changes: 4 additions & 1 deletion daft/delta_lake/delta_lake_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,7 @@

import logging
import os
from collections.abc import Iterator
from typing import TYPE_CHECKING

from deltalake.table import DeltaTable

Expand All @@ -20,6 +20,9 @@
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema

if TYPE_CHECKING:
from collections.abc import Iterator

logger = logging.getLogger(__name__)


Expand Down
32 changes: 32 additions & 0 deletions daft/dependencies.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
from typing import TYPE_CHECKING

from daft.lazy_import import LazyImport

if TYPE_CHECKING:
import xml.etree.ElementTree as ET

import fsspec
import numpy as np
import pandas as pd
import PIL.Image as pil_image
import pyarrow as pa
import pyarrow.csv as pacsv
import pyarrow.dataset as pads
import pyarrow.fs as pafs
import pyarrow.json as pajson
import pyarrow.parquet as pq
else:
ET = LazyImport("xml.etree.ElementTree")

fsspec = LazyImport("fsspec")
np = LazyImport("numpy")
pd = LazyImport("pandas")
pil_image = LazyImport("PIL.Image")
pa = LazyImport("pyarrow")
pacsv = LazyImport("pyarrow.csv")
pads = LazyImport("pyarrow.dataset")
pafs = LazyImport("pyarrow.fs")
pajson = LazyImport("pyarrow.json")
pq = LazyImport("pyarrow.parquet")

unity_catalog = LazyImport("daft.unity_catalog")
11 changes: 7 additions & 4 deletions daft/execution/execution_step.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,12 @@
from __future__ import annotations

import itertools
import pathlib
from dataclasses import dataclass, field
from typing import TYPE_CHECKING, Generic, Protocol

from daft.context import get_context
from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest, ScanTask
from daft.daft import ResourceRequest
from daft.expressions import Expression, ExpressionsProjection, col
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import (
Boundaries,
MaterializedResult,
Expand All @@ -20,9 +17,15 @@
from daft.table import MicroPartition, table_io

if TYPE_CHECKING:
import pathlib

from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.daft import FileFormat, IOConfig, JoinType, ScanTask
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema


ID_GEN = itertools.count()

Expand Down
10 changes: 5 additions & 5 deletions daft/execution/native_executor.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,14 @@
NativeExecutor as _NativeExecutor,
)
from daft.daft import PyDaftExecutionConfig
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
MaterializedResult,
PartitionT,
)
from daft.table import MicroPartition

if TYPE_CHECKING:
from daft.logical.builder import LogicalPlanBuilder
from daft.runners.partitioning import (
MaterializedResult,
PartitionT,
)
from daft.runners.pyrunner import PyMaterializedResult


Expand Down
9 changes: 6 additions & 3 deletions daft/execution/physical_plan.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,7 +17,6 @@
import itertools
import logging
import math
import pathlib
from collections import deque
from typing import (
TYPE_CHECKING,
Expand All @@ -30,7 +29,7 @@
)

from daft.context import get_context
from daft.daft import FileFormat, IOConfig, JoinType, ResourceRequest
from daft.daft import ResourceRequest
from daft.execution import execution_step
from daft.execution.execution_step import (
Instruction,
Expand All @@ -41,7 +40,6 @@
SingleOutputPartitionTask,
)
from daft.expressions import ExpressionsProjection
from daft.logical.schema import Schema
from daft.runners.partitioning import (
MaterializedResult,
PartitionT,
Expand All @@ -53,9 +51,14 @@
T = TypeVar("T")

if TYPE_CHECKING:
import pathlib

from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.daft import FileFormat, IOConfig, JoinType
from daft.logical.schema import Schema


# A PhysicalPlan that is still being built - may yield both PartitionTaskBuilders and PartitionTasks.
InProgressPhysicalPlan = Iterator[Union[None, PartitionTask[PartitionT], PartitionTaskBuilder[PartitionT]]]
Expand Down
3 changes: 2 additions & 1 deletion daft/execution/rust_physical_plan_shim.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,12 +17,13 @@
from daft.logical.map_partition_ops import MapPartitionOp
from daft.logical.schema import Schema
from daft.runners.partitioning import PartitionT
from daft.table import MicroPartition

if TYPE_CHECKING:
from pyiceberg.schema import Schema as IcebergSchema
from pyiceberg.table import TableProperties as IcebergTableProperties

from daft.table import MicroPartition


def scan_with_tasks(
scan_tasks: list[ScanTask],
Expand Down
Loading

0 comments on commit 78a92a2

Please sign in to comment.