Skip to content

Commit

Permalink
fix
Browse files Browse the repository at this point in the history
  • Loading branch information
ydshieh committed Apr 18, 2024
1 parent ce8e64f commit e828a1b
Show file tree
Hide file tree
Showing 2 changed files with 3 additions and 1 deletion.
Original file line number Diff line number Diff line change
Expand Up @@ -124,7 +124,8 @@ def convert_clip_checkpoint(checkpoint_path, pytorch_dump_folder_path, config_pa
copy_vison_model_and_projection(hf_model, pt_model)
hf_model.logit_scale = pt_model.logit_scale

input_ids = torch.arange(0, 77).unsqueeze(0)
# Use a `eos_token` so the example is more meaningful
input_ids = torch.tensor([[config.text_config.bos_token_id] + list(range(3, 77)) + [config.text_config.eos_token_id] + [config.text_config.pad_token_id]])
pixel_values = torch.randn(1, 3, 224, 224)

hf_outputs = hf_model(input_ids=input_ids, pixel_values=pixel_values, return_dict=True)
Expand Down
1 change: 1 addition & 0 deletions src/transformers/models/clip/modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -734,6 +734,7 @@ def forward(
pooled_output = last_hidden_state[
torch.arange(last_hidden_state.shape[0], device=last_hidden_state.device),
# We need to get the first position of `eos_token_id` value (`pad_token_ids` might equal to `eos_token_id`)
# Note: we assume the input always has a eos token in each text (i.e. always prepared by clip tokenizer)
(input_ids.to(dtype=torch.int, device=last_hidden_state.device) == self.eos_token_id)
.int()
.argmax(dim=-1),
Expand Down

0 comments on commit e828a1b

Please sign in to comment.