From 4c65b13617c45034df1a998fc49f3b30382262ca Mon Sep 17 00:00:00 2001 From: christopherbunn Date: Thu, 17 Aug 2023 17:18:52 -0400 Subject: [PATCH] Refactored code structure. --- evalml/pipelines/utils.py | 45 +++++++++---------- .../pipeline_tests/test_pipeline_utils.py | 2 +- 2 files changed, 23 insertions(+), 24 deletions(-) diff --git a/evalml/pipelines/utils.py b/evalml/pipelines/utils.py index 7719b020e4..6346e11b1c 100644 --- a/evalml/pipelines/utils.py +++ b/evalml/pipelines/utils.py @@ -1381,7 +1381,6 @@ def unstack_multiseries( # Perform the unstacking X_unstacked_cols = [] y_unstacked_cols = [] - new_time_index = None for s_id in series_id_unique: single_series = full_dataset[full_dataset[series_id] == s_id] @@ -1478,36 +1477,22 @@ def stack_X(X, series_id_name, time_index, starting_index=None, series_id_values """ original_columns = set() series_ids = series_id_values or set() - for col in X.columns: - if col == time_index: - continue - separated_name = col.split("_") - original_columns.add("_".join(separated_name[:-1])) - if series_id_values is None: + if series_id_values is None: + for col in X.columns: + if col == time_index: + continue + separated_name = col.split("_") + original_columns.add("_".join(separated_name[:-1])) series_ids.add(separated_name[-1]) - restacked_X = [] - if len(series_ids) == 0: raise ValueError( - "Unable to stack X as X had no exogenous variables and `series_id_values` is None.", + "X has no exogenous variables and `series_id_values` is None.", ) - for i, original_col in enumerate(original_columns): - # Only include the series id once (for the first column) - include_series_id = i == 0 - subset_X = [col for col in X.columns if original_col in col] - restacked_X.append( - stack_data( - X[subset_X], - include_series_id=include_series_id, - series_id_name=series_id_name, - starting_index=starting_index, - ), - ) time_index_col = X[time_index].repeat(len(series_ids)).reset_index(drop=True) - if len(restacked_X) == 0: + if len(original_columns) == 0: start_index = starting_index or X.index[0] stacked_index = pd.RangeIndex( start=start_index, @@ -1522,6 +1507,20 @@ def stack_X(X, series_id_name, time_index, starting_index=None, series_id_values index=stacked_index, ) else: + restacked_X = [] + for i, original_col in enumerate(original_columns): + # Only include the series id once (for the first column) + include_series_id = i == 0 + subset_X = [col for col in X.columns if original_col in col] + restacked_X.append( + stack_data( + X[subset_X], + include_series_id=include_series_id, + series_id_name=series_id_name, + starting_index=starting_index, + ), + ) + restacked_X = pd.concat(restacked_X, axis=1) time_index_col.index = restacked_X.index restacked_X[time_index] = time_index_col diff --git a/evalml/tests/pipeline_tests/test_pipeline_utils.py b/evalml/tests/pipeline_tests/test_pipeline_utils.py index c396b57788..d8c1c6a821 100644 --- a/evalml/tests/pipeline_tests/test_pipeline_utils.py +++ b/evalml/tests/pipeline_tests/test_pipeline_utils.py @@ -1478,7 +1478,7 @@ def test_stack_X( with pytest.raises( ValueError, - match="Unable to stack X as X had no exogenous variables and `series_id_values` is None.", + match="X has no exogenous variables and `series_id_values` is None.", ): stack_X(X, "series_id", "date", starting_index=starting_index)