Skip to content

Commit

Permalink
Refactored code structure.
Browse files Browse the repository at this point in the history
  • Loading branch information
christopherbunn committed Aug 18, 2023
1 parent ecfd078 commit 4c65b13
Show file tree
Hide file tree
Showing 2 changed files with 23 additions and 24 deletions.
45 changes: 22 additions & 23 deletions evalml/pipelines/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1381,7 +1381,6 @@ def unstack_multiseries(
# Perform the unstacking
X_unstacked_cols = []
y_unstacked_cols = []
new_time_index = None
for s_id in series_id_unique:
single_series = full_dataset[full_dataset[series_id] == s_id]

Expand Down Expand Up @@ -1478,36 +1477,22 @@ def stack_X(X, series_id_name, time_index, starting_index=None, series_id_values
"""
original_columns = set()
series_ids = series_id_values or set()
for col in X.columns:
if col == time_index:
continue
separated_name = col.split("_")
original_columns.add("_".join(separated_name[:-1]))
if series_id_values is None:
if series_id_values is None:
for col in X.columns:
if col == time_index:
continue
separated_name = col.split("_")
original_columns.add("_".join(separated_name[:-1]))
series_ids.add(separated_name[-1])

Check warning on line 1486 in evalml/pipelines/utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/pipelines/utils.py#L1479-L1486

Added lines #L1479 - L1486 were not covered by tests

restacked_X = []

if len(series_ids) == 0:
raise ValueError(

Check warning on line 1489 in evalml/pipelines/utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/pipelines/utils.py#L1488-L1489

Added lines #L1488 - L1489 were not covered by tests
"Unable to stack X as X had no exogenous variables and `series_id_values` is None.",
"X has no exogenous variables and `series_id_values` is None.",
)

for i, original_col in enumerate(original_columns):
# Only include the series id once (for the first column)
include_series_id = i == 0
subset_X = [col for col in X.columns if original_col in col]
restacked_X.append(
stack_data(
X[subset_X],
include_series_id=include_series_id,
series_id_name=series_id_name,
starting_index=starting_index,
),
)
time_index_col = X[time_index].repeat(len(series_ids)).reset_index(drop=True)

if len(restacked_X) == 0:
if len(original_columns) == 0:
start_index = starting_index or X.index[0]
stacked_index = pd.RangeIndex(

Check warning on line 1497 in evalml/pipelines/utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/pipelines/utils.py#L1495-L1497

Added lines #L1495 - L1497 were not covered by tests
start=start_index,
Expand All @@ -1522,6 +1507,20 @@ def stack_X(X, series_id_name, time_index, starting_index=None, series_id_values
index=stacked_index,
)
else:
restacked_X = []
for i, original_col in enumerate(original_columns):

Check warning on line 1511 in evalml/pipelines/utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/pipelines/utils.py#L1510-L1511

Added lines #L1510 - L1511 were not covered by tests
# Only include the series id once (for the first column)
include_series_id = i == 0
subset_X = [col for col in X.columns if original_col in col]
restacked_X.append(

Check warning on line 1515 in evalml/pipelines/utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/pipelines/utils.py#L1513-L1515

Added lines #L1513 - L1515 were not covered by tests
stack_data(
X[subset_X],
include_series_id=include_series_id,
series_id_name=series_id_name,
starting_index=starting_index,
),
)

restacked_X = pd.concat(restacked_X, axis=1)
time_index_col.index = restacked_X.index
restacked_X[time_index] = time_index_col

Check warning on line 1526 in evalml/pipelines/utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/pipelines/utils.py#L1524-L1526

Added lines #L1524 - L1526 were not covered by tests
Expand Down
2 changes: 1 addition & 1 deletion evalml/tests/pipeline_tests/test_pipeline_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1478,7 +1478,7 @@ def test_stack_X(

with pytest.raises(

Check warning on line 1479 in evalml/tests/pipeline_tests/test_pipeline_utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/tests/pipeline_tests/test_pipeline_utils.py#L1479

Added line #L1479 was not covered by tests
ValueError,
match="Unable to stack X as X had no exogenous variables and `series_id_values` is None.",
match="X has no exogenous variables and `series_id_values` is None.",
):
stack_X(X, "series_id", "date", starting_index=starting_index)

Check warning on line 1483 in evalml/tests/pipeline_tests/test_pipeline_utils.py

View check run for this annotation

Codecov / codecov/patch

evalml/tests/pipeline_tests/test_pipeline_utils.py#L1483

Added line #L1483 was not covered by tests

Expand Down

0 comments on commit 4c65b13

Please sign in to comment.