From 8cadf76e1c72eabbff24099c5c0a2a98edbb00ef Mon Sep 17 00:00:00 2001 From: Phillip Kuznetsov Date: Wed, 20 Nov 2024 02:31:21 -0800 Subject: [PATCH] fix(DPT,Depth-Anything) `torch.export` (#34103) * Fix torch.export issue in dpt based models Signed-off-by: Phillip Kuznetsov * Simplify the if statements Signed-off-by: Phillip Kuznetsov * Move activation definitions of zoe_depth to init() Signed-off-by: Phillip Kuznetsov * Add test_export for dpt and zoedepth Signed-off-by: Phillip Kuznetsov * add depth anything Signed-off-by: Phillip Kuznetsov * Remove zoedepth non-automated zoedepth changes and zoedepth test Signed-off-by: Phillip Kuznetsov * [run_slow] dpt, depth_anything, zoedepth Signed-off-by: Phillip Kuznetsov --------- Signed-off-by: Phillip Kuznetsov --- .../depth_anything/modeling_depth_anything.py | 16 +++++------ src/transformers/models/dpt/modeling_dpt.py | 13 +++++---- .../models/zoedepth/modeling_zoedepth.py | 13 +++++---- .../test_modeling_depth_anything.py | 28 +++++++++++++++++++ tests/models/dpt/test_modeling_dpt.py | 22 +++++++++++++++ 5 files changed, 72 insertions(+), 20 deletions(-) diff --git a/src/transformers/models/depth_anything/modeling_depth_anything.py b/src/transformers/models/depth_anything/modeling_depth_anything.py index 59c628786328e6..4667c413457b19 100644 --- a/src/transformers/models/depth_anything/modeling_depth_anything.py +++ b/src/transformers/models/depth_anything/modeling_depth_anything.py @@ -224,16 +224,16 @@ def forward(self, hidden_states, size=None): hidden_states = hidden_states[::-1] fused_hidden_states = [] - # first layer only uses the last hidden_state - size = hidden_states[1].shape[2:] - fused_hidden_state = self.layers[0](hidden_states[0], size=size) - fused_hidden_states.append(fused_hidden_state) + fused_hidden_state = None - # looping from the last layer to the second - for idx, (hidden_state, layer) in enumerate(zip(hidden_states[1:], self.layers[1:])): - size = hidden_states[1:][idx + 1].shape[2:] if idx != (len(hidden_states[1:]) - 1) else None + for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)): + size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None - fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size) + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state, size=size) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size) fused_hidden_states.append(fused_hidden_state) diff --git a/src/transformers/models/dpt/modeling_dpt.py b/src/transformers/models/dpt/modeling_dpt.py index 2d4654a234c2c6..5886d288b88271 100755 --- a/src/transformers/models/dpt/modeling_dpt.py +++ b/src/transformers/models/dpt/modeling_dpt.py @@ -689,12 +689,13 @@ def forward(self, hidden_states): hidden_states = hidden_states[::-1] fused_hidden_states = [] - # first layer only uses the last hidden_state - fused_hidden_state = self.layers[0](hidden_states[0]) - fused_hidden_states.append(fused_hidden_state) - # looping from the last layer to the second - for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]): - fused_hidden_state = layer(fused_hidden_state, hidden_state) + fused_hidden_state = None + for hidden_state, layer in zip(hidden_states, self.layers): + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state) fused_hidden_states.append(fused_hidden_state) return fused_hidden_states diff --git a/src/transformers/models/zoedepth/modeling_zoedepth.py b/src/transformers/models/zoedepth/modeling_zoedepth.py index 979b78aba678a5..5cbbdcdc04b756 100644 --- a/src/transformers/models/zoedepth/modeling_zoedepth.py +++ b/src/transformers/models/zoedepth/modeling_zoedepth.py @@ -185,12 +185,13 @@ def forward(self, hidden_states): hidden_states = hidden_states[::-1] fused_hidden_states = [] - # first layer only uses the last hidden_state - fused_hidden_state = self.layers[0](hidden_states[0]) - fused_hidden_states.append(fused_hidden_state) - # looping from the last layer to the second - for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]): - fused_hidden_state = layer(fused_hidden_state, hidden_state) + fused_hidden_state = None + for hidden_state, layer in zip(hidden_states, self.layers): + if fused_hidden_state is None: + # first layer only uses the last hidden_state + fused_hidden_state = layer(hidden_state) + else: + fused_hidden_state = layer(fused_hidden_state, hidden_state) fused_hidden_states.append(fused_hidden_state) return fused_hidden_states diff --git a/tests/models/depth_anything/test_modeling_depth_anything.py b/tests/models/depth_anything/test_modeling_depth_anything.py index 344d949fa4fe6c..6e7b423e9ec35f 100644 --- a/tests/models/depth_anything/test_modeling_depth_anything.py +++ b/tests/models/depth_anything/test_modeling_depth_anything.py @@ -18,6 +18,7 @@ from transformers import DepthAnythingConfig, Dinov2Config from transformers.file_utils import is_torch_available, is_vision_available +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester @@ -290,3 +291,30 @@ def test_inference(self): ).to(torch_device) self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4)) + + def test_export(self): + for strict in [True, False]: + with self.subTest(strict=strict): + if not is_torch_greater_or_equal_than_2_4: + self.skipTest(reason="This test requires torch >= 2.4 to run.") + model = ( + DepthAnythingForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf") + .to(torch_device) + .eval() + ) + image_processor = DPTImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf") + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + exported_program = torch.export.export( + model, + args=(inputs["pixel_values"],), + strict=strict, + ) + with torch.no_grad(): + eager_outputs = model(**inputs) + exported_outputs = exported_program.module().forward(inputs["pixel_values"]) + self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape) + self.assertTrue( + torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4) + ) diff --git a/tests/models/dpt/test_modeling_dpt.py b/tests/models/dpt/test_modeling_dpt.py index 376ea8b310080d..7f841fbb2efc58 100644 --- a/tests/models/dpt/test_modeling_dpt.py +++ b/tests/models/dpt/test_modeling_dpt.py @@ -18,6 +18,7 @@ from transformers import DPTConfig from transformers.file_utils import is_torch_available, is_vision_available +from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4 from transformers.testing_utils import require_torch, require_vision, slow, torch_device from ...test_configuration_common import ConfigTester @@ -410,3 +411,24 @@ def test_post_processing_depth_estimation(self): ).squeeze() self.assertTrue(output_enlarged.shape == expected_shape) self.assertTrue(torch.allclose(predicted_depth_l, output_enlarged, rtol=1e-3)) + + def test_export(self): + for strict in [True, False]: + with self.subTest(strict=strict): + if not is_torch_greater_or_equal_than_2_4: + self.skipTest(reason="This test requires torch >= 2.4 to run.") + model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(torch_device).eval() + image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade") + image = prepare_img() + inputs = image_processor(images=image, return_tensors="pt").to(torch_device) + + exported_program = torch.export.export( + model, + args=(inputs["pixel_values"],), + strict=strict, + ) + with torch.no_grad(): + eager_outputs = model(**inputs) + exported_outputs = exported_program.module().forward(inputs["pixel_values"]) + self.assertEqual(eager_outputs.logits.shape, exported_outputs.logits.shape) + self.assertTrue(torch.allclose(eager_outputs.logits, exported_outputs.logits, atol=1e-4))