-
Notifications
You must be signed in to change notification settings - Fork 175
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
- Loading branch information
Showing
19 changed files
with
563 additions
and
13 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,60 @@ | ||
# isort: dont-add-import: from __future__ import annotations | ||
|
||
|
||
from daft import context | ||
from daft.api_annotations import PublicAPI | ||
from daft.daft import ( | ||
NativeStorageConfig, | ||
PythonStorageConfig, | ||
ScanOperatorHandle, | ||
StorageConfig, | ||
) | ||
from daft.dataframe import DataFrame | ||
from daft.logical.builder import LogicalPlanBuilder | ||
from daft.sql.sql_scan import SQLScanOperator | ||
|
||
|
||
def native_downloader_available(url: str) -> bool: | ||
# TODO: We should be able to support native downloads via ConnectorX for compatible databases | ||
return False | ||
|
||
|
||
@PublicAPI | ||
def read_sql( | ||
sql: str, | ||
url: str, | ||
use_native_downloader: bool = False, | ||
# schema_hints: Optional[Dict[str, DataType]] = None, | ||
) -> DataFrame: | ||
"""Creates a DataFrame from a SQL query | ||
Example: | ||
>>> def create_connection(): | ||
return sqlite3.connect("example.db") | ||
>>> df = daft.read_sql("SELECT * FROM my_table", create_connection) | ||
Args: | ||
sql (str): SQL query to execute | ||
connection_factory (Callable[[], Connection]): A callable that returns a connection to the database. | ||
_multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing | ||
the amount of system resources (number of connections and thread contention) when running in the Ray runner. | ||
Defaults to None, which will let Daft decide based on the runner it is currently using. | ||
returns: | ||
DataFrame: parsed DataFrame | ||
""" | ||
|
||
io_config = context.get_context().daft_planning_config.default_io_config | ||
|
||
multithreaded_io = not context.get_context().is_ray_runner | ||
|
||
if use_native_downloader and native_downloader_available(url): | ||
storage_config = StorageConfig.native(NativeStorageConfig(multithreaded_io, io_config)) | ||
else: | ||
storage_config = StorageConfig.python(PythonStorageConfig(io_config)) | ||
|
||
sql_operator = SQLScanOperator(sql, url, storage_config=storage_config) | ||
|
||
handle = ScanOperatorHandle.from_python_scan_operator(sql_operator) | ||
builder = LogicalPlanBuilder.from_tabular_scan_with_scan_operator(scan_operator=handle) | ||
return DataFrame(builder) |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,135 @@ | ||
from __future__ import annotations | ||
|
||
import logging | ||
import math | ||
from collections.abc import Iterator | ||
|
||
import pyarrow as pa | ||
from sqlalchemy import create_engine, text | ||
|
||
from daft.daft import ( | ||
DatabaseSourceConfig, | ||
FileFormatConfig, | ||
Pushdowns, | ||
ScanTask, | ||
StorageConfig, | ||
) | ||
from daft.io.scan import PartitionField, ScanOperator | ||
from daft.logical.schema import Schema | ||
|
||
logger = logging.getLogger(__name__) | ||
|
||
|
||
class SQLScanOperator(ScanOperator): | ||
MIN_ROWS_PER_SCAN_TASK = 50 # Would be better to have a memory limit instead of a row limit | ||
|
||
def __init__( | ||
self, | ||
sql: str, | ||
url: str, | ||
storage_config: StorageConfig, | ||
) -> None: | ||
super().__init__() | ||
self.sql = sql | ||
self.url = url | ||
self.storage_config = storage_config | ||
self._limit_supported = self._check_limit_supported() | ||
self._schema = self._get_schema() | ||
|
||
def _check_limit_supported(self) -> bool: | ||
try: | ||
with create_engine(self.url).connect() as connection: | ||
connection.execute(text(f"SELECT * FROM ({self.sql}) AS subquery LIMIT 1 OFFSET 0")) | ||
|
||
return True | ||
except Exception: | ||
return False | ||
|
||
def _get_schema(self) -> Schema: | ||
with create_engine(self.url).connect() as connection: | ||
sql = f"SELECT * FROM ({self.sql}) AS subquery" | ||
if self._limit_supported: | ||
sql += " LIMIT 1 OFFSET 0" | ||
|
||
result = connection.execute(text(sql)) | ||
|
||
# Fetch the cursor from the result proxy to access column descriptions | ||
cursor = result.cursor | ||
|
||
rows = cursor.fetchall() | ||
columns = [column_description[0] for column_description in cursor.description] | ||
|
||
pydict = {column: [row[i] for row in rows] for i, column in enumerate(columns)} | ||
pa_table = pa.Table.from_pydict(pydict) | ||
|
||
return Schema.from_pyarrow_schema(pa_table.schema) | ||
|
||
def _get_num_rows(self) -> int: | ||
with create_engine(self.url).connect() as connection: | ||
result = connection.execute(text(f"SELECT COUNT(*) FROM ({self.sql}) AS subquery")) | ||
cursor = result.cursor | ||
return cursor.fetchone()[0] | ||
|
||
def schema(self) -> Schema: | ||
return self._schema | ||
|
||
def display_name(self) -> str: | ||
return f"SQLScanOperator({self.sql})" | ||
|
||
def partitioning_keys(self) -> list[PartitionField]: | ||
return [] | ||
|
||
def multiline_display(self) -> list[str]: | ||
return [ | ||
self.display_name(), | ||
f"Schema = {self._schema}", | ||
] | ||
|
||
def to_scan_tasks(self, pushdowns: Pushdowns) -> Iterator[ScanTask]: | ||
if not self._limit_supported: | ||
file_format_config = FileFormatConfig.from_database_config(DatabaseSourceConfig(self.sql)) | ||
return iter( | ||
[ | ||
ScanTask.sql_scan_task( | ||
url=self.url, | ||
file_format=file_format_config, | ||
schema=self._schema._schema, | ||
storage_config=self.storage_config, | ||
pushdowns=pushdowns, | ||
) | ||
] | ||
) | ||
|
||
total_rows = self._get_num_rows() | ||
num_scan_tasks = math.ceil(total_rows / self.MIN_ROWS_PER_SCAN_TASK) | ||
num_rows_per_scan_task = total_rows // num_scan_tasks | ||
|
||
scan_tasks = [] | ||
offset = 0 | ||
for _ in range(num_scan_tasks): | ||
limit = min(num_rows_per_scan_task, total_rows - offset) | ||
file_format_config = FileFormatConfig.from_database_config( | ||
DatabaseSourceConfig(self.sql, limit=limit, offset=offset) | ||
) | ||
|
||
scan_tasks.append( | ||
ScanTask.sql_scan_task( | ||
url=self.url, | ||
file_format=file_format_config, | ||
schema=self._schema._schema, | ||
storage_config=self.storage_config, | ||
pushdowns=pushdowns, | ||
) | ||
) | ||
offset += limit | ||
|
||
return iter(scan_tasks) | ||
|
||
def can_absorb_filter(self) -> bool: | ||
return False | ||
|
||
def can_absorb_limit(self) -> bool: | ||
return False | ||
|
||
def can_absorb_select(self) -> bool: | ||
return True |
Oops, something went wrong.