From c3a13536fdcb4d56bd8d8b441df02a87e4130f1d Mon Sep 17 00:00:00 2001 From: Max Kovalenko Date: Thu, 8 Feb 2024 16:26:19 +0200 Subject: [PATCH 1/2] Fix attention mask handling in the Hybrid Engine Bloom flow The Bloom flow in Hybrid Engine applies the same transformation of the input mask already performed earlier in the transformers BloomModel::forward. This results in the non-convergence of scores, specifically in Deepspeed Chat on different accelerators, including CUDA and HPU. The fix removes the redundant 2-nd mask transformation and application, producing correct convergence. --- deepspeed/module_inject/containers/bloom.py | 1 + deepspeed/ops/transformer/inference/ds_attention.py | 2 +- 2 files changed, 2 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index 05f30eec8d85..c103b17e1559 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -23,6 +23,7 @@ def __init__(self, **kwargs): # All model specific things should be defined here instead of the base class. self.bigscience_bloom = True + self.triangular_masking = False def create_module(self, config=None): _config = config if config is not None else self.ds_model_config diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index eb6ce2f75c69..cc29c7bbe283 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -255,7 +255,7 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi): input_mask = input_mask.long() attention_probs = self.softmax_func(attn_scores=attention_scores, - attn_mask=((1 - input_mask).to(target_dtype) * minus_inf), + attn_mask=input_mask.to(target_dtype) * minus_inf, alibi=alibi, triangular=(self.config.triangular_masking and (attention_scores.shape[-2] > 1)), From e7e6d75920ee0dcc9cd2b0147cf932ac7bcf05c7 Mon Sep 17 00:00:00 2001 From: Max Kovalenko Date: Thu, 8 Feb 2024 16:26:19 +0200 Subject: [PATCH 2/2] Fix attention mask handling in the Hybrid Engine Bloom flow The BLOOM flow in Hybrid Engine applies the same transformation of the input mask already performed earlier in the transformers BloomModel::forward. This results in the non-convergence of scores, specifically in Deepspeed Chat on different accelerators, including CUDA and HPU. An optional config parameter invert_mask is introduced into DeepSpeedInferenceConfig (True by default), which enables skipping the invert operation for some transformer implementations, such as BLOOM. --- deepspeed/module_inject/containers/bloom.py | 1 + deepspeed/ops/transformer/inference/config.py | 5 ++++- deepspeed/ops/transformer/inference/ds_attention.py | 4 ++++ 3 files changed, 9 insertions(+), 1 deletion(-) diff --git a/deepspeed/module_inject/containers/bloom.py b/deepspeed/module_inject/containers/bloom.py index c103b17e1559..a78ac8120346 100644 --- a/deepspeed/module_inject/containers/bloom.py +++ b/deepspeed/module_inject/containers/bloom.py @@ -30,6 +30,7 @@ def create_module(self, config=None): self.module = DeepSpeedBloomInference(_config, mp_group=self.mp_group) self.module.config.scale_attention = self.scale_attention + self.module.config.invert_mask = False return self.module def attention_qkv_mp(self, mp_replace, reversed_dim=False): diff --git a/deepspeed/ops/transformer/inference/config.py b/deepspeed/ops/transformer/inference/config.py index d5aff4f541f7..9709328cc133 100644 --- a/deepspeed/ops/transformer/inference/config.py +++ b/deepspeed/ops/transformer/inference/config.py @@ -43,6 +43,7 @@ class DeepSpeedInferenceConfig(TransformerConfig): return_tuple: if True, returns the transformer output as a tuple, otherwise returns as a tensor bigscience_bloom: This flag is added temporarily for supporting the BLOOM-176B model architecture. use_triton: This flag is to enable triton kernels in inference or not. + invert_mask: If True, the attention mask is inverted when passed to attention block. """ def __init__(self, @@ -80,7 +81,8 @@ def __init__(self, use_triton=False, triton_autotune=False, num_kv=-1, - rope_theta=10000): + rope_theta=10000, + invert_mask=True): super(DeepSpeedInferenceConfig, self).__init__(hidden_size, (intermediate_size if intermediate_size > 0 else 4 * hidden_size), heads, num_hidden_layers) @@ -116,6 +118,7 @@ def __init__(self, self.triton_autotune = triton_autotune self.num_kv = num_kv self.rope_theta = rope_theta + self.invert_mask = invert_mask @classmethod def from_dict(cls, json_object): diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index cc29c7bbe283..56cf3c7b6a2c 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -254,6 +254,10 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi): if input_mask.dtype == torch.bool: input_mask = input_mask.long() + # Invert input_mask per transformer implementation (eg, in BLOOM, it's already inverted) + if self.config.invert_mask: + input_mask = 1 - input_mask + attention_probs = self.softmax_func(attn_scores=attention_scores, attn_mask=input_mask.to(target_dtype) * minus_inf, alibi=alibi,