From f4109c6c9640af6709eb7b88d282742e91f4276b Mon Sep 17 00:00:00 2001 From: Wilson Beebe Date: Wed, 17 Jan 2024 11:06:14 -0800 Subject: [PATCH] Fix EnsembleFrame.repartition (#349) * Add EnsembleFrame.repartition * Repartition source frame with update_ensemble * lint fix --- src/tape/ensemble.py | 2 +- src/tape/ensemble_frame.py | 75 +++++++++++++++++++++++++ tests/tape_tests/test_ensemble_frame.py | 7 +++ 3 files changed, 83 insertions(+), 1 deletion(-) diff --git a/src/tape/ensemble.py b/src/tape/ensemble.py index 46959645..08822cdf 100644 --- a/src/tape/ensemble.py +++ b/src/tape/ensemble.py @@ -342,7 +342,7 @@ def insert_sources( if all(prev_div): self.update_frame(self.source.repartition(divisions=prev_div)) elif self.source.npartitions != prev_num: - self.source = self.source.repartition(npartitions=prev_num) + self.update_frame(self.source.repartition(npartitions=prev_num)) return self diff --git a/src/tape/ensemble_frame.py b/src/tape/ensemble_frame.py index b1005fdf..510099be 100644 --- a/src/tape/ensemble_frame.py +++ b/src/tape/ensemble_frame.py @@ -666,6 +666,81 @@ def compute(self, **kwargs): self.ensemble._lazy_sync_tables_from_frame(self) return super().compute(**kwargs) + def repartition( + self, + divisions=None, + npartitions=None, + partition_size=None, + freq=None, + force=False, + ): + """Repartition dataframe along new divisions + + Doc string below derived from dask.dataframe.DataFrame + + Parameters + ---------- + divisions : list, optional + The "dividing lines" used to split the dataframe into partitions. + For ``divisions=[0, 10, 50, 100]``, there would be three output partitions, + where the new index contained [0, 10), [10, 50), and [50, 100), respectively. + See https://docs.dask.org/en/latest/dataframe-design.html#partitions. + Only used if npartitions and partition_size isn't specified. + For convenience if given an integer this will defer to npartitions + and if given a string it will defer to partition_size (see below) + npartitions : int, optional + Approximate number of partitions of output. Only used if partition_size + isn't specified. The number of partitions used may be slightly + lower than npartitions depending on data distribution, but will never be + higher. + partition_size: int or string, optional + Max number of bytes of memory for each partition. Use numbers or + strings like 5MB. If specified npartitions and divisions will be + ignored. Note that the size reflects the number of bytes used as + computed by ``pandas.DataFrame.memory_usage``, which will not + necessarily match the size when storing to disk. + + .. warning:: + + This keyword argument triggers computation to determine + the memory size of each partition, which may be expensive. + + freq : str, pd.Timedelta + A period on which to partition timeseries data like ``'7D'`` or + ``'12h'`` or ``pd.Timedelta(hours=12)``. Assumes a datetime index. + force : bool, default False + Allows the expansion of the existing divisions. + If False then the new divisions' lower and upper bounds must be + the same as the old divisions'. + + Notes + ----- + Exactly one of `divisions`, `npartitions`, `partition_size`, or `freq` + should be specified. A ``ValueError`` will be raised when that is + not the case. + + Also note that ``len(divisons)`` is equal to ``npartitions + 1``. This is because ``divisions`` + represents the upper and lower bounds of each partition. The first item is the + lower bound of the first partition, the second item is the lower bound of the + second partition and the upper bound of the first partition, and so on. + The second-to-last item is the lower bound of the last partition, and the last + (extra) item is the upper bound of the last partition. + + Examples + -------- + >>> df = df.repartition(npartitions=10) # doctest: +SKIP + >>> df = df.repartition(divisions=[0, 5, 10, 20]) # doctest: +SKIP + >>> df = df.repartition(freq='7d') # doctest: +SKIP + """ + result = super().repartition( + divisions=divisions, + npartitions=npartitions, + partition_size=partition_size, + freq=freq, + force=force, + ) + return self._propagate_metadata(result) + class TapeSeries(pd.Series): """A barebones extension of a Pandas series to be used for underlying Ensemble data. diff --git a/tests/tape_tests/test_ensemble_frame.py b/tests/tape_tests/test_ensemble_frame.py index 962f9b2c..5ed01488 100644 --- a/tests/tape_tests/test_ensemble_frame.py +++ b/tests/tape_tests/test_ensemble_frame.py @@ -143,6 +143,13 @@ def test_ensemble_frame_propagation(data_fixture, request): assert merged_frame.ensemble == ens assert merged_frame.is_dirty() + # Test that frame metadata is preserved after repartitioning + repartitioned_frame = ens_frame.copy().repartition(npartitions=10) + assert isinstance(repartitioned_frame, EnsembleFrame) + assert repartitioned_frame.label == TEST_LABEL + assert repartitioned_frame.ensemble == ens + assert repartitioned_frame.is_dirty() + # Test that head returns a subset of the underlying TapeFrame. h = ens_frame.head(5) assert isinstance(h, TapeFrame)