Skip to content

Commit

Permalink
Add CPU test for Whisper(large_v3_turbo) Model (#752)
Browse files Browse the repository at this point in the history
  • Loading branch information
pdeviTT authored Dec 16, 2024
1 parent 492308b commit a08b563
Showing 1 changed file with 61 additions and 0 deletions.
61 changes: 61 additions & 0 deletions forge/test/models/pytorch/audio/whisper/test_whisper_3.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC

# SPDX-License-Identifier: Apache-2.0
# Whisper Large V3 turbo - Automatic speech recognition and speech translation
# Model link : https://huggingface.co/openai/whisper-large-v3-turbo
# By default, Transformers uses the sequential algorithm.
# To enable the chunked algorithm, pass the chunk_length_s parameter to the pipeline.
# For large-v3, a chunk length of 30-seconds is optimal. To activate batching over long audio files, pass the argument batch_size

import pytest
import torch
from transformers import WhisperConfig, WhisperForConditionalGeneration, WhisperProcessor
import forge


class Wrapper(torch.nn.Module):
def __init__(self, model):
super().__init__()
self.model = model
self.decoder_attention_mask = torch.ones((1, 1))

def forward(self, decoder_input_ids, encoder_hidden_states):
dec_out = self.model.model.decoder(
decoder_input_ids,
self.decoder_attention_mask,
encoder_hidden_states,
)
lin_out = self.model.proj_out(dec_out[0])
return lin_out


@pytest.mark.nightly
@pytest.mark.model_analysis
@pytest.mark.xfail(
reason='RuntimeError: TT_ASSERT @ /tt-forge-fe/forge/csrc/passes/commute_utils.cpp:1103: reshape->op_name() == "reshape"'
)
@pytest.mark.parametrize("variant", ["openai/whisper-large-v3-turbo"])
def test_whisper_large_v3_speech_translation(variant):
processor = WhisperProcessor.from_pretrained(variant)
framework_model = WhisperForConditionalGeneration.from_pretrained(variant)
model_config = WhisperConfig.from_pretrained(variant)
model = Wrapper(framework_model)

sample = torch.load("forge/test/models/files/samples/audio/1272-128104-0000.pt")
sample_audio = sample["audio"]["array"]
inputs = processor(sample_audio, return_tensors="pt", sampling_rate=16000)
input_features = inputs.input_features

# Get decoder inputs
decoder_input_ids = torch.tensor([[1, 1]]) * model_config.decoder_start_token_id
decoder_input_ids = decoder_input_ids.to(torch.int32)
encoder_outputs = model.model.model.encoder(input_features)[0].detach()
encoder_outputs = encoder_outputs.to(torch.float32)
data_input = [decoder_input_ids, encoder_outputs]

# Compiler test
compiled_model = forge.compile(
model, sample_inputs=data_input, module_name="pt_" + str(variant.split("/")[-1].replace("-", "_"))
)

verify(data_input, model, compiled_model, VerifyConfig(verify_allclose=False))

0 comments on commit a08b563

Please sign in to comment.