Skip to content

Commit

Permalink
improve checkpointing, fix bugs
Browse files Browse the repository at this point in the history
  • Loading branch information
derixu committed Dec 27, 2024
1 parent bbcef89 commit 5f78a56
Show file tree
Hide file tree
Showing 2 changed files with 98 additions and 75 deletions.
45 changes: 20 additions & 25 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -22,27 +22,22 @@
from utils import HuggingFaceClassifier, chat_completion_openai


class Category:
def __init__(self):
pass

@staticmethod
def create_category(name):
if name == "criteria_v0.1":
return CategoryHardPrompt()
elif name == "if_v0.1":
return CategoryIF()
elif name == "math_v0.1":
return CategoryMath()
elif name == "creative_writing_v0.1":
return CategoryCreativeWriting()
elif name == "refusal_v0.2":
return CategoryRefusalHF()

raise Exception(f"Category name is incorrect: {name}")


class CategoryAPI(Category):
def create_category(name):
if name == "criteria_v0.1":
return CategoryHardPrompt()
elif name == "if_v0.1":
return CategoryIF()
elif name == "math_v0.1":
return CategoryMath()
elif name == "creative_writing_v0.1":
return CategoryCreativeWriting()
elif name == "refusal_v0.2":
return CategoryRefusalHF()

raise Exception(f"Category name is incorrect: {name}")


class CategoryAPI:
def __init__(self):
self.batch_size = 1
self.is_parallel = True
Expand Down Expand Up @@ -96,10 +91,10 @@ def post_process(self, judgement, uid):
pass


class CategoryHF(Category):
class CategoryHF:
def __init__(self):
self.batch_size = 1
self.is_paraellel = False
self.is_parallel = False

def get_answer(self, batch, model_name, max_tokens, temperature, api_dict):
to_label, to_label_uids = self.pre_process(batch)
Expand Down Expand Up @@ -337,11 +332,11 @@ def pre_process(self, batch):
return to_label, to_label_uids

def post_process(self, labels, to_label_uids):
outputs = defaultdict(lambda: {"label": False})
outputs = defaultdict(lambda: {"refusal": False})
query_refusals = np.where(labels)[0]

for i in query_refusals:
outputs[to_label_uids[i]] = {"label": True}
outputs[to_label_uids[i]] = {"refusal": True}

return outputs, defaultdict(
lambda: None
Expand Down
128 changes: 78 additions & 50 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,22 +13,50 @@
import orjson

from collections import defaultdict
from category import Category
from category import create_category
from utils import api_config

# Tracks which category tasks have been completed for each row/battle - key: uid, value: set of category tags
TASKS_REQUIRED = defaultdict(lambda: set())

# Tracks in progress labels - key: uid, value: current incomplete category labels
TASK_TRACKER = defaultdict(lambda: {})

# Tracks in progress raw outputs for debugging - key: uid, value: current incomplete raw outputs
LOGS_TRACKER = defaultdict(lambda: {})

LOCK = threading.RLock()

TASKS = None

"""
CACHE_DICT (dict): Cached labels
- uid (str): UID for the battle that has been cached
- category_tag
- criteria_v0.1
- specificity
- ...
- math_v0.1
- math
- if_v0.1
- if
- score
- creative_writing_v0.1
- creative_writing
- score
- refusal_v0.2
- refusal
"""
CACHE_DICT = None

"""
OUTPUT_DICT (dict): Previously outputted labels
- uid (str): UID for the battle that has been cached
- criteria_v0.1
- specificity
- ...
- math_v0.1
- math
- if_v0.1
- if
- score
- creative_writing_v0.1
- creative_writing
- score
- refusal_v0.2
- refusal
"""
OUTPUT_DICT = None


Expand Down Expand Up @@ -63,42 +91,47 @@ def get_answer(
for _, row in batch.iterrows():
uid = row["uid"]
uid_to_row[uid] = row
if "category_tag" in row:
TASK_TRACKER[uid].update(row["category_tag"])

outputs, raw_outputs = category.get_answer(
batch, model_name, max_tokens, temperature, api_dict
)

for uid in uid_to_row:
output = outputs[uid]
TASKS_REQUIRED[uid]["required_tasks"].remove(category.name_tag)
TASK_TRACKER[uid][category.name_tag] = output
line = {"uid": uid, "category_tag": {category.name_tag: output}}

if testing:
raw_output = raw_outputs[uid]
LOGS_TRACKER[uid][category.name_tag] = raw_output
line["raw_output"] = raw_output

with LOCK:
with open(answer_file, "a") as fout:
fout.write(json.dumps(line) + "\n")


def category_merge_helper(series):
"""
Given a series of dictionaries of category labels for a single battle, merge into one dict
row = uid_to_row[uid]
if not TASKS_REQUIRED[uid][
"required_tasks"
]: # Check if all required tasks completed
row["category_tag"] = TASK_TRACKER[uid]
Args:
series (pd.Series[Dict[str, Dict]]): series of dictionaries of category labels
if testing:
row["output_log"] = LOGS_TRACKER[uid]
Returns:
category_label (Dict[str, Dict]): Dictionary of all labeled categories for one battle
"""
merged = {}
for dct in series:
merged.update(dct)

row.drop(["prompt", "uid", "required_tasks"], inplace=True)
with LOCK:
with open(answer_file, "a") as fout:
fout.write(json.dumps(row.to_dict()) + "\n")
# Pandas automatically turns top-level keys into index (not good), so we create a dummy key which we remove later
return {"dummy": merged}


def category_merge(row):
id = row["uid"]
input_category = row["category_tag"] if "category_tag" in row else {}
cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {}
output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {}
output_category = OUTPUT_DICT[id] if id in OUTPUT_DICT else {}

# tries to fill in missing categories using cache first, then output
for name in TASKS:
Expand All @@ -116,7 +149,7 @@ def find_required_tasks(row):
id = row["uid"]
input_category = row["category_tag"] if "category_tag" in row else {}
cache_category = CACHE_DICT[id]["category_tag"] if id in CACHE_DICT else {}
output_category = OUTPUT_DICT[id]["category_tag"] if id in OUTPUT_DICT else {}
output_category = OUTPUT_DICT[id] if id in OUTPUT_DICT else {}

return set(
[
Expand Down Expand Up @@ -147,7 +180,7 @@ def find_required_tasks(row):
api_config(config)

# Divide categories into parallelized + non-parallel. Non-parallel for HF models - automatically parallelized
categories = [Category.create_category(name) for name in config["task_name"]]
categories = [create_category(name) for name in config["task_name"]]
parallel_categories = [category for category in categories if category.is_parallel]
not_parallel_categories = [
category for category in categories if not category.is_parallel
Expand Down Expand Up @@ -182,27 +215,24 @@ def find_required_tasks(row):
cache_dict = cache_data[["uid", "category_tag"]].set_index("uid")
print("finalizing cache_dict (should take less than 30 sec)")
CACHE_DICT = cache_dict.to_dict("index")
TASK_TRACKER.update(CACHE_DICT)
else:
CACHE_DICT = {}

if os.path.isfile(config["output_file"]):
print("loading existing output")
output_data = pd.read_json(config["output_file"], lines=True)
output_data["uid"] = output_data.question_id.map(str) + output_data.tstamp.map(
str
)
assert len(output_data) == len(output_data.uid.unique())

print(f"{len(output_data)}# of existing output just loaded")

assert "category_tag" in output_data.columns
output_dict = output_data[["uid", "category_tag"]].set_index("uid")
assert "uid" in output_data.columns

print("finalizing output_dict (should take less than 30 sec)")
OUTPUT_DICT = output_dict.to_dict("index")
TASK_TRACKER.update(
OUTPUT_DICT
) # note: this will override rows from cache dict if uids overlap
OUTPUT_DICT = (
output_data.groupby("uid")["category_tag"]
.apply(category_merge_helper)
.reset_index(level=1, drop=True) # get rid of dummy key/index
.to_dict()
)
else:
OUTPUT_DICT = {}

Expand All @@ -211,12 +241,6 @@ def find_required_tasks(row):
)
input_data["required_tasks"] = input_data.apply(find_required_tasks, axis=1)

# Update task completion tracker with already completed tasks
required_tasks_dict = (
input_data[["uid", "required_tasks"]].set_index("uid").to_dict("index")
)
TASKS_REQUIRED.update(required_tasks_dict)

not_labeled = input_data[input_data.required_tasks.map(lambda x: len(x) > 0)].copy()

print(f"{len(not_labeled)} # of conversations needs to be labeled")
Expand Down Expand Up @@ -294,13 +318,17 @@ def find_required_tasks(row):
assert os.path.isfile(config["output_file"])
print("reading output file...")
temp = pd.read_json(config["output_file"], lines=True)
temp["uid"] = temp.question_id.map(str) + temp.tstamp.map(str)
assert len(temp) == len(temp.uid.unique())

assert "category_tag" in temp.columns
output_dict = temp[["uid", "category_tag"]].set_index("uid")
assert "uid" in temp.columns

print("finalizing output_dict (should take less than 30 sec)")
OUTPUT_DICT = output_dict.to_dict("index")
OUTPUT_DICT = (
output_data.groupby("uid")["category_tag"]
.apply(category_merge_helper)
.reset_index(level=1, drop=True) # get rid of dummy key/index
.to_dict()
)

print("begin merging (should take around 1 minute or less on large dataset)")
input_data["category_tag"] = input_data.apply(category_merge, axis=1)
Expand Down

0 comments on commit 5f78a56

Please sign in to comment.