diff --git a/tests/unit/ops/transformer/inference/test_bias_geglu.py b/tests/unit/ops/transformer/inference/test_bias_geglu.py index 05de4fbb4cf8..c995d2a8c46d 100644 --- a/tests/unit/ops/transformer/inference/test_bias_geglu.py +++ b/tests/unit/ops/transformer/inference/test_bias_geglu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_geglu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_bias_gelu.py b/tests/unit/ops/transformer/inference/test_bias_gelu.py index b69030e87ace..e3a3bad63961 100644 --- a/tests/unit/ops/transformer/inference/test_bias_gelu.py +++ b/tests/unit/ops/transformer/inference/test_bias_gelu.py @@ -16,8 +16,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_gelu_reference(activations, bias): # Expected behavior is that of casting to float32 internally and using the tanh approximation diff --git a/tests/unit/ops/transformer/inference/test_bias_relu.py b/tests/unit/ops/transformer/inference/test_bias_relu.py index 57134665b241..69078f9f7646 100644 --- a/tests/unit/ops/transformer/inference/test_bias_relu.py +++ b/tests/unit/ops/transformer/inference/test_bias_relu.py @@ -15,8 +15,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def run_bias_relu_reference(activations, bias): # Expected behavior is that of casting to float32 internally diff --git a/tests/unit/ops/transformer/inference/test_gelu.py b/tests/unit/ops/transformer/inference/test_gelu.py index 5f820ef3b579..a58abfdb100c 100644 --- a/tests/unit/ops/transformer/inference/test_gelu.py +++ b/tests/unit/ops/transformer/inference/test_gelu.py @@ -9,12 +9,11 @@ from deepspeed.ops.op_builder import InferenceBuilder from deepspeed.ops.transformer import DeepSpeedInferenceConfig from deepspeed.ops.transformer.inference.op_binding.bias_gelu import BiasGeluOp +from deepspeed.utils.torch import required_torch_version if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype @@ -23,14 +22,11 @@ def allclose(x, y): def version_appropriate_gelu(activations): - global torch_minor_version - if torch_minor_version is None: - torch_minor_version = int(torch.__version__.split('.')[1]) - # If torch version = 1.12 - if torch_minor_version < 12: - return torch.nn.functional.gelu(activations) - else: + # gelu behavior changes (correctly) in torch 1.12 + if required_torch_version(min_version=1.12): return torch.nn.functional.gelu(activations, approximate='tanh') + else: + return torch.nn.functional.gelu(activations) def run_gelu_reference(activations): diff --git a/tests/unit/ops/transformer/inference/test_matmul.py b/tests/unit/ops/transformer/inference/test_matmul.py index 559aa2c60afe..2ab195ee0115 100644 --- a/tests/unit/ops/transformer/inference/test_matmul.py +++ b/tests/unit/ops/transformer/inference/test_matmul.py @@ -12,7 +12,6 @@ pytest.skip("Inference ops are not available on this system", allow_module_level=True) inference_module = None -torch_minor_version = None def allclose(x, y): diff --git a/tests/unit/ops/transformer/inference/test_softmax.py b/tests/unit/ops/transformer/inference/test_softmax.py index e582be1b926a..83785ac38ebb 100644 --- a/tests/unit/ops/transformer/inference/test_softmax.py +++ b/tests/unit/ops/transformer/inference/test_softmax.py @@ -11,8 +11,6 @@ if not deepspeed.ops.__compatible_ops__[InferenceBuilder.NAME]: pytest.skip("Inference ops are not available on this system", allow_module_level=True) -torch_minor_version = None - def allclose(x, y): assert x.dtype == y.dtype