Skip to content

Commit

Permalink
[Whisper] patch float type on mps (#35295)
Browse files Browse the repository at this point in the history
* fix float type on mps

* make
  • Loading branch information
eustlb authored Dec 16, 2024
1 parent d5b81e1 commit 9feae5f
Showing 1 changed file with 13 additions and 4 deletions.
17 changes: 13 additions & 4 deletions src/transformers/models/whisper/generation_whisper.py
Original file line number Diff line number Diff line change
Expand Up @@ -632,7 +632,9 @@ def generate(
cur_bsz=cur_bsz,
batch_idx_map=batch_idx_map,
)
time_offset = seek.to(torch.float64) * time_precision / input_stride
time_offset = (
seek.to(torch.float32 if device.type == "mps" else torch.float64) * time_precision / input_stride
)
seek_num_frames = (max_frames - seek).clamp(max=num_segment_frames)

# 6.2 cut out next 30s segment from input features
Expand Down Expand Up @@ -1805,6 +1807,7 @@ def _retrieve_segment(
timestamp_segment_indices = torch.where(timestamp_tokens[:-1] & timestamp_tokens[1:])[0]
timestamp_segment_indices.add_(1)
token_timestamps = seek_outputs[idx]["token_timestamps"] if return_token_timestamps else []
device = seek_sequence.device

# If whisper predicted a "end of segment" via a timestep token, let's go ever each
# "end of segment" prediction and slice the decoding into segments accordingly
Expand All @@ -1828,8 +1831,12 @@ def _retrieve_segment(
end_timestamp_pos = sliced_tokens[idx_sliced_tokens] - timestamp_begin
segments.append(
{
"start": time_offset[prev_idx] + start_timestamp_pos.to(torch.float64) * time_precision,
"end": time_offset[prev_idx] + end_timestamp_pos.to(torch.float64) * time_precision,
"start": time_offset[prev_idx]
+ start_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
* time_precision,
"end": time_offset[prev_idx]
+ end_timestamp_pos.to(torch.float32 if device.type == "mps" else torch.float64)
* time_precision,
"tokens": sliced_tokens,
"result": seek_outputs[idx],
}
Expand All @@ -1856,7 +1863,9 @@ def _retrieve_segment(
last_timestamp_pos = int(seek_num_frames[prev_idx] * time_precision_features / time_precision)
if timestamps.numel() > 0 and timestamps[-1] != timestamp_begin:
# no consecutive timestamps but it has a timestamp; use the last one.
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(torch.float64)
last_timestamp_pos = (timestamps[-1] - timestamp_begin).to(
torch.float32 if device.type == "mps" else torch.float64
)
segments = [
{
"start": time_offset[prev_idx],
Expand Down

0 comments on commit 9feae5f

Please sign in to comment.