diff --git a/fastchat/serve/monitor/classify/category.py b/fastchat/serve/monitor/classify/category.py index b60e964ba..76287b0d3 100644 --- a/fastchat/serve/monitor/classify/category.py +++ b/fastchat/serve/monitor/classify/category.py @@ -22,27 +22,22 @@ from utils import HuggingFaceClassifier, chat_completion_openai -class Category: - def __init__(self): - pass - - @staticmethod - def create_category(name): - if name == "criteria_v0.1": - return CategoryHardPrompt() - elif name == "if_v0.1": - return CategoryIF() - elif name == "math_v0.1": - return CategoryMath() - elif name == "creative_writing_v0.1": - return CategoryCreativeWriting() - elif name == "refusal_v0.2": - return CategoryRefusalHF() - - raise Exception(f"Category name is incorrect: {name}") - - -class CategoryAPI(Category): +def create_category(name): + if name == "criteria_v0.1": + return CategoryHardPrompt() + elif name == "if_v0.1": + return CategoryIF() + elif name == "math_v0.1": + return CategoryMath() + elif name == "creative_writing_v0.1": + return CategoryCreativeWriting() + elif name == "refusal_v0.2": + return CategoryRefusalHF() + + raise Exception(f"Category name is incorrect: {name}") + + +class CategoryAPI: def __init__(self): self.batch_size = 1 self.is_parallel = True @@ -96,10 +91,10 @@ def post_process(self, judgement, uid): pass -class CategoryHF(Category): +class CategoryHF: def __init__(self): self.batch_size = 1 - self.is_paraellel = False + self.is_parallel = False def get_answer(self, batch, model_name, max_tokens, temperature, api_dict): to_label, to_label_uids = self.pre_process(batch) @@ -337,11 +332,11 @@ def pre_process(self, batch): return to_label, to_label_uids def post_process(self, labels, to_label_uids): - outputs = defaultdict(lambda: {"label": False}) + outputs = defaultdict(lambda: {"refusal": False}) query_refusals = np.where(labels)[0] for i in query_refusals: - outputs[to_label_uids[i]] = {"label": True} + outputs[to_label_uids[i]] = {"refusal": True} return outputs, defaultdict( lambda: None diff --git a/fastchat/serve/monitor/classify/label.py b/fastchat/serve/monitor/classify/label.py index a996d943c..89dcca6d7 100644 --- a/fastchat/serve/monitor/classify/label.py +++ b/fastchat/serve/monitor/classify/label.py @@ -13,22 +13,50 @@ import orjson from collections import defaultdict -from category import Category +from category import create_category from utils import api_config -# Tracks which category tasks have been completed for each row/battle - key: uid, value: set of category tags -TASKS_REQUIRED = defaultdict(lambda: set()) - -# Tracks in progress labels - key: uid, value: current incomplete category labels -TASK_TRACKER = defaultdict(lambda: {}) - -# Tracks in progress raw outputs for debugging - key: uid, value: current incomplete raw outputs -LOGS_TRACKER = defaultdict(lambda: {}) - LOCK = threading.RLock() TASKS = None + +""" +CACHE_DICT (dict): Cached labels +- uid (str): UID for the battle that has been cached + - category_tag + - criteria_v0.1 + - specificity + - ... + - math_v0.1 + - math + - if_v0.1 + - if + - score + - creative_writing_v0.1 + - creative_writing + - score + - refusal_v0.2 + - refusal +""" CACHE_DICT = None + +""" +OUTPUT_DICT (dict): Previously outputted labels +- uid (str): UID for the battle that has been cached + - criteria_v0.1 + - specificity + - ... + - math_v0.1 + - math + - if_v0.1 + - if + - score + - creative_writing_v0.1 + - creative_writing + - score + - refusal_v0.2 + - refusal +""" OUTPUT_DICT = None @@ -63,8 +91,6 @@ def get_answer( for _, row in batch.iterrows(): uid = row["uid"] uid_to_row[uid] = row - if "category_tag" in row: - TASK_TRACKER[uid].update(row["category_tag"]) outputs, raw_outputs = category.get_answer( batch, model_name, max_tokens, temperature, api_dict @@ -72,33 +98,40 @@ def get_answer( for uid in uid_to_row: output = outputs[uid] - TASKS_REQUIRED[uid]["required_tasks"].remove(category.name_tag) - TASK_TRACKER[uid][category.name_tag] = output + line = {"uid": uid, "category_tag": {category.name_tag: output}} if testing: raw_output = raw_outputs[uid] - LOGS_TRACKER[uid][category.name_tag] = raw_output + line["raw_output"] = raw_output + + with LOCK: + with open(answer_file, "a") as fout: + fout.write(json.dumps(line) + "\n") + + +def category_merge_helper(series): + """ + Given a series of dictionaries of category labels for a single battle, merge into one dict - row = uid_to_row[uid] - if not TASKS_REQUIRED[uid][ - "required_tasks" - ]: # Check if all required tasks completed - row["category_tag"] = TASK_TRACKER[uid] + Args: + series (pd.Series[Dict[str, Dict]]): series of dictionaries of category labels - if testing: - row["output_log"] = LOGS_TRACKER[uid] + Returns: + category_label (Dict[str, Dict]): Dictionary of all labeled categories for one battle + """ + merged = {} + for dct in series: + merged.update(dct) - row.drop(["prompt", "uid", "required_tasks"], inplace=True) - with LOCK: - with open(answer_file, "a") as fout: - fout.write(json.dumps(row.to_dict()) + "\n") + # Pandas automatically turns top-level keys into index (not good), so we create a dummy key which we remove later + return {"dummy": merged} def category_merge(row): id = row["uid"] input_category = row["category_tag"] if "category_tag" in row else {} cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} - output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} + output_category = OUTPUT_DICT[id] if id in OUTPUT_DICT else {} # tries to fill in missing categories using cache first, then output for name in TASKS: @@ -116,7 +149,7 @@ def find_required_tasks(row): id = row["uid"] input_category = row["category_tag"] if "category_tag" in row else {} cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {} - output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {} + output_category = OUTPUT_DICT[id] if id in OUTPUT_DICT else {} return set( [ @@ -147,7 +180,7 @@ def find_required_tasks(row): api_config(config) # Divide categories into parallelized + non-parallel. Non-parallel for HF models - automatically parallelized - categories = [Category.create_category(name) for name in config["task_name"]] + categories = [create_category(name) for name in config["task_name"]] parallel_categories = [category for category in categories if category.is_parallel] not_parallel_categories = [ category for category in categories if not category.is_parallel @@ -182,27 +215,24 @@ def find_required_tasks(row): cache_dict = cache_data[["uid", "category_tag"]].set_index("uid") print("finalizing cache_dict (should take less than 30 sec)") CACHE_DICT = cache_dict.to_dict("index") - TASK_TRACKER.update(CACHE_DICT) else: CACHE_DICT = {} if os.path.isfile(config["output_file"]): print("loading existing output") output_data = pd.read_json(config["output_file"], lines=True) - output_data["uid"] = output_data.question_id.map(str) + output_data.tstamp.map( - str - ) - assert len(output_data) == len(output_data.uid.unique()) - print(f"{len(output_data)}# of existing output just loaded") assert "category_tag" in output_data.columns - output_dict = output_data[["uid", "category_tag"]].set_index("uid") + assert "uid" in output_data.columns + print("finalizing output_dict (should take less than 30 sec)") - OUTPUT_DICT = output_dict.to_dict("index") - TASK_TRACKER.update( - OUTPUT_DICT - ) # note: this will override rows from cache dict if uids overlap + OUTPUT_DICT = ( + output_data.groupby("uid")["category_tag"] + .apply(category_merge_helper) + .reset_index(level=1, drop=True) # get rid of dummy key/index + .to_dict() + ) else: OUTPUT_DICT = {} @@ -211,12 +241,6 @@ def find_required_tasks(row): ) input_data["required_tasks"] = input_data.apply(find_required_tasks, axis=1) - # Update task completion tracker with already completed tasks - required_tasks_dict = ( - input_data[["uid", "required_tasks"]].set_index("uid").to_dict("index") - ) - TASKS_REQUIRED.update(required_tasks_dict) - not_labeled = input_data[input_data.required_tasks.map(lambda x: len(x) > 0)].copy() print(f"{len(not_labeled)} # of conversations needs to be labeled") @@ -294,13 +318,17 @@ def find_required_tasks(row): assert os.path.isfile(config["output_file"]) print("reading output file...") temp = pd.read_json(config["output_file"], lines=True) - temp["uid"] = temp.question_id.map(str) + temp.tstamp.map(str) - assert len(temp) == len(temp.uid.unique()) assert "category_tag" in temp.columns - output_dict = temp[["uid", "category_tag"]].set_index("uid") + assert "uid" in temp.columns + print("finalizing output_dict (should take less than 30 sec)") - OUTPUT_DICT = output_dict.to_dict("index") + OUTPUT_DICT = ( + output_data.groupby("uid")["category_tag"] + .apply(category_merge_helper) + .reset_index(level=1, drop=True) # get rid of dummy key/index + .to_dict() + ) print("begin merging (should take around 1 minute or less on large dataset)") input_data["category_tag"] = input_data.apply(category_merge, axis=1)