-
Notifications
You must be signed in to change notification settings - Fork 67
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
feat(weave): fixes tests, summarization scorer re-write, re-names flo…
…w/scorer dir, create weave/scorers dir
- Loading branch information
1 parent
2114c4f
commit 0b2bbf2
Showing
26 changed files
with
505 additions
and
178 deletions.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,37 +1,63 @@ | ||
import pytest | ||
from openai import OpenAI | ||
|
||
from weave.flow.scorer.hallucination_scorer import ( | ||
from weave.flow.scorers.hallucination_scorer import ( | ||
HallucinationReasoning, | ||
HallucinationResponse, | ||
) | ||
from weave.scorers import ( | ||
HallucinationScorer, | ||
) | ||
|
||
|
||
# mock the create function | ||
@pytest.fixture | ||
def mock_create(monkeypatch): | ||
def _mock_create(*args, **kwargs): | ||
return HallucinationResponse( | ||
chain_of_thought="The output is consistent with the input data.", | ||
is_hallucination=False | ||
hallucination_reasonings=[ | ||
HallucinationReasoning( | ||
observation="My observation for this is that the output is consistent with the input data.", | ||
hallucination_type="No Hallucination", | ||
) | ||
], | ||
conclusion="The output is consistent with the input data.", | ||
is_hallucination=False, | ||
) | ||
monkeypatch.setattr('weave.flow.scorer.hallucination_scorer.create', _mock_create) | ||
|
||
monkeypatch.setattr("weave.flow.scorers.hallucination_scorer.create", _mock_create) | ||
|
||
|
||
@pytest.fixture | ||
def hallucination_scorer(mock_create): | ||
return HallucinationScorer(client=OpenAI(api_key="DUMMY_API_KEY"), model_id="gpt-4o", temperature=0.7, max_tokens=4096) | ||
return HallucinationScorer( | ||
client=OpenAI(api_key="DUMMY_API_KEY"), | ||
model_id="gpt-4o", | ||
temperature=0.7, | ||
max_tokens=4096, | ||
) | ||
|
||
|
||
def test_hallucination_scorer_initialization(hallucination_scorer): | ||
assert isinstance(hallucination_scorer, HallucinationScorer) | ||
assert hallucination_scorer.model_id == "gpt-4o" | ||
assert hallucination_scorer.temperature == 0.7 | ||
assert hallucination_scorer.max_tokens == 4096 | ||
|
||
|
||
def test_hallucination_scorer_score(hallucination_scorer, mock_create): | ||
output = "John's favorite cheese is cheddar." | ||
context = "John likes various types of cheese." | ||
result = hallucination_scorer.score(output, context) | ||
result = hallucination_scorer.score(output=output, context=context) | ||
assert isinstance(result, HallucinationResponse) | ||
assert not result.is_hallucination | ||
assert "The output is consistent with the input data." == result.chain_of_thought | ||
|
||
# Add more tests as needed | ||
assert isinstance(result.hallucination_reasonings, list) | ||
assert isinstance(result.hallucination_reasonings[0], HallucinationReasoning) | ||
assert result.chain_of_thought == "The output is consistent with the input data." | ||
assert ( | ||
result.hallucination_reasonings[0].observation | ||
== "My observation for this is that the output is consistent with the input data." | ||
) | ||
assert result.conclusion == "The output is consistent with the input data." | ||
assert result.hallucination_reasonings[0].hallucination_type == "No Hallucination" |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,44 +1,50 @@ | ||
from weave.flow.scorer.json_scorer import JSONScorer | ||
from weave.scorers import ValidJSONScorer | ||
|
||
|
||
def test_json_scorer_valid_json(): | ||
scorer = JSONScorer() | ||
scorer = ValidJSONScorer() | ||
output = '{"city": "San Francisco", "country": "USA"}' | ||
result = scorer.score(output) | ||
assert result["json_valid"] is True | ||
|
||
|
||
def test_json_scorer_invalid_json(): | ||
scorer = JSONScorer() | ||
scorer = ValidJSONScorer() | ||
output = '{"city": "San Francisco", "country": "USA"' | ||
result = scorer.score(output) | ||
assert result["json_valid"] is False | ||
|
||
|
||
def test_json_scorer_non_json_string(): | ||
scorer = JSONScorer() | ||
scorer = ValidJSONScorer() | ||
output = "Just a plain string." | ||
result = scorer.score(output) | ||
assert result["json_valid"] is False | ||
|
||
|
||
def test_json_scorer_valid_json_list(): | ||
scorer = JSONScorer() | ||
output = '[1, 2, 3, 4, 5]' | ||
scorer = ValidJSONScorer() | ||
output = "[1, 2, 3, 4, 5]" | ||
result = scorer.score(output) | ||
assert result["json_valid"] is True | ||
|
||
|
||
def test_json_scorer_nested_json(): | ||
scorer = JSONScorer() | ||
scorer = ValidJSONScorer() | ||
output = '{"person": {"name": "John", "age": 30}, "city": "New York"}' | ||
result = scorer.score(output) | ||
assert result["json_valid"] is True | ||
|
||
|
||
def test_json_scorer_empty_object(): | ||
scorer = JSONScorer() | ||
output = '{}' | ||
scorer = ValidJSONScorer() | ||
output = "{}" | ||
result = scorer.score(output) | ||
assert result["json_valid"] is True | ||
|
||
|
||
def test_json_scorer_empty_list(): | ||
scorer = JSONScorer() | ||
output = '[]' | ||
scorer = ValidJSONScorer() | ||
output = "[]" | ||
result = scorer.score(output) | ||
assert result["json_valid"] is True |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -1,46 +1,55 @@ | ||
import pytest | ||
from pydantic import BaseModel | ||
|
||
from weave.flow.scorer.pydantic_scorer import PydanticScorer | ||
from weave.scorers import PydanticScorer | ||
|
||
|
||
class User(BaseModel): | ||
name: str | ||
age: int | ||
|
||
|
||
@pytest.fixture | ||
def user_scorer(): | ||
return PydanticScorer(model=User) | ||
|
||
|
||
def test_pydantic_scorer_initialization(): | ||
scorer = PydanticScorer(model=User) | ||
assert isinstance(scorer, PydanticScorer) | ||
assert scorer.model == User | ||
|
||
|
||
def test_pydantic_scorer_valid_json_string(user_scorer): | ||
valid_json = '{"name": "John", "age": 30}' | ||
assert user_scorer.score(valid_json) == {"valid_pydantic": True} | ||
|
||
|
||
def test_pydantic_scorer_valid_dict(user_scorer): | ||
valid_dict = {"name": "John", "age": 30} | ||
assert user_scorer.score(valid_dict) == {"valid_pydantic": True} | ||
|
||
|
||
def test_pydantic_scorer_invalid_json_string(user_scorer): | ||
invalid_json = '{"name": "John", "age": "thirty"}' | ||
assert user_scorer.score(invalid_json) == {"valid_pydantic": False} | ||
|
||
|
||
def test_pydantic_scorer_invalid_dict(user_scorer): | ||
invalid_dict = {"name": "John", "age": "thirty"} | ||
assert user_scorer.score(invalid_dict) == {"valid_pydantic": False} | ||
|
||
|
||
def test_pydantic_scorer_missing_field(user_scorer): | ||
missing_field = '{"name": "John"}' | ||
assert user_scorer.score(missing_field) == {"valid_pydantic": False} | ||
|
||
|
||
def test_pydantic_scorer_extra_field(user_scorer): | ||
extra_field = '{"name": "John", "age": 30, "city": "New York"}' | ||
assert user_scorer.score(extra_field) == {"valid_pydantic": True} | ||
|
||
|
||
def test_pydantic_scorer_invalid_input_type(user_scorer): | ||
invalid_input = 123 # Neither a string nor a dict | ||
assert user_scorer.score(invalid_input) == {"valid_pydantic": False} | ||
assert user_scorer.score(invalid_input) == {"valid_pydantic": False} |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Oops, something went wrong.