From a769ed45e17c44fd17b85c025863c4e4f2f73634 Mon Sep 17 00:00:00 2001 From: Alexandros Benetatos <34627055+alex-bene@users.noreply.github.com> Date: Mon, 28 Oct 2024 20:44:20 +0200 Subject: [PATCH] Add `post_process_depth_estimation` for GLPN (#34413) * add depth postprocessing for GLPN * remove previous temp fix for glpn tests * Style changes for GLPN's `post_process_depth_estimation` Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> * additional style fix --------- Co-authored-by: Arthur <48595927+ArthurZucker@users.noreply.github.com> --- .../models/glpn/image_processing_glpn.py | 54 ++++++++++++++++++- src/transformers/models/glpn/modeling_glpn.py | 16 +++--- tests/models/glpn/test_modeling_glpn.py | 8 --- 3 files changed, 59 insertions(+), 19 deletions(-) diff --git a/src/transformers/models/glpn/image_processing_glpn.py b/src/transformers/models/glpn/image_processing_glpn.py index 9e69c8ae8a6e7a..115cefc86beec3 100644 --- a/src/transformers/models/glpn/image_processing_glpn.py +++ b/src/transformers/models/glpn/image_processing_glpn.py @@ -14,7 +14,11 @@ # limitations under the License. """Image processor class for GLPN.""" -from typing import List, Optional, Union +from typing import TYPE_CHECKING, Dict, List, Optional, Tuple, Union + + +if TYPE_CHECKING: + from ...modeling_outputs import DepthEstimatorOutput import numpy as np import PIL.Image @@ -27,12 +31,17 @@ get_image_size, infer_channel_dimension_format, is_scaled_image, + is_torch_available, make_list_of_images, to_numpy_array, valid_images, validate_preprocess_arguments, ) -from ...utils import TensorType, filter_out_non_signature_kwargs, logging +from ...utils import TensorType, filter_out_non_signature_kwargs, logging, requires_backends + + +if is_torch_available(): + import torch logger = logging.get_logger(__name__) @@ -218,3 +227,44 @@ def preprocess( data = {"pixel_values": images} return BatchFeature(data=data, tensor_type=return_tensors) + + def post_process_depth_estimation( + self, + outputs: "DepthEstimatorOutput", + target_sizes: Optional[Union[TensorType, List[Tuple[int, int]], None]] = None, + ) -> List[Dict[str, TensorType]]: + """ + Converts the raw output of [`DepthEstimatorOutput`] into final depth predictions and depth PIL images. + Only supports PyTorch. + + Args: + outputs ([`DepthEstimatorOutput`]): + Raw outputs of the model. + target_sizes (`TensorType` or `List[Tuple[int, int]]`, *optional*): + Tensor of shape `(batch_size, 2)` or list of tuples (`Tuple[int, int]`) containing the target size + (height, width) of each image in the batch. If left to None, predictions will not be resized. + + Returns: + `List[Dict[str, TensorType]]`: A list of dictionaries of tensors representing the processed depth + predictions. + """ + requires_backends(self, "torch") + + predicted_depth = outputs.predicted_depth + + if (target_sizes is not None) and (len(predicted_depth) != len(target_sizes)): + raise ValueError( + "Make sure that you pass in as many target sizes as the batch dimension of the predicted depth" + ) + + results = [] + target_sizes = [None] * len(predicted_depth) if target_sizes is None else target_sizes + for depth, target_size in zip(predicted_depth, target_sizes): + if target_size is not None: + depth = depth[None, None, ...] + depth = torch.nn.functional.interpolate(depth, size=target_size, mode="bicubic", align_corners=False) + depth = depth.squeeze() + + results.append({"predicted_depth": depth}) + + return results diff --git a/src/transformers/models/glpn/modeling_glpn.py b/src/transformers/models/glpn/modeling_glpn.py index 9fd22ca0f7be95..70f175df8c9973 100755 --- a/src/transformers/models/glpn/modeling_glpn.py +++ b/src/transformers/models/glpn/modeling_glpn.py @@ -723,20 +723,18 @@ def forward( >>> with torch.no_grad(): ... outputs = model(**inputs) - ... predicted_depth = outputs.predicted_depth >>> # interpolate to original size - >>> prediction = torch.nn.functional.interpolate( - ... predicted_depth.unsqueeze(1), - ... size=image.size[::-1], - ... mode="bicubic", - ... align_corners=False, + >>> post_processed_output = image_processor.post_process_depth_estimation( + ... outputs, + ... target_sizes=[(image.height, image.width)], ... ) >>> # visualize the prediction - >>> output = prediction.squeeze().cpu().numpy() - >>> formatted = (output * 255 / np.max(output)).astype("uint8") - >>> depth = Image.fromarray(formatted) + >>> predicted_depth = post_processed_output[0]["predicted_depth"] + >>> depth = predicted_depth * 255 / predicted_depth.max() + >>> depth = depth.detach().cpu().numpy() + >>> depth = Image.fromarray(depth.astype("uint8")) ```""" return_dict = return_dict if return_dict is not None else self.config.use_return_dict output_hidden_states = ( diff --git a/tests/models/glpn/test_modeling_glpn.py b/tests/models/glpn/test_modeling_glpn.py index 254c1135357147..81e95ab244f9aa 100644 --- a/tests/models/glpn/test_modeling_glpn.py +++ b/tests/models/glpn/test_modeling_glpn.py @@ -157,14 +157,6 @@ def setUp(self): self.model_tester = GLPNModelTester(self) self.config_tester = GLPNConfigTester(self, config_class=GLPNConfig) - @unittest.skip(reason="Failing after #32550") - def test_pipeline_depth_estimation(self): - pass - - @unittest.skip(reason="Failing after #32550") - def test_pipeline_depth_estimation_fp16(self): - pass - def test_config(self): self.config_tester.run_common_tests()