diff --git a/deepspeed/model_implementations/transformers/ds_transformer.py b/deepspeed/model_implementations/transformers/ds_transformer.py index 360113b78a3d..8ebc4a6e27a1 100644 --- a/deepspeed/model_implementations/transformers/ds_transformer.py +++ b/deepspeed/model_implementations/transformers/ds_transformer.py @@ -158,7 +158,7 @@ def forward( input = input.to(target_dtype) with torch.no_grad(): - attention_output, key, value, context_outputtn_ctx, inp_norm = \ + attention_output, kv, context_outputtn_ctx, inp_norm = \ self.attention(input, input_mask, head_mask, @@ -172,7 +172,7 @@ def forward( alibi, **kwargs) - presents = (key, value) + #presents = (key, value) self.layer_past = presents if layer_past is None else None output = self.mlp(attention_output, input, inp_norm, self.attention.attn_ob) @@ -181,8 +181,10 @@ def forward( output = output.to(input_type) if get_present: - output = (output, presents) + output = (output, kv) + #import pdb; pdb.set_trace() + print(f"layer_id = {self.config.layer_id}") if self.config.return_single_tuple: return (output, ) elif self.config.return_tuple: diff --git a/deepspeed/ops/transformer/inference/ds_attention.py b/deepspeed/ops/transformer/inference/ds_attention.py index 24f710d22494..35965795a93b 100644 --- a/deepspeed/ops/transformer/inference/ds_attention.py +++ b/deepspeed/ops/transformer/inference/ds_attention.py @@ -9,6 +9,7 @@ from deepspeed import comm as dist from deepspeed.accelerator import get_accelerator from .op_binding import LinearOp, VectorMatMulOp, SoftmaxContextOp, QKVGemmOp, SoftmaxOp +from transformers import DynamicCache minus_inf = -10000.0 @@ -23,6 +24,7 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count data_type = self.config.dtype data_type_fp = torch.half if self.config.dtype == torch.int8 else self.config.dtype self.config.layer_id = DeepSpeedSelfAttention.num_layers + self.layer_idx = self.config.layer_id DeepSpeedSelfAttention.num_layers = DeepSpeedSelfAttention.num_layers + 1 device = get_accelerator().current_device_name() #if config.bigscience_bloom else 'cpu' if self.config.set_empty_params: @@ -89,7 +91,7 @@ def __init__(self, config, mp_group=None, q_scales=None, q_groups=1, merge_count torch.empty(self.hidden_size_per_partition * 3, dtype=data_type_fp, device=device) ] - def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): + def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids, cache_position): if isinstance(qkv_out, list) or isinstance(qkv_out, tuple): qkv_out = qkv_out[0] @@ -140,6 +142,7 @@ def forward(self, norm_w=None, norm_b=None, alibi=None, + cache_position=None, # TODO: Lev, optional cache position tensor **kwargs): if self.attn_qkvw is None: self._attn_qkvw, self._attn_qkvb = self._merge_qkv() @@ -165,20 +168,22 @@ def forward(self, token_idx = kwargs.get("token_idx", None) position_ids = kwargs.get("position_ids", None) - context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out, + #context_layer, key_layer, value_layer = self.compute_attention(qkv_out=qkv_out, + context_layer, kv_layer = self.compute_attention(qkv_out=qkv_out, input_mask=input_mask, layer_past=layer_past, alibi=alibi, is_prompt=is_prompt, token_idx=token_idx, - position_ids=position_ids) + position_ids=position_ids, + cache_position=cache_position) output = self.vector_matmul_func(input=context_layer, weight=self.attn_ow) inp_norm = qkv_out[-1] if self.config.mlp_after_attn and self.mp_group is not None and dist.get_world_size(group=self.mp_group) > 1: dist.all_reduce(output, group=self.mp_group) - return (output, key_layer, value_layer, context_layer, inp_norm) + return (output, kv_layer, context_layer, inp_norm) class BloomSelfAttention(DeepSpeedSelfAttention): @@ -221,7 +226,7 @@ def _split_tensor_along_last_dim(self, tensor, num_partitions, contiguous_split_ return tensor_list - def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids): + def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, token_idx, position_ids, cache_position): if isinstance(qkv_out, list) or isinstance(qkv_out, tuple): qkv_out = qkv_out[0] @@ -246,13 +251,24 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, t key_layer = key_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3], -1).transpose(-1, -2) value_layer = value_layer.transpose(1, 2).reshape(output_size[0] * output_size[1], output_size[3], -1) + #import pdb; pdb.set_trace() if layer_past is not None: - past_key, past_value = layer_past + #past_key, past_value = layer_past + # NEW 1 + cache_kwargs = {"cache_position": cache_position} + key_layer, value_layer = layer_past.update(key_layer, value_layer, self.layer_idx, cache_kwargs) + #import pdb; pdb.set_trace() + + # NEW 2 + #past_key, past_value = DynamicCache.from_legacy_cache(layer_past) + #cache_out = DynamicCache.from_legacy_cache(layer_past) + + # OLD # concatenate along seq_length dimension -> [batch_size, qk_length, num_heads, head_dim] - key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) - value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2) + #key_layer = torch.cat((past_key.type_as(key_layer), key_layer), dim=-1) + #value_layer = torch.cat((past_value.type_as(value_layer), value_layer), dim=-2) - presents = (key_layer, value_layer) + #presents = (key_layer, value_layer) # Raw attention scores. [batch_size * num_heads, q_length, k_length] matmul_result = torch.matmul(query_layer, key_layer) # change view to [batch_size, num_heads, q_length, k_length] @@ -293,9 +309,11 @@ def compute_attention(self, qkv_out, input_mask, layer_past, alibi, is_prompt, t context_layer.size(1), context_layer.shape[-1]) context_layer = self._transpose_for_context(context_layer) - key_layer = presents[0] - value_layer = presents[1] + #key_layer = presents[0] + #value_layer = presents[1] - return context_layer, key_layer, value_layer + outputs = (context_layer, layer_past) + + return outputs ###################### End of HF modeling_bloom addition ########################