-
Notifications
You must be signed in to change notification settings - Fork 401
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #766 from k-ivey/prompt_augmentation
Add support for prompt augmentation
- Loading branch information
Showing
16 changed files
with
320 additions
and
1 deletion.
There are no files selected for viewing
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,19 @@ | ||
textattack.llms package | ||
========================= | ||
|
||
.. automodule:: textattack.llms | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
|
||
.. automodule:: textattack.llms.huggingface_llm_wrapper | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
|
||
.. automodule:: textattack.llms.chat_gpt_wrapper | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,13 @@ | ||
textattack.prompt_augmentation package | ||
======================================= | ||
|
||
.. automodule:: textattack.prompt_augmentation | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: | ||
|
||
|
||
.. automodule:: textattack.prompt_augmentation.prompt_augmentation_pipeline | ||
:members: | ||
:undoc-members: | ||
:show-inheritance: |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,25 @@ | ||
def test_prompt_augmentation_pipeline(): | ||
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer | ||
|
||
from textattack.augmentation.recipes import CheckListAugmenter | ||
from textattack.constraints.pre_transformation import UnmodifiableIndices | ||
from textattack.llms import HuggingFaceLLMWrapper | ||
from textattack.prompt_augmentation import PromptAugmentationPipeline | ||
|
||
model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small") | ||
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small") | ||
model_wrapper = HuggingFaceLLMWrapper(model, tokenizer) | ||
|
||
augmenter = CheckListAugmenter() | ||
|
||
pipeline = PromptAugmentationPipeline(augmenter, model_wrapper) | ||
|
||
prompt = "As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: Poor Ben Bratt couldn't find stardom if MapQuest emailed him point-to-point driving directions." | ||
prompt_constraints = [UnmodifiableIndices([2, 3, 10, 12, 14])] | ||
|
||
output = pipeline(prompt, prompt_constraints) | ||
|
||
assert len(output) == 1 | ||
assert len(output[0]) == 2 | ||
assert "could not" in output[0][0] | ||
assert "negative" in output[0][1] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
24 changes: 24 additions & 0 deletions
24
textattack/constraints/pre_transformation/unmodifiable_indices.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,24 @@ | ||
from textattack.constraints import PreTransformationConstraint | ||
|
||
|
||
class UnmodifiableIndices(PreTransformationConstraint): | ||
"""A constraint that prevents the modification of certain words at specific | ||
indices. | ||
Args: | ||
indices (list(int)): A list of indices which are unmodifiable | ||
""" | ||
|
||
def __init__(self, indices): | ||
self.unmodifiable_indices = indices | ||
|
||
def _get_modifiable_indices(self, current_text): | ||
unmodifiable_set = current_text.convert_from_original_idxs( | ||
self.unmodifiable_indices | ||
) | ||
return set( | ||
i for i in range(0, len(current_text.words)) if i not in unmodifiable_set | ||
) | ||
|
||
def extra_repr_keys(self): | ||
return ["indices"] |
33 changes: 33 additions & 0 deletions
33
textattack/constraints/pre_transformation/unmodifiable_phrases.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,33 @@ | ||
from collections import defaultdict | ||
|
||
from textattack.constraints import PreTransformationConstraint | ||
|
||
|
||
class UnmodifablePhrases(PreTransformationConstraint): | ||
"""A constraint that prevents the modification of specified phrases or | ||
words. | ||
Args: | ||
phrases (list(str)): A list of strings that cannot be modified | ||
""" | ||
|
||
def __init__(self, phrases): | ||
self.length_to_phrases = defaultdict(set) | ||
for phrase in phrases: | ||
self.length_to_phrases[len(phrase.split())].add(phrase.lower()) | ||
|
||
def _get_modifiable_indices(self, current_text): | ||
phrase_indices = set() | ||
|
||
for phrase_length in self.length_to_phrases.keys(): | ||
for i in range(len(current_text.words) - phrase_length + 1): | ||
if ( | ||
" ".join(current_text.words[i : i + phrase_length]) | ||
in self.length_to_phrases[phrase_length] | ||
): | ||
phrase_indices |= set(range(i, i + phrase_length)) | ||
|
||
return set(i for i in range(len(current_text.words)) if i not in phrase_indices) | ||
|
||
def extra_repr_keys(self): | ||
return ["phrases"] |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,16 @@ | ||
""" | ||
Large Language Models | ||
====================== | ||
TextAttack can generate responses to prompts using LLMs, which take in a list of strings and outputs a list of responses. | ||
We've provided an implementation around two common LLM patterns: | ||
1. `HuggingFaceLLMWrapper` for LLMs in HuggingFace | ||
2. `ChatGptWrapper` for OpenAI's ChatGPT model | ||
""" | ||
|
||
from .chat_gpt_wrapper import ChatGptWrapper | ||
from .huggingface_llm_wrapper import HuggingFaceLLMWrapper |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,37 @@ | ||
import os | ||
|
||
from textattack.models.wrappers import ModelWrapper | ||
|
||
|
||
class ChatGptWrapper(ModelWrapper): | ||
"""A wrapper around OpenAI's ChatGPT model. Note that you must provide your | ||
own API key to use this wrapper. | ||
Args: | ||
model_name (:obj:`str`): The name of the GPT model to use. See the OpenAI documentation | ||
for a list of latest model names | ||
key_environment_variable (:obj:`str`, 'optional`, defaults to :obj:`OPENAI_API_KEY`): | ||
The environment variable that the API key is set to | ||
""" | ||
|
||
def __init__( | ||
self, model_name="gpt-3.5-turbo", key_environment_variable="OPENAI_API_KEY" | ||
): | ||
from openai import OpenAI | ||
|
||
self.model_name = model_name | ||
self.client = OpenAI(api_key=os.getenv(key_environment_variable)) | ||
|
||
def __call__(self, text_input_list): | ||
"""Returns a list of responses to the given input list.""" | ||
if isinstance(text_input_list, str): | ||
text_input_list = [text_input_list] | ||
|
||
outputs = [] | ||
for text in text_input_list: | ||
completion = self.client.chat.completions.create( | ||
model=self.model_name, messages=[{"role": "user", "content": text}] | ||
) | ||
outputs.append(completion.choices[0].message) | ||
|
||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,29 @@ | ||
from textattack.models.wrappers import ModelWrapper | ||
|
||
|
||
class HuggingFaceLLMWrapper(ModelWrapper): | ||
"""A wrapper around HuggingFace for LLMs. | ||
Args: | ||
model: A HuggingFace pretrained LLM | ||
tokenizer: A HuggingFace pretrained tokenizer | ||
""" | ||
|
||
def __init__(self, model, tokenizer): | ||
self.model = model | ||
self.tokenizer = tokenizer | ||
|
||
def __call__(self, text_input_list): | ||
"""Returns a list of responses to the given input list.""" | ||
model_device = next(self.model.parameters()).device | ||
input_ids = self.tokenizer(text_input_list, return_tensors="pt").input_ids | ||
input_ids.to(model_device) | ||
|
||
outputs = self.model.generate( | ||
input_ids, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id | ||
) | ||
|
||
responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True) | ||
if len(text_input_list) == 1: | ||
return responses[0] | ||
return responses |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,9 @@ | ||
""" | ||
Prompt Augmentation | ||
===================== | ||
This package includes functions used to augment a prompt for a LLM | ||
""" | ||
|
||
from .prompt_augmentation_pipeline import PromptAugmentationPipeline |
46 changes: 46 additions & 0 deletions
46
textattack/prompt_augmentation/prompt_augmentation_pipeline.py
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,46 @@ | ||
from textattack.constraints import PreTransformationConstraint | ||
|
||
|
||
class PromptAugmentationPipeline: | ||
"""A prompt augmentation pipeline to augment a prompt and obtain the | ||
responses from a LLM on the augmented prompts. | ||
Args: | ||
augmenter (textattack.Augmenter): the augmenter to use to | ||
augment the prompt | ||
llm (textattack.ModelWrapper): the LLM to generate responses | ||
to the augmented data | ||
""" | ||
|
||
def __init__(self, augmenter, llm): | ||
self.augmenter = augmenter | ||
self.llm = llm | ||
|
||
def __call__(self, prompt, prompt_constraints=[]): | ||
"""Augments the given prompt using the augmenter and generates | ||
responses using the LLM. | ||
Args: | ||
prompt (:obj:`str`): the prompt to augment and generate responses | ||
prompt_constraints (List(textattack.constraints.PreTransformationConstraint)): a list of pretransformation | ||
constraints to apply to the given prompt | ||
Returns a list of tuples of strings, where the first string in the pair is the augmented prompt and the second | ||
is the response to the augmented prompt from the LLM | ||
""" | ||
for constraint in prompt_constraints: | ||
if isinstance(constraint, PreTransformationConstraint): | ||
self.augmenter.pre_transformation_constraints.append(constraint) | ||
else: | ||
raise ValueError( | ||
"Prompt constraints must be of type PreTransformationConstraint" | ||
) | ||
|
||
augmented_prompts = self.augmenter.augment(prompt) | ||
for _ in range(len(prompt_constraints)): | ||
self.augmenter.pre_transformation_constraints.pop() | ||
|
||
outputs = [] | ||
for augmented_prompt in augmented_prompts: | ||
outputs.append((augmented_prompt, self.llm(augmented_prompt))) | ||
return outputs |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters