From c468243c6a95464a89a500128f3676ab7ea70db9 Mon Sep 17 00:00:00 2001 From: Haonan Zhang Date: Mon, 25 Nov 2024 15:39:35 +0800 Subject: [PATCH] Update multi_round.py --- mmevol/mmevol_sft_data/multi_round.py | 47 +++++++++++---------------- 1 file changed, 19 insertions(+), 28 deletions(-) diff --git a/mmevol/mmevol_sft_data/multi_round.py b/mmevol/mmevol_sft_data/multi_round.py index 13555f0b..bde792a4 100644 --- a/mmevol/mmevol_sft_data/multi_round.py +++ b/mmevol/mmevol_sft_data/multi_round.py @@ -126,17 +126,8 @@ def __init__(self, print('Unknown API Base. ') sys.exit(-1) - self.api_base="http://47.88.8.18:8088/api/ask" - # self.api_base = "http://47.88.8.18:8088/api/ask?tenant=gpt-4o-mini" - - # m6 - # self.key = "eyJ0eXAiOiJqd3QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VybmFtZSI6ImZhbnpoaWhhby5memhAYWxpYmFiYS1pbmMuY29tIiwicGFzc3dvcmQiOiIwOGU3NTk1ZjgyYTk4ZGY2NDRjMmI0NDM4NzM1Y2Y4Y2U0NDBmMWNjIiwiZXhwIjoyMDA5NjA5NTM4fQ.CmIJOx7fvERV2PP7eQ3sZVLhtO1aRB2B5DU7BIETVC8" - - # coai - # self.key = "eyJ0eXAiOiJqd3QiLCJhbGciOiJIUzI1NiJ9.eyJ1c2VybmFtZSI6IjI1ODczMCIsInBhc3N3b3JkIjoiMjU4NzMwMTIzIiwiZXhwIjoyMDE5NTUwNzAxfQ.JuqnTa7yauGkSzWkBiEig1K_rxvfAYTXS9F9_m-h4q8" - - # norm - self.key = "eyJhbGciOiJIUzI1NiIsInR5cCI6Imp3dCJ9.eyJ1c2VybmFtZSI6IjQ0MzQ1NSIsInBhc3N3b3JkIjoiNDQzNDU1MTIzIiwiZXhwIjoyMDMxNzA1NTA3fQ.7g4a6t9dKcRXVRa7MwQb5m2oirFu1OxjXhWbNM0w50s" + self.api_base = "" + self.key = "" # self.model="gpt-4o-2024-08-06" self.model = "gpt-4o-mini" @@ -489,23 +480,23 @@ def filter_round3(meta_data, conversation_v3_path): gen_save_path = osp.join(round_path, "gen_qa/{}.json") raw_save_path = osp.join(round_path, "evo_path/{}.json") - # patience = 0 - # while True: + patience = 0 + while True: - # # evol - # num_messages = evolution_parallel(seed_data_path, img_path, gen_save_path, raw_save_path=raw_save_path, round_n=round_n, root_path=root_path) - - # # post-process - # data_process.func_4_qa( - # path = round_path, - # data_path = osp.join(round_path, "gen_qa"), - # data_path_corrected = osp.join(round_path, "gen_qa_corrected"), - # round_n = round_n - # ) - # patience += 1 - # if len(num_messages) < 50 or patience >= 5: - # print("Round: {} QA Evo Finished".format(round_n)) - # break + # evol + num_messages = evolution_parallel(seed_data_path, img_path, gen_save_path, raw_save_path=raw_save_path, round_n=round_n, root_path=root_path) + + # post-process + data_process.func_4_qa( + path = round_path, + data_path = osp.join(round_path, "gen_qa"), + data_path_corrected = osp.join(round_path, "gen_qa_corrected"), + round_n = round_n + ) + patience += 1 + if len(num_messages) < 50 or patience >= 5: + print("Round: {} QA Evo Finished".format(round_n)) + break patience = 0 while True: @@ -543,4 +534,4 @@ def filter_round3(meta_data, conversation_v3_path): merged_data.append(data) json.dump(merged_data, open(final_save_path, "w"), indent=4) - print("Saveing file to {}".format(final_save_path)) \ No newline at end of file + print("Saveing file to {}".format(final_save_path))