Skip to content

Commit

Permalink
fix: adding tests, fixing logic
Browse files Browse the repository at this point in the history
  • Loading branch information
madtoinou committed Nov 17, 2024
1 parent 30a1811 commit e2bad2c
Show file tree
Hide file tree
Showing 8 changed files with 269 additions and 53 deletions.
3 changes: 2 additions & 1 deletion CHANGELOG.md
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,8 @@ but cannot always guarantee backwards compatibility. Changes that may **break co
**Improved**

- Improvements to `ForecastingModel`: Improved `start` handling for historical forecasts, backtest, residuals, and gridsearch. If `start` is not within the trainable / forecastable points, uses the closest valid start point that is a round multiple of `stride` ahead of start. Raises a ValueError, if no valid start point exists. This guarantees that all historical forecasts are `n * stride` points away from start, and will simplify many downstream tasks. [#2560](https://github.com/unit8co/darts/issues/2560) by [Dennis Bader](https://github.com/dennisbader).
- Added `data_transformers` argument to `historical_forecasts`, `backtest` and `gridsearch` that allows scaling of the series without data-leakage. [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou) and [Jan Fidor](https://github.com/JanFidor)
- Added `data_transformers` argument to `historical_forecasts`, `backtest`, `residuals`, and `gridsearch` that allows scaling of the series without data-leakage. [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou) and [Jan Fidor](https://github.com/JanFidor)
- Added `idx_params` argument to `DataTransformer` that allows users to use only a subset of the transformers when `global_fit=False` and severals series are used. [#2529](https://github.com/unit8co/darts/pull/2529) by [Antoine Madrona](https://github.com/madtoinou)

**Fixed**

Expand Down
17 changes: 11 additions & 6 deletions darts/dataprocessing/pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,7 @@

from collections.abc import Iterator, Sequence
from copy import deepcopy
from typing import Union
from typing import Optional, Union

from darts import TimeSeries
from darts.dataprocessing.transformers import (
Expand Down Expand Up @@ -158,7 +158,9 @@ def fit_transform(
return data

def transform(
self, data: Union[TimeSeries, Sequence[TimeSeries]]
self,
data: Union[TimeSeries, Sequence[TimeSeries]],
idx_params: Optional[Union[int, Sequence[int]]] = None,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
"""
For each data transformer in pipeline transform data. Then transformed data is passed to next transformer.
Expand All @@ -174,11 +176,14 @@ def transform(
Transformed data.
"""
for transformer in self._transformers:
data = transformer.transform(data)
data = transformer.transform(data, idx_params=idx_params)
return data

def inverse_transform(
self, data: Union[TimeSeries, Sequence[TimeSeries]], partial: bool = False
self,
data: Union[TimeSeries, Sequence[TimeSeries]],
partial: bool = False,
idx_params: Optional[Union[int, Sequence[int]]] = None,
) -> Union[TimeSeries, Sequence[TimeSeries]]:
"""
For each data transformer in the pipeline, inverse-transform data. Then inverse transformed data is passed to
Expand Down Expand Up @@ -207,12 +212,12 @@ def inverse_transform(
)

for transformer in reversed(self._transformers):
data = transformer.inverse_transform(data)
data = transformer.inverse_transform(data, idx_params=idx_params)
return data
else:
for transformer in reversed(self._transformers):
if isinstance(transformer, InvertibleDataTransformer):
data = transformer.inverse_transform(data)
data = transformer.inverse_transform(data, idx_params=idx_params)
return data

@property
Expand Down
5 changes: 4 additions & 1 deletion darts/dataprocessing/transformers/base_data_transformer.py
Original file line number Diff line number Diff line change
Expand Up @@ -363,7 +363,10 @@ def transform(
# Take note of original input for unmasking purposes:
if isinstance(series, TimeSeries):
data = [series]
transformer_selector = [0]
if idx_params:
transformer_selector = self._check_idx_params(idx_params)
else:
transformer_selector = [0]
else:
data = series
if idx_params:
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -271,7 +271,8 @@ def fit(
transformer_selector = range(len(series))

params_iterator = self._get_params(
transformer_selector=transformer_selector, calling_fit=True
transformer_selector=transformer_selector,
calling_fit=True,
)
fit_iterator = (
zip(data, params_iterator)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -327,7 +327,10 @@ def inverse_transform(
called_with_sequence_series = False
if isinstance(series, TimeSeries):
data = [series]
transformer_selector = [0]
if idx_params:
transformer_selector = self._check_idx_params(idx_params)
else:
transformer_selector = [0]
called_with_single_series = True
elif isinstance(series[0], TimeSeries): # Sequence[TimeSeries]
data = series
Expand All @@ -346,7 +349,6 @@ def inverse_transform(
for idx, series_list in iterator_:
data.extend(series_list)
transformer_selector += [idx] * len(series_list)

input_iterator = _build_tqdm_iterator(
zip(data, self._get_params(transformer_selector=transformer_selector)),
verbose=self._verbose,
Expand Down
7 changes: 5 additions & 2 deletions darts/models/forecasting/forecasting_model.py
Original file line number Diff line number Diff line change
Expand Up @@ -911,8 +911,10 @@ def retrain_func(
data_transformers=data_transformers, copy=True
)

# data transformer must already be fitted and can be directly applied to all the series
using_prefitted_transformers = False
# data transformer already fitted and can be directly applied to all the series
if data_transformers and not retrain:
using_prefitted_transformers = True
series, past_covariates, future_covariates = _apply_data_transformers(
series=series,
past_covariates=past_covariates,
Expand Down Expand Up @@ -1097,6 +1099,7 @@ def retrain_func(
train_series = train_series[-train_length_:]

# when `retrain=True`, data transformers are also retrained between iterations to avoid data-leakage
# using a single series
if data_transformers and retrain:
train_series, past_covariates_, future_covariates_ = (
_apply_data_transformers(
Expand Down Expand Up @@ -1187,11 +1190,11 @@ def retrain_func(
**predict_kwargs,
)

# target transformer is either already fitted or fitted during the retraining
forecast = _apply_inverse_data_transformers(
series=train_series,
forecasts=forecast,
data_transformers=data_transformers,
idx_transformer=idx if using_prefitted_transformers else None,
)

show_predict_warnings = False
Expand Down
Loading

0 comments on commit e2bad2c

Please sign in to comment.