Skip to content

Commit

Permalink
draft of unit integration test
Browse files Browse the repository at this point in the history
  • Loading branch information
ErikKaum committed Jul 31, 2024
1 parent 951430a commit ffdefa1
Showing 1 changed file with 57 additions and 0 deletions.
57 changes: 57 additions & 0 deletions integration-tests/models/test_no_repeat_ngram.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,57 @@
from typing import List
import pytest
import requests

def bloom_560_handle(launcher):
with launcher("bigscience/bloom-560m") as handle:
yield handle


@pytest.fixture(scope="module")
async def bloom_560(bloom_560_handle):
await bloom_560_handle.health(240)
return bloom_560_handle.client

@pytest.mark.release
@pytest.mark.asyncio
async def test_bloom_560m(bloom_560):

base_url = bloom_560.base_url
prompt = "The cat sat on the mat. The cat"

repeated_2grams_control = await call_model(base_url, prompt, 0)
assert len(repeated_2grams_control) > 0, "Expected to find repeated bi-grams in control case"

repeated_2grams_test = await call_model(base_url, prompt, 2)
assert len(repeated_2grams_test) == 0, f"Expected no repeated bi-grams, but found: {repeated_2grams_test}"


async def call_model(base_url, prompt, n_grams):
data = {
"inputs": prompt,
"parameters" : {
"max_new_tokens": 20,
"seed": 42,
"no_repeat_ngram_size": n_grams,
"details": True
}
}
res = requests.post(f"{base_url}/generate", json=data)
res = res.json()

tokens = res['details']['tokens']
token_texts = [token['text'] for token in tokens]

# find repeated 2grams
ngrams = [tuple(token_texts[i:i+2]) for i in range(len(token_texts)-2+1)]
ngram_counts = {}
for ngram in ngrams:
if ngram in ngram_counts:
ngram_counts[ngram] += 1
else:
ngram_counts[ngram] = 1

repeated = [list(ngram) for ngram, count in ngram_counts.items() if count > 1]

return repeated

0 comments on commit ffdefa1

Please sign in to comment.