diff --git a/CHANGES.rst b/CHANGES.rst index 3fe583ef..691c2133 100644 --- a/CHANGES.rst +++ b/CHANGES.rst @@ -22,6 +22,19 @@ Version 3.X.X (2019-09-XX) - Fix an issue where an empty dataframe of a partition in a multi-table dataset would raise a schema validation exception - Remove support for pyarrow < 0.13.0 +- Fix an issue where the `dispatch_by` keyword would disable partition pruning +- Additional functions in `kartothek.serialization` module for dealing with predicates + * :func:`~kartothek.serialization.check_predicates` + * :func:`~kartothek.serialization.filter_predicates_by_column` + * :func:`~kartothek.serialization.columns_in_predicates` +- Added available types for type annotation when dealing with predicates + * `~kartothek.serialization.PredicatesType` + * `~kartothek.serialization.ConjunctionType` + * `~kartothek.serialization.LiteralType` + +Internal changes +^^^^^^^^^^^^^^^^ +- Move the docs module from `io_components` to `core` Version 3.4.0 (2019-09-17) diff --git a/kartothek/core/dataset.py b/kartothek/core/dataset.py index 26eda225..f223d7ba 100644 --- a/kartothek/core/dataset.py +++ b/kartothek/core/dataset.py @@ -15,6 +15,7 @@ from kartothek.core._compat import load_json from kartothek.core._mixins import CopyMixin from kartothek.core.common_metadata import SchemaWrapper, read_schema_metadata +from kartothek.core.docs import default_docs from kartothek.core.index import ( ExplicitSecondaryIndex, IndexBase, @@ -25,9 +26,12 @@ from kartothek.core.partition import Partition from kartothek.core.urlencode import decode_key, quote_indices from kartothek.core.utils import verify_metadata_version +from kartothek.serialization import PredicatesType, columns_in_predicates _logger = logging.getLogger(__name__) +TableMetaType = Dict[str, SchemaWrapper] + def _validate_uuid(uuid: str) -> bool: return re.match(r"[a-zA-Z0-9+\-_]+$", uuid) is not None @@ -340,14 +344,19 @@ def load_partition_indices(self) -> "DatasetMetadataBase": partitions=self.partitions, table_meta=self.table_meta, default_dtype=pa.string() if self.metadata_version == 3 else None, + partition_keys=self.partition_keys, ) combined_indices = self.indices.copy() combined_indices.update(indices) return self.copy(indices=combined_indices) + @default_docs def get_indices_as_dataframe( - self, columns: Optional[List[str]] = None, date_as_object: bool = True - ) -> pd.DataFrame: + self, + columns: Optional[List[str]] = None, + date_as_object: bool = True, + predicates: PredicatesType = None, + ): """ Converts the dataset indices to a pandas dataframe. @@ -363,18 +372,22 @@ def get_indices_as_dataframe( Parameters ---------- - columns: list of str - If provided, the dataframe will only be constructed for the provided columns/indices. - If `None` is given, all indices are included. - date_as_object: bool, optional - Cast dates to objects. """ if columns is None: columns = sorted(self.indices.keys()) + elif columns == []: + return pd.DataFrame(index=self.partitions) - result = None dfs = [] - for col in columns: + columns_to_scan = columns[:] + if predicates: + predicate_columns = columns_in_predicates(predicates) + # Don't use set logic to preserve order + for col in predicate_columns: + if col not in columns_to_scan and col in self.indices: + columns_to_scan.append(col) + + for col in columns_to_scan: if col not in self.indices: if col in self.partition_keys: raise RuntimeError( @@ -383,19 +396,30 @@ def get_indices_as_dataframe( raise ValueError("Index `{}` unknown.") df = pd.DataFrame( self.indices[col].as_flat_series( - partitions_as_index=True, date_as_object=date_as_object + partitions_as_index=True, + date_as_object=date_as_object, + predicates=predicates, ) ) dfs.append(df) # start joining with the small ones - for df in sorted(dfs, key=lambda df: len(df)): - if result is None: - result = df - continue + sorted_dfs = sorted(dfs, key=lambda df: len(df)) + result = sorted_dfs.pop(0) + for df in sorted_dfs: result = result.merge(df, left_index=True, right_index=True, copy=False) - return result + if predicates: + index_name = result.index.name + result = ( + result.loc[:, columns] + .reset_index() + .drop_duplicates() + .set_index(index_name) + ) + return result + else: + return result class DatasetMetadata(DatasetMetadataBase): @@ -618,13 +642,24 @@ def _get_type_from_meta( ) +def _empty_partition_indices( + partition_keys: List[str], table_meta: TableMetaType, default_dtype: pa.DataType +): + indices = {} + for col in partition_keys: + arrow_type = _get_type_from_meta(table_meta, col, default_dtype) + indices[col] = PartitionIndex(column=col, index_dct={}, dtype=arrow_type) + return indices + + def _construct_dynamic_index_from_partitions( partitions: Dict[str, Partition], - table_meta: Dict[str, SchemaWrapper], + table_meta: TableMetaType, default_dtype: pa.DataType, + partition_keys: List[str], ) -> Dict[str, PartitionIndex]: if len(partitions) == 0: - return {} + return _empty_partition_indices(partition_keys, table_meta, default_dtype) def _get_files(part): if isinstance(part, dict): @@ -638,7 +673,7 @@ def _get_files(part): ) # partitions is NOT empty here, see check above first_partition_files = _get_files(first_partition) if not first_partition_files: - return {} + return _empty_partition_indices(partition_keys, table_meta, default_dtype) key_table = next(iter(first_partition_files.keys())) storage_keys = ( (key, _get_files(part)[key_table]) for key, part in partitions.items() diff --git a/kartothek/io_components/docs.py b/kartothek/core/docs.py similarity index 100% rename from kartothek/io_components/docs.py rename to kartothek/core/docs.py diff --git a/kartothek/core/index.py b/kartothek/core/index.py index 42129ac6..127d836f 100644 --- a/kartothek/core/index.py +++ b/kartothek/core/index.py @@ -15,8 +15,15 @@ from kartothek.core._compat import ARROW_LARGER_EQ_0150 from kartothek.core._mixins import CopyMixin from kartothek.core.common_metadata import normalize_type +from kartothek.core.docs import default_docs from kartothek.core.urlencode import quote -from kartothek.serialization import filter_array_like +from kartothek.serialization import ( + PredicatesType, + check_predicates, + filter_array_like, + filter_df_from_predicates, + filter_predicates_by_column, +) from kartothek.serialization._parquet import _fix_pyarrow_07992_table ValueType = TypeVar("ValueType") @@ -421,25 +428,50 @@ def __eq__(self, other) -> bool: def __ne__(self, other) -> bool: return not (self == other) + @default_docs def as_flat_series( - self, compact=False, partitions_as_index=False, date_as_object=True + self, + compact: bool = False, + partitions_as_index: bool = False, + date_as_object: bool = False, + predicates: PredicatesType = None, ): """ Convert the Index object to a pandas.Series Parameters ---------- - compact: bool, optional + compact: If True, the index will be unique and the Series.values will be a list of partitions/values - partitions_as_index: bool, optional + partitions_as_index: If True, the relation between index values and partitions will be reverted for the output - date_as_object: bool, optional - Cast dates to objects. + predicates: + A list of predicates. If a literal within the provided predicates + references a column which is not part of this index, this literal is + interpreted as True. """ + check_predicates(predicates) table = _index_dct_to_table( self.index_dct, column=self.column, dtype=self.dtype ) df = table.to_pandas(date_as_object=date_as_object) + + if predicates is not None: + # If there is a conjunction without any reference to the index + # column the entire predicates expression is evaluated to True. In + # this case we do not need to filter the dataframe anymore + for conjunction in predicates: + new_conjunction = filter_predicates_by_column( + [conjunction], [self.column] + ) + if new_conjunction is None: + break + else: + filtered_predicates = filter_predicates_by_column( + predicates, [self.column] + ) + df = filter_df_from_predicates(df, predicates=filtered_predicates) + result_column = _PARTITION_COLUMN_NAME # This is the way the dictionary is directly translated # value: [partition] @@ -451,9 +483,13 @@ def as_flat_series( # value: part_2 # value2: part_1 if partitions_as_index or not compact: - keys = np.concatenate(df[_PARTITION_COLUMN_NAME].values) + if len(df) == 0: + keys = np.array([], dtype=df[_PARTITION_COLUMN_NAME].values.dtype) + else: + keys = np.concatenate(df[_PARTITION_COLUMN_NAME].values) lengths = df[_PARTITION_COLUMN_NAME].apply(len).values + lengths = lengths.astype(int) values_index = np.repeat(np.arange(len(df)), lengths) values = df[self.column].values[values_index] diff --git a/kartothek/io/dask/bag.py b/kartothek/io/dask/bag.py index 6ab86d89..bed2cc90 100644 --- a/kartothek/io/dask/bag.py +++ b/kartothek/io/dask/bag.py @@ -5,6 +5,7 @@ import dask.bag as db from kartothek.core import naming +from kartothek.core.docs import default_docs from kartothek.core.factory import _ensure_factory from kartothek.core.utils import _check_callable from kartothek.core.uuid import gen_uuid @@ -14,7 +15,6 @@ _identity, _maybe_get_categoricals_from_index, ) -from kartothek.io_components.docs import default_docs from kartothek.io_components.index import update_indices_from_partitions from kartothek.io_components.metapartition import ( MetaPartition, diff --git a/kartothek/io/dask/dataframe.py b/kartothek/io/dask/dataframe.py index 71cf265c..bfc8c30e 100644 --- a/kartothek/io/dask/dataframe.py +++ b/kartothek/io/dask/dataframe.py @@ -6,9 +6,9 @@ import numpy as np from kartothek.core.common_metadata import empty_dataframe_from_schema +from kartothek.core.docs import default_docs from kartothek.core.factory import _ensure_factory from kartothek.core.naming import DEFAULT_METADATA_VERSION -from kartothek.io_components.docs import default_docs from kartothek.io_components.metapartition import parse_input_to_metapartition from kartothek.io_components.update import update_dataset_from_partitions from kartothek.io_components.utils import ( diff --git a/kartothek/io/dask/delayed.py b/kartothek/io/dask/delayed.py index 065d7717..eb684752 100644 --- a/kartothek/io/dask/delayed.py +++ b/kartothek/io/dask/delayed.py @@ -8,6 +8,7 @@ from dask import delayed from kartothek.core import naming +from kartothek.core.docs import default_docs from kartothek.core.factory import _ensure_factory from kartothek.core.naming import DEFAULT_METADATA_VERSION from kartothek.core.utils import _check_callable @@ -17,7 +18,6 @@ delete_indices, delete_top_level_metadata, ) -from kartothek.io_components.docs import default_docs from kartothek.io_components.gc import delete_files, dispatch_files_to_gc from kartothek.io_components.merge import align_datasets from kartothek.io_components.metapartition import ( diff --git a/kartothek/io/eager.py b/kartothek/io/eager.py index 08755c27..37f3e076 100644 --- a/kartothek/io/eager.py +++ b/kartothek/io/eager.py @@ -11,6 +11,7 @@ store_schema_metadata, ) from kartothek.core.dataset import DatasetMetadataBuilder +from kartothek.core.docs import default_docs from kartothek.core.factory import _ensure_factory from kartothek.core.naming import ( DEFAULT_METADATA_STORAGE_FORMAT, @@ -24,7 +25,6 @@ delete_indices, delete_top_level_metadata, ) -from kartothek.io_components.docs import default_docs from kartothek.io_components.gc import delete_files, dispatch_files_to_gc from kartothek.io_components.index import update_indices_from_partitions from kartothek.io_components.metapartition import ( diff --git a/kartothek/io/iter.py b/kartothek/io/iter.py index cbee7280..b5890111 100644 --- a/kartothek/io/iter.py +++ b/kartothek/io/iter.py @@ -2,13 +2,13 @@ from functools import partial from typing import cast +from kartothek.core.docs import default_docs from kartothek.core.factory import _ensure_factory from kartothek.core.naming import ( DEFAULT_METADATA_STORAGE_FORMAT, DEFAULT_METADATA_VERSION, ) from kartothek.core.uuid import gen_uuid -from kartothek.io_components.docs import default_docs from kartothek.io_components.metapartition import ( MetaPartition, parse_input_to_metapartition, diff --git a/kartothek/io/testing/read.py b/kartothek/io/testing/read.py index d56c815d..d9a3a5ab 100644 --- a/kartothek/io/testing/read.py +++ b/kartothek/io/testing/read.py @@ -372,40 +372,22 @@ def test_read_dataset_as_dataframes_concat_primary( @pytest.mark.parametrize("dispatch_by", ["A", "B", "C"]) def test_read_dataset_as_dataframes_dispatch_by_single_col( - store_factory, + store_session_factory, + dataset_dispatch_by, bound_load_dataframes, backend_identifier, dispatch_by, output_type, metadata_version, + dataset_dispatch_by_uuid, ): if output_type == "table": pytest.skip() - cluster1 = pd.DataFrame( - {"A": [1, 1], "B": [10, 10], "C": [1, 2], "Content": ["cluster1", "cluster1"]} - ) - cluster2 = pd.DataFrame( - {"A": [1, 1], "B": [10, 10], "C": [2, 3], "Content": ["cluster2", "cluster2"]} - ) - cluster3 = pd.DataFrame({"A": [1], "B": [20], "C": [1], "Content": ["cluster3"]}) - cluster4 = pd.DataFrame( - {"A": [2, 2], "B": [10, 10], "C": [1, 2], "Content": ["cluster4", "cluster4"]} - ) - clusters = [cluster1, cluster2, cluster3, cluster4] - partitions = [{"data": [("data", c)]} for c in clusters] - - store_dataframes_as_dataset__iter( - df_generator=partitions, - store=store_factory, - dataset_uuid="partitioned_uuid", - metadata_version=metadata_version, - partition_on=["A", "B"], - secondary_indices=["C"], - ) - # Dispatch by primary index "A" dispatched_a = bound_load_dataframes( - dataset_uuid="partitioned_uuid", store=store_factory, dispatch_by=[dispatch_by] + dataset_uuid=dataset_dispatch_by_uuid, + store=store_session_factory, + dispatch_by=[dispatch_by], ) unique_a = set() @@ -420,15 +402,17 @@ def test_read_dataset_as_dataframes_dispatch_by_single_col( unique_a.add(unique_dispatch[0]) -def test_read_dataset_as_dataframes_dispatch_by_multi_col( - store_factory, - bound_load_dataframes, - backend_identifier, - output_type, - metadata_version, +@pytest.fixture(scope="session") +def dataset_dispatch_by_uuid(): + import uuid + + return uuid.uuid1().hex + + +@pytest.fixture(scope="session") +def dataset_dispatch_by( + metadata_version, store_session_factory, dataset_dispatch_by_uuid ): - if output_type == "table": - pytest.skip() cluster1 = pd.DataFrame( {"A": [1, 1], "B": [10, 10], "C": [1, 2], "Content": ["cluster1", "cluster1"]} ) @@ -440,20 +424,33 @@ def test_read_dataset_as_dataframes_dispatch_by_multi_col( {"A": [2, 2], "B": [10, 10], "C": [1, 2], "Content": ["cluster4", "cluster4"]} ) clusters = [cluster1, cluster2, cluster3, cluster4] + partitions = [{"data": [("data", c)]} for c in clusters] store_dataframes_as_dataset__iter( df_generator=partitions, - store=store_factory, - dataset_uuid="partitioned_uuid", + store=store_session_factory, + dataset_uuid=dataset_dispatch_by_uuid, metadata_version=metadata_version, partition_on=["A", "B"], secondary_indices=["C"], ) + return pd.concat(clusters).sort_values(["A", "B", "C"]).reset_index(drop=True) + + +def test_read_dataset_as_dataframes_dispatch_by_multi_col( + store_session_factory, + bound_load_dataframes, + output_type, + dataset_dispatch_by, + dataset_dispatch_by_uuid, +): + if output_type == "table": + pytest.skip() for dispatch_by in permutations(("A", "B", "C"), 2): dispatched = bound_load_dataframes( - dataset_uuid="partitioned_uuid", - store=store_factory, + dataset_uuid=dataset_dispatch_by_uuid, + store=store_session_factory, dispatch_by=dispatch_by, ) uniques = pd.DataFrame(columns=dispatch_by) @@ -469,6 +466,43 @@ def test_read_dataset_as_dataframes_dispatch_by_multi_col( assert not any(uniques.duplicated()) +@pytest.mark.parametrize( + "dispatch_by, predicates, expected_dispatches", + [ + # This should only dispatch one partition since there is only + # one file with valid data points + (["A"], [[("C", ">", 2)]], 1), + # We dispatch and restrict to one valie, i.e. one dispatch + (["B"], [[("B", "==", 10)]], 1), + # The same is true for a non-partition index col + (["C"], [[("C", "==", 1)]], 1), + # A condition where both primary and secondary indices need to work together + (["A", "C"], [[("A", ">", 1), ("C", "<", 3)]], 2), + ], +) +def test_read_dispatch_by_with_predicates( + store_session_factory, + dataset_dispatch_by_uuid, + bound_load_dataframes, + dataset_dispatch_by, + dispatch_by, + output_type, + expected_dispatches, + predicates, +): + if output_type == "table": + pytest.skip() + + dispatched = bound_load_dataframes( + dataset_uuid=dataset_dispatch_by_uuid, + store=store_session_factory, + dispatch_by=dispatch_by, + predicates=predicates, + ) + + assert len(dispatched) == expected_dispatches, dispatched + + def test_read_dataset_as_dataframes( dataset, store_session_factory, diff --git a/kartothek/io_components/metapartition.py b/kartothek/io_components/metapartition.py index 8516ee38..52db167d 100644 --- a/kartothek/io_components/metapartition.py +++ b/kartothek/io_components/metapartition.py @@ -25,6 +25,7 @@ validate_compatible, validate_shared_columns, ) +from kartothek.core.docs import default_docs from kartothek.core.index import ExplicitSecondaryIndex, IndexBase from kartothek.core.index import merge_indices as merge_indices_algo from kartothek.core.naming import get_partition_file_prefix @@ -32,7 +33,6 @@ from kartothek.core.urlencode import decode_key, quote_indices from kartothek.core.utils import ensure_string_type, verify_metadata_version from kartothek.core.uuid import gen_uuid -from kartothek.io_components.docs import default_docs from kartothek.io_components.utils import _instantiate_store, combine_metadata from kartothek.serialization import ( DataFrameSerializer, diff --git a/kartothek/io_components/read.py b/kartothek/io_components/read.py index 23e7146f..50b4f9c3 100644 --- a/kartothek/io_components/read.py +++ b/kartothek/io_components/read.py @@ -1,5 +1,5 @@ import warnings -from typing import Iterator, List, Union, cast +from typing import Iterator, List, Set, Union, cast import pandas as pd @@ -7,20 +7,7 @@ from kartothek.core.index import ExplicitSecondaryIndex from kartothek.io_components.metapartition import MetaPartition from kartothek.io_components.utils import _make_callable - - -def _index_to_dataframe(idx_name, idx, allowed_labels=None): - label_col = [] - value_col = [] - for value, labels in idx.items(): - for label in labels: - if allowed_labels is not None and label not in allowed_labels: - continue - label_col.append(label) - value_col.append(value) - df = pd.DataFrame({idx_name: value_col, "__partition__": label_col}) - - return df +from kartothek.serialization import check_predicates, columns_in_predicates def dispatch_metapartitions_from_factory( @@ -55,13 +42,27 @@ def dispatch_metapartitions_from_factory( raise RuntimeError( f"Dispatch columns must be indexed.\nRequested index: {dispatch_by} but available index columns: {sorted(dataset_factory.index_columns)}" ) + check_predicates(predicates) - if predicates is not None: - dataset_factory, allowed_labels = _allowed_labels_by_predicates( - predicates, dataset_factory, dispatch_by - ) - else: - allowed_labels = None + # Determine which indices need to be loaded. + index_cols: Set[str] = set() + if dispatch_by: + index_cols |= set(dispatch_by) + + if predicates: + predicate_cols = set(columns_in_predicates(predicates)) + predicate_index_cols = predicate_cols & set(dataset_factory.index_columns) + index_cols |= predicate_index_cols + + for col in index_cols: + dataset_factory.load_index(col) + + base_df = dataset_factory.get_indices_as_dataframe( + list(index_cols), predicates=predicates + ) + + if label_filter: + base_df = base_df[base_df.index.map(label_filter)] indices_to_dispatch = { name: ix.copy(index_dct={}) @@ -70,18 +71,6 @@ def dispatch_metapartitions_from_factory( } if dispatch_by: - # Build up a DataFrame that contains per row a Partition and its primary index columns. - base_df = None - for part_key in dispatch_by: - dataset_factory.load_index(part_key) - idx = dataset_factory.indices[part_key].index_dct - df = _index_to_dataframe(part_key, idx, allowed_labels) - if base_df is None: - base_df = df - else: - base_df = base_df.merge(df, on=["__partition__"]) - - assert base_df is not None base_df = cast(pd.DataFrame, base_df) # Group the resulting MetaParitions by partition keys or a subset of those keys @@ -95,7 +84,7 @@ def dispatch_metapartitions_from_factory( logical_conjunction = list( zip(dispatch_by, ["=="] * len(dispatch_by), group_name) ) - for label in group.__partition__: + for label in group.index: mps.append( MetaPartition.from_partition( partition=dataset_factory.partitions[label], @@ -109,18 +98,7 @@ def dispatch_metapartitions_from_factory( ) yield mps else: - - if allowed_labels is not None: - partition_labels = allowed_labels - else: - partition_labels = dataset_factory.partitions.keys() - - for part_label in partition_labels: - - if label_filter is not None: - if not label_filter(part_label): - continue - + for part_label in base_df.index: part = dataset_factory.partitions[part_label] yield MetaPartition.from_partition( @@ -133,76 +111,6 @@ def dispatch_metapartitions_from_factory( ) -def _allowed_labels_by_predicates(predicates, dataset_factory, dispatch_by): - if len(predicates) == 0: - raise ValueError("The behaviour on an empty list of predicates is undefined") - - dataset_factory = dataset_factory.load_partition_indices() - - # Determine the set of columns that are part of a predicate - columns = set() - for predicates_inner in predicates: - if len(predicates_inner) == 0: - raise ValueError("The behaviour on an empty predicate is undefined") - for col, _, _ in predicates_inner: - columns.add(col) - - # Load the necessary indices - for column in columns: - if column in dataset_factory.indices: - dataset_factory = dataset_factory.load_index(column) - - # Narrow down predicates to the columns that have an index. - # The remaining parts of the predicate are filtered during - # load_dataframes. - filtered_predicates = [] - for predicate in predicates: - new_predicate = [] - for col, op, val in predicate: - if col in dataset_factory.indices: - new_predicate.append((col, op, val)) - filtered_predicates.append(new_predicate) - - # In the case that any of the above filters produced an empty predicate, - # we have to load the full dataset as we cannot prefilter on the indices. - has_catchall = any(((len(predicate) == 0) for predicate in filtered_predicates)) - - # None is a sentinel value for "no predicates" - allowed_labels = None - if filtered_predicates and not has_catchall: - allowed_labels = set() - for conjunction in filtered_predicates: - allowed_labels |= _allowed_labels_by_conjunction( - conjunction, dataset_factory.indices - ) - return dataset_factory, allowed_labels - - -def _allowed_labels_by_conjunction(conjunction, indices): - """ - Returns all partition labels which are allowed by the given conjunction (AND) - of literals based on the indices - - Parameters - ---------- - conjunction: list of tuple - A list of (column, operator, value) tuples - indices: dict - A dict column->kartothek.core.index.IndexBase holding the indices to be evaluated - Returns - ------- - set: allowed labels - """ - allowed_by_conjunction = None - for col, op, val in conjunction: - allowed_labels = indices[col].eval_operator(op, val) - if allowed_by_conjunction is not None: - allowed_by_conjunction &= allowed_labels - else: - allowed_by_conjunction = allowed_labels - return allowed_by_conjunction - - def dispatch_metapartitions( dataset_uuid, store, diff --git a/kartothek/serialization/__init__.py b/kartothek/serialization/__init__.py index 0876956b..1db43bea 100644 --- a/kartothek/serialization/__init__.py +++ b/kartothek/serialization/__init__.py @@ -2,10 +2,17 @@ from ._csv import CsvSerializer from ._generic import ( + ConjunctionType, DataFrameSerializer, + LiteralType, + LiteralValue, + PredicatesType, + check_predicates, + columns_in_predicates, filter_array_like, filter_df, filter_df_from_predicates, + filter_predicates_by_column, ) from ._parquet import ParquetSerializer @@ -25,11 +32,21 @@ def default_serializer(): __all__ = [ - "DataFrameSerializer", + # Serializer classes "CsvSerializer", + "DataFrameSerializer", "ParquetSerializer", "default_serializer", - "filter_df", + # functions + "check_predicates", + "columns_in_predicates", "filter_array_like", "filter_df_from_predicates", + "filter_df", + "filter_predicates_by_column", + # types + "ConjunctionType", + "LiteralType", + "LiteralValue", + "PredicatesType", ] diff --git a/kartothek/serialization/_generic.py b/kartothek/serialization/_generic.py index ebd8dd5a..66aa8193 100644 --- a/kartothek/serialization/_generic.py +++ b/kartothek/serialization/_generic.py @@ -2,9 +2,17 @@ # -*- coding: utf-8 -*- """ This module contains functionality for persisting/serialising DataFrames. + +Available constants + +**PredicatesType** - A type describing the format of predicates which is a list of ConjuntionType +**ConjunctionType** - A type describing a single Conjunction which is a list of literals +**LiteralType** - A type for a single literal + +**LiteralValue** - A type indicating the value of a predicate literal """ -from typing import Dict +from typing import Dict, List, Optional, Set, Tuple, TypeVar import numpy as np import pandas as pd @@ -14,6 +22,13 @@ from ._util import ensure_unicode_string_type +LiteralValue = TypeVar("LiteralValue") +LiteralType = Tuple[str, str, LiteralValue] +ConjunctionType = List[LiteralType] +# Optional is part of the actual type since predicate=None +# is a sential for: All values +PredicatesType = Optional[List[ConjunctionType]] + class DataFrameSerializer: """ @@ -167,7 +182,78 @@ def check_predicates(predicates): ) -def filter_df_from_predicates(df, predicates, strict_date_types=False): +def filter_predicates_by_column( + predicates: PredicatesType, columns: List[str] +) -> Optional[PredicatesType]: + """ + Takes a predicate list and removes all literals which are not referencing one of the given column + + .. ipython:: python + + from kartothek.serialization import filter_predicates_by_column + predicates = [ + [ + ("A", "==", 1), + ("B", "<", 5) + ], + [ + ("C", "==", 4) + ] + ] + + filter_predicates_by_column( + predicates, ["A"] + ) + + Parameters + ---------- + predicates: + A list of predicates to be filtered + columns: + A list of all columns allowed in the output + """ + if predicates is None: + return None + check_predicates(predicates) + filtered_predicates = [] + for predicate in predicates: + new_conjunction = [] + for col, op, val in predicate: + if col in columns: + new_conjunction.append((col, op, val)) + if new_conjunction: + filtered_predicates.append(new_conjunction) + if filtered_predicates: + return filtered_predicates + else: + return None + + +def columns_in_predicates(predicates: PredicatesType) -> Set[str]: + """ + Determine all columns which are mentioned in the list of predicates. + + Parameters + ---------- + predicates: + The predicates to be scaned. + """ + if predicates is None: + return set() + check_predicates(predicates) + # Determine the set of columns that are part of a predicate + columns = set() + for predicates_inner in predicates: + for col, _, _ in predicates_inner: + columns.add(col) + return columns + + +def filter_df_from_predicates( + df: pd.DataFrame, + predicates: Optional[PredicatesType], + strict_date_types: bool = False, +) -> PredicatesType: """ Filter a `pandas.DataFrame` based on predicates in disjunctive normal form. @@ -178,6 +264,7 @@ def filter_df_from_predicates(df, predicates, strict_date_types=False): predicates: list of lists Predicates in disjunctive normal form (DNF). For a thorough documentation, see :class:`DataFrameSerializer.restore_dataframe` + If None, the df is returned unmodified strict_date_types: bool If False (default), cast all datelike values to datetime64 for comparison. @@ -185,6 +272,8 @@ def filter_df_from_predicates(df, predicates, strict_date_types=False): ------- pd.DataFrame """ + if predicates is None: + return df indexer = np.zeros(len(df), dtype=bool) for conjunction in predicates: inner_indexer = np.ones(len(df), dtype=bool) diff --git a/tests/io_components/test_docs.py b/tests/core/test_docs.py similarity index 97% rename from tests/io_components/test_docs.py rename to tests/core/test_docs.py index 08ab62cf..bf329af2 100644 --- a/tests/io_components/test_docs.py +++ b/tests/core/test_docs.py @@ -2,6 +2,7 @@ import pytest +from kartothek.core.docs import _PARAMETER_MAPPING from kartothek.io.dask.bag import ( build_dataset_indices__bag, read_dataset_as_dataframe_bag, @@ -37,7 +38,6 @@ store_dataframes_as_dataset__iter, update_dataset_from_dataframes__iter, ) -from kartothek.io_components.docs import _PARAMETER_MAPPING @pytest.mark.parametrize( diff --git a/tests/io_components/test_dataset.py b/tests/io_components/test_dataset.py index de486c37..807e0a00 100644 --- a/tests/io_components/test_dataset.py +++ b/tests/io_components/test_dataset.py @@ -67,3 +67,53 @@ def test_dataset_get_indices_as_dataframe_duplicates(): ) result = ds.get_indices_as_dataframe() pdt.assert_frame_equal(result, expected) + + +def test_dataset_get_indices_as_dataframe_predicates(): + ds = DatasetMetadata( + "some_uuid", + indices={ + "l_external_code": ExplicitSecondaryIndex( + "l_external_code", {"1": ["part1", "part2"], "2": ["part1", "part2"]} + ), + "p_external_code": ExplicitSecondaryIndex( + "p_external_code", {"1": ["part1"], "2": ["part2"]} + ), + }, + ) + expected = pd.DataFrame( + OrderedDict([("p_external_code", ["1"])]), + index=pd.Index(["part1"], name="partition"), + ) + result = ds.get_indices_as_dataframe( + columns=["p_external_code"], predicates=[[("p_external_code", "==", "1")]] + ) + pdt.assert_frame_equal(result, expected) + + result = ds.get_indices_as_dataframe( + columns=["l_external_code"], predicates=[[("l_external_code", "==", "1")]] + ) + expected = pd.DataFrame( + OrderedDict([("l_external_code", "1")]), + index=pd.Index(["part1", "part2"], name="partition"), + ) + pdt.assert_frame_equal(result, expected) + + result = ds.get_indices_as_dataframe( + columns=["l_external_code"], + predicates=[[("l_external_code", "==", "1"), ("p_external_code", "==", "1")]], + ) + expected = pd.DataFrame( + OrderedDict([("l_external_code", "1")]), + index=pd.Index(["part1"], name="partition"), + ) + pdt.assert_frame_equal(result, expected) + + result = ds.get_indices_as_dataframe( + columns=["l_external_code"], + predicates=[[("l_external_code", "==", "1"), ("p_external_code", "==", "3")]], + ) + expected = pd.DataFrame( + columns=["l_external_code"], index=pd.Index([], name="partition") + ) + pdt.assert_frame_equal(result, expected) diff --git a/tests/io_components/test_read.py b/tests/io_components/test_read.py index 576fc3fe..948d030c 100644 --- a/tests/io_components/test_read.py +++ b/tests/io_components/test_read.py @@ -63,11 +63,10 @@ def test_dispatch_metapartitions_without_dataset_metadata(dataset, store_session @pytest.mark.parametrize("predicates", [[], [[]]]) def test_dispatch_metapartition_undefined_behaviour(dataset, store_session, predicates): - with pytest.raises(ValueError) as exc: + with pytest.raises(ValueError, match="Malformed predicates"): list( dispatch_metapartitions(dataset.uuid, store_session, predicates=predicates) ) - assert "The behaviour on an empty" in str(exc.value) @pytest.mark.parametrize( @@ -95,6 +94,7 @@ def test_dispatch_metapartitions_query_partition_on( @pytest.mark.parametrize( "predicates", [ + # These predicates are OR connected, therefore they need to allow all partitions [[("P", "==", 2)], [("TARGET", "==", 500)]], [[("P", "in", [2])], [("TARGET", "in", [500])]], [[("L", "==", 2)], [("TARGET", "==", 500)]], @@ -103,7 +103,7 @@ def test_dispatch_metapartitions_query_partition_on( def test_dispatch_metapartitions_query_no_effect( dataset_partition_keys, store_session, predicates ): - # These predicates should still lead to loading the whole set of partitionss + # These predicates should still lead to loading the whole set of partitions generator = dispatch_metapartitions( dataset_partition_keys.uuid, store_session, predicates=predicates ) @@ -126,9 +126,13 @@ def test_dispatch_metapartitions_concat_regression(store): ) assert len(mps) == 2 - mps = list( - dispatch_metapartitions( - dataset.uuid, store, concat_partitions_on_primary_index=True + with pytest.deprecated_call(): + mps = list( + dispatch_metapartitions( + dataset.uuid, store, concat_partitions_on_primary_index=True + ) ) - ) + assert len(mps) == 1 + + mps = list(dispatch_metapartitions(dataset.uuid, store, dispatch_by=["p"])) assert len(mps) == 1