Skip to content

Commit

Permalink
fix generation test
Browse files Browse the repository at this point in the history
  • Loading branch information
patrickvonplaten committed Oct 11, 2023
1 parent 6aa60fe commit d59bcc6
Showing 1 changed file with 13 additions and 1 deletion.
14 changes: 13 additions & 1 deletion tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2953,7 +2953,8 @@ def forward(self, input_ids, foo=False, **kwargs):

return outs

def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
def prepare_inputs_for_generation(self, *args, foo=False, encoder_outputs=None, **kwargs):
kwargs["encoder_outputs"] = encoder_outputs
inputs = super().prepare_inputs_for_generation(*args, **kwargs)

inputs["foo"] = foo
Expand Down Expand Up @@ -2992,3 +2993,14 @@ def prepare_inputs_for_generation(self, *args, foo=False, **kwargs):
assistant_model=assistant,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

# Check that passing encoder_outputs directly also works as expected
encoder_outputs = assistant.get_encoder()(input_ids)

outputs_assisted = model.generate(
foo=True,
assistant_model=assistant,
encoder_outputs=encoder_outputs,
assistant_encoder_outputs=encoder_outputs,
)
self.assertListEqual(outputs_assisted.tolist(), outputs_foo.tolist())

0 comments on commit d59bcc6

Please sign in to comment.