-
Notifications
You must be signed in to change notification settings - Fork 1
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
[!96][SIMULTANEOUS] Add agent Wait-k with tags
# Why is the change needed? Currently, the wait-k agent does not support the module for the tag prediction of the parallel model introduced in the paper ["Joint Speech Translation and Named Entity Recognition"](https://arxiv.org/pdf/2210.11987.pdf). # What changes does the patch introduce? Implements the wait-k inference with tags produced by the parallel model. # How was this patch tested? UTs and manual runs
- Loading branch information
Showing
4 changed files
with
223 additions
and
2 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
96 changes: 96 additions & 0 deletions
96
examples/speech_to_text/simultaneous_translation/agents/simul_offline_waitk_tags.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,96 @@ | ||
# Copyright 2023 FBK | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License | ||
import torch | ||
from examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk import WaitkAgent | ||
|
||
try: | ||
from simuleval import READ_ACTION, WRITE_ACTION, DEFAULT_EOS | ||
from simuleval.agents import SpeechAgent | ||
from simuleval.states import ListEntry, SpeechStates | ||
except ImportError: | ||
print("Please install simuleval 'pip install simuleval'") | ||
|
||
|
||
class WaitkAgentWithTags(WaitkAgent): | ||
def load_model_vocab(self, args): | ||
super().load_model_vocab(args) | ||
self.tags = self.task.data_cfg.tags | ||
|
||
def initialize_states(self, states): | ||
super().initialize_states(states) | ||
# Store previous output tokens without considering emitted tags | ||
states.prev_toks = [] | ||
states.prev_tag = 0 | ||
|
||
def _get_prefix(self, states): | ||
if states.prev_toks: | ||
prefix_tokens = torch.tensor([states.prev_toks], dtype=torch.int64) | ||
if self.prefix_token_idx is not None: | ||
return torch.cat( | ||
(torch.LongTensor([[self.prefix_token_idx]]), prefix_tokens), dim=1) | ||
return prefix_tokens | ||
else: | ||
if self.prefix_token_idx is not None: | ||
return torch.LongTensor([[self.prefix_token_idx]]) | ||
return None | ||
|
||
def add_tags_to_target(self, states, hypo_tag): | ||
hypo_tok = states.write | ||
states.write = [] | ||
for token, tag in zip(hypo_tok, hypo_tag): | ||
if tag != states.prev_tag: | ||
if states.prev_tag == 0: | ||
states.write.append(torch.tensor( | ||
self.tgtdict.index(f"<{self.tags[tag - 1]}>"), dtype=token.dtype)) | ||
elif tag == 0: | ||
states.write.append(torch.tensor( | ||
self.tgtdict.index(f"</{self.tags[states.prev_tag - 1]}>"), dtype=token.dtype)) | ||
else: | ||
states.write.append(torch.tensor( | ||
self.tgtdict.index(f"</{self.tags[states.prev_tag - 1]}>"), dtype=token.dtype)) | ||
states.write.append(torch.tensor( | ||
self.tgtdict.index(f"<{self.tags[tag - 1]}>"), dtype=token.dtype)) | ||
states.write.append(token) | ||
states.prev_tag = tag | ||
|
||
def new_hypo(self, states): | ||
states.new_segment = False | ||
prefix_tokens = self._get_prefix(states) | ||
prefix_len = self._get_prefix_len(prefix_tokens) | ||
hypo = self.generate_hypothesis(states, prefix_tokens) | ||
hypo_tokens = hypo['tokens'].int().cpu() | ||
new_hypo_tokens = hypo_tokens[prefix_len:] | ||
hypo_tags = hypo['tags'].int().cpu() | ||
new_hypo_tags = hypo_tags[prefix_len:] | ||
return new_hypo_tokens, new_hypo_tags | ||
|
||
def waitk_prediction(self, states): | ||
new_hypo, new_tags = self.new_hypo(states) | ||
selected_n_words = states.n_audio_words - (states.n_predicted_words + self.waitk) | ||
states.n_predicted_words += selected_n_words | ||
states.write = self._select_words(new_hypo, selected_n_words) | ||
if states.write: | ||
states.prev_toks += states.write | ||
new_tags = new_tags[:len(states.write)] | ||
if sum(new_tags != 0) > 0 or states.prev_tag != 0: | ||
self.add_tags_to_target(states, new_tags) | ||
return True | ||
return False | ||
|
||
def _emit_remaining_tokens(self, states): | ||
final_hypo, final_tags = self.new_hypo(states) | ||
states.write = final_hypo | ||
if sum(final_tags != 0) > 0 or states.prev_tag != 0: | ||
self.add_tags_to_target(states, final_tags) | ||
return WRITE_ACTION |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,121 @@ | ||
# Copyright 2023 FBK | ||
|
||
# Licensed under the Apache License, Version 2.0 (the "License"); | ||
# you may not use this file except in compliance with the License. | ||
# You may obtain a copy of the License at | ||
|
||
# http://www.apache.org/licenses/LICENSE-2.0 | ||
|
||
# Unless required by applicable law or agreed to in writing, software | ||
# distributed under the License is distributed on an "AS IS" BASIS, | ||
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. | ||
# See the License for the specific language governing permissions and | ||
# limitations under the License | ||
import unittest | ||
from unittest.mock import patch | ||
import copy | ||
|
||
import torch | ||
|
||
from examples.speech_to_text.simultaneous_translation.agents.base_simulst_agent import BOW_PREFIX | ||
from examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags import WaitkAgentWithTags | ||
|
||
from fbk_uts.simultaneous.test_base_simulst_agent import BaseSTAgentTestCase | ||
|
||
|
||
class WaitkSimulSTWithTagsTestCase(BaseSTAgentTestCase, unittest.TestCase): | ||
def add_extra_args(self): | ||
self.args.waitk = 0 | ||
self.args.parallel = False | ||
|
||
def create_agent(self): | ||
return WaitkAgentWithTags(self.args) | ||
|
||
@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' | ||
'WaitkAgentWithTags.load_model_vocab') | ||
@patch('examples.speech_to_text.simultaneous_translation.agents.base_simulst_agent.' | ||
'FairseqSimulSTAgent.__init__') | ||
def setUp(self, mock_load_model_vocab, mock_simulst_agent_init): | ||
mock_simulst_agent_init.return_value = None | ||
mock_load_model_vocab.return_value = None | ||
self.base_init() | ||
self.hypo = BOW_PREFIX + "quokka " + BOW_PREFIX + "is " + BOW_PREFIX + "pretty ." | ||
self.agent.tgtdict.add_symbol(BOW_PREFIX + "is") | ||
self.agent.tgtdict.add_symbol(BOW_PREFIX + "pretty") | ||
self.agent.tgtdict.add_symbol("<PERSON>") | ||
self.agent.tgtdict.add_symbol("</PERSON>") | ||
self.agent.tags = ["", "", "", "", "", "", "", "", "", "", "PERSON"] | ||
self.encoded_hypo = self.agent.tgtdict.encode_line(self.hypo, add_if_not_exist=False) | ||
self.predicted_tags = torch.tensor([self.agent.tgtdict.index("<PERSON>"), 0, 0, 0, 0]) | ||
self.states.n_audio_words = 3 | ||
|
||
@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' | ||
'WaitkAgentWithTags.new_hypo') | ||
def test_full_hypo(self, mock_new_hypo): | ||
# Full hypothesis emitted | ||
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags | ||
self.states.n_predicted_words = 0 | ||
WaitkAgentWithTags.waitk_prediction(self.agent, self.states) | ||
self.assertEqual(self.states.write, [11, 7, 12, 9, 10, 8, 2]) | ||
|
||
@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' | ||
'WaitkAgentWithTags.new_hypo') | ||
def test_wait_1(self, mock_new_hypo): | ||
# Partial hypothesis emitted | ||
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags | ||
self.states.n_predicted_words = 0 | ||
new_agent = copy.deepcopy(self.agent) | ||
new_agent.waitk = 1 | ||
WaitkAgentWithTags.waitk_prediction(new_agent, self.states) | ||
self.assertEqual(self.states.write, [11, 7, 12, 9]) | ||
|
||
@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' | ||
'WaitkAgentWithTags.new_hypo') | ||
def test_wait_1_predicted_1(self, mock_new_hypo): | ||
# Partial hypothesis emitted considering already predicted words | ||
mock_new_hypo.return_value = self.encoded_hypo[1:], self.predicted_tags[1:] | ||
new_agent = copy.deepcopy(self.agent) | ||
new_agent.waitk = 1 | ||
self.states.n_predicted_words = 1 | ||
WaitkAgentWithTags.waitk_prediction(new_agent, self.states) | ||
self.assertEqual(self.states.write, [9]) | ||
|
||
@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' | ||
'WaitkAgentWithTags.new_hypo') | ||
def test_wait_3(self, mock_new_hypo): | ||
# No hypothesis emitted | ||
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags | ||
new_agent = copy.deepcopy(self.agent) | ||
new_agent.waitk = 3 | ||
self.states.n_predicted_words = 0 | ||
WaitkAgentWithTags.waitk_prediction(new_agent, self.states) | ||
self.assertEqual(self.states.write, []) | ||
|
||
@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' | ||
'WaitkAgentWithTags.new_hypo') | ||
def test_emit_remaining_tokens_with_tags(self, mock_new_hypo): | ||
mock_new_hypo.return_value = self.encoded_hypo, self.predicted_tags | ||
new_agent = copy.deepcopy(self.agent) | ||
new_agent.waitk = 3 | ||
self.states.n_predicted_words = 0 | ||
WaitkAgentWithTags._emit_remaining_tokens(new_agent, self.states) | ||
self.assertEqual(self.states.write, [11, 7, 12, 9, 10, 8, 2]) | ||
|
||
# Move tag towards the end (last word: "pretty") | ||
mock_new_hypo.return_value = self.encoded_hypo, torch.tensor( | ||
[0, 0, self.agent.tgtdict.index("<PERSON>"), 0, 0]) | ||
new_agent = copy.deepcopy(self.agent) | ||
new_agent.waitk = 3 | ||
self.states.n_predicted_words = 0 | ||
WaitkAgentWithTags._emit_remaining_tokens(new_agent, self.states) | ||
self.assertEqual(self.states.write, [7, 9, 11, 10, 12, 8, 2]) | ||
|
||
@patch('examples.speech_to_text.simultaneous_translation.agents.simul_offline_waitk_tags.' | ||
'WaitkAgentWithTags._emit_remaining_tokens') | ||
def test_finish_read(self, mock_emit_remaining_tokens): | ||
mock_emit_remaining_tokens.return_values = None | ||
super().test_finish_read() | ||
|
||
|
||
if __name__ == '__main__': | ||
unittest.main() |