From d69024cc9894bd0cb1272aec6ffa0d235c894a1f Mon Sep 17 00:00:00 2001 From: xcharleslin <4212216+xcharleslin@users.noreply.github.com> Date: Mon, 9 Oct 2023 16:50:40 -0700 Subject: [PATCH 01/14] [FEAT] Streaming CSV reads (#1479) We never had streaming CSV reads at the execution layer, so we would be reading a whole CSV file regardless of how many rows needed. Now we use pyarrow CSV stream reader to get record batches, and from those only get the rows we need. Co-authored-by: Xiayue Charles Lin --- daft/table/table.py | 6 ++++++ daft/table/table_io.py | 46 ++++++++++++++++++++++++++++++++++++------ 2 files changed, 46 insertions(+), 6 deletions(-) diff --git a/daft/table/table.py b/daft/table/table.py index 4ab8e490b7..fb7366bfb8 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -112,6 +112,12 @@ def from_arrow(arrow_table: pa.Table) -> Table: pyt = _PyTable.from_arrow_record_batches(arrow_table.to_batches(), schema._schema) return Table._from_pytable(pyt) + @staticmethod + def from_arrow_record_batches(rbs: list[pa.RecordBatch], arrow_schema: pa.Schema) -> Table: + schema = Schema._from_field_name_and_types([(f.name, DataType.from_arrow_type(f.type)) for f in arrow_schema]) + pyt = _PyTable.from_arrow_record_batches(rbs, schema._schema) + return Table._from_pytable(pyt) + @staticmethod def from_pandas(pd_df: pd.DataFrame) -> Table: if not _PANDAS_AVAILABLE: diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 42d32ab477..4cb0276c34 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -175,6 +175,17 @@ def read_parquet( return _cast_table_to_schema(Table.from_arrow(table), read_options=read_options, schema=schema) +class PACSVStreamHelper: + def __init__(self, stream: pa.CSVStreamReader) -> None: + self.stream = stream + + def __next__(self) -> pa.RecordBatch: + return self.stream.read_next_batch() + + def __iter__(self) -> PACSVStreamHelper: + return self + + def read_csv( file: FileInput, schema: Schema, @@ -219,7 +230,7 @@ def read_csv( fs = None with _open_stream(file, fs) as f: - table = pacsv.read_csv( + pacsv_stream = pacsv.open_csv( f, parse_options=pacsv.ParseOptions( delimiter=csv_options.delimiter, @@ -238,11 +249,34 @@ def read_csv( ), ) - # TODO(jay): Can't limit number of rows with current PyArrow filesystem so we have to shave it off after the read - if read_options.num_rows is not None: - table = table[: read_options.num_rows] - - return _cast_table_to_schema(Table.from_arrow(table), read_options=read_options, schema=schema) + if read_options.num_rows is not None: + rows_left = read_options.num_rows + pa_batches = [] + pa_schema = None + for record_batch in PACSVStreamHelper(pacsv_stream): + if pa_schema is None: + pa_schema = record_batch.schema + if record_batch.num_rows > rows_left: + record_batch = record_batch.slice(0, rows_left) + pa_batches.append(record_batch) + rows_left -= record_batch.num_rows + + # Break needs to be here; always need to process at least one record batch + if rows_left <= 0: + break + + # If source schema isn't determined, then the file was truly empty; set an empty source schema + if pa_schema is None: + pa_schema = pa.schema([]) + + daft_table = Table.from_arrow_record_batches(pa_batches, pa_schema) + assert len(daft_table) <= read_options.num_rows + + else: + pa_table = pacsv_stream.read_all() + daft_table = Table.from_arrow(pa_table) + + return _cast_table_to_schema(daft_table, read_options=read_options, schema=schema) def write_csv( From f2c80110a438db67c8f295c68020d342d94a9e6b Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Mon, 9 Oct 2023 17:33:11 -0700 Subject: [PATCH 02/14] [CHORE] Create SECURITY.md (#1481) --- SECURITY.md | 6 ++++++ 1 file changed, 6 insertions(+) create mode 100644 SECURITY.md diff --git a/SECURITY.md b/SECURITY.md new file mode 100644 index 0000000000..8f152e4373 --- /dev/null +++ b/SECURITY.md @@ -0,0 +1,6 @@ +# Security Policy + +## Reporting a Vulnerability + +Please do not make a Github issue when reporting security issues but email daft-security@eventualcomputing.com. +Thank you! From 439f2bd8239400bb4d7231c53e7d076be5428e98 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 9 Oct 2023 18:01:18 -0700 Subject: [PATCH 03/14] [PERF] Update number of cores on every iteration (#1480) Updates the number of cores available before/after every batch dispatch This should allow us to take advantage of autoscaling of the Ray cluster better as we will schedule larger batches of tasks + more total inflight tasks as the cluster autoscales. --------- Co-authored-by: Jay Chia --- daft/runners/ray_runner.py | 35 ++++++++++++++++++++++++++++------- 1 file changed, 28 insertions(+), 7 deletions(-) diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index a7dd27f355..a0aa7dc31e 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -1,12 +1,13 @@ from __future__ import annotations import threading +import time import uuid from collections import defaultdict from dataclasses import dataclass from datetime import datetime from queue import Queue -from typing import TYPE_CHECKING, Any, Iterable, Iterator +from typing import TYPE_CHECKING, Any, Generator, Iterable, Iterator import pyarrow as pa from loguru import logger @@ -368,6 +369,28 @@ def get_meta(partition: Table) -> PartitionMetadata: return PartitionMetadata.from_table(partition) +def _ray_num_cpus_provider(ttl_seconds: int = 1) -> Generator[int, None, None]: + """Helper that gets the number of CPUs from Ray + + Used as a generator as it provides a guard against calling ray.cluster_resources() + more than once per `ttl_seconds`. + + Example: + >>> p = _ray_num_cpus_provider() + >>> next(p) + """ + last_checked_time = time.time() + last_num_cpus_queried = int(ray.cluster_resources()["CPU"]) + while True: + currtime = time.time() + if currtime - last_checked_time < ttl_seconds: + yield last_num_cpus_queried + else: + last_checked_time = currtime + last_num_cpus_queried = int(ray.cluster_resources()["CPU"]) + yield last_num_cpus_queried + + class Scheduler: def __init__(self, max_task_backlog: int | None) -> None: """ @@ -434,15 +457,11 @@ def _run_plan( # Get executable tasks from plan scheduler. tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) - # Note: For autoscaling clusters, we will probably want to query cores dynamically. - # Keep in mind this call takes about 0.3ms. - cores = int(ray.cluster_resources()["CPU"]) - self.reserved_cores - - max_inflight_tasks = cores + self.max_task_backlog - inflight_tasks: dict[str, PartitionTask[ray.ObjectRef]] = dict() inflight_ref_to_task: dict[ray.ObjectRef, str] = dict() + num_cpus_provider = _ray_num_cpus_provider() + start = datetime.now() profile_filename = ( f"profile_RayRunner.run()_" @@ -456,6 +475,8 @@ def _run_plan( while True: # Loop: Dispatch (get tasks -> batch dispatch). tasks_to_dispatch: list[PartitionTask] = [] + cores: int = next(num_cpus_provider) - self.reserved_cores + max_inflight_tasks = cores + self.max_task_backlog dispatches_allowed = max_inflight_tasks - len(inflight_tasks) dispatches_allowed = min(cores, dispatches_allowed) From 65749925ac97c0506096c65e85f13b51a6373a97 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Tue, 10 Oct 2023 15:32:06 -0700 Subject: [PATCH 04/14] [CHORE] Add tests and fixes for Azure globbing (#1482) Closes: #1468 Co-authored-by: Jay Chia --- tests/integration/io/test_list_files_azure.py | 71 +++++++++++++++++++ 1 file changed, 71 insertions(+) create mode 100644 tests/integration/io/test_list_files_azure.py diff --git a/tests/integration/io/test_list_files_azure.py b/tests/integration/io/test_list_files_azure.py new file mode 100644 index 0000000000..b14a198e56 --- /dev/null +++ b/tests/integration/io/test_list_files_azure.py @@ -0,0 +1,71 @@ +from __future__ import annotations + +import adlfs +import pytest + +from daft.daft import AzureConfig, IOConfig, io_glob + +STORAGE_ACCOUNT = "dafttestdata" +CONTAINER = "public-anonymous" +DEFAULT_AZURE_CONFIG = AzureConfig(storage_account=STORAGE_ACCOUNT, anonymous=True) + + +def adlfs_recursive_list(fs, path) -> list: + all_results = [] + curr_level_result = fs.ls(path.replace("az://", ""), detail=True) + for item in curr_level_result: + if item["type"] == "directory": + new_path = f'az://{item["name"]}' + all_results.extend(adlfs_recursive_list(fs, new_path)) + item["name"] += "/" + all_results.append(item) + else: + all_results.append(item) + return all_results + + +def compare_az_result(daft_ls_result: list, fsspec_result: list): + daft_files = [(f["path"], f["type"].lower()) for f in daft_ls_result] + azfs_files = [(f"az://{f['name']}", f["type"]) for f in fsspec_result] + + # Remove all directories: our glob utilities don't return dirs + azfs_files = [(path, type_) for path, type_ in azfs_files if type_ == "file"] + + assert len(daft_files) == len(azfs_files) + assert sorted(daft_files) == sorted(azfs_files) + + +@pytest.mark.integration() +@pytest.mark.parametrize( + "path", + [ + f"az://{CONTAINER}", + f"az://{CONTAINER}/", + f"az://{CONTAINER}/test_ls/", + f"az://{CONTAINER}/test_ls//", + ], +) +@pytest.mark.parametrize("recursive", [False, True]) +@pytest.mark.parametrize("fanout_limit", [None, 1]) +def test_az_flat_directory_listing(path, recursive, fanout_limit): + fs = adlfs.AzureBlobFileSystem(account_name=STORAGE_ACCOUNT) + glob_path = path.rstrip("/") + "/**/*.{txt,parquet}" if recursive else path + daft_ls_result = io_glob(glob_path, io_config=IOConfig(azure=DEFAULT_AZURE_CONFIG), fanout_limit=fanout_limit) + fsspec_result = adlfs_recursive_list(fs, path) if recursive else fs.ls(path.replace("az://", ""), detail=True) + compare_az_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_az_single_file_listing(): + path = f"az://{CONTAINER}/mvp.parquet" + fs = adlfs.AzureBlobFileSystem(account_name=STORAGE_ACCOUNT) + daft_ls_result = io_glob(path, io_config=IOConfig(azure=DEFAULT_AZURE_CONFIG)) + fsspec_result = fs.ls(path.replace("az://", ""), detail=True) + compare_az_result(daft_ls_result, fsspec_result) + + +@pytest.mark.integration() +def test_az_notfound(): + path = f"az://{CONTAINER}/test_" + with pytest.raises(FileNotFoundError, match=path): + io_glob(path, io_config=IOConfig(azure=DEFAULT_AZURE_CONFIG)) From a24e918c6f40618ae947274291d4cfc18174cb90 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Wed, 11 Oct 2023 11:26:06 -0700 Subject: [PATCH 05/14] [PERF] Pass-through multithreaded_io flag in read_parquet (#1484) Passes the `multithreaded_io=False` flag through when running on the Ray Runner for read_parquet --------- Co-authored-by: Jay Chia --- daft/daft.pyi | 5 +++++ daft/execution/execution_step.py | 1 + daft/io/_parquet.py | 17 ++++++++++++----- daft/table/table_io.py | 2 ++ src/common/io-config/src/python.rs | 2 +- src/daft-io/src/lib.rs | 20 ++++++++++++++++++-- src/daft-io/src/s3_like.rs | 7 +++++-- src/daft-plan/src/source_info/file_format.rs | 13 ++++++++++--- 8 files changed, 54 insertions(+), 13 deletions(-) diff --git a/daft/daft.pyi b/daft/daft.pyi index d8c843080b..786d00e374 100644 --- a/daft/daft.pyi +++ b/daft/daft.pyi @@ -178,6 +178,11 @@ class ParquetSourceConfig: Configuration of a Parquet data source. """ + # Whether or not to use a multithreaded tokio runtime for processing I/O + multithreaded_io: bool + + def __init__(self, multithreaded_io: bool): ... + class CsvSourceConfig: """ Configuration of a CSV data source. diff --git a/daft/execution/execution_step.py b/daft/execution/execution_step.py index 896c7f3e8b..37052b39c6 100644 --- a/daft/execution/execution_step.py +++ b/daft/execution/execution_step.py @@ -401,6 +401,7 @@ def _handle_tabular_files_scan( schema=self.schema, storage_config=self.storage_config, read_options=read_options, + multithreaded_io=format_config.multithreaded_io, ) for fp in filepaths ] diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index 1e6558dd9b..0228cc1d0a 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -1,12 +1,14 @@ # isort: dont-add-import: from __future__ import annotations -from typing import TYPE_CHECKING, Dict, List, Optional, Union +from typing import Dict, List, Optional, Union import fsspec +from daft import context from daft.api_annotations import PublicAPI from daft.daft import ( FileFormatConfig, + IOConfig, NativeStorageConfig, ParquetSourceConfig, PythonStorageConfig, @@ -16,9 +18,6 @@ from daft.datatype import DataType from daft.io.common import _get_tabular_files_scan -if TYPE_CHECKING: - from daft.io import IOConfig - @PublicAPI def read_parquet( @@ -53,7 +52,15 @@ def read_parquet( if isinstance(path, list) and len(path) == 0: raise ValueError(f"Cannot read DataFrame from from empty list of Parquet filepaths") - file_format_config = FileFormatConfig.from_parquet_config(ParquetSourceConfig()) + # If running on Ray, we want to limit the amount of concurrency and requests being made. + # This is because each Ray worker process receives its own pool of thread workers and connections + multithreaded_io = not context.get_context().is_ray_runner + + file_format_config = FileFormatConfig.from_parquet_config( + ParquetSourceConfig( + multithreaded_io=multithreaded_io, + ) + ) if use_native_downloader: storage_config = StorageConfig.native(NativeStorageConfig(io_config)) else: diff --git a/daft/table/table_io.py b/daft/table/table_io.py index 4cb0276c34..b12ffb9feb 100644 --- a/daft/table/table_io.py +++ b/daft/table/table_io.py @@ -106,6 +106,7 @@ def read_parquet( storage_config: StorageConfig | None = None, read_options: TableReadOptions = TableReadOptions(), parquet_options: TableParseParquetOptions = TableParseParquetOptions(), + multithreaded_io: bool | None = None, ) -> Table: """Reads a Table from a Parquet file @@ -130,6 +131,7 @@ def read_parquet( num_rows=read_options.num_rows, io_config=config.io_config, coerce_int96_timestamp_unit=parquet_options.coerce_int96_timestamp_unit, + multithreaded_io=multithreaded_io, ) return _cast_table_to_schema(tbl, read_options=read_options, schema=schema) diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 943693e410..91aed62b28 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -211,7 +211,7 @@ impl S3Config { Ok(self.config.access_key.clone()) } - /// AWS max connections + /// AWS max connections per IO thread #[getter] pub fn max_connections(&self) -> PyResult { Ok(self.config.max_connections) diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 417313f047..3f14834fb3 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -17,6 +17,7 @@ pub use common_io_config::{AzureConfig, IOConfig, S3Config}; pub use object_io::GetResult; #[cfg(feature = "python")] pub use python::register_modules; +use tokio::runtime::RuntimeFlavor; use std::{borrow::Cow, collections::HashMap, hash::Hash, ops::Range, sync::Arc}; @@ -261,16 +262,17 @@ pub fn parse_url(input: &str) -> Result<(SourceType, Cow<'_, str>)> { type CacheKey = (bool, Arc); lazy_static! { static ref NUM_CPUS: usize = std::thread::available_parallelism().unwrap().get(); + static ref THREADED_RUNTIME_NUM_WORKER_THREADS: usize = 8.min(*NUM_CPUS); static ref THREADED_RUNTIME: tokio::sync::RwLock<(Arc, usize)> = tokio::sync::RwLock::new(( Arc::new( tokio::runtime::Builder::new_multi_thread() - .worker_threads(8.min(*NUM_CPUS)) + .worker_threads(*THREADED_RUNTIME_NUM_WORKER_THREADS) .enable_all() .build() .unwrap() ), - 8.min(*NUM_CPUS) + *THREADED_RUNTIME_NUM_WORKER_THREADS, )); static ref CLIENT_CACHE: tokio::sync::RwLock>> = tokio::sync::RwLock::new(HashMap::new()); @@ -332,6 +334,20 @@ pub fn set_io_pool_num_threads(num_threads: usize) -> bool { true } +pub async fn get_io_pool_num_threads() -> Option { + match tokio::runtime::Handle::try_current() { + Ok(handle) => { + match handle.runtime_flavor() { + RuntimeFlavor::CurrentThread => Some(1), + RuntimeFlavor::MultiThread => Some(THREADED_RUNTIME.read().await.1), + // RuntimeFlavor is #non_exhaustive, so we default to 1 here to be conservative + _ => Some(1), + } + } + Err(_) => None, + } +} + pub fn _url_download( array: &Utf8Array, max_connections: usize, diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 3a549d2afb..46b9831712 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -9,7 +9,7 @@ use s3::operation::list_objects_v2::ListObjectsV2Error; use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit}; use crate::object_io::{FileMetadata, FileType, LSResult}; -use crate::{InvalidArgumentSnafu, SourceType}; +use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; use aws_credential_types::cache::ProvideCachedCredentials; use aws_credential_types::provider::error::CredentialsError; @@ -311,7 +311,10 @@ async fn build_client(config: &S3Config) -> super::Result { Ok(S3LikeSource { region_to_client_map: tokio::sync::RwLock::new(client_map), connection_pool_sema: Arc::new(tokio::sync::Semaphore::new( - config.max_connections as usize, + (config.max_connections as usize) + * get_io_pool_num_threads() + .await + .expect("Should be running in tokio pool"), )), s3_config: config.clone(), default_region, diff --git a/src/daft-plan/src/source_info/file_format.rs b/src/daft-plan/src/source_info/file_format.rs index fdcb483e61..1e4988f4d5 100644 --- a/src/daft-plan/src/source_info/file_format.rs +++ b/src/daft-plan/src/source_info/file_format.rs @@ -52,15 +52,22 @@ impl FileFormatConfig { /// Configuration for a Parquet data source. #[derive(Clone, Debug, PartialEq, Eq, Serialize, Deserialize, Hash)] #[cfg_attr(feature = "python", pyclass(module = "daft.daft"))] -pub struct ParquetSourceConfig; +pub struct ParquetSourceConfig { + multithreaded_io: bool, +} #[cfg(feature = "python")] #[pymethods] impl ParquetSourceConfig { /// Create a config for a Parquet data source. #[new] - fn new() -> Self { - Self {} + fn new(multithreaded_io: bool) -> Self { + Self { multithreaded_io } + } + + #[getter] + fn multithreaded_io(&self) -> PyResult { + Ok(self.multithreaded_io) } } From 8ad7fda98e1c182b200f7184531883e23c213359 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Wed, 11 Oct 2023 15:28:06 -0700 Subject: [PATCH 06/14] [PERF] Update default max_connections 64->8 because it is now per-io-thread (#1485) Updates default max_connections value from 64 to 8 Also renames `max_connections` in internal APIs to `max_connections_per_io_thread` to be more explicit, but keeps naming for external-facing APIs for backwards compatibility Note that the total number of connections being spawned for PyRunner is: `8.min(Num CPUs) * max_connections`, and theses are shared throughout the multithreaded backend The total number of connections being spawned for RayRunner after #1484 is: `num_ray_workers * 1 (sine we run single-threaded) * max_connections` --------- Co-authored-by: Jay Chia --- daft/io/_parquet.py | 6 +++++- src/common/io-config/src/python.rs | 5 +++-- src/common/io-config/src/s3.rs | 6 +++--- src/daft-io/src/s3_like.rs | 2 +- 4 files changed, 12 insertions(+), 7 deletions(-) diff --git a/daft/io/_parquet.py b/daft/io/_parquet.py index 0228cc1d0a..8a824867d0 100644 --- a/daft/io/_parquet.py +++ b/daft/io/_parquet.py @@ -26,6 +26,7 @@ def read_parquet( fs: Optional[fsspec.AbstractFileSystem] = None, io_config: Optional["IOConfig"] = None, use_native_downloader: bool = False, + _multithreaded_io: Optional[bool] = None, ) -> DataFrame: """Creates a DataFrame from Parquet file(s) @@ -44,6 +45,9 @@ def read_parquet( io_config (IOConfig): Config to be used with the native downloader use_native_downloader: Whether to use the native downloader instead of PyArrow for reading Parquet. This is currently experimental. + _multithreaded_io: Whether to use multithreading for IO threads. Setting this to False can be helpful in reducing + the amount of system resources (number of connections and thread contention) when running in the Ray runner. + Defaults to None, which will let Daft decide based on the runner it is currently using. returns: DataFrame: parsed DataFrame @@ -54,7 +58,7 @@ def read_parquet( # If running on Ray, we want to limit the amount of concurrency and requests being made. # This is because each Ray worker process receives its own pool of thread workers and connections - multithreaded_io = not context.get_context().is_ray_runner + multithreaded_io = not context.get_context().is_ray_runner if _multithreaded_io is None else _multithreaded_io file_format_config = FileFormatConfig.from_parquet_config( ParquetSourceConfig( diff --git a/src/common/io-config/src/python.rs b/src/common/io-config/src/python.rs index 91aed62b28..ccea7f34b1 100644 --- a/src/common/io-config/src/python.rs +++ b/src/common/io-config/src/python.rs @@ -163,7 +163,8 @@ impl S3Config { key_id: key_id.or(def.key_id), session_token: session_token.or(def.session_token), access_key: access_key.or(def.access_key), - max_connections: max_connections.unwrap_or(def.max_connections), + max_connections_per_io_thread: max_connections + .unwrap_or(def.max_connections_per_io_thread), retry_initial_backoff_ms: retry_initial_backoff_ms .unwrap_or(def.retry_initial_backoff_ms), connect_timeout_ms: connect_timeout_ms.unwrap_or(def.connect_timeout_ms), @@ -214,7 +215,7 @@ impl S3Config { /// AWS max connections per IO thread #[getter] pub fn max_connections(&self) -> PyResult { - Ok(self.config.max_connections) + Ok(self.config.max_connections_per_io_thread) } /// AWS Retry Initial Backoff Time in Milliseconds diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index f7ab25095c..cd35c3f49c 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -11,7 +11,7 @@ pub struct S3Config { pub key_id: Option, pub session_token: Option, pub access_key: Option, - pub max_connections: u32, + pub max_connections_per_io_thread: u32, pub retry_initial_backoff_ms: u64, pub connect_timeout_ms: u64, pub read_timeout_ms: u64, @@ -30,7 +30,7 @@ impl Default for S3Config { key_id: None, session_token: None, access_key: None, - max_connections: 64, + max_connections_per_io_thread: 8, retry_initial_backoff_ms: 1000, connect_timeout_ms: 10_000, read_timeout_ms: 10_000, @@ -68,7 +68,7 @@ impl Display for S3Config { self.session_token, self.access_key, self.retry_initial_backoff_ms, - self.max_connections, + self.max_connections_per_io_thread, self.connect_timeout_ms, self.read_timeout_ms, self.num_tries, diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 46b9831712..4638077eb4 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -311,7 +311,7 @@ async fn build_client(config: &S3Config) -> super::Result { Ok(S3LikeSource { region_to_client_map: tokio::sync::RwLock::new(client_map), connection_pool_sema: Arc::new(tokio::sync::Semaphore::new( - (config.max_connections as usize) + (config.max_connections_per_io_thread as usize) * get_io_pool_num_threads() .await .expect("Should be running in tokio pool"), From 84fcc7f5003c82ddaca43cbc9241ae7f9e44431e Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Thu, 12 Oct 2023 15:44:36 -0700 Subject: [PATCH 07/14] [CHORE] Update default num_tries on S3Config to 25 (#1487) We were seeing some issues with timeouts and retries Co-authored-by: Jay Chia --- src/common/io-config/src/s3.rs | 4 +++- 1 file changed, 3 insertions(+), 1 deletion(-) diff --git a/src/common/io-config/src/s3.rs b/src/common/io-config/src/s3.rs index cd35c3f49c..972bc0a761 100644 --- a/src/common/io-config/src/s3.rs +++ b/src/common/io-config/src/s3.rs @@ -34,7 +34,9 @@ impl Default for S3Config { retry_initial_backoff_ms: 1000, connect_timeout_ms: 10_000, read_timeout_ms: 10_000, - num_tries: 5, + // AWS EMR actually does 100 tries by default for AIMD retries + // (See [Advanced AIMD retry settings]: https://docs.aws.amazon.com/emr/latest/ReleaseGuide/emr-spark-emrfs-retry.html) + num_tries: 25, retry_mode: Some("adaptive".to_string()), anonymous: false, verify_ssl: true, From fac11a4c08e37c5df8364f547cc2557a06c6f969 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Thu, 12 Oct 2023 22:01:26 -0700 Subject: [PATCH 08/14] [PERF] Use region from system and leverage cached credentials when making new clients (#1490) * Fixes 2 issues: * We always set to default region even when the system can provide us creds * We reran the credential chain even though we can leverage a cache --- src/daft-io/src/lib.rs | 1 + src/daft-io/src/s3_like.rs | 39 +++++++++++++++++++++++++++++++------- 2 files changed, 33 insertions(+), 7 deletions(-) diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 3f14834fb3..15179102e9 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -1,5 +1,6 @@ #![feature(async_closure)] #![feature(let_chains)] + mod azure_blob; mod google_cloud; mod http; diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 4638077eb4..7681201fb6 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -1,4 +1,5 @@ use async_trait::async_trait; +use aws_config::meta::credentials::CredentialsProviderChain; use aws_config::retry::RetryMode; use aws_config::timeout::TimeoutConfig; use aws_smithy_async::rt::sleep::TokioSleep; @@ -11,7 +12,7 @@ use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit}; use crate::object_io::{FileMetadata, FileType, LSResult}; use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; -use aws_credential_types::cache::ProvideCachedCredentials; +use aws_credential_types::cache::{ProvideCachedCredentials, SharedCredentialsCache}; use aws_credential_types::provider::error::CredentialsError; use aws_sig_auth::signer::SigningRequirements; use common_io_config::S3Config; @@ -211,7 +212,10 @@ fn handle_https_client_settings( Ok(builder) } -async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> { +async fn build_s3_client( + config: &S3Config, + credentials_cache: Option, +) -> super::Result<(bool, s3::Client)> { const DEFAULT_REGION: Region = Region::from_static("us-east-1"); let mut anonymous = config.anonymous; @@ -228,8 +232,6 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> }; let builder = if let Some(region) = &config.region_name { builder.region(Region::new(region.to_owned())) - } else if conf.region().is_none() && config.region_name.is_none() { - builder.region(DEFAULT_REGION) } else { builder }; @@ -268,7 +270,19 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> .build(); let builder = builder.timeout_config(timeout_config); - let builder = if config.access_key.is_some() && config.key_id.is_some() { + let cached_creds = if let Some(credentials_cache) = credentials_cache { + let creds = credentials_cache.provide_cached_credentials().await; + creds.ok() + } else { + None + }; + + let builder = if let Some(cached_creds) = cached_creds { + let provider = CredentialsProviderChain::first_try("different_region_cache", cached_creds) + .or_default_provider() + .await; + builder.credentials_provider(provider) + } else if config.access_key.is_some() && config.key_id.is_some() { let creds = Credentials::from_keys( config.key_id.clone().unwrap(), config.access_key.clone().unwrap(), @@ -283,6 +297,7 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> builder }; + let builder_copy = builder.clone(); let s3_conf = builder.build(); if !config.anonymous { use CredentialsError::*; @@ -300,11 +315,16 @@ async fn build_s3_client(config: &S3Config) -> super::Result<(bool, s3::Client)> }.with_context(|_| UnableToLoadCredentialsSnafu {})?; }; + let s3_conf = if s3_conf.region().is_none() { + builder_copy.region(DEFAULT_REGION).build() + } else { + s3_conf + }; Ok((anonymous, s3::Client::from_conf(s3_conf))) } async fn build_client(config: &S3Config) -> super::Result { - let (anonymous, client) = build_s3_client(config).await?; + let (anonymous, client) = build_s3_client(config, None).await?; let mut client_map = HashMap::new(); let default_region = client.conf().region().unwrap().clone(); client_map.insert(default_region.clone(), client.into()); @@ -343,7 +363,12 @@ impl S3LikeSource { let mut new_config = self.s3_config.clone(); new_config.region_name = Some(region.to_string()); - let (_, new_client) = build_s3_client(&new_config).await?; + + let creds_cache = w_handle + .get(&self.default_region) + .map(|current_client| current_client.conf().credentials_cache()); + + let (_, new_client) = build_s3_client(&new_config, creds_cache).await?; if w_handle.get(region).is_none() { w_handle.insert(region.clone(), new_client.into()); From 8b5386b10cf3a2eaf62d3384d87531db643f1545 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Thu, 12 Oct 2023 23:03:36 -0700 Subject: [PATCH 09/14] [CHORE] Add Workflow to build artifacts and upload to S3 (#1491) --- .github/workflows/build-artifact-s3.yml | 95 +++++++++++++++++++++++++ .github/workflows/python-package.yml | 1 - 2 files changed, 95 insertions(+), 1 deletion(-) create mode 100644 .github/workflows/build-artifact-s3.yml diff --git a/.github/workflows/build-artifact-s3.yml b/.github/workflows/build-artifact-s3.yml new file mode 100644 index 0000000000..2fdf166c7c --- /dev/null +++ b/.github/workflows/build-artifact-s3.yml @@ -0,0 +1,95 @@ +name: daft-build-artifact-s3 + +on: + workflow_dispatch: + inputs: + rust-profile: + description: Profile to compile with + required: true + default: release-lto + type: choice + options: + - release-lto + - release + +env: + PACKAGE_NAME: getdaft + PYTHON_VERSION: 3.8 + +jobs: + build-and-push: + name: platform wheels for ${{ matrix.os }}-${{ matrix.compile_arch }} + runs-on: ${{ matrix.os }}-latest + strategy: + fail-fast: false + matrix: + os: [ubuntu] + compile_arch: [x86_64, aarch64] + # These permissions are needed to interact with GitHub's OIDC Token endpoint. + # This is used in the step "Assume GitHub Actions AWS Credentials" + permissions: + id-token: write + contents: read + steps: + - uses: actions/checkout@v4 + with: + submodules: true + fetch-depth: 0 + - name: Assume GitHub Actions AWS Credentials + uses: aws-actions/configure-aws-credentials@v3 + with: + aws-region: us-west-2 + role-to-assume: ${{ secrets.ACTIONS_AWS_ROLE_ARN }} + role-session-name: DaftPythonPackageGitHubWorkflow + - uses: actions/setup-python@v4 + with: + python-version: ${{ env.PYTHON_VERSION }} + architecture: x64 + - run: pip install -U toml + - run: python tools/patch_package_version.py + - name: Build wheels - Linux x86 + if: ${{ (matrix.os == 'ubuntu') && (matrix.compile_arch == 'x86_64') }} + uses: messense/maturin-action@v1 + with: + target: x86_64 + manylinux: auto + args: --profile ${{ inputs.rust-profile }} --out dist + before-script-linux: yum -y install perl-IPC-Cmd + env: + RUSTFLAGS: -C target-feature=+fxsr,+sse,+sse2,+sse3,+ssse3,+sse4.1,+sse4.2,+popcnt,+avx,+fma + - name: Build wheels - Linux aarch64 + if: ${{ (matrix.os == 'ubuntu') && (matrix.compile_arch == 'aarch64') }} + uses: messense/maturin-action@v1 + with: + target: aarch64-unknown-linux-gnu + manylinux: auto + # GCC 4.8.5 in manylinux2014 container doesn't support c11 atomic. This caused issues with the `ring` crate that causes TLS to fail + container: messense/manylinux_2_24-cross:aarch64 + args: --profile ${{ inputs.rust-profile }} --out dist + before-script-linux: export JEMALLOC_SYS_WITH_LG_PAGE=16 + + - name: Copy all files as zip for Glue + run: for foo in dist/*.whl; do cp $foo dist/`basename $foo .whl`.zip; done + + - name: Upload wheels to s3 + run: aws s3 cp dist/* s3://github-actions-artifacts-bucket/daft-build-artifact-s3/${{ github.sha }}/ --no-progress + + + list-wheels: + name: List Wheels and Zip files Published to S3 + runs-on: ubuntu-latest + needs: + - build-and-push + + permissions: + id-token: write + contents: read + steps: + - name: Assume GitHub Actions AWS Credentials + uses: aws-actions/configure-aws-credentials@v3 + with: + aws-region: us-west-2 + role-to-assume: ${{ secrets.ACTIONS_AWS_ROLE_ARN }} + role-session-name: DaftPythonPackageGitHubWorkflow + - name: List Wheels + run: aws s3 ls s3://github-actions-artifacts-bucket/daft-build-artifact-s3/${{ github.sha }}/ diff --git a/.github/workflows/python-package.yml b/.github/workflows/python-package.yml index 62ac913ebe..e7fc18cc5a 100644 --- a/.github/workflows/python-package.yml +++ b/.github/workflows/python-package.yml @@ -8,7 +8,6 @@ on: branches: [main] pull_request: branches: [main] - env: DAFT_ANALYTICS_ENABLED: '0' From 167bb77f372ceea8846a004b36b1cdf25f0bc307 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Thu, 12 Oct 2023 23:30:12 -0700 Subject: [PATCH 10/14] [BUG] fix script to upload file 1 at a time (#1492) --- .github/workflows/build-artifact-s3.yml | 7 +++---- 1 file changed, 3 insertions(+), 4 deletions(-) diff --git a/.github/workflows/build-artifact-s3.yml b/.github/workflows/build-artifact-s3.yml index 2fdf166c7c..7bb8756a3b 100644 --- a/.github/workflows/build-artifact-s3.yml +++ b/.github/workflows/build-artifact-s3.yml @@ -67,12 +67,11 @@ jobs: container: messense/manylinux_2_24-cross:aarch64 args: --profile ${{ inputs.rust-profile }} --out dist before-script-linux: export JEMALLOC_SYS_WITH_LG_PAGE=16 - - name: Copy all files as zip for Glue - run: for foo in dist/*.whl; do cp $foo dist/`basename $foo .whl`.zip; done + run: for file in dist/*.whl; do cp $file dist/`basename $file .whl`.zip; done + - name: Upload files to s3 + run: for file in dist/*; do aws s3 cp $file s3://github-actions-artifacts-bucket/daft-build-artifact-s3/${{ github.sha }}/ --no-progress; done - - name: Upload wheels to s3 - run: aws s3 cp dist/* s3://github-actions-artifacts-bucket/daft-build-artifact-s3/${{ github.sha }}/ --no-progress list-wheels: From e1e5eaf4789aff982767396d9493c9a61f628a04 Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Fri, 13 Oct 2023 12:00:01 -0700 Subject: [PATCH 11/14] [PERF] Use pyarrow table for pickling rather than ChunkedArray (#1488) * Use pyarrow table for pickling rather than ChunkedArray so we can exploit ray's special pickling for arrow tables which doesn't work on ChunkedArray --- daft/series.py | 36 +++++++++++++++++++++++++++++++++++- 1 file changed, 35 insertions(+), 1 deletion(-) diff --git a/daft/series.py b/daft/series.py index 2430dbf31d..ece133c060 100644 --- a/daft/series.py +++ b/daft/series.py @@ -1,5 +1,6 @@ from __future__ import annotations +import sys from typing import TypeVar import pyarrow as pa @@ -490,8 +491,41 @@ def image(self) -> SeriesImageNamespace: def __reduce__(self) -> tuple: if self.datatype()._is_python_type(): return (Series.from_pylist, (self.to_pylist(), self.name(), "force")) - else: + elif sys.platform == "win32": return (Series.from_arrow, (self.to_arrow(), self.name())) + else: + # Ray Special CloudPickling fast path. + # Only run for Linux and Mac, since windows runs slower for some reason + return ( + Series._from_arrow_table_to_series, + self._to_arrow_table_for_serdes(), + ) + + def _to_arrow_table_for_serdes(self) -> tuple[pa.Table, pa.ExtensionType | None]: + array = self.to_arrow() + if len(array) == 0: + # This is a workaround for: + # pyarrow.lib.ArrowIndexError: buffer slice would exceed buffer length + # when we have 0 length arrays + array = pa.array([], type=array.type) + + if isinstance(array.type, pa.BaseExtensionType): + stype = array.type.storage_type + ltype = array.type + storage_array = array.cast(stype) + return (pa.table({self.name(): storage_array}), ltype) + else: + return (pa.table({self.name(): array}), None) + + @classmethod + def _from_arrow_table_to_series(cls, table: pa.Table, extension_type: pa.ExtensionType | None) -> Series: + # So we can exploit ray's special pickling for arrow tables which doesn't work on pyarrow arrays + assert table.num_columns == 1 + [name] = table.column_names + [array] = table.columns + if extension_type is not None: + array = extension_type.wrap_array(array) + return cls.from_arrow(array, name) SomeSeriesNamespace = TypeVar("SomeSeriesNamespace", bound="SeriesNamespace") From 4d9f395b9f3c2854b53f5f42a00c29f40a7f2c46 Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 16 Oct 2023 11:26:10 +0900 Subject: [PATCH 12/14] [BUG] Fix local globbing of current directory (#1494) Co-authored-by: Jay Chia --- src/daft-io/src/local.rs | 25 +++++++++++++++++-------- tests/io/test_list_files_local.py | 26 ++++++++++++++++++++++++++ 2 files changed, 43 insertions(+), 8 deletions(-) diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index 67130b2fc0..7d7b1d3f0f 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -185,15 +185,24 @@ impl ObjectSource for LocalSource { } const LOCAL_PROTOCOL: &str = "file://"; - let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { + let uri = if uri.is_empty() { + std::borrow::Cow::Owned( + std::env::current_dir() + .with_context(|_| UnableToFetchDirectoryEntriesSnafu { path: uri })? + .to_string_lossy() + .to_string(), + ) + } else if let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) { + std::borrow::Cow::Borrowed(uri) + } else { return Err(Error::InvalidFilePath { path: uri.into() }.into()); }; - let meta = - tokio::fs::metadata(uri) - .await - .with_context(|_| UnableToFetchFileMetadataSnafu { - path: uri.to_string(), - })?; + + let meta = tokio::fs::metadata(uri.as_ref()).await.with_context(|_| { + UnableToFetchFileMetadataSnafu { + path: uri.to_string(), + } + })?; if meta.file_type().is_file() { // Provided uri points to a file, so only return that file. return Ok(futures::stream::iter([Ok(FileMetadata { @@ -203,7 +212,7 @@ impl ObjectSource for LocalSource { })]) .boxed()); } - let dir_entries = tokio::fs::read_dir(uri).await.with_context(|_| { + let dir_entries = tokio::fs::read_dir(uri.as_ref()).await.with_context(|_| { UnableToFetchDirectoryEntriesSnafu { path: uri.to_string(), } diff --git a/tests/io/test_list_files_local.py b/tests/io/test_list_files_local.py index e18f2acaf8..b9aa6d443a 100644 --- a/tests/io/test_list_files_local.py +++ b/tests/io/test_list_files_local.py @@ -1,5 +1,7 @@ from __future__ import annotations +import os + import pytest from fsspec.implementations.local import LocalFileSystem @@ -50,6 +52,30 @@ def test_flat_directory_listing(tmp_path, include_protocol): compare_local_result(daft_ls_result, fs_result) +@pytest.mark.parametrize("include_protocol", [False, True]) +def test_recursive_curr_dir_listing(tmp_path, include_protocol): + d = tmp_path / "dir" + d.mkdir() + files = ["a", "b", "c"] + for name in files: + p = d / name + p.touch() + d = str(d) + "/" + + pwd = os.getcwd() + os.chdir(str(d)) + + try: + path = "file://**" if include_protocol else "**" + + daft_ls_result = io_glob(path) + fs = LocalFileSystem() + fs_result = fs.ls(d, detail=True) + compare_local_result(daft_ls_result, fs_result) + finally: + os.chdir(pwd) + + @pytest.mark.parametrize("include_protocol", [False, True]) def test_recursive_directory_listing(tmp_path, include_protocol): d = tmp_path / "dir" From bf5b598b3611b0be442e79fce9e7ed1cf378e76c Mon Sep 17 00:00:00 2001 From: Sammy Sidhu Date: Sun, 15 Oct 2023 21:58:04 -0700 Subject: [PATCH 13/14] [FEAT] IOStats for Native Reader (#1493) * Implements IOStats via atomics that will track the number of: * GetRequests * HeadRequests * ListRequests * BytesRead When the `IOStatsContext` goes out of scope it will debug log its output: For a parquet read: ``` IOStatsContext: read_parquet: for uri s3://eventual-dev-benchmarking-fixtures/parquet-benchmarking/tpch/200MB-2RG/daft_200MB_lineitem_chunk.RG-2.parquet, Gets: 17, Heads: 1, Lists: 0, BytesRead: 213627535, AvgGetSize: 12566325 ``` For our io_glob: ``` IOStatsContext: io_glob for s3://daft-segment-data/segment-logs/**, Gets: 0, Heads: 0, Lists: 237, BytesRead: 0, AvgGetSize: 0 ``` Currently the context is threaded through our python Apis for `io_glob`, `read_parquet`, and `read_csv` IOStats is currently implemented for: * S3 * GCS * HTTP * Azure --- src/daft-csv/src/metadata.rs | 16 ++- src/daft-csv/src/python.rs | 8 +- src/daft-csv/src/read.rs | 28 ++++- src/daft-dsl/src/functions/uri/download.rs | 1 + src/daft-io/src/azure_blob.rs | 61 +++++++++-- src/daft-io/src/google_cloud.rs | 69 ++++++++++--- src/daft-io/src/http.rs | 46 +++++++-- src/daft-io/src/lib.rs | 23 ++++- src/daft-io/src/local.rs | 49 ++++++--- src/daft-io/src/object_io.rs | 26 +++-- src/daft-io/src/object_store_glob.rs | 36 +++++-- src/daft-io/src/python.rs | 12 ++- src/daft-io/src/s3_like.rs | 72 +++++++++++-- src/daft-io/src/stats.rs | 114 +++++++++++++++++++++ src/daft-io/src/stream_utils.rs | 28 +++++ src/daft-parquet/src/file.rs | 22 ++-- src/daft-parquet/src/metadata.rs | 9 +- src/daft-parquet/src/python.rs | 25 ++++- src/daft-parquet/src/read.rs | 46 +++++++-- src/daft-parquet/src/read_planner.rs | 12 ++- 20 files changed, 586 insertions(+), 117 deletions(-) create mode 100644 src/daft-io/src/stats.rs create mode 100644 src/daft-io/src/stream_utils.rs diff --git a/src/daft-csv/src/metadata.rs b/src/daft-csv/src/metadata.rs index 7b4000dee8..f5f572af5c 100644 --- a/src/daft-csv/src/metadata.rs +++ b/src/daft-csv/src/metadata.rs @@ -4,7 +4,7 @@ use arrow2::io::csv::read_async::{infer, infer_schema, AsyncReaderBuilder}; use async_compat::CompatExt; use common_error::DaftResult; use daft_core::schema::Schema; -use daft_io::{get_runtime, GetResult, IOClient}; +use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; use futures::{io::Cursor, AsyncRead, AsyncSeek}; use tokio::fs::File; @@ -13,11 +13,13 @@ pub fn read_csv_schema( has_header: bool, delimiter: Option, io_client: Arc, + io_stats: Option, ) -> DaftResult { let runtime_handle = get_runtime(true)?; let _rt_guard = runtime_handle.enter(); - runtime_handle - .block_on(async { read_csv_schema_single(uri, has_header, delimiter, io_client).await }) + runtime_handle.block_on(async { + read_csv_schema_single(uri, has_header, delimiter, io_client, io_stats).await + }) } async fn read_csv_schema_single( @@ -25,8 +27,12 @@ async fn read_csv_schema_single( has_header: bool, delimiter: Option, io_client: Arc, + io_stats: Option, ) -> DaftResult { - match io_client.single_url_get(uri.to_string(), None).await? { + match io_client + .single_url_get(uri.to_string(), None, io_stats) + .await? + { GetResult::File(file) => { read_csv_schema_from_reader( File::open(file.path).await?.compat(), @@ -77,7 +83,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let schema = read_csv_schema(file, true, None, io_client.clone())?; + let schema = read_csv_schema(file, true, None, io_client.clone(), None)?; assert_eq!( schema, Schema::new(vec![ diff --git a/src/daft-csv/src/python.rs b/src/daft-csv/src/python.rs index a38d2866b0..f45a03b9b5 100644 --- a/src/daft-csv/src/python.rs +++ b/src/daft-csv/src/python.rs @@ -4,7 +4,7 @@ pub mod pylib { use std::sync::Arc; use daft_core::python::schema::PySchema; - use daft_io::{get_io_client, python::IOConfig}; + use daft_io::{get_io_client, python::IOConfig, IOStatsContext}; use daft_table::python::PyTable; use pyo3::{exceptions::PyValueError, pyfunction, PyResult, Python}; @@ -34,6 +34,8 @@ pub mod pylib { multithreaded_io: Option, ) -> PyResult { py.allow_threads(|| { + let io_stats = IOStatsContext::new(format!("read_csv: for uri {uri}")); + let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), @@ -46,6 +48,7 @@ pub mod pylib { has_header.unwrap_or(true), str_delimiter_to_byte(delimiter)?, io_client, + Some(io_stats), multithreaded_io.unwrap_or(true), )? .into()) @@ -62,6 +65,8 @@ pub mod pylib { multithreaded_io: Option, ) -> PyResult { py.allow_threads(|| { + let io_stats = IOStatsContext::new(format!("read_csv_schema: for uri {uri}")); + let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), @@ -71,6 +76,7 @@ pub mod pylib { has_header.unwrap_or(true), str_delimiter_to_byte(delimiter)?, io_client, + Some(io_stats), )?) .into()) }) diff --git a/src/daft-csv/src/read.rs b/src/daft-csv/src/read.rs index 5bf766656a..cc8bdcd63c 100644 --- a/src/daft-csv/src/read.rs +++ b/src/daft-csv/src/read.rs @@ -13,7 +13,7 @@ use arrow2::{ use async_compat::CompatExt; use common_error::DaftResult; use daft_core::{schema::Schema, utils::arrow::cast_array_for_daft_if_needed, Series}; -use daft_io::{get_runtime, GetResult, IOClient}; +use daft_io::{get_runtime, GetResult, IOClient, IOStatsRef}; use daft_table::Table; use futures::{io::Cursor, AsyncRead, AsyncSeek}; use tokio::fs::File; @@ -27,6 +27,7 @@ pub fn read_csv( has_header: bool, delimiter: Option, io_client: Arc, + io_stats: Option, multithreaded_io: bool, ) -> DaftResult { let runtime_handle = get_runtime(multithreaded_io)?; @@ -40,11 +41,13 @@ pub fn read_csv( has_header, delimiter, io_client, + io_stats, ) .await }) } +#[allow(clippy::too_many_arguments)] async fn read_csv_single( uri: &str, column_names: Option>, @@ -53,8 +56,12 @@ async fn read_csv_single( has_header: bool, delimiter: Option, io_client: Arc, + io_stats: Option, ) -> DaftResult
{ - match io_client.single_url_get(uri.to_string(), None).await? { + match io_client + .single_url_get(uri.to_string(), None, io_stats) + .await? + { GetResult::File(file) => { read_csv_single_from_reader( File::open(file.path).await?.compat(), @@ -208,7 +215,7 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); - let table = read_csv(file, None, None, None, true, None, io_client, true)?; + let table = read_csv(file, None, None, None, true, None, io_client, None, true)?; assert_eq!(table.len(), 100); assert_eq!( table.schema, @@ -231,7 +238,7 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); - let table = read_csv(file, None, None, None, true, None, io_client, true)?; + let table = read_csv(file, None, None, None, true, None, io_client, None, true)?; assert_eq!(table.len(), 5000); Ok(()) @@ -246,7 +253,17 @@ mod tests { let io_client = Arc::new(IOClient::new(io_config.into())?); - let table = read_csv(file, None, None, Some(10), true, None, io_client, true)?; + let table = read_csv( + file, + None, + None, + Some(10), + true, + None, + io_client, + None, + true, + )?; assert_eq!(table.len(), 10); assert_eq!( table.schema, @@ -277,6 +294,7 @@ mod tests { true, None, io_client, + None, true, )?; assert_eq!(table.len(), 100); diff --git a/src/daft-dsl/src/functions/uri/download.rs b/src/daft-dsl/src/functions/uri/download.rs index da27349c36..b9d5c3661d 100644 --- a/src/daft-dsl/src/functions/uri/download.rs +++ b/src/daft-dsl/src/functions/uri/download.rs @@ -64,6 +64,7 @@ impl FunctionEvaluator for DownloadEvaluator { *raise_error_on_failure, *multi_thread, config.clone(), + None, ), _ => Err(DaftError::ValueError(format!( "Expected 1 input arg, got {}", diff --git a/src/daft-io/src/azure_blob.rs b/src/daft-io/src/azure_blob.rs index a546a71422..5ad53b4742 100644 --- a/src/daft-io/src/azure_blob.rs +++ b/src/daft-io/src/azure_blob.rs @@ -10,6 +10,8 @@ use std::{ops::Range, sync::Arc}; use crate::{ object_io::{FileMetadata, FileType, LSResult, ObjectSource}, + stats::IOStatsRef, + stream_utils::io_stats_on_bytestream, GetResult, }; use common_io_config::AzureConfig; @@ -139,6 +141,7 @@ impl AzureBlobSource { async fn list_containers_stream( &self, protocol: &str, + io_stats: Option, ) -> BoxStream> { let protocol = protocol.to_string(); @@ -152,7 +155,12 @@ impl AzureBlobSource { // Flatmap each page of results to a single stream of our standardized FileMetadata. responses_stream - .map(move |response| (response, protocol.clone())) + .map(move |response| { + if let Some(is) = io_stats.clone() { + is.mark_list_requests(1) + } + (response, protocol.clone()) + }) .flat_map(move |(response, protocol)| match response { Ok(response) => { let containers = response.containers.into_iter().map(move |container| { @@ -174,6 +182,7 @@ impl AzureBlobSource { container_name: &str, prefix: &str, posix: bool, + io_stats: Option, ) -> BoxStream> { let container_client = self.blob_client.container_client(container_name); @@ -205,6 +214,7 @@ impl AzureBlobSource { &container_name, &prefix_with_delimiter, &posix, + io_stats.clone(), ) .await; @@ -255,6 +265,7 @@ impl AzureBlobSource { &container_name, upper_dir, &posix, + io_stats.clone() ).await; // At this point, we have a stream of Result. @@ -302,6 +313,7 @@ impl AzureBlobSource { container_name: &str, prefix: &str, posix: &bool, + io_stats: Option, ) -> BoxStream> { // Calls Azure list_blobs with the prefix // and returns the result flattened and standardized into FileMetadata. @@ -323,7 +335,12 @@ impl AzureBlobSource { // Map each page of results to a page of standardized FileMetadata. responses_stream - .map(move |response| (response, protocol.clone(), container_name.clone())) + .map(move |response| { + if let Some(is) = io_stats.clone() { + is.mark_list_requests(1) + } + (response, protocol.clone(), container_name.clone()) + }) .flat_map(move |(response, protocol, container_name)| match response { Ok(response) => { let paths_data = @@ -381,7 +398,12 @@ impl AzureBlobSource { #[async_trait] impl ObjectSource for AzureBlobSource { - async fn get(&self, uri: &str, range: Option>) -> super::Result { + async fn get( + &self, + uri: &str, + range: Option>, + io_stats: Option, + ) -> super::Result { let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; let container = match parsed.host_str() { Some(s) => Ok(s), @@ -412,10 +434,17 @@ impl ObjectSource for AzureBlobSource { .into_error(e) .into() }); - Ok(GetResult::Stream(stream.boxed(), None, None)) + if let Some(is) = io_stats.as_ref() { + is.mark_get_requests(1) + } + Ok(GetResult::Stream( + io_stats_on_bytestream(Box::pin(stream), io_stats), + None, + None, + )) } - async fn get_size(&self, uri: &str) -> super::Result { + async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result { let parsed = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; let container = match parsed.host_str() { Some(s) => Ok(s), @@ -432,6 +461,10 @@ impl ObjectSource for AzureBlobSource { .get_properties() .await .context(UnableToOpenFileSnafu:: { path: uri.into() })?; + if let Some(is) = io_stats.as_ref() { + is.mark_head_requests(1) + } + Ok(metadata.blob.properties.content_length as usize) } @@ -440,13 +473,21 @@ impl ObjectSource for AzureBlobSource { glob_path: &str, fanout_limit: Option, page_size: Option, + io_stats: Option, ) -> super::Result>> { use crate::object_store_glob::glob; // Ensure fanout_limit is not None to prevent runaway concurrency let fanout_limit = fanout_limit.or(Some(DEFAULT_GLOB_FANOUT_LIMIT)); - glob(self, glob_path, fanout_limit, page_size.or(Some(1000))).await + glob( + self, + glob_path, + fanout_limit, + page_size.or(Some(1000)), + io_stats, + ) + .await } async fn iter_dir( @@ -454,6 +495,7 @@ impl ObjectSource for AzureBlobSource { uri: &str, posix: bool, _page_size: Option, + io_stats: Option, ) -> super::Result>> { let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; @@ -480,12 +522,12 @@ impl ObjectSource for AzureBlobSource { match container { // List containers. - None => Ok(self.list_containers_stream(protocol).await), + None => Ok(self.list_containers_stream(protocol, io_stats).await), // List a path within a container. Some(container_name) => { let prefix = uri.path(); Ok(self - .list_directory_stream(protocol, container_name, prefix, posix) + .list_directory_stream(protocol, container_name, prefix, posix, io_stats) .await) } } @@ -497,6 +539,7 @@ impl ObjectSource for AzureBlobSource { posix: bool, continuation_token: Option<&str>, _page_size: Option, + io_stats: Option, ) -> super::Result { // It looks like the azure rust library API // does not currently allow using the continuation token: @@ -511,7 +554,7 @@ impl ObjectSource for AzureBlobSource { }?; let files = self - .iter_dir(path, posix, None) + .iter_dir(path, posix, None, io_stats) .await? .try_collect::>() .await?; diff --git a/src/daft-io/src/google_cloud.rs b/src/daft-io/src/google_cloud.rs index de2c6127a9..cda1460041 100644 --- a/src/daft-io/src/google_cloud.rs +++ b/src/daft-io/src/google_cloud.rs @@ -2,7 +2,6 @@ use std::ops::Range; use std::sync::Arc; use futures::stream::BoxStream; -use futures::StreamExt; use futures::TryStreamExt; use google_cloud_storage::client::ClientConfig; @@ -21,6 +20,8 @@ use crate::object_io::FileType; use crate::object_io::LSResult; use crate::object_io::ObjectSource; use crate::s3_like; +use crate::stats::IOStatsRef; +use crate::stream_utils::io_stats_on_bytestream; use crate::GetResult; use common_io_config::GCSConfig; @@ -127,7 +128,12 @@ fn parse_uri(uri: &url::Url) -> super::Result<(&str, &str)> { } impl GCSClientWrapper { - async fn get(&self, uri: &str, range: Option>) -> super::Result { + async fn get( + &self, + uri: &str, + range: Option>, + io_stats: Option, + ) -> super::Result { let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; let (bucket, key) = parse_uri(&uri)?; match self { @@ -160,16 +166,23 @@ impl GCSClientWrapper { .into_error(e) .into() }); - Ok(GetResult::Stream(response.boxed(), size, None)) + if let Some(is) = io_stats.as_ref() { + is.mark_get_requests(1) + } + Ok(GetResult::Stream( + io_stats_on_bytestream(response, io_stats), + size, + None, + )) } GCSClientWrapper::S3Compat(client) => { let uri = format!("s3://{}/{}", bucket, key); - client.get(&uri, range).await + client.get(&uri, range, io_stats).await } } } - async fn get_size(&self, uri: &str) -> super::Result { + async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result { let uri = url::Url::parse(uri).with_context(|_| InvalidUrlSnafu { path: uri })?; let (bucket, key) = parse_uri(&uri)?; match self { @@ -186,15 +199,18 @@ impl GCSClientWrapper { .context(UnableToOpenFileSnafu { path: uri.to_string(), })?; + if let Some(is) = io_stats.as_ref() { + is.mark_head_requests(1) + } Ok(response.size as usize) } GCSClientWrapper::S3Compat(client) => { let uri = format!("s3://{}/{}", bucket, key); - client.get_size(&uri).await + client.get_size(&uri, io_stats).await } } } - + #[allow(clippy::too_many_arguments)] async fn _ls_impl( &self, client: &Client, @@ -203,6 +219,7 @@ impl GCSClientWrapper { delimiter: Option<&str>, continuation_token: Option<&str>, page_size: Option, + io_stats: Option<&IOStatsRef>, ) -> super::Result { let req = ListObjectsRequest { bucket: bucket.to_string(), @@ -222,6 +239,10 @@ impl GCSClientWrapper { .context(UnableToListObjectsSnafu { path: format!("{GCS_SCHEME}://{}/{}", bucket, key), })?; + if let Some(is) = io_stats.as_ref() { + is.mark_list_requests(1) + } + let response_items = ls_response.items.unwrap_or_default(); let response_prefixes = ls_response.prefixes.unwrap_or_default(); let files = response_items.iter().map(|obj| FileMetadata { @@ -246,6 +267,7 @@ impl GCSClientWrapper { posix: bool, continuation_token: Option<&str>, page_size: Option, + io_stats: Option, ) -> super::Result { let uri = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; let (bucket, key) = parse_uri(&uri)?; @@ -266,6 +288,7 @@ impl GCSClientWrapper { Some(GCS_DELIMITER), continuation_token, page_size, + io_stats.as_ref(), ) .await?; @@ -280,6 +303,7 @@ impl GCSClientWrapper { Some(GCS_DELIMITER), continuation_token, page_size, + io_stats.as_ref(), ) .await?; @@ -307,12 +331,15 @@ impl GCSClientWrapper { None, // Force a prefix-listing continuation_token, page_size, + io_stats.as_ref(), ) .await } } GCSClientWrapper::S3Compat(client) => { - client.ls(path, posix, continuation_token, page_size).await + client + .ls(path, posix, continuation_token, page_size, io_stats) + .await } } } @@ -362,12 +389,17 @@ impl GCSSource { #[async_trait] impl ObjectSource for GCSSource { - async fn get(&self, uri: &str, range: Option>) -> super::Result { - self.client.get(uri, range).await + async fn get( + &self, + uri: &str, + range: Option>, + io_stats: Option, + ) -> super::Result { + self.client.get(uri, range, io_stats).await } - async fn get_size(&self, uri: &str) -> super::Result { - self.client.get_size(uri).await + async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result { + self.client.get_size(uri, io_stats).await } async fn glob( @@ -375,13 +407,21 @@ impl ObjectSource for GCSSource { glob_path: &str, fanout_limit: Option, page_size: Option, + io_stats: Option, ) -> super::Result>> { use crate::object_store_glob::glob; // Ensure fanout_limit is not None to prevent runaway concurrency let fanout_limit = fanout_limit.or(Some(DEFAULT_GLOB_FANOUT_LIMIT)); - glob(self, glob_path, fanout_limit, page_size.or(Some(1000))).await + glob( + self, + glob_path, + fanout_limit, + page_size.or(Some(1000)), + io_stats, + ) + .await } async fn ls( @@ -390,9 +430,10 @@ impl ObjectSource for GCSSource { posix: bool, continuation_token: Option<&str>, page_size: Option, + io_stats: Option, ) -> super::Result { self.client - .ls(path, posix, continuation_token, page_size) + .ls(path, posix, continuation_token, page_size, io_stats) .await } } diff --git a/src/daft-io/src/http.rs b/src/daft-io/src/http.rs index a433c99c80..e9f448256c 100644 --- a/src/daft-io/src/http.rs +++ b/src/daft-io/src/http.rs @@ -1,7 +1,7 @@ use std::{num::ParseIntError, ops::Range, string::FromUtf8Error, sync::Arc}; use async_trait::async_trait; -use futures::{stream::BoxStream, StreamExt, TryStreamExt}; +use futures::{stream::BoxStream, TryStreamExt}; use lazy_static::lazy_static; use regex::Regex; @@ -9,7 +9,11 @@ use reqwest::header::{CONTENT_LENGTH, RANGE}; use snafu::{IntoError, ResultExt, Snafu}; use url::Position; -use crate::object_io::{FileMetadata, FileType, LSResult}; +use crate::{ + object_io::{FileMetadata, FileType, LSResult}, + stats::IOStatsRef, + stream_utils::io_stats_on_bytestream, +}; use super::object_io::{GetResult, ObjectSource}; @@ -169,7 +173,12 @@ impl HttpSource { #[async_trait] impl ObjectSource for HttpSource { - async fn get(&self, uri: &str, range: Option>) -> super::Result { + async fn get( + &self, + uri: &str, + range: Option>, + io_stats: Option, + ) -> super::Result { let request = self.client.get(uri); let request = match range { None => request, @@ -186,6 +195,9 @@ impl ObjectSource for HttpSource { let response = response .error_for_status() .context(UnableToOpenFileSnafu:: { path: uri.into() })?; + if let Some(is) = io_stats.as_ref() { + is.mark_get_requests(1) + } let size_bytes = response.content_length().map(|s| s as usize); let stream = response.bytes_stream(); let owned_string = uri.to_owned(); @@ -196,10 +208,14 @@ impl ObjectSource for HttpSource { .into_error(e) .into() }); - Ok(GetResult::Stream(stream.boxed(), size_bytes, None)) + Ok(GetResult::Stream( + io_stats_on_bytestream(stream, io_stats), + size_bytes, + None, + )) } - async fn get_size(&self, uri: &str) -> super::Result { + async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result { let request = self.client.head(uri); let response = request .send() @@ -209,6 +225,10 @@ impl ObjectSource for HttpSource { .error_for_status() .context(UnableToOpenFileSnafu:: { path: uri.into() })?; + if let Some(is) = io_stats.as_ref() { + is.mark_head_requests(1) + } + let headers = response.headers(); match headers.get(CONTENT_LENGTH) { Some(v) => { @@ -229,6 +249,7 @@ impl ObjectSource for HttpSource { glob_path: &str, _fanout_limit: Option, _page_size: Option, + io_stats: Option, ) -> super::Result>> { use crate::object_store_glob::glob; @@ -236,7 +257,7 @@ impl ObjectSource for HttpSource { let fanout_limit = None; let page_size = None; - glob(self, glob_path, fanout_limit, page_size).await + glob(self, glob_path, fanout_limit, page_size, io_stats).await } async fn ls( @@ -245,6 +266,7 @@ impl ObjectSource for HttpSource { posix: bool, _continuation_token: Option<&str>, _page_size: Option, + io_stats: Option, ) -> super::Result { if !posix { unimplemented!("Prefix-listing is not implemented for HTTP listing"); @@ -257,6 +279,9 @@ impl ObjectSource for HttpSource { .context(UnableToConnectSnafu:: { path: path.into() })? .error_for_status() .with_context(|_| UnableToOpenFileSnafu { path })?; + if let Some(is) = io_stats.as_ref() { + is.mark_list_requests(1) + } // Reconstruct the actual path of the request, which may have been redirected via a 301 // This is important because downstream URL joining logic relies on proper trailing-slashes/index.html @@ -308,14 +333,14 @@ mod tests { let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; let client = HttpSource::get_client().await?; - let parquet_file = client.get(parquet_file_path, None).await?; + let parquet_file = client.get(parquet_file_path, None, None).await?; let bytes = parquet_file.bytes().await?; let all_bytes = bytes.as_ref(); let checksum = format!("{:x}", md5::compute(all_bytes)); assert_eq!(checksum, parquet_expected_md5); let first_bytes = client - .get_range(parquet_file_path, 0..10) + .get_range(parquet_file_path, 0..10, None) .await? .bytes() .await?; @@ -323,7 +348,7 @@ mod tests { assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); let first_bytes = client - .get_range(parquet_file_path, 10..100) + .get_range(parquet_file_path, 10..100, None) .await? .bytes() .await?; @@ -334,6 +359,7 @@ mod tests { .get_range( parquet_file_path, (all_bytes.len() - 10)..(all_bytes.len() + 10), + None, ) .await? .bytes() @@ -341,7 +367,7 @@ mod tests { assert_eq!(last_bytes.len(), 10); assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); - let size_from_get_size = client.get_size(parquet_file_path).await?; + let size_from_get_size = client.get_size(parquet_file_path, None).await?; assert_eq!(size_from_get_size, all_bytes.len()); Ok(()) } diff --git a/src/daft-io/src/lib.rs b/src/daft-io/src/lib.rs index 15179102e9..013eb9050c 100644 --- a/src/daft-io/src/lib.rs +++ b/src/daft-io/src/lib.rs @@ -8,6 +8,8 @@ mod local; mod object_io; mod object_store_glob; mod s3_like; +mod stats; +mod stream_utils; use azure_blob::AzureBlobSource; use google_cloud::GCSSource; use lazy_static::lazy_static; @@ -18,6 +20,7 @@ pub use common_io_config::{AzureConfig, IOConfig, S3Config}; pub use object_io::GetResult; #[cfg(feature = "python")] pub use python::register_modules; +pub use stats::{IOStatsContext, IOStatsRef}; use tokio::runtime::RuntimeFlavor; use std::{borrow::Cow, collections::HashMap, hash::Hash, ops::Range, sync::Arc}; @@ -166,16 +169,21 @@ impl IOClient { &self, input: String, range: Option>, + io_stats: Option, ) -> Result { let (scheme, path) = parse_url(&input)?; let source = self.get_source(&scheme).await?; - source.get(path.as_ref(), range).await + source.get(path.as_ref(), range, io_stats).await } - pub async fn single_url_get_size(&self, input: String) -> Result { + pub async fn single_url_get_size( + &self, + input: String, + io_stats: Option, + ) -> Result { let (scheme, path) = parse_url(&input)?; let source = self.get_source(&scheme).await?; - source.get_size(path.as_ref()).await + source.get_size(path.as_ref(), io_stats).await } async fn single_url_download( @@ -183,9 +191,10 @@ impl IOClient { index: usize, input: Option, raise_error_on_failure: bool, + io_stats: Option, ) -> Result> { let value = if let Some(input) = input { - let response = self.single_url_get(input, None).await; + let response = self.single_url_get(input, None, io_stats).await; let res = match response { Ok(res) => res.bytes().await, Err(err) => Err(err), @@ -355,6 +364,7 @@ pub fn _url_download( raise_error_on_failure: bool, multi_thread: bool, config: Arc, + io_stats: Option, ) -> DaftResult { let urls = array.as_arrow().iter(); let name = array.name(); @@ -376,11 +386,12 @@ pub fn _url_download( let fetches = futures::stream::iter(urls.enumerate().map(|(i, url)| { let owned_url = url.map(|s| s.to_string()); let owned_client = io_client.clone(); + let owned_io_stats = io_stats.clone(); tokio::spawn(async move { ( i, owned_client - .single_url_download(i, owned_url, raise_error_on_failure) + .single_url_download(i, owned_url, raise_error_on_failure, owned_io_stats) .await, ) }) @@ -432,6 +443,7 @@ pub fn url_download( raise_error_on_failure: bool, multi_thread: bool, config: Arc, + io_stats: Option, ) -> DaftResult { match series.data_type() { DataType::Utf8 => Ok(_url_download( @@ -440,6 +452,7 @@ pub fn url_download( raise_error_on_failure, multi_thread, config, + io_stats, )? .into_series()), dt => Err(DaftError::TypeError(format!( diff --git a/src/daft-io/src/local.rs b/src/daft-io/src/local.rs index 7d7b1d3f0f..fd27d87d27 100644 --- a/src/daft-io/src/local.rs +++ b/src/daft-io/src/local.rs @@ -3,6 +3,7 @@ use std::ops::Range; use std::path::PathBuf; use crate::object_io::{self, FileMetadata, LSResult}; +use crate::stats::IOStatsRef; use super::object_io::{GetResult, ObjectSource}; use super::Result; @@ -111,7 +112,12 @@ pub struct LocalFile { #[async_trait] impl ObjectSource for LocalSource { - async fn get(&self, uri: &str, range: Option>) -> super::Result { + async fn get( + &self, + uri: &str, + range: Option>, + _io_stats: Option, + ) -> super::Result { const LOCAL_PROTOCOL: &str = "file://"; if let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) { Ok(GetResult::File(LocalFile { @@ -123,7 +129,7 @@ impl ObjectSource for LocalSource { } } - async fn get_size(&self, uri: &str) -> super::Result { + async fn get_size(&self, uri: &str, _io_stats: Option) -> super::Result { const LOCAL_PROTOCOL: &str = "file://"; let Some(uri) = uri.strip_prefix(LOCAL_PROTOCOL) else { return Err(Error::InvalidFilePath { path: uri.into() }.into()); @@ -141,6 +147,7 @@ impl ObjectSource for LocalSource { glob_path: &str, _fanout_limit: Option, _page_size: Option, + io_stats: Option, ) -> super::Result>> { use crate::object_store_glob::glob; @@ -153,10 +160,10 @@ impl ObjectSource for LocalSource { #[cfg(target_env = "msvc")] { let glob_path = glob_path.replace("\\", "/"); - return glob(self, glob_path.as_str(), fanout_limit, page_size).await; + return glob(self, glob_path.as_str(), fanout_limit, page_size, io_stats).await; } - glob(self, glob_path, fanout_limit, page_size).await + glob(self, glob_path, fanout_limit, page_size, io_stats).await } async fn ls( @@ -165,8 +172,9 @@ impl ObjectSource for LocalSource { posix: bool, _continuation_token: Option<&str>, _page_size: Option, + io_stats: Option, ) -> super::Result { - let s = self.iter_dir(path, posix, None).await?; + let s = self.iter_dir(path, posix, None, io_stats).await?; let files = s.try_collect::>().await?; Ok(LSResult { files, @@ -179,6 +187,7 @@ impl ObjectSource for LocalSource { uri: &str, posix: bool, _page_size: Option, + _io_stats: Option, ) -> super::Result>> { if !posix { unimplemented!("Prefix-listing is not implemented for local."); @@ -317,7 +326,7 @@ mod tests { let parquet_expected_md5 = "929674747af64a98aceaa6d895863bd3"; let client = HttpSource::get_client().await?; - let parquet_file = client.get(parquet_file_path, None).await?; + let parquet_file = client.get(parquet_file_path, None, None).await?; let bytes = parquet_file.bytes().await?; let all_bytes = bytes.as_ref(); let checksum = format!("{:x}", md5::compute(all_bytes)); @@ -335,12 +344,16 @@ mod tests { let parquet_file_path = format!("file://{}", file1.path().to_str().unwrap()); let client = LocalSource::get_client().await?; - let try_all_bytes = client.get(&parquet_file_path, None).await?.bytes().await?; + let try_all_bytes = client + .get(&parquet_file_path, None, None) + .await? + .bytes() + .await?; assert_eq!(try_all_bytes.len(), bytes.len()); assert_eq!(try_all_bytes, bytes); let first_bytes = client - .get_range(&parquet_file_path, 0..10) + .get_range(&parquet_file_path, 0..10, None) .await? .bytes() .await?; @@ -348,7 +361,7 @@ mod tests { assert_eq!(first_bytes.as_ref(), &bytes[..10]); let first_bytes = client - .get_range(&parquet_file_path, 10..100) + .get_range(&parquet_file_path, 10..100, None) .await? .bytes() .await?; @@ -356,14 +369,18 @@ mod tests { assert_eq!(first_bytes.as_ref(), &bytes[10..100]); let last_bytes = client - .get_range(&parquet_file_path, (bytes.len() - 10)..(bytes.len() + 10)) + .get_range( + &parquet_file_path, + (bytes.len() - 10)..(bytes.len() + 10), + None, + ) .await? .bytes() .await?; assert_eq!(last_bytes.len(), 10); assert_eq!(last_bytes.as_ref(), &bytes[(bytes.len() - 10)..]); - let size_from_get_size = client.get_size(parquet_file_path.as_str()).await?; + let size_from_get_size = client.get_size(parquet_file_path.as_str(), None).await?; assert_eq!(size_from_get_size, bytes.len()); Ok(()) @@ -378,10 +395,10 @@ mod tests { write_remote_parquet_to_local_file(&mut file2).await?; let mut file3 = tempfile::NamedTempFile::new_in(dir.path()).unwrap(); write_remote_parquet_to_local_file(&mut file3).await?; - let dir_path = format!("file://{}", dir.path().to_string_lossy().replace("\\", "/")); + let dir_path = format!("file://{}", dir.path().to_string_lossy().replace('\\', "/")); let client = LocalSource::get_client().await?; - let ls_result = client.ls(dir_path.as_ref(), true, None, None).await?; + let ls_result = client.ls(dir_path.as_ref(), true, None, None, None).await?; let mut files = ls_result.files.clone(); // Ensure stable sort ordering of file paths before comparing with expected payload. files.sort_by(|a, b| a.filepath.cmp(&b.filepath)); @@ -389,7 +406,7 @@ mod tests { FileMetadata { filepath: format!( "file://{}/{}", - dir.path().to_string_lossy().replace("\\", "/"), + dir.path().to_string_lossy().replace('\\', "/"), file1.path().file_name().unwrap().to_string_lossy(), ), size: Some(file1.as_file().metadata().unwrap().len()), @@ -398,7 +415,7 @@ mod tests { FileMetadata { filepath: format!( "file://{}/{}", - dir.path().to_string_lossy().replace("\\", "/"), + dir.path().to_string_lossy().replace('\\', "/"), file2.path().file_name().unwrap().to_string_lossy(), ), size: Some(file2.as_file().metadata().unwrap().len()), @@ -407,7 +424,7 @@ mod tests { FileMetadata { filepath: format!( "file://{}/{}", - dir.path().to_string_lossy().replace("\\", "/"), + dir.path().to_string_lossy().replace('\\', "/"), file3.path().file_name().unwrap().to_string_lossy(), ), size: Some(file3.as_file().metadata().unwrap().len()), diff --git a/src/daft-io/src/object_io.rs b/src/daft-io/src/object_io.rs index ba8d4d15a6..c0429c687d 100644 --- a/src/daft-io/src/object_io.rs +++ b/src/daft-io/src/object_io.rs @@ -10,6 +10,7 @@ use futures::StreamExt; use tokio::sync::OwnedSemaphorePermit; use crate::local::{collect_file, LocalFile}; +use crate::stats::IOStatsRef; pub enum GetResult { File(LocalFile), @@ -94,17 +95,28 @@ use async_stream::stream; #[async_trait] pub(crate) trait ObjectSource: Sync + Send { - async fn get(&self, uri: &str, range: Option>) -> super::Result; - async fn get_range(&self, uri: &str, range: Range) -> super::Result { - self.get(uri, Some(range)).await + async fn get( + &self, + uri: &str, + range: Option>, + io_stats: Option, + ) -> super::Result; + async fn get_range( + &self, + uri: &str, + range: Range, + io_stats: Option, + ) -> super::Result { + self.get(uri, Some(range), io_stats).await } - async fn get_size(&self, uri: &str) -> super::Result; + async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result; async fn glob( self: Arc, glob_path: &str, fanout_limit: Option, page_size: Option, + io_stats: Option, ) -> super::Result>>; async fn ls( @@ -113,6 +125,7 @@ pub(crate) trait ObjectSource: Sync + Send { posix: bool, continuation_token: Option<&str>, page_size: Option, + io_stats: Option, ) -> super::Result; async fn iter_dir( @@ -120,17 +133,18 @@ pub(crate) trait ObjectSource: Sync + Send { uri: &str, posix: bool, page_size: Option, + io_stats: Option, ) -> super::Result>> { let uri = uri.to_string(); let s = stream! { - let lsr = self.ls(&uri, posix, None, page_size).await?; + let lsr = self.ls(&uri, posix, None, page_size, io_stats.clone()).await?; for fm in lsr.files { yield Ok(fm); } let mut continuation_token = lsr.continuation_token.clone(); while continuation_token.is_some() { - let lsr = self.ls(&uri, posix, continuation_token.as_deref(), page_size).await?; + let lsr = self.ls(&uri, posix, continuation_token.as_deref(), page_size, io_stats.clone()).await?; continuation_token = lsr.continuation_token.clone(); for fm in lsr.files { yield Ok(fm); diff --git a/src/daft-io/src/object_store_glob.rs b/src/daft-io/src/object_store_glob.rs index 162462d891..d7de111795 100644 --- a/src/daft-io/src/object_store_glob.rs +++ b/src/daft-io/src/object_store_glob.rs @@ -7,7 +7,10 @@ use tokio::sync::mpsc::Sender; use globset::{GlobBuilder, GlobMatcher}; use lazy_static::lazy_static; -use crate::object_io::{FileMetadata, FileType, ObjectSource}; +use crate::{ + object_io::{FileMetadata, FileType, ObjectSource}, + stats::IOStatsRef, +}; lazy_static! { /// Check if a given char is considered a special glob character @@ -229,15 +232,17 @@ async fn ls_with_prefix_fallback( uri: &str, max_dirs: Option, page_size: Option, + io_stats: Option, ) -> (BoxStream<'static, super::Result>, usize) { // Prefix list function that only returns Files fn prefix_ls( source: Arc, path: String, page_size: Option, + io_stats: Option, ) -> BoxStream<'static, super::Result> { stream! { - match source.iter_dir(&path, false, page_size).await { + match source.iter_dir(&path, false, page_size, io_stats).await { Ok(mut result_stream) => { while let Some(result) = result_stream.next().await { match result { @@ -261,7 +266,7 @@ async fn ls_with_prefix_fallback( let mut results_buffer = vec![]; let mut fm_stream = source - .iter_dir(uri, true, page_size) + .iter_dir(uri, true, page_size, io_stats.clone()) .await .unwrap_or_else(|e| futures::stream::iter([Err(e)]).boxed()); @@ -278,7 +283,10 @@ async fn ls_with_prefix_fallback( .map(|max_dirs| dir_count_so_far > max_dirs) .unwrap_or(false) { - return (prefix_ls(source.clone(), uri.to_string(), page_size), 0); + return ( + prefix_ls(source.clone(), uri.to_string(), page_size, io_stats), + 0, + ); } } } @@ -312,13 +320,14 @@ pub(crate) async fn glob( glob: &str, fanout_limit: Option, page_size: Option, + io_stats: Option, ) -> super::Result>> { // If no special characters, we fall back to ls behavior let full_fragment = GlobFragment::new(glob); if !full_fragment.has_special_character() { let glob = full_fragment.escaped_str().to_string(); return Ok(stream! { - let mut results = source.iter_dir(glob.as_str(), true, page_size).await?; + let mut results = source.iter_dir(glob.as_str(), true, page_size, io_stats).await?; while let Some(result) = results.next().await { match result { Ok(fm) => { @@ -364,6 +373,7 @@ pub(crate) async fn glob( result_tx: Sender>, source: Arc, state: GlobState, + io_stats: Option, ) { tokio::spawn(async move { log::debug!( @@ -383,6 +393,7 @@ pub(crate) async fn glob( .fanout_limit .map(|fanout_limit| fanout_limit / state.current_fanout), state.page_size, + io_stats.clone(), ) .await; @@ -401,6 +412,7 @@ pub(crate) async fn glob( state.current_fragment_idx, stream_dir_count, ), + io_stats.clone(), ); } // Return any Files that match @@ -424,7 +436,7 @@ pub(crate) async fn glob( // Last fragment contains a wildcard: we list the last level and match against the full glob if current_fragment.has_special_character() { let mut results = source - .iter_dir(&state.current_path, true, state.page_size) + .iter_dir(&state.current_path, true, state.page_size, io_stats) .await .unwrap_or_else(|e| futures::stream::iter([Err(e)]).boxed()); @@ -448,7 +460,13 @@ pub(crate) async fn glob( } else { let full_dir_path = state.current_path.clone() + current_fragment.escaped_str(); let single_file_ls = source - .ls(full_dir_path.as_str(), true, None, state.page_size) + .ls( + full_dir_path.as_str(), + true, + None, + state.page_size, + io_stats, + ) .await; match single_file_ls { Ok(mut single_file_ls) => { @@ -488,6 +506,7 @@ pub(crate) async fn glob( .fanout_limit .map(|fanout_limit| fanout_limit / state.current_fanout), state.page_size, + io_stats.clone(), ) .await; @@ -510,6 +529,7 @@ pub(crate) async fn glob( stream_dir_count, ) .with_wildcard_mode(), + io_stats.clone(), ); } FileType::File @@ -536,6 +556,7 @@ pub(crate) async fn glob( state .clone() .advance(full_dir_path, state.current_fragment_idx + 1, 1), + io_stats, ); } }); @@ -554,6 +575,7 @@ pub(crate) async fn glob( fanout_limit, page_size, }, + io_stats, ); let to_rtn_stream = stream! { diff --git a/src/daft-io/src/python.rs b/src/daft-io/src/python.rs index f22d182a19..1dba8af023 100644 --- a/src/daft-io/src/python.rs +++ b/src/daft-io/src/python.rs @@ -2,7 +2,7 @@ pub use common_io_config::python::{AzureConfig, GCSConfig, IOConfig}; pub use py::register_modules; mod py { - use crate::{get_io_client, get_runtime, parse_url}; + use crate::{get_io_client, get_runtime, parse_url, stats::IOStatsContext}; use common_error::DaftResult; use futures::TryStreamExt; use pyo3::{ @@ -20,6 +20,9 @@ mod py { page_size: Option, ) -> PyResult<&PyList> { let multithreaded_io = multithreaded_io.unwrap_or(true); + let io_stats = IOStatsContext::new(format!("io_glob for {path}")); + let io_stats_handle = io_stats.clone(); + let lsr: DaftResult> = py.allow_threads(|| { let io_client = get_io_client( multithreaded_io, @@ -32,7 +35,12 @@ mod py { runtime_handle.block_on(async move { let source = io_client.get_source(&scheme).await?; let files = source - .glob(path.as_ref(), fanout_limit, page_size) + .glob( + path.as_ref(), + fanout_limit, + page_size, + Some(io_stats_handle), + ) .await? .try_collect() .await?; diff --git a/src/daft-io/src/s3_like.rs b/src/daft-io/src/s3_like.rs index 7681201fb6..156e4acf18 100644 --- a/src/daft-io/src/s3_like.rs +++ b/src/daft-io/src/s3_like.rs @@ -10,6 +10,8 @@ use s3::operation::list_objects_v2::ListObjectsV2Error; use tokio::sync::{OwnedSemaphorePermit, SemaphorePermit}; use crate::object_io::{FileMetadata, FileType, LSResult}; +use crate::stats::IOStatsRef; +use crate::stream_utils::io_stats_on_bytestream; use crate::{get_io_pool_num_threads, InvalidArgumentSnafu, SourceType}; use aws_config::SdkConfig; use aws_credential_types::cache::{ProvideCachedCredentials, SharedCredentialsCache}; @@ -720,24 +722,51 @@ impl S3LikeSource { #[async_trait] impl ObjectSource for S3LikeSource { - async fn get(&self, uri: &str, range: Option>) -> super::Result { + async fn get( + &self, + uri: &str, + range: Option>, + io_stats: Option, + ) -> super::Result { let permit = self .connection_pool_sema .clone() .acquire_owned() .await .context(UnableToGrabSemaphoreSnafu)?; - self._get_impl(permit, uri, range, &self.default_region) - .await + let get_result = self + ._get_impl(permit, uri, range, &self.default_region) + .await?; + + if io_stats.is_some() { + if let GetResult::Stream(stream, num_bytes, permit) = get_result { + if let Some(is) = io_stats.as_ref() { + is.mark_get_requests(1) + } + Ok(GetResult::Stream( + io_stats_on_bytestream(stream, io_stats), + num_bytes, + permit, + )) + } else { + panic!("This should always be a stream"); + } + } else { + Ok(get_result) + } } - async fn get_size(&self, uri: &str) -> super::Result { + async fn get_size(&self, uri: &str, io_stats: Option) -> super::Result { let permit = self .connection_pool_sema .acquire() .await .context(UnableToGrabSemaphoreSnafu)?; - self._head_impl(permit, uri, &self.default_region).await + let head_result = self._head_impl(permit, uri, &self.default_region).await?; + if let Some(is) = io_stats.as_ref() { + is.mark_head_requests(1) + } + Ok(head_result) } async fn glob( @@ -745,13 +774,21 @@ impl ObjectSource for S3LikeSource { glob_path: &str, fanout_limit: Option, page_size: Option, + io_stats: Option, ) -> super::Result>> { use crate::object_store_glob::glob; // Ensure fanout_limit is not None to prevent runaway concurrency let fanout_limit = fanout_limit.or(Some(DEFAULT_GLOB_FANOUT_LIMIT)); - glob(self, glob_path, fanout_limit, page_size.or(Some(1000))).await + glob( + self, + glob_path, + fanout_limit, + page_size.or(Some(1000)), + io_stats, + ) + .await } async fn ls( @@ -760,6 +797,7 @@ impl ObjectSource for S3LikeSource { posix: bool, continuation_token: Option<&str>, page_size: Option, + io_stats: Option, ) -> super::Result { let parsed = url::Url::parse(path).with_context(|_| InvalidUrlSnafu { path })?; let scheme = parsed.scheme(); @@ -800,6 +838,10 @@ impl ObjectSource for S3LikeSource { ) .await? }; + if let Some(is) = io_stats.as_ref() { + is.mark_list_requests(1) + } + if lsr.files.is_empty() && key.contains(S3_DELIMITER) { let permit = self .connection_pool_sema @@ -820,6 +862,9 @@ impl ObjectSource for S3LikeSource { page_size, ) .await?; + if let Some(is) = io_stats.as_ref() { + is.mark_list_requests(1) + } let target_path = format!("{scheme}://{bucket}/{key}"); lsr.files.retain(|f| f.filepath == target_path); @@ -852,6 +897,10 @@ impl ObjectSource for S3LikeSource { ) .await? }; + if let Some(is) = io_stats.as_ref() { + is.mark_list_requests(1) + } + Ok(lsr) } } @@ -875,14 +924,14 @@ mod tests { ..Default::default() }; let client = S3LikeSource::get_client(&config).await?; - let parquet_file = client.get(parquet_file_path, None).await?; + let parquet_file = client.get(parquet_file_path, None, None).await?; let bytes = parquet_file.bytes().await?; let all_bytes = bytes.as_ref(); let checksum = format!("{:x}", md5::compute(all_bytes)); assert_eq!(checksum, parquet_expected_md5); let first_bytes = client - .get_range(parquet_file_path, 0..10) + .get_range(parquet_file_path, 0..10, None) .await? .bytes() .await?; @@ -890,7 +939,7 @@ mod tests { assert_eq!(first_bytes.as_ref(), &all_bytes[..10]); let first_bytes = client - .get_range(parquet_file_path, 10..100) + .get_range(parquet_file_path, 10..100, None) .await? .bytes() .await?; @@ -901,6 +950,7 @@ mod tests { .get_range( parquet_file_path, (all_bytes.len() - 10)..(all_bytes.len() + 10), + None, ) .await? .bytes() @@ -908,7 +958,7 @@ mod tests { assert_eq!(last_bytes.len(), 10); assert_eq!(last_bytes.as_ref(), &all_bytes[(all_bytes.len() - 10)..]); - let size_from_get_size = client.get_size(parquet_file_path).await?; + let size_from_get_size = client.get_size(parquet_file_path, None).await?; assert_eq!(size_from_get_size, all_bytes.len()); Ok(()) @@ -924,7 +974,7 @@ mod tests { }; let client = S3LikeSource::get_client(&config).await?; - client.ls(file_path, true, None, None).await?; + client.ls(file_path, true, None, None, None).await?; Ok(()) } diff --git a/src/daft-io/src/stats.rs b/src/daft-io/src/stats.rs new file mode 100644 index 0000000000..aee57b140e --- /dev/null +++ b/src/daft-io/src/stats.rs @@ -0,0 +1,114 @@ +use std::sync::{ + atomic::{self}, + Arc, +}; + +pub type IOStatsRef = Arc; + +#[derive(Default, Debug)] +pub struct IOStatsContext { + name: String, + num_get_requests: atomic::AtomicUsize, + num_head_requests: atomic::AtomicUsize, + num_list_requests: atomic::AtomicUsize, + bytes_read: atomic::AtomicUsize, +} + +impl Drop for IOStatsContext { + fn drop(&mut self) { + let bytes_read = self.load_bytes_read(); + let num_gets = self.load_get_requests(); + let mean_size = (bytes_read as f64) / (num_gets as f64); + log::debug!( + "IOStatsContext: {}, Gets: {}, Heads: {}, Lists: {}, BytesRead: {}, AvgGetSize: {}", + self.name, + num_gets, + self.load_head_requests(), + self.load_list_requests(), + bytes_read, + mean_size as i64 + ); + } +} + +pub(crate) struct IOStatsByteStreamContextHandle { + // do not enable Copy or Clone on this struct + bytes_read: usize, + inner: IOStatsRef, +} + +impl IOStatsContext { + pub fn new(name: String) -> IOStatsRef { + Arc::new(IOStatsContext { + name, + num_get_requests: atomic::AtomicUsize::new(0), + num_head_requests: atomic::AtomicUsize::new(0), + num_list_requests: atomic::AtomicUsize::new(0), + bytes_read: atomic::AtomicUsize::new(0), + }) + } + + #[inline] + pub(crate) fn mark_get_requests(&self, num_requests: usize) { + self.num_get_requests + .fetch_add(num_requests, atomic::Ordering::Relaxed); + } + + #[inline] + pub(crate) fn mark_head_requests(&self, num_requests: usize) { + self.num_head_requests + .fetch_add(num_requests, atomic::Ordering::Relaxed); + } + + #[inline] + pub(crate) fn mark_list_requests(&self, num_requests: usize) { + self.num_list_requests + .fetch_add(num_requests, atomic::Ordering::Relaxed); + } + + #[inline] + pub fn load_get_requests(&self) -> usize { + self.num_get_requests.load(atomic::Ordering::Acquire) + } + + #[inline] + pub fn load_head_requests(&self) -> usize { + self.num_head_requests.load(atomic::Ordering::Acquire) + } + + #[inline] + pub fn load_list_requests(&self) -> usize { + self.num_list_requests.load(atomic::Ordering::Acquire) + } + + #[inline] + pub(crate) fn mark_bytes_read(&self, bytes_read: usize) { + self.bytes_read + .fetch_add(bytes_read, atomic::Ordering::Relaxed); + } + + #[inline] + pub fn load_bytes_read(&self) -> usize { + self.bytes_read.load(atomic::Ordering::Acquire) + } +} + +impl IOStatsByteStreamContextHandle { + pub fn new(io_stats: IOStatsRef) -> Self { + Self { + bytes_read: 0, + inner: io_stats, + } + } + + #[inline] + pub fn mark_bytes_read(&mut self, bytes_read: usize) { + self.bytes_read += bytes_read; + } +} + +impl Drop for IOStatsByteStreamContextHandle { + fn drop(&mut self) { + self.inner.mark_bytes_read(self.bytes_read); + } +} diff --git a/src/daft-io/src/stream_utils.rs b/src/daft-io/src/stream_utils.rs new file mode 100644 index 0000000000..1460d3ebe3 --- /dev/null +++ b/src/daft-io/src/stream_utils.rs @@ -0,0 +1,28 @@ +use bytes::Bytes; + +use crate::stats::{IOStatsByteStreamContextHandle, IOStatsRef}; + +use futures::{stream::BoxStream, StreamExt}; + +pub(crate) fn io_stats_on_bytestream( + mut s: impl futures::stream::Stream> + + Unpin + + std::marker::Send + + 'static, + io_stats: Option, +) -> BoxStream<'static, super::Result> { + if let Some(io_stats) = io_stats { + let mut context = IOStatsByteStreamContextHandle::new(io_stats); + async_stream::stream! { + while let Some(val) = s.next().await { + if let Ok(ref val) = val { + context.mark_bytes_read(val.len()); + } + yield val + } + } + .boxed() + } else { + s.boxed() + } +} diff --git a/src/daft-parquet/src/file.rs b/src/daft-parquet/src/file.rs index b30e8acf93..22d0c3984b 100644 --- a/src/daft-parquet/src/file.rs +++ b/src/daft-parquet/src/file.rs @@ -3,7 +3,7 @@ use std::{collections::HashSet, sync::Arc}; use arrow2::io::parquet::read::schema::infer_schema_with_options; use common_error::DaftResult; use daft_core::{utils::arrow::cast_array_for_daft_if_needed, Series}; -use daft_io::IOClient; +use daft_io::{IOClient, IOStatsRef}; use daft_table::Table; use futures::{future::try_join_all, StreamExt}; use parquet2::{ @@ -142,10 +142,16 @@ pub(crate) fn build_row_ranges( } impl ParquetReaderBuilder { - pub async fn from_uri(uri: &str, io_client: Arc) -> super::Result { + pub async fn from_uri( + uri: &str, + io_client: Arc, + io_stats: Option, + ) -> super::Result { // TODO(sammy): We actually don't need this since we can do negative offsets when reading the metadata - let size = io_client.single_url_get_size(uri.into()).await?; - let metadata = read_parquet_metadata(uri, size, io_client).await?; + let size = io_client + .single_url_get_size(uri.into(), io_stats.clone()) + .await?; + let metadata = read_parquet_metadata(uri, size, io_client, io_stats).await?; let num_rows = metadata.num_rows; Ok(ParquetReaderBuilder { uri: uri.into(), @@ -301,7 +307,11 @@ impl ParquetFileReader { Ok(read_planner) } - pub fn prebuffer_ranges(&self, io_client: Arc) -> DaftResult> { + pub fn prebuffer_ranges( + &self, + io_client: Arc, + io_stats: Option, + ) -> DaftResult> { let mut read_planner = self.naive_read_plan()?; // TODO(sammy) these values should be populated by io_client read_planner.add_pass(Box::new(SplitLargeRequestPass { @@ -315,7 +325,7 @@ impl ParquetFileReader { })); read_planner.run_passes()?; - read_planner.collect(io_client) + read_planner.collect(io_client, io_stats) } pub async fn read_from_ranges_into_table( diff --git a/src/daft-parquet/src/metadata.rs b/src/daft-parquet/src/metadata.rs index 974c678fba..92969294da 100644 --- a/src/daft-parquet/src/metadata.rs +++ b/src/daft-parquet/src/metadata.rs @@ -1,6 +1,6 @@ use std::sync::Arc; -use daft_io::IOClient; +use daft_io::{IOClient, IOStatsRef}; use parquet2::{metadata::FileMetaData, read::deserialize_metadata}; use snafu::ResultExt; @@ -15,6 +15,7 @@ pub async fn read_parquet_metadata( uri: &str, size: usize, io_client: Arc, + io_stats: Option, ) -> super::Result { const FOOTER_SIZE: usize = 8; const PARQUET_MAGIC: [u8; 4] = [b'P', b'A', b'R', b'1']; @@ -25,7 +26,7 @@ pub async fn read_parquet_metadata( let start = size.saturating_sub(default_end_len); let mut data = io_client - .single_url_get(uri.into(), Some(start..size)) + .single_url_get(uri.into(), Some(start..size), io_stats.clone()) .await? .bytes() .await?; @@ -57,7 +58,7 @@ pub async fn read_parquet_metadata( let start = size.saturating_sub(footer_len); data = io_client - .single_url_get(uri.into(), Some(start..size)) + .single_url_get(uri.into(), Some(start..size), io_stats) .await? .bytes() .await?; @@ -103,7 +104,7 @@ mod tests { io_config.s3.anonymous = true; let io_client = Arc::new(IOClient::new(io_config.into())?); - let metadata = read_parquet_metadata(file, size, io_client.clone()).await?; + let metadata = read_parquet_metadata(file, size, io_client.clone(), None).await?; assert_eq!(metadata.num_rows, 100); Ok(()) diff --git a/src/daft-parquet/src/python.rs b/src/daft-parquet/src/python.rs index 51244c8b24..c49c0fb7e6 100644 --- a/src/daft-parquet/src/python.rs +++ b/src/daft-parquet/src/python.rs @@ -5,7 +5,7 @@ pub mod pylib { ffi::field_to_py, python::{datatype::PyTimeUnit, schema::PySchema, PySeries}, }; - use daft_io::{get_io_client, python::IOConfig}; + use daft_io::{get_io_client, python::IOConfig, IOStatsContext}; use daft_table::python::PyTable; use pyo3::{pyfunction, types::PyModule, PyResult, Python}; use std::{collections::BTreeMap, sync::Arc}; @@ -26,6 +26,8 @@ pub mod pylib { coerce_int96_timestamp_unit: Option, ) -> PyResult { py.allow_threads(|| { + let io_stats = IOStatsContext::new(format!("read_parquet: for uri {uri}")); + let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), @@ -33,17 +35,19 @@ pub mod pylib { let schema_infer_options = ParquetSchemaInferenceOptions::new( coerce_int96_timestamp_unit.map(|tu| tu.timeunit), ); - Ok(crate::read::read_parquet( + let result = crate::read::read_parquet( uri, columns.as_deref(), start_offset, num_rows, row_groups, io_client, + Some(io_stats.clone()), multithreaded_io.unwrap_or(true), schema_infer_options, )? - .into()) + .into(); + Ok(result) }) } type PyArrowChunks = Vec>; @@ -100,6 +104,7 @@ pub mod pylib { num_rows, row_groups, io_client, + None, multithreaded_io.unwrap_or(true), schema_infer_options, ) @@ -123,6 +128,8 @@ pub mod pylib { coerce_int96_timestamp_unit: Option, ) -> PyResult> { py.allow_threads(|| { + let io_stats = IOStatsContext::new("read_parquet_bulk".to_string()); + let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), @@ -138,6 +145,7 @@ pub mod pylib { num_rows, row_groups, io_client, + Some(io_stats), num_parallel_tasks.unwrap_or(128) as usize, multithreaded_io.unwrap_or(true), &schema_infer_options, @@ -178,6 +186,7 @@ pub mod pylib { num_rows, row_groups, io_client, + None, num_parallel_tasks.unwrap_or(128) as usize, multithreaded_io.unwrap_or(true), schema_infer_options, @@ -201,6 +210,8 @@ pub mod pylib { coerce_int96_timestamp_unit: Option, ) -> PyResult { py.allow_threads(|| { + let io_stats = IOStatsContext::new(format!("read_parquet_schema: for uri {uri}")); + let schema_infer_options = ParquetSchemaInferenceOptions::new( coerce_int96_timestamp_unit.map(|tu| tu.timeunit), ); @@ -211,6 +222,7 @@ pub mod pylib { Ok(Arc::new(crate::read::read_parquet_schema( uri, io_client, + Some(io_stats), schema_infer_options, )?) .into()) @@ -225,11 +237,16 @@ pub mod pylib { multithreaded_io: Option, ) -> PyResult { py.allow_threads(|| { + let io_stats = IOStatsContext::new("read_parquet_statistics".to_string()); + let io_client = get_io_client( multithreaded_io.unwrap_or(true), io_config.unwrap_or_default().config.into(), )?; - Ok(crate::read::read_parquet_statistics(&uris.series, io_client)?.into()) + Ok( + crate::read::read_parquet_statistics(&uris.series, io_client, Some(io_stats))? + .into(), + ) }) } } diff --git a/src/daft-parquet/src/read.rs b/src/daft-parquet/src/read.rs index 1815a99a54..a0ae8918b6 100644 --- a/src/daft-parquet/src/read.rs +++ b/src/daft-parquet/src/read.rs @@ -7,7 +7,7 @@ use daft_core::{ schema::Schema, DataType, IntoSeries, Series, }; -use daft_io::{get_runtime, parse_url, IOClient, SourceType}; +use daft_io::{get_runtime, parse_url, IOClient, IOStatsRef, SourceType}; use daft_table::Table; use futures::{future::join_all, StreamExt, TryStreamExt}; use itertools::Itertools; @@ -49,6 +49,7 @@ impl From } } +#[allow(clippy::too_many_arguments)] async fn read_parquet_single( uri: &str, columns: Option<&[&str]>, @@ -56,6 +57,7 @@ async fn read_parquet_single( num_rows: Option, row_groups: Option>, io_client: Arc, + io_stats: Option, schema_infer_options: ParquetSchemaInferenceOptions, ) -> DaftResult
{ let (source_type, fixed_uri) = parse_url(uri)?; @@ -70,7 +72,8 @@ async fn read_parquet_single( ) .await } else { - let builder = ParquetReaderBuilder::from_uri(uri, io_client.clone()).await?; + let builder = + ParquetReaderBuilder::from_uri(uri, io_client.clone(), io_stats.clone()).await?; let builder = builder.set_infer_schema_options(schema_infer_options); let builder = if let Some(columns) = columns { @@ -92,7 +95,7 @@ async fn read_parquet_single( }; let parquet_reader = builder.build()?; - let ranges = parquet_reader.prebuffer_ranges(io_client)?; + let ranges = parquet_reader.prebuffer_ranges(io_client, io_stats)?; Ok(( metadata, parquet_reader.read_from_ranges_into_table(ranges).await?, @@ -165,6 +168,7 @@ async fn read_parquet_single( Ok(table) } +#[allow(clippy::too_many_arguments)] async fn read_parquet_single_into_arrow( uri: &str, columns: Option<&[&str]>, @@ -172,6 +176,7 @@ async fn read_parquet_single_into_arrow( num_rows: Option, row_groups: Option>, io_client: Arc, + io_stats: Option, schema_infer_options: ParquetSchemaInferenceOptions, ) -> DaftResult<(arrow2::datatypes::SchemaRef, Vec)> { let (source_type, fixed_uri) = parse_url(uri)?; @@ -188,7 +193,8 @@ async fn read_parquet_single_into_arrow( .await?; (metadata, Arc::new(schema), all_arrays) } else { - let builder = ParquetReaderBuilder::from_uri(uri, io_client.clone()).await?; + let builder = + ParquetReaderBuilder::from_uri(uri, io_client.clone(), io_stats.clone()).await?; let builder = builder.set_infer_schema_options(schema_infer_options); let builder = if let Some(columns) = columns { @@ -212,7 +218,7 @@ async fn read_parquet_single_into_arrow( let parquet_reader = builder.build()?; let schema = parquet_reader.arrow_schema().clone(); - let ranges = parquet_reader.prebuffer_ranges(io_client)?; + let ranges = parquet_reader.prebuffer_ranges(io_client, io_stats)?; let all_arrays = parquet_reader .read_from_ranges_into_arrow_arrays(ranges) .await?; @@ -306,6 +312,7 @@ pub fn read_parquet( num_rows: Option, row_groups: Option>, io_client: Arc, + io_stats: Option, multithreaded_io: bool, schema_infer_options: ParquetSchemaInferenceOptions, ) -> DaftResult
{ @@ -319,6 +326,7 @@ pub fn read_parquet( num_rows, row_groups, io_client, + io_stats, schema_infer_options, ) .await @@ -334,6 +342,7 @@ pub fn read_parquet_into_pyarrow( num_rows: Option, row_groups: Option>, io_client: Arc, + io_stats: Option, multithreaded_io: bool, schema_infer_options: ParquetSchemaInferenceOptions, ) -> DaftResult { @@ -347,6 +356,7 @@ pub fn read_parquet_into_pyarrow( num_rows, row_groups, io_client, + io_stats, schema_infer_options, ) .await @@ -361,6 +371,7 @@ pub fn read_parquet_bulk( num_rows: Option, row_groups: Option>>, io_client: Arc, + io_stats: Option, num_parallel_tasks: usize, multithreaded_io: bool, schema_infer_options: &ParquetSchemaInferenceOptions, @@ -388,6 +399,8 @@ pub fn read_parquet_bulk( }; let io_client = io_client.clone(); + let io_stats = io_stats.clone(); + let schema_infer_options = schema_infer_options.clone(); tokio::task::spawn(async move { let columns = owned_columns @@ -402,6 +415,7 @@ pub fn read_parquet_bulk( num_rows, owned_row_group, io_client, + io_stats, schema_infer_options, ) .await?, @@ -428,6 +442,7 @@ pub fn read_parquet_into_pyarrow_bulk( num_rows: Option, row_groups: Option>>, io_client: Arc, + io_stats: Option, num_parallel_tasks: usize, multithreaded_io: bool, schema_infer_options: ParquetSchemaInferenceOptions, @@ -456,6 +471,8 @@ pub fn read_parquet_into_pyarrow_bulk( }; let io_client = io_client.clone(); + let io_stats = io_stats.clone(); + let schema_infer_options = schema_infer_options.clone(); tokio::task::spawn(async move { let columns = owned_columns @@ -470,6 +487,7 @@ pub fn read_parquet_into_pyarrow_bulk( num_rows, owned_row_group, io_client, + io_stats, schema_infer_options, ) .await?, @@ -489,18 +507,24 @@ pub fn read_parquet_into_pyarrow_bulk( pub fn read_parquet_schema( uri: &str, io_client: Arc, + io_stats: Option, schema_inference_options: ParquetSchemaInferenceOptions, ) -> DaftResult { let runtime_handle = get_runtime(true)?; let _rt_guard = runtime_handle.enter(); - let builder = runtime_handle - .block_on(async { ParquetReaderBuilder::from_uri(uri, io_client.clone()).await })?; + let builder = runtime_handle.block_on(async { + ParquetReaderBuilder::from_uri(uri, io_client.clone(), io_stats).await + })?; let builder = builder.set_infer_schema_options(schema_inference_options); Schema::try_from(builder.build()?.arrow_schema().as_ref()) } -pub fn read_parquet_statistics(uris: &Series, io_client: Arc) -> DaftResult
{ +pub fn read_parquet_statistics( + uris: &Series, + io_client: Arc, + io_stats: Option, +) -> DaftResult
{ let runtime_handle = get_runtime(true)?; let _rt_guard = runtime_handle.enter(); @@ -518,9 +542,12 @@ pub fn read_parquet_statistics(uris: &Series, io_client: Arc) -> DaftR let handles_iter = values.iter().map(|uri| { let owned_string = uri.map(|v| v.to_string()); let owned_client = io_client.clone(); + let io_stats = io_stats.clone(); + tokio::spawn(async move { if let Some(owned_string) = owned_string { - let builder = ParquetReaderBuilder::from_uri(&owned_string, owned_client).await?; + let builder = + ParquetReaderBuilder::from_uri(&owned_string, owned_client, io_stats).await?; let num_rows = builder.metadata().num_rows; let num_row_groups = builder.metadata().row_groups.len(); let version_num = builder.metadata().version; @@ -596,6 +623,7 @@ mod tests { None, None, io_client, + None, true, Default::default(), )?; diff --git a/src/daft-parquet/src/read_planner.rs b/src/daft-parquet/src/read_planner.rs index 79133c84cf..f9615024c7 100644 --- a/src/daft-parquet/src/read_planner.rs +++ b/src/daft-parquet/src/read_planner.rs @@ -2,7 +2,7 @@ use std::{fmt::Display, ops::Range, sync::Arc}; use bytes::Bytes; use common_error::DaftResult; -use daft_io::IOClient; +use daft_io::{IOClient, IOStatsRef}; use futures::StreamExt; use tokio::task::JoinHandle; @@ -148,16 +148,22 @@ impl ReadPlanner { Ok(()) } - pub fn collect(self, io_client: Arc) -> DaftResult> { + pub fn collect( + self, + io_client: Arc, + io_stats: Option, + ) -> DaftResult> { let mut entries = Vec::with_capacity(self.ranges.len()); for range in self.ranges { let owned_io_client = io_client.clone(); let owned_url = self.source.clone(); + let owned_io_stats = io_stats.clone(); + let start = range.start; let end = range.end; let join_handle = tokio::spawn(async move { let get_result = owned_io_client - .single_url_get(owned_url, Some(range.clone())) + .single_url_get(owned_url, Some(range.clone()), owned_io_stats) .await?; get_result.bytes().await }); From 9d20890e40079a065f9b77f8fe6716c7cc44249d Mon Sep 17 00:00:00 2001 From: Jay Chia <17691182+jaychia@users.noreply.github.com> Date: Mon, 16 Oct 2023 15:19:12 +0900 Subject: [PATCH 14/14] [CHORE] Refactor logging (#1489) 1. Removes `loguru` as a dependency 2. Removes custom handler that forwards logs to loguru 3. Removes all custom formatting of our logs 4. Warning which runner is being used is only done now for when the PyRunner is being used with an existing Ray connection I did (3) because Daft is a library, and custom formatting of logs should be done by the application. E.g. if a user was building a webserver with their own custom logging setup, then Daft mangling log formatting on the root logger would be very annoying. This is what our logs look like now on a jupyter notebook: image Enabling more verbose logs at a higher level (e.g. INFO logs) is performed by the user/embedding that uses Daft, e.g. ```python import logging logging.basicConfig( format='%(asctime)s,%(msecs)03d %(levelname)-8s [%(pathname)s:%(lineno)d] %(message)s', datefmt='%Y-%m-%d:%H:%M:%S', level=logging.INFO, ) # Outputs logs that look like: # 2023-10-16:11:25:46,195 INFO [/Users/jaychia/.cargo/registry/src/index.crates.io-6f17d22bba15001f/aws-config-0.55.3/src/meta/region.rs:43] load_region; provider=EnvironmentVariableRegionProvider { env: Env(Real) } ``` Daft now respects normal Python logging and does not rely on loguru at all to do any of this configurations. This lets us play much nicer with applications that use normal Python logging. --------- Co-authored-by: Jay Chia --- benchmarking/tpch/__main__.py | 4 +- benchmarking/tpch/data_generation.py | 5 ++- .../tpch/pipelined_data_generation.py | 5 ++- daft/__init__.py | 8 ---- daft/context.py | 20 ++++++++-- daft/dataframe/to_torch.py | 3 +- daft/execution/physical_plan.py | 37 ++++++++---------- daft/filesystem.py | 4 +- daft/internal/rule_runner.py | 5 ++- daft/internal/treenode.py | 5 ++- daft/logging.py | 39 ------------------- daft/logical/optimizer.py | 4 +- daft/runners/profiler.py | 6 ++- daft/runners/pyrunner.py | 15 ++++--- daft/runners/ray_runner.py | 10 ++--- daft/table/table.py | 11 ++---- daft/udf_library/url_udfs.py | 6 ++- pyproject.toml | 1 - src/daft-table/src/lib.rs | 9 ++++- 19 files changed, 86 insertions(+), 111 deletions(-) delete mode 100644 daft/logging.py diff --git a/benchmarking/tpch/__main__.py b/benchmarking/tpch/__main__.py index 4be8e871e1..002d173c6b 100644 --- a/benchmarking/tpch/__main__.py +++ b/benchmarking/tpch/__main__.py @@ -3,6 +3,7 @@ import argparse import contextlib import csv +import logging import math import os import platform @@ -13,7 +14,6 @@ from typing import Any, Callable import ray -from loguru import logger import daft from benchmarking.tpch import answers, data_generation @@ -21,6 +21,8 @@ from daft.context import get_context from daft.runners.profiler import profiler +logger = logging.getLogger(__name__) + ALL_TABLES = [ "part", "supplier", diff --git a/benchmarking/tpch/data_generation.py b/benchmarking/tpch/data_generation.py index d38af0c5c1..bc7f3ac576 100644 --- a/benchmarking/tpch/data_generation.py +++ b/benchmarking/tpch/data_generation.py @@ -1,6 +1,7 @@ from __future__ import annotations import argparse +import logging import math import os import shlex @@ -8,10 +9,10 @@ import subprocess from glob import glob -from loguru import logger - import daft +logger = logging.getLogger(__name__) + SCHEMA = { "part": [ "P_PARTKEY", diff --git a/benchmarking/tpch/pipelined_data_generation.py b/benchmarking/tpch/pipelined_data_generation.py index 1e08ee7a1a..3fa3f2b701 100644 --- a/benchmarking/tpch/pipelined_data_generation.py +++ b/benchmarking/tpch/pipelined_data_generation.py @@ -15,6 +15,7 @@ import argparse import glob +import logging import os import pathlib import shlex @@ -22,10 +23,10 @@ import subprocess from multiprocessing import Pool -from loguru import logger - from benchmarking.tpch.data_generation import gen_parquet +logger = logging.getLogger(__name__) + STATIC_TABLES = ["nation", "region"] diff --git a/daft/__init__.py b/daft/__init__.py index 00b6a8ed4c..173173adb9 100644 --- a/daft/__init__.py +++ b/daft/__init__.py @@ -2,8 +2,6 @@ import os -from daft.logging import setup_logger - ### # Set up code coverage for when running code coverage with ray ### @@ -20,12 +18,6 @@ "Environ: {!r} " "Exception: {!r}\n".format({k: v for k, v in os.environ.items() if k.startswith("COV_CORE")}, exc) ) -### -# Setup logging -### - - -setup_logger() ### # Get build constants from Rust .so diff --git a/daft/context.py b/daft/context.py index 848b0d9612..b8273c7c0f 100644 --- a/daft/context.py +++ b/daft/context.py @@ -1,16 +1,17 @@ from __future__ import annotations import dataclasses +import logging import os import warnings from typing import TYPE_CHECKING, ClassVar -from loguru import logger - if TYPE_CHECKING: from daft.logical.builder import LogicalPlanBuilder from daft.runners.runner import Runner +logger = logging.getLogger(__name__) + class _RunnerConfig: name = ClassVar[str] @@ -75,7 +76,6 @@ def runner(self) -> Runner: if self.runner_config.name == "ray": from daft.runners.ray_runner import RayRunner - logger.info("Using RayRunner") assert isinstance(self.runner_config, _RayRunnerConfig) _RUNNER = RayRunner( address=self.runner_config.address, @@ -84,7 +84,19 @@ def runner(self) -> Runner: elif self.runner_config.name == "py": from daft.runners.pyrunner import PyRunner - logger.info("Using PyRunner") + try: + import ray + + if ray.is_initialized(): + logger.warning( + "WARNING: Daft is NOT using Ray for execution!\n" + "Daft is using the PyRunner but we detected an active Ray connection. " + "If you intended to use the Daft RayRunner, please first run `daft.context.set_runner_ray()` " + "before executing Daft queries." + ) + except ImportError: + pass + assert isinstance(self.runner_config, _PyRunnerConfig) _RUNNER = PyRunner(use_thread_pool=self.runner_config.use_thread_pool) diff --git a/daft/dataframe/to_torch.py b/daft/dataframe/to_torch.py index f3b30b888f..1d637e83a4 100644 --- a/daft/dataframe/to_torch.py +++ b/daft/dataframe/to_torch.py @@ -1,8 +1,9 @@ from __future__ import annotations +import logging from typing import Any, Iterable, Iterator -from loguru import logger +logger = logging.getLogger(__name__) try: # When available, subclass from the newer torchdata DataPipes instead of torch Datasets. diff --git a/daft/execution/physical_plan.py b/daft/execution/physical_plan.py index 1ff6de696c..e4bc572bc6 100644 --- a/daft/execution/physical_plan.py +++ b/daft/execution/physical_plan.py @@ -13,13 +13,12 @@ from __future__ import annotations +import logging import math import pathlib from collections import deque from typing import Generator, Iterator, TypeVar, Union -from loguru import logger - from daft.daft import ( FileFormat, FileFormatConfig, @@ -40,6 +39,8 @@ from daft.logical.schema import Schema from daft.runners.partitioning import PartialPartitionMetadata +logger = logging.getLogger(__name__) + PartitionT = TypeVar("PartitionT") T = TypeVar("T") @@ -123,7 +124,7 @@ def file_read( except StopIteration: if len(materializations) > 0: - logger.debug("file_read blocked on completion of first source in: {sources}", sources=materializations) + logger.debug(f"file_read blocked on completion of first source in: {materializations}") yield None else: return @@ -231,10 +232,8 @@ def join( if len(left_requests) + len(right_requests) > 0: logger.debug( "join blocked on completion of sources.\n" - "Left sources: {left_requests}\n" - "Right sources: {right_requests}", - left_requests=left_requests, - right_requests=right_requests, + f"Left sources: {left_requests}\n" + f"Right sources: {right_requests}", ) yield None @@ -339,7 +338,7 @@ def global_limit( # (Optimization. If we are doing limit(0) and already have a partition executing to use for it, just wait.) if remaining_rows == 0 and len(materializations) > 0: - logger.debug("global_limit blocked on completion of: {source}", source=materializations[0]) + logger.debug(f"global_limit blocked on completion of: {materializations[0]}") yield None continue @@ -364,9 +363,7 @@ def global_limit( except StopIteration: if len(materializations) > 0: - logger.debug( - "global_limit blocked on completion of first source in: {sources}", sources=materializations - ) + logger.debug(f"global_limit blocked on completion of first source in: {materializations}") yield None else: return @@ -396,9 +393,7 @@ def flatten_plan(child_plan: InProgressPhysicalPlan[PartitionT]) -> InProgressPh except StopIteration: if len(materializations) > 0: - logger.debug( - "flatten_plan blocked on completion of first source in: {sources}", sources=materializations - ) + logger.debug(f"flatten_plan blocked on completion of first source in: {materializations}") yield None else: return @@ -427,7 +422,7 @@ def split( yield step while any(not _.done() for _ in materializations): - logger.debug("split_to blocked on completion of all sources: {sources}", sources=materializations) + logger.debug(f"split_to blocked on completion of all sources: {materializations}") yield None splits_per_partition = deque([1 for _ in materializations]) @@ -517,7 +512,7 @@ def coalesce( except StopIteration: if len(materializations) > 0: - logger.debug("coalesce blocked on completion of a task in: {sources}", sources=materializations) + logger.debug(f"coalesce blocked on completion of a task in: {materializations}") yield None else: return @@ -547,7 +542,7 @@ def reduce( # All fanouts dispatched. Wait for all of them to materialize # (since we need all of them to emit even a single reduce). while any(not _.done() for _ in materializations): - logger.debug("reduce blocked on completion of all sources in: {sources}", sources=materializations) + logger.debug(f"reduce blocked on completion of all sources in: {materializations}") yield None inputs_to_reduce = [deque(_.partitions()) for _ in materializations] @@ -587,7 +582,7 @@ def sort( sample_materializations: deque[SingleOutputPartitionTask[PartitionT]] = deque() for source in source_materializations: while not source.done(): - logger.debug("sort blocked on completion of source: {source}", source=source) + logger.debug(f"sort blocked on completion of source: {source}") yield None sample = ( @@ -606,7 +601,7 @@ def sort( # Wait for samples to materialize. while any(not _.done() for _ in sample_materializations): - logger.debug("sort blocked on completion of all samples: {samples}", samples=sample_materializations) + logger.debug(f"sort blocked on completion of all samples: {sample_materializations}") yield None # Reduce the samples to get sort boundaries. @@ -628,7 +623,7 @@ def sort( # Wait for boundaries to materialize. while not boundaries.done(): - logger.debug("sort blocked on completion of boundary partition: {boundaries}", boundaries=boundaries) + logger.debug(f"sort blocked on completion of boundary partition: {boundaries}") yield None # Create a range fanout plan. @@ -699,7 +694,7 @@ def materialize( except StopIteration: if len(materializations) > 0: - logger.debug("materialize blocked on completion of all sources: {sources}", sources=materializations) + logger.debug(f"materialize blocked on completion of all sources: {materializations}") yield None else: return diff --git a/daft/filesystem.py b/daft/filesystem.py index b675c45ee4..d23747fb88 100644 --- a/daft/filesystem.py +++ b/daft/filesystem.py @@ -12,12 +12,12 @@ else: from typing import Literal +import logging from typing import Any import fsspec import pyarrow as pa from fsspec.registry import get_filesystem_class -from loguru import logger from pyarrow.fs import ( FileSystem, FSSpecHandler, @@ -28,6 +28,8 @@ from daft.daft import FileFormat, FileInfos, NativeStorageConfig, StorageConfig from daft.table import Table +logger = logging.getLogger(__name__) + _CACHED_FSES: dict[str, FileSystem] = {} diff --git a/daft/internal/rule_runner.py b/daft/internal/rule_runner.py index 22688790d0..0e0cb32249 100644 --- a/daft/internal/rule_runner.py +++ b/daft/internal/rule_runner.py @@ -1,13 +1,14 @@ from __future__ import annotations +import logging from dataclasses import dataclass from typing import Generic, TypeVar -from loguru import logger - from daft.internal.rule import Rule from daft.internal.treenode import TreeNode +logger = logging.getLogger(__name__) + TreeNodeType = TypeVar("TreeNodeType", bound="TreeNode") diff --git a/daft/internal/treenode.py b/daft/internal/treenode.py index 54fb4e30a3..9de36ac300 100644 --- a/daft/internal/treenode.py +++ b/daft/internal/treenode.py @@ -1,14 +1,15 @@ from __future__ import annotations +import logging import os import typing from typing import TYPE_CHECKING, Generic, List, TypeVar, cast -from loguru import logger - if TYPE_CHECKING: from daft.internal.rule import Rule +logger = logging.getLogger(__name__) + TreeNodeType = TypeVar("TreeNodeType", bound="TreeNode") diff --git a/daft/logging.py b/daft/logging.py deleted file mode 100644 index 971849930b..0000000000 --- a/daft/logging.py +++ /dev/null @@ -1,39 +0,0 @@ -from __future__ import annotations - -import sys - - -def setup_logger() -> None: - import inspect - import logging - - from loguru import logger - from loguru._defaults import env - - logger.remove() - LOGURU_LEVEL = env("LOGURU_LEVEL", str, "INFO") - logger.add(sys.stderr, level=LOGURU_LEVEL) - - class InterceptHandler(logging.Handler): - def filter(self, record: logging.LogRecord) -> bool: - parent = super().filter(record) - return parent or record.pathname.startswith("src/") - - def emit(self, record: logging.LogRecord) -> None: - # Get corresponding Loguru level if it exists. - level: str | int - try: - level = logger.level(record.levelname).name - except ValueError: - level = record.levelno - - # Find caller from where originated the logged message. - frame, depth = inspect.currentframe(), 0 - while frame and (depth == 0 or frame.f_code.co_filename == logging.__file__): - frame = frame.f_back - depth += 1 - - logger.opt(depth=depth - 1, exception=record.exc_info).log(level, record.getMessage()) - - logging.getLogger().setLevel(logger.level(LOGURU_LEVEL).no) - logging.getLogger().addHandler(InterceptHandler()) diff --git a/daft/logical/optimizer.py b/daft/logical/optimizer.py index 4320d834a8..6b4db7537a 100644 --- a/daft/logical/optimizer.py +++ b/daft/logical/optimizer.py @@ -1,6 +1,6 @@ from __future__ import annotations -from loguru import logger +import logging from daft.daft import PartitionScheme, ResourceRequest from daft.expressions import ExpressionsProjection, col @@ -21,6 +21,8 @@ UnaryNode, ) +logger = logging.getLogger(__name__) + class PushDownPredicates(Rule[LogicalPlan]): """Push Filter nodes down through its children when possible - run filters early to reduce amount of data processed""" diff --git a/daft/runners/profiler.py b/daft/runners/profiler.py index 095e6425f3..0336b29de2 100644 --- a/daft/runners/profiler.py +++ b/daft/runners/profiler.py @@ -1,17 +1,19 @@ from __future__ import annotations +import logging import os from contextlib import contextmanager from typing import TYPE_CHECKING -from loguru import logger - if TYPE_CHECKING: from viztracer import VizTracer ACTIVE = False +logger = logging.getLogger(__name__) + + @contextmanager def profiler(filename: str) -> VizTracer: if int(os.environ.get("DAFT_PROFILING", 0)) == 1: diff --git a/daft/runners/pyrunner.py b/daft/runners/pyrunner.py index f291727b13..2c34ae6b1f 100644 --- a/daft/runners/pyrunner.py +++ b/daft/runners/pyrunner.py @@ -1,12 +1,12 @@ from __future__ import annotations +import logging import multiprocessing from concurrent import futures from dataclasses import dataclass from typing import TYPE_CHECKING, Iterable, Iterator import psutil -from loguru import logger from daft.daft import ( FileFormatConfig, @@ -36,6 +36,9 @@ import fsspec +logger = logging.getLogger(__name__) + + @dataclass class LocalPartitionSet(PartitionSet[Table]): _partitions: dict[PartID, Table] @@ -200,15 +203,13 @@ def _physical_plan_to_partitions(self, plan: physical_plan.MaterializedPhysicalP and next_step.resource_request.num_gpus > 0 ) ): - logger.debug( - "Running task synchronously in main thread: {next_step}", next_step=next_step - ) + logger.debug(f"Running task synchronously in main thread: {next_step}") partitions = self.build_partitions(next_step.instructions, *next_step.inputs) next_step.set_result([PyMaterializedResult(partition) for partition in partitions]) else: # Submit the task for execution. - logger.debug("Submitting task for execution: {next_step}", next_step=next_step) + logger.debug(f"Submitting task for execution: {next_step}") future = thread_pool.submit( self.build_partitions, next_step.instructions, *next_step.inputs ) @@ -230,9 +231,7 @@ def _physical_plan_to_partitions(self, plan: physical_plan.MaterializedPhysicalP done_task = inflight_tasks.pop(done_id) partitions = done_future.result() - logger.debug( - "Task completed: {done_id} -> {partitions}", done_id=done_id, partitions=partitions - ) + logger.debug(f"Task completed: {done_id} -> {partitions}") done_task.set_result([PyMaterializedResult(partition) for partition in partitions]) if next_step is None: diff --git a/daft/runners/ray_runner.py b/daft/runners/ray_runner.py index a0aa7dc31e..c4b97a1738 100644 --- a/daft/runners/ray_runner.py +++ b/daft/runners/ray_runner.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import threading import time import uuid @@ -10,11 +11,12 @@ from typing import TYPE_CHECKING, Any, Generator, Iterable, Iterator import pyarrow as pa -from loguru import logger from daft.logical.builder import LogicalPlanBuilder from daft.planner import PhysicalPlanScheduler +logger = logging.getLogger(__name__) + try: import ray except ImportError: @@ -452,8 +454,6 @@ def _run_plan( psets: dict[str, ray.ObjectRef], result_uuid: str, ) -> None: - from loguru import logger - # Get executable tasks from plan scheduler. tasks = plan_scheduler.to_partition_tasks(psets, is_ray_runner=True) @@ -495,9 +495,7 @@ def _run_plan( # If it is a no-op task, just run it locally immediately. elif len(next_step.instructions) == 0: - logger.debug( - "Running task synchronously in main thread: {next_step}", next_step=next_step - ) + logger.debug(f"Running task synchronously in main thread: {next_step}") assert isinstance(next_step, SingleOutputPartitionTask) next_step.set_result( [RayMaterializedResult(partition) for partition in next_step.inputs] diff --git a/daft/table/table.py b/daft/table/table.py index fb7366bfb8..ec865bd05d 100644 --- a/daft/table/table.py +++ b/daft/table/table.py @@ -1,10 +1,9 @@ from __future__ import annotations -import sys +import logging from typing import TYPE_CHECKING, Any import pyarrow as pa -from loguru import logger from daft.arrow_utils import ensure_table from daft.daft import JoinType @@ -20,11 +19,6 @@ from daft.logical.schema import Schema from daft.series import Series -if sys.version_info < (3, 8): - pass -else: - pass - _NUMPY_AVAILABLE = True try: import numpy as np @@ -45,6 +39,9 @@ from daft.io import IOConfig +logger = logging.getLogger(__name__) + + class Table: _table: _PyTable diff --git a/daft/udf_library/url_udfs.py b/daft/udf_library/url_udfs.py index 989b87f016..b5a9c682d7 100644 --- a/daft/udf_library/url_udfs.py +++ b/daft/udf_library/url_udfs.py @@ -1,5 +1,6 @@ from __future__ import annotations +import logging import sys import threading from concurrent.futures import ThreadPoolExecutor, as_completed @@ -19,6 +20,9 @@ thread_local = threading.local() +logger = logging.getLogger(__name__) + + def _worker_thread_initializer() -> None: """Initializes per-thread local state""" thread_local.filesystems_cache = {} @@ -27,8 +31,6 @@ def _worker_thread_initializer() -> None: def _download( path: str | None, on_error: Literal["raise"] | Literal["null"], fs: fsspec.AbstractFileSystem | None ) -> bytes | None: - from loguru import logger - if path is None: return None protocol = filesystem.get_protocol_from_path(path) diff --git a/pyproject.toml b/pyproject.toml index 56f5c604ef..8ddc1e9fcc 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -7,7 +7,6 @@ authors = [{name = "Eventual Inc", email = "daft@eventualcomputing.com"}] dependencies = [ "pyarrow >= 6.0.1", "fsspec[http]", - "loguru", "psutil", "typing-extensions >= 4.0.0; python_version < '3.8'", "pickle5 >= 0.0.12; python_version < '3.8'" diff --git a/src/daft-table/src/lib.rs b/src/daft-table/src/lib.rs index 98f2e9da54..1f4df77e47 100644 --- a/src/daft-table/src/lib.rs +++ b/src/daft-table/src/lib.rs @@ -498,7 +498,14 @@ impl Table { let mut str_val = s.str_value(i).unwrap(); if let Some(max_col_width) = max_col_width { if str_val.len() > max_col_width { - str_val = format!("{}...", &str_val[..max_col_width - 3]); + str_val = format!( + "{}...", + &str_val + .char_indices() + .take(max_col_width - 3) + .map(|(_, c)| c) + .collect::() + ); } } str_val