Skip to content

Commit

Permalink
Fix position ids logic instantiation of idefics vision part
Browse files Browse the repository at this point in the history
  • Loading branch information
VictorSanh authored Sep 26, 2023
1 parent 2f51645 commit 5a6c572
Showing 1 changed file with 1 addition and 3 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -88,12 +88,10 @@ def __init__(self, prefix, config, weights):

self.num_patches = (self.image_size // self.patch_size) ** 2
self.num_positions = self.num_patches + 1
# self.position_embedding = nn.Embedding(self.num_positions, self.embed_dim)
self.position_embedding = TensorParallelEmbedding(
prefix="model.vision_model.embeddings.position_embedding", weights=weights
)
# self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)
self.position_ids = weights.get_tensor(f"{prefix}.position_ids")
self.register_buffer("position_ids", torch.arange(self.num_positions).expand((1, -1)), persistent=False)

def forward(self, pixel_values: torch.FloatTensor) -> torch.Tensor:
batch_size = pixel_values.shape[0]
Expand Down

0 comments on commit 5a6c572

Please sign in to comment.