Skip to content

Commit

Permalink
Merge pull request #135 from ibm-granite/issue_134
Browse files Browse the repository at this point in the history
Handle exogenous during plotting
  • Loading branch information
wgifford authored Sep 12, 2024
2 parents c4af6f3 + 241d827 commit c889210
Showing 1 changed file with 9 additions and 2 deletions.
11 changes: 9 additions & 2 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
#
"""Utilities for plotting time series data"""

import inspect
import logging
import os
from typing import List, Optional, Union
Expand Down Expand Up @@ -297,9 +298,15 @@ def plot_predictions(
with torch.no_grad():
if indices is None:
indices = np.random.choice(len(dset), size=num_plots, replace=False)
random_samples = torch.stack([dset[i]["past_values"] for i in indices]).to(device=device)

output = model(random_samples)
signature = inspect.signature(model.forward)
signature_keys = list(signature.parameters.keys())
dset_keys = dset[0].keys()
random_samples = {}
for k in dset_keys:
if k in signature_keys:
random_samples[k] = torch.stack([dset[i][k] for i in indices]).to(device=device)
output = model(**random_samples)
predictions_subset = output.prediction_outputs[:, :, channel].squeeze().cpu().numpy()
prediction_length = predictions_subset.shape[1]
using_pipeline = False
Expand Down

0 comments on commit c889210

Please sign in to comment.