diff --git a/deepspeed/runtime/compiler.py b/deepspeed/runtime/compiler.py index b5e4e33425d0..66fe29fbbea2 100644 --- a/deepspeed/runtime/compiler.py +++ b/deepspeed/runtime/compiler.py @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values): return field_value -class CompiledModuleWrapper(torch.nn.Module): - - def __init__(self, module, compile_config: Union[CompileConfig, None] = None): - super().__init__() - - assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." - - modules = self.__dict__.get('_modules') - modules['wrapped'] = module - self.__dict__['wrapped'] = module - self._is_compiled = False - self._backend = get_backend_fn(compile_config.backend) - self._compile_kwargs = compile_config.kwargs - self._compiler_fn = None - - def __getattr__(self, name): - return getattr(self.__dict__['wrapped'], name) - - def set_backend(self, backend: Union[str, Callable]): - """Set the backend for torch.compile. - - Args: - backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. - You can directly pass a function that works as a backend. - See also `backend` field in `CompileConfig` for more details. - """ - self._backend = get_backend_fn(backend) - - def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: - """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. - You can also pass a backend name with "backend" key to change the backend. - - Args: - kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. - """ - - if "backend" in kwargs: - raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") - self._compile_kwargs.update(kwargs) - - def set_compiler_fn(self, compiler_fn: Callable) -> None: - """Set a function to be used for compiling the module. - This function should take a torch.nn.Module as input and return a compiled module. - Note that other compile options are ignored when a compiler_fn is set. - - Example: - ```python - def my_compiler_fn(module: torch.nn.Module): - ... - return torch.compile(module, ...) - - engine.set_compiler_fn(my_compiler_fn) - ``` - """ - self._compiler_fn = compiler_fn - - def forward(self, *args, **kwargs) -> Any: - if not self.is_compiled: - if self._compiler_fn is None: - self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs) - else: - self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) - self._is_compiled = True - - return self.__dict__['wrapped'](*args, **kwargs) - - @property - def is_compiled(self) -> bool: - return self._is_compiled - - @property - def backend(self) -> Union[str, Callable]: - return self._backend - - @property - def torch_compile_kwargs(self) -> Dict[str, Any]: - return self._compile_kwargs - - @property - def compiler_fn(self) -> Union[Callable, None]: - return self._compiler_fn +def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None): + + class wrapper(mod.__class__): + + def __init__(self, module, compile_config: Union[CompileConfig, None] = None): + self.__dict__ = module.__dict__.copy() + + assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch." + + self.__dict__['wrapped'] = module + self._is_compiled = False + self._backend = get_backend_fn(compile_config.backend) + self._compile_kwargs = compile_config.kwargs + self._compiler_fn = None + + def set_backend(self, backend: Union[str, Callable]): + """Set the backend for torch.compile. + + Args: + backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module. + You can directly pass a function that works as a backend. + See also `backend` field in `CompileConfig` for more details. + """ + self._backend = get_backend_fn(backend) + + def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None: + """Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten. + You can also pass a backend name with "backend" key to change the backend. + + Args: + kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile. + """ + + if "backend" in kwargs: + raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.") + self._compile_kwargs.update(kwargs) + + def set_compiler_fn(self, compiler_fn: Callable) -> None: + """Set a function to be used for compiling the module. + This function should take a torch.nn.Module as input and return a compiled module. + Note that other compile options are ignored when a compiler_fn is set. + + Example: + ```python + def my_compiler_fn(module: torch.nn.Module): + ... + return torch.compile(module, ...) + + engine.set_compiler_fn(my_compiler_fn) + ``` + """ + self._compiler_fn = compiler_fn + + def forward(self, *args, **kwargs) -> Any: + if not self.is_compiled: + if self._compiler_fn is None: + self.__dict__['wrapped'] = torch.compile(self.wrapped, + backend=self._backend, + **self._compile_kwargs) + else: + self.__dict__['wrapped'] = self._compiler_fn(self.wrapped) + self._is_compiled = True + + return self.__dict__['wrapped'](*args, **kwargs) + + @property + def is_compiled(self) -> bool: + return self._is_compiled + + @property + def backend(self) -> Union[str, Callable]: + return self._backend + + @property + def torch_compile_kwargs(self) -> Dict[str, Any]: + return self._compile_kwargs + + @property + def compiler_fn(self) -> Union[Callable, None]: + return self._compiler_fn + + return wrapper(mod, compile_config) diff --git a/deepspeed/runtime/engine.py b/deepspeed/runtime/engine.py index 9a2b943b0992..34263444c1b7 100644 --- a/deepspeed/runtime/engine.py +++ b/deepspeed/runtime/engine.py @@ -469,13 +469,6 @@ def __getattr__(self, name): return getattr(self, name) elif name in dir(_module): return getattr(_module, name) - elif isinstance(_module, CompiledModuleWrapper): - try: - return getattr(_module, name) - except AttributeError: - raise AttributeError( - f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'" - ) else: raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'") diff --git a/deepspeed/runtime/pipe/engine.py b/deepspeed/runtime/pipe/engine.py index 1dda7f1aad32..be8fe1a368c6 100644 --- a/deepspeed/runtime/pipe/engine.py +++ b/deepspeed/runtime/pipe/engine.py @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine): def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs): super().__init__(*super_args, **super_kwargs) - assert isinstance(self.module, PipelineModule) \ - or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \ - "model must base PipelineModule" + assert isinstance(self.module, PipelineModule), "model must base PipelineModule" assert self.zero_optimization_stage( ) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"