Skip to content

Commit

Permalink
Fix some type ignores (#589)
Browse files Browse the repository at this point in the history
  • Loading branch information
hanlint authored Sep 12, 2023
1 parent e5c243c commit f03276d
Show file tree
Hide file tree
Showing 24 changed files with 182 additions and 166 deletions.
4 changes: 1 addition & 3 deletions llmfoundry/callbacks/monolithic_ckpt_callback.py
Original file line number Diff line number Diff line change
Expand Up @@ -72,9 +72,7 @@ def _save_checkpoint(self, state: State, logger: Logger):
) if self.upload_to_object_store else contextlib.nullcontext(
enter_result=save_dir)
with dir_context_mgr as temp_save_dir:
save_path = str(
Path(temp_save_dir) / # type: ignore
Path(filename))
save_path = str(Path(temp_save_dir) / Path(filename))
dirname = os.path.dirname(save_path)
if dirname:
os.makedirs(dirname, exist_ok=True)
Expand Down
13 changes: 6 additions & 7 deletions llmfoundry/data/denoising.py
Original file line number Diff line number Diff line change
Expand Up @@ -677,8 +677,7 @@ def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray:
"""
span_markers = np.less(np.arange(total_tokens - 1), num_spans -
1)[np.random.permutation(total_tokens - 1)]
span_start_indicator = np.concatenate([[0],
span_markers]) # type: ignore
span_start_indicator = np.concatenate([np.array([0]), span_markers])
span_id = np.cumsum(span_start_indicator).reshape(-1, 1)
spans = np.arange(num_spans).reshape(1, -1)
span_lengths = np.sum(span_id == spans, axis=0)
Expand Down Expand Up @@ -715,13 +714,13 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],

# Ensure there's an end-of-sentence token at the end
if ensure_eos and (noised_tokens[-1] != eos_token_id):
noised_tokens = np.concatenate([noised_tokens,
[eos_token_id]]) # type: ignore
noised_tokens = np.concatenate(
[noised_tokens, np.array([eos_token_id])])

return noised_tokens

# Masking at previous token
prev_token_mask = np.concatenate([[0], mask[:-1]]) # type: ignore
prev_token_mask = np.concatenate([np.array([0]), mask[:-1]])

# Decompose mask into start-of-span mask and non-start-of-span mask
start_of_noise_span_token = np.logical_and(mask,
Expand All @@ -740,8 +739,8 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray],

# Ensure there's an end-of-sentence token at the end
if ensure_eos and (noised_tokens[-1] != eos_token_id):
noised_tokens = np.concatenate([noised_tokens,
[eos_token_id]]) # type: ignore
noised_tokens = np.concatenate(
[noised_tokens, np.array([eos_token_id])])
return noised_tokens


Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/data/finetuning/dataloader.py
Original file line number Diff line number Diff line change
Expand Up @@ -111,7 +111,7 @@ def build_finetuning_dataloader(cfg: DictConfig,
_validate_config(cfg.dataset)

# Use EOS as the pad token if none exists
if tokenizer.pad_token is None: # type: ignore
if tokenizer.pad_token is None:
tokenizer.pad_token = tokenizer.eos_token

dataset = None # for pyright
Expand Down
10 changes: 7 additions & 3 deletions llmfoundry/data/packing.py
Original file line number Diff line number Diff line change
Expand Up @@ -360,9 +360,13 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase,
# build tokenizer
if 'tokenizer' not in cfg:
raise ValueError('config must define tokenizer')
tokenizer_cfg: Dict[str,
Any] = om.to_container(cfg.tokenizer,
resolve=True) # type: ignore

resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True)
if not isinstance(resolved_tokenizer_cfg, Dict):
raise ValueError(
'tokenizer config needs to be resolved by omegaconf into a Dict.')
tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg

tokenizer_name = tokenizer_cfg['name']
tokenizer_kwargs = tokenizer_cfg.get('kwargs', {})
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)
Expand Down
17 changes: 10 additions & 7 deletions llmfoundry/data/text_data.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,7 +5,8 @@

import os
from itertools import islice
from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union
from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence,
Union, cast)

import numpy as np
import torch
Expand Down Expand Up @@ -193,11 +194,12 @@ def __init__(
'`bos_token_id` if sequences start with a BOS token.'
)

self.split_token_id = eos_token_id
self.bos_mode = False
if eos_token_id is None:
self.split_token_id = bos_token_id
self.split_token_id = cast(int, bos_token_id)
self.bos_mode = True
else:
self.split_token_id = eos_token_id
self.bos_mode = False

def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:
batch = self.base_collator(examples)
Expand All @@ -206,8 +208,7 @@ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]:

def get_sequence_id_from_batch(
self, batch: Dict[str, torch.Tensor]) -> torch.Tensor:
is_separator = torch.eq(batch['input_ids'],
self.split_token_id) # type: ignore
is_separator = torch.eq(batch['input_ids'], self.split_token_id)
cumulative_sep = torch.cumsum(is_separator,
dim=1).to(batch['input_ids'].dtype)
# If separator token is bos, we're already done
Expand Down Expand Up @@ -340,7 +341,9 @@ def build_text_dataloader(
tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs)

loader = build_text_dataloader(cfg, tokenizer, device_batch_size)
tokenizer = loader.dataset.tokenizer # type: ignore
assert isinstance(loader.dataset, StreamingTextDataset)
tokenizer = loader.dataset.tokenizer

for batch_ix, batch in enumerate(islice(loader, 5)):
print('\n')
print('#' * 20, f'Batch {batch_ix}', '#' * 20)
Expand Down
29 changes: 17 additions & 12 deletions llmfoundry/models/hf/hf_fsdp.py
Original file line number Diff line number Diff line change
Expand Up @@ -94,7 +94,12 @@ def hf_get_hidden_layers(model: PreTrainedModel):
'model.layers', # LLaMa
'transformer.blocks', # MPT
)
return findattr(model, hidden_layers_attrs)
layers = findattr(model, hidden_layers_attrs)
if layers is None:
raise ValueError(
f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}'
)
return layers


def hf_get_init_device(init_device: Optional[str]):
Expand Down Expand Up @@ -136,7 +141,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
# OPT has an extra layer of wrapping, so special case here
if isinstance(causal_base_model, OPTDecoder):
model.model._fsdp_wrap = False
model_block = hf_get_hidden_layers(model) # type: ignore
model_block = hf_get_hidden_layers(model)
lm_head = model.get_output_embeddings()
# some models (OPT) implement .get_input_embeddings for the causal subclass
# but all of them implement it for the base model
Expand All @@ -153,7 +158,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
raise ValueError(
f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
'follow common layer/weight naming conventions.')
block_type = type(model_block[0]) # type: ignore
block_type = type(model_block[0])
if init_device == 'mixed':
# For FSDP with models with different device initializations, `mixed`, which
# initializes the model on rank 0 on `cpu` and on all other ranks on `meta,``
Expand Down Expand Up @@ -186,9 +191,9 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel,
# These lines ensures that both modules stay together in the top-most block when
# the model has this tying enabled (almost all do; this property defaults to True)
if model.config.tie_word_embeddings:
causal_base_model._fsdp_wrap = False # type: ignore
tied_embeddings._fsdp_wrap = False # type: ignore
lm_head._fsdp_wrap = False # type: ignore
causal_base_model._fsdp_wrap = False
tied_embeddings._fsdp_wrap = False
lm_head._fsdp_wrap = False

# FSDP Wrap and Activation Checkpoint every model block
model.fsdp_wrap_fn = lambda module: isinstance(module, block_type)
Expand Down Expand Up @@ -228,15 +233,15 @@ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel,
raise ValueError(
f'Unable to FSDP-wrap this model! `{mod_name}` does not ' +
'follow common layer/weight naming conventions.')
decoder_block_type = type(decoder_block[0]) # type: ignore
encoder_block_type = type(encoder_block[0]) # type: ignore
decoder_block_type = type(decoder_block[0])
encoder_block_type = type(encoder_block[0])

if model.config.tie_word_embeddings:
# it is possible to train an enc/dec without tied embeddings, hence the check
tied_embeddings._fsdp_wrap = False # type: ignore
encoder._fsdp_wrap = False # type: ignore
decoder._fsdp_wrap = False # type: ignore
lm_head._fsdp_wrap = False # type: ignore
tied_embeddings._fsdp_wrap = False
encoder._fsdp_wrap = False
decoder._fsdp_wrap = False
lm_head._fsdp_wrap = False

# FSDP Wrap and Activation Checkpoint every decoder block
model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type)
Expand Down
4 changes: 2 additions & 2 deletions llmfoundry/models/layers/attention.py
Original file line number Diff line number Diff line change
Expand Up @@ -464,7 +464,7 @@ def __init__(
i * self.head_dim
for i in range(1, self.n_heads + 2 * self.kv_n_heads)
]
self.Wqkv._fused = (0, fuse_splits) # type: ignore
self.Wqkv._fused = (0, fuse_splits)

if self.qk_ln:
norm_class = NORM_CLASS_REGISTRY[norm_type.lower()]
Expand Down Expand Up @@ -499,7 +499,7 @@ def __init__(
self.d_model,
**fc_kwargs,
)
self.out_proj._is_residual = True # type: ignore
self.out_proj._is_residual = True

def forward(
self,
Expand Down
2 changes: 1 addition & 1 deletion llmfoundry/models/layers/ffn.py
Original file line number Diff line number Diff line change
Expand Up @@ -40,7 +40,7 @@ def __init__(
d_model,
**fc_kwargs,
)
self.down_proj._is_residual = True # type: ignore
self.down_proj._is_residual = True

def forward(self, x: torch.Tensor):
return self.down_proj(self.act(self.up_proj(x)))
Expand Down
56 changes: 29 additions & 27 deletions llmfoundry/models/layers/llama_attention_monkeypatch.py
Original file line number Diff line number Diff line change
Expand Up @@ -9,7 +9,8 @@
from typing import Callable, Optional, Tuple

import torch
import torch.functional as F
import torch.nn.functional as F
from transformers.models.llama.modeling_llama import LlamaAttention

from llmfoundry.models.layers.attention import (
scaled_multihead_dot_product_attention, triton_flash_attn_fn)
Expand Down Expand Up @@ -42,8 +43,11 @@ def rotate_half(x: torch.Tensor):
return torch.cat((-x2, x1), dim=-1)


def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor,
sin: torch.Tensor, position_ids: torch.Tensor):
def apply_rotary_pos_emb(q: torch.Tensor,
k: torch.Tensor,
cos: torch.Tensor,
sin: torch.Tensor,
position_ids: Optional[torch.Tensor] = None):
# The first two dimensions of cos and sin are always 1, so we can `squeeze` them.
cos = cos.squeeze(1).squeeze(0) # [seq_len, dim]
sin = sin.squeeze(1).squeeze(0) # [seq_len, dim]
Expand All @@ -65,7 +69,7 @@ def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable:


def llama_attention_patch_torch(
self, # type: ignore
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand All @@ -89,21 +93,19 @@ def llama_attention_patch_torch(
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, query_slices[i])
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty)
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, value_slices[i])
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
Expand All @@ -123,9 +125,9 @@ def llama_attention_patch_torch(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin,
position_ids) # type: ignore (thirdParty)

query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

### MAIN MODIFICATIONS START HERE ###
query_states = query_states.transpose(1, 2).view(
Expand Down Expand Up @@ -160,21 +162,22 @@ def llama_attention_patch_torch(
self.config.pretraining_tp,
dim=1)
attn_output = sum([
F.linear( # type: ignore (thirdParty)
attn_output[i], o_proj_slices[i])
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
])
else:
attn_output = self.o_proj(attn_output)

assert isinstance(attn_output, torch.Tensor)

if not output_attentions:
attn_weights = None

return attn_output, attn_weights, None # type: ignore (thirdParty)
return attn_output, attn_weights, None


def llama_attention_patch_triton(
self, # type: ignore
self: LlamaAttention,
hidden_states: torch.Tensor,
attention_mask: Optional[torch.Tensor] = None,
position_ids: Optional[torch.LongTensor] = None,
Expand All @@ -190,6 +193,7 @@ def llama_attention_patch_triton(
raise NotImplementedError(
'output_attentions is not supported when patching Llama attention with triton attention.'
)

bsz, q_len, _ = hidden_states.size()

if self.config.pretraining_tp > 1:
Expand All @@ -202,21 +206,19 @@ def llama_attention_patch_triton(
value_slices = self.v_proj.weight.split(key_value_slicing, dim=0)

query_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, query_slices[i])
F.linear(hidden_states, query_slices[i])
for i in range(self.config.pretraining_tp)
]
query_states = torch.cat(query_states, dim=-1)

key_states = [
F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty)
F.linear(hidden_states, key_slices[i])
for i in range(self.config.pretraining_tp)
]
key_states = torch.cat(key_states, dim=-1)

value_states = [
F.linear( # type: ignore (thirdParty)
hidden_states, value_slices[i])
F.linear(hidden_states, value_slices[i])
for i in range(self.config.pretraining_tp)
]
value_states = torch.cat(value_states, dim=-1)
Expand All @@ -236,9 +238,8 @@ def llama_attention_patch_triton(
if past_key_value is not None:
kv_seq_len += past_key_value[0].shape[-2]
cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len)
query_states, key_states = apply_rotary_pos_emb(
query_states, key_states, cos, sin,
position_ids) # type: ignore (thirdParty)
query_states, key_states = apply_rotary_pos_emb(query_states, key_states,
cos, sin, position_ids)

### MAIN MODIFICATIONS START HERE ###
query_states = query_states.transpose(1, 2).view(
Expand Down Expand Up @@ -273,11 +274,12 @@ def llama_attention_patch_triton(
self.config.pretraining_tp,
dim=1)
attn_output = sum([
F.linear( # type: ignore (thirdParty)
attn_output[i], o_proj_slices[i])
F.linear(attn_output[i], o_proj_slices[i])
for i in range(self.config.pretraining_tp)
])
else:
attn_output = self.o_proj(attn_output)

return attn_output, None, None # type: ignore (thirdParty)
assert isinstance(attn_output, torch.Tensor)

return attn_output, None, None
Loading

0 comments on commit f03276d

Please sign in to comment.