Skip to content

Commit

Permalink
Watermark: fix tests (huggingface#30961)
Browse files Browse the repository at this point in the history
* fix tests

* style

* Update tests/generation/test_utils.py

Co-authored-by: amyeroberts <[email protected]>

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
2 people authored and vasqu committed Jun 1, 2024
1 parent f2a7f7c commit 5237955
Showing 1 changed file with 3 additions and 9 deletions.
12 changes: 3 additions & 9 deletions tests/generation/test_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -2148,6 +2148,8 @@ def test_watermark_generation(self):
watermark_config = WatermarkingConfig(bias=2.5, seeding_scheme="selfhash")
_ = model.generate(**model_inputs, watermarking_config=watermark_config, do_sample=False, max_length=15)

# We will not check watermarked text, since we check it in `logits_processors` tests
# Checking if generated ids are as expected fails on different hardware
args = {
"bias": 2.0,
"context_width": 1,
Expand All @@ -2158,19 +2160,11 @@ def test_watermark_generation(self):
output = model.generate(**model_inputs, do_sample=False, max_length=15)
output_selfhash = model.generate(**model_inputs, watermarking_config=args, do_sample=False, max_length=15)

# check that the watermarked text is generating what is should
self.assertListEqual(
output.tolist(), [[40, 481, 307, 262, 717, 284, 9159, 326, 314, 716, 407, 257, 4336, 286, 262]]
)
self.assertListEqual(
output_selfhash.tolist(), [[40, 481, 307, 2263, 616, 640, 284, 651, 616, 1621, 503, 612, 553, 531, 367]]
)

# Check that the detector is detecting watermarked text
detector = WatermarkDetector(model_config=model.config, device=torch_device, watermarking_config=args)
detection_out_watermarked = detector(output_selfhash[:, input_len:], return_dict=True)
detection_out = detector(output[:, input_len:], return_dict=True)

# check that the detector is detecting watermarked text
self.assertListEqual(detection_out_watermarked.prediction.tolist(), [True])
self.assertListEqual(detection_out.prediction.tolist(), [False])

Expand Down

0 comments on commit 5237955

Please sign in to comment.