Skip to content

Commit

Permalink
Merge branch 'main' of https://github.com/huggingface/transformers in…
Browse files Browse the repository at this point in the history
…to zach/Dino-v2-with-registers
  • Loading branch information
BernardZach committed Dec 9, 2024
2 parents 0ed1114 + 7238387 commit 64fd5e1
Show file tree
Hide file tree
Showing 9 changed files with 139 additions and 35 deletions.
18 changes: 16 additions & 2 deletions docs/source/en/model_doc/ijepa.md
Original file line number Diff line number Diff line change
Expand Up @@ -18,13 +18,18 @@ rendered properly in your Markdown viewer.

## Overview

The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/pdf/2301.08243.pdf) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/abs/2301.08243) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
I-JEPA is a self-supervised learning method that predicts the representations of one part of an image based on other parts of the same image. This approach focuses on learning semantic features without relying on pre-defined invariances from hand-crafted data transformations, which can bias specific tasks, or on filling in pixel-level details, which often leads to less meaningful representations.

The abstract from the paper is the following:

This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image- based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample tar- get blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transform- ers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction.

<img src="https://huggingface.co/datasets/huggingface/documentation-images/resolve/main/transformers/model_doc/ijepa_architecture.jpg"
alt="drawing" width="600"/>

<small> I-JEPA architecture. Taken from the <a href="https://arxiv.org/abs/2301.08243">original paper.</a> </small>

This model was contributed by [jmtzt](https://huggingface.co/jmtzt).
The original code can be found [here](https://github.com/facebookresearch/ijepa).

Expand Down Expand Up @@ -63,6 +68,15 @@ similarity = cosine_similarity(embed_1, embed_2)
print(similarity)
```

## Resources

A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with I-JEPA.

<PipelineTag pipeline="image-classification"/>

- [`IJepaForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
- See also: [Image classification task guide](../tasks/image_classification)

## IJepaConfig

[[autodoc]] IJepaConfig
Expand All @@ -75,4 +89,4 @@ print(similarity)
## IJepaForImageClassification

[[autodoc]] IJepaForImageClassification
- forward
- forward
6 changes: 3 additions & 3 deletions src/transformers/agents/prompts.py
Original file line number Diff line number Diff line change
Expand Up @@ -129,7 +129,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
```<end_action>
---
Above example were using tools that might not exist for you. You only have acces to those Tools:
Above example were using tools that might not exist for you. You only have access to these Tools:
<<tool_names>>
Remember to make sure that variables you use are all defined.
Expand Down Expand Up @@ -256,7 +256,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
}<end_action>
Above example were using notional tools that might not exist for you. You only have acces to those tools:
Above example were using notional tools that might not exist for you. You only have access to these tools:
<<tool_descriptions>>
Here are the rules you should always follow to solve your task:
Expand Down Expand Up @@ -348,7 +348,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
final_answer(pope_current_age)
```<end_action>
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool):
Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have access to these tools (and no other tool):
<<tool_descriptions>>
Expand Down
4 changes: 3 additions & 1 deletion src/transformers/feature_extraction_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -213,6 +213,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`).
Returns:
[`BatchFeature`]: The same instance after modification.
Expand All @@ -222,6 +223,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":

new_data = {}
device = kwargs.get("device")
non_blocking = kwargs.get("non_blocking", False)
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
Expand All @@ -241,7 +243,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
new_data[k] = v.to(device=device)
new_data[k] = v.to(device=device, non_blocking=non_blocking)
else:
new_data[k] = v
self.data = new_data
Expand Down
7 changes: 2 additions & 5 deletions src/transformers/models/bert/modeling_bert.py
Original file line number Diff line number Diff line change
Expand Up @@ -1325,6 +1325,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
**loss_kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
Expand Down Expand Up @@ -1375,11 +1376,7 @@ def forward(

lm_loss = None
if labels is not None:
# we are doing next-token prediction; shift prediction scores and input ids by one
shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
labels = labels[:, 1:].contiguous()
loss_fct = CrossEntropyLoss()
lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)

if not return_dict:
output = (prediction_scores,) + outputs[2:]
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -491,6 +491,8 @@ def forward(
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
if "num_items_in_batch" in kwargs_encoder:
kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)

if encoder_outputs is None:
if inputs is None:
Expand Down
10 changes: 7 additions & 3 deletions src/transformers/tokenization_utils_base.py
Original file line number Diff line number Diff line change
Expand Up @@ -799,12 +799,13 @@ def as_tensor(value, dtype=None):

return self

def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
"""
Send all values to device by calling `v.to(device)` (PyTorch only).
Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
Args:
device (`str` or `torch.device`): The device to put the tensors on.
non_blocking (`bool`): Whether to perform the copy asynchronously.
Returns:
[`BatchEncoding`]: The same instance after modification.
Expand All @@ -816,7 +817,10 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
self.data = {
k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v
for k, v in self.data.items()
}
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
Expand Down
9 changes: 1 addition & 8 deletions src/transformers/trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -3649,10 +3649,7 @@ def training_step(
return loss_mb.reduce_mean().detach().to(self.args.device)

with self.compute_loss_context_manager():
if self.model_accepts_loss_kwargs:
loss = self.compute_loss(model, inputs)
else:
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)

del inputs
if (
Expand Down Expand Up @@ -5132,10 +5129,6 @@ def get_batch_samples(self, epoch_iterator, num_batches):
except StopIteration:
break

# Keep default behavior the same
if not self.model_accepts_loss_kwargs:
return batch_samples, None

if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
Expand Down
4 changes: 2 additions & 2 deletions tests/quantization/eetq_integration/test_eetq.py
Original file line number Diff line number Diff line change
Expand Up @@ -119,7 +119,7 @@ def test_quantized_model_conversion(self):

self.assertEqual(nb_linears - 1, nb_eetq_linear)

# Try with `linear_weights_not_to_quantize`
# Try with `modules_to_not_convert`
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = EetqConfig(modules_to_not_convert=["fc1"])
Expand All @@ -128,7 +128,7 @@ def test_quantized_model_conversion(self):
for module in model.modules():
if isinstance(module, EetqLinear):
nb_eetq_linear += 1

# 25 corresponds to the lm_head along with 24 fc1 layers.
self.assertEqual(nb_linears - 25, nb_eetq_linear)

def test_quantized_model(self):
Expand Down
114 changes: 103 additions & 11 deletions tests/trainer/test_trainer.py
Original file line number Diff line number Diff line change
Expand Up @@ -750,11 +750,102 @@ def test_model_init(self):
self.check_trained_model(trainer.model, alternate_seed=True)

@slow
def test_gradient_accumulation_loss_alignment(self):
def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42)
import datasets

model_name = "distilgpt2"
model_name = "nickypro/tinyllama-110M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
dataset = dataset.train_test_split(test_size=0.2)
tokenizer = AutoTokenizer.from_pretrained(model_name)

tokenizer.pad_token = tokenizer.eos_token

def tokenize_function(examples):
return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)

tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)

data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)

model = AutoModelForCausalLM.from_pretrained(model_name)

base_loss_callback = StoreLossCallback()

args_kwargs = {
"report_to": "none",
"logging_steps": 1,
"max_steps": 20,
"learning_rate": 3e-4,
"disable_tqdm": True,
}

args = TrainingArguments(
"./generation",
**args_kwargs,
)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[base_loss_callback],
data_collator=data_collator,
)
assert trainer.model_accepts_loss_kwargs
trainer.train()

grad_accum_loss_callback = StoreLossCallback()
args = TrainingArguments(
"./generation",
**args_kwargs,
gradient_accumulation_steps=2,
per_device_train_batch_size=4,
)
set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[grad_accum_loss_callback],
data_collator=data_collator,
)
trainer.train()

set_seed(42)
model = AutoModelForCausalLM.from_pretrained(model_name)
broken_loss_callback = StoreLossCallback()
trainer = Trainer(
model,
args,
train_dataset=tokenized_dataset["train"],
callbacks=[broken_loss_callback],
data_collator=data_collator,
)
# disable model_accepts_loss_kwargs
trainer.model_accepts_loss_kwargs = False
trainer.train()

# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]

# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")

# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")

@slow
def test_gradient_accumulation_loss_alignment_with_loss_func(self):
set_seed(42)
import datasets

model_name = "roneneldan/TinyStories-33M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
Expand Down Expand Up @@ -836,15 +927,16 @@ def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_ite
trainer.train()

# Calculate the difference between the base loss and the grad_accum loss
diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
# These should be quite close
for diff in diff_truth:
self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")

# These should be very off
for diff in diff_broken:
self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
diff_truth = [
abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
]
diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]

# all diff truth should be quite close
self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")

# max diff broken should be very off
self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")

def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same training losses.
Expand Down

0 comments on commit 64fd5e1

Please sign in to comment.