From dc1374816a007488ca9885c92e83e89e6fe933ef Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Mon, 11 Sep 2023 14:17:42 -0700 Subject: [PATCH 1/4] Add git-repo and git-branch params to regressions script (#591) --- .github/workflows/regressions.py | 10 ++++++++-- 1 file changed, 8 insertions(+), 2 deletions(-) diff --git a/.github/workflows/regressions.py b/.github/workflows/regressions.py index ecd48799d0..9211df1908 100644 --- a/.github/workflows/regressions.py +++ b/.github/workflows/regressions.py @@ -13,7 +13,8 @@ def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str, - wandb_project: str): + wandb_project: str, git_repo: str, git_branch: str): + print(f'Running regression tests on {git_repo} {git_branch}.') eval_7b_hf = RunConfig.from_file( os.path.join(REGRESSIONS_DIR, 'eval-7b-hf.yaml')) eval_7b_composer = RunConfig.from_file( @@ -48,6 +49,8 @@ def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str, config.cluster = cluster config.parameters['loggers'] = config.parameters.get('loggers', {}) config.parameters['loggers']['wandb'] = wandb_config + config.integrations[0]['git_repo'] = git_repo + config.integrations[0]['git_branch'] = git_branch return all_configs, [] @@ -58,10 +61,13 @@ def get_configs(cluster: str, mpt_7b_ckpt_path: str, wandb_entity: str, parser.add_argument('--mpt-7b-ckpt-path', type=str) parser.add_argument('--wandb-entity', type=str) parser.add_argument('--wandb-project', type=str) + parser.add_argument('--git-repo', type=str, default='mosaicml/llm-foundry') + parser.add_argument('--git-branch', type=str, default='main') args = parser.parse_args() run_configs, _ = get_configs(args.cluster, args.mpt_7b_ckpt_path, - args.wandb_entity, args.wandb_project) + args.wandb_entity, args.wandb_project, + args.git_repo, args.git_branch) for run_config in run_configs: run = create_run(run_config) From e5c243cb86e3be566008109cae90641c0b760a90 Mon Sep 17 00:00:00 2001 From: Irene Dea Date: Tue, 12 Sep 2023 14:09:34 -0700 Subject: [PATCH 2/4] Fix ComposerHFCausalLM instantiation with PeftModel (#593) * Fix bug in hf_causal_lm, causing errors with evaluating peft models * Move attention patch * Fix typing --- llmfoundry/models/hf/hf_causal_lm.py | 64 ++++++++++++++-------------- 1 file changed, 31 insertions(+), 33 deletions(-) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index a2e2ad3cdc..746499fdfb 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -16,6 +16,7 @@ LanguageCrossEntropy, LanguagePerplexity) from composer.utils import dist from omegaconf import DictConfig +from torch import nn from transformers import (AutoConfig, AutoModelForCausalLM, PreTrainedTokenizerBase) @@ -28,12 +29,9 @@ try: from peft.peft_model import PeftModel model_types = PeftModel, transformers.PreTrainedModel - _om_model_config_type = Union[DictConfig, PeftModel, - transformers.PreTrainedModel] except ImportError: model_types = transformers.PreTrainedModel - _om_model_config_type = Union[DictConfig, transformers.PreTrainedModel] __all__ = ['ComposerHFCausalLM'] @@ -58,21 +56,10 @@ class ComposerHFCausalLM(HuggingFaceModelWithZLoss): tokenizer (PreTrainedTokenizer): The tokenizer that the model will use. """ - def __init__( - self, - om_model_config: _om_model_config_type, # type: ignore - tokenizer: PreTrainedTokenizerBase): - - if not om_model_config.get('trust_remote_code', - True) and om_model_config.get( - 'pretrained_model_name_or_path', - None).startswith('mosaicml/mpt'): - raise ValueError( - 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' - + - 'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' - ) - + def __init__(self, om_model_config: Union[DictConfig, + transformers.PreTrainedModel, + nn.Module], + tokenizer: PreTrainedTokenizerBase): # set up training and eval metrics train_metrics = [ LanguageCrossEntropy(), @@ -90,6 +77,15 @@ def __init__( # if we are passed a DictConfig, we need to instantiate the model if isinstance(om_model_config, DictConfig): + if not om_model_config.get('trust_remote_code', + True) and om_model_config.get( + 'pretrained_model_name_or_path', + None).startswith('mosaicml/mpt'): + raise ValueError( + 'trust_remote_code must be set to True for MPT models. Without this, the MPT model code will come from the transformers library, ' + + + 'which is not significantly slower and not compatible with the LLM foundry training code, rather than the code release by MosaicML.' + ) # load the model config trust_remote_code = om_model_config.get('trust_remote_code', True) @@ -181,6 +177,23 @@ def __init__( z_loss = om_model_config.get('z_loss', 0.0) + attention_patch_type = om_model_config.get('attention_patch_type', + None) + if attention_patch_type is not None: + if model.config.model_type != 'llama': + raise ValueError( + f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' + ) + + print( + f'Patching llama attention with {attention_patch_type} attention' + ) + from transformers.models.llama.modeling_llama import \ + LlamaAttention + LlamaAttention.forward = get_llama_attention_patch_fn( + attention_patch_type) + model.config.use_cache = False + # elif the model is either a PeftModel or a PreTrainedModel elif isinstance(om_model_config, model_types): model = om_model_config @@ -193,21 +206,6 @@ def __init__( f'om_model_config must be either a DictConfig, PeftModel, or PreTrainedModel, but got {type(om_model_config)}' ) - attention_patch_type = om_model_config.get('attention_patch_type', None) - if attention_patch_type is not None: - if model.config.model_type != 'llama': - raise ValueError( - f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' - ) - - print( - f'Patching llama attention with {attention_patch_type} attention' - ) - from transformers.models.llama.modeling_llama import LlamaAttention - LlamaAttention.forward = get_llama_attention_patch_fn( - attention_patch_type) - model.config.use_cache = False - composer_model = super().__init__(model=model, shift_labels=True, tokenizer=tokenizer, From f03276d53782de28dbe5e93f7bf9e6b41d09b65a Mon Sep 17 00:00:00 2001 From: Hanlin Tang Date: Tue, 12 Sep 2023 16:06:08 -0700 Subject: [PATCH 3/4] Fix some type ignores (#589) --- .../callbacks/monolithic_ckpt_callback.py | 4 +- llmfoundry/data/denoising.py | 13 ++--- llmfoundry/data/finetuning/dataloader.py | 2 +- llmfoundry/data/packing.py | 10 +++- llmfoundry/data/text_data.py | 17 +++--- llmfoundry/models/hf/hf_fsdp.py | 29 ++++++---- llmfoundry/models/layers/attention.py | 4 +- llmfoundry/models/layers/ffn.py | 2 +- .../layers/llama_attention_monkeypatch.py | 56 ++++++++++--------- llmfoundry/models/mpt/modeling_mpt.py | 19 +++---- llmfoundry/models/utils/__init__.py | 4 +- .../models/utils/hf_prefixlm_converter.py | 24 ++++---- llmfoundry/models/utils/meta_init_context.py | 36 +++++++----- llmfoundry/models/utils/param_init_fns.py | 37 ++++++------ llmfoundry/optim/outlier_detection.py | 4 +- scripts/data_prep/convert_dataset_hf.py | 2 +- .../data_prep/convert_finetuning_dataset.py | 7 ++- scripts/eval/eval.py | 2 +- scripts/inference/convert_hf_to_onnx.py | 6 +- scripts/inference/hf_generate.py | 5 +- scripts/train/train.py | 4 +- tests/test_hf_v_mpt.py | 2 +- tests/test_init_fn.py | 14 ++--- tests/test_model.py | 45 ++++++--------- 24 files changed, 182 insertions(+), 166 deletions(-) diff --git a/llmfoundry/callbacks/monolithic_ckpt_callback.py b/llmfoundry/callbacks/monolithic_ckpt_callback.py index 71d1a93f7d..afca099832 100644 --- a/llmfoundry/callbacks/monolithic_ckpt_callback.py +++ b/llmfoundry/callbacks/monolithic_ckpt_callback.py @@ -72,9 +72,7 @@ def _save_checkpoint(self, state: State, logger: Logger): ) if self.upload_to_object_store else contextlib.nullcontext( enter_result=save_dir) with dir_context_mgr as temp_save_dir: - save_path = str( - Path(temp_save_dir) / # type: ignore - Path(filename)) + save_path = str(Path(temp_save_dir) / Path(filename)) dirname = os.path.dirname(save_path) if dirname: os.makedirs(dirname, exist_ok=True) diff --git a/llmfoundry/data/denoising.py b/llmfoundry/data/denoising.py index e4b018ab05..302bdc4bc4 100644 --- a/llmfoundry/data/denoising.py +++ b/llmfoundry/data/denoising.py @@ -677,8 +677,7 @@ def _sample_span_lengths(total_tokens: int, num_spans: int) -> np.ndarray: """ span_markers = np.less(np.arange(total_tokens - 1), num_spans - 1)[np.random.permutation(total_tokens - 1)] - span_start_indicator = np.concatenate([[0], - span_markers]) # type: ignore + span_start_indicator = np.concatenate([np.array([0]), span_markers]) span_id = np.cumsum(span_start_indicator).reshape(-1, 1) spans = np.arange(num_spans).reshape(1, -1) span_lengths = np.sum(span_id == spans, axis=0) @@ -715,13 +714,13 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray], # Ensure there's an end-of-sentence token at the end if ensure_eos and (noised_tokens[-1] != eos_token_id): - noised_tokens = np.concatenate([noised_tokens, - [eos_token_id]]) # type: ignore + noised_tokens = np.concatenate( + [noised_tokens, np.array([eos_token_id])]) return noised_tokens # Masking at previous token - prev_token_mask = np.concatenate([[0], mask[:-1]]) # type: ignore + prev_token_mask = np.concatenate([np.array([0]), mask[:-1]]) # Decompose mask into start-of-span mask and non-start-of-span mask start_of_noise_span_token = np.logical_and(mask, @@ -740,8 +739,8 @@ def _apply_mask(tokens: Union[torch.Tensor, Sequence[int], np.ndarray], # Ensure there's an end-of-sentence token at the end if ensure_eos and (noised_tokens[-1] != eos_token_id): - noised_tokens = np.concatenate([noised_tokens, - [eos_token_id]]) # type: ignore + noised_tokens = np.concatenate( + [noised_tokens, np.array([eos_token_id])]) return noised_tokens diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index b5a7420b34..1ed3d56d60 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -111,7 +111,7 @@ def build_finetuning_dataloader(cfg: DictConfig, _validate_config(cfg.dataset) # Use EOS as the pad token if none exists - if tokenizer.pad_token is None: # type: ignore + if tokenizer.pad_token is None: tokenizer.pad_token = tokenizer.eos_token dataset = None # for pyright diff --git a/llmfoundry/data/packing.py b/llmfoundry/data/packing.py index a00e694228..5f157724ce 100644 --- a/llmfoundry/data/packing.py +++ b/llmfoundry/data/packing.py @@ -360,9 +360,13 @@ def build_dataloader(cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, # build tokenizer if 'tokenizer' not in cfg: raise ValueError('config must define tokenizer') - tokenizer_cfg: Dict[str, - Any] = om.to_container(cfg.tokenizer, - resolve=True) # type: ignore + + resolved_tokenizer_cfg = om.to_container(cfg.tokenizer, resolve=True) + if not isinstance(resolved_tokenizer_cfg, Dict): + raise ValueError( + 'tokenizer config needs to be resolved by omegaconf into a Dict.') + tokenizer_cfg: Dict[Any, Any] = resolved_tokenizer_cfg + tokenizer_name = tokenizer_cfg['name'] tokenizer_kwargs = tokenizer_cfg.get('kwargs', {}) tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) diff --git a/llmfoundry/data/text_data.py b/llmfoundry/data/text_data.py index 253f90cd9f..4562d3de0a 100644 --- a/llmfoundry/data/text_data.py +++ b/llmfoundry/data/text_data.py @@ -5,7 +5,8 @@ import os from itertools import islice -from typing import Any, Callable, Dict, List, Mapping, Optional, Sequence, Union +from typing import (Any, Callable, Dict, List, Mapping, Optional, Sequence, + Union, cast) import numpy as np import torch @@ -193,11 +194,12 @@ def __init__( '`bos_token_id` if sequences start with a BOS token.' ) - self.split_token_id = eos_token_id - self.bos_mode = False if eos_token_id is None: - self.split_token_id = bos_token_id + self.split_token_id = cast(int, bos_token_id) self.bos_mode = True + else: + self.split_token_id = eos_token_id + self.bos_mode = False def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]: batch = self.base_collator(examples) @@ -206,8 +208,7 @@ def __call__(self, examples: List[Any]) -> Dict[str, torch.Tensor]: def get_sequence_id_from_batch( self, batch: Dict[str, torch.Tensor]) -> torch.Tensor: - is_separator = torch.eq(batch['input_ids'], - self.split_token_id) # type: ignore + is_separator = torch.eq(batch['input_ids'], self.split_token_id) cumulative_sep = torch.cumsum(is_separator, dim=1).to(batch['input_ids'].dtype) # If separator token is bos, we're already done @@ -340,7 +341,9 @@ def build_text_dataloader( tokenizer = build_tokenizer(tokenizer_name, tokenizer_kwargs) loader = build_text_dataloader(cfg, tokenizer, device_batch_size) - tokenizer = loader.dataset.tokenizer # type: ignore + assert isinstance(loader.dataset, StreamingTextDataset) + tokenizer = loader.dataset.tokenizer + for batch_ix, batch in enumerate(islice(loader, 5)): print('\n') print('#' * 20, f'Batch {batch_ix}', '#' * 20) diff --git a/llmfoundry/models/hf/hf_fsdp.py b/llmfoundry/models/hf/hf_fsdp.py index 4a0b76e640..56ba24aeff 100644 --- a/llmfoundry/models/hf/hf_fsdp.py +++ b/llmfoundry/models/hf/hf_fsdp.py @@ -94,7 +94,12 @@ def hf_get_hidden_layers(model: PreTrainedModel): 'model.layers', # LLaMa 'transformer.blocks', # MPT ) - return findattr(model, hidden_layers_attrs) + layers = findattr(model, hidden_layers_attrs) + if layers is None: + raise ValueError( + f'Unable to find hidden layer for {model}. Model must have one of the following attributes: {hidden_layers_attrs}' + ) + return layers def hf_get_init_device(init_device: Optional[str]): @@ -136,7 +141,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, # OPT has an extra layer of wrapping, so special case here if isinstance(causal_base_model, OPTDecoder): model.model._fsdp_wrap = False - model_block = hf_get_hidden_layers(model) # type: ignore + model_block = hf_get_hidden_layers(model) lm_head = model.get_output_embeddings() # some models (OPT) implement .get_input_embeddings for the causal subclass # but all of them implement it for the base model @@ -153,7 +158,7 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, raise ValueError( f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.') - block_type = type(model_block[0]) # type: ignore + block_type = type(model_block[0]) if init_device == 'mixed': # For FSDP with models with different device initializations, `mixed`, which # initializes the model on rank 0 on `cpu` and on all other ranks on `meta,`` @@ -186,9 +191,9 @@ def prepare_hf_causal_lm_model_for_fsdp(model: PreTrainedModel, # These lines ensures that both modules stay together in the top-most block when # the model has this tying enabled (almost all do; this property defaults to True) if model.config.tie_word_embeddings: - causal_base_model._fsdp_wrap = False # type: ignore - tied_embeddings._fsdp_wrap = False # type: ignore - lm_head._fsdp_wrap = False # type: ignore + causal_base_model._fsdp_wrap = False + tied_embeddings._fsdp_wrap = False + lm_head._fsdp_wrap = False # FSDP Wrap and Activation Checkpoint every model block model.fsdp_wrap_fn = lambda module: isinstance(module, block_type) @@ -228,15 +233,15 @@ def prepare_hf_enc_dec_model_for_fsdp(model: PreTrainedModel, raise ValueError( f'Unable to FSDP-wrap this model! `{mod_name}` does not ' + 'follow common layer/weight naming conventions.') - decoder_block_type = type(decoder_block[0]) # type: ignore - encoder_block_type = type(encoder_block[0]) # type: ignore + decoder_block_type = type(decoder_block[0]) + encoder_block_type = type(encoder_block[0]) if model.config.tie_word_embeddings: # it is possible to train an enc/dec without tied embeddings, hence the check - tied_embeddings._fsdp_wrap = False # type: ignore - encoder._fsdp_wrap = False # type: ignore - decoder._fsdp_wrap = False # type: ignore - lm_head._fsdp_wrap = False # type: ignore + tied_embeddings._fsdp_wrap = False + encoder._fsdp_wrap = False + decoder._fsdp_wrap = False + lm_head._fsdp_wrap = False # FSDP Wrap and Activation Checkpoint every decoder block model.fsdp_wrap_fn = lambda module: isinstance(module, decoder_block_type) diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 9114bc47aa..982f212493 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -464,7 +464,7 @@ def __init__( i * self.head_dim for i in range(1, self.n_heads + 2 * self.kv_n_heads) ] - self.Wqkv._fused = (0, fuse_splits) # type: ignore + self.Wqkv._fused = (0, fuse_splits) if self.qk_ln: norm_class = NORM_CLASS_REGISTRY[norm_type.lower()] @@ -499,7 +499,7 @@ def __init__( self.d_model, **fc_kwargs, ) - self.out_proj._is_residual = True # type: ignore + self.out_proj._is_residual = True def forward( self, diff --git a/llmfoundry/models/layers/ffn.py b/llmfoundry/models/layers/ffn.py index a02558102f..0b41a753d9 100644 --- a/llmfoundry/models/layers/ffn.py +++ b/llmfoundry/models/layers/ffn.py @@ -40,7 +40,7 @@ def __init__( d_model, **fc_kwargs, ) - self.down_proj._is_residual = True # type: ignore + self.down_proj._is_residual = True def forward(self, x: torch.Tensor): return self.down_proj(self.act(self.up_proj(x))) diff --git a/llmfoundry/models/layers/llama_attention_monkeypatch.py b/llmfoundry/models/layers/llama_attention_monkeypatch.py index b65a5cb300..0f75986e11 100644 --- a/llmfoundry/models/layers/llama_attention_monkeypatch.py +++ b/llmfoundry/models/layers/llama_attention_monkeypatch.py @@ -9,7 +9,8 @@ from typing import Callable, Optional, Tuple import torch -import torch.functional as F +import torch.nn.functional as F +from transformers.models.llama.modeling_llama import LlamaAttention from llmfoundry.models.layers.attention import ( scaled_multihead_dot_product_attention, triton_flash_attn_fn) @@ -42,8 +43,11 @@ def rotate_half(x: torch.Tensor): return torch.cat((-x2, x1), dim=-1) -def apply_rotary_pos_emb(q: torch.Tensor, k: torch.Tensor, cos: torch.Tensor, - sin: torch.Tensor, position_ids: torch.Tensor): +def apply_rotary_pos_emb(q: torch.Tensor, + k: torch.Tensor, + cos: torch.Tensor, + sin: torch.Tensor, + position_ids: Optional[torch.Tensor] = None): # The first two dimensions of cos and sin are always 1, so we can `squeeze` them. cos = cos.squeeze(1).squeeze(0) # [seq_len, dim] sin = sin.squeeze(1).squeeze(0) # [seq_len, dim] @@ -65,7 +69,7 @@ def get_llama_attention_patch_fn(patch_fn_name: str = 'torch') -> Callable: def llama_attention_patch_torch( - self, # type: ignore + self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -89,21 +93,19 @@ def llama_attention_patch_torch( value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, query_slices[i]) + F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) ] query_states = torch.cat(query_states, dim=-1) key_states = [ - F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty) + F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) ] key_states = torch.cat(key_states, dim=-1) value_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, value_slices[i]) + F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) ] value_states = torch.cat(value_states, dim=-1) @@ -123,9 +125,9 @@ def llama_attention_patch_torch( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, - position_ids) # type: ignore (thirdParty) + + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) ### MAIN MODIFICATIONS START HERE ### query_states = query_states.transpose(1, 2).view( @@ -160,21 +162,22 @@ def llama_attention_patch_torch( self.config.pretraining_tp, dim=1) attn_output = sum([ - F.linear( # type: ignore (thirdParty) - attn_output[i], o_proj_slices[i]) + F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp) ]) else: attn_output = self.o_proj(attn_output) + assert isinstance(attn_output, torch.Tensor) + if not output_attentions: attn_weights = None - return attn_output, attn_weights, None # type: ignore (thirdParty) + return attn_output, attn_weights, None def llama_attention_patch_triton( - self, # type: ignore + self: LlamaAttention, hidden_states: torch.Tensor, attention_mask: Optional[torch.Tensor] = None, position_ids: Optional[torch.LongTensor] = None, @@ -190,6 +193,7 @@ def llama_attention_patch_triton( raise NotImplementedError( 'output_attentions is not supported when patching Llama attention with triton attention.' ) + bsz, q_len, _ = hidden_states.size() if self.config.pretraining_tp > 1: @@ -202,21 +206,19 @@ def llama_attention_patch_triton( value_slices = self.v_proj.weight.split(key_value_slicing, dim=0) query_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, query_slices[i]) + F.linear(hidden_states, query_slices[i]) for i in range(self.config.pretraining_tp) ] query_states = torch.cat(query_states, dim=-1) key_states = [ - F.linear(hidden_states, key_slices[i]) # type: ignore (thirdParty) + F.linear(hidden_states, key_slices[i]) for i in range(self.config.pretraining_tp) ] key_states = torch.cat(key_states, dim=-1) value_states = [ - F.linear( # type: ignore (thirdParty) - hidden_states, value_slices[i]) + F.linear(hidden_states, value_slices[i]) for i in range(self.config.pretraining_tp) ] value_states = torch.cat(value_states, dim=-1) @@ -236,9 +238,8 @@ def llama_attention_patch_triton( if past_key_value is not None: kv_seq_len += past_key_value[0].shape[-2] cos, sin = self.rotary_emb(value_states, seq_len=kv_seq_len) - query_states, key_states = apply_rotary_pos_emb( - query_states, key_states, cos, sin, - position_ids) # type: ignore (thirdParty) + query_states, key_states = apply_rotary_pos_emb(query_states, key_states, + cos, sin, position_ids) ### MAIN MODIFICATIONS START HERE ### query_states = query_states.transpose(1, 2).view( @@ -273,11 +274,12 @@ def llama_attention_patch_triton( self.config.pretraining_tp, dim=1) attn_output = sum([ - F.linear( # type: ignore (thirdParty) - attn_output[i], o_proj_slices[i]) + F.linear(attn_output[i], o_proj_slices[i]) for i in range(self.config.pretraining_tp) ]) else: attn_output = self.o_proj(attn_output) - return attn_output, None, None # type: ignore (thirdParty) + assert isinstance(attn_output, torch.Tensor) + + return attn_output, None, None diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 5481201a8f..75ebfbd67c 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -307,12 +307,10 @@ def forward( if use_cache is not None else self.config.use_cache) if attention_mask is not None: - attention_mask = attention_mask.bool( - ) # type: ignore (TODO to figure out the right type here) + attention_mask = attention_mask.bool() # type: ignore if prefix_mask is not None: - prefix_mask = prefix_mask.bool( - ) # type: ignore (TODO to figure out the right type here) + prefix_mask = prefix_mask.bool() # type: ignore # These args are passed in by keyword in huggingface's generate function # https://github.com/huggingface/transformers/blob/68287689f2f0d8b7063c400230b3766987abf18d/src/transformers/generation/utils.py#L2201-L2206 @@ -360,7 +358,7 @@ def forward( S <= self.config.max_seq_len ), f'Cannot forward input with seq_len={S}, this model only supports seq_len<={self.config.max_seq_len}' - tok_emb = self.wte(input_ids) # type: ignore + tok_emb = self.wte(input_ids) if self.learned_pos_emb: past_position = 0 if past_key_values is not None: @@ -397,14 +395,14 @@ def forward( min=0, ) - pos_emb = self.wpe(pos) # type: ignore + pos_emb = self.wpe(pos) x = tok_emb + pos_emb else: # ALiBi and NoPE use this path (RoPE will also use this path if / when enabled) x = tok_emb if self.embedding_fraction == 1: - x = self.emb_drop(x) # type: ignore + x = self.emb_drop(x) else: # this implementation is proposed on page 7 of the GLM-130B paper https://arxiv.org/abs/2210.02414 x_shrunk = (x * self.embedding_fraction) + ( @@ -427,7 +425,7 @@ def forward( all_hidden_states = () if output_hidden_states else None all_self_attns = () if output_attentions else None - for b_idx, block in enumerate(self.blocks): # type: ignore + for b_idx, block in enumerate(self.blocks): if output_hidden_states: assert all_hidden_states is not None # pyright all_hidden_states = all_hidden_states + (x,) @@ -447,7 +445,7 @@ def forward( assert all_self_attns is not None # pyright all_self_attns = all_self_attns + (attn_weights,) - x = self.norm_f(x) # type: ignore + x = self.norm_f(x) # add hidden states from the last decoder layer if output_hidden_states: @@ -716,7 +714,8 @@ def __init__( loss_fn_config = om_model_config.get('loss_fn', 'fused_crossentropy') if loss_fn_config == 'fused_crossentropy': try: - from flash_attn.losses.cross_entropy import CrossEntropyLoss as FusedCrossEntropyLoss # type: ignore # isort: skip + from flash_attn.losses.cross_entropy import \ + CrossEntropyLoss as FusedCrossEntropyLoss if hf_config.verbose > 1: warnings.warn('Using Fused Cross Entropy Loss.') diff --git a/llmfoundry/models/utils/__init__.py b/llmfoundry/models/utils/__init__.py index 6ec67a5b71..35a15e530a 100644 --- a/llmfoundry/models/utils/__init__.py +++ b/llmfoundry/models/utils/__init__.py @@ -7,8 +7,8 @@ add_bidirectional_mask_if_missing, convert_hf_causal_lm_to_prefix_lm) from llmfoundry.models.utils.meta_init_context import (init_empty_weights, init_on_device) -from llmfoundry.models.utils.param_init_fns import ( # type: ignore - MODEL_INIT_REGISTRY, generic_param_init_fn_) +from llmfoundry.models.utils.param_init_fns import (MODEL_INIT_REGISTRY, + generic_param_init_fn_) __all__ = [ 'AutoTokenizerForMOD', diff --git a/llmfoundry/models/utils/hf_prefixlm_converter.py b/llmfoundry/models/utils/hf_prefixlm_converter.py index ae8ed444c8..fb9477d909 100644 --- a/llmfoundry/models/utils/hf_prefixlm_converter.py +++ b/llmfoundry/models/utils/hf_prefixlm_converter.py @@ -79,7 +79,7 @@ def _get_attn_modules(model: CAUSAL_GPT_TYPES) -> List[torch.nn.Module]: else: blocks = model.transformer.h - for block in blocks: # type: ignore + for block in blocks: if isinstance(model, GPTNeoForCausalLM): # Ignore "local" layers in this model type if block.attn.attention_type != 'global': @@ -150,7 +150,7 @@ def call_og_forward(): if bidirectional_mask is None: # This wrapper is a no-op if bidirectional masks are not supplied - return call_og_forward() # type: ignore + return call_og_forward() assert isinstance(bidirectional_mask, torch.Tensor) attn_modules = _get_attn_modules(model) @@ -158,7 +158,9 @@ def call_og_forward(): # Handle bidirectional_mask sizing # Note: all attn_modules.bias have the same size b, s = bidirectional_mask.shape + max_length = attn_modules[0].bias.shape[-1] # type: ignore + if s > max_length: raise ValueError( f'bidirectional_mask sequence length (={s}) exceeds the ' +\ @@ -174,8 +176,9 @@ def call_og_forward(): # Incorporate the bidirectional mask into the original causal mask for attn_module in attn_modules: - attn_module.bias.data = torch.logical_or( - attn_module.bias.data, bidirectional) # type: ignore + assert isinstance(attn_module.bias, torch.Tensor) + attn_module.bias.data = torch.logical_or(attn_module.bias.data, + bidirectional) # Collect outputs using the model's original forward method output = call_og_forward() @@ -201,7 +204,7 @@ def generate(self: CAUSAL_GPT_TYPES, *args: Any, **kwargs: Any): attn_module.bias.data[:] = 1 # type: ignore # Collect outputs using the model's original forward method - output = self._original_generate(*args, **kwargs) # type: ignore + output = self._original_generate(*args, **kwargs) # Reset the masks for attn_module in attn_modules: @@ -330,7 +333,7 @@ def _build_alibi_tensor( # and one new argument (`bidirectional_mask`) is added to the signature. KeyValueT = Tuple[torch.Tensor, torch.Tensor] - def forward( # type: ignore + def transformer_forward( self: BloomModel, input_ids: Optional[torch.LongTensor] = None, past_key_values: Optional[Tuple[KeyValueT, ...]] = None, @@ -342,8 +345,9 @@ def forward( # type: ignore output_attentions: Optional[bool] = None, output_hidden_states: Optional[bool] = None, return_dict: Optional[bool] = None, - **deprecated_arguments: Any) -> Union[Tuple[ - torch.Tensor, ...], BaseModelOutputWithPastAndCrossAttentions]: + **deprecated_arguments: Any + ) -> Union[Tuple[torch.Tensor, ...], + BaseModelOutputWithPastAndCrossAttentions]: if deprecated_arguments.pop('position_ids', False) is not False: # `position_ids` could have been `torch.Tensor` or `None` so # defaulting pop to `False` allows to detect if users were @@ -501,8 +505,8 @@ def custom_forward(*inputs: Any): MethodType(_prepare_attn_mask, model.transformer)) setattr(model.transformer, '_build_alibi_tensor', MethodType(_build_alibi_tensor, model.transformer)) - setattr(model.transformer, 'forward', MethodType(forward, - model.transformer)) + setattr(model.transformer, 'forward', + MethodType(transformer_forward, model.transformer)) # In order to actually use the new argument we've added to # model.transformer, we need to update the parent module's `forward` to diff --git a/llmfoundry/models/utils/meta_init_context.py b/llmfoundry/models/utils/meta_init_context.py index eee30f9357..c22c226c28 100644 --- a/llmfoundry/models/utils/meta_init_context.py +++ b/llmfoundry/models/utils/meta_init_context.py @@ -80,21 +80,27 @@ def init_on_device(device: torch.device, include_buffers: bool = False): if include_buffers: old_register_buffer = nn.Module.register_buffer - def register_empty_parameter(module: torch.nn.Module, name: str, + def register_empty_parameter(self: torch.nn.Module, name: str, param: Optional[torch.nn.Parameter]): - old_register_parameter(module, name, param) + old_register_parameter(self, name, param) if param is not None: - param_cls = type(module._parameters[name]) - kwargs = module._parameters[name].__dict__ # type: ignore - module._parameters[name] = param_cls( - module._parameters[name].to(device), **kwargs) # type: ignore - - def register_empty_buffer(module: torch.nn.Module, name: str, - buffer: Optional[torch.Tensor]): - old_register_buffer(module, name, buffer) - if buffer is not None: - module._buffers[name] = module._buffers[name].to( # type: ignore - device) + parameter = self._parameters[name] + assert parameter is not None + + param_cls = type(parameter) + kwargs = parameter.__dict__ + + self._parameters[name] = param_cls(parameter.to(device), **kwargs) + + def register_empty_buffer(self: torch.nn.Module, + name: str, + tensor: Optional[torch.Tensor], + persistent: bool = True): + old_register_buffer(self, name, tensor, persistent=persistent) + if tensor is not None: + named_buffer = self._buffers[name] + assert named_buffer is not None + self._buffers[name] = named_buffer.to(device) # Patch tensor creation if include_buffers: @@ -114,9 +120,9 @@ def wrapper(*args: Any, **kwargs: Any): return wrapper try: - nn.Module.register_parameter = register_empty_parameter # type: ignore + nn.Module.register_parameter = register_empty_parameter if include_buffers: - nn.Module.register_buffer = register_empty_buffer # type: ignore + nn.Module.register_buffer = register_empty_buffer for torch_function_name in tensor_constructors_to_patch.keys(): setattr( torch, torch_function_name, diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 0dbb4e4a6f..388ef10e29 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -29,8 +29,9 @@ def torch_default_param_init_fn_( warnings.warn( f"Initializing network using module's reset_parameters attribute") - if hasattr(module, 'reset_parameters'): - module.reset_parameters() # type: ignore + if hasattr(module, 'reset_parameters') and isinstance( + module.reset_parameters, Callable): + module.reset_parameters() def fused_init_helper_(module: nn.Module, init_fn_: Callable): @@ -46,12 +47,14 @@ def fused_init_helper_(module: nn.Module, init_fn_: Callable): if _fused is None: raise RuntimeError(f'Internal logic error') + assert isinstance(module.weight, torch.Tensor) + dim, splits = _fused - splits = (0, *splits, module.weight.size(dim)) # type: ignore + splits = (0, *splits, module.weight.size(dim)) for s, e in zip(splits[:-1], splits[1:]): - slice_indices = [slice(None)] * module.weight.ndim # type: ignore + slice_indices = [slice(None)] * module.weight.ndim slice_indices[dim] = slice(s, e) - init_fn_(module.weight[slice_indices]) # type: ignore + init_fn_(module.weight[slice_indices]) def generic_param_init_fn_( @@ -82,9 +85,7 @@ def generic_param_init_fn_( elif isinstance(init_div_is_residual, float) or isinstance( init_div_is_residual, int): div_is_residual = init_div_is_residual - elif isinstance( - init_div_is_residual, # type: ignore - str) and init_div_is_residual.isnumeric(): + elif init_div_is_residual.isnumeric(): # do not trust YAML parsing to always convert numbers to numbers div_is_residual = float(init_div_is_residual) else: @@ -151,17 +152,17 @@ def generic_param_init_fn_( emb_init_fn_(module.weight) - elif isinstance(module, - tuple(set(NORM_CLASS_REGISTRY.values()))): # type: ignore + elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): # Norm if verbose > 1: warnings.warn( f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.' ) - if hasattr(module, 'weight') and module.weight is not None: - torch.nn.init.ones_(module.weight) # type: ignore - if hasattr(module, 'bias') and module.bias is not None: - torch.nn.init.zeros_(module.bias) # type: ignore + if hasattr(module, 'weight') and isinstance(module.weight, + torch.Tensor): + torch.nn.init.ones_(module.weight) + if hasattr(module, 'bias') and isinstance(module.bias, torch.Tensor): + torch.nn.init.zeros_(module.bias) elif isinstance(module, nn.MultiheadAttention): # torch's MultiheadAttention @@ -199,10 +200,10 @@ def generic_param_init_fn_( torch.nn.init.zeros_(module.out_proj.bias) elif te is not None and isinstance(module, te.LayerNormMLP): - if module.layer_norm_weight is not None: - torch.nn.init.ones_(module.layer_norm_weight) # type: ignore - if module.layer_norm_bias is not None: - torch.nn.init.zeros_(module.layer_norm_bias) # type: ignore + if isinstance(module.layer_norm_weight, torch.Tensor): + torch.nn.init.ones_(module.layer_norm_weight) + if isinstance(module.layer_norm_bias, torch.Tensor): + torch.nn.init.zeros_(module.layer_norm_bias) init_fn_(module.fc1_weight) if module.fc1_bias is not None: diff --git a/llmfoundry/optim/outlier_detection.py b/llmfoundry/optim/outlier_detection.py index 71ff778166..b485a17c5d 100644 --- a/llmfoundry/optim/outlier_detection.py +++ b/llmfoundry/optim/outlier_detection.py @@ -41,7 +41,9 @@ def insert_observation(self, obs: float) -> bool: Returns: bool: Indicator of whether the most recent observation was an outlier. """ - if len(self.intermediate_data_queue # type: ignore + assert self.intermediate_data_queue.maxlen is not None, 'expected maxlen defined' + + if len(self.intermediate_data_queue ) >= self.intermediate_data_queue.maxlen: # move data from intermediate queue to slow moving average queue intermediate_obs = self.intermediate_data_queue.popleft() diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index 50dd30c45d..01948822c2 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -257,7 +257,7 @@ def build_dataloader(dataset: Dataset, batch_size: int, if num_workers is None: # Multiple workers is only supported on linux machines if 'linux' or 'macos' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) # type: ignore + num_workers = max(1, psutil.cpu_count()) else: num_workers = 0 diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index ecccb5f0f7..d2bf8923df 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -96,12 +96,13 @@ def __iter__(self) -> Iterable[Dict[str, bytes]]: yield {key: sample[key].encode('utf-8') for key in self.columns} -def build_dataloader(dataset: SimpleDataset, batch_size: int, - num_workers: int) -> DataLoader: +def build_dataloader(dataset: SimpleDataset, + batch_size: int, + num_workers: Optional[int] = None) -> DataLoader: if num_workers is None: # Multiple workers is only supported on linux machines if 'linux' in platform.platform().lower(): - num_workers = max(1, psutil.cpu_count()) # type: ignore + num_workers = max(1, psutil.cpu_count()) else: num_workers = 0 diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 31743f1752..3408f31d33 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -153,7 +153,7 @@ def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int], callbacks=callbacks, loggers=loggers, precision=precision, - fsdp_config=fsdp_config, # type: ignore + fsdp_config=fsdp_config, load_path=load_path, load_weights_only=True, progress_bar=False, diff --git a/scripts/inference/convert_hf_to_onnx.py b/scripts/inference/convert_hf_to_onnx.py index dd7a6f7a62..f73836e28f 100644 --- a/scripts/inference/convert_hf_to_onnx.py +++ b/scripts/inference/convert_hf_to_onnx.py @@ -138,9 +138,9 @@ def export_to_onnx( with torch.no_grad(): orig_out = model(**sample_input) - import onnx # type: ignore - import onnx.checker # type: ignore - import onnxruntime as ort # type: ignore + import onnx + import onnx.checker + import onnxruntime as ort _ = onnx.load(str(output_file)) diff --git a/scripts/inference/hf_generate.py b/scripts/inference/hf_generate.py index cc331acd14..96592ca477 100644 --- a/scripts/inference/hf_generate.py +++ b/scripts/inference/hf_generate.py @@ -293,8 +293,9 @@ def _generate(encoded_inp: Dict[str, torch.Tensor]): maybe_synchronize() encode_end = time.time() input_tokens = torch.sum( - encoded_inp['input_ids'] != tokenizer.pad_token_id, - axis=1).numpy(force=True) # type: ignore + encoded_inp['input_ids'] != + tokenizer.pad_token_id, # type: ignore + axis=1).numpy(force=True) # Warmup if args.warmup and (not done_warmup): diff --git a/scripts/train/train.py b/scripts/train/train.py index 58fa67afe1..95aa22f44c 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -90,7 +90,7 @@ def validate_config(cfg: DictConfig): '`te.LayerNormMLP` requires has issues with torch._dynamo. ' + 'Setting `torch._dynamo.config.suppress_errors = True` and falling back to eager.' ) - torch._dynamo.config.suppress_errors = True # type: ignore + torch._dynamo.config.suppress_errors = True # type: ignore (third-party) if cfg.model.get('load_in_8bit', False): raise ValueError( @@ -530,7 +530,7 @@ def main(cfg: DictConfig) -> Trainer: precision=precision, algorithms=algorithms, device_train_microbatch_size=device_train_microbatch_size, - fsdp_config=fsdp_config, # type: ignore + fsdp_config=fsdp_config, save_folder=save_folder, save_filename=save_filename, save_latest_filename=save_latest_filename, diff --git a/tests/test_hf_v_mpt.py b/tests/test_hf_v_mpt.py index 22c9241037..82e2d05550 100644 --- a/tests/test_hf_v_mpt.py +++ b/tests/test_hf_v_mpt.py @@ -165,7 +165,7 @@ def test_compare_hf_v_mpt(attn_impl: str, dropout: float, alibi: bool, print(f'{hf_model_fwd.mean().item() = }\n{model_fwd.mean().item() = }') if hf_model_fwd.mean().allclose(model_fwd.mean()): warn_msg = f'WARNING: model_fwd ({model_fwd}) and hf_model_fwd ({hf_model_fwd}) are very close at init.' - raise warnings.warn(warn_msg) # type: ignore + raise RuntimeError(warn_msg) hf_model_statedict = hf_model.state_dict() diff --git a/tests/test_init_fn.py b/tests/test_init_fn.py index 9355e7a277..b054bac186 100644 --- a/tests/test_init_fn.py +++ b/tests/test_init_fn.py @@ -23,7 +23,7 @@ def __init__(self, cfg: Union[ListConfig, DictConfig]): self.fc1 = nn.Linear(cfg.in_features, cfg.out_features, bias=True) self.ln_1 = nn.LayerNorm(cfg.out_features) self.fc2 = nn.Linear(cfg.out_features, cfg.out_features, bias=True) - self.fc2._is_residual = True # type: ignore + self.fc2._is_residual = True def forward(self, x: torch.Tensor): y = self.ln_1(self.fc1(x)) @@ -76,7 +76,7 @@ def test_fused_init_helper(fused: bool): fc = nn.Linear(cfg.in_features, cfg.out_features, bias=True) fc.train() if fused: - fc._fused = (0, (cfg.out_features // 2,)) # type: ignore + fc._fused = (0, (cfg.out_features // 2,)) def init_fn_(weight: torch.Tensor): # dummy init based on layer width @@ -159,18 +159,18 @@ def test_emb_init(emb_init_cfg: Optional[Tuple[str, Union[int, List[int]]]]): model.apply(partial(MODEL_INIT_REGISTRY['kaiming_normal_'], **dict_cfg)) + assert isinstance(model.emb, torch.nn.Embedding) + if dict_cfg.get('emb_init_std') is not None: emb_init_std = dict_cfg.get('emb_init_std') if emb_init_std == 0: - assert (model.emb.weight == 0).all() # type: ignore + assert (model.emb.weight == 0).all() elif dict_cfg.get('emb_init_uniform_lim') is not None: emb_init_uniform_lim = dict_cfg.get('emb_init_uniform_lim') if emb_init_uniform_lim == 0: - assert (model.emb.weight == 0).all() # type: ignore + assert (model.emb.weight == 0).all() elif isinstance(emb_init_uniform_lim, Sequence): assert len(emb_init_uniform_lim) <= 2 if len(emb_init_uniform_lim ) == 2 and emb_init_uniform_lim[0] == emb_init_uniform_lim[1]: - assert isinstance(model.emb, torch.nn.Embedding) - assert (model.emb.weight == emb_init_uniform_lim[0] - ).all() # type: ignore + assert (model.emb.weight == emb_init_uniform_lim[0]).all() diff --git a/tests/test_model.py b/tests/test_model.py index d26e86b59e..f20381f288 100644 --- a/tests/test_model.py +++ b/tests/test_model.py @@ -44,6 +44,12 @@ def get_config( return cast(DictConfig, test_cfg) +def _load_tokenizer_cfg(cfg: DictConfig) -> Dict: + config = om.to_container(cfg, resolve=True) + assert isinstance(config, Dict) + return config + + def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): warnings.filterwarnings( action='ignore', @@ -74,9 +80,7 @@ def get_objs(conf_path: str = 'scripts/train/yamls/pretrain/testing.yaml'): test_cfg.device_eval_batch_size = 2 test_cfg.device_train_microbatch_size = 2 - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -176,7 +180,7 @@ def test_attention_mechanism(batch_size: int = 2): # and with 0 everywhere else expected_zerod_weights = nn.Transformer.generate_square_subsequent_mask(test_cfg.max_seq_len)\ .reshape(1, test_cfg.max_seq_len, test_cfg.max_seq_len) - expected_zerod_weights = torch.isneginf( # type: ignore + expected_zerod_weights = torch.isneginf( torch.cat(batch_size * [expected_zerod_weights])) torch_key_padding = torch.cat( # type: ignore test_cfg.max_seq_len * @@ -225,9 +229,7 @@ def test_full_forward_and_backward_gpt2_small(prefixlm: bool, else: neo_cfg.model.name = 'hf_causal_lm' - tokenizer_cfg: Dict[str, - Any] = om.to_container(neo_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(neo_cfg.tokenizer) tokenizer = build_tokenizer(neo_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -272,9 +274,7 @@ def test_full_forward_and_backward_t5_small(batch_size: int = 2): t5_cfg.device = device t5_cfg.max_seq_len = 16 - tokenizer_cfg: Dict[str, - Any] = om.to_container(t5_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(t5_cfg.tokenizer) tokenizer = build_tokenizer(t5_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -328,9 +328,7 @@ def test_determinism(attn_impl: str, precision: torch.dtype): test_cfg.model.init_device = 'cuda:0' test_cfg.device = 'cuda:0' - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -379,8 +377,9 @@ def test_loss_fn(): pytest.skip('Fused cross entropy was not installed') # run numerical test in pure fp32 - torch.backends.cuda.matmul.allow_tf32 = False # type: ignore (third-party) - torch.backends.cudnn.allow_tf32 = False # type: ignore (third-party) + from torch.backends import cuda, cudnn + cuda.matmul.allow_tf32 = False + cudnn.allow_tf32 = False conf_path = 'scripts/train/yamls/pretrain/testing.yaml' with open(conf_path) as f: @@ -397,9 +396,7 @@ def test_loss_fn(): reproducibility.seed_all(test_cfg.get('global_seed', 42)) - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -460,9 +457,7 @@ def test_opt_wrapping(prefixlm: bool): } config = DictConfig(conf) - tokenizer_cfg: Dict[str, - Any] = om.to_container(config.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(config.tokenizer) tokenizer = build_tokenizer(config.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -1414,9 +1409,7 @@ def test_hf_init(tmp_path: pathlib.Path, model = AutoModelForCausalLM.from_pretrained(save_path, trust_remote_code=True) - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) @@ -1466,9 +1459,7 @@ def test_head_dim_8_triton_mqa_attn(batch_size: int = 2): ) test_cfg.device = torch.cuda.current_device() - tokenizer_cfg: Dict[str, - Any] = om.to_container(test_cfg.tokenizer, - resolve=True) # type: ignore + tokenizer_cfg: Dict[str, Any] = _load_tokenizer_cfg(test_cfg.tokenizer) tokenizer = build_tokenizer(test_cfg.tokenizer.name, tokenizer_cfg.get('kwargs', {})) From 0fdf43fffe692eeee1203be8458d6a91772a26e5 Mon Sep 17 00:00:00 2001 From: Hanlin Tang Date: Tue, 12 Sep 2023 17:08:56 -0700 Subject: [PATCH 4/4] Refactor logging (#234) Replaces most print statements with proper logging. Deprecates the `verbose` argument in favor of using the `python_log_level` argument that is also used by composer. --- .../callbacks/eval_gauntlet_callback.py | 12 +-- llmfoundry/callbacks/resumption_callbacks.py | 7 +- llmfoundry/data/data.py | 1 - llmfoundry/data/finetuning/dataloader.py | 4 +- llmfoundry/data/finetuning/tasks.py | 34 ++++----- llmfoundry/models/hf/hf_causal_lm.py | 5 +- llmfoundry/models/layers/attention.py | 18 ----- llmfoundry/models/layers/blocks.py | 2 - llmfoundry/models/mpt/configuration_mpt.py | 9 ++- llmfoundry/models/mpt/modeling_mpt.py | 26 +++---- llmfoundry/models/utils/param_init_fns.py | 76 +------------------ llmfoundry/utils/builders.py | 5 +- .../utils/checkpoint_conversion_helpers.py | 9 ++- llmfoundry/utils/config_utils.py | 7 +- scripts/data_prep/convert_dataset_hf.py | 3 +- .../data_prep/convert_finetuning_dataset.py | 4 +- scripts/eval/eval.py | 42 +++++++--- scripts/train/train.py | 19 ++++- 18 files changed, 111 insertions(+), 172 deletions(-) diff --git a/llmfoundry/callbacks/eval_gauntlet_callback.py b/llmfoundry/callbacks/eval_gauntlet_callback.py index 88f47e36fc..b1570e9793 100644 --- a/llmfoundry/callbacks/eval_gauntlet_callback.py +++ b/llmfoundry/callbacks/eval_gauntlet_callback.py @@ -3,6 +3,7 @@ """Aggregate ICL evals into composite scores.""" +import logging import math from enum import Enum from typing import Optional @@ -12,6 +13,8 @@ __all__ = ['EvalGauntlet'] +log = logging.getLogger(__name__) + class Weighting(Enum): EQUAL = 1 @@ -130,9 +133,8 @@ def eval_after_all(self, state: State, logger: Logger): key = f"{benchmark['name']}/{benchmark['num_fewshot']}-shot" if key not in new_metrics: - print( - f"Warning: couldn't find results for benchmark: {benchmark}" - ) + log.warning( + f'Could not find results for benchmark: {benchmark}.') missing_metrics.append(key) else: score = new_metrics[key] @@ -150,8 +152,8 @@ def eval_after_all(self, state: State, logger: Logger): }) if len(missing_metrics) > 0: - print( - f"Removing category `{category['name']}` from gauntlet scores because benchmarks were missing: {missing_metrics}" + log.warning( + f"Removing category `{category['name']}` from scores because benchmarks were missing: {missing_metrics}" ) del composite_scores[category['name']] continue diff --git a/llmfoundry/callbacks/resumption_callbacks.py b/llmfoundry/callbacks/resumption_callbacks.py index b89989d301..b5e20a7a57 100644 --- a/llmfoundry/callbacks/resumption_callbacks.py +++ b/llmfoundry/callbacks/resumption_callbacks.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import logging from typing import List from composer.core import Callback, State @@ -11,6 +12,8 @@ 'LayerFreezing', ] +log = logging.getLogger(__name__) + class GlobalLRScaling(Callback): """GlobalLRScaling. @@ -38,7 +41,7 @@ def fit_start(self, state: State, logger: Logger): group['weight_decay'] = group['lr'] * self.wd_pct if 'initial_lr' in group: group['initial_lr'] *= self.lr_scale - print( + log.info( f"Set LR and WD to {group['lr']}, {group['weight_decay']}") for scheduler in state.schedulers: @@ -74,7 +77,7 @@ def fit_start(self, state: State, logger: Logger): for name, p in state.model.named_parameters(): if p.requires_grad and name in self.layer_names: p.requires_grad = False - print(f'Froze layer: {name}\nParam: {p}') + log.debug(f'Froze layer: {name}\nParam: {p}') successful_freeze = True if not successful_freeze: diff --git a/llmfoundry/data/data.py b/llmfoundry/data/data.py index ef758dfcef..92e4538d73 100644 --- a/llmfoundry/data/data.py +++ b/llmfoundry/data/data.py @@ -24,7 +24,6 @@ def __init__(self, hf_dataset: Union[hf_datasets.IterableDataset, def __iter__(self) -> Iterable[Dict[str, bytes]]: for sample in self.hf_dataset: - # print(sample) # convert to bytes to store in MDS binary format yield {'text': sample['text'].encode('utf-8')} diff --git a/llmfoundry/data/finetuning/dataloader.py b/llmfoundry/data/finetuning/dataloader.py index 1ed3d56d60..b0d175f2a8 100644 --- a/llmfoundry/data/finetuning/dataloader.py +++ b/llmfoundry/data/finetuning/dataloader.py @@ -323,7 +323,7 @@ def _build_hf_dataset_from_remote( f'at {files_searched}' ) from e else: - print( + log.debug( f'Could not find {name}, looking for another extension') continue @@ -343,7 +343,7 @@ def _build_hf_dataset_from_remote( dist.barrier() cfg.dataset.hf_name = finetune_dir - print(cfg.dataset) + log.info(cfg.dataset) dataset = dataset_constructor.build_from_hf( cfg.dataset, max_seq_len=cfg.dataset.max_seq_len, diff --git a/llmfoundry/data/finetuning/tasks.py b/llmfoundry/data/finetuning/tasks.py index f4da9750c7..f5e6ac6b27 100644 --- a/llmfoundry/data/finetuning/tasks.py +++ b/llmfoundry/data/finetuning/tasks.py @@ -32,6 +32,7 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: """ import importlib +import logging import os import warnings from typing import Any, Callable, Dict, Optional, Union @@ -41,6 +42,8 @@ def preprocessing_fn(example: Dict) -> Dict[str, str]: from streaming import StreamingDataset from transformers import PreTrainedTokenizerBase +log = logging.getLogger(__name__) + __all__ = ['dataset_constructor'] @@ -205,8 +208,7 @@ def _preprocessor(example: Dict[str, Any]) -> Dict[str, str]: def get_preprocessing_fn_from_str(self, preprocessor: Optional[str], - dataset_name: Optional[str] = None, - verbose: bool = False): + dataset_name: Optional[str] = None): """Get a preprocessing function from a string. String can be either a registered function or an import path. @@ -214,7 +216,6 @@ def get_preprocessing_fn_from_str(self, Args: preprocessor (Optional[str]): The name of the preprocessing function, or an import path. dataset_name (Optional[str]): The dataset name to look up in the registry. - verbose (bool): Whether to print verbose messages or not. Returns: Callable: The preprocessing function or None if not found. @@ -226,33 +227,24 @@ def get_preprocessing_fn_from_str(self, if dataset_name is None: return None if dataset_name in self._task_preprocessing_registry: - if verbose: - print( - f'Re-formatting dataset with "{dataset_name}" preprocessing function.' - ) + log.info( + f'Re-formatting dataset with "{dataset_name}" preprocessing function.' + ) return self._task_preprocessing_registry[dataset_name] else: - if verbose: - print( - 'No preprocessor was supplied and no preprocessing function ' +\ + log.info('No preprocessor was supplied and no preprocessing function ' +\ f'is registered for dataset name "{dataset_name}". No additional ' +\ 'preprocessing will be applied. If the dataset is already formatted ' +\ - 'correctly, you can ignore this message.' - ) + 'correctly, you can ignore this message.') return None if preprocessor in self._task_preprocessing_registry: - if verbose: - print( - f'Re-formatting dataset with "{preprocessor}" preprocessing function.' - ) + log.info( + f'Re-formatting dataset with "{preprocessor}" preprocessing function.' + ) return self._task_preprocessing_registry[preprocessor] try: import_path, function_name = preprocessor.split(':', maxsplit=1) - if verbose: - print( - f'Importing preprocessing function via: `from {import_path} import {function_name}`' - ) module = importlib.import_module(import_path) preprocessing_fn = getattr(module, function_name) except Exception as e: @@ -289,7 +281,7 @@ def build_from_hf( proto_preprocessing_fn) else: preprocessing_fn = self.get_preprocessing_fn_from_str( - proto_preprocessing_fn, dataset_name, verbose=True) + proto_preprocessing_fn, dataset_name) dataset = hf_datasets.load_dataset(dataset_name, split=split, **kwargs) diff --git a/llmfoundry/models/hf/hf_causal_lm.py b/llmfoundry/models/hf/hf_causal_lm.py index 746499fdfb..3100478a27 100644 --- a/llmfoundry/models/hf/hf_causal_lm.py +++ b/llmfoundry/models/hf/hf_causal_lm.py @@ -3,6 +3,7 @@ """Implements a Hugging Causal LM wrapped inside a :class:`.ComposerModel`.""" +import logging import os from typing import Mapping, Union @@ -35,6 +36,8 @@ __all__ = ['ComposerHFCausalLM'] +log = logging.getLogger(__name__) + class ComposerHFCausalLM(HuggingFaceModelWithZLoss): """Configures a :class:`.HuggingFaceModel` around a Causal LM. @@ -185,7 +188,7 @@ def __init__(self, om_model_config: Union[DictConfig, f'attention_patch_type is only supported for llama models, but got {model.config.model_type}' ) - print( + log.debug( f'Patching llama attention with {attention_patch_type} attention' ) from transformers.models.llama.modeling_llama import \ diff --git a/llmfoundry/models/layers/attention.py b/llmfoundry/models/layers/attention.py index 982f212493..6ac496ebd8 100644 --- a/llmfoundry/models/layers/attention.py +++ b/llmfoundry/models/layers/attention.py @@ -418,7 +418,6 @@ def __init__( attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', - verbose: int = 0, device: Optional[str] = None, ): super().__init__() @@ -476,21 +475,8 @@ def __init__( self.attn_fn = flash_attn_fn elif self.attn_impl == 'triton': self.attn_fn = triton_flash_attn_fn - if verbose: - warnings.warn( - 'While `attn_impl: triton` can be faster than `attn_impl: flash` ' +\ - 'it uses more memory. When training larger models this can trigger ' +\ - 'alloc retries which hurts performance. If encountered, we recommend ' +\ - 'using `attn_impl: flash` if your model does not use `alibi` or `prefix_lm`.' - ) elif self.attn_impl == 'torch': self.attn_fn = scaled_multihead_dot_product_attention - if torch.cuda.is_available() and verbose: - warnings.warn( - 'Using `attn_impl: torch`. If your model does not use `alibi` or ' +\ - '`prefix_lm` we recommend using `attn_impl: flash` otherwise ' +\ - 'we recommend using `attn_impl: triton`.' - ) else: raise ValueError(f'{attn_impl=} is an invalid setting.') @@ -569,7 +555,6 @@ def __init__( attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', - verbose: int = 0, device: Optional[str] = None, ): super().__init__( @@ -583,7 +568,6 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - verbose=verbose, device=device) @@ -605,7 +589,6 @@ def __init__( attn_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', fc_type: str = 'torch', - verbose: int = 0, device: Optional[str] = None, ): super().__init__( @@ -619,7 +602,6 @@ def __init__( attn_pdrop=attn_pdrop, norm_type=norm_type, fc_type=fc_type, - verbose=verbose, device=device) diff --git a/llmfoundry/models/layers/blocks.py b/llmfoundry/models/layers/blocks.py index b5a3ff8d68..dd208302b8 100644 --- a/llmfoundry/models/layers/blocks.py +++ b/llmfoundry/models/layers/blocks.py @@ -24,7 +24,6 @@ def __init__( ffn_config: Optional[Dict] = None, resid_pdrop: float = 0.0, norm_type: str = 'low_precision_layernorm', - verbose: int = 0, fc_type: str = 'torch', device: Optional[str] = None, **kwargs: Any, @@ -70,7 +69,6 @@ def __init__( self.attn = attn_class(d_model=d_model, n_heads=n_heads, fc_type=fc_type, - verbose=verbose, device=device, **attn_config_subset_for_attn_class) self.norm_2 = None diff --git a/llmfoundry/models/mpt/configuration_mpt.py b/llmfoundry/models/mpt/configuration_mpt.py index 08c02fa3b1..38946b47c8 100644 --- a/llmfoundry/models/mpt/configuration_mpt.py +++ b/llmfoundry/models/mpt/configuration_mpt.py @@ -56,12 +56,12 @@ def __init__( init_device: str = 'cpu', logit_scale: Optional[Union[float, str]] = None, no_bias: bool = False, - verbose: int = 0, embedding_fraction: float = 1.0, norm_type: str = 'low_precision_layernorm', use_cache: bool = False, init_config: Dict = init_config_defaults, fc_type: str = 'torch', + verbose: Optional[int] = None, **kwargs: Any, ): """The MPT configuration class. @@ -135,12 +135,17 @@ def __init__( self.init_device = init_device self.logit_scale = logit_scale self.no_bias = no_bias - self.verbose = verbose self.embedding_fraction = embedding_fraction self.norm_type = norm_type self.use_cache = use_cache self.init_config = init_config self.fc_type = fc_type + if verbose is not None: + warnings.warn( + DeprecationWarning( + 'verbose argument for MPTConfig is now ignored and will be removed. Use python_log_level instead.' + )) + if 'name' in kwargs: del kwargs['name'] if 'loss_fn' in kwargs: diff --git a/llmfoundry/models/mpt/modeling_mpt.py b/llmfoundry/models/mpt/modeling_mpt.py index 75ebfbd67c..1b4ca764ea 100644 --- a/llmfoundry/models/mpt/modeling_mpt.py +++ b/llmfoundry/models/mpt/modeling_mpt.py @@ -63,6 +63,10 @@ pass # isort: on +import logging + +log = logging.getLogger(__name__) + class MPTPreTrainedModel(PreTrainedModel): config_class = MPTConfig @@ -118,8 +122,8 @@ def __init__(self, config: MPTConfig): self.norm_f = norm_class(config.d_model, device=config.init_device) if config.init_device != 'meta': - print( - f'You are using {config.init_device=}, but you can also use config.init_device="meta" with Composer + FSDP for fast initialization.' + log.info( + f'We recommend using config.init_device="meta" with Composer + FSDP for faster initialization.' ) self.apply(self.param_init_fn) @@ -142,19 +146,11 @@ def __init__(self, config: MPTConfig): for module in self.modules(): if hasattr(module, 'bias') and isinstance( module.bias, nn.Parameter): - if config.verbose: - warnings.warn( - f'Removing bias ({module.bias}) from {module}.') + log.info(f'Removing bias ({module.bias}) from {module}.') module.register_parameter('bias', None) - # Print verbose info - if config.verbose and config.verbose > 2: - print(self) - if 'verbose' not in self.config.init_config: - self.config.init_config['verbose'] = self.config.verbose - if self.config.init_config['verbose'] > 1: - init_fn_name = self.config.init_config['name'] - warnings.warn(f'Using {init_fn_name} initialization.') + log.debug(self) + log.debug(f'Using {self.config.init_config["name"]} initialization.') def get_input_embeddings(self): return self.wte @@ -486,7 +482,7 @@ def __init__(self, config: MPTConfig): raise ValueError( 'MPTForCausalLM only supports tied word embeddings') - print(f'Instantiating an MPTForCausalLM model from {__file__}') + log.info(f'Instantiating an MPTForCausalLM model from {__file__}') self.transformer: MPTModel = MPTModel(config) @@ -717,8 +713,6 @@ def __init__( from flash_attn.losses.cross_entropy import \ CrossEntropyLoss as FusedCrossEntropyLoss - if hf_config.verbose > 1: - warnings.warn('Using Fused Cross Entropy Loss.') self.loss_fn = FusedCrossEntropyLoss(ignore_index=-100) except: raise ValueError( diff --git a/llmfoundry/models/utils/param_init_fns.py b/llmfoundry/models/utils/param_init_fns.py index 388ef10e29..2411dc8a16 100644 --- a/llmfoundry/models/utils/param_init_fns.py +++ b/llmfoundry/models/utils/param_init_fns.py @@ -21,13 +21,9 @@ def torch_default_param_init_fn_( module: nn.Module, - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config - if verbose > 1: - warnings.warn( - f"Initializing network using module's reset_parameters attribute") if hasattr(module, 'reset_parameters') and isinstance( module.reset_parameters, Callable): @@ -65,15 +61,11 @@ def generic_param_init_fn_( init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config - if verbose > 1: - warnings.warn( - f'If model has bias parameters they are initialized to 0.') - # enable user to divide _is_residual weights by + # a value which defaults to math.sqrt(2 * cfg.n_layers) init_div_is_residual = init_div_is_residual @@ -95,13 +87,6 @@ def generic_param_init_fn_( f'Expected init_div_is_residual to be boolean or numeric, got {init_div_is_residual}' ) - if init_div_is_residual is not False: - if verbose > 1: - warnings.warn( - f'Initializing _is_residual layers then dividing them by {div_is_residual:.3f}. ' +\ - f'Set `init_div_is_residual: false` in init config to disable this.' - ) - if isinstance(module, tuple(set(FC_CLASS_REGISTRY.values()))): # Linear if hasattr(module, '_fused'): @@ -124,10 +109,6 @@ def generic_param_init_fn_( if std == 0: warnings.warn(f'Embedding layer initialized to 0.') emb_init_fn_ = partial(torch.nn.init.normal_, mean=0.0, std=std) - if verbose > 1: - warnings.warn( - f'Embedding layer initialized using normal distribution with mean=0 and {std=}.' - ) elif emb_init_uniform_lim is not None: lim = emb_init_uniform_lim if isinstance(lim, Sequence): @@ -143,10 +124,6 @@ def generic_param_init_fn_( lim = [-lim, lim] a, b = lim emb_init_fn_ = partial(torch.nn.init.uniform_, a=a, b=b) - if verbose > 1: - warnings.warn( - f'Embedding layer initialized using uniform distribution in range {lim}.' - ) else: emb_init_fn_ = init_fn_ @@ -154,10 +131,6 @@ def generic_param_init_fn_( elif isinstance(module, tuple(set(NORM_CLASS_REGISTRY.values()))): # Norm - if verbose > 1: - warnings.warn( - f'Norm weights are set to 1. If norm layer has a bias it is initialized to 0.' - ) if hasattr(module, 'weight') and isinstance(module.weight, torch.Tensor): torch.nn.init.ones_(module.weight) @@ -237,16 +210,11 @@ def _normal_param_init_fn_( init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config init_fn_ = _normal_init_(std=std) - if verbose > 1: - warnings.warn( - f'Using torch.nn.init.normal_ init fn mean=0.0, std={std}') - generic_param_init_fn_( module=module, init_fn_=init_fn_, @@ -255,7 +223,6 @@ def _normal_param_init_fn_( init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) @@ -267,7 +234,6 @@ def baseline_param_init_fn_( init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config @@ -283,7 +249,6 @@ def baseline_param_init_fn_( init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) @@ -294,7 +259,6 @@ def small_param_init_fn_( init_div_is_residual: Union[int, float, str, bool] = True, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config @@ -309,7 +273,6 @@ def small_param_init_fn_( init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) @@ -319,7 +282,6 @@ def neox_param_init_fn_( d_model: int, emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, - verbose: int = 0, **kwargs: Any, ): """From section 2.3.1 of GPT-NeoX-20B: @@ -331,9 +293,6 @@ def neox_param_init_fn_( del kwargs # unused, just to capture any extra args from the config residual_div = n_layers / math.sqrt(10) # small std / wang std - if verbose > 1: - warnings.warn(f'setting init_div_is_residual to {residual_div}') - small_param_init_fn_( module=module, d_model=d_model, @@ -341,7 +300,6 @@ def neox_param_init_fn_( init_div_is_residual=residual_div, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) @@ -355,17 +313,10 @@ def kaiming_uniform_param_init_fn_( init_gain: float = 0, fan_mode: str = 'fan_in', init_nonlinearity: str = 'leaky_relu', - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config - if verbose > 1: - warnings.warn( - f'Using nn.init.kaiming_uniform_ init fn with parameters: ' +\ - f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}' - ) - kaiming_uniform_ = partial(nn.init.kaiming_uniform_, a=init_gain, mode=fan_mode, @@ -379,7 +330,6 @@ def kaiming_uniform_param_init_fn_( init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) @@ -393,17 +343,10 @@ def kaiming_normal_param_init_fn_( init_gain: float = 0, fan_mode: str = 'fan_in', init_nonlinearity: str = 'leaky_relu', - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config - if verbose > 1: - warnings.warn( - f'Using nn.init.kaiming_normal_ init fn with parameters: ' +\ - f'a={init_gain}, mode={fan_mode}, nonlinearity={init_nonlinearity}' - ) - kaiming_normal_ = partial(torch.nn.init.kaiming_normal_, a=init_gain, mode=fan_mode, @@ -417,7 +360,6 @@ def kaiming_normal_param_init_fn_( init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) @@ -429,18 +371,11 @@ def xavier_uniform_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config xavier_uniform_ = partial(torch.nn.init.xavier_uniform_, gain=init_gain) - if verbose > 1: - warnings.warn( - f'Using torch.nn.init.xavier_uniform_ init fn with parameters: ' +\ - f'gain={init_gain}' - ) - generic_param_init_fn_( module=module, init_fn_=xavier_uniform_, @@ -449,7 +384,6 @@ def xavier_uniform_param_init_fn_( init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) @@ -461,18 +395,11 @@ def xavier_normal_param_init_fn_( emb_init_std: Optional[float] = None, emb_init_uniform_lim: Optional[Union[Tuple[float, float], float]] = None, init_gain: float = 0, - verbose: int = 0, **kwargs: Any, ): del kwargs # unused, just to capture any extra args from the config xavier_normal_ = partial(torch.nn.init.xavier_normal_, gain=init_gain) - if verbose > 1: - warnings.warn( - f'Using torch.nn.init.xavier_normal_ init fn with parameters: ' +\ - f'gain={init_gain}' - ) - generic_param_init_fn_( module=module, init_fn_=xavier_normal_, @@ -481,7 +408,6 @@ def xavier_normal_param_init_fn_( init_div_is_residual=init_div_is_residual, emb_init_std=emb_init_std, emb_init_uniform_lim=emb_init_uniform_lim, - verbose=verbose, ) diff --git a/llmfoundry/utils/builders.py b/llmfoundry/utils/builders.py index e4117e8a7d..937d30661e 100644 --- a/llmfoundry/utils/builders.py +++ b/llmfoundry/utils/builders.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import logging import os from typing import Any, Dict, List, Optional, Tuple, Union @@ -30,6 +31,8 @@ from llmfoundry.optim import (DecoupledAdaLRLion, DecoupledClipLion, DecoupledLionW, DecoupledLionW_8bit) +log = logging.getLogger(__name__) + def build_icl_data_and_gauntlet( icl_tasks_config: Union[str, ListConfig], @@ -189,7 +192,7 @@ def build_icl_evaluators( icl_tasks_list = None if isinstance(icl_tasks, str): - print(f'Extracting ICL task config from path: {icl_tasks}') + log.info(f'Extracting ICL task config from path: {icl_tasks}') with open(icl_tasks, 'r') as icl_f: icl_task_cfg = om.load(icl_f) icl_tasks_list = icl_task_cfg.icl_tasks diff --git a/llmfoundry/utils/checkpoint_conversion_helpers.py b/llmfoundry/utils/checkpoint_conversion_helpers.py index 1e97ff7959..e058706316 100644 --- a/llmfoundry/utils/checkpoint_conversion_helpers.py +++ b/llmfoundry/utils/checkpoint_conversion_helpers.py @@ -10,6 +10,7 @@ """ import json +import logging import os import random import string @@ -20,6 +21,8 @@ import sentencepiece as spm from transformers import AutoTokenizer, PreTrainedTokenizer +log = logging.getLogger(__name__) + def _get_weight_data_type(data_type: str): if data_type == 'fp32': @@ -106,7 +109,7 @@ def _write_zero_bias(weight_name: str, weight_file_path: str, raise RuntimeError( f'Cannot write zero bias for {weight_name}. Input is not a weight tensor' ) - print(f'zero bias for weight: {weight_name}') + log.debug(f'zero bias for weight: {weight_name}') bias_file_path = weight_file_path.replace('.weight', '.bias') bias = np.zeros(bias_shape, dtype=np.float32) bias.tofile(bias_file_path) @@ -259,10 +262,10 @@ def convert_and_save_ft_weights(named_params: dict, } for name, param in named_params.items(): - print(f'Working on parameter {name} ...') + log.debug(f'Working on parameter {name} ...') data = param.detach().cpu().numpy().astype(np_weight_data_type) if name.find('weight') == -1 and name.find('bias') == -1: - print(f'found a parameter name that is not handled: {name}') + log.debug(f'found a parameter name that is not handled: {name}') continue if name == 'transformer.wpe.weight': assert data.shape == ( diff --git a/llmfoundry/utils/config_utils.py b/llmfoundry/utils/config_utils.py index 455433cb04..103f091c0a 100644 --- a/llmfoundry/utils/config_utils.py +++ b/llmfoundry/utils/config_utils.py @@ -2,6 +2,7 @@ # SPDX-License-Identifier: Apache-2.0 import contextlib +import logging import math import warnings from typing import Any, Dict, Mapping, Optional, Union @@ -12,6 +13,8 @@ from llmfoundry.models.utils import init_empty_weights +log = logging.getLogger(__name__) + def pop_config(cfg: DictConfig, key: str, @@ -56,8 +59,8 @@ def calculate_batch_size_info(global_batch_size: int, device_grad_accum = 'auto' elif isinstance(device_microbatch_size, int): if device_microbatch_size > device_batch_size: - print( - f'WARNING: device_microbatch_size > device_batch_size, ' + + log.warn( + f'device_microbatch_size > device_batch_size, ' + f'will be reduced from {device_microbatch_size} -> {device_batch_size}.' ) device_microbatch_size = device_batch_size diff --git a/scripts/data_prep/convert_dataset_hf.py b/scripts/data_prep/convert_dataset_hf.py index 01948822c2..fee56de54e 100644 --- a/scripts/data_prep/convert_dataset_hf.py +++ b/scripts/data_prep/convert_dataset_hf.py @@ -368,9 +368,8 @@ def main(args: Namespace) -> None: # Write samples print(f'Converting {folder_split} to MDS format...') print( - f'Note that the progress bar is based on the dataset length before tokenization.' + f'Note: the progress bar is based on the dataset length before tokenization, and may finish at a value before 100%.' ) - print(f'It will finish at a value below 100% if tokenizing') with MDSWriter(columns=columns, out=os.path.join(args.out_root, folder_split), compression=args.compression) as out: diff --git a/scripts/data_prep/convert_finetuning_dataset.py b/scripts/data_prep/convert_finetuning_dataset.py index d2bf8923df..e24b6ec904 100644 --- a/scripts/data_prep/convert_finetuning_dataset.py +++ b/scripts/data_prep/convert_finetuning_dataset.py @@ -161,9 +161,7 @@ def main(args: Namespace) -> None: else: preprocessor_str = args.preprocessor preprocessing_fn = dataset_constructor.get_preprocessing_fn_from_str( - preprocessor=preprocessor_str, - dataset_name=args.dataset, - verbose=True) + preprocessor=preprocessor_str, dataset_name=args.dataset) if preprocessing_fn is None: raise ValueError( '`args.preprocessor` was not set and no preprocessing function ' +\ diff --git a/scripts/eval/eval.py b/scripts/eval/eval.py index 3408f31d33..1ba723a172 100644 --- a/scripts/eval/eval.py +++ b/scripts/eval/eval.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 +import logging import os import sys import time @@ -90,14 +91,22 @@ def load_model(model_cfg: DictConfig, tokenizer: PreTrainedTokenizerBase, ) -def evaluate_model(model_cfg: DictConfig, dist_timeout: Union[float, int], - run_name: str, icl_tasks: Union[str, ListConfig], - max_seq_len: int, device_eval_batch_size: int, - eval_gauntlet_config: Optional[Union[str, DictConfig]], - fsdp_config: Optional[Dict], num_retries: int, - loggers_cfg: Dict[str, Any], python_log_level: str, - precision: str, eval_gauntlet_df: Optional[pd.DataFrame], - icl_subset_num_batches: Optional[int]): +def evaluate_model( + model_cfg: DictConfig, + dist_timeout: Union[float, int], + run_name: str, + icl_tasks: Union[str, ListConfig], + max_seq_len: int, + device_eval_batch_size: int, + eval_gauntlet_config: Optional[Union[str, DictConfig]], + fsdp_config: Optional[Dict], + num_retries: int, + loggers_cfg: Dict[str, Any], + python_log_level: Optional[str], + precision: str, + eval_gauntlet_df: Optional[pd.DataFrame], + icl_subset_num_batches: Optional[int], +): print(f'Evaluating model: {model_cfg.model_name}', flush=True) # Build tokenizer and model tokenizer_cfg: Dict[str, @@ -206,10 +215,10 @@ def main(cfg: DictConfig): 'device_eval_batch_size', must_exist=True) precision: str = pop_config(cfg, 'precision', must_exist=True) - python_log_level: str = pop_config(cfg, - 'python_log_level', - must_exist=False, - default_value='debug') + python_log_level: Optional[str] = pop_config(cfg, + 'python_log_level', + must_exist=False, + default_value='debug') # Optional Evaluation Parameters with default values seed: int = pop_config(cfg, 'seed', must_exist=False, default_value=17) @@ -246,6 +255,15 @@ def main(cfg: DictConfig): reproducibility.seed_all(seed) dist.initialize_dist(get_device(None), timeout=dist_timeout) + if python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('llmfoundry').setLevel(python_log_level.upper()) + eval_gauntlet_df = None models_df = None composite_scores = None diff --git a/scripts/train/train.py b/scripts/train/train.py index 95aa22f44c..87217702e5 100644 --- a/scripts/train/train.py +++ b/scripts/train/train.py @@ -1,6 +1,7 @@ # Copyright 2022 MosaicML LLM Foundry authors # SPDX-License-Identifier: Apache-2.0 import copy +import logging import os import sys import warnings @@ -346,10 +347,10 @@ def main(cfg: DictConfig) -> Trainer: 'log_to_console', must_exist=False, default_value=True) - python_log_level: str = pop_config(cfg, - 'python_log_level', - must_exist=False, - default_value='debug') + python_log_level: Optional[str] = pop_config(cfg, + 'python_log_level', + must_exist=False, + default_value='debug') console_log_interval: Union[int, str] = pop_config(cfg, 'console_log_interval', must_exist=False, @@ -414,6 +415,16 @@ def main(cfg: DictConfig) -> Trainer: 'FSDP is not applicable for single-GPU training. Reverting to DDP.') fsdp_config = None + # set logging level + if python_log_level is not None: + logging.basicConfig( + # Example of format string + # 2022-06-29 11:22:26,152: rank0[822018][MainThread]: INFO: Message here + format= + f'%(asctime)s: rank{dist.get_global_rank()}[%(process)d][%(threadName)s]: %(levelname)s: %(name)s: %(message)s' + ) + logging.getLogger('llmfoundry').setLevel(python_log_level.upper()) + # Initialize context init_context = process_init_device(model_config, fsdp_config) logged_cfg.update({'fsdp_config': fsdp_config}, merge=True)