diff --git a/tests/unit/ops/adam/test_cpu_adam.py b/tests/unit/ops/adam/test_cpu_adam.py index 0187a0c38fc6..467f48b03d8b 100644 --- a/tests/unit/ops/adam/test_cpu_adam.py +++ b/tests/unit/ops/adam/test_cpu_adam.py @@ -92,6 +92,8 @@ def test_fused_adam_equal(self, dtype, model_size): def test_torch_adamw_equal(self, dtype, model_size): if get_accelerator().is_available(): + if dtype == torch.half: + pytest.skip("torch.optim.AdamW with half precision inf/nan output.") if ("amd" in pytest.cpu_vendor) and (dtype == torch.half): pytest.skip("cpu-adam with half precision not supported on AMD CPUs") ref_param_device = get_accelerator().device_name()