diff --git a/llmfoundry/callbacks/monolithic_ckpt_callback.py b/llmfoundry/callbacks/monolithic_ckpt_callback.py index 71d1a93f7d..afca099832 100644 --- a/llmfoundry/callbacks/monolithic_ckpt_callback.py +++ b/llmfoundry/callbacks/monolithic_ckpt_callback.py @@ -72,9 +72,7 @@ def _save_checkpoint(self, state: State, logger: Logger): ) if self.upload_to_object_store else contextlib.nullcontext( enter_result=save_dir) with dir_context_mgr as temp_save_dir: - save_path = str( - Path(temp_save_dir) / # type: ignore - Path(filename)) + save_path = str(Path(temp_save_dir) / Path(filename)) dirname = os.path.dirname(save_path) if dirname: os.makedirs(dirname, exist_ok=True) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index e4b018ab05..302bdc4bc4 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -677,8 +677,7 @@ def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray: """ span_markers = np.less(np.arange(total_tokens - 1), num_spans - 1)[np.random.permutation(total_tokens - 1)] - span_start_indicator = np.concatenate([[0], - span_markers]) # type: ignore + span_start_indicator = np.concatenate([np.array([0]), span_markers]) span_id = np.cumsum(span_start_indicator).reshape(-1, 1) spans = np.arange(num_spans).reshape(1, -1) span_lengths = np.sum(span_id == spans, axis=0) @@ -715,13 +714,13 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray], # Ensure there's an end-of-sentence token at the end if ensure_eos and (noised_tokens[-1] != eos_token_id): - noised_tokens = np.concatenate([noised_tokens, - [eos_token_id]]) # type: ignore + noised_tokens = np.concatenate( + [noised_tokens, np.array([eos_token_id])]) return noised_tokens # Masking at previous token - prev_token_mask = np.concatenate([[0], mask[:-1]]) # type: ignore + prev_token_mask = np.concatenate([np.array([0]), mask[:-1]]) # Decompose mask into start-of-span mask and non-start-of-span mask start_of_noise_span_token = np.logical_and(mask, @@ -740,8 +739,8 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray], # Ensure there's an end-of-sentence token at the end if ensure_eos and (noised_tokens[-1] != eos_token_id): - noised_tokens = np.concatenate([noised_tokens, - [eos_token_id]]) # type: ignore + noised_tokens = np.concatenate( + [noised_tokens, np.array([eos_token_id])]) return noised_tokens diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index b5a7420b34..1ed3d56d60 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -111,7 +111,7 @@ def build_finetuning_dataloader(cfg: DictConfig, _validate_config(cfg.dataset) # Use EOS as the pad token if none exists - if tokenizer.pad_token is None: # type: ignore + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token dataset = None # for pyright diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index a00e694228..5f157724ce 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -360,9 +360,13 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, # build tokenizer if 'tokenizer' not in cfg: raise ValueError('config must define tokenizer') - tokenizer_cfg: Dict[str, - Any] = om.to_container(cfg.tokenizer, - resolve=True) # type: ignore + + resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True) + if not isinstance(resolved_tokenizer_cfg, Dict): + raise ValueError( + 'tokenizer config needs to be resolved by omegaconf into a Dict.') + tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg + tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 253f90cd9f..4562d3de0a 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -5,7 +5,8 @@ import os from itertools import islice -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union +from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, + Union, cast) import numpy as np import torch @@ -193,11 +194,12 @@ def __init__( '`bos_token_id` if sequences start with a BOS token.' ) - self.split_token_id = eos_token_id - self.bos_mode = False if eos_token_id is None: - self.split_token_id = bos_token_id + self.split_token_id = cast(int, bos_token_id) self.bos_mode = True + else: + self.split_token_id = eos_token_id + self.bos_mode = False def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]: batch = self.base_collator(examples) @@ -206,8 +208,7 @@ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]: def get_sequence_id_from_batch( self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - is_separator = torch.eq(batch['input_ids'], - self.split_token_id) # type: ignore + is_separator = torch.eq(batch['input_ids'], self.split_token_id) cumulative_sep = torch.cumsum(is_separator, dim=1).to(batch['input_ids'].dtype) # If separator token is bos, we're already done @@ -340,7 +341,9 @@ def build_text_dataloader( tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) loader = build_text_dataloader(cfg, tokenizer, device_batch_size) - tokenizer = loader.dataset.tokenizer # type: ignore + assert isinstance(loader.dataset, StreamingTextDataset) + tokenizer = loader.dataset.tokenizer + for batch_ix, batch in enumerate(islice(loader, 5)): print('\n') print('#' * 20, f'Batch {batch_ix}', '#' * 20) diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index 4a0b76e640..56ba24aeff 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -94,7 +94,12 @@ def hf_get_hidden_layers(model: PreTrainedModel): 'model.layers', # LLaMa 'transformer.blocks', # MPT ) - return findattr(model, hidden_layers_attrs) + layers = findattr(model, hidden_layers_attrs) + if layers is None: + raise ValueError( + f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}' + ) + return layers def hf_get_init_device(init_device: Optional[str]): @@ -136,7 +141,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, # OPT has an extra layer of wrapping, so special case here if isinstance(causal_base_model, OPTDecoder): model.model._fsdp_wrap = False - model_block = hf_get_hidden_layers(model) # type: ignore + model_block = hf_get_hidden_layers(model) lm_head = model.get_output_embeddings() # some models (OPT) implement .get_input_embeddings for the causal subclass # but all of them implement it for the base model @@ -153,7 +158,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, raise ValueError( f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.') - block_type = type(model_block[0]) # type: ignore + block_type = type(model_block[0]) if init_device == 'mixed': # For FSDP with models with different device initializations, `mixed`, which # initializes the model on rank 0 on `cpu` and on all other ranks on `meta,`` @@ -186,9 +191,9 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, # These lines ensures that both modules stay together in the top-most block when # the model has this tying enabled (almost all do; this property defaults to True) if model.config.tie_word_embeddings: - causal_base_model._fsdp_wrap = False # type: ignore - tied_embeddings._fsdp_wrap = False # type: ignore - lm_head._fsdp_wrap = False # type: ignore + causal_base_model._fsdp_wrap = False + tied_embeddings._fsdp_wrap = False + lm_head._fsdp_wrap = False # FSDP Wrap and Activation Checkpoint every model block model.fsdp_wrap_fn = lambda module: isinstance(module, block_type) @@ -228,15 +233,15 @@ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, raise ValueError( f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.') - decoder_block_type = type(decoder_block[0]) # type: ignore - encoder_block_type = type(encoder_block[0]) # type: ignore + decoder_block_type = type(decoder_block[0]) + encoder_block_type = type(encoder_block[0]) if model.config.tie_word_embeddings: # it is possible to train an enc/dec without tied embeddings, hence the check - tied_embeddings._fsdp_wrap = False # type: ignore - encoder._fsdp_wrap = False # type: ignore - decoder._fsdp_wrap = False # type: ignore - lm_head._fsdp_wrap = False # type: ignore + tied_embeddings._fsdp_wrap = False + encoder._fsdp_wrap = False + decoder._fsdp_wrap = False + lm_head._fsdp_wrap = False # FSDP Wrap and Activation Checkpoint every decoder block model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9114bc47aa..982f212493 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -464,7 +464,7 @@ def __init__( i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads) ] - self.Wqkv._fused = (0, fuse_splits) # type: ignore + self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] @@ -499,7 +499,7 @@ def __init__( self.d_model, **fc_kwargs, ) - self.out_proj._is_residual = True # type: ignore + self.out_proj._is_residual = True def forward( self, diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a02558102f..0b41a753d9 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -40,7 +40,7 @@ def __init__( d_model, **fc_kwargs, ) - self.down_proj._is_residual = True # type: ignore + self.down_proj._is_residual = True def forward(self, x: torch.Tensor): return self.down_proj(self.act(self.up_proj(x))) diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py index b65a5cb300..0f75986e11 100644 --- a/llmfoundry/models/layers/llama_attention_monkeypatch.py +++ b/llmfoundry/models/layers/llama_attention_monkeypatch.py @@ -9,7 +9,8 @@ from typing import Callable, Optional, Tuple import torch -import torch.functional as F +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import LlamaAttention from llmfoundry.models.layers.attention import ( scaled_multihead_dot_product_attention, triton_flash_attn_fn) @@ -42,8 +43,11 @@ def rotate_half(x: torch.Tensor): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor, position_ids: torch.Tensor): +def apply_rotary_pos_emb(q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] @@ -65,7 +69,7 @@ def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable: def llama_attention_patch_torch( - self, # type: ignore + self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -89,21 +93,19 @@ def llama_attention_patch_torch( value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, query_slices[i]) + F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) ] query_states = torch.cat(query_states, dim=-1) key_states = [ - F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty) + F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) ] key_states = torch.cat(key_states, dim=-1) value_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, value_slices[i]) + F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) ] value_states = torch.cat(value_states, dim=-1) @@ -123,9 +125,9 @@ def llama_attention_patch_torch( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, - position_ids) # type: ignore (thirdParty) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) ### MAIN MODIFICATIONS START HERE ### query_states = query_states.transpose(1, 2).view( @@ -160,21 +162,22 @@ def llama_attention_patch_torch( self.config.pretraining_tp, dim=1) attn_output = sum([ - F.linear( # type: ignore (thirdParty) - attn_output[i], o_proj_slices[i]) + F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp) ]) else: attn_output = self.o_proj(attn_output) + assert isinstance(attn_output, torch.Tensor) + if not output_attentions: attn_weights = None - return attn_output, attn_weights, None # type: ignore (thirdParty) + return attn_output, attn_weights, None def llama_attention_patch_triton( - self, # type: ignore + self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -190,6 +193,7 @@ def llama_attention_patch_triton( raise NotImplementedError( 'output_attentions is not supported when patching Llama attention with triton attention.' ) + bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -202,21 +206,19 @@ def llama_attention_patch_triton( value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, query_slices[i]) + F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) ] query_states = torch.cat(query_states, dim=-1) key_states = [ - F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty) + F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) ] key_states = torch.cat(key_states, dim=-1) value_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, value_slices[i]) + F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) ] value_states = torch.cat(value_states, dim=-1) @@ -236,9 +238,8 @@ def llama_attention_patch_triton( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, - position_ids) # type: ignore (thirdParty) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) ### MAIN MODIFICATIONS START HERE ### query_states = query_states.transpose(1, 2).view( @@ -273,11 +274,12 @@ def llama_attention_patch_triton( self.config.pretraining_tp, dim=1) attn_output = sum([ - F.linear( # type: ignore (thirdParty) - attn_output[i], o_proj_slices[i]) + F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp) ]) else: attn_output = self.o_proj(attn_output) - return attn_output, None, None # type: ignore (thirdParty) + assert isinstance(attn_output, torch.Tensor) + + return attn_output, None, None diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 5481201a8f..75ebfbd67c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -307,12 +307,10 @@ def forward( if use_cache is not None else self.config.use_cache) if attention_mask is not None: - attention_mask = attention_mask.bool( - ) # type: ignore (TODO to figure out the right type here) + attention_mask = attention_mask.bool() # type: ignore if prefix_mask is not None: - prefix_mask = prefix_mask.bool( - ) # type: ignore (TODO to figure out the right type here) + prefix_mask = prefix_mask.bool() # type: ignore # These args are passed in by keyword in huggingface's generate function # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206 @@ -360,7 +358,7 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' - tok_emb = self.wte(input_ids) # type: ignore + tok_emb = self.wte(input_ids) if self.learned_pos_emb: past_position = 0 if past_key_values is not None: @@ -397,14 +395,14 @@ def forward( min=0, ) - pos_emb = self.wpe(pos) # type: ignore + pos_emb = self.wpe(pos) x = tok_emb + pos_emb else: # ALiBi and NoPE use this path (RoPE will also use this path if / when enabled) x = tok_emb if self.embedding_fraction == 1: - x = self.emb_drop(x) # type: ignore + x = self.emb_drop(x) else: # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414 x_shrunk = (x * self.embedding_fraction) + ( @@ -427,7 +425,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for b_idx, block in enumerate(self.blocks): # type: ignore + for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) @@ -447,7 +445,7 @@ def forward( assert all_self_attns is not None # pyright all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) # type: ignore + x = self.norm_f(x) # add hidden states from the last decoder layer if output_hidden_states: @@ -716,7 +714,8 @@ def __init__( loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy') if loss_fn_config == 'fused_crossentropy': try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + from flash_attn.losses.cross_entropy import \ + CrossEntropyLoss as FusedCrossEntropyLoss if hf_config.verbose > 1: warnings.warn('Using Fused Cross Entropy Loss.') diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 6ec67a5b71..35a15e530a 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -7,8 +7,8 @@ add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm) from llmfoundry.models.utils.meta_init_context import (init_empty_weights, init_on_device) -from llmfoundry.models.utils.param_init_fns import ( # type: ignore - MODEL_INIT_REGISTRY, generic_param_init_fn_) +from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY, + generic_param_init_fn_) __all__ = [ 'AutoTokenizerForMOD', diff --git a/llmfoundry/models/utils/hf_prefixlm_converter.py b/llmfoundry/models/utils/hf_prefixlm_converter.py index ae8ed444c8..fb9477d909 100644 --- a/llmfoundry/models/utils/hf_prefixlm_converter.py +++ b/llmfoundry/models/utils/hf_prefixlm_converter.py @@ -79,7 +79,7 @@ def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]: else: blocks = model.transformer.h - for block in blocks: # type: ignore + for block in blocks: if isinstance(model, GPTNeoForCausalLM): # Ignore "local" layers in this model type if block.attn.attention_type != 'global': @@ -150,7 +150,7 @@ def call_og_forward(): if bidirectional_mask is None: # This wrapper is a no-op if bidirectional masks are not supplied - return call_og_forward() # type: ignore + return call_og_forward() assert isinstance(bidirectional_mask, torch.Tensor) attn_modules = _get_attn_modules(model) @@ -158,7 +158,9 @@ def call_og_forward(): # Handle bidirectional_mask sizing # Note: all attn_modules.bias have the same size b, s = bidirectional_mask.shape + max_length = attn_modules[0].bias.shape[-1] # type: ignore + if s > max_length: raise ValueError( f'bidirectional_mask sequence length (={s}) exceeds the ' +\ @@ -174,8 +176,9 @@ def call_og_forward(): # Incorporate the bidirectional mask into the original causal mask for attn_module in attn_modules: - attn_module.bias.data = torch.logical_or( - attn_module.bias.data, bidirectional) # type: ignore + assert isinstance(attn_module.bias, torch.Tensor) + attn_module.bias.data = torch.logical_or(attn_module.bias.data, + bidirectional) # Collect outputs using the model's original forward method output = call_og_forward() @@ -201,7 +204,7 @@ def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any): attn_module.bias.data[:] = 1 # type: ignore # Collect outputs using the model's original forward method - output = self._original_generate(*args, **kwargs) # type: ignore + output = self._original_generate(*args, **kwargs) # Reset the masks for attn_module in attn_modules: @@ -330,7 +333,7 @@ def _build_alibi_tensor( # and one new argument (`bidirectional_mask`) is added to the signature. KeyValueT = Tuple[torch.Tensor, torch.Tensor] - def forward( # type: ignore + def transformer_forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[KeyValueT, ...]] = None, @@ -342,8 +345,9 @@ def forward( # type: ignore output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **deprecated_arguments: Any) -> Union[Tuple[ - torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + **deprecated_arguments: Any + ) -> Union[Tuple[torch.Tensor, ...], + BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop('position_ids', False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so # defaulting pop to `False` allows to detect if users were @@ -501,8 +505,8 @@ def custom_forward(*inputs: Any): MethodType(_prepare_attn_mask, model.transformer)) setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer)) - setattr(model.transformer, 'forward', MethodType(forward, - model.transformer)) + setattr(model.transformer, 'forward', + MethodType(transformer_forward, model.transformer)) # In order to actually use the new argument we've added to # model.transformer, we need to update the parent module's `forward` to diff --git a/llmfoundry/models/utils/meta_init_context.py b/llmfoundry/models/utils/meta_init_context.py index eee30f9357..c22c226c28 100644 --- a/llmfoundry/models/utils/meta_init_context.py +++ b/llmfoundry/models/utils/meta_init_context.py @@ -80,21 +80,27 @@ def init_on_device(device: torch.device, include_buffers: bool = False): if include_buffers: old_register_buffer = nn.Module.register_buffer - def register_empty_parameter(module: torch.nn.Module, name: str, + def register_empty_parameter(self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]): - old_register_parameter(module, name, param) + old_register_parameter(self, name, param) if param is not None: - param_cls = type(module._parameters[name]) - kwargs = module._parameters[name].__dict__ # type: ignore - module._parameters[name] = param_cls( - module._parameters[name].to(device), **kwargs) # type: ignore - - def register_empty_buffer(module: torch.nn.Module, name: str, - buffer: Optional[torch.Tensor]): - old_register_buffer(module, name, buffer) - if buffer is not None: - module._buffers[name] = module._buffers[name].to( # type: ignore - device) + parameter = self._parameters[name] + assert parameter is not None + + param_cls = type(parameter) + kwargs = parameter.__dict__ + + self._parameters[name] = param_cls(parameter.to(device), **kwargs) + + def register_empty_buffer(self: torch.nn.Module, + name: str, + tensor: Optional[torch.Tensor], + persistent: bool = True): + old_register_buffer(self, name, tensor, persistent=persistent) + if tensor is not None: + named_buffer = self._buffers[name] + assert named_buffer is not None + self._buffers[name] = named_buffer.to(device) # Patch tensor creation if include_buffers: @@ -114,9 +120,9 @@ def wrapper(*args: Any, **kwargs: Any): return wrapper try: - nn.Module.register_parameter = register_empty_parameter # type: ignore + nn.Module.register_parameter = register_empty_parameter if include_buffers: - nn.Module.register_buffer = register_empty_buffer # type: ignore + nn.Module.register_buffer = register_empty_buffer for torch_function_name in tensor_constructors_to_patch.keys(): setattr( torch, torch_function_name, diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 0dbb4e4a6f..388ef10e29 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -29,8 +29,9 @@ def torch_default_param_init_fn_( warnings.warn( f"Initializing network using module's reset_parameters attribute") - if hasattr(module, 'reset_parameters'): - module.reset_parameters() # type: ignore + if hasattr(module, 'reset_parameters') and isinstance( + module.reset_parameters, Callable): + module.reset_parameters() def fused_init_helper_(module: nn.Module, init_fn_: Callable): @@ -46,12 +47,14 @@ def fused_init_helper_(module: nn.Module, init_fn_: Callable): if _fused is None: raise RuntimeError(f'Internal logic error') + assert isinstance(module.weight, torch.Tensor) + dim, splits = _fused - splits = (0, *splits, module.weight.size(dim)) # type: ignore + splits = (0, *splits, module.weight.size(dim)) for s, e in zip(splits[:-1], splits[1:]): - slice_indices = [slice(None)] * module.weight.ndim # type: ignore + slice_indices = [slice(None)] * module.weight.ndim slice_indices[dim] = slice(s, e) - init_fn_(module.weight[slice_indices]) # type: ignore + init_fn_(module.weight[slice_indices]) def generic_param_init_fn_( @@ -82,9 +85,7 @@ def generic_param_init_fn_( elif isinstance(init_div_is_residual, float) or isinstance( init_div_is_residual, int): div_is_residual = init_div_is_residual - elif isinstance( - init_div_is_residual, # type: ignore - str) and init_div_is_residual.isnumeric(): + elif init_div_is_residual.isnumeric(): # do not trust YAML parsing to always convert numbers to numbers div_is_residual = float(init_div_is_residual) else: @@ -151,17 +152,17 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, - tuple(set(NORM_CLASS_REGISTRY.values()))): # type: ignore + elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): # Norm if verbose > 1: warnings.warn( f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.' ) - if hasattr(module, 'weight') and module.weight is not None: - torch.nn.init.ones_(module.weight) # type: ignore - if hasattr(module, 'bias') and module.bias is not None: - torch.nn.init.zeros_(module.bias) # type: ignore + if hasattr(module, 'weight') and isinstance(module.weight, + torch.Tensor): + torch.nn.init.ones_(module.weight) + if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor): + torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): # torch's MultiheadAttention @@ -199,10 +200,10 @@ def generic_param_init_fn_( torch.nn.init.zeros_(module.out_proj.bias) elif te is not None and isinstance(module, te.LayerNormMLP): - if module.layer_norm_weight is not None: - torch.nn.init.ones_(module.layer_norm_weight) # type: ignore - if module.layer_norm_bias is not None: - torch.nn.init.zeros_(module.layer_norm_bias) # type: ignore + if isinstance(module.layer_norm_weight, torch.Tensor): + torch.nn.init.ones_(module.layer_norm_weight) + if isinstance(module.layer_norm_bias, torch.Tensor): + torch.nn.init.zeros_(module.layer_norm_bias) init_fn_(module.fc1_weight) if module.fc1_bias is not None: diff --git a/llmfoundry/optim/outlier_detection.py b/llmfoundry/optim/outlier_detection.py index 71ff778166..b485a17c5d 100644 --- a/llmfoundry/optim/outlier_detection.py +++ b/llmfoundry/optim/outlier_detection.py @@ -41,7 +41,9 @@ def insert_observation(self, obs: float) -> bool: Returns: bool: Indicator of whether the most recent observation was an outlier. """ - if len(self.intermediate_data_queue # type: ignore + assert self.intermediate_data_queue.maxlen is not None, 'expected maxlen defined' + + if len(self.intermediate_data_queue ) >= self.intermediate_data_queue.maxlen: # move data from intermediate queue to slow moving average queue intermediate_obs = self.intermediate_data_queue.popleft() diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index 50dd30c45d..01948822c2 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -257,7 +257,7 @@ def build_dataloader(dataset: Dataset, batch_size: int, if num_workers is None: # Multiple workers is only supported on linux machines if 'linux' or 'macos' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) # type: ignore + num_workers = max(1, psutil.cpu_count()) else: num_workers = 0 diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index ecccb5f0f7..d2bf8923df 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -96,12 +96,13 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: yield {key: sample[key].encode('utf-8') for key in self.columns} -def build_dataloader(dataset: SimpleDataset, batch_size: int, - num_workers: int) -> DataLoader: +def build_dataloader(dataset: SimpleDataset, + batch_size: int, + num_workers: Optional[int] = None) -> DataLoader: if num_workers is None: # Multiple workers is only supported on linux machines if 'linux' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) # type: ignore + num_workers = max(1, psutil.cpu_count()) else: num_workers = 0 diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 31743f1752..3408f31d33 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -153,7 +153,7 @@ def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int], callbacks=callbacks, loggers=loggers, precision=precision, - fsdp_config=fsdp_config, # type: ignore + fsdp_config=fsdp_config, load_path=load_path, load_weights_only=True, progress_bar=False, diff --git a/scripts/inference/convert_hf_to_onnx.py b/scripts/inference/convert_hf_to_onnx.py index dd7a6f7a62..f73836e28f 100644 --- a/scripts/inference/convert_hf_to_onnx.py +++ b/scripts/inference/convert_hf_to_onnx.py @@ -138,9 +138,9 @@ def export_to_onnx( with torch.no_grad(): orig_out = model(**sample_input) - import onnx # type: ignore - import onnx.checker # type: ignore - import onnxruntime as ort # type: ignore + import onnx + import onnx.checker + import onnxruntime as ort _ = onnx.load(str(output_file)) diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index cc331acd14..96592ca477 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -293,8 +293,9 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]): maybe_synchronize() encode_end = time.time() input_tokens = torch.sum( - encoded_inp['input_ids'] != tokenizer.pad_token_id, - axis=1).numpy(force=True) # type: ignore + encoded_inp['input_ids'] != + tokenizer.pad_token_id, # type: ignore + axis=1).numpy(force=True) # Warmup if args.warmup and (not done_warmup): diff --git a/scripts/train/train.py b/scripts/train/train.py index 58fa67afe1..95aa22f44c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -90,7 +90,7 @@ def validate_config(cfg: DictConfig): '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.' ) - torch._dynamo.config.suppress_errors = True # type: ignore + torch._dynamo.config.suppress_errors = True # type: ignore (third-party) if cfg.model.get('load_in_8bit', False): raise ValueError( @@ -530,7 +530,7 @@ def main(cfg: DictConfig) -> Trainer: precision=precision, algorithms=algorithms, device_train_microbatch_size=device_train_microbatch_size, - fsdp_config=fsdp_config, # type: ignore + fsdp_config=fsdp_config, save_folder=save_folder, save_filename=save_filename, save_latest_filename=save_latest_filename, diff --git a/tests/test_hf_v_mpt.py b/tests/test_hf_v_mpt.py index 22c9241037..82e2d05550 100644 --- a/tests/test_hf_v_mpt.py +++ b/tests/test_hf_v_mpt.py @@ -165,7 +165,7 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, print(f'{hf_model_fwd.mean().item() = }\n{model_fwd.mean().item() = }') if hf_model_fwd.mean().allclose(model_fwd.mean()): warn_msg = f'WARNING: model_fwd ({model_fwd}) and hf_model_fwd ({hf_model_fwd}) are very close at init.' - raise warnings.warn(warn_msg) # type: ignore + raise RuntimeError(warn_msg) hf_model_statedict = hf_model.state_dict() diff --git a/tests/test_init_fn.py b/tests/test_init_fn.py index 9355e7a277..b054bac186 100644 --- a/tests/test_init_fn.py +++ b/tests/test_init_fn.py @@ -23,7 +23,7 @@ def __init__(self, cfg: Union[ListConfig, DictConfig]): self.fc1 = nn.Linear(cfg.in_features, cfg.out_features, bias=True) self.ln_1 = nn.LayerNorm(cfg.out_features) self.fc2 = nn.Linear(cfg.out_features, cfg.out_features, bias=True) - self.fc2._is_residual = True # type: ignore + self.fc2._is_residual = True def forward(self, x: torch.Tensor): y = self.ln_1(self.fc1(x)) @@ -76,7 +76,7 @@ def test_fused_init_helper(fused: bool): fc = nn.Linear(cfg.in_features, cfg.out_features, bias=True) fc.train() if fused: - fc._fused = (0, (cfg.out_features // 2,)) # type: ignore + fc._fused = (0, (cfg.out_features // 2,)) def init_fn_(weight: torch.Tensor): # dummy init based on layer width @@ -159,18 +159,18 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **dict_cfg)) + assert isinstance(model.emb, torch.nn.Embedding) + if dict_cfg.get('emb_init_std') is not None: emb_init_std = dict_cfg.get('emb_init_std') if emb_init_std == 0: - assert (model.emb.weight == 0).all() # type: ignore + assert (model.emb.weight == 0).all() elif dict_cfg.get('emb_init_uniform_lim') is not None: emb_init_uniform_lim = dict_cfg.get('emb_init_uniform_lim') if emb_init_uniform_lim == 0: - assert (model.emb.weight == 0).all() # type: ignore + assert (model.emb.weight == 0).all() elif isinstance(emb_init_uniform_lim, Sequence): assert len(emb_init_uniform_lim) <= 2 if len(emb_init_uniform_lim ) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]: - assert isinstance(model.emb, torch.nn.Embedding) - assert (model.emb.weight == emb_init_uniform_lim[0] - ).all() # type: ignore + assert (model.emb.weight == emb_init_uniform_lim[0]).all() diff --git a/tests/test_model.py b/tests/test_model.py index d26e86b59e..f20381f288 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -44,6 +44,12 @@ def get_config( return cast(DictConfig, test_cfg) +def _load_tokenizer_cfg(cfg: DictConfig) -> Dict: + config = om.to_container(cfg, resolve=True) + assert isinstance(config, Dict) + return config + + def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): warnings.filterwarnings( action='ignore', @@ -74,9 +80,7 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): test_cfg.device_eval_batch_size = 2 test_cfg.device_train_microbatch_size = 2 - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -176,7 +180,7 @@ def test_attention_mechanism(batch_size: int = 2): # and with 0 everywhere else expected_zerod_weights = nn.Transformer.generate_square_subsequent_mask(test_cfg.max_seq_len)\ .reshape(1, test_cfg.max_seq_len, test_cfg.max_seq_len) - expected_zerod_weights = torch.isneginf( # type: ignore + expected_zerod_weights = torch.isneginf( torch.cat(batch_size * [expected_zerod_weights])) torch_key_padding = torch.cat( # type: ignore test_cfg.max_seq_len * @@ -225,9 +229,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm: bool, else: neo_cfg.model.name = 'hf_causal_lm' - tokenizer_cfg: Dict[str, - Any] = om.to_container(neo_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(neo_cfg.tokenizer) tokenizer = build_tokenizer(neo_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -272,9 +274,7 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): t5_cfg.device = device t5_cfg.max_seq_len = 16 - tokenizer_cfg: Dict[str, - Any] = om.to_container(t5_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(t5_cfg.tokenizer) tokenizer = build_tokenizer(t5_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -328,9 +328,7 @@ def test_determinism(attn_impl: str, precision: torch.dtype): test_cfg.model.init_device = 'cuda:0' test_cfg.device = 'cuda:0' - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -379,8 +377,9 @@ def test_loss_fn(): pytest.skip('Fused cross entropy was not installed') # run numerical test in pure fp32 - torch.backends.cuda.matmul.allow_tf32 = False # type: ignore (third-party) - torch.backends.cudnn.allow_tf32 = False # type: ignore (third-party) + from torch.backends import cuda, cudnn + cuda.matmul.allow_tf32 = False + cudnn.allow_tf32 = False conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: @@ -397,9 +396,7 @@ def test_loss_fn(): reproducibility.seed_all(test_cfg.get('global_seed', 42)) - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -460,9 +457,7 @@ def test_opt_wrapping(prefixlm: bool): } config = DictConfig(conf) - tokenizer_cfg: Dict[str, - Any] = om.to_container(config.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer) tokenizer = build_tokenizer(config.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -1414,9 +1409,7 @@ def test_hf_init(tmp_path: pathlib.Path, model = AutoModelForCausalLM.from_pretrained(save_path, trust_remote_code=True) - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -1466,9 +1459,7 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2): ) test_cfg.device = torch.cuda.current_device() - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {}))