Skip to content

Commit

Permalink
catch some edge cases
Browse files Browse the repository at this point in the history
  • Loading branch information
wgifford committed Sep 18, 2024
1 parent b3ae7cf commit e94eaa4
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions tsfm_public/toolkit/visualization.py
Original file line number Diff line number Diff line change
Expand Up @@ -282,6 +282,7 @@ def plot_predictions(

if indices is None:
l = len(predictions_df)
num_plots = min(num_plots, l)
indices = np.random.choice(l, size=num_plots, replace=False)
predictions_subset = [predictions_df.iloc[i] for i in indices]

Expand All @@ -297,6 +298,7 @@ def plot_predictions(

with torch.no_grad():
if indices is None:
num_plots = min(num_plots, len(dset))
indices = np.random.choice(len(dset), size=num_plots, replace=False)

signature = inspect.signature(model.forward)
Expand Down

0 comments on commit e94eaa4

Please sign in to comment.