Skip to content

Commit

Permalink
refusal classifier added
Browse files Browse the repository at this point in the history
  • Loading branch information
derixu committed Nov 30, 2024
1 parent 1cd4b74 commit 1e103f3
Show file tree
Hide file tree
Showing 5 changed files with 254 additions and 73 deletions.
21 changes: 21 additions & 0 deletions fastchat/serve/monitor/classify/category.py
Original file line number Diff line number Diff line change
Expand Up @@ -11,6 +11,10 @@
import ast
import re

from utils import (
HuggingFaceRefusalClassifier
)


class Category:
def __init__(self):
Expand All @@ -26,6 +30,8 @@ def create_category(name):
return CategoryMath()
elif name == "creative_writing_v0.1":
return CategoryCreativeWriting()
elif name == "refusal_v0.2":
return CategoryRefusalFineTuned()

raise Exception(f"Category name is incorrect: {name}")

Expand Down Expand Up @@ -174,3 +180,18 @@ def post_process(self, judgment):
score = self.get_score(judgment=judgment)
bool_score = bool(score == "yes") if score else False
return {"creative_writing": bool_score, "score": score}


class CategoryRefusalFineTuned(Category):
def __init__(self):
super().__init__()
self.name_tag = "refusal_v0.2"
self.prompt_template = "Here is the user query:\n<user_query>\n{QUERY}\n</user_query>\n\nHere is the LLM response to the user:\n<llm_response>\n{RESPONSE}\n</llm_response>"
self.classifier = HuggingFaceRefusalClassifier()

def pre_process(self, conversation):
conv = []
for i in range(0, len(conversation), 2):
args = {"QUERY": conversation[i]["content"], "RESPONSE": conversation[i+1]["content"]}
conv.append(self.prompt_template.format(**args))
return conv
13 changes: 7 additions & 6 deletions fastchat/serve/monitor/classify/config.yaml
Original file line number Diff line number Diff line change
@@ -1,16 +1,17 @@
# Yaml config file for category classification

input_file: null # json
input_file: "/home/derryxu/FastChatRefusal/fastchat/serve/monitor/classify/refusal_test/unlabeled.json" # json
cache_file: null # json
output_file: null # json line
output_file: "/home/derryxu/FastChatRefusal/fastchat/serve/monitor/classify/refusal_test/labeled.jsonl" # json line

convert_to_json: True

task_name:
- criteria_v0.1
- if_v0.1
- math_v0.1
- creative_writing_v0.1
# - criteria_v0.1
# - if_v0.1
# - math_v0.1
# - creative_writing_v0.1
- refusal_v0.2

model_name: null
name: llama-3-70b-instruct
Expand Down
104 changes: 37 additions & 67 deletions fastchat/serve/monitor/classify/label.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,7 +2,6 @@
import json
import pandas as pd
import os
import time
import concurrent.futures
import tqdm
import yaml
Expand All @@ -12,18 +11,17 @@

from category import Category

from utils import (
api_config,
chat_completion_openai
)

LOCK = threading.RLock()

TASKS = None
CACHE_DICT = None
OUTPUT_DICT = None

# API setting constants
API_MAX_RETRY = None
API_RETRY_SLEEP = None
API_ERROR_OUTPUT = None


# load config args from config yaml files
def make_config(config_file: str) -> dict:
Expand All @@ -42,53 +40,6 @@ def get_endpoint(endpoint_list):
return api_dict


def chat_completion_openai(model, messages, temperature, max_tokens, api_dict=None):
import openai

if api_dict:
client = openai.OpenAI(
base_url=api_dict["api_base"],
api_key=api_dict["api_key"],
)
else:
client = openai.OpenAI()

output = API_ERROR_OUTPUT
for _ in range(API_MAX_RETRY):
try:
# print(messages)
completion = client.chat.completions.create(
model=model,
messages=messages,
temperature=temperature,
max_tokens=max_tokens,
# extra_body={"guided_choice": GUIDED_CHOICES} if GUIDED_CHOICES else None,
)
output = completion.choices[0].message.content
# print(output)
break
except openai.RateLimitError as e:
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
except openai.BadRequestError as e:
print(messages)
print(type(e), e)
break
except openai.APIConnectionError as e:
print(messages)
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
except openai.InternalServerError as e:
print(messages)
print(type(e), e)
time.sleep(API_RETRY_SLEEP)
except Exception as e:
print(type(e), e)
break

return output


def get_answer(
question: dict,
model_name: str,
Expand All @@ -107,16 +58,38 @@ def get_answer(
output_log = {}

for category in categories:
conv = category.pre_process(question["prompt"])
output = chat_completion_openai(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
)
# Dump answers
category_tag[category.name_tag] = category.post_process(output)

if category.name_tag == "refusal_v0.2":
refusal_classifier = category.classifier

conv_a = category.pre_process(question["conversation_a"])
conv_b = category.pre_process(question["conversation_b"])

refusal_prompts = conv_a + conv_b
batch_size = 16
refusal_results = []
for i in range(0, len(refusal_prompts), batch_size):
batch_prompts = refusal_prompts[i:i + batch_size]
batch_results = refusal_classifier.classify_batch(batch_prompts)
refusal_results.extend(batch_results)

# If any query/resp classified as refusal, entire conversation is refusal
output = any(refusal_results)

# Dump answers
category_tag[category.name_tag] = output

else:
conv = category.pre_process(question["prompt"])
output = chat_completion_openai(
model=model_name,
messages=conv,
temperature=temperature,
max_tokens=max_tokens,
api_dict=api_dict,
)
# Dump answers
category_tag[category.name_tag] = category.post_process(output)

if testing:
output_log[category.name_tag] = output
Expand Down Expand Up @@ -178,10 +151,7 @@ def find_required_tasks(row):
exit()

config = make_config(args.config)

API_MAX_RETRY = config["max_retry"]
API_RETRY_SLEEP = config["retry_sleep"]
API_ERROR_OUTPUT = config["error_output"]
api_config(config)

categories = [Category.create_category(name) for name in config["task_name"]]
TASKS = config["task_name"]
Expand Down
Loading

0 comments on commit 1e103f3

Please sign in to comment.