Skip to content

Commit

Permalink
Merge pull request #766 from k-ivey/prompt_augmentation
Browse files Browse the repository at this point in the history
Add support for prompt augmentation
  • Loading branch information
qiyanjun authored Mar 11, 2024
2 parents 29f38b2 + 6469138 commit a8dfcb1
Show file tree
Hide file tree
Showing 16 changed files with 320 additions and 1 deletion.
21 changes: 21 additions & 0 deletions README.md
Original file line number Diff line number Diff line change
Expand Up @@ -392,6 +392,27 @@ You can also create your own augmenter from scratch by importing transformations
['What I cannot creae, I do not understand.', 'What I cannot creat, I do not understand.', 'What I cannot create, I do not nderstand.', 'What I cannot create, I do nt understand.', 'Wht I cannot create, I do not understand.']
```

#### Prompt Augmentation
In additional to augmentation of regular text, you can augment prompts and then generate responses to
the augmented prompts using a large language model (LLMs). The augmentation is performed using the same
`Augmenter` as above. To generate responses, you can use your own LLM, a HuggingFace LLM, or an OpenAI LLM.
Here's an example using a pretrained HuggingFace LLM:

```python
>>> from textattack.augmentation import EmbeddingAugmenter
>>> from transformers import AutoModelForSeq2SeqLM, AutoTokenizer
>>> from textattack.llms import HuggingFaceLLMWrapper
>>> from textattack.prompt_augmentation import PromptAugmentationPipeline
>>> augmenter = EmbeddingAugmenter(transformations_per_example=3)
>>> model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
>>> tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
>>> model_wrapper = HuggingFaceLLMWrapper(model, tokenizer)
>>> pipeline = PromptAugmentationPipeline(augmenter, model_wrapper)
>>> pipeline("Classify the following piece of text as `positive` or `negative`: This movie is great!")
[('Classify the following piece of text as `positive` or `negative`: This film is great!', ['positive']), ('Classify the following piece of text as `positive` or `negative`: This movie is fabulous!', ['positive']), ('Classify the following piece of text as `positive` or `negative`: This movie is wonderful!', ['positive'])]
```


### Training Models: `textattack train`

Our model training code is available via `textattack train` to help you train LSTMs,
Expand Down
12 changes: 12 additions & 0 deletions docs/apidoc/textattack.constraints.pre_transformation.rst
Original file line number Diff line number Diff line change
Expand Up @@ -43,3 +43,15 @@ textattack.constraints.pre\_transformation package
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.constraints.pre_transformation.unmodifiable_indices
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.constraints.pre_transformation.unmodifiable_phrases
:members:
:undoc-members:
:show-inheritance:
19 changes: 19 additions & 0 deletions docs/apidoc/textattack.llms.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
textattack.llms package
=========================

.. automodule:: textattack.llms
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.llms.huggingface_llm_wrapper
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.llms.chat_gpt_wrapper
:members:
:undoc-members:
:show-inheritance:
13 changes: 13 additions & 0 deletions docs/apidoc/textattack.prompt_augmentation.rst
Original file line number Diff line number Diff line change
@@ -0,0 +1,13 @@
textattack.prompt_augmentation package
=======================================

.. automodule:: textattack.prompt_augmentation
:members:
:undoc-members:
:show-inheritance:


.. automodule:: textattack.prompt_augmentation.prompt_augmentation_pipeline
:members:
:undoc-members:
:show-inheritance:
2 changes: 2 additions & 0 deletions docs/apidoc/textattack.rst
Original file line number Diff line number Diff line change
Expand Up @@ -19,9 +19,11 @@ textattack package
textattack.datasets
textattack.goal_function_results
textattack.goal_functions
textattack.llms
textattack.loggers
textattack.metrics
textattack.models
textattack.prompt_augmentation
textattack.search_methods
textattack.shared
textattack.transformations
Expand Down
31 changes: 31 additions & 0 deletions tests/test_constraints/test_pretransformation_constraints.py
Original file line number Diff line number Diff line change
Expand Up @@ -103,3 +103,34 @@ def test_stopword_modification(
set(range(len(entailment_attacked_text.words)))
- {1, 2, 3, 8, 9, 11, 16, 17, 20, 22, 25, 31, 34, 39, 40, 41, 43, 44}
)

def test_unmodifiable_indices(
self, sentence_attacked_text, entailment_attacked_text
):
constraint = textattack.constraints.pre_transformation.UnmodifiableIndices(
[4, 5]
)
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
set(range(len(sentence_attacked_text.words))) - {4, 5}
)
sentence_attacked_text = sentence_attacked_text.delete_word_at_index(2)
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
set(range(len(sentence_attacked_text.words))) - {3, 4}
)
assert constraint._get_modifiable_indices(entailment_attacked_text) == (
set(range(len(entailment_attacked_text.words))) - {4, 5}
)
entailment_attacked_text = (
entailment_attacked_text.insert_text_after_word_index(0, "two words")
)
assert constraint._get_modifiable_indices(entailment_attacked_text) == (
set(range(len(entailment_attacked_text.words))) - {6, 7}
)

def test_unmodifiable_phrases(self, sentence_attacked_text):
constraint = textattack.constraints.pre_transformation.UnmodifablePhrases(
["South Korea's", "oil", "monday"]
)
assert constraint._get_modifiable_indices(sentence_attacked_text) == (
set(range(len(sentence_attacked_text.words))) - {0, 1, 9, 22}
)
25 changes: 25 additions & 0 deletions tests/test_prompt_augmentation.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
def test_prompt_augmentation_pipeline():
from transformers import AutoModelForSeq2SeqLM, AutoTokenizer

from textattack.augmentation.recipes import CheckListAugmenter
from textattack.constraints.pre_transformation import UnmodifiableIndices
from textattack.llms import HuggingFaceLLMWrapper
from textattack.prompt_augmentation import PromptAugmentationPipeline

model = AutoModelForSeq2SeqLM.from_pretrained("google/flan-t5-small")
tokenizer = AutoTokenizer.from_pretrained("google/flan-t5-small")
model_wrapper = HuggingFaceLLMWrapper(model, tokenizer)

augmenter = CheckListAugmenter()

pipeline = PromptAugmentationPipeline(augmenter, model_wrapper)

prompt = "As a sentiment classifier, determine whether the following text is 'positive' or 'negative'. Please classify: Poor Ben Bratt couldn't find stardom if MapQuest emailed him point-to-point driving directions."
prompt_constraints = [UnmodifiableIndices([2, 3, 10, 12, 14])]

output = pipeline(prompt, prompt_constraints)

assert len(output) == 1
assert len(output[0]) == 2
assert "could not" in output[0][0]
assert "negative" in output[0][1]
2 changes: 2 additions & 0 deletions textattack/constraints/pre_transformation/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -13,3 +13,5 @@
from .max_num_words_modified import MaxNumWordsModified
from .min_word_length import MinWordLength
from .max_modification_rate import MaxModificationRate
from .unmodifiable_indices import UnmodifiableIndices
from .unmodifiable_phrases import UnmodifablePhrases
24 changes: 24 additions & 0 deletions textattack/constraints/pre_transformation/unmodifiable_indices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,24 @@
from textattack.constraints import PreTransformationConstraint


class UnmodifiableIndices(PreTransformationConstraint):
"""A constraint that prevents the modification of certain words at specific
indices.
Args:
indices (list(int)): A list of indices which are unmodifiable
"""

def __init__(self, indices):
self.unmodifiable_indices = indices

def _get_modifiable_indices(self, current_text):
unmodifiable_set = current_text.convert_from_original_idxs(
self.unmodifiable_indices
)
return set(
i for i in range(0, len(current_text.words)) if i not in unmodifiable_set
)

def extra_repr_keys(self):
return ["indices"]
33 changes: 33 additions & 0 deletions textattack/constraints/pre_transformation/unmodifiable_phrases.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,33 @@
from collections import defaultdict

from textattack.constraints import PreTransformationConstraint


class UnmodifablePhrases(PreTransformationConstraint):
"""A constraint that prevents the modification of specified phrases or
words.
Args:
phrases (list(str)): A list of strings that cannot be modified
"""

def __init__(self, phrases):
self.length_to_phrases = defaultdict(set)
for phrase in phrases:
self.length_to_phrases[len(phrase.split())].add(phrase.lower())

def _get_modifiable_indices(self, current_text):
phrase_indices = set()

for phrase_length in self.length_to_phrases.keys():
for i in range(len(current_text.words) - phrase_length + 1):
if (
" ".join(current_text.words[i : i + phrase_length])
in self.length_to_phrases[phrase_length]
):
phrase_indices |= set(range(i, i + phrase_length))

return set(i for i in range(len(current_text.words)) if i not in phrase_indices)

def extra_repr_keys(self):
return ["phrases"]
16 changes: 16 additions & 0 deletions textattack/llms/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
"""
Large Language Models
======================
TextAttack can generate responses to prompts using LLMs, which take in a list of strings and outputs a list of responses.
We've provided an implementation around two common LLM patterns:
1. `HuggingFaceLLMWrapper` for LLMs in HuggingFace
2. `ChatGptWrapper` for OpenAI's ChatGPT model
"""

from .chat_gpt_wrapper import ChatGptWrapper
from .huggingface_llm_wrapper import HuggingFaceLLMWrapper
37 changes: 37 additions & 0 deletions textattack/llms/chat_gpt_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,37 @@
import os

from textattack.models.wrappers import ModelWrapper


class ChatGptWrapper(ModelWrapper):
"""A wrapper around OpenAI's ChatGPT model. Note that you must provide your
own API key to use this wrapper.
Args:
model_name (:obj:`str`): The name of the GPT model to use. See the OpenAI documentation
for a list of latest model names
key_environment_variable (:obj:`str`, 'optional`, defaults to :obj:`OPENAI_API_KEY`):
The environment variable that the API key is set to
"""

def __init__(
self, model_name="gpt-3.5-turbo", key_environment_variable="OPENAI_API_KEY"
):
from openai import OpenAI

self.model_name = model_name
self.client = OpenAI(api_key=os.getenv(key_environment_variable))

def __call__(self, text_input_list):
"""Returns a list of responses to the given input list."""
if isinstance(text_input_list, str):
text_input_list = [text_input_list]

outputs = []
for text in text_input_list:
completion = self.client.chat.completions.create(
model=self.model_name, messages=[{"role": "user", "content": text}]
)
outputs.append(completion.choices[0].message)

return outputs
29 changes: 29 additions & 0 deletions textattack/llms/huggingface_llm_wrapper.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,29 @@
from textattack.models.wrappers import ModelWrapper


class HuggingFaceLLMWrapper(ModelWrapper):
"""A wrapper around HuggingFace for LLMs.
Args:
model: A HuggingFace pretrained LLM
tokenizer: A HuggingFace pretrained tokenizer
"""

def __init__(self, model, tokenizer):
self.model = model
self.tokenizer = tokenizer

def __call__(self, text_input_list):
"""Returns a list of responses to the given input list."""
model_device = next(self.model.parameters()).device
input_ids = self.tokenizer(text_input_list, return_tensors="pt").input_ids
input_ids.to(model_device)

outputs = self.model.generate(
input_ids, max_new_tokens=512, pad_token_id=self.tokenizer.eos_token_id
)

responses = self.tokenizer.batch_decode(outputs, skip_special_tokens=True)
if len(text_input_list) == 1:
return responses[0]
return responses
9 changes: 9 additions & 0 deletions textattack/prompt_augmentation/__init__.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
"""
Prompt Augmentation
=====================
This package includes functions used to augment a prompt for a LLM
"""

from .prompt_augmentation_pipeline import PromptAugmentationPipeline
46 changes: 46 additions & 0 deletions textattack/prompt_augmentation/prompt_augmentation_pipeline.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
from textattack.constraints import PreTransformationConstraint


class PromptAugmentationPipeline:
"""A prompt augmentation pipeline to augment a prompt and obtain the
responses from a LLM on the augmented prompts.
Args:
augmenter (textattack.Augmenter): the augmenter to use to
augment the prompt
llm (textattack.ModelWrapper): the LLM to generate responses
to the augmented data
"""

def __init__(self, augmenter, llm):
self.augmenter = augmenter
self.llm = llm

def __call__(self, prompt, prompt_constraints=[]):
"""Augments the given prompt using the augmenter and generates
responses using the LLM.
Args:
prompt (:obj:`str`): the prompt to augment and generate responses
prompt_constraints (List(textattack.constraints.PreTransformationConstraint)): a list of pretransformation
constraints to apply to the given prompt
Returns a list of tuples of strings, where the first string in the pair is the augmented prompt and the second
is the response to the augmented prompt from the LLM
"""
for constraint in prompt_constraints:
if isinstance(constraint, PreTransformationConstraint):
self.augmenter.pre_transformation_constraints.append(constraint)
else:
raise ValueError(
"Prompt constraints must be of type PreTransformationConstraint"
)

augmented_prompts = self.augmenter.augment(prompt)
for _ in range(len(prompt_constraints)):
self.augmenter.pre_transformation_constraints.pop()

outputs = []
for augmented_prompt in augmented_prompts:
outputs.append((augmented_prompt, self.llm(augmented_prompt)))
return outputs
2 changes: 1 addition & 1 deletion textattack/shared/attacked_text.py
Original file line number Diff line number Diff line change
Expand Up @@ -317,7 +317,7 @@ def convert_from_original_idxs(self, idxs: Iterable[int]) -> List[int]:
elif isinstance(idxs, set):
idxs = list(idxs)

elif not isinstance(idxs, [list, np.ndarray]):
elif not isinstance(idxs, (list, np.ndarray)):
raise TypeError(
f"convert_from_original_idxs got invalid idxs type {type(idxs)}"
)
Expand Down

0 comments on commit a8dfcb1

Please sign in to comment.