Skip to content

Commit

Permalink
Fix: Remove gradients in prediction
Browse files Browse the repository at this point in the history
  • Loading branch information
bentaculum committed Oct 3, 2024
1 parent 225e46d commit 0f1783b
Showing 1 changed file with 9 additions and 7 deletions.
16 changes: 9 additions & 7 deletions trackastra/model/predict.py
Original file line number Diff line number Diff line change
Expand Up @@ -35,15 +35,17 @@ def predict(batch, model):

# Concat timepoints to coordinates
coords = torch.cat((timepoints.unsqueeze(2).float(), coords), dim=2)
A = model(coords, features=feats)
A = model.normalize_output(A, timepoints, coords)
with torch.no_grad():
A = model(coords, features=feats)
A = model.normalize_output(A, timepoints, coords)

# # Spatially far entries should not influence the causal normalization
# dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
# invalid = dist > model.config["spatial_pos_cutoff"]
# A[invalid] = -torch.inf
# # Spatially far entries should not influence the causal normalization
# dist = torch.cdist(coords[0, :, 1:], coords[0, :, 1:])
# invalid = dist > model.config["spatial_pos_cutoff"]
# A[invalid] = -torch.inf

A = A.squeeze(0).detach().cpu().numpy()

A = A.squeeze(0).detach().cpu().numpy()
return A


Expand Down

0 comments on commit 0f1783b

Please sign in to comment.