Skip to content

Commit

Permalink
the fix
Browse files Browse the repository at this point in the history
  • Loading branch information
JegernOUTT authored and olegklimov committed Mar 15, 2024
1 parent e108a81 commit 88d8702
Show file tree
Hide file tree
Showing 4 changed files with 31 additions and 19 deletions.
16 changes: 7 additions & 9 deletions self_hosting_machinery/finetune/modelling/flash_sa.py
Original file line number Diff line number Diff line change
@@ -1,10 +1,11 @@
import functools
import math
from typing import Tuple, Optional

import einops
import torch
from typing import Tuple, Optional

from self_hosting_machinery.finetune.modelling.utils import get_base_model
from self_hosting_machinery.finetune.utils import traces


Expand Down Expand Up @@ -42,6 +43,7 @@ def _get_alibi_slopes(attn_heads: int, dev: str) -> torch.Tensor:
m = torch.cat([m, m_hat])
return m


def _prerequisites_are_ok(model, try_triton_kernel: bool):
try:
from flash_attn import flash_attn_func
Expand Down Expand Up @@ -90,8 +92,7 @@ def _forward(
return

traces.log("Applying flash attention to the model")
if type(model).__name__ == 'PeftModelForCausalLM':
model = model.base_model.model
model = get_base_model(model)
for block in model.transformer.h:
block.attn.forward = _forward.__get__(block.attn, type(block.attn))

Expand Down Expand Up @@ -134,8 +135,7 @@ def _forward(
return

traces.log("Applying flash attention to the model")
if type(model).__name__ == 'PeftModelForCausalLM':
model = model.base_model.model
model = get_base_model(model)
for block in model.transformer.h:
block.attn.forward = _forward.__get__(block.attn, type(block.attn))

Expand Down Expand Up @@ -184,8 +184,7 @@ def _forward(
return

traces.log("Applying flash attention to the model")
if type(model).__name__ == 'PeftModelForCausalLM':
model = model.base_model.model
model = get_base_model(model)
for layer in model.base_model.layers:
layer.self_attn.forward = _forward.__get__(layer.self_attn, type(layer.self_attn))

Expand Down Expand Up @@ -234,7 +233,6 @@ def _forward(
return

traces.log("Applying flash attention to the model")
if type(model).__name__ == 'PeftModelForCausalLM':
model = model.base_model.model
model = get_base_model(model)
for layer in model.base_model.layers:
layer.self_attn.forward = _forward.__get__(layer.self_attn, type(layer.self_attn))
7 changes: 3 additions & 4 deletions self_hosting_machinery/finetune/modelling/triton_flash_sa.py
Original file line number Diff line number Diff line change
@@ -1,15 +1,13 @@
import functools
import logging
import math
from typing import Optional

import torch as th
import triton
import triton.language as tl

from typing import Optional, Tuple

from einops import einops

from self_hosting_machinery.finetune.modelling.utils import get_base_model
from self_hosting_machinery.finetune.utils import traces


Expand Down Expand Up @@ -597,5 +595,6 @@ def _forward(
return

traces.log("Applying triton flash attention to the model")
model = get_base_model(model)
for block in model.transformer.h:
block.attn.forward = _forward.__get__(block.attn, type(block.attn))
10 changes: 10 additions & 0 deletions self_hosting_machinery/finetune/modelling/utils.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,18 @@
from typing import List, Tuple

import torch

from self_hosting_machinery.finetune.configuration import supported_models


def get_base_model(model: torch.nn.Module) -> torch.nn.Module:
if type(model).__name__ == "DeepSpeedEngine":
model = model.base_model
if type(model).__name__ in ("LoraModel", "PeftModelForCausalLM"):
model = model.model
return model


def map_model_specific_params(
model_name: str,
freeze_exceptions: List[str],
Expand Down
17 changes: 11 additions & 6 deletions self_hosting_machinery/finetune/scripts/auxiliary/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,7 @@
import os
from collections import defaultdict
from functools import partial
from pathlib import Path
from typing import Dict, Any, List, Tuple

import deepspeed
Expand All @@ -14,7 +15,7 @@

from self_hosting_machinery.finetune.configuration import supported_models
from self_hosting_machinery.finetune.modelling.loss import masked_loss
from self_hosting_machinery.finetune.modelling.utils import map_model_specific_params
from self_hosting_machinery.finetune.modelling.utils import map_model_specific_params, get_base_model
from self_hosting_machinery.finetune.utils import traces
from self_hosting_machinery.finetune.utils.timer import Timer

Expand Down Expand Up @@ -73,8 +74,9 @@ def __init__(
)
self.use_deepspeed = True

traces.log(summary(get_base_model(self.model), depth=4,
col_names=['num_params', 'params_percent', 'trainable'], verbose=0))
traces.log("Allocated memory: %0.2fG" % (torch.cuda.max_memory_allocated() / 1e9))
traces.log(summary(self.model, depth=4, col_names=['num_params', 'params_percent', 'trainable'], verbose=0))

self.loss_fn = partial(
masked_loss,
Expand Down Expand Up @@ -195,18 +197,21 @@ def save_model_state(
save_path: str,
tag: str
):
output_path = os.path.join(save_path, tag)
output_path = Path(save_path) / tag
weights_path = output_path / "adapter_model.safetensors"
embeddings_path = output_path / "new_embeddings.safetensors"
self.model.save_pretrained(output_path, safe_serialization=True)
weights = safetensors.torch.load_file(os.path.join(output_path, "adapter_model.safetensors"))
weights = safetensors.torch.load_file(weights_path)
lora_weights, embeddings_weights = {}, {}
for key in weights.keys():
if "lora" in key:
lora_weights[key] = weights[key]
else:
embeddings_weights[key] = weights[key]
safetensors.torch.save_file(lora_weights, os.path.join(output_path, "adapter_model.safetensors"))
if len(embeddings_weights) > 0:
safetensors.torch.save_file(embeddings_weights, os.path.join(output_path, "new_embeddings.safetensors"))
weights_path.unlink()
safetensors.torch.save_file(lora_weights, weights_path)
safetensors.torch.save_file(embeddings_weights, embeddings_path)

def _freeze_model(
self,
Expand Down

0 comments on commit 88d8702

Please sign in to comment.