Skip to content

Commit

Permalink
some cleaning
Browse files Browse the repository at this point in the history
  • Loading branch information
fxmarty committed Jun 27, 2024
1 parent 3760102 commit 02ac451
Show file tree
Hide file tree
Showing 2 changed files with 4 additions and 44 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,6 @@ def __init__(
prefix: str,
config,
weights,
layer_idx,
):
super().__init__()
self.num_heads = config.num_attention_heads
Expand Down Expand Up @@ -144,7 +143,6 @@ def __init__(

self.query_key_value = load_attention(config, prefix, weights, index)
self.index = index
self.layer_idx = layer_idx

o_proj = TensorParallelRowLinear.load(
config,
Expand All @@ -165,8 +163,6 @@ def __init__(
0, self.num_key_value_heads, dtype=torch.int32, device=weights.device
).repeat_interleave(self.num_groups)

self.step = 0

def forward(
self,
hidden_states,
Expand Down Expand Up @@ -198,18 +194,6 @@ def forward(
# output tensor
attn_output = torch.empty_like(query)

if self.layer_idx < 4:
torch.save(query, f"query_states_step{self.step}_layer{self.layer_idx}.pt")
if cu_seqlen_prefill is not None:
torch.save(
torch.select(kv, dim=1, index=0),
f"key_states_step{self.step}_layer{self.layer_idx}.pt",
)
torch.save(
torch.select(kv, dim=1, index=1),
f"value_states_step{self.step}_layer{self.layer_idx}.pt",
)

# Prefill
if cu_seqlen_prefill is not None:
# flash attention
Expand All @@ -236,14 +220,9 @@ def forward(
max_s,
)

attn_output = attn_output.view(-1, self.num_heads * self.head_size)
if self.layer_idx < 4:
torch.save(
attn_output, f"attn_output_step{self.step}_layer{self.layer_idx}.pt"
)

self.step += 1
return self.o_proj(attn_output, adapter_data)
return self.o_proj(
attn_output.view(-1, self.num_heads * self.head_size), adapter_data
)


class LlamaMLP(nn.Module):
Expand Down Expand Up @@ -342,14 +321,13 @@ def forward(self, hidden_states, adapter_data):


class FlashLlamaLayer(nn.Module):
def __init__(self, index, prefix, config, weights, layer_idx):
def __init__(self, index, prefix, config, weights):
super().__init__()
self.self_attn = FlashLlamaAttention(
index=index,
prefix=f"{prefix}.self_attn",
config=config,
weights=weights,
layer_idx=layer_idx,
)
self.mlp = LlamaMLP(
prefix=f"{prefix}.mlp", config=config, weights=weights, index=index
Expand Down Expand Up @@ -422,7 +400,6 @@ def __init__(self, prefix, config, weights):
),
config=config,
weights=weights,
layer_idx=layer_id,
)
for layer_id in range(config.num_hidden_layers)
]
Expand Down
17 changes: 0 additions & 17 deletions server/text_generation_server/models/flash_causal_lm.py
Original file line number Diff line number Diff line change
Expand Up @@ -1149,23 +1149,6 @@ def forward(
cuda_graph = None

if cu_seqlen_prefill is not None or cuda_graph is None:
logger.info(f"input_ids {input_ids} {input_ids.shape}")
logger.info(f"position_ids {position_ids} {position_ids.shape}")
logger.info(
f"cu_seqlen_prefill {cu_seqlen_prefill} {cu_seqlen_prefill.shape if cu_seqlen_prefill is not None else 'NONE'}"
)
logger.info(
f"kv_cache {type(kv_cache)}, len={len(kv_cache)}, {len(kv_cache[0])}, shape={kv_cache[0][0].shape}"
)
logger.info(
f"block_tables {type(block_tables)} {block_tables.shape} {block_tables}"
)
logger.info(f"slots {type(slots)} {slots.shape} {slots}")
logger.info(f"input_lengths {input_lengths}")
logger.info(f"max_s {max_s}")
logger.info(f"prefill_cache_indices {batch.prefill_cache_indices}")
logger.info(f"lm_head_indices {lm_head_indices}")
logger.info(f"adapter_data {adapter_data}")
logits, speculative_logits = self.model.forward(
input_ids=input_ids,
position_ids=position_ids,
Expand Down

0 comments on commit 02ac451

Please sign in to comment.