From 88d87021e9629401f3c18602c43353c2f17f4c01 Mon Sep 17 00:00:00 2001 From: JegernOUTT Date: Fri, 15 Mar 2024 19:10:24 +1030 Subject: [PATCH] the fix --- .../finetune/modelling/flash_sa.py | 16 +++++++--------- .../finetune/modelling/triton_flash_sa.py | 7 +++---- .../finetune/modelling/utils.py | 10 ++++++++++ .../finetune/scripts/auxiliary/model.py | 17 +++++++++++------ 4 files changed, 31 insertions(+), 19 deletions(-) diff --git a/self_hosting_machinery/finetune/modelling/flash_sa.py b/self_hosting_machinery/finetune/modelling/flash_sa.py index 1609df34..d7b8b483 100644 --- a/self_hosting_machinery/finetune/modelling/flash_sa.py +++ b/self_hosting_machinery/finetune/modelling/flash_sa.py @@ -1,10 +1,11 @@ import functools import math +from typing import Tuple, Optional import einops import torch -from typing import Tuple, Optional +from self_hosting_machinery.finetune.modelling.utils import get_base_model from self_hosting_machinery.finetune.utils import traces @@ -42,6 +43,7 @@ def _get_alibi_slopes(attn_heads: int, dev: str) -> torch.Tensor: m = torch.cat([m, m_hat]) return m + def _prerequisites_are_ok(model, try_triton_kernel: bool): try: from flash_attn import flash_attn_func @@ -90,8 +92,7 @@ def _forward( return traces.log("Applying flash attention to the model") - if type(model).__name__ == 'PeftModelForCausalLM': - model = model.base_model.model + model = get_base_model(model) for block in model.transformer.h: block.attn.forward = _forward.__get__(block.attn, type(block.attn)) @@ -134,8 +135,7 @@ def _forward( return traces.log("Applying flash attention to the model") - if type(model).__name__ == 'PeftModelForCausalLM': - model = model.base_model.model + model = get_base_model(model) for block in model.transformer.h: block.attn.forward = _forward.__get__(block.attn, type(block.attn)) @@ -184,8 +184,7 @@ def _forward( return traces.log("Applying flash attention to the model") - if type(model).__name__ == 'PeftModelForCausalLM': - model = model.base_model.model + model = get_base_model(model) for layer in model.base_model.layers: layer.self_attn.forward = _forward.__get__(layer.self_attn, type(layer.self_attn)) @@ -234,7 +233,6 @@ def _forward( return traces.log("Applying flash attention to the model") - if type(model).__name__ == 'PeftModelForCausalLM': - model = model.base_model.model + model = get_base_model(model) for layer in model.base_model.layers: layer.self_attn.forward = _forward.__get__(layer.self_attn, type(layer.self_attn)) diff --git a/self_hosting_machinery/finetune/modelling/triton_flash_sa.py b/self_hosting_machinery/finetune/modelling/triton_flash_sa.py index d3467a3c..08b8f4fd 100644 --- a/self_hosting_machinery/finetune/modelling/triton_flash_sa.py +++ b/self_hosting_machinery/finetune/modelling/triton_flash_sa.py @@ -1,15 +1,13 @@ import functools -import logging import math +from typing import Optional import torch as th import triton import triton.language as tl - -from typing import Optional, Tuple - from einops import einops +from self_hosting_machinery.finetune.modelling.utils import get_base_model from self_hosting_machinery.finetune.utils import traces @@ -597,5 +595,6 @@ def _forward( return traces.log("Applying triton flash attention to the model") + model = get_base_model(model) for block in model.transformer.h: block.attn.forward = _forward.__get__(block.attn, type(block.attn)) diff --git a/self_hosting_machinery/finetune/modelling/utils.py b/self_hosting_machinery/finetune/modelling/utils.py index 9025a85a..9ab9eb1f 100644 --- a/self_hosting_machinery/finetune/modelling/utils.py +++ b/self_hosting_machinery/finetune/modelling/utils.py @@ -1,8 +1,18 @@ from typing import List, Tuple +import torch + from self_hosting_machinery.finetune.configuration import supported_models +def get_base_model(model: torch.nn.Module) -> torch.nn.Module: + if type(model).__name__ == "DeepSpeedEngine": + model = model.base_model + if type(model).__name__ in ("LoraModel", "PeftModelForCausalLM"): + model = model.model + return model + + def map_model_specific_params( model_name: str, freeze_exceptions: List[str], diff --git a/self_hosting_machinery/finetune/scripts/auxiliary/model.py b/self_hosting_machinery/finetune/scripts/auxiliary/model.py index 8e69ab13..7ab5bb5d 100644 --- a/self_hosting_machinery/finetune/scripts/auxiliary/model.py +++ b/self_hosting_machinery/finetune/scripts/auxiliary/model.py @@ -2,6 +2,7 @@ import os from collections import defaultdict from functools import partial +from pathlib import Path from typing import Dict, Any, List, Tuple import deepspeed @@ -14,7 +15,7 @@ from self_hosting_machinery.finetune.configuration import supported_models from self_hosting_machinery.finetune.modelling.loss import masked_loss -from self_hosting_machinery.finetune.modelling.utils import map_model_specific_params +from self_hosting_machinery.finetune.modelling.utils import map_model_specific_params, get_base_model from self_hosting_machinery.finetune.utils import traces from self_hosting_machinery.finetune.utils.timer import Timer @@ -73,8 +74,9 @@ def __init__( ) self.use_deepspeed = True + traces.log(summary(get_base_model(self.model), depth=4, + col_names=['num_params', 'params_percent', 'trainable'], verbose=0)) traces.log("Allocated memory: %0.2fG" % (torch.cuda.max_memory_allocated() / 1e9)) - traces.log(summary(self.model, depth=4, col_names=['num_params', 'params_percent', 'trainable'], verbose=0)) self.loss_fn = partial( masked_loss, @@ -195,18 +197,21 @@ def save_model_state( save_path: str, tag: str ): - output_path = os.path.join(save_path, tag) + output_path = Path(save_path) / tag + weights_path = output_path / "adapter_model.safetensors" + embeddings_path = output_path / "new_embeddings.safetensors" self.model.save_pretrained(output_path, safe_serialization=True) - weights = safetensors.torch.load_file(os.path.join(output_path, "adapter_model.safetensors")) + weights = safetensors.torch.load_file(weights_path) lora_weights, embeddings_weights = {}, {} for key in weights.keys(): if "lora" in key: lora_weights[key] = weights[key] else: embeddings_weights[key] = weights[key] - safetensors.torch.save_file(lora_weights, os.path.join(output_path, "adapter_model.safetensors")) if len(embeddings_weights) > 0: - safetensors.torch.save_file(embeddings_weights, os.path.join(output_path, "new_embeddings.safetensors")) + weights_path.unlink() + safetensors.torch.save_file(lora_weights, weights_path) + safetensors.torch.save_file(embeddings_weights, embeddings_path) def _freeze_model( self,