Skip to content

Commit

Permalink
Set device in torch.arange call for .embed()
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Oct 12, 2023
1 parent 49fe73b commit 3c12d0a
Showing 1 changed file with 3 additions and 1 deletion.
4 changes: 3 additions & 1 deletion utils/residual_stream.py
Original file line number Diff line number Diff line change
Expand Up @@ -161,7 +161,9 @@ def embed(
names_filter = lambda name: hook == name, batch_size=batch_size
)
out: Float[Tensor, "batch pos d_model"] = cache[hook]
return out[torch.arange(len(out)), self.position, :].detach().cpu()
return out[
torch.arange(len(out), device=out.device), self.position, :
].detach().cpu()

@classmethod
def get_dataset(
Expand Down

0 comments on commit 3c12d0a

Please sign in to comment.