Skip to content

Commit

Permalink
test(weave): Add large input tests for scorers
Browse files Browse the repository at this point in the history
  • Loading branch information
devin-ai-integration[bot] and morganmcg1 committed Dec 13, 2024
1 parent 60bf3e0 commit 68c1b5e
Show file tree
Hide file tree
Showing 5 changed files with 247 additions and 2 deletions.
21 changes: 21 additions & 0 deletions tests/scorers/test_coherence_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@

import weave
from weave.scorers.coherence_scorer import CoherenceScorer
from tests.scorers.test_utils import generate_large_text


@pytest.fixture
Expand Down Expand Up @@ -93,3 +94,23 @@ def model(input: str):
assert "coherent" in result["CoherenceScorer"]
assert result["CoherenceScorer"]["coherent"]["true_count"] == 1
assert result["CoherenceScorer"]["coherent"]["true_fraction"] == pytest.approx(0.5)


@pytest.mark.asyncio
async def test_coherence_scorer_large_input(coherence_scorer):
large_text = generate_large_text()

result = await coherence_scorer.score(
input="What is the story about?",
output=large_text
)

assert "coherent" in result
assert "coherence" in result
assert "coherence_score" in result


@pytest.mark.asyncio
async def test_coherence_scorer_error_handling(coherence_scorer):
with pytest.raises(ValueError):
await coherence_scorer.score(input="", output="")
61 changes: 61 additions & 0 deletions tests/scorers/test_context_relevance_scorer.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
"""Tests for the Context Relevance Scorer."""
import pytest
from weave.scorers.context_relevance_scorer import ContextRelevanceScorer
from tests.scorers.test_utils import generate_large_text, generate_context_and_output


@pytest.fixture
def context_relevance_scorer():
"""Create a context relevance scorer for testing."""
return ContextRelevanceScorer()


@pytest.mark.asyncio
async def test_context_relevance_scorer_basic(context_relevance_scorer):
"""Test basic functionality of the context relevance scorer."""
query = "What is the capital of France?"
context = "Paris is the capital of France. It is known for the Eiffel Tower."
output = "The capital of France is Paris."

result = await context_relevance_scorer.score(
query=query,
context=context,
output=output,
verbose=True
)

assert "flagged" in result
assert "extras" in result
assert "score" in result["extras"]
assert "all_spans" in result["extras"]


@pytest.mark.asyncio
async def test_context_relevance_scorer_large_input(context_relevance_scorer):
"""Test the context relevance scorer with large inputs."""
query = "What is the story about?"
context, output = generate_context_and_output(100_000, context_ratio=0.8)

result = await context_relevance_scorer.score(
query=query,
context=context,
output=output,
verbose=True
)

assert "flagged" in result
assert "extras" in result
assert "score" in result["extras"]
assert "all_spans" in result["extras"]


@pytest.mark.asyncio
async def test_context_relevance_scorer_error_handling(context_relevance_scorer):
"""Test error handling in the context relevance scorer."""
with pytest.raises(ValueError):
await context_relevance_scorer.score(
query="",
context="",
output="",
verbose=True
)
58 changes: 57 additions & 1 deletion tests/scorers/test_hallucination_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,11 +4,14 @@
import weave
from weave.scorers import (
HallucinationFreeScorer,
HallucinationScorer,
FaithfulnessScorer,
)
from weave.scorers.hallucination_scorer import (
HallucinationReasoning,
HallucinationResponse,
)
from tests.scorers.test_utils import generate_large_text, generate_context_and_output


# mock the create function
Expand Down Expand Up @@ -40,11 +43,20 @@ def hallucination_scorer(mock_create):
)


@pytest.fixture
def hallucination_scorer_v2(mock_create):
return HallucinationScorer()


@pytest.fixture
def faithfulness_scorer(mock_create):
return FaithfulnessScorer()


def test_hallucination_scorer_score(hallucination_scorer, mock_create):
output = "John's favorite cheese is cheddar."
context = "John likes various types of cheese."
result = hallucination_scorer.score(output=output, context=context)
# we should be able to do this validation
_ = HallucinationResponse.model_validate(result)

assert result["has_hallucination"] == True
Expand Down Expand Up @@ -103,3 +115,47 @@ def model(input):
assert (
result["HallucinationFreeScorer"]["has_hallucination"]["true_fraction"] == 1.0
)


@pytest.mark.asyncio
async def test_hallucination_scorer_large_input(hallucination_scorer_v2, mock_create):
query = "What is the story about?"
context, output = generate_context_and_output(100_000, context_ratio=0.8)

result = await hallucination_scorer_v2.score(
query=query,
context=context,
output=output
)

assert "flagged" in result
assert "extras" in result
assert "score" in result["extras"]


@pytest.mark.asyncio
async def test_faithfulness_scorer_large_input(faithfulness_scorer, mock_create):
query = "What is the story about?"
context, output = generate_context_and_output(100_000, context_ratio=0.8)

result = await faithfulness_scorer.score(
query=query,
context=context,
output=output
)

assert "flagged" in result
assert "extras" in result
assert "score" in result["extras"]


@pytest.mark.asyncio
async def test_hallucination_scorer_error_handling(hallucination_scorer_v2):
with pytest.raises(ValueError):
await hallucination_scorer_v2.score(query="", context="", output="")


@pytest.mark.asyncio
async def test_faithfulness_scorer_error_handling(faithfulness_scorer):
with pytest.raises(ValueError):
await faithfulness_scorer.score(query="", context="", output="")
47 changes: 46 additions & 1 deletion tests/scorers/test_moderation_scorer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@
import torch
from torch import Tensor

from weave.scorers.moderation_scorer import RollingWindowScorer
from weave.scorers.moderation_scorer import RollingWindowScorer, ToxicityScorer, BiasScorer
from tests.scorers.test_utils import generate_large_text


# Define a concrete subclass for testing since RollingWindowScorer is abstract
Expand Down Expand Up @@ -99,3 +100,47 @@ async def test_tokenize_input_without_truncation(scorer):
scorer._tokenizer.assert_called_with(prompt, return_tensors="pt", truncation=False)
# Assert the tokenized input is as expected
assert torch.equal(result, expected_tensor.to(scorer.device))


@pytest.fixture
def toxicity_scorer():
return ToxicityScorer()


@pytest.fixture
def bias_scorer():
return BiasScorer()


@pytest.mark.asyncio
async def test_toxicity_scorer_large_input(toxicity_scorer):
large_text = generate_large_text()

result = await toxicity_scorer.score(large_text)

assert "extras" in result
assert all(cat in result["extras"] for cat in [
"Race/Origin", "Gender/Sex", "Religion", "Ability", "Violence"
])


@pytest.mark.asyncio
async def test_bias_scorer_large_input(bias_scorer):
large_text = generate_large_text()

result = await bias_scorer.score(large_text)

assert "extras" in result
assert all(cat in result["extras"] for cat in ["gender_bias", "racial_bias"])


@pytest.mark.asyncio
async def test_toxicity_scorer_error_handling(toxicity_scorer):
with pytest.raises(ValueError):
await toxicity_scorer.score("")


@pytest.mark.asyncio
async def test_bias_scorer_error_handling(bias_scorer):
with pytest.raises(ValueError):
await bias_scorer.score("")
62 changes: 62 additions & 0 deletions tests/scorers/test_utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,70 @@
from typing import Any, Optional
from weave.scorers.utils import stringify


def generate_large_text(tokens: int = 100_000, pattern: Optional[str] = None) -> str:
if pattern is None:
pattern = (
"The quick brown fox jumps over the lazy dog. "
"A wizard's job is to vex chumps quickly in fog. "
"Pack my box with five dozen liquor jugs. "
)

words_per_pattern = len(pattern.split())
tokens_per_pattern = words_per_pattern * 1.5
multiplier = int(tokens / tokens_per_pattern)

text = pattern * max(1, multiplier)

return text


def generate_context_and_output(
total_tokens: int = 100_000,
context_ratio: float = 0.5
) -> tuple[str, str]:
context_tokens = int(total_tokens * context_ratio)
output_tokens = total_tokens - context_tokens

context = generate_large_text(context_tokens)
output = generate_large_text(output_tokens)

return context, output


def test_stringify():
assert stringify("Hello, world!") == "Hello, world!"
assert stringify(123) == "123"
assert stringify([1, 2, 3]) == "[\n 1,\n 2,\n 3\n]"
assert stringify({"a": 1, "b": 2}) == '{\n "a": 1,\n "b": 2\n}'


def test_generate_large_text():
text = generate_large_text()
assert len(text) > 0
words = text.split()
assert len(words) > 60000

small_text = generate_large_text(1000)
assert len(small_text) > 0
small_words = small_text.split()
assert len(small_words) > 600

custom_text = generate_large_text(1000, pattern="Test pattern. ")
assert len(custom_text) > 0
assert "Test pattern" in custom_text


def test_generate_context_and_output():
context, output = generate_context_and_output()
assert len(context) > 0
assert len(output) > 0
context_words = context.split()
output_words = output.split()
assert len(context_words) > 30000
assert len(output_words) > 30000

context, output = generate_context_and_output(10000, context_ratio=0.8)
context_words = context.split()
output_words = output.split()
assert len(context_words) > len(output_words)

0 comments on commit 68c1b5e

Please sign in to comment.