Skip to content

Commit

Permalink
Set filter length to 2 for mistral
Browse files Browse the repository at this point in the history
  • Loading branch information
ojh31 committed Nov 22, 2023
1 parent feb7301 commit ceeea6c
Showing 1 changed file with 60 additions and 40 deletions.
100 changes: 60 additions & 40 deletions utils/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -127,14 +127,12 @@ def get(
filter_length: Optional[int] = None,
truncate_length: Optional[int] = None,
drop_duplicates: bool = True,
prepend_space: Optional[bool] = None,
prepend_space: bool = True,
verbose: bool = False,
) -> CircularList:
assert (
filter_length is not None or truncate_length is not None
), "Must specify at least one of filter_length or truncate_length"
if prepend_space is None:
prepend_space = "mistral" not in model.cfg.model_name
words: list = self._prompts_dict[key]
if prepend_space:
words = [" " + word.strip() for word in words]
Expand Down Expand Up @@ -259,38 +257,39 @@ def get_prompts(
neutral_prompts: Union[List[str], None]

# Read lists from config
FILTER_LENGTH = 2 if "mistral" in model.cfg.model_name.lower() else 1
pos_answers: CircularList[str] = prompt_config.get(
"positive_answer_tokens", model, filter_length=1
"positive_answer_tokens", model, filter_length=FILTER_LENGTH
)
neg_answers: CircularList[str] = prompt_config.get(
"negative_answer_tokens", model, filter_length=1
"negative_answer_tokens", model, filter_length=FILTER_LENGTH
)
positive_adjectives: CircularList[str] = prompt_config.get(
"positive_core_adjectives", model, filter_length=1
"positive_core_adjectives", model, filter_length=FILTER_LENGTH
)
negative_adjectives: CircularList[str] = prompt_config.get(
"negative_core_adjectives", model, filter_length=1
"negative_core_adjectives", model, filter_length=FILTER_LENGTH
)
neutral_adjectives: CircularList[str] = prompt_config.get(
"neutral_core_adjectives", model, filter_length=1
"neutral_core_adjectives", model, filter_length=FILTER_LENGTH
)
positive_verbs: CircularList[str] = prompt_config.get(
"positive_verbs", model, filter_length=1
"positive_verbs", model, filter_length=FILTER_LENGTH
)
negative_verbs: CircularList[str] = prompt_config.get(
"negative_verbs", model, filter_length=1
"negative_verbs", model, filter_length=FILTER_LENGTH
)
neutral_verbs: CircularList[str] = prompt_config.get(
"neutral_verbs", model, filter_length=1
"neutral_verbs", model, filter_length=FILTER_LENGTH
)
positive_top_adjectives: CircularList[str] = prompt_config.get(
"positive_top_adjectives", model, filter_length=1
"positive_top_adjectives", model, filter_length=FILTER_LENGTH
)
negative_top_adjectives: CircularList[str] = prompt_config.get(
"negative_top_adjectives", model, filter_length=1
"negative_top_adjectives", model, filter_length=FILTER_LENGTH
)
neutral_top_adjectives: CircularList[str] = prompt_config.get(
"neutral_top_adjectives", model, filter_length=1
"neutral_top_adjectives", model, filter_length=FILTER_LENGTH
)

# Get prompt type/format
Expand All @@ -315,10 +314,10 @@ def get_prompts(
elif prompt_type == PromptType.SIMPLE_TRAIN:
n_prompts = min(len(positive_adjectives), len(negative_adjectives))
positive_adjectives = prompt_config.get(
"positive_adjectives_train", model, filter_length=1
"positive_adjectives_train", model, filter_length=FILTER_LENGTH
)
negative_adjectives = prompt_config.get(
"negative_adjectives_train", model, filter_length=1
"negative_adjectives_train", model, filter_length=FILTER_LENGTH
)
neutral_prompts = None
pos_prompts = [
Expand All @@ -331,17 +330,17 @@ def get_prompts(
]
elif prompt_type == PromptType.SIMPLE_TEST:
positive_adjectives = prompt_config.get(
"positive_adjectives_test", model, filter_length=1
"positive_adjectives_test", model, filter_length=FILTER_LENGTH
)
negative_adjectives = prompt_config.get(
"negative_adjectives_test", model, filter_length=1
"negative_adjectives_test", model, filter_length=FILTER_LENGTH
)
n_prompts = min(len(positive_adjectives), len(negative_adjectives))
positive_adjectives = prompt_config.get(
"positive_adjectives_test", model, filter_length=1
"positive_adjectives_test", model, filter_length=FILTER_LENGTH
)
negative_adjectives = prompt_config.get(
"negative_adjectives_test", model, filter_length=1
"negative_adjectives_test", model, filter_length=FILTER_LENGTH
)
neutral_prompts = None
pos_prompts = [
Expand All @@ -354,10 +353,10 @@ def get_prompts(
]
elif prompt_type == PromptType.SIMPLE_BOOK:
positive_adjectives = prompt_config.get(
"positive_comment_adjectives", model, filter_length=1
"positive_comment_adjectives", model, filter_length=FILTER_LENGTH
)
negative_adjectives = prompt_config.get(
"negative_comment_adjectives", model, filter_length=1
"negative_comment_adjectives", model, filter_length=FILTER_LENGTH
)
n_prompts = min(
len(positive_adjectives),
Expand All @@ -375,19 +374,23 @@ def get_prompts(
]
neutral_prompts = None
pos_answers = prompt_config.get(
"positive_answer_adjectives", model, filter_length=1
"positive_answer_adjectives", model, filter_length=FILTER_LENGTH
)
neg_answers = prompt_config.get(
"negative_answer_adjectives", model, filter_length=1
"negative_answer_adjectives", model, filter_length=FILTER_LENGTH
)
elif prompt_type == PromptType.SIMPLE_RES:
positive_nouns = prompt_config.get("positive_nouns", model, filter_length=1)
negative_nouns = prompt_config.get("negative_nouns", model, filter_length=1)
positive_nouns = prompt_config.get(
"positive_nouns", model, filter_length=FILTER_LENGTH
)
negative_nouns = prompt_config.get(
"negative_nouns", model, filter_length=FILTER_LENGTH
)
positive_infinitives = prompt_config.get(
"positive_infinitives", model, filter_length=1
"positive_infinitives", model, filter_length=FILTER_LENGTH
)
negative_infinitives = prompt_config.get(
"negative_infinitives", model, filter_length=1
"negative_infinitives", model, filter_length=FILTER_LENGTH
)
n_prompts = min(
len(positive_nouns),
Expand All @@ -404,14 +407,18 @@ def get_prompts(
for i in range(n_prompts)
]
neutral_prompts = None
pos_answers = prompt_config.get("positive_very_answers", model, filter_length=1)
neg_answers = prompt_config.get("negative_very_answers", model, filter_length=1)
pos_answers = prompt_config.get(
"positive_very_answers", model, filter_length=FILTER_LENGTH
)
neg_answers = prompt_config.get(
"negative_very_answers", model, filter_length=FILTER_LENGTH
)
elif prompt_type == PromptType.SIMPLE_PRODUCT:
positive_feelings = prompt_config.get(
"positive_feelings", model, filter_length=1
"positive_feelings", model, filter_length=FILTER_LENGTH
)
negative_feelings = prompt_config.get(
"negative_feelings", model, filter_length=1
"negative_feelings", model, filter_length=FILTER_LENGTH
)
n_prompts = min(
len(positive_feelings),
Expand All @@ -428,8 +435,12 @@ def get_prompts(
for i in range(n_prompts)
]
neutral_prompts = None
pos_answers = prompt_config.get("positive_moods", model, filter_length=1)
neg_answers = prompt_config.get("negative_moods", model, filter_length=1)
pos_answers = prompt_config.get(
"positive_moods", model, filter_length=FILTER_LENGTH
)
neg_answers = prompt_config.get(
"negative_moods", model, filter_length=FILTER_LENGTH
)
elif prompt_type == PromptType.SIMPLE_ADVERB:
positive_adverbs = prompt_config.get("positive_adverbs", model, filter_length=2)
negative_adverbs = prompt_config.get("negative_adverbs", model, filter_length=2)
Expand All @@ -441,8 +452,12 @@ def get_prompts(
formatter.format(ADV=negative_adverbs[i]) for i in range(n_prompts)
]
neutral_prompts = None
pos_answers = prompt_config.get("positive_moods", model, filter_length=1)
neg_answers = prompt_config.get("negative_moods", model, filter_length=1)
pos_answers = prompt_config.get(
"positive_moods", model, filter_length=FILTER_LENGTH
)
neg_answers = prompt_config.get(
"negative_moods", model, filter_length=FILTER_LENGTH
)
elif prompt_type == PromptType.SIMPLE_FRENCH:
positive_french_adj = prompt_config.get(
"positive_french_adjectives", model, filter_length=3
Expand Down Expand Up @@ -474,10 +489,10 @@ def get_prompts(
)
elif prompt_type == PromptType.PROPER_NOUNS:
positive_proper = prompt_config.get(
"positive_proper_nouns", model, filter_length=1
"positive_proper_nouns", model, filter_length=FILTER_LENGTH
)
negative_proper = prompt_config.get(
"negative_proper_nouns", model, filter_length=1
"negative_proper_nouns", model, filter_length=FILTER_LENGTH
)
n_prompts = min(len(positive_proper), len(negative_proper))
pos_prompts = [
Expand All @@ -488,8 +503,12 @@ def get_prompts(
]
neutral_prompts = None
elif prompt_type == PromptType.MEDICAL:
positive_medical = prompt_config.get("positive_medical", model, filter_length=1)
negative_medical = prompt_config.get("negative_medical", model, filter_length=1)
positive_medical = prompt_config.get(
"positive_medical", model, filter_length=FILTER_LENGTH
)
negative_medical = prompt_config.get(
"negative_medical", model, filter_length=FILTER_LENGTH
)
n_prompts = min(len(positive_medical), len(negative_medical))
pos_prompts = [
formatter.format(MED=positive_medical[i]) for i in range(n_prompts)
Expand Down Expand Up @@ -641,6 +660,7 @@ def get_prompts(
raise ValueError(f"Invalid prompt type: {prompt_type}")

# check length match
assert len(pos_prompts) > 0, f"Positive prompts is empty: {pos_prompts}"
assert len(pos_prompts) == len(neg_prompts), (
f"Number of positive prompts ({len(pos_prompts)}) "
f"does not match number of negative prompts ({len(neg_prompts)}). "
Expand Down

0 comments on commit ceeea6c

Please sign in to comment.