diff --git a/magma/language_model.py b/magma/language_model.py index 1902540..664f69d 100644 --- a/magma/language_model.py +++ b/magma/language_model.py @@ -1,45 +1,32 @@ import torch -from transformers import GPTNeoForCausalLM, AutoConfig, GPT2LMHeadModel +from transformers import AutoModelForCausalLM, GPTJForCausalLM, GPTJConfig from .utils import print_main from pathlib import Path from transformers.modeling_utils import no_init_weights +from magma.config import MultimodalConfig LANGUAGE_MODELS = [ "gptj", ] - -def gptj_config(): - config = AutoConfig.from_pretrained("EleutherAI/gpt-neo-2.7B") - config.attention_layers = ["global"] * 28 - config.attention_types = [["global"], 28] - config.num_layers = 28 - config.num_heads = 16 - config.hidden_size = 256 * config.num_heads - config.vocab_size = 50400 - config.rotary = True - config.rotary_dim = 64 - config.jax = True - config.gradient_checkpointing = True - return config - - -def get_gptj( +def get_gptj(config: MultimodalConfig, gradient_checkpointing: bool = True, - from_pretrained=False, + from_pretrained="EleutherAI/gpt-j-6B", ) -> torch.nn.Module: """ Loads GPTJ language model from HF """ print_main("Loading GPTJ language model...") - config = gptj_config() - config.gradient_checkpointing = gradient_checkpointing + gptj_config = GPTJConfig() + gptj_config.gradient_checkpointing = gradient_checkpointing if gradient_checkpointing: - config.use_cache = False - config.model_device = "cpu" - if from_pretrained: - raise NotImplemented("GPTJ pretrained not implemented") + gptj_config.use_cache = False + + if config.deepspeed_config_params['fp16']['enabled'] is True: + model = GPTJForCausalLM.from_pretrained( + from_pretrained, revision="float16", torch_dtype=torch.float16, low_cpu_mem_usage=True, config=gptj_config + ) else: - with no_init_weights(): - model = GPTNeoForCausalLM(config=config) + model = AutoModelForCausalLM.from_pretrained(from_pretrained, config=gptj_config) + return model diff --git a/magma/magma.py b/magma/magma.py index 9b57446..de4c3a5 100644 --- a/magma/magma.py +++ b/magma/magma.py @@ -5,6 +5,8 @@ from copy import deepcopy from typing import Literal, Optional, List from torchtyping import TensorType +from torch.nn.modules.container import ModuleList, Sequential +from torch.nn.parameter import Parameter from transformers.file_utils import ModelOutput from magma.config import MultimodalConfig @@ -40,7 +42,7 @@ def __init__(self, config, device=None): "cuda" if torch.cuda.is_available() else "cpu" ) self.config = config - self.lm = get_gptj() #.to(self.device) + self.lm = get_gptj(config) #.to(self.device) self.seq_len = self.lm.config.max_position_embeddings self.tokenizer = get_tokenizer("gpt2", sequence_length=self.seq_len) @@ -89,6 +91,21 @@ def __init__(self, config, device=None): **attn_config, ) + #check weights contiguous + for name, param in self.named_parameters(): + if param.is_contiguous() is False: + path, param = name.rsplit(".",1) + path = path.split('.') + ref = self + while path: + element, path = path[0], path[1:] + if type(ref) in {Sequential, ModuleList}: + ref = ref[int(element)] + else: + ref = getattr(ref, element) + setattr(ref, param, Parameter(getattr(ref, param).contiguous())) + # print(name, getattr(ref, param).is_contiguous()) + # freeze parameters if config.freeze_lm: for name, param in self.lm.named_parameters(): # freeze lm weights diff --git a/magma/utils.py b/magma/utils.py index 8c88fdf..be939e8 100644 --- a/magma/utils.py +++ b/magma/utils.py @@ -46,7 +46,7 @@ def get_tokenizer(name="gpt2", sequence_length=2048): """ if name == "gpt2": tokenizer = GPT2TokenizerFast.from_pretrained("gpt2") - tokenizer.pad_token_id = tokenizer.eos_token + tokenizer.pad_token_id = tokenizer.eos_token_id tokenizer.padding_side = "right" tokenizer.model_max_length = sequence_length # setup lm settings diff --git a/requirements.txt b/requirements.txt index d770009..f4bec03 100644 --- a/requirements.txt +++ b/requirements.txt @@ -1,6 +1,6 @@ torchtyping typeguard -git+https://github.com/finetuneanon/transformers.git#egg=transformers +transformers gdown tqdm timm