diff --git a/tsfm_public/toolkit/visualization.py b/tsfm_public/toolkit/visualization.py index db1dfe7b..2a42b643 100644 --- a/tsfm_public/toolkit/visualization.py +++ b/tsfm_public/toolkit/visualization.py @@ -282,6 +282,7 @@ def plot_predictions( if indices is None: l = len(predictions_df) + num_plots = min(num_plots, l) indices = np.random.choice(l, size=num_plots, replace=False) predictions_subset = [predictions_df.iloc[i] for i in indices] @@ -297,6 +298,7 @@ def plot_predictions( with torch.no_grad(): if indices is None: + num_plots = min(num_plots, len(dset)) indices = np.random.choice(len(dset), size=num_plots, replace=False) signature = inspect.signature(model.forward)