From 241d827c102863dfbd95c33efbf92de38dbed2c3 Mon Sep 17 00:00:00 2001 From: Wesley Gifford <79663411+wgifford@users.noreply.github.com> Date: Thu, 12 Sep 2024 14:11:52 -0400 Subject: [PATCH] handle exog --- tsfm_public/toolkit/visualization.py | 11 +++++++++-- 1 file changed, 9 insertions(+), 2 deletions(-) diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index ac67878e..db1dfe7b 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -2,6 +2,7 @@ # """Utilities for plotting time series data""" +import inspect import logging import os from typing import List, Optional, Union @@ -297,9 +298,15 @@ def plot_predictions( with torch.no_grad(): if indices is None: indices = np.random.choice(len(dset), size=num_plots, replace=False) - random_samples = torch.stack([dset[i]["past_values"] for i in indices]).to(device=device) - output = model(random_samples) + signature = inspect.signature(model.forward) + signature_keys = list(signature.parameters.keys()) + dset_keys = dset[0].keys() + random_samples = {} + for k in dset_keys: + if k in signature_keys: + random_samples[k] = torch.stack([dset[i][k] for i in indices]).to(device=device) + output = model(**random_samples) predictions_subset = output.prediction_outputs[:, :, channel].squeeze().cpu().numpy() prediction_length = predictions_subset.shape[1] using_pipeline = False