Skip to content

Commit

Permalink
Fix default device setting
Browse files Browse the repository at this point in the history
  • Loading branch information
lukasgarbas committed Nov 30, 2024
1 parent 1b416b3 commit 7edf0e4
Showing 1 changed file with 2 additions and 0 deletions.
2 changes: 2 additions & 0 deletions transformer_ranker/ranker.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,6 +72,8 @@ def run(
"""
self._confirm_ranker_setup(estimator=estimator, layer_aggregator=layer_aggregator)

device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu"))

# Load all transformers into hf cache
self._preload_transformers(models, device)

Expand Down

0 comments on commit 7edf0e4

Please sign in to comment.