forked from All-Hands-AI/OpenHands
-
Notifications
You must be signed in to change notification settings - Fork 0
/
ast_eval_tf.py
127 lines (115 loc) Β· 4.35 KB
/
ast_eval_tf.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
# Copyright 2023 https://github.com/ShishirPatil/gorilla
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.
# This file is modified from https://github.com/ShishirPatil/gorilla/blob/main/eval/eval-scripts/ast_eval_tf.py
from tree_sitter import Language, Parser
# Get all the subtrees given a root_node
def get_all_sub_trees(root_node):
node_stack = []
sub_tree_sexp_list = []
depth = 1
# text = root_node.text
node_stack.append([root_node, depth])
while len(node_stack) != 0:
cur_node, cur_depth = node_stack.pop()
if cur_node.child_count > 0:
sub_tree_sexp_list.append(
[cur_node.sexp(), cur_depth, cur_node, cur_node.children[0].text]
)
else:
sub_tree_sexp_list.append([cur_node.sexp(), cur_depth, cur_node, None])
for child_node in cur_node.children:
if len(child_node.children) != 0:
depth = cur_depth + 1
node_stack.append([child_node, depth])
return sub_tree_sexp_list
# Parse the program into AST trees
def ast_parse(candidate, lang='python'):
LANGUAGE = Language('evaluation/gorilla/my-languages.so', lang)
parser = Parser()
parser.set_language(LANGUAGE)
candidate_tree = parser.parse(bytes(candidate, 'utf8')).root_node
return candidate_tree
# Get all the arguments in the ast tree
def get_args(node):
if node.child_count == 0:
return []
args_list = []
for child in node.children[0].children[0].children[1].children:
if 'model=' in child.text.decode() or 'model =' in child.text.decode():
args_list.append(child.children[2].text)
elif (
child.text.decode() != '('
and child.text.decode() != ')'
and child.text.decode() != ','
):
args_list.append(child.text)
return args_list
# Check if there is an api match
def ast_check(candidate_subtree_list, base_tree_list):
for idx, base_tree in enumerate(base_tree_list):
if base_tree.children[0].children[0].child_count == 0:
continue
api_name = base_tree.children[0].children[0].children[0].text
for candidate_tree in candidate_subtree_list:
if candidate_tree[3] == api_name:
break
# Now we have a sub-tree
candidate_tree = candidate_tree[2]
args_list = get_args(base_tree)
if len(args_list) == 0:
continue
ast_match = True
for arg in args_list:
if arg.decode().lstrip("'").rstrip("'") not in candidate_tree.text.decode():
ast_match = False
break
if ast_match:
return idx
return -1
def ast_eval_tf(api_database, qa_pairs, ast_database, question_id, response):
# Check correctness
correct = False
hallucination = False
output = response
# Index the "api_call" domain
output = output.split('api_call')
if len(output) == 1:
api_call = output[0]
else:
# Parse the output
output = output[1].split('api_provider')[0]
if ':' not in output:
start = 0
else:
start = output.index(':')
if ')' not in output:
end = -2
else:
end = output.rindex(')')
api_call = output[start + 2 : end + 1]
# Parse the api_call into AST tree
ast_tree = ast_parse(api_call)
# Search for a subtree
ast_subtree_list = get_all_sub_trees(ast_tree)
# Check which ast tree is matching
database_index = ast_check(ast_subtree_list, ast_database)
# We cannot index this ast in our database
if database_index == -1:
hallucination = True
# We index our reference api_call
ref_api_call = api_database[database_index]
# Check for functionality
if ref_api_call['domain'] == qa_pairs[question_id - 1]['domain']:
correct = True
return correct, hallucination