diff --git a/docs/source/en/model_doc/ijepa.md b/docs/source/en/model_doc/ijepa.md
index 32944e2617eae1..cb2afd25e20bca 100644
--- a/docs/source/en/model_doc/ijepa.md
+++ b/docs/source/en/model_doc/ijepa.md
@@ -18,13 +18,18 @@ rendered properly in your Markdown viewer.
## Overview
-The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/pdf/2301.08243.pdf) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
+The I-JEPA model was proposed in [Image-based Joint-Embedding Predictive Architecture](https://arxiv.org/abs/2301.08243) by Mahmoud Assran, Quentin Duval, Ishan Misra, Piotr Bojanowski, Pascal Vincent, Michael Rabbat, Yann LeCun, Nicolas Ballas.
I-JEPA is a self-supervised learning method that predicts the representations of one part of an image based on other parts of the same image. This approach focuses on learning semantic features without relying on pre-defined invariances from hand-crafted data transformations, which can bias specific tasks, or on filling in pixel-level details, which often leads to less meaningful representations.
The abstract from the paper is the following:
This paper demonstrates an approach for learning highly semantic image representations without relying on hand-crafted data-augmentations. We introduce the Image- based Joint-Embedding Predictive Architecture (I-JEPA), a non-generative approach for self-supervised learning from images. The idea behind I-JEPA is simple: from a single context block, predict the representations of various target blocks in the same image. A core design choice to guide I-JEPA towards producing semantic representations is the masking strategy; specifically, it is crucial to (a) sample tar- get blocks with sufficiently large scale (semantic), and to (b) use a sufficiently informative (spatially distributed) context block. Empirically, when combined with Vision Transform- ers, we find I-JEPA to be highly scalable. For instance, we train a ViT-Huge/14 on ImageNet using 16 A100 GPUs in under 72 hours to achieve strong downstream performance across a wide range of tasks, from linear classification to object counting and depth prediction.
+
+
+ I-JEPA architecture. Taken from the original paper.
+
This model was contributed by [jmtzt](https://huggingface.co/jmtzt).
The original code can be found [here](https://github.com/facebookresearch/ijepa).
@@ -63,6 +68,15 @@ similarity = cosine_similarity(embed_1, embed_2)
print(similarity)
```
+## Resources
+
+A list of official Hugging Face and community (indicated by 🌎) resources to help you get started with I-JEPA.
+
+
+
+- [`IJepaForImageClassification`] is supported by this [example script](https://github.com/huggingface/transformers/tree/main/examples/pytorch/image-classification) and [notebook](https://colab.research.google.com/github/huggingface/notebooks/blob/main/examples/image_classification.ipynb).
+- See also: [Image classification task guide](../tasks/image_classification)
+
## IJepaConfig
[[autodoc]] IJepaConfig
@@ -75,4 +89,4 @@ print(similarity)
## IJepaForImageClassification
[[autodoc]] IJepaForImageClassification
- - forward
+ - forward
\ No newline at end of file
diff --git a/src/transformers/agents/prompts.py b/src/transformers/agents/prompts.py
index 7a84b1db44faba..898a7e011a2b05 100644
--- a/src/transformers/agents/prompts.py
+++ b/src/transformers/agents/prompts.py
@@ -129,7 +129,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
```
---
-Above example were using tools that might not exist for you. You only have acces to those Tools:
+Above example were using tools that might not exist for you. You only have access to these Tools:
<>
Remember to make sure that variables you use are all defined.
@@ -256,7 +256,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
}
-Above example were using notional tools that might not exist for you. You only have acces to those tools:
+Above example were using notional tools that might not exist for you. You only have access to these tools:
<>
Here are the rules you should always follow to solve your task:
@@ -348,7 +348,7 @@ def download_prompt(prompt_or_repo_id, agent_name, mode="run"):
final_answer(pope_current_age)
```
-Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have acces to those tools (and no other tool):
+Above example were using notional tools that might not exist for you. On top of performing computations in the Python code snippets that you create, you have access to these tools (and no other tool):
<>
diff --git a/src/transformers/feature_extraction_utils.py b/src/transformers/feature_extraction_utils.py
index f3cde8180c1bd4..6e8007edbc0b78 100644
--- a/src/transformers/feature_extraction_utils.py
+++ b/src/transformers/feature_extraction_utils.py
@@ -213,6 +213,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
Will be passed to the `to(...)` function of the tensors.
kwargs (`Dict`, *optional*):
Will be passed to the `to(...)` function of the tensors.
+ To enable asynchronous data transfer, set the `non_blocking` flag in `kwargs` (defaults to `False`).
Returns:
[`BatchFeature`]: The same instance after modification.
@@ -222,6 +223,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
new_data = {}
device = kwargs.get("device")
+ non_blocking = kwargs.get("non_blocking", False)
# Check if the args are a device or a dtype
if device is None and len(args) > 0:
# device should be always the first argument
@@ -241,7 +243,7 @@ def to(self, *args, **kwargs) -> "BatchFeature":
# cast and send to device
new_data[k] = v.to(*args, **kwargs)
elif isinstance(v, torch.Tensor) and device is not None:
- new_data[k] = v.to(device=device)
+ new_data[k] = v.to(device=device, non_blocking=non_blocking)
else:
new_data[k] = v
self.data = new_data
diff --git a/src/transformers/models/bert/modeling_bert.py b/src/transformers/models/bert/modeling_bert.py
index 6b05fa648158a6..e311f93b6c81ed 100755
--- a/src/transformers/models/bert/modeling_bert.py
+++ b/src/transformers/models/bert/modeling_bert.py
@@ -1325,6 +1325,7 @@ def forward(
output_attentions: Optional[bool] = None,
output_hidden_states: Optional[bool] = None,
return_dict: Optional[bool] = None,
+ **loss_kwargs,
) -> Union[Tuple[torch.Tensor], CausalLMOutputWithCrossAttentions]:
r"""
encoder_hidden_states (`torch.FloatTensor` of shape `(batch_size, sequence_length, hidden_size)`, *optional*):
@@ -1375,11 +1376,7 @@ def forward(
lm_loss = None
if labels is not None:
- # we are doing next-token prediction; shift prediction scores and input ids by one
- shifted_prediction_scores = prediction_scores[:, :-1, :].contiguous()
- labels = labels[:, 1:].contiguous()
- loss_fct = CrossEntropyLoss()
- lm_loss = loss_fct(shifted_prediction_scores.view(-1, self.config.vocab_size), labels.view(-1))
+ lm_loss = self.loss_function(prediction_scores, labels, self.config.vocab_size, **loss_kwargs)
if not return_dict:
output = (prediction_scores,) + outputs[2:]
diff --git a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
index 0d2b911bebe582..3bff8f6acd290d 100644
--- a/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
+++ b/src/transformers/models/speech_encoder_decoder/modeling_speech_encoder_decoder.py
@@ -491,6 +491,8 @@ def forward(
kwargs_decoder = {
argument[len("decoder_") :]: value for argument, value in kwargs.items() if argument.startswith("decoder_")
}
+ if "num_items_in_batch" in kwargs_encoder:
+ kwargs_decoder["num_items_in_batch"] = kwargs_encoder.pop("num_items_in_batch", None)
if encoder_outputs is None:
if inputs is None:
diff --git a/src/transformers/tokenization_utils_base.py b/src/transformers/tokenization_utils_base.py
index 0bfcc4aa303665..f4e5b9b3aaf314 100644
--- a/src/transformers/tokenization_utils_base.py
+++ b/src/transformers/tokenization_utils_base.py
@@ -799,12 +799,13 @@ def as_tensor(value, dtype=None):
return self
- def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
+ def to(self, device: Union[str, "torch.device"], *, non_blocking: bool = False) -> "BatchEncoding":
"""
- Send all values to device by calling `v.to(device)` (PyTorch only).
+ Send all values to device by calling `v.to(device, non_blocking=non_blocking)` (PyTorch only).
Args:
device (`str` or `torch.device`): The device to put the tensors on.
+ non_blocking (`bool`): Whether to perform the copy asynchronously.
Returns:
[`BatchEncoding`]: The same instance after modification.
@@ -816,7 +817,10 @@ def to(self, device: Union[str, "torch.device"]) -> "BatchEncoding":
# Otherwise it passes the casts down and casts the LongTensor containing the token idxs
# into a HalfTensor
if isinstance(device, str) or is_torch_device(device) or isinstance(device, int):
- self.data = {k: v.to(device=device) if isinstance(v, torch.Tensor) else v for k, v in self.data.items()}
+ self.data = {
+ k: v.to(device=device, non_blocking=non_blocking) if isinstance(v, torch.Tensor) else v
+ for k, v in self.data.items()
+ }
else:
logger.warning(f"Attempting to cast a BatchEncoding to type {str(device)}. This is not supported.")
return self
diff --git a/src/transformers/trainer.py b/src/transformers/trainer.py
index af908e48e4b8c4..f7d79481809807 100755
--- a/src/transformers/trainer.py
+++ b/src/transformers/trainer.py
@@ -3649,10 +3649,7 @@ def training_step(
return loss_mb.reduce_mean().detach().to(self.args.device)
with self.compute_loss_context_manager():
- if self.model_accepts_loss_kwargs:
- loss = self.compute_loss(model, inputs)
- else:
- loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
+ loss = self.compute_loss(model, inputs, num_items_in_batch=num_items_in_batch)
del inputs
if (
@@ -5132,10 +5129,6 @@ def get_batch_samples(self, epoch_iterator, num_batches):
except StopIteration:
break
- # Keep default behavior the same
- if not self.model_accepts_loss_kwargs:
- return batch_samples, None
-
if len(batch_samples) > 0 and "labels" in batch_samples[0]:
# For now we don't support object detection
try:
diff --git a/tests/quantization/eetq_integration/test_eetq.py b/tests/quantization/eetq_integration/test_eetq.py
index 2c01f8145cba0e..f14fa076e4bb76 100644
--- a/tests/quantization/eetq_integration/test_eetq.py
+++ b/tests/quantization/eetq_integration/test_eetq.py
@@ -119,7 +119,7 @@ def test_quantized_model_conversion(self):
self.assertEqual(nb_linears - 1, nb_eetq_linear)
- # Try with `linear_weights_not_to_quantize`
+ # Try with `modules_to_not_convert`
with init_empty_weights():
model = OPTForCausalLM(config)
quantization_config = EetqConfig(modules_to_not_convert=["fc1"])
@@ -128,7 +128,7 @@ def test_quantized_model_conversion(self):
for module in model.modules():
if isinstance(module, EetqLinear):
nb_eetq_linear += 1
-
+ # 25 corresponds to the lm_head along with 24 fc1 layers.
self.assertEqual(nb_linears - 25, nb_eetq_linear)
def test_quantized_model(self):
diff --git a/tests/trainer/test_trainer.py b/tests/trainer/test_trainer.py
index f7b4a8637bff85..d33be2789761da 100644
--- a/tests/trainer/test_trainer.py
+++ b/tests/trainer/test_trainer.py
@@ -750,11 +750,102 @@ def test_model_init(self):
self.check_trained_model(trainer.model, alternate_seed=True)
@slow
- def test_gradient_accumulation_loss_alignment(self):
+ def test_gradient_accumulation_loss_alignment_with_model_loss(self):
set_seed(42)
import datasets
- model_name = "distilgpt2"
+ model_name = "nickypro/tinyllama-110M"
+ dataset_name = "wikitext"
+ dataset_config = "wikitext-2-raw-v1"
+ dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
+ dataset = dataset.train_test_split(test_size=0.2)
+ tokenizer = AutoTokenizer.from_pretrained(model_name)
+
+ tokenizer.pad_token = tokenizer.eos_token
+
+ def tokenize_function(examples):
+ return tokenizer(examples["text"], max_length=128, padding="max_length", truncation=True)
+
+ tokenized_dataset = dataset.map(tokenize_function, batched=True, remove_columns=dataset["train"].column_names)
+
+ data_collator = DataCollatorForLanguageModeling(tokenizer=tokenizer, mlm=False)
+
+ model = AutoModelForCausalLM.from_pretrained(model_name)
+
+ base_loss_callback = StoreLossCallback()
+
+ args_kwargs = {
+ "report_to": "none",
+ "logging_steps": 1,
+ "max_steps": 20,
+ "learning_rate": 3e-4,
+ "disable_tqdm": True,
+ }
+
+ args = TrainingArguments(
+ "./generation",
+ **args_kwargs,
+ )
+ trainer = Trainer(
+ model,
+ args,
+ train_dataset=tokenized_dataset["train"],
+ callbacks=[base_loss_callback],
+ data_collator=data_collator,
+ )
+ assert trainer.model_accepts_loss_kwargs
+ trainer.train()
+
+ grad_accum_loss_callback = StoreLossCallback()
+ args = TrainingArguments(
+ "./generation",
+ **args_kwargs,
+ gradient_accumulation_steps=2,
+ per_device_train_batch_size=4,
+ )
+ set_seed(42)
+ model = AutoModelForCausalLM.from_pretrained(model_name)
+ trainer = Trainer(
+ model,
+ args,
+ train_dataset=tokenized_dataset["train"],
+ callbacks=[grad_accum_loss_callback],
+ data_collator=data_collator,
+ )
+ trainer.train()
+
+ set_seed(42)
+ model = AutoModelForCausalLM.from_pretrained(model_name)
+ broken_loss_callback = StoreLossCallback()
+ trainer = Trainer(
+ model,
+ args,
+ train_dataset=tokenized_dataset["train"],
+ callbacks=[broken_loss_callback],
+ data_collator=data_collator,
+ )
+ # disable model_accepts_loss_kwargs
+ trainer.model_accepts_loss_kwargs = False
+ trainer.train()
+
+ # Calculate the difference between the base loss and the grad_accum loss
+ diff_truth = [
+ abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
+ ]
+ diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
+
+ # all diff truth should be quite close
+ self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
+
+ # max diff broken should be very off
+ self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
+
+ @slow
+ def test_gradient_accumulation_loss_alignment_with_loss_func(self):
+ set_seed(42)
+ import datasets
+
+ model_name = "roneneldan/TinyStories-33M"
dataset_name = "wikitext"
dataset_config = "wikitext-2-raw-v1"
dataset = datasets.load_dataset(dataset_name, dataset_config, split="train[:500]")
@@ -836,15 +927,16 @@ def compute_loss(logits, labels, vocab_size, num_items_in_batch, disable_num_ite
trainer.train()
# Calculate the difference between the base loss and the grad_accum loss
- diff_truth = [base - grad for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)]
- diff_broken = [base - grad for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
- # These should be quite close
- for diff in diff_truth:
- self.assertLess(abs(diff), 0.1, f"Difference {diff} is not within 0.1")
-
- # These should be very off
- for diff in diff_broken:
- self.assertGreater(abs(diff), 0.1, f"Difference {diff} is not greater than 0.1")
+ diff_truth = [
+ abs(base - grad) for base, grad in zip(base_loss_callback.losses, grad_accum_loss_callback.losses)
+ ]
+ diff_broken = [abs(base - grad) for base, grad in zip(base_loss_callback.losses, broken_loss_callback.losses)]
+
+ # all diff truth should be quite close
+ self.assertLess(max(diff_truth), 0.01, f"Difference {max(diff_truth)} is not within 0.01")
+
+ # max diff broken should be very off
+ self.assertGreater(max(diff_broken), 3, f"Difference {max(diff_broken)} is not greater than 3")
def test_gradient_accumulation(self):
# Training with half the batch size but accumulation steps as 2 should give the same training losses.