From d3d452d20c637856d7235ec6368b84e06b93c21a Mon Sep 17 00:00:00 2001 From: Matt Date: Thu, 19 Oct 2023 12:04:25 +0100 Subject: [PATCH] Fix and re-enable ConversationalPipeline tests (#26907) * Fix and re-enable conversationalpipeline tests * Fix the batch test so the change only applies to conversational pipeline --- tests/pipelines/test_pipelines_conversational.py | 8 ++++---- tests/test_pipeline_mixin.py | 9 ++++++--- 2 files changed, 10 insertions(+), 7 deletions(-) diff --git a/tests/pipelines/test_pipelines_conversational.py b/tests/pipelines/test_pipelines_conversational.py index 2f6ba61340f667..c85eb04c1957ef 100644 --- a/tests/pipelines/test_pipelines_conversational.py +++ b/tests/pipelines/test_pipelines_conversational.py @@ -77,14 +77,14 @@ def get_test_pipeline(self, model, tokenizer, processor): def run_pipeline_test(self, conversation_agent, _): # Simple - outputs = conversation_agent(Conversation("Hi there!")) + outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20) self.assertEqual( outputs, Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), ) # Single list - outputs = conversation_agent([Conversation("Hi there!")]) + outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20) self.assertEqual( outputs, Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), @@ -96,7 +96,7 @@ def run_pipeline_test(self, conversation_agent, _): self.assertEqual(len(conversation_1), 1) self.assertEqual(len(conversation_2), 1) - outputs = conversation_agent([conversation_1, conversation_2]) + outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20) self.assertEqual(outputs, [conversation_1, conversation_2]) self.assertEqual( outputs, @@ -118,7 +118,7 @@ def run_pipeline_test(self, conversation_agent, _): # One conversation with history conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"}) - outputs = conversation_agent(conversation_2) + outputs = conversation_agent(conversation_2, max_new_tokens=20) self.assertEqual(outputs, conversation_2) self.assertEqual( outputs, diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index bf01d29a92a0b6..0c07248ab065d1 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -312,8 +312,12 @@ def data(n): yield copy.deepcopy(random.choice(examples)) out = [] - for item in pipeline(data(10), batch_size=4): - out.append(item) + if task == "conversational": + for item in pipeline(data(10), batch_size=4, max_new_tokens=20): + out.append(item) + else: + for item in pipeline(data(10), batch_size=4): + out.append(item) self.assertEqual(len(out), 10) run_batch_test(pipeline, examples) @@ -327,7 +331,6 @@ def test_pipeline_automatic_speech_recognition(self): self.run_task_tests(task="automatic-speech-recognition") @is_pipeline_test - @unittest.skip("Conversational tests are currently broken for several models, will fix ASAP - Matt") def test_pipeline_conversational(self): self.run_task_tests(task="conversational")