Skip to content

Commit

Permalink
init
Browse files Browse the repository at this point in the history
  • Loading branch information
colin-ho committed Feb 22, 2024
1 parent b702e4d commit ace5b93
Show file tree
Hide file tree
Showing 19 changed files with 563 additions and 13 deletions.
11 changes: 9 additions & 2 deletions daft/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,7 +67,14 @@ def get_build_type() -> str:
from daft.dataframe import DataFrame
from daft.datatype import DataType, TimeUnit
from daft.expressions import Expression, col, lit
from daft.io import from_glob_path, read_csv, read_iceberg, read_json, read_parquet
from daft.io import (
from_glob_path,
read_csv,
read_iceberg,
read_json,
read_parquet,
read_sql,
)
from daft.series import Series
from daft.udf import udf
from daft.viz import register_viz_hook
Expand All @@ -84,7 +91,7 @@ def get_build_type() -> str:
"read_json",
"read_parquet",
"read_iceberg",
"DataFrame",
"read_sql" "DataFrame",
"Expression",
"col",
"DataType",
Expand Down
31 changes: 30 additions & 1 deletion daft/daft.pyi
Original file line number Diff line number Diff line change
Expand Up @@ -267,12 +267,23 @@ class JsonSourceConfig:
chunk_size: int | None = None,
): ...

class DatabaseSourceConfig:
"""
Configuration of a database data source.
"""

sql: str
limit: int | None
offset: int | None

def __init__(self, sql: str, limit: int | None = None, offset: int | None = None): ...

class FileFormatConfig:
"""
Configuration for parsing a particular file format (Parquet, CSV, JSON).
"""

config: ParquetSourceConfig | CsvSourceConfig | JsonSourceConfig
config: ParquetSourceConfig | CsvSourceConfig | JsonSourceConfig | DatabaseSourceConfig

@staticmethod
def from_parquet_config(config: ParquetSourceConfig) -> FileFormatConfig:
Expand All @@ -292,6 +303,12 @@ class FileFormatConfig:
Create a JSON file format config.
"""
...
@staticmethod
def from_database_config(config: DatabaseSourceConfig) -> FileFormatConfig:
"""
Create a database file format config.
"""
...
def file_format(self) -> FileFormat:
"""
Get the file format for this config.
Expand Down Expand Up @@ -608,6 +625,18 @@ class ScanTask:
Create a Catalog Scan Task
"""
...
@staticmethod
def sql_scan_task(
url: str,
file_format: FileFormatConfig,
schema: PySchema,
storage_config: StorageConfig,
pushdowns: Pushdowns | None,
) -> ScanTask:
"""
Create a SQL Scan Task
"""
...

class ScanOperatorHandle:
"""
Expand Down
1 change: 1 addition & 0 deletions daft/iceberg/iceberg_scan.py
Original file line number Diff line number Diff line change
Expand Up @@ -169,6 +169,7 @@ def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
continue
rows_left -= record_count
scan_tasks.append(st)

return iter(scan_tasks)

def can_absorb_filter(self) -> bool:
Expand Down
2 changes: 2 additions & 0 deletions daft/io/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@
from daft.io._iceberg import read_iceberg
from daft.io._json import read_json
from daft.io._parquet import read_parquet
from daft.io._sql import read_sql
from daft.io.file_path import from_glob_path


Expand All @@ -36,6 +37,7 @@ def _set_linux_cert_paths():
"from_glob_path",
"read_parquet",
"read_iceberg",
"read_sql",
"IOConfig",
"S3Config",
"AzureConfig",
Expand Down
60 changes: 60 additions & 0 deletions daft/io/_sql.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,60 @@
# isort: dont-add-import: from __future__ import annotations


from daft import context
from daft.api_annotations import PublicAPI
from daft.daft import (
NativeStorageConfig,
PythonStorageConfig,
ScanOperatorHandle,
StorageConfig,
)
from daft.dataframe import DataFrame
from daft.logical.builder import LogicalPlanBuilder
from daft.sql.sql_scan import SQLScanOperator


def native_downloader_available(url: str) -> bool:
# TODO: We should be able to support native downloads via ConnectorX for compatible databases
return False


@PublicAPI
def read_sql(
sql: str,
url: str,
use_native_downloader: bool = False,
# schema_hints: Optional[Dict[str, DataType]] = None,
) -> DataFrame:
"""Creates a DataFrame from a SQL query
Example:
>>> def create_connection():
return sqlite3.connect("example.db")
>>> df = daft.read_sql("SELECT * FROM my_table", create_connection)
Args:
sql (str): SQL query to execute
connection_factory (Callable[[], Connection]): A callable that returns a connection to the database.
_multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing
the amount of system resources (number of connections and thread contention) when running in the Ray runner.
Defaults to None, which will let Daft decide based on the runner it is currently using.
returns:
DataFrame: parsed DataFrame
"""

io_config = context.get_context().daft_planning_config.default_io_config

multithreaded_io = not context.get_context().is_ray_runner

if use_native_downloader and native_downloader_available(url):
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config))
else:
storage_config = StorageConfig.python(PythonStorageConfig(io_config))

sql_operator = SQLScanOperator(sql, url, storage_config=storage_config)

handle = ScanOperatorHandle.from_python_scan_operator(sql_operator)
builder = LogicalPlanBuilder.from_tabular_scan_with_scan_operator(scan_operator=handle)
return DataFrame(builder)
13 changes: 13 additions & 0 deletions daft/runners/partitioning.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,19 @@ class TableParseParquetOptions:
coerce_int96_timestamp_unit: TimeUnit = TimeUnit.ns()


@dataclass(frozen=True)
class TableParseSQLOptions:
"""Options for parsing SQL tables
Args:
limit: Number of rows to read, or None to read all rows
offset: Number of rows to skip before reading
"""

limit: int | None = None
offset: int | None = None


@dataclass(frozen=True)
class PartialPartitionMetadata:
num_rows: None | int
Expand Down
Empty file added daft/sql/__init__.py
Empty file.
135 changes: 135 additions & 0 deletions daft/sql/sql_scan.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,135 @@
from __future__ import annotations

import logging
import math
from collections.abc import Iterator

import pyarrow as pa
from sqlalchemy import create_engine, text

from daft.daft import (
DatabaseSourceConfig,
FileFormatConfig,
Pushdowns,
ScanTask,
StorageConfig,
)
from daft.io.scan import PartitionField, ScanOperator
from daft.logical.schema import Schema

logger = logging.getLogger(__name__)


class SQLScanOperator(ScanOperator):
MIN_ROWS_PER_SCAN_TASK = 50 # Would be better to have a memory limit instead of a row limit

def __init__(
self,
sql: str,
url: str,
storage_config: StorageConfig,
) -> None:
super().__init__()
self.sql = sql
self.url = url
self.storage_config = storage_config
self._limit_supported = self._check_limit_supported()
self._schema = self._get_schema()

def _check_limit_supported(self) -> bool:
try:
with create_engine(self.url).connect() as connection:
connection.execute(text(f"SELECT * FROM ({self.sql}) AS subquery LIMIT 1 OFFSET 0"))

return True
except Exception:
return False

def _get_schema(self) -> Schema:
with create_engine(self.url).connect() as connection:
sql = f"SELECT * FROM ({self.sql}) AS subquery"
if self._limit_supported:
sql += " LIMIT 1 OFFSET 0"

result = connection.execute(text(sql))

# Fetch the cursor from the result proxy to access column descriptions
cursor = result.cursor

rows = cursor.fetchall()
columns = [column_description[0] for column_description in cursor.description]

pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)}
pa_table = pa.Table.from_pydict(pydict)

return Schema.from_pyarrow_schema(pa_table.schema)

def _get_num_rows(self) -> int:
with create_engine(self.url).connect() as connection:
result = connection.execute(text(f"SELECT COUNT(*) FROM ({self.sql}) AS subquery"))
cursor = result.cursor
return cursor.fetchone()[0]

def schema(self) -> Schema:
return self._schema

def display_name(self) -> str:
return f"SQLScanOperator({self.sql})"

def partitioning_keys(self) -> list[PartitionField]:
return []

def multiline_display(self) -> list[str]:
return [
self.display_name(),
f"Schema = {self._schema}",
]

def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]:
if not self._limit_supported:
file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql))
return iter(
[
ScanTask.sql_scan_task(
url=self.url,
file_format=file_format_config,
schema=self._schema._schema,
storage_config=self.storage_config,
pushdowns=pushdowns,
)
]
)

total_rows = self._get_num_rows()
num_scan_tasks = math.ceil(total_rows / self.MIN_ROWS_PER_SCAN_TASK)
num_rows_per_scan_task = total_rows // num_scan_tasks

scan_tasks = []
offset = 0
for _ in range(num_scan_tasks):
limit = min(num_rows_per_scan_task, total_rows - offset)
file_format_config = FileFormatConfig.from_database_config(
DatabaseSourceConfig(self.sql, limit=limit, offset=offset)
)

scan_tasks.append(
ScanTask.sql_scan_task(
url=self.url,
file_format=file_format_config,
schema=self._schema._schema,
storage_config=self.storage_config,
pushdowns=pushdowns,
)
)
offset += limit

return iter(scan_tasks)

def can_absorb_filter(self) -> bool:
return False

def can_absorb_limit(self) -> bool:
return False

def can_absorb_select(self) -> bool:
return True
Loading

0 comments on commit ace5b93

Please sign in to comment.