Skip to content

Commit

Permalink
fix(DPT,Depth-Anything) torch.export (#34103)
Browse files Browse the repository at this point in the history
* Fix torch.export issue in dpt based models

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Simplify the if statements

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Move activation definitions of zoe_depth to init()

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Add test_export for dpt and zoedepth

Signed-off-by: Phillip Kuznetsov <[email protected]>

* add depth anything

Signed-off-by: Phillip Kuznetsov <[email protected]>

* Remove zoedepth non-automated zoedepth changes and zoedepth test

Signed-off-by: Phillip Kuznetsov <[email protected]>

* [run_slow] dpt, depth_anything, zoedepth

Signed-off-by: Phillip Kuznetsov <[email protected]>

---------

Signed-off-by: Phillip Kuznetsov <[email protected]>
  • Loading branch information
philkuz authored Nov 20, 2024
1 parent 9d16441 commit 8cadf76
Show file tree
Hide file tree
Showing 5 changed files with 72 additions and 20 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -224,16 +224,16 @@ def forward(self, hidden_states, size=None):
hidden_states = hidden_states[::-1]

fused_hidden_states = []
# first layer only uses the last hidden_state
size = hidden_states[1].shape[2:]
fused_hidden_state = self.layers[0](hidden_states[0], size=size)
fused_hidden_states.append(fused_hidden_state)
fused_hidden_state = None

# looping from the last layer to the second
for idx, (hidden_state, layer) in enumerate(zip(hidden_states[1:], self.layers[1:])):
size = hidden_states[1:][idx + 1].shape[2:] if idx != (len(hidden_states[1:]) - 1) else None
for idx, (hidden_state, layer) in enumerate(zip(hidden_states, self.layers)):
size = hidden_states[idx + 1].shape[2:] if idx != (len(hidden_states) - 1) else None

fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)
if fused_hidden_state is None:
# first layer only uses the last hidden_state
fused_hidden_state = layer(hidden_state, size=size)
else:
fused_hidden_state = layer(fused_hidden_state, hidden_state, size=size)

fused_hidden_states.append(fused_hidden_state)

Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/dpt/modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -689,12 +689,13 @@ def forward(self, hidden_states):
hidden_states = hidden_states[::-1]

fused_hidden_states = []
# first layer only uses the last hidden_state
fused_hidden_state = self.layers[0](hidden_states[0])
fused_hidden_states.append(fused_hidden_state)
# looping from the last layer to the second
for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
fused_hidden_state = layer(fused_hidden_state, hidden_state)
fused_hidden_state = None
for hidden_state, layer in zip(hidden_states, self.layers):
if fused_hidden_state is None:
# first layer only uses the last hidden_state
fused_hidden_state = layer(hidden_state)
else:
fused_hidden_state = layer(fused_hidden_state, hidden_state)
fused_hidden_states.append(fused_hidden_state)

return fused_hidden_states
Expand Down
13 changes: 7 additions & 6 deletions src/transformers/models/zoedepth/modeling_zoedepth.py
Original file line number Diff line number Diff line change
Expand Up @@ -185,12 +185,13 @@ def forward(self, hidden_states):
hidden_states = hidden_states[::-1]

fused_hidden_states = []
# first layer only uses the last hidden_state
fused_hidden_state = self.layers[0](hidden_states[0])
fused_hidden_states.append(fused_hidden_state)
# looping from the last layer to the second
for hidden_state, layer in zip(hidden_states[1:], self.layers[1:]):
fused_hidden_state = layer(fused_hidden_state, hidden_state)
fused_hidden_state = None
for hidden_state, layer in zip(hidden_states, self.layers):
if fused_hidden_state is None:
# first layer only uses the last hidden_state
fused_hidden_state = layer(hidden_state)
else:
fused_hidden_state = layer(fused_hidden_state, hidden_state)
fused_hidden_states.append(fused_hidden_state)

return fused_hidden_states
Expand Down
28 changes: 28 additions & 0 deletions tests/models/depth_anything/test_modeling_depth_anything.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from transformers import DepthAnythingConfig, Dinov2Config
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import require_torch, require_vision, slow, torch_device

from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -290,3 +291,30 @@ def test_inference(self):
).to(torch_device)

self.assertTrue(torch.allclose(predicted_depth[0, :3, :3], expected_slice, atol=1e-4))

def test_export(self):
for strict in [True, False]:
with self.subTest(strict=strict):
if not is_torch_greater_or_equal_than_2_4:
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = (
DepthAnythingForDepthEstimation.from_pretrained("LiheYoung/depth-anything-small-hf")
.to(torch_device)
.eval()
)
image_processor = DPTImageProcessor.from_pretrained("LiheYoung/depth-anything-small-hf")
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)

exported_program = torch.export.export(
model,
args=(inputs["pixel_values"],),
strict=strict,
)
with torch.no_grad():
eager_outputs = model(**inputs)
exported_outputs = exported_program.module().forward(inputs["pixel_values"])
self.assertEqual(eager_outputs.predicted_depth.shape, exported_outputs.predicted_depth.shape)
self.assertTrue(
torch.allclose(eager_outputs.predicted_depth, exported_outputs.predicted_depth, atol=1e-4)
)
22 changes: 22 additions & 0 deletions tests/models/dpt/test_modeling_dpt.py
Original file line number Diff line number Diff line change
Expand Up @@ -18,6 +18,7 @@

from transformers import DPTConfig
from transformers.file_utils import is_torch_available, is_vision_available
from transformers.pytorch_utils import is_torch_greater_or_equal_than_2_4
from transformers.testing_utils import require_torch, require_vision, slow, torch_device

from ...test_configuration_common import ConfigTester
Expand Down Expand Up @@ -410,3 +411,24 @@ def test_post_processing_depth_estimation(self):
).squeeze()
self.assertTrue(output_enlarged.shape == expected_shape)
self.assertTrue(torch.allclose(predicted_depth_l, output_enlarged, rtol=1e-3))

def test_export(self):
for strict in [True, False]:
with self.subTest(strict=strict):
if not is_torch_greater_or_equal_than_2_4:
self.skipTest(reason="This test requires torch >= 2.4 to run.")
model = DPTForSemanticSegmentation.from_pretrained("Intel/dpt-large-ade").to(torch_device).eval()
image_processor = DPTImageProcessor.from_pretrained("Intel/dpt-large-ade")
image = prepare_img()
inputs = image_processor(images=image, return_tensors="pt").to(torch_device)

exported_program = torch.export.export(
model,
args=(inputs["pixel_values"],),
strict=strict,
)
with torch.no_grad():
eager_outputs = model(**inputs)
exported_outputs = exported_program.module().forward(inputs["pixel_values"])
self.assertEqual(eager_outputs.logits.shape, exported_outputs.logits.shape)
self.assertTrue(torch.allclose(eager_outputs.logits, exported_outputs.logits, atol=1e-4))

0 comments on commit 8cadf76

Please sign in to comment.