Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Aurevoir PyTorch 1 #35358

Merged
merged 3 commits into from
Dec 20, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
33 changes: 0 additions & 33 deletions .github/workflows/self-nightly-past-ci-caller.yml
Original file line number Diff line number Diff line change
Expand Up @@ -21,39 +21,6 @@ jobs:
echo "$(python3 -c 'print(int(${{ github.run_number }}) % 10)')"
echo "run_number=$(python3 -c 'print(int(${{ github.run_number }}) % 10)')" >> $GITHUB_OUTPUT
run_past_ci_pytorch_1-13:
name: PyTorch 1.13
needs: get_number
if: needs.get_number.outputs.run_number == 0 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
uses: ./.github/workflows/self-past-caller.yml
with:
framework: pytorch
version: "1.13"
sha: ${{ github.sha }}
secrets: inherit

run_past_ci_pytorch_1-12:
name: PyTorch 1.12
needs: get_number
if: needs.get_number.outputs.run_number == 1 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
uses: ./.github/workflows/self-past-caller.yml
with:
framework: pytorch
version: "1.12"
sha: ${{ github.sha }}
secrets: inherit

run_past_ci_pytorch_1-11:
name: PyTorch 1.11
needs: get_number
if: needs.get_number.outputs.run_number == 2 && (cancelled() != true) && ((github.event_name == 'schedule') || ((github.event_name == 'push') && startsWith(github.ref_name, 'run_past_ci')))
uses: ./.github/workflows/self-past-caller.yml
with:
framework: pytorch
version: "1.11"
sha: ${{ github.sha }}
secrets: inherit

run_past_ci_tensorflow_2-11:
name: TensorFlow 2.11
needs: get_number
Expand Down
2 changes: 1 addition & 1 deletion README.md
Original file line number Diff line number Diff line change
Expand Up @@ -249,7 +249,7 @@ The model itself is a regular [Pytorch `nn.Module`](https://pytorch.org/docs/sta

### With pip

This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, and TensorFlow 2.6+.
This repository is tested on Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, and TensorFlow 2.6+.

You should install 🤗 Transformers in a [virtual environment](https://docs.python.org/3/library/venv.html). If you're unfamiliar with Python virtual environments, check out the [user guide](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ar.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ limitations under the License.

### باستخدام pip

تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، و TensorFlow 2.6+.
تم اختبار هذا المستودع على Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، و TensorFlow 2.6+.

يجب تثبيت 🤗 Transformers في [بيئة افتراضية](https://docs.python.org/3/library/venv.html). إذا كنت غير معتاد على البيئات الافتراضية Python، فراجع [دليل المستخدم](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_de.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ Das Modell selbst ist ein reguläres [PyTorch `nn.Module`](https://pytorch.org/d

### Mit pip

Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ und TensorFlow 2.6+ getestet.
Dieses Repository wurde mit Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ und TensorFlow 2.6+ getestet.

Sie sollten 🤗 Transformers in einer [virtuellen Umgebung](https://docs.python.org/3/library/venv.html) installieren. Wenn Sie mit virtuellen Python-Umgebungen nicht vertraut sind, schauen Sie sich den [Benutzerleitfaden](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) an.

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_es.md
Original file line number Diff line number Diff line change
Expand Up @@ -222,7 +222,7 @@ El modelo en si es un [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.h

### Con pip

Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ y TensorFlow 2.6+.
Este repositorio está probado en Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ y TensorFlow 2.6+.

Deberías instalar 🤗 Transformers en un [entorno virtual](https://docs.python.org/3/library/venv.html). Si no estas familiarizado con los entornos virtuales de Python, consulta la [guía de usuario](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_fr.md
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ Le modèle lui-même est un module [`nn.Module` PyTorch](https://pytorch.org/doc

### Avec pip

Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ et TensorFlow 2.6+.
Ce référentiel est testé sur Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ et TensorFlow 2.6+.

Vous devriez installer 🤗 Transformers dans un [environnement virtuel](https://docs.python.org/3/library/venv.html). Si vous n'êtes pas familier avec les environnements virtuels Python, consultez le [guide utilisateur](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_hd.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ checkpoint: जाँच बिंदु

### पिप का उपयोग करना

इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ और TensorFlow 2.6+ के तहत किया गया है।
इस रिपॉजिटरी का परीक्षण Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ और TensorFlow 2.6+ के तहत किया गया है।

आप [वर्चुअल एनवायरनमेंट](https://docs.python.org/3/library/venv.html) में 🤗 ट्रांसफॉर्मर इंस्टॉल कर सकते हैं। यदि आप अभी तक पायथन के वर्चुअल एनवायरनमेंट से परिचित नहीं हैं, तो कृपया इसे [उपयोगकर्ता निर्देश](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) पढ़ें।

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ja.md
Original file line number Diff line number Diff line change
Expand Up @@ -256,7 +256,7 @@ Hugging Faceチームによって作られた **[トランスフォーマーを

### pipにて

このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+ でテストされています。
このリポジトリは、Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+ でテストされています。

🤗Transformersは[仮想環境](https://docs.python.org/3/library/venv.html)にインストールする必要があります。Pythonの仮想環境に慣れていない場合は、[ユーザーガイド](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)を確認してください。

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ko.md
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ Transformers에 달린 100,000개의 별을 축하하기 위해, 우리는 커

### pip로 설치하기

이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 1.11+, TensorFlow 2.6+에서 테스트 되었습니다.
이 저장소는 Python 3.9+, Flax 0.4.1+, PyTorch 2.0+, TensorFlow 2.6+에서 테스트 되었습니다.

[가상 환경](https://docs.python.org/3/library/venv.html)에 🤗 Transformers를 설치하세요. Python 가상 환경에 익숙하지 않다면, [사용자 가이드](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)를 확인하세요.

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_pt-br.md
Original file line number Diff line number Diff line change
Expand Up @@ -253,7 +253,7 @@ O modelo em si é um [Pytorch `nn.Module`](https://pytorch.org/docs/stable/nn.ht

### Com pip

Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ e TensorFlow 2.6+.
Este repositório é testado no Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ e TensorFlow 2.6+.

Você deve instalar o 🤗 Transformers em um [ambiente virtual](https://docs.python.org/3/library/venv.html). Se você não está familiarizado com ambientes virtuais em Python, confira o [guia do usuário](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ru.md
Original file line number Diff line number Diff line change
Expand Up @@ -244,7 +244,7 @@ Hugging Face Hub. Мы хотим, чтобы Transformers позволил ра

### С помощью pip

Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ и TensorFlow 2.6+.
Данный репозиторий протестирован на Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ и TensorFlow 2.6+.

Устанавливать 🤗 Transformers следует в [виртуальной среде](https://docs.python.org/3/library/venv.html). Если вы не знакомы с виртуальными средами Python, ознакомьтесь с [руководством пользователя](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_te.md
Original file line number Diff line number Diff line change
Expand Up @@ -246,7 +246,7 @@ limitations under the License.

### పిప్ తో

ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 1.11+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.
ఈ రిపోజిటరీ పైథాన్ 3.9+, ఫ్లాక్స్ 0.4.1+, PyTorch 2.0+ మరియు TensorFlow 2.6+లో పరీక్షించబడింది.

మీరు [వర్చువల్ వాతావరణం](https://docs.python.org/3/library/venv.html)లో 🤗 ట్రాన్స్‌ఫార్మర్‌లను ఇన్‌స్టాల్ చేయాలి. మీకు పైథాన్ వర్చువల్ పరిసరాల గురించి తెలియకుంటే, [యూజర్ గైడ్](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) చూడండి.

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_ur.md
Original file line number Diff line number Diff line change
Expand Up @@ -259,7 +259,7 @@ limitations under the License.

#### ‏ pip کے ساتھ

یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 1.11+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔
یہ ریپوزٹری Python 3.9+، Flax 0.4.1+، PyTorch 2.0+، اور TensorFlow 2.6+ پر ٹیسٹ کی گئی ہے۔

آپ کو 🤗 Transformers کو ایک [ورچوئل ماحول](https://docs.python.org/3/library/venv.html) میں انسٹال کرنا چاہیے۔ اگر آپ Python ورچوئل ماحول سے واقف نہیں ہیں، تو [یوزر گائیڈ](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/) دیکھیں۔

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_vi.md
Original file line number Diff line number Diff line change
Expand Up @@ -245,7 +245,7 @@ Chính mô hình là một [Pytorch `nn.Module`](https://pytorch.org/docs/stable

### Sử dụng pip

Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 1.11+ và TensorFlow 2.6+.
Thư viện này được kiểm tra trên Python 3.9+, Flax 0.4.1+, PyTorch 2.0+ và TensorFlow 2.6+.

Bạn nên cài đặt 🤗 Transformers trong một [môi trường ảo Python](https://docs.python.org/3/library/venv.html). Nếu bạn chưa quen với môi trường ảo Python, hãy xem [hướng dẫn sử dụng](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/).

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_zh-hans.md
Original file line number Diff line number Diff line change
Expand Up @@ -198,7 +198,7 @@ checkpoint: 检查点

### 使用 pip

这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下经过测试。
这个仓库已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下经过测试。

你可以在[虚拟环境](https://docs.python.org/3/library/venv.html)中安装 🤗 Transformers。如果你还不熟悉 Python 的虚拟环境,请阅此[用户说明](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。

Expand Down
2 changes: 1 addition & 1 deletion i18n/README_zh-hant.md
Original file line number Diff line number Diff line change
Expand Up @@ -210,7 +210,7 @@ Tokenizer 為所有的預訓練模型提供了預處理,並可以直接轉換

### 使用 pip

這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 1.11+ 和 TensorFlow 2.6+ 下經過測試。
這個 Repository 已在 Python 3.9+、Flax 0.4.1+、PyTorch 2.0+ 和 TensorFlow 2.6+ 下經過測試。

你可以在[虛擬環境](https://docs.python.org/3/library/venv.html)中安裝 🤗 Transformers。如果你還不熟悉 Python 的虛擬環境,請閱此[使用者指引](https://packaging.python.org/guides/installing-using-pip-and-virtual-environments/)。

Expand Down
3 changes: 1 addition & 2 deletions src/transformers/convert_pytorch_checkpoint_to_tf2.py
Original file line number Diff line number Diff line change
Expand Up @@ -106,7 +106,6 @@
XLMWithLMHeadModel,
XLNetLMHeadModel,
)
from .pytorch_utils import is_torch_greater_or_equal_than_1_13


logging.set_verbosity_info()
Expand Down Expand Up @@ -279,7 +278,7 @@ def convert_pt_checkpoint_to_tf(
if compare_with_pt_model:
tfo = tf_model(tf_model.dummy_inputs, training=False) # build the network

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": True}
state_dict = torch.load(
pytorch_checkpoint_path,
map_location="cpu",
Expand Down
8 changes: 2 additions & 6 deletions src/transformers/modeling_flax_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -63,8 +63,6 @@ def load_pytorch_checkpoint_in_flax_state_dict(
else:
try:
import torch # noqa: F401

from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except (ImportError, ModuleNotFoundError):
logger.error(
"Loading a PyTorch model in Flax, requires both PyTorch and Flax to be installed. Please see"
Expand All @@ -73,7 +71,7 @@ def load_pytorch_checkpoint_in_flax_state_dict(
)
raise

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": True}
pt_state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)
logger.info(f"PyTorch checkpoint contains {sum(t.numel() for t in pt_state_dict.values()):,} parameters.")

Expand Down Expand Up @@ -246,13 +244,11 @@ def convert_pytorch_state_dict_to_flax(pt_state_dict, flax_model):
def convert_pytorch_sharded_state_dict_to_flax(shard_filenames, flax_model):
import torch

from .pytorch_utils import is_torch_greater_or_equal_than_1_13

# Load the index
flax_state_dict = {}
for shard_file in shard_filenames:
# load using msgpack utils
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": True}
pt_state_dict = torch.load(shard_file, **weights_only_kwarg)
weight_dtypes = {k: v.dtype for k, v in pt_state_dict.items()}
pt_state_dict = {
Expand Down
4 changes: 1 addition & 3 deletions src/transformers/modeling_tf_pytorch_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,8 +180,6 @@ def load_pytorch_checkpoint_in_tf2_model(
import tensorflow as tf # noqa: F401
import torch # noqa: F401
from safetensors.torch import load_file as safe_load_file # noqa: F401

from .pytorch_utils import is_torch_greater_or_equal_than_1_13 # noqa: F401
except ImportError:
logger.error(
"Loading a PyTorch model in TensorFlow, requires both PyTorch and TensorFlow to be installed. Please see "
Expand All @@ -201,7 +199,7 @@ def load_pytorch_checkpoint_in_tf2_model(
if pt_path.endswith(".safetensors"):
state_dict = safe_load_file(pt_path)
else:
weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": True}
state_dict = torch.load(pt_path, map_location="cpu", **weights_only_kwarg)

pt_state_dict.update(state_dict)
Expand Down
5 changes: 2 additions & 3 deletions src/transformers/modeling_utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -54,7 +54,6 @@
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
id_tensor_storage,
is_torch_greater_or_equal_than_1_13,
prune_conv1d_layer,
prune_layer,
prune_linear_layer,
Expand Down Expand Up @@ -476,7 +475,7 @@ def load_sharded_checkpoint(model, folder, strict=True, prefer_safe=True):
error_message += f"\nMissing key(s): {str_unexpected_keys}."
raise RuntimeError(error_message)

weights_only_kwarg = {"weights_only": True} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": True}
loader = safe_load_file if load_safe else partial(torch.load, map_location="cpu", **weights_only_kwarg)

for shard_file in shard_files:
Expand Down Expand Up @@ -532,7 +531,7 @@ def load_state_dict(
and is_zipfile(checkpoint_file)
):
extra_args = {"mmap": True}
weights_only_kwarg = {"weights_only": weights_only} if is_torch_greater_or_equal_than_1_13 else {}
weights_only_kwarg = {"weights_only": weights_only}
return torch.load(
checkpoint_file,
map_location=map_location,
Expand Down
9 changes: 0 additions & 9 deletions src/transformers/models/falcon/modeling_falcon.py
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,6 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_2_0
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand Down Expand Up @@ -815,14 +814,6 @@ def _init_weights(self, module: nn.Module):
# Adapted from transformers.modeling_utils.PreTrainedModel._check_and_enable_sdpa
@classmethod
def _check_and_enable_sdpa(cls, config, hard_check_only: bool = False) -> "PretrainedConfig":
# NOTE: Falcon supported SDPA from PyTorch 2.0. We keep it like that for backward compatibility (automatically use SDPA for torch>=2.0).
if hard_check_only:
if not is_torch_greater_or_equal_than_2_0:
raise ImportError("PyTorch SDPA requirements in Transformers are not met. Please install torch>=2.0.")

if not is_torch_greater_or_equal_than_2_0:
return config

_is_bettertransformer = getattr(cls, "use_bettertransformer", False)
if _is_bettertransformer:
return config
Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/gpt_neo/modeling_gpt_neo.py
Original file line number Diff line number Diff line change
Expand Up @@ -36,7 +36,6 @@
TokenClassifierOutput,
)
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
add_code_sample_docstrings,
add_start_docstrings,
Expand All @@ -56,9 +55,6 @@
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx

_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)


Expand Down
4 changes: 0 additions & 4 deletions src/transformers/models/phimoe/modeling_phimoe.py
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,6 @@
)
from ...modeling_rope_utils import ROPE_INIT_FUNCTIONS
from ...modeling_utils import PreTrainedModel
from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
add_start_docstrings,
add_start_docstrings_to_model_forward,
Expand All @@ -51,9 +50,6 @@
# This makes `_prepare_4d_causal_attention_mask` a leaf function in the FX graph.
# It means that the function will not be traced through and simply appear as a node in the graph.
if is_torch_fx_available():
if not is_torch_greater_or_equal_than_1_13:
import torch.fx

_prepare_4d_causal_attention_mask = torch.fx.wrap(_prepare_4d_causal_attention_mask)


Expand Down
3 changes: 1 addition & 2 deletions src/transformers/models/superpoint/modeling_superpoint.py
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,6 @@
)
from transformers.models.superpoint.configuration_superpoint import SuperPointConfig

from ...pytorch_utils import is_torch_greater_or_equal_than_1_13
from ...utils import (
ModelOutput,
add_start_docstrings,
Expand Down Expand Up @@ -314,7 +313,7 @@ def _sample_descriptors(keypoints, descriptors, scale: int = 8) -> torch.Tensor:
divisor = divisor.to(keypoints)
keypoints /= divisor
keypoints = keypoints * 2 - 1 # normalize to (-1, 1)
kwargs = {"align_corners": True} if is_torch_greater_or_equal_than_1_13 else {}
kwargs = {"align_corners": True}
# [batch_size, num_channels, num_keypoints, 2] -> [batch_size, num_channels, num_keypoints, 2]
keypoints = keypoints.view(batch_size, 1, -1, 2)
descriptors = nn.functional.grid_sample(descriptors, keypoints, mode="bilinear", **kwargs)
Expand Down
7 changes: 0 additions & 7 deletions src/transformers/models/tapas/modeling_tapas.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,7 +31,6 @@
from ...pytorch_utils import (
apply_chunking_to_forward,
find_pruneable_heads_and_indices,
is_torch_greater_or_equal_than_1_12,
prune_linear_layer,
)
from ...utils import (
Expand All @@ -46,12 +45,6 @@

logger = logging.get_logger(__name__)

if not is_torch_greater_or_equal_than_1_12:
logger.warning(
f"You are using torch=={torch.__version__}, but torch>=1.12.0 is required to use "
"TapasModel. Please upgrade torch."
)

_CONFIG_FOR_DOC = "TapasConfig"
_CHECKPOINT_FOR_DOC = "google/tapas-base"

Expand Down
Loading
Loading