Skip to content

Commit

Permalink
Image Feature Extraction pipeline (#28216)
Browse files Browse the repository at this point in the history
* Draft pipeline

* Fixup

* Fix docstrings

* Update doctest

* Update pipeline_model_mapping

* Update docstring

* Update tests

* Update src/transformers/pipelines/image_feature_extraction.py

Co-authored-by: Omar Sanseviero <[email protected]>

* Fix docstrings - review comments

* Remove pipeline mapping for composite vision models

* Add to pipeline tests

* Remove for flava (multimodal)

* safe pil import

* Add requirements for pipeline run

* Account for super slow efficientnet

* Review comments

* Fix tests

* Swap order of kwargs

* Use build_pipeline_init_args

* Add back FE pipeline for Vilt

* Include image_processor_kwargs in docstring

* Mark test as flaky

* Update TODO

* Update tests/pipelines/test_pipelines_image_feature_extraction.py

Co-authored-by: Arthur <[email protected]>

* Add license header

---------

Co-authored-by: Omar Sanseviero <[email protected]>
Co-authored-by: Arthur <[email protected]>
  • Loading branch information
3 people authored Feb 5, 2024
1 parent 7addc93 commit ba3264b
Show file tree
Hide file tree
Showing 60 changed files with 387 additions and 53 deletions.
6 changes: 6 additions & 0 deletions docs/source/en/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -469,6 +469,12 @@ Pipelines available for multimodal tasks include the following.
- __call__
- all

### ImageFeatureExtractionPipeline

[[autodoc]] ImageFeatureExtractionPipeline
- __call__
- all

### ImageToTextPipeline

[[autodoc]] ImageToTextPipeline
Expand Down
8 changes: 7 additions & 1 deletion docs/source/ja/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -25,7 +25,7 @@ Recognition、Masked Language Modeling、Sentiment Analysis、Feature Extraction
パイプラインの抽象化には2つのカテゴリーがある:

- [`pipeline`] は、他のすべてのパイプラインをカプセル化する最も強力なオブジェクトです。
- タスク固有のパイプラインは、[オーディオ](#audio)[コンピューター ビジョン](#computer-vision)[自然言語処理](#natural-language-processing)、および [マルチモーダル](#multimodal) タスクで使用できます。
- タスク固有のパイプラインは、[オーディオ](#audio)[コンピューター ビジョン](#computer-vision)[自然言語処理](#natural-language-processing)、および [マルチモーダル](#multimodal) タスクで使用できます。

## The pipeline abstraction

Expand Down Expand Up @@ -477,6 +477,12 @@ my_pipeline = pipeline(model="xxxx", pipeline_class=MyPipeline)
- __call__
- all

### ImageFeatureExtractionPipeline

[[autodoc]] ImageFeatureExtractionPipeline
- __call__
- all

### ImageToTextPipeline

[[autodoc]] ImageToTextPipeline
Expand Down
8 changes: 7 additions & 1 deletion docs/source/zh/main_classes/pipelines.md
Original file line number Diff line number Diff line change
Expand Up @@ -435,7 +435,7 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all

## 多模态
## 多模态

可用于多模态任务的pipeline包括以下几种。

Expand All @@ -451,6 +451,12 @@ See [`TokenClassificationPipeline`] for all details.
- __call__
- all

### ImageFeatureExtractionPipeline

[[autodoc]] ImageFeatureExtractionPipeline
- __call__
- all

### ImageToTextPipeline

[[autodoc]] ImageToTextPipeline
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -973,6 +973,7 @@
"FeatureExtractionPipeline",
"FillMaskPipeline",
"ImageClassificationPipeline",
"ImageFeatureExtractionPipeline",
"ImageSegmentationPipeline",
"ImageToImagePipeline",
"ImageToTextPipeline",
Expand Down Expand Up @@ -5709,6 +5710,7 @@
FeatureExtractionPipeline,
FillMaskPipeline,
ImageClassificationPipeline,
ImageFeatureExtractionPipeline,
ImageSegmentationPipeline,
ImageToImagePipeline,
ImageToTextPipeline,
Expand Down
15 changes: 15 additions & 0 deletions src/transformers/pipelines/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -66,6 +66,7 @@
from .feature_extraction import FeatureExtractionPipeline
from .fill_mask import FillMaskPipeline
from .image_classification import ImageClassificationPipeline
from .image_feature_extraction import ImageFeatureExtractionPipeline
from .image_segmentation import ImageSegmentationPipeline
from .image_to_image import ImageToImagePipeline
from .image_to_text import ImageToTextPipeline
Expand Down Expand Up @@ -362,6 +363,18 @@
},
"type": "image",
},
"image-feature-extraction": {
"impl": ImageFeatureExtractionPipeline,
"tf": (TFAutoModel,) if is_tf_available() else (),
"pt": (AutoModel,) if is_torch_available() else (),
"default": {
"model": {
"pt": ("google/vit-base-patch16-224", "29e7a1e183"),
"tf": ("google/vit-base-patch16-224", "29e7a1e183"),
}
},
"type": "image",
},
"image-segmentation": {
"impl": ImageSegmentationPipeline,
"tf": (),
Expand Down Expand Up @@ -500,6 +513,7 @@ def check_task(task: str) -> Tuple[str, Dict, Any]:
- `"feature-extraction"`
- `"fill-mask"`
- `"image-classification"`
- `"image-feature-extraction"`
- `"image-segmentation"`
- `"image-to-text"`
- `"image-to-image"`
Expand Down Expand Up @@ -586,6 +600,7 @@ def pipeline(
- `"feature-extraction"`: will return a [`FeatureExtractionPipeline`].
- `"fill-mask"`: will return a [`FillMaskPipeline`]:.
- `"image-classification"`: will return a [`ImageClassificationPipeline`].
- `"image-feature-extraction"`: will return an [`ImageFeatureExtractionPipeline`].
- `"image-segmentation"`: will return a [`ImageSegmentationPipeline`].
- `"image-to-image"`: will return a [`ImageToImagePipeline`].
- `"image-to-text"`: will return a [`ImageToTextPipeline`].
Expand Down
2 changes: 1 addition & 1 deletion src/transformers/pipelines/feature_extraction.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,7 @@
)
class FeatureExtractionPipeline(Pipeline):
"""
Feature extraction pipeline using no model head. This pipeline extracts the hidden states from the base
Feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
transformer, which can be used as features in downstream tasks.
Example:
Expand Down
92 changes: 92 additions & 0 deletions src/transformers/pipelines/image_feature_extraction.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,92 @@
from typing import Dict

from ..utils import add_end_docstrings, is_vision_available
from .base import GenericTensor, Pipeline, build_pipeline_init_args


if is_vision_available():
from ..image_utils import load_image


@add_end_docstrings(
build_pipeline_init_args(has_image_processor=True),
"""
image_processor_kwargs (`dict`, *optional*):
Additional dictionary of keyword arguments passed along to the image processor e.g.
{"size": {"height": 100, "width": 100}}
""",
)
class ImageFeatureExtractionPipeline(Pipeline):
"""
Image feature extraction pipeline uses no model head. This pipeline extracts the hidden states from the base
transformer, which can be used as features in downstream tasks.
Example:
```python
>>> from transformers import pipeline
>>> extractor = pipeline(model="google/vit-base-patch16-224", task="image-feature-extraction")
>>> result = extractor("https://huggingface.co/datasets/Narsil/image_dummy/raw/main/parrots.png", return_tensors=True)
>>> result.shape # This is a tensor of shape [1, sequence_lenth, hidden_dimension] representing the input image.
torch.Size([1, 197, 768])
```
Learn more about the basics of using a pipeline in the [pipeline tutorial](../pipeline_tutorial)
This image feature extraction pipeline can currently be loaded from [`pipeline`] using the task identifier:
`"image-feature-extraction"`.
All vision models may be used for this pipeline. See a list of all models, including community-contributed models on
[huggingface.co/models](https://huggingface.co/models).
"""

def _sanitize_parameters(self, image_processor_kwargs=None, return_tensors=None, **kwargs):
preprocess_params = {} if image_processor_kwargs is None else image_processor_kwargs
postprocess_params = {"return_tensors": return_tensors} if return_tensors is not None else {}

if "timeout" in kwargs:
preprocess_params["timeout"] = kwargs["timeout"]

return preprocess_params, {}, postprocess_params

def preprocess(self, image, timeout=None, **image_processor_kwargs) -> Dict[str, GenericTensor]:
image = load_image(image, timeout=timeout)
model_inputs = self.image_processor(image, return_tensors=self.framework, **image_processor_kwargs)
return model_inputs

def _forward(self, model_inputs):
model_outputs = self.model(**model_inputs)
return model_outputs

def postprocess(self, model_outputs, return_tensors=False):
# [0] is the first available tensor, logits or last_hidden_state.
if return_tensors:
return model_outputs[0]
if self.framework == "pt":
return model_outputs[0].tolist()
elif self.framework == "tf":
return model_outputs[0].numpy().tolist()

def __call__(self, *args, **kwargs):
"""
Extract the features of the input(s).
Args:
images (`str`, `List[str]`, `PIL.Image` or `List[PIL.Image]`):
The pipeline handles three types of images:
- A string containing a http link pointing to an image
- A string containing a local path to an image
- An image loaded in PIL directly
The pipeline accepts either a single image or a batch of images, which must then be passed as a string.
Images in a batch must all be in the same format: all as http links, all as local paths, or all as PIL
images.
timeout (`float`, *optional*, defaults to None):
The maximum time in seconds to wait for fetching images from the web. If None, no timeout is used and
the call may block forever.
Return:
A nested list of `float`: The features computed by the model.
"""
return super().__call__(*args, **kwargs)
15 changes: 15 additions & 0 deletions src/transformers/pipelines/image_to_text.py
Original file line number Diff line number Diff line change
@@ -1,3 +1,18 @@
# coding=utf-8
# Copyright 2024 The HuggingFace Inc. team. All rights reserved.
#
# Licensed under the Apache License, Version 2.0 (the "License");
# you may not use this file except in compliance with the License.
# You may obtain a copy of the License at
#
# http://www.apache.org/licenses/LICENSE-2.0
#
# Unless required by applicable law or agreed to in writing, software
# distributed under the License is distributed on an "AS IS" BASIS,
# WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
# See the License for the specific language governing permissions and
# limitations under the License.

from typing import List, Union

from ..utils import (
Expand Down
2 changes: 1 addition & 1 deletion tests/models/beit/test_modeling_beit.py
Original file line number Diff line number Diff line change
Expand Up @@ -242,7 +242,7 @@ class BeitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
)
pipeline_model_mapping = (
{
"feature-extraction": BeitModel,
"image-feature-extraction": BeitModel,
"image-classification": BeitForImageClassification,
"image-segmentation": BeitForSemanticSegmentation,
}
Expand Down
2 changes: 1 addition & 1 deletion tests/models/bit/test_modeling_bit.py
Original file line number Diff line number Diff line change
Expand Up @@ -162,7 +162,7 @@ class BitModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):

all_model_classes = (BitModel, BitForImageClassification, BitBackbone) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": BitModel, "image-classification": BitForImageClassification}
{"image-feature-extraction": BitModel, "image-classification": BitForImageClassification}
if is_torch_available()
else {}
)
Expand Down
5 changes: 4 additions & 1 deletion tests/models/blip/test_modeling_blip.py
Original file line number Diff line number Diff line change
Expand Up @@ -429,7 +429,10 @@ def prepare_config_and_inputs_for_common(self):
class BlipModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (BlipModel,) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": BlipModel, "image-to-text": BlipForConditionalGeneration}
{
"feature-extraction": BlipModel,
"image-to-text": BlipForConditionalGeneration,
}
if is_torch_available()
else {}
)
Expand Down
4 changes: 3 additions & 1 deletion tests/models/clip/test_modeling_clip.py
Original file line number Diff line number Diff line change
Expand Up @@ -477,7 +477,9 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class CLIPModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (CLIPModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": CLIPModel} if is_torch_available() else {}
pipeline_model_mapping = (
{"feature-extraction": CLIPModel, "image-feature-extraction": CLIPVisionModel} if is_torch_available() else {}
)
fx_compatible = True
test_head_masking = False
test_pruning = False
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -185,7 +185,7 @@ class ConditionalDetrModelTest(ModelTesterMixin, GenerationTesterMixin, Pipeline
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
{"image-feature-extraction": ConditionalDetrModel, "object-detection": ConditionalDetrForObjectDetection}
if is_torch_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/convnext/test_modeling_convnext.py
Original file line number Diff line number Diff line change
Expand Up @@ -172,7 +172,7 @@ class ConvNextModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
{"image-feature-extraction": ConvNextModel, "image-classification": ConvNextForImageClassification}
if is_torch_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/convnextv2/test_modeling_convnextv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -180,7 +180,7 @@ class ConvNextV2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCa
else ()
)
pipeline_model_mapping = (
{"feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
{"image-feature-extraction": ConvNextV2Model, "image-classification": ConvNextV2ForImageClassification}
if is_torch_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/cvt/test_modeling_cvt.py
Original file line number Diff line number Diff line change
Expand Up @@ -151,7 +151,7 @@ class CvtModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):

all_model_classes = (CvtModel, CvtForImageClassification) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
{"image-feature-extraction": CvtModel, "image-classification": CvtForImageClassification}
if is_torch_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/data2vec/test_modeling_data2vec_vision.py
Original file line number Diff line number Diff line change
Expand Up @@ -178,7 +178,7 @@ class Data2VecVisionModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.Te
)
pipeline_model_mapping = (
{
"feature-extraction": Data2VecVisionModel,
"image-feature-extraction": Data2VecVisionModel,
"image-classification": Data2VecVisionForImageClassification,
"image-segmentation": Data2VecVisionForSemanticSegmentation,
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -191,7 +191,7 @@ def create_and_check_deformable_detr_object_detection_head_model(self, config, p
class DeformableDetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DeformableDetrModel, DeformableDetrForObjectDetection) if is_torch_available() else ()
pipeline_model_mapping = (
{"feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
{"image-feature-extraction": DeformableDetrModel, "object-detection": DeformableDetrForObjectDetection}
if is_torch_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/deit/test_modeling_deit.py
Original file line number Diff line number Diff line change
Expand Up @@ -206,7 +206,7 @@ class DeiTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
)
pipeline_model_mapping = (
{
"feature-extraction": DeiTModel,
"image-feature-extraction": DeiTModel,
"image-classification": (DeiTForImageClassification, DeiTForImageClassificationWithTeacher),
}
if is_torch_available()
Expand Down
2 changes: 1 addition & 1 deletion tests/models/deta/test_modeling_deta.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ def create_and_check_deta_object_detection_head_model(self, config, pixel_values
class DetaModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DetaModel, DetaForObjectDetection) if is_torchvision_available() else ()
pipeline_model_mapping = (
{"feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
{"image-feature-extraction": DetaModel, "object-detection": DetaForObjectDetection}
if is_torchvision_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/detr/test_modeling_detr.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,7 +182,7 @@ class DetrModelTest(ModelTesterMixin, GenerationTesterMixin, PipelineTesterMixin
)
pipeline_model_mapping = (
{
"feature-extraction": DetrModel,
"image-feature-extraction": DetrModel,
"image-segmentation": DetrForSegmentation,
"object-detection": DetrForObjectDetection,
}
Expand Down
2 changes: 1 addition & 1 deletion tests/models/dinat/test_modeling_dinat.py
Original file line number Diff line number Diff line change
Expand Up @@ -207,7 +207,7 @@ class DinatModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
{"image-feature-extraction": DinatModel, "image-classification": DinatForImageClassification}
if is_torch_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/dinov2/test_modeling_dinov2.py
Original file line number Diff line number Diff line change
Expand Up @@ -217,7 +217,7 @@ class Dinov2ModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
else ()
)
pipeline_model_mapping = (
{"feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
{"image-feature-extraction": Dinov2Model, "image-classification": Dinov2ForImageClassification}
if is_torch_available()
else {}
)
Expand Down
2 changes: 1 addition & 1 deletion tests/models/donut/test_modeling_donut_swin.py
Original file line number Diff line number Diff line change
Expand Up @@ -145,7 +145,7 @@ def prepare_config_and_inputs_for_common(self):
@require_torch
class DonutSwinModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
all_model_classes = (DonutSwinModel,) if is_torch_available() else ()
pipeline_model_mapping = {"feature-extraction": DonutSwinModel} if is_torch_available() else {}
pipeline_model_mapping = {"image-feature-extraction": DonutSwinModel} if is_torch_available() else {}
fx_compatible = True

test_pruning = False
Expand Down
2 changes: 1 addition & 1 deletion tests/models/dpt/test_modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -163,7 +163,7 @@ class DPTModelTest(ModelTesterMixin, PipelineTesterMixin, unittest.TestCase):
pipeline_model_mapping = (
{
"depth-estimation": DPTForDepthEstimation,
"feature-extraction": DPTModel,
"image-feature-extraction": DPTModel,
"image-segmentation": DPTForSemanticSegmentation,
}
if is_torch_available()
Expand Down
Loading

0 comments on commit ba3264b

Please sign in to comment.