Skip to content

Commit

Permalink
periods parameters
Browse files Browse the repository at this point in the history
  • Loading branch information
remyogasawara committed Aug 22, 2023
1 parent 09e31f2 commit 837fc79
Show file tree
Hide file tree
Showing 5 changed files with 74 additions and 96 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -190,34 +190,6 @@ def _detrend_on_fly(X, y):
relative_maxima = _get_rel_max_from_acf(y_detrended)
return relative_maxima

# def set_period(
# self,
# X: pd.DataFrame,
# y: pd.Series,
# acf_threshold: float = 0.01,
# rel_max_order: int = 5,
# ):
# """Function to set the component's seasonal period based on the target's seasonality.

# Args:
# X (pandas.DataFrame): The feature data of the time series problem.
# y (pandas.Series): The target data of a time series problem.
# acf_threshold (float) : The threshold for the autocorrelation function to determine the period. Any values below
# the threshold are considered to be 0 and will not be considered for the period. Defaults to 0.01.
# rel_max_order (int) : The order of the relative maximum to determine the period. Defaults to 5.

# """
# self.periods = {}
# if len(y.columns) == 1:
# self.period = self.determine_periodicity(X, y, acf_threshold, rel_max_order)
# self.update_parameters({"period": self.period})
# self.periods[id] = self.period
# return
# else:
# for id in y.columns:
# self.periods[id] = self.determine_periodicity(X, y[id], acf_threshold, rel_max_order)
# self.update_parameters({"periods": self.periods})

def set_period(
self,
X: pd.DataFrame,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -195,26 +195,27 @@ def fit(
self.seasonalities = {}
self.trends = {}
self.residuals = {}
self.periods = {}

# # Determine the period of the seasonal component
# # Set the period if it is single series and period is given
# if self.period is not None and len(y.columns) == 1:
# self.periods = {0: self.period}
# # Set periods if it is single series and period is
# if self.periods is None or self.period is None:
# self.set_period(X, y)

# if self.period is None:
# self.set_period(X, y)
if self.periods is None:
self.periods = {}

for id in y.columns:
series_y = y[id]

# Determine the period of the seasonal component
if id not in self.periods or self.period is None:
self.set_period(X, series_y)
self.periods[id] = self.period
if id not in self.periods:
period = self.determine_periodicity(
X,
series_y,
acf_threshold=0.01,
rel_max_order=5,
)
if self.period is None and len(y.columns) == 1:
self.period = period
self.update_parameters({"period": self.period})
elif self.period is not None and len(y.columns) == 1:
period = self.period
self.periods[id] = period
self.update_parameters({"periods": self.periods})

stl = STL(
series_y,
Expand Down Expand Up @@ -463,7 +464,8 @@ def get_trend_dataframe(self, X, y):
# in ForecastingHorizon during decomposition.
if not isinstance(y.index, pd.DatetimeIndex):
y = self._set_time_index(X, y)

if not isinstance(X.index, pd.DatetimeIndex):
X.index = y.index
self._check_oos_past(y)

def _decompose_target(X, y, fh, trend, seasonal, residual, period, id):
Expand Down Expand Up @@ -495,13 +497,6 @@ def _decompose_target(X, y, fh, trend, seasonal, residual, period, id):
# Iterate through each series id
for id in y.columns:
result_dfs = []
if not isinstance(X.index, pd.DatetimeIndex):
raise TypeError("Provided X should have datetimes in the index.")
if X.index.freq is None:
raise ValueError(
"Provided DatetimeIndex of X should have an inferred frequency.",
)

if len(y.columns) > 1:
seasonal = self.seasonals[id]
trend = self.trends[id]
Expand Down
44 changes: 0 additions & 44 deletions evalml/tests/component_tests/decomposer_tests/test_decomposer.py
Original file line number Diff line number Diff line change
Expand Up @@ -438,50 +438,6 @@ def test_decomposer_projected_seasonality_integer_and_datetime(
)


@pytest.mark.parametrize(
"decomposer_child_class",
decomposer_list,
)
@pytest.mark.parametrize(
"variateness",
[
"univariate",
"multivariate",
],
)
def test_decomposer_get_trend_dataframe_raises_errors(
decomposer_child_class,
ts_data,
ts_multiseries_data,
variateness,
):
if variateness == "univariate":
X, _, y = ts_data()
elif variateness == "multivariate":
if isinstance(decomposer_child_class(), PolynomialDecomposer):
pytest.skip(
"Skipping Decomposer because multiseries is not implemented for Polynomial Decomposer",
)
X, _, y = ts_multiseries_data()

dec = decomposer_child_class()
dec.fit_transform(X, y)

with pytest.raises(
TypeError,
match="Provided X should have datetimes in the index.",
):
X_int_index = X.reset_index()
dec.get_trend_dataframe(X_int_index, y)

with pytest.raises(
ValueError,
match="Provided DatetimeIndex of X should have an inferred frequency.",
):
X.index.freq = None
dec.get_trend_dataframe(X, y)


@pytest.mark.parametrize(
"decomposer_child_class",
decomposer_list,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -98,3 +98,26 @@ def test_polynomial_decomposer_needs_monotonic_index(ts_data):
decomposer.fit_transform(X, y_shuffled)
expected_errors = ["monotonically", "X must be in an sktime compatible format"]
assert any([error in str(exec_info.value) for error in expected_errors])


def test_polynomial_decomposer_get_trend_dataframe_raises_errors(
ts_data,
):
X, _, y = ts_data()

dec = PolynomialDecomposer()
dec.fit_transform(X, y)

with pytest.raises(
TypeError,
match="Provided X should have datetimes in the index.",
):
X_int_index = X.reset_index()
dec.get_trend_dataframe(X_int_index, y)

with pytest.raises(
ValueError,
match="Provided DatetimeIndex of X should have an inferred frequency.",
):
X.index.freq = None
dec.get_trend_dataframe(X, y)
Original file line number Diff line number Diff line change
Expand Up @@ -447,6 +447,7 @@ def test_unsupported_frequencies(
"""This test exists to highlight that even though the underlying statsmodels STL component won't work
for minute or annual frequencies, we can still run these frequencies with automatic period detection.
"""
# period = 7 if variateness == "univariate" else {}
X, y = generate_seasonal_data(
real_or_synthetic="synthetic",
univariate_or_multivariate=variateness,
Expand All @@ -457,7 +458,38 @@ def test_unsupported_frequencies(

stl = STLDecomposer()
X_t, y_t = stl.fit_transform(X, y)
assert stl.period is not None
if variateness == "univariate":
assert stl.period is not None
else:
assert stl.periods is not None


@pytest.mark.parametrize(
"variateness",
[
"univariate",
"multivariate",
],
)
def test_init_periods(
generate_seasonal_data,
variateness,
):
"""This test exists to highlight that even though the underlying statsmodels STL component won't work
for minute or annual frequencies, we can still run these frequencies with automatic period detection.
"""
period = 7
X, y = generate_seasonal_data(
real_or_synthetic="synthetic",
univariate_or_multivariate=variateness,
)(period)
periods = {id: 8 for id in y.columns} if variateness == "multivariate" else None
stl = STLDecomposer(period=period, periods=periods)
X_t, y_t = stl.fit_transform(X, y)
if variateness == "univariate":
assert stl.period == period
else:
assert stl.periods == periods


@pytest.mark.parametrize(
Expand Down

0 comments on commit 837fc79

Please sign in to comment.