-
Notifications
You must be signed in to change notification settings - Fork 2
/
openai_api_labels.py
66 lines (60 loc) · 1.92 KB
/
openai_api_labels.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
# %%
import pandas as pd
import openai
from tqdm.auto import tqdm
from utils.store import get_csv, to_csv, is_file
# %%
pd.set_option("display.max_colwidth", 200)
MODEL = "stablelm-base-alpha-3b"
OVERWRITE = False
# %%
DIRECTIONS = [
"kmeans_simple_train_ADJ_layer1",
"pca_simple_train_ADJ_layer1",
"mean_diff_simple_train_ADJ_layer1",
"logistic_regression_simple_train_ADJ_layer1",
"das_simple_train_ADJ_layer1",
]
SUFFIX = "_bin_samples.csv"
# %%
with open("api_key.txt", "r") as f:
openai.api_key = f.read()
# %%
def classify_tokens(file_name: str, max_rows: int = 1_000):
csv_df = get_csv(file_name, MODEL)
csv_df.head()
prefix = "Your job is to classify the sentiment of a given token (i.e. word or word fragment) into Positive/Neutral/Negative."
sentiment_data = []
assert len(csv_df) < max_rows
for idx, row in tqdm(csv_df.iterrows(), total=len(csv_df)):
token = row["token"]
context = row["text"]
chat_completion = openai.ChatCompletion.create(
model="gpt-4",
messages=[
{
"role": "user",
"content": f"{prefix} Token: '{token}'. Context: '{context}'. Sentiment: ",
}
],
)
sentiment_data.append(chat_completion.choices[0].message.content)
if idx > max_rows:
break
out_df = csv_df.iloc[: len(sentiment_data)].copy()
out_df["sentiment"] = sentiment_data
to_csv(out_df, f"labelled_{file_name}", MODEL)
out_df
# %%
bar = tqdm(DIRECTIONS)
for direction in bar:
file_name = direction + SUFFIX
bar.set_description(f"Classifying {file_name}")
labelled_file = f"labelled_{file_name}"
if is_file(labelled_file, MODEL) and not OVERWRITE:
print(f"Skipping {labelled_file}")
continue
bar.set_description(f"Classifying {file_name} by calling OpenAI API...")
classify_tokens(file_name)
# %%
# %%