diff --git a/forge/test/models/pytorch/audio/stereo/__init__.py b/forge/test/models/pytorch/audio/stereo/__init__.py new file mode 100644 index 000000000..e69de29bb diff --git a/forge/test/models/pytorch/audio/stereo/test_stereo.py b/forge/test/models/pytorch/audio/stereo/test_stereo.py new file mode 100644 index 000000000..340e4d070 --- /dev/null +++ b/forge/test/models/pytorch/audio/stereo/test_stereo.py @@ -0,0 +1,35 @@ +# SPDX-FileCopyrightText: © 2024 Tenstorrent AI ULC + +# SPDX-License-Identifier: Apache-2.0 + + +import pytest + +import forge +from forge.verify.verify import verify + +from .utils import load_inputs, load_model + + +variants = [ + "facebook/musicgen-small", + "facebook/musicgen-medium", + "facebook/musicgen-large", +] + + +@pytest.mark.nightly +@pytest.mark.model_analysis +@pytest.mark.parametrize("variant", variants) +@pytest.mark.xfail(reason="[optimized_graph] Trying to access element outside of dimensions: 3") +def test_stereo(variant): + # Issue: https://github.com/tenstorrent/tt-forge-fe/issues/615 + + framework_model, processor = load_model(variant) + + input_ids, attn_mask, decoder_input_ids = load_inputs(framework_model, processor) + inputs = [input_ids, attn_mask, decoder_input_ids] + + compiled_model = forge.compile(framework_model, sample_inputs=inputs) + + verify(inputs, framework_model, compiled_model) diff --git a/forge/test/models/pytorch/audio/stereo/utils/__init__.py b/forge/test/models/pytorch/audio/stereo/utils/__init__.py new file mode 100644 index 000000000..f51d90ac6 --- /dev/null +++ b/forge/test/models/pytorch/audio/stereo/utils/__init__.py @@ -0,0 +1,4 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +from .utils import load_inputs, load_model diff --git a/forge/test/models/pytorch/audio/stereo/utils/utils.py b/forge/test/models/pytorch/audio/stereo/utils/utils.py new file mode 100644 index 000000000..a77fb59ed --- /dev/null +++ b/forge/test/models/pytorch/audio/stereo/utils/utils.py @@ -0,0 +1,31 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import torch +from transformers import AutoProcessor, MusicgenForConditionalGeneration + +from .wrapper import Wrapper + + +def load_model(variant): + processor = AutoProcessor.from_pretrained(variant) + model = MusicgenForConditionalGeneration.from_pretrained(variant) + model = Wrapper(model) + return model, processor + + +def load_inputs(model, processor): + inputs = processor( + text=["80s pop track with bassy drums and synth", "90s rock song with loud guitars and heavy drums"], + padding=True, + return_tensors="pt", + ) + input_ids = inputs["input_ids"] + attn_mask = inputs["attention_mask"] + + pad_token_id = model.model.generation_config.pad_token_id + decoder_input_ids = ( + torch.ones((inputs.input_ids.shape[0] * model.model.decoder.num_codebooks, 1), dtype=torch.long) * pad_token_id + ) + + return input_ids, attn_mask, decoder_input_ids diff --git a/forge/test/models/pytorch/audio/stereo/utils/wrapper.py b/forge/test/models/pytorch/audio/stereo/utils/wrapper.py new file mode 100644 index 000000000..c549a13ea --- /dev/null +++ b/forge/test/models/pytorch/audio/stereo/utils/wrapper.py @@ -0,0 +1,15 @@ +# SPDX-FileCopyrightText: (c) 2024 Tenstorrent AI ULC +# +# SPDX-License-Identifier: Apache-2.0 +import torch + + +class Wrapper(torch.nn.Module): + def __init__(self, model): + super().__init__() + self.model = model + + def forward(self, input_ids, attention_mask, decoder_input_ids): + inputs = {"input_ids": input_ids, "attention_mask": attention_mask, "decoder_input_ids": decoder_input_ids} + output = self.model(**inputs) + return output.logits