-
Notifications
You must be signed in to change notification settings - Fork 27.4k
New issue
Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.
By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.
Already on GitHub? Sign in to your account
Added segmentation maps
support for DPT image processor
#34345
base: main
Are you sure you want to change the base?
Changes from 3 commits
469eddb
befbbf2
3033509
145fbd4
9470d65
9d16441
8cadf76
f297af5
67890de
bf42c3b
3cb8676
40821a2
28fb02f
d4e1acb
c57eafd
ae5cbf8
d6a5c23
1887159
4e90b99
6a912ff
1867be6
d9e6f30
597efd2
42b36d7
861758e
286ffaa
54be2d7
3a8eb74
318fe25
1339a14
c1a8520
098962d
857d46c
1e492af
4dc1a69
a0f4f31
c50b567
62ab94d
97514a8
74db22f
11cc229
f4c04ba
1de3598
9121ab8
4e6b19c
b13916c
a464afb
a830df2
b76a292
890ea7d
95c10fe
bfc3556
bdb29ff
73b4ab1
0e805e6
4d1d0f2
1141eff
5a45617
d5cf91b
1f6b423
a0ba631
64b73e6
6bc0c21
784d220
5bfb40b
6c3f168
4c1388f
8f48ccf
0d99a93
5f8b24e
0600f46
6372255
2910015
4120cb2
5523e38
f4b674f
4f0bf98
2b053fd
44af935
57ca9e6
5e8c1d7
6300212
9d6f0dd
01ad80f
f491096
0b5b5e6
89d7bf5
737f4dc
f7427f5
19dabe9
3480cbb
9ab8c5b
c24c79e
7b5f76e
f41d5d8
3183047
3129967
f0dec87
4955e4e
527dc04
f9c7e60
ee37bf0
901f504
7a7f276
125de41
3deaa81
346597b
b8cdc26
329f5db
c7a109e
accb720
46df859
1da1e0d
baa3b22
1ed1de2
beb2c66
54aae12
93f87d3
3544705
482cb28
95a855e
50189e3
e682c17
a928d9c
b0a51e5
e27465c
a5bb528
66ab300
44f88d8
98e8062
15ab310
9ad4c93
7f95372
c8c8dff
1ccca8f
9e420e0
1452dc2
de8a0b7
7238387
8e806a3
4bc39de
fa8763c
34f4080
dada0fd
0938b57
80f2b16
6acb4e4
5fba3f9
3e2769a
e5c45a6
425af6c
52d1354
217c47e
91b8ab1
5290f6a
e850892
10feacd
9094b87
d363e71
bcc50cc
5fcf628
7d303ef
33c12e4
6181c6b
a9ccdfd
3db8e27
5cf11e5
63766ab
e3ee49f
a691ccb
11ba1d4
31f9a28
e4e404f
64478c7
3d213b5
bdd4201
8096161
bc6ae0d
e94083b
6009642
7237b3e
add53e2
ca03842
5615a39
66531a1
1491028
85eb339
d0f3221
d5b81e1
9feae5f
22834ee
886f690
eb92bc4
f5620a7
a7f5479
f33a0ce
6c08b3b
747f361
6eb00dd
e0ae9b5
d29a06e
deac971
4302b27
927c3e3
a7feae1
8bfd7ee
77080f0
0531d75
1eee1ce
c7e4805
f1b7634
da334bc
69e31eb
75be5a0
2c47618
9a94dfe
9613933
d19b11f
4592cc9
56ff1e9
667ed56
2a134f6
1fa807f
0ade1ca
f42084e
ff9141b
d94832b
5a2aedc
4e27a40
b5a557e
eafbb0e
0d51d65
c3a4359
4567ee8
05de764
40292aa
34ad1bd
6fae2a8
0fc2970
504c4d3
c96cc03
94fe0b9
608e163
8f38f58
0bff533
dd1a4e7
d4c2857
6f4e61e
b65774e
a30323c
File filter
Filter by extension
Conversations
Jump to
Diff view
Diff view
There are no files selected for viewing
Original file line number | Diff line number | Diff line change |
---|---|---|
|
@@ -139,6 +139,11 @@ class DPTImageProcessor(BaseImageProcessor): | |
size_divisor (`int`, *optional*): | ||
If `do_pad` is `True`, pads the image dimensions to be divisible by this value. This was introduced in the | ||
DINOv2 paper, which uses the model in combination with DPT. | ||
do_reduce_labels (`bool`, *optional*, defaults to `False`): | ||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 is | ||
used for background, and background itself is not included in all classes of a dataset (e.g. ADE20k). The | ||
background label will be replaced by 255. Can be overridden by the `do_reduce_labels` parameter in the | ||
`preprocess` method. | ||
""" | ||
|
||
model_input_names = ["pixel_values"] | ||
|
@@ -157,6 +162,7 @@ def __init__( | |
image_std: Optional[Union[float, List[float]]] = None, | ||
do_pad: bool = False, | ||
size_divisor: int = None, | ||
do_reduce_labels: bool = False, | ||
**kwargs, | ||
) -> None: | ||
super().__init__(**kwargs) | ||
|
@@ -174,6 +180,7 @@ def __init__( | |
self.image_std = image_std if image_std is not None else IMAGENET_STANDARD_STD | ||
self.do_pad = do_pad | ||
self.size_divisor = size_divisor | ||
self.do_reduce_labels = do_reduce_labels | ||
|
||
def resize( | ||
self, | ||
|
@@ -275,10 +282,162 @@ def _get_pad(size, size_divisor): | |
|
||
return pad(image, ((pad_size_left, pad_size_right), (pad_size_top, pad_size_bottom)), data_format=data_format) | ||
|
||
def reduce_label(self, label: ImageInput) -> np.ndarray: | ||
label = to_numpy_array(label) | ||
# Avoid using underflow conversion | ||
label[label == 0] = 255 | ||
label = label - 1 | ||
label[label == 254] = 255 | ||
return label | ||
|
||
def _preprocess( | ||
self, | ||
image: ImageInput, | ||
do_reduce_labels: bool = None, | ||
do_resize: bool = None, | ||
size: Dict[str, int] = None, | ||
resample: PILImageResampling = None, | ||
keep_aspect_ratio: bool = None, | ||
ensure_multiple_of: int = None, | ||
do_rescale: bool = None, | ||
rescale_factor: float = None, | ||
do_normalize: bool = None, | ||
image_mean: Optional[Union[float, List[float]]] = None, | ||
image_std: Optional[Union[float, List[float]]] = None, | ||
do_pad: bool = None, | ||
size_divisor: int = None, | ||
input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
): | ||
# Adapted from transformers.models.beit.image_processing_beit | ||
|
||
if do_reduce_labels: | ||
image = self.reduce_label(image) | ||
|
||
if do_resize: | ||
image = self.resize( | ||
image=image, | ||
size=size, | ||
resample=resample, | ||
keep_aspect_ratio=keep_aspect_ratio, | ||
ensure_multiple_of=ensure_multiple_of, | ||
input_data_format=input_data_format, | ||
) | ||
|
||
if do_rescale: | ||
image = self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) | ||
|
||
if do_normalize: | ||
image = self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) | ||
|
||
if do_pad: | ||
image = self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) | ||
|
||
return image | ||
|
||
def _preprocess_image( | ||
self, | ||
image: ImageInput, | ||
do_resize: bool = None, | ||
size: Dict[str, int] = None, | ||
resample: PILImageResampling = None, | ||
keep_aspect_ratio: bool = None, | ||
ensure_multiple_of: int = None, | ||
do_rescale: bool = None, | ||
rescale_factor: float = None, | ||
do_normalize: bool = None, | ||
image_mean: Optional[Union[float, List[float]]] = None, | ||
image_std: Optional[Union[float, List[float]]] = None, | ||
do_pad: bool = None, | ||
size_divisor: int = None, | ||
data_format: Optional[Union[str, ChannelDimension]] = None, | ||
input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
) -> np.ndarray: | ||
"""Preprocesses a single image.""" | ||
# Adapted from transformers.models.beit.image_processing_beit | ||
# All transformations expect numpy arrays. | ||
image = to_numpy_array(image) | ||
if is_scaled_image(image) and do_rescale: | ||
logger.warning_once( | ||
"It looks like you are trying to rescale already rescaled images. If the input" | ||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." | ||
) | ||
if input_data_format is None: | ||
# We assume that all images have the same channel dimension format. | ||
input_data_format = infer_channel_dimension_format(image) | ||
|
||
image = self._preprocess( | ||
image, | ||
do_reduce_labels=False, | ||
do_resize=do_resize, | ||
size=size, | ||
resample=resample, | ||
keep_aspect_ratio=keep_aspect_ratio, | ||
ensure_multiple_of=ensure_multiple_of, | ||
do_rescale=do_rescale, | ||
rescale_factor=rescale_factor, | ||
do_normalize=do_normalize, | ||
image_mean=image_mean, | ||
image_std=image_std, | ||
do_pad=do_pad, | ||
size_divisor=size_divisor, | ||
input_data_format=input_data_format, | ||
) | ||
if data_format is not None: | ||
image = to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) | ||
return image | ||
|
||
def _preprocess_segmentation_map( | ||
self, | ||
segmentation_map: ImageInput, | ||
do_resize: bool = None, | ||
size: Dict[str, int] = None, | ||
resample: PILImageResampling = None, | ||
keep_aspect_ratio: bool = None, | ||
ensure_multiple_of: int = None, | ||
do_reduce_labels: bool = None, | ||
input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
): | ||
"""Preprocesses a single segmentation map.""" | ||
# Adapted from transformers.models.beit.image_processing_beit | ||
# All transformations expect numpy arrays. | ||
segmentation_map = to_numpy_array(segmentation_map) | ||
# Add an axis to the segmentation maps for transformations. | ||
if segmentation_map.ndim == 2: | ||
segmentation_map = segmentation_map[None, ...] | ||
added_dimension = True | ||
input_data_format = ChannelDimension.FIRST | ||
else: | ||
added_dimension = False | ||
if input_data_format is None: | ||
input_data_format = infer_channel_dimension_format(segmentation_map, num_channels=1) | ||
segmentation_map = self._preprocess( | ||
image=segmentation_map, | ||
do_reduce_labels=do_reduce_labels, | ||
do_resize=do_resize, | ||
size=size, | ||
resample=resample, | ||
keep_aspect_ratio=keep_aspect_ratio, | ||
ensure_multiple_of=ensure_multiple_of, | ||
do_normalize=False, | ||
do_rescale=False, | ||
input_data_format=input_data_format, | ||
) | ||
# Remove extra axis if added | ||
if added_dimension: | ||
segmentation_map = np.squeeze(segmentation_map, axis=0) | ||
segmentation_map = segmentation_map.astype(np.int64) | ||
return segmentation_map | ||
|
||
def __call__(self, images, segmentation_maps=None, **kwargs): | ||
# Overrides the `__call__` method of the `Preprocessor` class such that the images and segmentation maps can both | ||
# be passed in as positional arguments. | ||
return super().__call__(images, segmentation_maps=segmentation_maps, **kwargs) | ||
Comment on lines
+429
to
+432
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. Same here for adding a |
||
|
||
@filter_out_non_signature_kwargs() | ||
def preprocess( | ||
self, | ||
images: ImageInput, | ||
segmentation_maps: Optional[ImageInput] = None, | ||
There was a problem hiding this comment. Choose a reason for hiding this commentThe reason will be displayed to describe this comment to others. Learn more. This is a bit tricky as it could be a breaking change, if some users use |
||
do_resize: bool = None, | ||
size: int = None, | ||
keep_aspect_ratio: bool = None, | ||
|
@@ -291,6 +450,7 @@ def preprocess( | |
image_std: Optional[Union[float, List[float]]] = None, | ||
do_pad: bool = None, | ||
size_divisor: int = None, | ||
do_reduce_labels: Optional[bool] = None, | ||
return_tensors: Optional[Union[str, TensorType]] = None, | ||
data_format: ChannelDimension = ChannelDimension.FIRST, | ||
input_data_format: Optional[Union[str, ChannelDimension]] = None, | ||
|
@@ -302,6 +462,8 @@ def preprocess( | |
images (`ImageInput`): | ||
Image to preprocess. Expects a single or batch of images with pixel values ranging from 0 to 255. If | ||
passing in images with pixel values between 0 and 1, set `do_rescale=False`. | ||
segmentation_maps (`ImageInput`, *optional*): | ||
Segmentation map to preprocess. | ||
do_resize (`bool`, *optional*, defaults to `self.do_resize`): | ||
Whether to resize the image. | ||
size (`Dict[str, int]`, *optional*, defaults to `self.size`): | ||
|
@@ -326,6 +488,10 @@ def preprocess( | |
Image mean. | ||
image_std (`float` or `List[float]`, *optional*, defaults to `self.image_std`): | ||
Image standard deviation. | ||
do_reduce_labels (`bool`, *optional*, defaults to `self.do_reduce_labels`): | ||
Whether or not to reduce all label values of segmentation maps by 1. Usually used for datasets where 0 | ||
is used for background, and background itself is not included in all classes of a dataset (e.g. | ||
ADE20k). The background label will be replaced by 255. | ||
return_tensors (`str` or `TensorType`, *optional*): | ||
The type of tensors to return. Can be one of: | ||
- Unset: Return a list of `np.ndarray`. | ||
|
@@ -357,9 +523,13 @@ def preprocess( | |
image_std = image_std if image_std is not None else self.image_std | ||
do_pad = do_pad if do_pad is not None else self.do_pad | ||
size_divisor = size_divisor if size_divisor is not None else self.size_divisor | ||
do_reduce_labels = do_reduce_labels if do_reduce_labels is not None else self.do_reduce_labels | ||
|
||
images = make_list_of_images(images) | ||
|
||
if segmentation_maps is not None: | ||
segmentation_maps = make_list_of_images(segmentation_maps, expected_ndims=2) | ||
|
||
if not valid_images(images): | ||
raise ValueError( | ||
"Invalid image type. Must be of type PIL.Image.Image, numpy.ndarray, " | ||
|
@@ -377,55 +547,47 @@ def preprocess( | |
size=size, | ||
resample=resample, | ||
) | ||
# All transformations expect numpy arrays. | ||
images = [to_numpy_array(image) for image in images] | ||
|
||
if is_scaled_image(images[0]) and do_rescale: | ||
logger.warning_once( | ||
"It looks like you are trying to rescale already rescaled images. If the input" | ||
" images have pixel values between 0 and 1, set `do_rescale=False` to avoid rescaling them again." | ||
images = [ | ||
self._preprocess_image( | ||
image=img, | ||
do_resize=do_resize, | ||
do_rescale=do_rescale, | ||
do_normalize=do_normalize, | ||
do_pad=do_pad, | ||
size=size, | ||
resample=resample, | ||
keep_aspect_ratio=keep_aspect_ratio, | ||
ensure_multiple_of=ensure_multiple_of, | ||
rescale_factor=rescale_factor, | ||
image_mean=image_mean, | ||
image_std=image_std, | ||
size_divisor=size_divisor, | ||
data_format=data_format, | ||
input_data_format=input_data_format, | ||
) | ||
for img in images | ||
] | ||
|
||
if input_data_format is None: | ||
# We assume that all images have the same channel dimension format. | ||
input_data_format = infer_channel_dimension_format(images[0]) | ||
data = {"pixel_values": images} | ||
|
||
if do_resize: | ||
images = [ | ||
self.resize( | ||
image=image, | ||
if segmentation_maps is not None: | ||
segmentation_maps = [ | ||
self._preprocess_segmentation_map( | ||
segmentation_map=segmentation_map, | ||
do_reduce_labels=do_reduce_labels, | ||
do_resize=do_resize, | ||
size=size, | ||
resample=resample, | ||
keep_aspect_ratio=keep_aspect_ratio, | ||
ensure_multiple_of=ensure_multiple_of, | ||
input_data_format=input_data_format, | ||
) | ||
for image in images | ||
for segmentation_map in segmentation_maps | ||
] | ||
|
||
if do_rescale: | ||
images = [ | ||
self.rescale(image=image, scale=rescale_factor, input_data_format=input_data_format) | ||
for image in images | ||
] | ||
data["labels"] = segmentation_maps | ||
|
||
if do_normalize: | ||
images = [ | ||
self.normalize(image=image, mean=image_mean, std=image_std, input_data_format=input_data_format) | ||
for image in images | ||
] | ||
|
||
if do_pad: | ||
images = [ | ||
self.pad_image(image=image, size_divisor=size_divisor, input_data_format=input_data_format) | ||
for image in images | ||
] | ||
|
||
images = [ | ||
to_channel_dimension_format(image, data_format, input_channel_dim=input_data_format) for image in images | ||
] | ||
|
||
data = {"pixel_values": images} | ||
return BatchFeature(data=data, tensor_type=return_tensors) | ||
|
||
# Copied from transformers.models.beit.image_processing_beit.BeitImageProcessor.post_process_semantic_segmentation with Beit->DPT | ||
|
There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
This seems to be fully copied from beit image processor, you should add a
# Copied from
statement above if that's the case :)There was a problem hiding this comment.
Choose a reason for hiding this comment
The reason will be displayed to describe this comment to others. Learn more.
Done