Skip to content

Commit

Permalink
add string match
Browse files Browse the repository at this point in the history
  • Loading branch information
tcapelle committed Oct 11, 2024
1 parent a22c76b commit 8050e8b
Showing 1 changed file with 32 additions and 2 deletions.
34 changes: 32 additions & 2 deletions weave/flow/scorer/regex_scorer.py
Original file line number Diff line number Diff line change
@@ -1,19 +1,27 @@
import re
from typing import Union
from typing import Union, List, Any

from pydantic import Field

import weave
from weave.flow.scorer.base_scorer import Scorer

class StringScorer(Scorer):
"""
Scorer that checks if the model output string is found in the search columns of the dataset row.
"""
target_columns: List[str] = Field(default_factory=list, description="The names of the columns that are used as input to the scorer")

def score(self, model_output: Any, dataset_row: dict) -> dict:
string_in_input = any([model_output.lower() in input.lower() for k, input in dataset_row.items() if k in self.target_columns])
return {"string_in_input": string_in_input}

class RegexScorer(Scorer):
patterns: Union[str, list[str]] = Field(
default_factory=list, description="The patterns or keywords to match"
)
ignore_case: bool = True
ignore_whitespace: bool = False
use_regex: bool = False # Use regex patterns if True
match_full_string: bool = False # Match the entire string if True
target_column: str = Field(default="target", description="The class name to match")

Expand Down Expand Up @@ -49,3 +57,25 @@ def score(
)

return {"string_match": match_found}



if __name__ == "__main__":
import asyncio

scorer = StringScorer(target_columns=["col1", "col2"])

@weave.op
def f(col1, col2):
return "Hello"

model_output = f(col1="hello", col2="world")
dataset_row = {"col1": "Hello my name is Morgan", "col2": "I am an engineer"}
print(scorer.score(model_output=model_output, dataset_row=dataset_row))

dataset = [{"col1": "Hello my name is Morgan", "col2": "I am an engineer", "target": "Morgan"},
{"col1": "Hello my name is John", "col2": "I am a doctor", "target": "John"}]

evaluation = weave.Evaluation(dataset=dataset, scorers=[scorer])

eval_out = asyncio.run(evaluation.evaluate(f))

0 comments on commit 8050e8b

Please sign in to comment.