From c6f48c98000e0361677d863ff7f7e071acf130d4 Mon Sep 17 00:00:00 2001 From: Omar Elayan Date: Tue, 4 Jun 2024 14:03:38 +0300 Subject: [PATCH] Add inference builder for cpu --- accelerator/cpu_accelerator.py | 6 ++-- op_builder/cpu/__init__.py | 1 + op_builder/cpu/transformer_inference.py | 39 +++++++++++++++++++++++++ 3 files changed, 44 insertions(+), 2 deletions(-) create mode 100644 op_builder/cpu/transformer_inference.py diff --git a/accelerator/cpu_accelerator.py b/accelerator/cpu_accelerator.py index bd11d034f312..6ada275d7b19 100644 --- a/accelerator/cpu_accelerator.py +++ b/accelerator/cpu_accelerator.py @@ -301,9 +301,9 @@ def get_op_builder(self, class_name): # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed # if successful this also means we're doing a local install and not JIT compile path from op_builder import __deepspeed__ # noqa: F401 # type: ignore - from op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, InferenceBuilder, NotImplementedBuilder except ImportError: - from deepspeed.ops.op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, NotImplementedBuilder + from deepspeed.ops.op_builder.cpu import CCLCommBuilder, ShareMemCommBuilder, FusedAdamBuilder, CPUAdamBuilder, InferenceBuilder, NotImplementedBuilder if class_name == "CCLCommBuilder": return CCLCommBuilder @@ -311,6 +311,8 @@ def get_op_builder(self, class_name): return ShareMemCommBuilder elif class_name == "FusedAdamBuilder": return FusedAdamBuilder + elif class_name == "InferenceBuilder": + return InferenceBuilder elif class_name == "CPUAdamBuilder": return CPUAdamBuilder else: diff --git a/op_builder/cpu/__init__.py b/op_builder/cpu/__init__.py index 30238add3f90..d5bdfaee2075 100644 --- a/op_builder/cpu/__init__.py +++ b/op_builder/cpu/__init__.py @@ -7,4 +7,5 @@ from .comm import CCLCommBuilder, ShareMemCommBuilder from .fused_adam import FusedAdamBuilder from .cpu_adam import CPUAdamBuilder +from .transformer_inference import InferenceBuilder from .no_impl import NotImplementedBuilder diff --git a/op_builder/cpu/transformer_inference.py b/op_builder/cpu/transformer_inference.py new file mode 100644 index 000000000000..2623602ad91e --- /dev/null +++ b/op_builder/cpu/transformer_inference.py @@ -0,0 +1,39 @@ +# Copyright (c) Microsoft Corporation. +# SPDX-License-Identifier: Apache-2.0 + +# DeepSpeed Team + +import importlib + +try: + # is op_builder from deepspeed or a 3p version? this should only succeed if it's deepspeed + # if successful this also means we're doing a local install and not JIT compile path + from op_builder import __deepspeed__ # noqa: F401 + from op_builder.builder import OpBuilder +except ImportError: + from deepspeed.ops.op_builder.builder import OpBuilder + + +class InferenceBuilder(OpBuilder): + BUILD_VAR = "DS_BUILD_TRANSFORMER_INFERENCE" + NAME = "transformer_inference" + + def __init__(self, name=None): + name = self.NAME if name is None else name + super().__init__(name=self.NAME) + + def absolute_name(self): + return f"deepspeed.ops.transformer.inference.{self.NAME}_op" + + def sources(self): + return [] + + def load(self, verbose=True): + if self.name in __class__._loaded_ops: + return __class__._loaded_ops[self.name] + + from deepspeed.git_version_info import installed_ops # noqa: F401 + if installed_ops.get(self.name, False): + op_module = importlib.import_module(self.absolute_name()) + __class__._loaded_ops[self.name] = op_module + return op_module