Skip to content

Commit

Permalink
Fix dtype casting in swinv2 and swinv2sr to allow non-FP32 inference (#…
Browse files Browse the repository at this point in the history
…31589)

* Fix dtype casting in modeling_swin2sr to allow non-FP32 inference

* Fix formattting

* Fix for swinv2 too

* Update src/transformers/models/swin2sr/modeling_swin2sr.py

Co-authored-by: amyeroberts <[email protected]>

* Update src/transformers/models/swinv2/modeling_swinv2.py

Co-authored-by: amyeroberts <[email protected]>

* Add FP16 tests for swin2sr and swinv2

* [run_slow] swin2sr, swinv2

* [run_slow] swin2sr, swinv2

---------

Co-authored-by: amyeroberts <[email protected]>
  • Loading branch information
aliencaocao and amyeroberts authored Jun 26, 2024
1 parent a3fb96a commit 1f9f57a
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 0 deletions.
2 changes: 2 additions & 0 deletions src/transformers/models/swin2sr/modeling_swin2sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -301,6 +301,8 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
)
# set to same dtype as mlp weight
relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)

# get pair-wise relative position index for each token inside the window
Expand Down
2 changes: 2 additions & 0 deletions src/transformers/models/swinv2/modeling_swinv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -492,6 +492,8 @@ def __init__(self, config, dim, num_heads, window_size, pretrained_window_size=[
relative_coords_table = (
torch.sign(relative_coords_table) * torch.log2(torch.abs(relative_coords_table) + 1.0) / math.log2(8)
)
# set to same dtype as mlp weight
relative_coords_table = relative_coords_table.to(next(self.continuous_position_bias_mlp.parameters()).dtype)
self.register_buffer("relative_coords_table", relative_coords_table, persistent=False)

# get pair-wise relative position index for each token inside the window
Expand Down
21 changes: 21 additions & 0 deletions tests/models/swin2sr/test_modeling_swin2sr.py
Original file line number Diff line number Diff line change
Expand Up @@ -333,3 +333,24 @@ def test_inference_image_super_resolution_head(self):
[[0.5458, 0.5546, 0.5638], [0.5526, 0.5565, 0.5651], [0.5396, 0.5426, 0.5621]]
).to(torch_device)
self.assertTrue(torch.allclose(outputs.reconstruction[0, 0, :3, :3], expected_slice, atol=1e-4))

def test_inference_fp16(self):
processor = Swin2SRImageProcessor()
model = Swin2SRForImageSuperResolution.from_pretrained(
"caidas/swin2SR-classical-sr-x2-64", torch_dtype=torch.float16
).to(torch_device)

image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = processor(images=image, return_tensors="pt").to(model.dtype).to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs)

# verify the logits
expected_shape = torch.Size([1, 3, 976, 1296])
self.assertEqual(outputs.reconstruction.shape, expected_shape)
expected_slice = torch.tensor(
[[0.5454, 0.5542, 0.5640], [0.5518, 0.5562, 0.5649], [0.5391, 0.5425, 0.5620]], dtype=model.dtype
).to(torch_device)
self.assertTrue(torch.allclose(outputs.reconstruction[0, 0, :3, :3], expected_slice, atol=1e-4))
20 changes: 20 additions & 0 deletions tests/models/swinv2/test_modeling_swinv2.py
Original file line number Diff line number Diff line change
Expand Up @@ -487,6 +487,26 @@ def test_inference_image_classification_head(self):
expected_slice = torch.tensor([-0.3947, -0.4306, 0.0026]).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

@slow
def test_inference_fp16(self):
model = Swinv2ForImageClassification.from_pretrained(
"microsoft/swinv2-tiny-patch4-window8-256", torch_dtype=torch.float16
).to(torch_device)
image_processor = self.default_image_processor

image = Image.open("./tests/fixtures/tests_samples/COCO/000000039769.png")
inputs = image_processor(images=image, return_tensors="pt").to(model.dtype).to(torch_device)

# forward pass
with torch.no_grad():
outputs = model(**inputs)

# verify the logits
expected_shape = torch.Size((1, 1000))
self.assertEqual(outputs.logits.shape, expected_shape)
expected_slice = torch.tensor([-0.3938, -0.4290, 0.0020], dtype=model.dtype).to(torch_device)
self.assertTrue(torch.allclose(outputs.logits[0, :3], expected_slice, atol=1e-4))

@slow
def test_inference_interpolate_pos_encoding(self):
# Swinv2 models have an `interpolate_pos_encoding` argument in their forward method,
Expand Down

0 comments on commit 1f9f57a

Please sign in to comment.