Skip to content

Commit

Permalink
Merge branch 'master' into lyj/enable_phi3_mini
Browse files Browse the repository at this point in the history
  • Loading branch information
lekurile authored May 8, 2024
2 parents 6bef2f2 + 0b224ed commit da8127d
Show file tree
Hide file tree
Showing 7 changed files with 91 additions and 97 deletions.
2 changes: 1 addition & 1 deletion deepspeed/linear/optimized_linear.py
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,7 @@ def __init__(self,
self.bias = bias
self.lora_config = lora_config
self.quantization_config = quantization_config
device = get_accelerator().current_device() if device is None else device
device = get_accelerator().current_device_name() if device is None else device
assert self.lora_config is not None, "DSOptimizedLinear requires a LoRA config"

self.zero_shards = self.lora_config.base_weight_sharding
Expand Down
163 changes: 82 additions & 81 deletions deepspeed/runtime/compiler.py
Original file line number Diff line number Diff line change
Expand Up @@ -83,84 +83,85 @@ def validate_enabled(cls, field_value, values):
return field_value


class CompiledModuleWrapper(torch.nn.Module):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
super().__init__()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

modules = self.__dict__.get('_modules')
modules['wrapped'] = module
self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def __getattr__(self, name):
return getattr(self.__dict__['wrapped'], name)

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped, backend=self._backend, **self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn
def CompiledModuleWrapper(mod, compile_config: Union[CompileConfig, None] = None):

class wrapper(mod.__class__):

def __init__(self, module, compile_config: Union[CompileConfig, None] = None):
self.__dict__ = module.__dict__.copy()

assert is_compile_supported(), "torch.compile is not supported on this version of PyTorch."

self.__dict__['wrapped'] = module
self._is_compiled = False
self._backend = get_backend_fn(compile_config.backend)
self._compile_kwargs = compile_config.kwargs
self._compiler_fn = None

def set_backend(self, backend: Union[str, Callable]):
"""Set the backend for torch.compile.
Args:
backend (Union[str, Callable]): backend name or a function that takes a torch.nn.Module and returns a compiled module.
You can directly pass a function that works as a backend.
See also `backend` field in `CompileConfig` for more details.
"""
self._backend = get_backend_fn(backend)

def set_torch_compile_kwargs(self, kwargs: Dict[str, Union[str, Any]]) -> None:
"""Set kwargs for torch.compile. Kwargs that are set in DeepSpeed config will be overwritten.
You can also pass a backend name with "backend" key to change the backend.
Args:
kwargs (Dict[str, Union[str, Any]]): kwargs passed to torch.compile.
"""

if "backend" in kwargs:
raise ValueError("backend cannot be set as compile kwargs. Use set_backend instead.")
self._compile_kwargs.update(kwargs)

def set_compiler_fn(self, compiler_fn: Callable) -> None:
"""Set a function to be used for compiling the module.
This function should take a torch.nn.Module as input and return a compiled module.
Note that other compile options are ignored when a compiler_fn is set.
Example:
```python
def my_compiler_fn(module: torch.nn.Module):
...
return torch.compile(module, ...)
engine.set_compiler_fn(my_compiler_fn)
```
"""
self._compiler_fn = compiler_fn

def forward(self, *args, **kwargs) -> Any:
if not self.is_compiled:
if self._compiler_fn is None:
self.__dict__['wrapped'] = torch.compile(self.wrapped,
backend=self._backend,
**self._compile_kwargs)
else:
self.__dict__['wrapped'] = self._compiler_fn(self.wrapped)
self._is_compiled = True

return self.__dict__['wrapped'](*args, **kwargs)

@property
def is_compiled(self) -> bool:
return self._is_compiled

@property
def backend(self) -> Union[str, Callable]:
return self._backend

@property
def torch_compile_kwargs(self) -> Dict[str, Any]:
return self._compile_kwargs

@property
def compiler_fn(self) -> Union[Callable, None]:
return self._compiler_fn

return wrapper(mod, compile_config)
7 changes: 0 additions & 7 deletions deepspeed/runtime/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -469,13 +469,6 @@ def __getattr__(self, name):
return getattr(self, name)
elif name in dir(_module):
return getattr(_module, name)
elif isinstance(_module, CompiledModuleWrapper):
try:
return getattr(_module, name)
except AttributeError:
raise AttributeError(
f"None of {type(self).__name__}, CompiledModuleWrapper, or the wrapped model has the attribute '{name}'"
)
else:
raise AttributeError(f"'{type(self).__name__}' object has no attribute '{name}'")

Expand Down
2 changes: 1 addition & 1 deletion deepspeed/runtime/fp16/fused_optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -241,7 +241,7 @@ def _get_norm_mask_idx(self, group):
group_mask_idx_list.append([grad_flat_st_idx, grad_flat_en_idx])
grad_flat_st_idx = grad_flat_en_idx

return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device())
return torch.tensor(group_mask_idx_list, device=get_accelerator().current_device_name())

def step(self, closure=None):
"""
Expand Down
4 changes: 1 addition & 3 deletions deepspeed/runtime/pipe/engine.py
Original file line number Diff line number Diff line change
Expand Up @@ -67,9 +67,7 @@ class PipelineEngine(DeepSpeedEngine):

def __init__(self, has_bool_tensors=False, *super_args, **super_kwargs):
super().__init__(*super_args, **super_kwargs)
assert isinstance(self.module, PipelineModule) \
or (hasattr(self.module, 'wrapped') and isinstance(self.module.wrapped, PipelineModule)), \
"model must base PipelineModule"
assert isinstance(self.module, PipelineModule), "model must base PipelineModule"

assert self.zero_optimization_stage(
) < ZeroStageEnum.gradients, "ZeRO-2 and ZeRO-3 are incompatible with pipeline parallelism"
Expand Down
8 changes: 5 additions & 3 deletions deepspeed/runtime/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -171,7 +171,7 @@ def get_norm_with_moe_layers_fast(all_groups_norm, group):
# This implementation standardizes the grad_norm across ranks. A more precise implementation can be found in 'get_norm_with_moe_layers'.
# Need to allreduce (avg) the norms across different ranks because moe params will not be synced during allreduce
scaled_norm = all_groups_norm * 1.0 / float(dist.get_world_size(group=group))
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device(), dtype=torch.float)
scaled_norm_tensor = torch.tensor(scaled_norm, device=get_accelerator().current_device_name(), dtype=torch.float)
dist.all_reduce(scaled_norm_tensor, group=group)
all_groups_norm = scaled_norm_tensor.item()
#print(f"old = {all_groups_norm_old} and new = {all_groups_norm} at rank: {deepspeed.comm.get_rank()}")
Expand Down Expand Up @@ -424,9 +424,11 @@ def get_flattened_grad_norm(parameters, norm_type=2, mpu=None, grad_norm_mask=No
# # mask_tensor_ = torch.zeros_like(p, device=p.device, dtype=bool)
# # for mask_idx in grad_norm_mask[idx]:
# # mask_tensor_[mask_idx[0]:mask_idx[1]] = True
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device(),
cum_sum_pairs = torch.tensor([1, -1], device=get_accelerator().current_device_name(),
dtype=p.dtype).repeat(grad_norm_mask[idx].shape[0], 1)
mask_tensor = torch.zeros(p.shape[0] + 1, device=get_accelerator().current_device(), dtype=p.dtype)
mask_tensor = torch.zeros(p.shape[0] + 1,
device=get_accelerator().current_device_name(),
dtype=p.dtype)
mask_tensor = mask_tensor.scatter_(0, grad_norm_mask[idx].view(-1),
cum_sum_pairs.view(-1)).cumsum(0).bool()[:-1]

Expand Down
2 changes: 1 addition & 1 deletion tests/unit/moe/test_moe.py
Original file line number Diff line number Diff line change
Expand Up @@ -177,7 +177,7 @@ class TestTopk(DistributedTest):
world_size = 2

def test(self):
device = get_accelerator().current_device()
device = get_accelerator().current_device_name()
if dist.get_rank() == 0:
logits = torch.rand(2, 2, device=device)
elif dist.get_rank() == 1:
Expand Down

0 comments on commit da8127d

Please sign in to comment.