diff --git a/server/text_generation_server/models/custom_modeling/idefics_vision.py b/server/text_generation_server/models/custom_modeling/idefics_vision.py index 0418281231e..30f07095e71 100644 --- a/server/text_generation_server/models/custom_modeling/idefics_vision.py +++ b/server/text_generation_server/models/custom_modeling/idefics_vision.py @@ -91,7 +91,7 @@ def __init__(self, prefix, config, weights): self.position_embedding = TensorParallelEmbedding( prefix="model.vision_model.embeddings.position_embedding", weights=weights ) - self.position_ids = torch.arange(self.num_positions).expand((1, -1)) + self.position_ids = torch.arange(self.num_positions).expand((1, -1)).to(device=weights.device) def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor: batch_size = pixel_values.shape[0]