diff --git a/transformer_ranker/ranker.py b/transformer_ranker/ranker.py index 57d1362..37503a6 100644 --- a/transformer_ranker/ranker.py +++ b/transformer_ranker/ranker.py @@ -72,6 +72,8 @@ def run( """ self._confirm_ranker_setup(estimator=estimator, layer_aggregator=layer_aggregator) + device = torch.device(device or ("cuda" if torch.cuda.is_available() else "cpu")) + # Load all transformers into hf cache self._preload_transformers(models, device)