Skip to content

Commit

Permalink
Fix and re-enable conversationalpipeline tests
Browse files Browse the repository at this point in the history
  • Loading branch information
Rocketknight1 committed Oct 18, 2023
1 parent de55ead commit 1d51daf
Show file tree
Hide file tree
Showing 2 changed files with 5 additions and 6 deletions.
8 changes: 4 additions & 4 deletions tests/pipelines/test_pipelines_conversational.py
Original file line number Diff line number Diff line change
Expand Up @@ -77,14 +77,14 @@ def get_test_pipeline(self, model, tokenizer, processor):

def run_pipeline_test(self, conversation_agent, _):
# Simple
outputs = conversation_agent(Conversation("Hi there!"))
outputs = conversation_agent(Conversation("Hi there!"), max_new_tokens=20)
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
)

# Single list
outputs = conversation_agent([Conversation("Hi there!")])
outputs = conversation_agent([Conversation("Hi there!")], max_new_tokens=20)
self.assertEqual(
outputs,
Conversation([{"role": "user", "content": "Hi there!"}, {"role": "assistant", "content": ANY(str)}]),
Expand All @@ -96,7 +96,7 @@ def run_pipeline_test(self, conversation_agent, _):
self.assertEqual(len(conversation_1), 1)
self.assertEqual(len(conversation_2), 1)

outputs = conversation_agent([conversation_1, conversation_2])
outputs = conversation_agent([conversation_1, conversation_2], max_new_tokens=20)
self.assertEqual(outputs, [conversation_1, conversation_2])
self.assertEqual(
outputs,
Expand All @@ -118,7 +118,7 @@ def run_pipeline_test(self, conversation_agent, _):

# One conversation with history
conversation_2.add_message({"role": "user", "content": "Why do you recommend it?"})
outputs = conversation_agent(conversation_2)
outputs = conversation_agent(conversation_2, max_new_tokens=20)
self.assertEqual(outputs, conversation_2)
self.assertEqual(
outputs,
Expand Down
3 changes: 1 addition & 2 deletions tests/test_pipeline_mixin.py
Original file line number Diff line number Diff line change
Expand Up @@ -312,7 +312,7 @@ def data(n):
yield copy.deepcopy(random.choice(examples))

out = []
for item in pipeline(data(10), batch_size=4):
for item in pipeline(data(10), batch_size=4, max_new_tokens=20):
out.append(item)
self.assertEqual(len(out), 10)

Expand All @@ -327,7 +327,6 @@ def test_pipeline_automatic_speech_recognition(self):
self.run_task_tests(task="automatic-speech-recognition")

@is_pipeline_test
@unittest.skip("Conversational tests are currently broken for several models, will fix ASAP - Matt")
def test_pipeline_conversational(self):
self.run_task_tests(task="conversational")

Expand Down

0 comments on commit 1d51daf

Please sign in to comment.