diff --git a/examples/speech_to_text/inference/sequence_generator_tagged.py b/examples/speech_to_text/inference/sequence_generator_tagged.py index c623c248..d7997369 100644 --- a/examples/speech_to_text/inference/sequence_generator_tagged.py +++ b/examples/speech_to_text/inference/sequence_generator_tagged.py @@ -95,6 +95,7 @@ def _generate( prefix_tokens: Optional[Tensor] = None, constraints: Optional[Tensor] = None, bos_token: Optional[int] = None, + pre_computed_encoder_outs: Optional[Tensor] = None, ): incremental_states = torch.jit.annotate( List[Dict[str, Dict[str, Optional[Tensor]]]], @@ -147,7 +148,10 @@ def _generate( self.min_len <= max_len ), "min_len cannot be larger than max_len, please adjust these!" # compute the encoder output for each beam - encoder_outs = self.model.forward_encoder(net_input) + if pre_computed_encoder_outs is not None: + encoder_outs = pre_computed_encoder_outs + else: + encoder_outs = self.model.forward_encoder(net_input) # placeholder of indices for bsz * beam_size to hold tokens and accumulative scores new_order = torch.arange(bsz).view(-1, 1).repeat(1, beam_size).view(-1) diff --git a/examples/speech_to_text/simultaneous_translation/agents/simul_offline_waitk_tags.py b/examples/speech_to_text/simultaneous_translation/agents/simul_offline_waitk_tags.py new file mode 100644 index 00000000..8141a348 --- /dev/null +++ b/examples/speech_to_text/simultaneous_translation/agents/simul_offline_waitk_tags.py @@ -0,0 +1,96 @@ +# Copyright 2023 FBK + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import torch +from examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk import WaitkAgent + +try: + from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS + from simuleval.agents import SpeechAgent + from simuleval.states import ListEntry, SpeechStates +except ImportError: + print("Please install simuleval 'pip install simuleval'") + + +class WaitkAgentWithTags(WaitkAgent): + def load_model_vocab(self, args): + super().load_model_vocab(args) + self.tags = self.task.data_cfg.tags + + def initialize_states(self, states): + super().initialize_states(states) + # Store previous output tokens without considering emitted tags + states.prev_toks = [] + states.prev_tag = 0 + + def _get_prefix(self, states): + if states.prev_toks: + prefix_tokens = torch.tensor([states.prev_toks], dtype=torch.int64) + if self.prefix_token_idx is not None: + return torch.cat( + (torch.LongTensor([[self.prefix_token_idx]]), prefix_tokens), dim=1) + return prefix_tokens + else: + if self.prefix_token_idx is not None: + return torch.LongTensor([[self.prefix_token_idx]]) + return None + + def add_tags_to_target(self, states, hypo_tag): + hypo_tok = states.write + states.write = [] + for token, tag in zip(hypo_tok, hypo_tag): + if tag != states.prev_tag: + if states.prev_tag == 0: + states.write.append(torch.tensor( + self.tgtdict.index(f"<{self.tags[tag - 1]}>"), dtype=token.dtype)) + elif tag == 0: + states.write.append(torch.tensor( + self.tgtdict.index(f""), dtype=token.dtype)) + else: + states.write.append(torch.tensor( + self.tgtdict.index(f""), dtype=token.dtype)) + states.write.append(torch.tensor( + self.tgtdict.index(f"<{self.tags[tag - 1]}>"), dtype=token.dtype)) + states.write.append(token) + states.prev_tag = tag + + def new_hypo(self, states): + states.new_segment = False + prefix_tokens = self._get_prefix(states) + prefix_len = self._get_prefix_len(prefix_tokens) + hypo = self.generate_hypothesis(states, prefix_tokens) + hypo_tokens = hypo['tokens'].int().cpu() + new_hypo_tokens = hypo_tokens[prefix_len:] + hypo_tags = hypo['tags'].int().cpu() + new_hypo_tags = hypo_tags[prefix_len:] + return new_hypo_tokens, new_hypo_tags + + def waitk_prediction(self, states): + new_hypo, new_tags = self.new_hypo(states) + selected_n_words = states.n_audio_words - (states.n_predicted_words + self.waitk) + states.n_predicted_words += selected_n_words + states.write = self._select_words(new_hypo, selected_n_words) + if states.write: + states.prev_toks += states.write + new_tags = new_tags[:len(states.write)] + if sum(new_tags != 0) > 0 or states.prev_tag != 0: + self.add_tags_to_target(states, new_tags) + return True + return False + + def _emit_remaining_tokens(self, states): + final_hypo, final_tags = self.new_hypo(states) + states.write = final_hypo + if sum(final_tags != 0) > 0 or states.prev_tag != 0: + self.add_tags_to_target(states, final_tags) + return WRITE_ACTION diff --git a/fbk_uts/simultaneous/test_base_simulst_agent.py b/fbk_uts/simultaneous/test_base_simulst_agent.py index d7c09a8d..583c8da7 100644 --- a/fbk_uts/simultaneous/test_base_simulst_agent.py +++ b/fbk_uts/simultaneous/test_base_simulst_agent.py @@ -56,7 +56,7 @@ def initialize_agent(agent, args): agent.feature_extractor = OnlineFeatureExtractor(args) agent.eos = "" agent.eos_idx = 0 - agent.prefix_token_idx = 0 + agent.prefix_token_idx = None agent.tgtdict = Dictionary() agent.tgtdict.add_symbol(BOW_PREFIX + "I") agent.tgtdict.add_symbol(BOW_PREFIX + "am") diff --git a/fbk_uts/simultaneous/test_waitk_tags.py b/fbk_uts/simultaneous/test_waitk_tags.py new file mode 100644 index 00000000..02e6336d --- /dev/null +++ b/fbk_uts/simultaneous/test_waitk_tags.py @@ -0,0 +1,121 @@ +# Copyright 2023 FBK + +# Licensed under the Apache License, Version 2.0 (the "License"); +# you may not use this file except in compliance with the License. +# You may obtain a copy of the License at + +# http://www.apache.org/licenses/LICENSE-2.0 + +# Unless required by applicable law or agreed to in writing, software +# distributed under the License is distributed on an "AS IS" BASIS, +# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +# See the License for the specific language governing permissions and +# limitations under the License +import unittest +from unittest.mock import patch +import copy + +import torch + +from examples.speech_to_text.simultaneous_translation.agents.base_simulst_agent import BOW_PREFIX +from examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags import WaitkAgentWithTags + +from fbk_uts.simultaneous.test_base_simulst_agent import BaseSTAgentTestCase + + +class WaitkSimulSTWithTagsTestCase(BaseSTAgentTestCase, unittest.TestCase): + def add_extra_args(self): + self.args.waitk = 0 + self.args.parallel = False + + def create_agent(self): + return WaitkAgentWithTags(self.args) + + @patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' + 'WaitkAgentWithTags.load_model_vocab') + @patch('examples.speech_to_text.simultaneous_translation.agents.base_simulst_agent.' + 'FairseqSimulSTAgent.__init__') + def setUp(self, mock_load_model_vocab, mock_simulst_agent_init): + mock_simulst_agent_init.return_value = None + mock_load_model_vocab.return_value = None + self.base_init() + self.hypo = BOW_PREFIX + "quokka " + BOW_PREFIX + "is " + BOW_PREFIX + "pretty ." + self.agent.tgtdict.add_symbol(BOW_PREFIX + "is") + self.agent.tgtdict.add_symbol(BOW_PREFIX + "pretty") + self.agent.tgtdict.add_symbol("") + self.agent.tgtdict.add_symbol("") + self.agent.tags = ["", "", "", "", "", "", "", "", "", "", "PERSON"] + self.encoded_hypo = self.agent.tgtdict.encode_line(self.hypo, add_if_not_exist=False) + self.predicted_tags = torch.tensor([self.agent.tgtdict.index(""), 0, 0, 0, 0]) + self.states.n_audio_words = 3 + + @patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' + 'WaitkAgentWithTags.new_hypo') + def test_full_hypo(self, mock_new_hypo): + # Full hypothesis emitted + mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags + self.states.n_predicted_words = 0 + WaitkAgentWithTags.waitk_prediction(self.agent, self.states) + self.assertEqual(self.states.write, [11, 7, 12, 9, 10, 8, 2]) + + @patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' + 'WaitkAgentWithTags.new_hypo') + def test_wait_1(self, mock_new_hypo): + # Partial hypothesis emitted + mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags + self.states.n_predicted_words = 0 + new_agent = copy.deepcopy(self.agent) + new_agent.waitk = 1 + WaitkAgentWithTags.waitk_prediction(new_agent, self.states) + self.assertEqual(self.states.write, [11, 7, 12, 9]) + + @patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' + 'WaitkAgentWithTags.new_hypo') + def test_wait_1_predicted_1(self, mock_new_hypo): + # Partial hypothesis emitted considering already predicted words + mock_new_hypo.return_value = self.encoded_hypo[1:], self.predicted_tags[1:] + new_agent = copy.deepcopy(self.agent) + new_agent.waitk = 1 + self.states.n_predicted_words = 1 + WaitkAgentWithTags.waitk_prediction(new_agent, self.states) + self.assertEqual(self.states.write, [9]) + + @patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' + 'WaitkAgentWithTags.new_hypo') + def test_wait_3(self, mock_new_hypo): + # No hypothesis emitted + mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags + new_agent = copy.deepcopy(self.agent) + new_agent.waitk = 3 + self.states.n_predicted_words = 0 + WaitkAgentWithTags.waitk_prediction(new_agent, self.states) + self.assertEqual(self.states.write, []) + + @patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' + 'WaitkAgentWithTags.new_hypo') + def test_emit_remaining_tokens_with_tags(self, mock_new_hypo): + mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags + new_agent = copy.deepcopy(self.agent) + new_agent.waitk = 3 + self.states.n_predicted_words = 0 + WaitkAgentWithTags._emit_remaining_tokens(new_agent, self.states) + self.assertEqual(self.states.write, [11, 7, 12, 9, 10, 8, 2]) + + # Move tag towards the end (last word: "pretty") + mock_new_hypo.return_value = self.encoded_hypo, torch.tensor( + [0, 0, self.agent.tgtdict.index(""), 0, 0]) + new_agent = copy.deepcopy(self.agent) + new_agent.waitk = 3 + self.states.n_predicted_words = 0 + WaitkAgentWithTags._emit_remaining_tokens(new_agent, self.states) + self.assertEqual(self.states.write, [7, 9, 11, 10, 12, 8, 2]) + + @patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' + 'WaitkAgentWithTags._emit_remaining_tokens') + def test_finish_read(self, mock_emit_remaining_tokens): + mock_emit_remaining_tokens.return_values = None + super().test_finish_read() + + +if __name__ == '__main__': + unittest.main()