Skip to content

Commit

Permalink
Corrected FDT, SDT metrics in case of empty outputs (#733)
Browse files Browse the repository at this point in the history
  • Loading branch information
ljaljushkin authored Aug 8, 2024
1 parent ab57b24 commit e505352
Showing 1 changed file with 52 additions and 35 deletions.
87 changes: 52 additions & 35 deletions llm_bench/python/who_what_benchmark/whowhatbench/whowhat_metrics.py
Original file line number Diff line number Diff line change
Expand Up @@ -30,44 +30,61 @@ def evaluate_divergency(tokenizer, data_gold, data_prediction):

DEBUG = False
# NOTE: a - reference answers, b - answers to evaluate
fdt_list, sdt_list, sdtn_list, fdt_max = [], [], [], []

fdt_list = [] # each value = the position of first divergent (different) token.
sdt_list = [] # each value = number of tokens to correct in the prediction.
sdtn_list = [] # each value = share of tokens to correct in the prediction
fdt_max = [] # each value = total number of tokens in the reference
for a_answer, b_answer in zip(answers_gold, answers_prediction):
a_indexes = tokenizer.encode(a_answer, return_tensors="pt").squeeze().tolist()
b_indexes = tokenizer.encode(b_answer, return_tensors="pt").squeeze().tolist()
if isinstance(a_indexes, int):
a_indexes = list([a_indexes])
if isinstance(b_indexes, int):
b_indexes = list([b_indexes])
fdt_max.append(len(a_indexes))

matcher = SequenceMatcher(None, a_indexes, b_indexes)
blocks = matcher.get_matching_blocks()
a, b, size = blocks[0]
fdt = 0
if a == 0 and b == 0:
fdt = blocks[0].size
fdt_list.append(fdt)

num_matched = sum(block.size for block in blocks)
sdt = (
len(b_indexes) - num_matched
) # how many tokens to correct in the prediction
sdt_list.append(sdt)
sdt_norm = sdt / len(b_indexes) # share of tokens to correct in the prediction
sdtn_list.append(sdt_norm)

if DEBUG:
print(blocks)
for block in blocks:
a, b, size = block
matched = a_indexes[a : a + size + 1]
print(matched)
print(tokenizer.decode(matched))
matched = b_indexes[b : b + size + 1]
print(matched)
print(tokenizer.decode(matched))

if not a_indexes and not b_indexes:
sdt_list.append(0)
fdt_list.append(0)
sdtn_list.append(0)
fdt_max.append(0)
elif a_indexes and not b_indexes:
sdt_list.append(len(a_indexes))
fdt_list.append(0)
sdtn_list.append(1)
fdt_max.append(len(a_indexes))
elif not a_indexes and b_indexes:
sdt_list.append(len(b_indexes))
fdt_list.append(0)
sdtn_list.append(1)
fdt_max.append(0)
else:
if isinstance(a_indexes, int):
a_indexes = list([a_indexes])
if isinstance(b_indexes, int):
b_indexes = list([b_indexes])
fdt_max.append(len(a_indexes))

matcher = SequenceMatcher(None, a_indexes, b_indexes)
blocks = matcher.get_matching_blocks()
a, b, size = blocks[0]
fdt = 0
if a == 0 and b == 0:
fdt = blocks[0].size
fdt_list.append(fdt)

num_matched = sum(block.size for block in blocks)
sdt = (
len(b_indexes) - num_matched
)
sdt_list.append(sdt)
sdt_norm = sdt / len(b_indexes)
sdtn_list.append(sdt_norm)

if DEBUG:
print(blocks)
for block in blocks:
a, b, size = block
matched = a_indexes[a : a + size + 1]
print(matched)
print(tokenizer.decode(matched))
matched = b_indexes[b : b + size + 1]
print(matched)
print(tokenizer.decode(matched))
fdt_max = np.average(fdt_max)
metric_per_question = {
"FDT": fdt_list,
Expand Down

0 comments on commit e505352

Please sign in to comment.