From 1d51daf726c81a5b854ff6c9a68e71698b030ce8 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 18 Oct 2023 15:55:17 +0100 Subject: [PATCH 1/2] Fix and re-enable conversationalpipeline tests --- tests/pipelines/test_pipelines_conversational.py | 8 ++++---- tests/test_pipeline_mixin.py | 3 +-- 2 files changed, 5 insertions(+), 6 deletions(-) diff --git a/tests/pipelines/test_pipelines_conversational.py b/tests/pipelines/test_pipelines_conversational.py index 2f6ba61340f667..c85eb04c1957ef 100644 --- a/tests/pipelines/test_pipelines_conversational.py +++ b/tests/pipelines/test_pipelines_conversational.py @@ -77,14 +77,14 @@ def get_test_pipeline(self, model, tokenizer, processor): def run_pipeline_test(self, conversation_agent, _): # Simple - outputs = conversation_agent(Conversation("Hi there!")) + outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20) self.assertEqual( outputs, Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), ) # Single list - outputs = conversation_agent([Conversation("Hi there!")]) + outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20) self.assertEqual( outputs, Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]), @@ -96,7 +96,7 @@ def run_pipeline_test(self, conversation_agent, _): self.assertEqual(len(conversation_1), 1) self.assertEqual(len(conversation_2), 1) - outputs = conversation_agent([conversation_1, conversation_2]) + outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20) self.assertEqual(outputs, [conversation_1, conversation_2]) self.assertEqual( outputs, @@ -118,7 +118,7 @@ def run_pipeline_test(self, conversation_agent, _): # One conversation with history conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"}) - outputs = conversation_agent(conversation_2) + outputs = conversation_agent(conversation_2, max_new_tokens=20) self.assertEqual(outputs, conversation_2) self.assertEqual( outputs, diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index bf01d29a92a0b6..3cb3e49d6e2912 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -312,7 +312,7 @@ def data(n): yield copy.deepcopy(random.choice(examples)) out = [] - for item in pipeline(data(10), batch_size=4): + for item in pipeline(data(10), batch_size=4, max_new_tokens=20): out.append(item) self.assertEqual(len(out), 10) @@ -327,7 +327,6 @@ def test_pipeline_automatic_speech_recognition(self): self.run_task_tests(task="automatic-speech-recognition") @is_pipeline_test - @unittest.skip("Conversational tests are currently broken for several models, will fix ASAP - Matt") def test_pipeline_conversational(self): self.run_task_tests(task="conversational") From 2114343696890d88d3b5945e7ab6e62faf477480 Mon Sep 17 00:00:00 2001 From: Matt Date: Wed, 18 Oct 2023 16:16:35 +0100 Subject: [PATCH 2/2] Fix the batch test so the change only applies to conversational pipeline --- tests/test_pipeline_mixin.py | 8 ++++++-- 1 file changed, 6 insertions(+), 2 deletions(-) diff --git a/tests/test_pipeline_mixin.py b/tests/test_pipeline_mixin.py index 3cb3e49d6e2912..0c07248ab065d1 100644 --- a/tests/test_pipeline_mixin.py +++ b/tests/test_pipeline_mixin.py @@ -312,8 +312,12 @@ def data(n): yield copy.deepcopy(random.choice(examples)) out = [] - for item in pipeline(data(10), batch_size=4, max_new_tokens=20): - out.append(item) + if task == "conversational": + for item in pipeline(data(10), batch_size=4, max_new_tokens=20): + out.append(item) + else: + for item in pipeline(data(10), batch_size=4): + out.append(item) self.assertEqual(len(out), 10) run_batch_test(pipeline, examples)