Skip to content

Commit

Permalink
Fix compile wrapper (#5455)
Browse files Browse the repository at this point in the history
compile wrapper will inherit from user module class and copy it's
__dict__

This should resolve most issues in #5383 except potential extra user
forward hooks.

@tohtana @loadams

Co-authored-by: Logan Adams <[email protected]>
Co-authored-by: Olatunji Ruwase <[email protected]>
Co-authored-by: Masahiro Tanaka <[email protected]>
  • Loading branch information
4 people authored May 8, 2024
1 parent 0fc19b6 commit 0b224ed
Show file tree
Hide file tree
Showing 3 changed files with 83 additions and 91 deletions.
163 changes: 82 additions & 81 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values):
return field_value


class CompiledModuleWrapper(torch.nn.Module):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
super().__init__()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

modules = self.__dict__.get('_modules')
modules['wrapped'] = module
self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def __getattr__(self, name):
return getattr(self.__dict__['wrapped'], name)

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn
def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None):

class wrapper(mod.__class__):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
self.__dict__ = module.__dict__.copy()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped,
backend=self._backend,
**self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn

return wrapper(mod, compile_config)
7 changes: 0 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,6 @@ def __getattr__(self, name):
return getattr(self, name)
elif name in dir(_module):
return getattr(_module, name)
elif isinstance(_module, CompiledModuleWrapper):
try:
return getattr(_module, name)
except AttributeError:
raise AttributeError(
f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'"
)
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

Expand Down
4 changes: 1 addition & 3 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine):

def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule) \
or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \
"model must base PipelineModule"
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"

assert self.zero_optimization_stage(
) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
Expand Down

0 comments on commit 0b224ed

Please sign in to comment.