Skip to content

Commit

Permalink
Merge branch 'huggingface:main' into main
Browse files Browse the repository at this point in the history
  • Loading branch information
Quentin-Anthony authored Sep 26, 2024
2 parents cb1d1d9 + e32521b commit efcf16a
Show file tree
Hide file tree
Showing 36 changed files with 366 additions and 113 deletions.
12 changes: 6 additions & 6 deletions docs/source/en/model_doc/chameleon.md
Original file line number Diff line number Diff line change
Expand Up @@ -19,7 +19,7 @@ rendered properly in your Markdown viewer.
## Overview

The Chameleon model was proposed in [Chameleon: Mixed-Modal Early-Fusion Foundation Models
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet.
](https://arxiv.org/abs/2405.09818v1) by META AI Chameleon Team. Chameleon is a Vision-Language Model that use vector quantization to tokenize images which enables the model to generate multimodal output. The model takes images and texts as input, including an interleaved format, and generates textual response. Image generation module is not released yet.


The abstract from the paper is the following:
Expand Down Expand Up @@ -61,7 +61,7 @@ The original code can be found [here](https://github.com/facebookresearch/chamel

### Single image inference

Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token.
Chameleon is a gated model so make sure to have access and login to Hugging Face Hub using a token.
Here's how to load the model and perform inference in half-precision (`torch.bfloat16`):

```python
Expand All @@ -78,7 +78,7 @@ url = 'http://images.cocodataset.org/val2017/000000039769.jpg'
image = Image.open(requests.get(url, stream=True).raw)
prompt = "What do you see in this image?<image>"

inputs = processor(prompt, image, return_tensors="pt").to(model.device, dtype=torch.bfloat16)
inputs = processor(images=image, text=prompt, return_tensors="pt").to(model.device, dtype=torch.bfloat16)

# autoregressively complete prompt
output = model.generate(**inputs, max_new_tokens=50)
Expand Down Expand Up @@ -117,7 +117,7 @@ prompts = [

# We can simply feed images in the order they have to be used in the text prompt
# Each "<image>" token uses one image leaving the next for the subsequent "<image>" tokens
inputs = processor(text=prompts, images=[image_stop, image_cats, image_snowman], padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)
inputs = processor(images=[image_stop, image_cats, image_snowman], text=prompts, padding=True, return_tensors="pt").to(device="cuda", dtype=torch.bfloat16)

# Generate
generate_ids = model.generate(**inputs, max_new_tokens=50)
Expand Down Expand Up @@ -162,8 +162,8 @@ from transformers import ChameleonForConditionalGeneration

model_id = "facebook/chameleon-7b"
model = ChameleonForConditionalGeneration.from_pretrained(
model_id,
torch_dtype=torch.bfloat16,
model_id,
torch_dtype=torch.bfloat16,
low_cpu_mem_usage=True,
attn_implementation="flash_attention_2"
).to(0)
Expand Down
4 changes: 2 additions & 2 deletions docs/source/en/model_doc/omdet-turbo.md
Original file line number Diff line number Diff line change
Expand Up @@ -49,8 +49,8 @@ from PIL import Image

from transformers import AutoProcessor, OmDetTurboForObjectDetection

processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-tiny")
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-tiny")
processor = AutoProcessor.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")
model = OmDetTurboForObjectDetection.from_pretrained("omlab/omdet-turbo-swin-tiny-hf")

url = "http://images.cocodataset.org/val2017/000000039769.jpg"
image = Image.open(requests.get(url, stream=True).raw)
Expand Down
2 changes: 1 addition & 1 deletion scripts/check_tokenizers.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,7 +88,7 @@ def check_details(line, spm_ids, tok_ids, slow, fast):
if tok_ids[first + k : first + k + min_width] == spm_ids[first + i : first + i + min_width]
]
for j in possible_matches:
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], sp, tok) and check_details(
if check_diff(spm_ids[first : first + i], tok_ids[first : first + j], slow, fast) and check_details(
line,
spm_ids[first + i : last],
tok_ids[first + j : last],
Expand Down
6 changes: 6 additions & 0 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -1645,6 +1645,12 @@ def can_generate(cls) -> bool:
# Model class overwrites `generate` (e.g. time series models) -> can generate
if str(cls.__name__) in str(cls.generate):
return True
# The class inherits from a class that can generate (recursive check) -> can generate
for base in cls.__bases__:
if not hasattr(base, "can_generate"):
continue
if "PreTrainedModel" not in str(base) and base.can_generate():
return True
# BC: Detects whether `prepare_inputs_for_generation` has been overwritten in the model. Prior to v4.45, this
# was how we detected whether a model could generate.
if "GenerationMixin" not in str(cls.prepare_inputs_for_generation):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/bert/tokenization_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -88,6 +88,9 @@ class BertTokenizer(PreTrainedTokenizer):
strip_accents (`bool`, *optional*):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for `lowercase` (as in the original BERT).
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
"""

vocab_files_names = VOCAB_FILES_NAMES
Expand All @@ -105,6 +108,7 @@ def __init__(
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
clean_up_tokenization_spaces=True,
**kwargs,
):
if not os.path.isfile(vocab_file):
Expand Down Expand Up @@ -136,6 +140,7 @@ def __init__(
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/blip/modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -233,7 +233,6 @@ def __init__(self, config: BlipVisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
Expand All @@ -245,14 +244,14 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
return self.position_embedding

class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
class_pos_embed = self.position_embedding[:, :1]
patch_pos_embed = self.position_embedding[:, 1:]

dim = embeddings.shape[-1]

Expand Down
9 changes: 4 additions & 5 deletions src/transformers/models/blip_2/modeling_blip_2.py
Original file line number Diff line number Diff line change
Expand Up @@ -200,7 +200,6 @@ def __init__(self, config: Blip2VisionConfig):

self.position_embedding = nn.Parameter(torch.randn(1, self.num_positions, self.embed_dim))

# Copied from transformers.models.vit.modeling_vit.ViTEmbeddings.interpolate_pos_encoding
def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width: int) -> torch.Tensor:
"""
This method allows to interpolate the pre-trained position encodings, to be able to use the model on higher resolution
Expand All @@ -212,14 +211,14 @@ def interpolate_pos_encoding(self, embeddings: torch.Tensor, height: int, width:
"""

num_patches = embeddings.shape[1] - 1
num_positions = self.position_embeddings.shape[1] - 1
num_positions = self.position_embedding.shape[1] - 1

# always interpolate when tracing to ensure the exported model works for dynamic input shapes
if not torch.jit.is_tracing() and num_patches == num_positions and height == width:
return self.position_embeddings
return self.position_embedding

class_pos_embed = self.position_embeddings[:, :1]
patch_pos_embed = self.position_embeddings[:, 1:]
class_pos_embed = self.position_embedding[:, :1]
patch_pos_embed = self.position_embedding[:, 1:]

dim = embeddings.shape[-1]

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@

from transformers import (
ChameleonConfig,
ChameleonForCausalLM,
ChameleonForConditionalGeneration,
ChameleonImageProcessor,
ChameleonProcessor,
)
Expand All @@ -49,10 +49,10 @@
Thereafter, models can be loaded via:
```py
from transformers import ChameleonForCausalLM, LlamaTokenizer
from transformers import ChameleonForConditionalGeneration, LlamaTokenizerFast
model = ChameleonForCausalLM.from_pretrained("/output/path")
tokenizer = LlamaTokenizer.from_pretrained("/output/path")
model = ChameleonForConditionalGeneration.from_pretrained("/output/path")
tokenizer = LlamaTokenizerFast.from_pretrained("/output/path")
```
Important note: you need to be able to host the whole model in RAM to execute this script (even if the biggest versions
Expand Down Expand Up @@ -372,7 +372,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
vocabulary_map=vocabulary_map,
)
with init_empty_weights():
model = ChameleonForCausalLM(config)
model = ChameleonForConditionalGeneration(config)

model.load_state_dict(state_dict, assign=True, strict=False)
model.save_pretrained(model_path, safe_serialization=True)
Expand All @@ -397,7 +397,7 @@ def permute(w, n_heads, dim1=dim, dim2=dim):
# taken from https://github.com/facebookresearch/chameleon/blob/7a72f40aa5f462965c8374f25257f55b65b25ff4/data/prompts_for_human_evaluations.jsonl
print("Loading the checkpoint in a Chameleon model...")
print("*" * 100)
model = ChameleonForCausalLM.from_pretrained(
model = ChameleonForConditionalGeneration.from_pretrained(
model_path, attn_implementation="eager", torch_dtype=torch.bfloat16, device_map="auto"
)
processor = ChameleonProcessor.from_pretrained(model_path)
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/models/chameleon/modeling_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -1568,7 +1568,7 @@ def forward(
>>> image = Image.open(requests.get("https://nineplanets.org/wp-content/uploads/2020/12/the-big-dipper-1.jpg", stream=True).raw)
>>> image_2 = Image.open(requests.get("https://www.kxan.com/wp-content/uploads/sites/40/2020/10/ORION.jpg", stream=True).raw)
>>> inputs = processor(prompt, images=[image, image_2], return_tensors="pt").to(model.device, torch.bfloat16)
>>> inputs = processor(images=[image, image_2], text=prompt, return_tensors="pt").to(model.device, torch.bfloat16)
>>> generated_ids = model.generate(**inputs, max_new_tokens=100, do_sample=False)
>>> processor.batch_decode(generated_ids, skip_special_tokens=True)[0]
Expand Down
77 changes: 41 additions & 36 deletions src/transformers/models/chameleon/processing_chameleon.py
Original file line number Diff line number Diff line change
Expand Up @@ -20,9 +20,25 @@

from ...feature_extraction_utils import BatchFeature
from ...image_utils import ImageInput
from ...processing_utils import ProcessorMixin
from ...tokenization_utils_base import PaddingStrategy, PreTokenizedInput, TextInput, TruncationStrategy
from ...utils import TensorType
from ...processing_utils import ProcessingKwargs, ProcessorMixin, TextKwargs, Unpack, _validate_images_text_input_order
from ...tokenization_utils_base import PreTokenizedInput, TextInput


class ChameleonTextKwargs(TextKwargs, total=False):
return_for_text_completion: bool


class ChameleonProcessorKwargs(ProcessingKwargs, total=False):
text_kwargs: ChameleonTextKwargs
_defaults = {
"text_kwargs": {
"padding": False,
"return_for_text_completion": False,
},
"common_kwargs": {
"return_tensors": "pt",
},
}


class ChameleonProcessor(ProcessorMixin):
Expand Down Expand Up @@ -57,13 +73,11 @@ def __init__(self, image_processor, tokenizer, image_seq_length: int = 1024, ima

def __call__(
self,
text: Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]] = None,
images: ImageInput = None,
padding: Union[bool, str, PaddingStrategy] = False,
truncation: Union[bool, str, TruncationStrategy] = None,
max_length: int = None,
return_tensors: Optional[Union[str, TensorType]] = TensorType.PYTORCH,
return_for_text_completion: bool = False,
images: Optional[ImageInput] = None,
text: Optional[Union[TextInput, PreTokenizedInput, List[TextInput], List[PreTokenizedInput]]] = None,
audio=None,
videos=None,
**kwargs: Unpack[ChameleonProcessorKwargs],
) -> BatchFeature:
"""
Main method to prepare for the model one or several sequences(s) and image(s). This method forwards the `text`
Expand All @@ -73,26 +87,13 @@ def __call__(
of the above two methods for more information.
Args:
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
text (`str`, `List[str]`, `List[List[str]]`):
The sequence or batch of sequences to be encoded. Each sequence can be a string or a list of strings
(pretokenized string). If the sequences are provided as list of strings (pretokenized), you must set
`is_split_into_words=True` (to lift the ambiguity with a batch of sequences).
images (`PIL.Image.Image`, `np.ndarray`, `torch.Tensor`, `List[PIL.Image.Image]`, `List[np.ndarray]`, `List[torch.Tensor]`):
The image or batch of images to be prepared. Each image can be a PIL image, NumPy array or PyTorch
tensor. Both channels-first and channels-last formats are supported.
padding (`bool`, `str` or [`~utils.PaddingStrategy`], *optional*, defaults to `False`):
Select a strategy to pad the returned sequences (according to the model's padding side and padding
index) among:
- `True` or `'longest'`: Pad to the longest sequence in the batch (or no padding if only a single
sequence if provided).
- `'max_length'`: Pad to a maximum length specified with the argument `max_length` or to the maximum
acceptable input length for the model if that argument is not provided.
- `False` or `'do_not_pad'` (default): No padding (i.e., can output a batch with sequences of different
lengths).
max_length (`int`, *optional*):
Maximum length of the returned list and optionally padding length (see above).
truncation (`bool`, *optional*):
Activates truncation to cut input sequences longer than `max_length` to `max_length`.
return_tensors (`str` or [`~utils.TensorType`], *optional*):
If set, will return tensors of a particular framework. Acceptable values are:
Expand All @@ -110,10 +111,21 @@ def __call__(
`None`).
- **pixel_values** -- Pixel values to be fed to a model. Returned when `images` is not `None`.
"""
# check if images and text inputs are reversed for BC
images, text = _validate_images_text_input_order(images, text)
if isinstance(text, str):
text = [text]
elif not isinstance(text, list) and not isinstance(text[0], str):
raise TypeError("Invalid input text. Please provide a string, or a list of strings")
if text is None and images is None:
raise ValueError("You must provide either text or images")

output_kwargs = self._merge_kwargs(
ChameleonProcessorKwargs,
tokenizer_init_kwargs=self.tokenizer.init_kwargs,
**kwargs,
)
return_for_text_completion = output_kwargs["text_kwargs"].pop("return_for_text_completion", False)

# Replace the image token with the expanded image token sequence
prompt_strings = []
Expand All @@ -124,19 +136,12 @@ def __call__(
sample += self.tokenizer.sep_token # special Chameleon treatment to add sep for chat mode
prompt_strings.append(sample)

data = self.tokenizer(
prompt_strings,
return_tensors=return_tensors,
padding=padding,
truncation=truncation,
max_length=max_length,
)
data = self.tokenizer(prompt_strings, **output_kwargs["text_kwargs"])

if images is not None:
pixel_values = self.image_processor(images, return_tensors=return_tensors)["pixel_values"]
data["pixel_values"] = pixel_values
data["pixel_values"] = self.image_processor(images, **output_kwargs["images_kwargs"])["pixel_values"]

return BatchFeature(data=data, tensor_type=return_tensors)
return BatchFeature(data=data, tensor_type=output_kwargs["common_kwargs"]["return_tensors"])

# Copied from transformers.models.clip.processing_clip.CLIPProcessor.batch_decode with CLIP->Llama
def batch_decode(self, *args, **kwargs):
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/convbert/tokenization_convbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -91,6 +91,9 @@ class ConvBertTokenizer(PreTrainedTokenizer):
strip_accents (`bool`, *optional*):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for `lowercase` (as in the original ConvBERT).
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
"""

vocab_files_names = VOCAB_FILES_NAMES
Expand All @@ -108,6 +111,7 @@ def __init__(
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
clean_up_tokenization_spaces=True,
**kwargs,
):
if not os.path.isfile(vocab_file):
Expand Down Expand Up @@ -139,6 +143,7 @@ def __init__(
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -52,6 +52,8 @@ class ConvNextV2Config(BackboneConfigMixin, PretrainedConfig):
The epsilon used by the layer normalization layers.
drop_path_rate (`float`, *optional*, defaults to 0.0):
The drop rate for stochastic depth.
image_size (`int`, *optional*, defaults to 224):
The size (resolution) of each image.
out_features (`List[str]`, *optional*):
If used as backbone, list of features to output. Can be any of `"stem"`, `"stage1"`, `"stage2"`, etc.
(depending on how many stages the model has). If unset and `out_indices` is set, will default to the
Expand Down
5 changes: 5 additions & 0 deletions src/transformers/models/distilbert/tokenization_distilbert.py
Original file line number Diff line number Diff line change
Expand Up @@ -90,6 +90,9 @@ class DistilBertTokenizer(PreTrainedTokenizer):
strip_accents (`bool`, *optional*):
Whether or not to strip all accents. If this option is not specified, then it will be determined by the
value for `lowercase` (as in the original BERT).
clean_up_tokenization_spaces (`bool`, *optional*, defaults to `True`):
Whether or not to cleanup spaces after decoding, cleanup consists in removing potential artifacts like
extra spaces.
"""

vocab_files_names = VOCAB_FILES_NAMES
Expand All @@ -108,6 +111,7 @@ def __init__(
mask_token="[MASK]",
tokenize_chinese_chars=True,
strip_accents=None,
clean_up_tokenization_spaces=True,
**kwargs,
):
if not os.path.isfile(vocab_file):
Expand Down Expand Up @@ -138,6 +142,7 @@ def __init__(
mask_token=mask_token,
tokenize_chinese_chars=tokenize_chinese_chars,
strip_accents=strip_accents,
clean_up_tokenization_spaces=clean_up_tokenization_spaces,
**kwargs,
)

Expand Down
Loading

0 comments on commit efcf16a

Please sign in to comment.