diff --git a/donut/util.py b/donut/util.py index b5a0bf7c..7f0d4dde 100755 --- a/donut/util.py +++ b/donut/util.py @@ -137,14 +137,13 @@ def __getitem__(self, idx: int) -> Tuple[torch.Tensor, torch.Tensor, torch.Tenso class JSONParseEvaluator: """ - Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score + Calculate n-TED(Normalized Tree Edit Distance) based accuracy and F1 accuracy score """ @staticmethod def flatten(data: dict): """ Convert Dictionary into Non-nested Dictionary - Example: input(dict) { @@ -153,13 +152,15 @@ def flatten(data: dict): {"name" : ["juice"], "count" : ["1"]}, ] } - output(dict) - { - "menu.name": ["cake", "juice"], - "menu.count": ["2", "1"], - } + output(list) + [ + ("menu.name", "cake"), + ("menu.count", "2"), + ("menu.name", "juice"), + ("menu.count", "1"), + ] """ - flatten_data = defaultdict(list) + flatten_data = list() def _flatten(value, key=""): if type(value) is dict: @@ -169,10 +170,10 @@ def _flatten(value, key=""): for value_item in value: _flatten(value_item, key) else: - flatten_data[key].append(value) + flatten_data.append((key, value)) _flatten(data) - return dict(flatten_data) + return flatten_data @staticmethod def update_cost(label1: str, label2: str): @@ -225,10 +226,11 @@ def normalize_dict(self, data: Union[Dict, List, Any]): elif isinstance(data, list): if all(isinstance(item, dict) for item in data): new_data = [] - for item in sorted(data, key=lambda x: str(sorted(x.items()))): + for item in data: item = self.normalize_dict(item) if item: new_data.append(item) + new_data = sorted(new_data, key=lambda x: str(x.keys())+str(x.values())) else: new_data = sorted([str(item) for item in data if type(item) in {str, int, float} and str(item)]) else: @@ -243,14 +245,14 @@ def cal_f1(self, preds: List[dict], answers: List[dict]): total_tp, total_fn_or_fp = 0, 0 for pred, answer in zip(preds, answers): pred, answer = self.flatten(self.normalize_dict(pred)), self.flatten(self.normalize_dict(answer)) - for pred_key, pred_values in pred.items(): - for pred_value in pred_values: - if pred_key in answer and pred_value in answer[pred_key]: - answer[pred_key].remove(pred_value) - total_tp += 1 - else: - total_fn_or_fp += 1 - return total_tp / (total_tp + (total_fn_or_fp) / 2) + for field in pred: + if field in answer: + total_tp += 1 + answer.remove(field) + else: + total_fn_or_fp += 1 + total_fn_or_fp += len(answer) + return total_tp / (total_tp + total_fn_or_fp / 2) def construct_tree_from_dict(self, data: Union[Dict, List], node_name: str = None): """