diff --git a/.github/workflows/build.yaml b/.github/workflows/build.yaml index a665d9b0aa5..8213887ff76 100644 --- a/.github/workflows/build.yaml +++ b/.github/workflows/build.yaml @@ -153,7 +153,7 @@ jobs: runs-on: ["self-hosted", "${{ needs.build-and-push.outputs.runs_on }}", "multi-gpu"] if: needs.build-and-push.outputs.runs_on != 'ubuntu-latest' env: - PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == 'true') && '--release' || '' }} + PYTEST_FLAGS: ${{ (startsWith(github.ref, 'refs/tags/') || github.ref == 'refs/heads/main' || inputs.release-tests == true) && '--release' || '' }} steps: - name: Checkout repository uses: actions/checkout@v4 diff --git a/server/text_generation_server/models/__init__.py b/server/text_generation_server/models/__init__.py index 15e746229a9..58131a3a736 100644 --- a/server/text_generation_server/models/__init__.py +++ b/server/text_generation_server/models/__init__.py @@ -16,6 +16,7 @@ from text_generation_server.models.custom_modeling.mpt_modeling import ( MPTForCausalLM, ) +from text_generation_server.models.bloom import BloomCausalLMBatch from text_generation_server.models.custom_modeling.bloom_modeling import ( BloomForCausalLM, ) @@ -522,7 +523,7 @@ def get_model( speculator=speculator, dtype=dtype, trust_remote_code=trust_remote_code, - batch_class=CausalLMBatchKeysLast, + batch_class=BloomCausalLMBatch, ) elif model_type == MPT: return CausalLM( diff --git a/server/text_generation_server/models/causal_lm.py b/server/text_generation_server/models/causal_lm.py index cac36ebdd0a..868a3cc0792 100644 --- a/server/text_generation_server/models/causal_lm.py +++ b/server/text_generation_server/models/causal_lm.py @@ -553,7 +553,8 @@ def __init__( if config.quantize in ["awq", "exl2", "gptq", "marlin"]: weights._set_gptq_params(model_id, revision) - model = model_class(config, weights) + prefix = "" + model = model_class(prefix, config, weights) torch.distributed.barrier(group=self.process_group) super().__init__( diff --git a/server/text_generation_server/models/custom_modeling/bloom_modeling.py b/server/text_generation_server/models/custom_modeling/bloom_modeling.py index 0d8a1b590e6..77b89c5bf16 100644 --- a/server/text_generation_server/models/custom_modeling/bloom_modeling.py +++ b/server/text_generation_server/models/custom_modeling/bloom_modeling.py @@ -816,7 +816,7 @@ def forward( class BloomForCausalLM(BloomPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.transformer = BloomModel(config, weights) diff --git a/server/text_generation_server/models/custom_modeling/clip.py b/server/text_generation_server/models/custom_modeling/clip.py index 56618bf16d7..27b9ff1cc78 100644 --- a/server/text_generation_server/models/custom_modeling/clip.py +++ b/server/text_generation_server/models/custom_modeling/clip.py @@ -446,7 +446,7 @@ def forward( class CLIPTextTransformer(nn.Module): - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix: str, config: CLIPTextConfig): super().__init__() self.config = config embed_dim = config.hidden_size @@ -536,9 +536,9 @@ class CLIPTextModel(CLIPPreTrainedModel): _no_split_modules = ["CLIPTextEmbeddings", "CLIPEncoderLayer"] - def __init__(self, config: CLIPTextConfig): + def __init__(self, prefix, config: CLIPTextConfig): super().__init__(config) - self.text_model = CLIPTextTransformer(config) + self.text_model = CLIPTextTransformer(prefix, config) # Initialize weights and apply final processing self.post_init() diff --git a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py index e088f9aa308..f993fe72094 100644 --- a/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_cohere_modeling.py @@ -363,9 +363,9 @@ def forward(self, hidden_states): class FlashCohereLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashCohereAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -416,18 +416,19 @@ def forward( class FlashCohereModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashCohereLayer( + prefix, layer_id, config, weights, @@ -436,7 +437,7 @@ def __init__(self, config, weights): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="model.norm", weights=weights, eps=config.layer_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.layer_norm_eps ) self.gradient_checkpointing = False @@ -486,10 +487,15 @@ def forward( class FlashCohereForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashCohereModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashCohereModel(prefix, config, weights) try: self.lm_head = SpeculativeHead.load( config, @@ -499,7 +505,7 @@ def __init__(self, config, weights): except RuntimeError: self.lm_head = SpeculativeHead.load( config, - prefix="model.embed_tokens", + prefix=f"{prefix}.embed_tokens", weights=weights, ) self.logit_scale = config.logit_scale diff --git a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py index aea7f3994a2..e469495fc01 100644 --- a/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_dbrx_modeling.py @@ -593,9 +593,9 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class DbrxLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.blocks.{layer_id}" + prefix = f"{prefix}.blocks.{layer_id}" self.attn = DbrxNormAttentionNorm( prefix=f"{prefix}.norm_attn_norm", config=config, weights=weights @@ -637,16 +637,17 @@ def forward( class DbrxModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.wte", weights=weights + prefix=f"{prefix}.wte", weights=weights ) self.layers = nn.ModuleList( [ DbrxLayer( + prefix, layer_id, config, weights, @@ -655,7 +656,7 @@ def __init__(self, config, weights): ] ) self.norm = FastLayerNorm.load_no_bias( - prefix="transformer.norm_f", weights=weights, eps=1e-5 + prefix=f"{prefix}.norm_f", weights=weights, eps=1e-5 ) self.head_size = self.layers[0].attn.self_attn.head_size @@ -702,9 +703,14 @@ def forward( class FlashDbrxForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + self.model = DbrxModel(config, weights) self.lm_head = SpeculativeHead.load( config, diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py index 625baa9109b..beff08b3080 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma2_modeling.py @@ -102,7 +102,7 @@ def __init__( class Gemma2FastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ def forward(self, hidden_states, residual=None): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -305,7 +305,7 @@ def forward(self, hidden_states): class FlashGemma2Layer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool, is_sliding: bool): + def __init__(self, prefix: str, config, weights, causal: bool, is_sliding: bool): super().__init__() self.self_attn = FlashGemma2Attention( prefix=f"{prefix}.self_attn", @@ -376,7 +376,7 @@ def forward( class FlashGemma2Model(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -442,7 +442,7 @@ def forward( class FlashGemma2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, *, causal: bool = True): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py index b7ce6307580..14b62b00b0b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gemma_modeling.py @@ -102,7 +102,7 @@ def __init__( class GemmaFastRMSNorm(FastRMSNorm): @classmethod - def load(cls, prefix, weights, eps=1e-6): + def load(cls, prefix: str, weights, eps=1e-6): dtype = weights.dtype weights.dtype = torch.float32 weight = weights.get_tensor(f"{prefix}.weight") + 1 @@ -123,7 +123,7 @@ def forward(self, hidden_states, residual=None): return hidden_states.to(self.dtype), residual -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -261,7 +261,7 @@ def forward( class GemmaMLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.hidden_act self.act = ( @@ -299,7 +299,7 @@ def forward(self, hidden_states): class FlashGemmaLayer(nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() self.self_attn = FlashGemmaAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights, causal=causal @@ -354,7 +354,7 @@ def forward( class FlashGemmaModel(torch.nn.Module): - def __init__(self, prefix, config, weights, causal: bool): + def __init__(self, prefix: str, config, weights, causal: bool): super().__init__() process_group = weights.process_group @@ -419,7 +419,7 @@ def forward( class FlashGemmaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, *, causal: bool = True): + def __init__(self, prefix: str, config, weights, *, causal: bool = True): super().__init__() embed_norm = config.hidden_size**0.5 diff --git a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py index 9f8001468dd..d5dc25cff1a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_gpt2_modeling.py @@ -261,7 +261,7 @@ def forward( class GPT2MLP(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() act = config.activation_function self.act = ( @@ -298,7 +298,7 @@ def forward(self, hidden_states): class FlashGPT2Layer(nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.self_attn = FlashGPT2Attention( prefix=f"{prefix}.attn", config=config, weights=weights @@ -350,7 +350,7 @@ def forward( class FlashGPT2Model(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -414,7 +414,7 @@ def forward( class FlashGPT2ForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py index 77a7e2d5738..78832341c3f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_llama_modeling.py @@ -54,7 +54,7 @@ raise ImportError(f"Could not load `vllm._custom_C`. Full error: {e}") -def load_attention(config, prefix, weights, layer_id): +def load_attention(config, prefix: str, weights, layer_id): # Only defined in granite. bias = getattr(config, "attention_bias", False) head_size = config.hidden_size // config.num_attention_heads @@ -467,7 +467,7 @@ def forward( class FlashLlamaForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( diff --git a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py index 396969cd03e..8028dbe80a6 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mistral_modeling.py @@ -248,7 +248,7 @@ def forward( class MistralMLP(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.hidden_act = config.hidden_act self.act = ( @@ -328,7 +328,7 @@ def forward(self, hidden_states, adapter_data): class MistralLayer(nn.Module): - def __init__(self, prefix, config, weights, layer_id): + def __init__(self, prefix: str, config, weights, layer_id): super().__init__() self.self_attn = MistralAttention( prefix=f"{prefix}.self_attn", @@ -392,7 +392,7 @@ def forward( class MistralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group @@ -462,7 +462,7 @@ def forward( class FlashMistralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights, name=None): + def __init__(self, prefix: str, config, weights, name=None): if name is None: name = "model" super().__init__() diff --git a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py index 2d6a7f972e2..429793ea575 100644 --- a/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_mixtral_modeling.py @@ -116,7 +116,7 @@ def promote_scalar(x: torch.Tensor) -> torch.Tensor: return x.view(1) if len(x.size()) == 0 else x -def load_attention(config, prefix, weights): +def load_attention(config, prefix: str, weights): if config.num_attention_heads != config.num_key_value_heads: return _load_gqa(config, prefix, weights) else: @@ -155,7 +155,7 @@ def _load_gqa(config, prefix: str, weights): ) -def _load_experts(config, prefix, mat, weights): +def _load_experts(config, prefix: str, mat, weights): if config.quantize is not None: raise NotImplementedError("Mixtral does not support weight quantization yet.") @@ -475,7 +475,7 @@ def forward(self, x: torch.Tensor) -> torch.Tensor: class MixtralLayer(nn.Module): - def __init__(self, prefix, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() prefix = f"{prefix}.layers.{layer_id}" @@ -536,7 +536,7 @@ def forward( class MixtralModel(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.embed_tokens = TensorParallelEmbedding( @@ -610,7 +610,7 @@ def forward( class FlashMixtralForCausalLM(torch.nn.Module): - def __init__(self, prefix, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.model = MixtralModel(prefix, config, weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py index 33aebc2be38..0eca181b67b 100644 --- a/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_neox_modeling.py @@ -305,12 +305,12 @@ class FlashGPTNeoXPreTrainedModel(PreTrainedModel): class FlashGPTNeoXModel(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( @@ -320,7 +320,7 @@ def __init__(self, config, weights): ] ) self.final_layer_norm = FastLayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -370,9 +370,15 @@ def forward( class FlashGPTNeoXForCausalLM(FlashGPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) - self.gpt_neox = FlashGPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = FlashGPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights diff --git a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py index f237ea37e0f..7401bc27a7f 100644 --- a/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_phi_modeling.py @@ -258,9 +258,9 @@ def forward(self, hidden_states): class FlashPhiLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = FlashPhiAttention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -307,18 +307,19 @@ def forward( class FlashPhiModel(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ FlashPhiLayer( + prefix, layer_id, config, weights, @@ -378,10 +379,15 @@ def forward( class FlashPhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = FlashPhiModel(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = FlashPhiModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py index 1cc6a613dc1..a98709c51a5 100644 --- a/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_qwen2_modeling.py @@ -203,9 +203,9 @@ def forward(self, hidden_states): class Qwen2Layer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix, layer_id, config, weights): super().__init__() - prefix = f"model.layers.{layer_id}" + prefix = f"{prefix}.layers.{layer_id}" self.self_attn = Qwen2Attention( prefix=f"{prefix}.self_attn", config=config, weights=weights ) @@ -260,17 +260,18 @@ def forward( class Qwen2Model(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() process_group = weights.process_group self.tp_rank = process_group.rank() self.tp_world_size = process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="model.embed_tokens", weights=weights + prefix=f"{prefix}.embed_tokens", weights=weights ) self.layers = nn.ModuleList( [ Qwen2Layer( + prefix, layer_id, config, weights, @@ -279,7 +280,7 @@ def __init__(self, config, weights): ] ) self.norm = FastRMSNorm.load( - prefix="model.norm", weights=weights, eps=config.rms_norm_eps + prefix=f"{prefix}.norm", weights=weights, eps=config.rms_norm_eps ) self.gradient_checkpointing = False @@ -331,10 +332,15 @@ def forward( class Qwen2ForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = Qwen2Model(config, weights) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + + self.model = Qwen2Model(prefix, config, weights) self.lm_head = SpeculativeHead.load( config, prefix="lm_head", diff --git a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py index e7614232290..d12ed567a0a 100644 --- a/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_rw_modeling.py @@ -127,7 +127,7 @@ class FlashRWAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -236,7 +236,7 @@ class FlashRWLargeAttention(torch.nn.Module): def __init__( self, config, - prefix, + prefix: str, weights, ): super().__init__() @@ -358,7 +358,7 @@ def forward( class FlashMLP(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.act = torch.nn.functional.gelu @@ -380,6 +380,7 @@ class FlashRWLayer(nn.Module): def __init__( self, layer_id, + prefix: str, config, weights, ): @@ -388,7 +389,7 @@ def __init__( parallel_attn = config.parallel_attn self.parallel_attn = parallel_attn - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.input_layernorm = FastLayerNorm.load( prefix=f"{prefix}.input_layernorm", @@ -479,7 +480,7 @@ def forward( class FlashRWLayerNorm(nn.Module): - def __init__(self, config, prefix, weights): + def __init__(self, config, prefix: str, weights): super().__init__() self.num_ln = config.num_ln_in_parallel_attn @@ -518,9 +519,9 @@ def forward( class FlashRWLargeLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_layer = FlashRWLayerNorm(config, prefix, weights) @@ -580,18 +581,18 @@ class FlashRWPreTrainedModel(PreTrainedModel): class FlashRWModel(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.word_embeddings = TensorParallelEmbedding( - prefix="transformer.word_embeddings", weights=weights + prefix=f"{prefix}.word_embeddings", weights=weights ) if config.new_decoder_architecture: self.h = nn.ModuleList( [ - FlashRWLargeLayer(layer_id, config, weights) + FlashRWLargeLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -599,14 +600,14 @@ def __init__(self, config, weights): else: self.h = nn.ModuleList( [ - FlashRWLayer(layer_id, config, weights) + FlashRWLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.cache_size = self.h[0].self_attention.num_heads_kv self.ln_f = FastLayerNorm.load( - prefix="transformer.ln_f", + prefix=f"{prefix}.ln_f", weights=weights, eps=config.layer_norm_epsilon, ) @@ -653,10 +654,15 @@ def forward( class FlashRWForCausalLM(FlashRWPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.transformer = FlashRWModel(config, weights) + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.transformer = FlashRWModel(prefix, config, weights) self.lm_head = SpeculativeHead.load(config, prefix="lm_head", weights=weights) diff --git a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py index daef43cc9ee..21a22046c9e 100644 --- a/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py +++ b/server/text_generation_server/models/custom_modeling/flash_santacoder_modeling.py @@ -346,9 +346,9 @@ def forward(self, hidden_states): class Block(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, prefix: str, layer_id, config, weights): super().__init__() - prefix = f"transformer.h.{layer_id}" + prefix = f"{prefix}.h.{layer_id}" self.ln_1 = FastLayerNorm.load( prefix=f"{prefix}.ln_1", weights=weights, eps=config.layer_norm_epsilon ) @@ -396,18 +396,18 @@ def forward( class FlashSantacoderModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.config = config self.process_group = weights.process_group self.wte = TensorParallelEmbedding( - prefix="transformer.wte", + prefix=f"{prefix}.wte", weights=weights, reduce=False, ) self.wpe = TensorParallelEmbedding( - prefix="transformer.wpe", + prefix=f"{prefix}.wpe", weights=weights, reduce=False, ) @@ -415,6 +415,7 @@ def __init__(self, config, weights): self.layers = nn.ModuleList( [ Block( + prefix, layer_id, config, weights, @@ -466,10 +467,16 @@ def forward( class FlashSantacoderForCausalLM(nn.Module): def __init__(self, prefix, config, weights): super().__init__() + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + config.transpose = config.architectures[0].startswith("GPT2") - self.model = FlashSantacoderModel(config, weights) + self.model = FlashSantacoderModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) def forward( diff --git a/server/text_generation_server/models/custom_modeling/mpt_modeling.py b/server/text_generation_server/models/custom_modeling/mpt_modeling.py index f7981bf5311..fb09a8f1740 100644 --- a/server/text_generation_server/models/custom_modeling/mpt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/mpt_modeling.py @@ -783,7 +783,7 @@ class MPTPreTrainedModel(PreTrainedModel): class MPTModel(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): # config._validate_config() super().__init__(config) self.world_size = weights.process_group.size() @@ -809,13 +809,13 @@ def __init__(self, config, weights): f"Requested norm type ({config.norm_type}) is not implemented within this repo." ) - self.wte = TensorParallelEmbedding("transformer.wte", weights) + self.wte = TensorParallelEmbedding(f"{prefix}.wte", weights) if not self.alibi: - self.wpe = TensorParallelEmbedding("transformer.wpe", weights) + self.wpe = TensorParallelEmbedding(f"{prefix}.wpe", weights) self.blocks = nn.ModuleList( [ - MPTBlock(config, prefix=f"transformer.blocks.{i}", weights=weights) + MPTBlock(config, prefix=f"{prefix}.blocks.{i}", weights=weights) for i in range(config.n_layers) ] ) @@ -1085,13 +1085,19 @@ def forward( class MPTForCausalLM(MPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + if not config.tie_word_embeddings: raise ValueError("MPTForCausalLM only supports tied word embeddings") - self.transformer = MPTModel(config, weights) + self.transformer = MPTModel(prefix, config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="transformer.wte", weights=weights + config, prefix=f"{prefix}.wte", weights=weights ) self.logit_scale = None if config.logit_scale is not None: diff --git a/server/text_generation_server/models/custom_modeling/neox_modeling.py b/server/text_generation_server/models/custom_modeling/neox_modeling.py index fcad32fa79c..8998778fdb0 100644 --- a/server/text_generation_server/models/custom_modeling/neox_modeling.py +++ b/server/text_generation_server/models/custom_modeling/neox_modeling.py @@ -404,24 +404,24 @@ def forward(self, hidden_states): class GPTNeoXLayer(nn.Module): - def __init__(self, layer_id, config, weights): + def __init__(self, layer_id, prefix: str, config, weights): super().__init__() self.use_parallel_residual = config.use_parallel_residual self.input_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.input_layernorm", + prefix=f"{prefix}.layers.{layer_id}.input_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.post_attention_layernorm = nn.LayerNorm.load( - prefix=f"gpt_neox.layers.{layer_id}.post_attention_layernorm", + prefix=f"{prefix}.layers.{layer_id}.post_attention_layernorm", weights=weights, eps=config.layer_norm_eps, ) self.attention = GPTNeoXAttention( - config, prefix=f"gpt_neox.layers.{layer_id}.attention", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.attention", weights=weights ) self.mlp = GPTNeoXMLP( - config, prefix=f"gpt_neox.layers.{layer_id}.mlp", weights=weights + config, prefix=f"{prefix}.layers.{layer_id}.mlp", weights=weights ) def forward( @@ -472,23 +472,23 @@ def forward( class GPTNeoXModel(GPTNeoXPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) self.config = config self.num_attention_heads = config.num_attention_heads self.embed_in = TensorParallelEmbedding( - prefix="gpt_neox.embed_in", weights=weights + prefix=f"{prefix}.embed_in", weights=weights ) self.layers = nn.ModuleList( [ - GPTNeoXLayer(layer_id, config, weights) + GPTNeoXLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) self.final_layer_norm = nn.LayerNorm.load( - prefix="gpt_neox.final_layer_norm", + prefix=f"{prefix}.final_layer_norm", weights=weights, eps=config.layer_norm_eps, ) @@ -640,9 +640,15 @@ def forward( class GPTNeoxForCausalLM(GPTNeoXPreTrainedModel): _keys_to_ignore_on_load_missing = [r"position_ids", r"predictions.decoder.bias"] - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__(config) - self.gpt_neox = GPTNeoXModel(config, weights) + + if not prefix: + prefix = "gpt_neox" + else: + prefix = f"{prefix}.gpt_neox" + + self.gpt_neox = GPTNeoXModel(prefix, config, weights) self.embed_out = SpeculativeHead.load( config, prefix="embed_out", weights=weights ) diff --git a/server/text_generation_server/models/custom_modeling/opt_modeling.py b/server/text_generation_server/models/custom_modeling/opt_modeling.py index 9b2d01e0763..5ab02959126 100644 --- a/server/text_generation_server/models/custom_modeling/opt_modeling.py +++ b/server/text_generation_server/models/custom_modeling/opt_modeling.py @@ -94,11 +94,11 @@ class OPTLearnedPositionalEmbedding(nn.Module): This module learns positional embeddings up to a fixed maximum size. """ - def __init__(self, weights): + def __init__(self, prefix: str, weights): super().__init__() self.offset = 2 self.weight = nn.Parameter( - weights.get_tensor("model.decoder.embed_positions.weight") + weights.get_tensor(f"{prefix}.decoder.embed_positions.weight") ) def forward( @@ -311,11 +311,11 @@ def forward( class OPTDecoderLayer(nn.Module): - def __init__(self, layer_id: int, config: OPTConfig, weights): + def __init__(self, layer_id: int, prefix: str, config: OPTConfig, weights): super().__init__() self.process_group = weights.process_group self.hidden_size = config.hidden_size - prefix = f"model.decoder.layers.{layer_id}" + prefix = f"{prefix}.decoder.layers.{layer_id}" self.self_attn = OPTAttention( config, prefix=f"{prefix}.self_attn", @@ -429,7 +429,7 @@ class OPTPreTrainedModel(PreTrainedModel): class OPTDecoder(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) self.dropout = config.dropout self.layerdrop = config.layerdrop @@ -438,20 +438,26 @@ def __init__(self, config: OPTConfig, weights): self.vocab_size = config.vocab_size self.embed_tokens = TensorParallelEmbedding( - prefix="model.decoder.embed_tokens", weights=weights + prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) - self.embed_positions = OPTLearnedPositionalEmbedding(weights) + self.embed_positions = OPTLearnedPositionalEmbedding(prefix, weights) if config.word_embed_proj_dim != config.hidden_size: self.project_out = FastLinear.load( - config, prefix="model.decoder.project_out", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_out", + weights=weights, + bias=False, ) else: self.project_out = None if config.word_embed_proj_dim != config.hidden_size: self.project_in = FastLinear.load( - config, prefix="model.decoder.project_in", weights=weights, bias=False + config, + prefix=f"{prefix}.decoder.project_in", + weights=weights, + bias=False, ) else: self.project_in = None @@ -461,14 +467,14 @@ def __init__(self, config: OPTConfig, weights): # see https://github.com/facebookresearch/metaseq/pull/164 if config.do_layer_norm_before and not config._remove_final_layer_norm: self.final_layer_norm = nn.LayerNorm.load( - prefix="model.decoder.final_layer_norm", weights=weights, eps=EPS + prefix=f"{prefix}.decoder.final_layer_norm", weights=weights, eps=EPS ) else: self.final_layer_norm = None self.layers = nn.ModuleList( [ - OPTDecoderLayer(layer_id, config, weights) + OPTDecoderLayer(layer_id, prefix, config, weights) for layer_id in range(config.num_hidden_layers) ] ) @@ -686,9 +692,9 @@ def forward( class OPTModel(OPTPreTrainedModel): - def __init__(self, config: OPTConfig, weights): + def __init__(self, prefix: str, config: OPTConfig, weights): super().__init__(config) - self.decoder = OPTDecoder(config, weights) + self.decoder = OPTDecoder(prefix, config, weights) # Initialize weights and apply final processing def forward( @@ -743,13 +749,18 @@ def forward( class OPTForCausalLM(OPTPreTrainedModel): - def __init__(self, config, weights): + def __init__(self, prefix, config, weights): super().__init__(config) + if not prefix: + prefix = "model" + else: + prefix = f"{prefix}.model" + self.model = OPTModel(config, weights) self.lm_head = SpeculativeHead.load( - config, prefix="model.decoder.embed_tokens", weights=weights + config, prefix=f"{prefix}.decoder.embed_tokens", weights=weights ) def forward( diff --git a/server/text_generation_server/models/custom_modeling/phi_modeling.py b/server/text_generation_server/models/custom_modeling/phi_modeling.py index 04b470eb7d9..b4d56db131e 100644 --- a/server/text_generation_server/models/custom_modeling/phi_modeling.py +++ b/server/text_generation_server/models/custom_modeling/phi_modeling.py @@ -248,16 +248,16 @@ def forward( # PhiModel implements the embedding layer and the transformer blocks. class PhiModel(nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() self.tp_rank = weights.process_group.rank() self.tp_world_size = weights.process_group.size() self.embed_tokens = TensorParallelEmbedding( - prefix="transformer.embd.wte", weights=weights + prefix=f"{prefix}.embd.wte", weights=weights ) self.blocks = nn.ModuleList( [ - PhiBlock(f"transformer.h.{layer_id}", config, weights) + PhiBlock(f"{prefix}.h.{layer_id}", config, weights) for layer_id in range(config.n_layer) ] ) @@ -289,9 +289,15 @@ def forward( # PhiForCausalLM wraps the PhiModel and PhiCausalLMHead together and returns a CausalLMOutputWithPast object. class PhiForCausalLM(torch.nn.Module): - def __init__(self, config, weights): + def __init__(self, prefix: str, config, weights): super().__init__() - self.model = PhiModel(config, weights) + + if not prefix: + prefix = "transformer" + else: + prefix = f"{prefix}.transformer" + + self.model = PhiModel(prefix, config, weights) self.lm_head = PhiCausalLMHead(config, weights) def forward( diff --git a/server/text_generation_server/models/flash_causal_lm.py b/server/text_generation_server/models/flash_causal_lm.py index c7f5f1f9331..e66011a19a7 100644 --- a/server/text_generation_server/models/flash_causal_lm.py +++ b/server/text_generation_server/models/flash_causal_lm.py @@ -878,10 +878,6 @@ def __init__( ) config.quantize = quantize config.speculator = speculator - if getattr(config, "sliding_window", None) is not None: - set_sliding_window(config.sliding_window) - else: - config.sliding_window = None torch.distributed.barrier(group=self.process_group) @@ -900,13 +896,22 @@ def __init__( text_config = getattr(config, "text_config", None) if text_config is not None: config = text_config + + if getattr(config, "sliding_window", None) is not None: + set_sliding_window(config.sliding_window) + else: + config.sliding_window = None + self.num_layers = config.num_hidden_layers # Validation is done in the model itself if num_kv_heads is None: - num_kv_heads = getattr(config, "num_key_value_heads", None) + # Order is important here. + for attr in ["num_key_value_heads", "num_key_value_heads", "n_head"]: + num_kv_heads = getattr(config, "num_attention_heads", None) + if num_kv_heads is not None: + break if num_kv_heads is None: - # Final overide for GPT2 - num_kv_heads = config.n_head + raise ValueError("Cannot get the number of key/value heads") self.num_kv_heads = num_kv_heads // self.process_group.size() self.head_size = config.hidden_size // config.num_attention_heads