From 0ba8af3d63687bdfe5712e3004aca4ddd4436010 Mon Sep 17 00:00:00 2001 From: Tom Aarsen <37621491+tomaarsen@users.noreply.github.com> Date: Wed, 13 Dec 2023 22:18:38 +0100 Subject: [PATCH] Set the Linear device equal to the main model device in SoftmaxLoss (#2378) --- sentence_transformers/losses/SoftmaxLoss.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/sentence_transformers/losses/SoftmaxLoss.py b/sentence_transformers/losses/SoftmaxLoss.py index aa3d87ae9..8b35d56ac 100644 --- a/sentence_transformers/losses/SoftmaxLoss.py +++ b/sentence_transformers/losses/SoftmaxLoss.py @@ -55,7 +55,7 @@ def __init__(self, if concatenation_sent_multiplication: num_vectors_concatenated += 1 logger.info("Softmax loss: #Vectors concatenated: {}".format(num_vectors_concatenated)) - self.classifier = nn.Linear(num_vectors_concatenated * sentence_embedding_dimension, num_labels) + self.classifier = nn.Linear(num_vectors_concatenated * sentence_embedding_dimension, num_labels, device=model.device) self.loss_fct = loss_fct def forward(self, sentence_features: Iterable[Dict[str, Tensor]], labels: Tensor):