Skip to content

Commit

Permalink
BLIP: this is correct now (#35081)
Browse files Browse the repository at this point in the history
this is correct now
  • Loading branch information
zucchini-nlp authored Dec 5, 2024
1 parent 50189e3 commit e682c17
Show file tree
Hide file tree
Showing 4 changed files with 4 additions and 4 deletions.
2 changes: 1 addition & 1 deletion src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -2311,7 +2311,7 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1593,7 +1593,7 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "image_token_index", None) is not None:
start_tokens += [self.config.image_token_index] * self.config.num_query_tokens
start_tokens = [self.config.image_token_index] * self.config.num_query_tokens + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -1628,7 +1628,7 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -441,7 +441,7 @@ def generate(
if input_ids is None:
start_tokens = [self.config.text_config.bos_token_id]
if getattr(self.config, "video_token_index", None) is not None:
start_tokens += [self.config.video_token_index] * self.config.num_query_tokens * 4
start_tokens = [self.config.video_token_index] * self.config.num_query_tokens * 4 + start_tokens
input_ids = torch.tensor([start_tokens], dtype=torch.long, device=image_embeds.device)
input_ids = input_ids.repeat(batch_size, 1)

Expand Down

0 comments on commit e682c17

Please sign in to comment.