From c0980fd08bfe257abf183cb79c2ade4e2c33ee53 Mon Sep 17 00:00:00 2001 From: shashikanth Date: Thu, 14 Nov 2024 14:03:17 +0530 Subject: [PATCH 1/4] Added Support for Apple Silicon - No gguf support yet. - Build Triton and bitsandbytes from source - `cmake -DCOMPUTE_BACKEND=hip -S .` for bitsandbytes building --- unsloth-cli.py | 110 +++++++---- unsloth/__init__.py | 21 ++- unsloth/devices.py | 49 +++++ unsloth/kernels/utils.py | 22 ++- unsloth/mac_specific.py | 71 +++++++ unsloth/models/_utils.py | 12 +- unsloth/models/mlx_lora.py | 212 +++++++++++++++++++++ unsloth/models/mlx_models.py | 355 +++++++++++++++++++++++++++++++++++ unsloth/models/mlx_utils.py | 210 +++++++++++++++++++++ unsloth/save.py | 9 +- unsloth/sd_hijack_utils.py | 28 +++ unsloth/tokenizer_utils.py | 3 +- 12 files changed, 1047 insertions(+), 55 deletions(-) create mode 100644 unsloth/devices.py create mode 100644 unsloth/mac_specific.py create mode 100644 unsloth/models/mlx_lora.py create mode 100644 unsloth/models/mlx_models.py create mode 100644 unsloth/models/mlx_utils.py create mode 100644 unsloth/sd_hijack_utils.py diff --git a/unsloth-cli.py b/unsloth-cli.py index ddb0ac8b..92c0b4c4 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -31,6 +31,8 @@ import argparse +from unsloth.devices import has_mps + def run(args): import torch from unsloth import FastLanguageModel @@ -40,29 +42,60 @@ def run(args): from unsloth import is_bfloat16_supported import logging logging.getLogger('hf-to-gguf').setLevel(logging.WARNING) + if has_mps: + import mlx.optimizers as optim + import mlx.core as mx + from unsloth.models import mlx_utils as lora_utils + from unsloth.models import mlx_lora + import numpy as np + from unsloth.models.mlx_models import LoRALinear + from mlx.utils import tree_flatten + from pathlib import Path + if not has_mps: # Load model and tokenizer - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=args.model_name, - max_seq_length=args.max_seq_length, - dtype=args.dtype, - load_in_4bit=args.load_in_4bit, - ) + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=args.model_name, + max_seq_length=args.max_seq_length, + dtype=args.dtype, + load_in_4bit=args.load_in_4bit, + ) + else: + np.random.seed(args.seed) + + # Building tokenizer_config + tokenizer_config = {} + + print("Loading pretrained model") + model, tokenizer, config = lora_utils.load(args.model_name, tokenizer_config) + # Freeze all layers other than LORA linears + model.freeze() + for l in model.model.layers[len(model.model.layers) - args.r :]: + l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) + if hasattr(l, "block_sparse_moe"): + l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) + + p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 + print(f"Total parameters {p:.3f}M") + p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + print(f"Trainable parameters {p:.3f}M") # Configure PEFT model - model = FastLanguageModel.get_peft_model( - model, - r=args.r, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj"], - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - bias=args.bias, - use_gradient_checkpointing=args.use_gradient_checkpointing, - random_state=args.random_state, - use_rslora=args.use_rslora, - loftq_config=args.loftq_config, - ) + if not has_mps: + model = FastLanguageModel.get_peft_model( + model, + r=args.r, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias=args.bias, + use_gradient_checkpointing=args.use_gradient_checkpointing, + random_state=args.random_state, + use_rslora=args.use_rslora, + loftq_config=args.loftq_config, + ) alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -110,19 +143,25 @@ def formatting_prompts_func(examples): ) # Initialize trainer - trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - train_dataset=dataset, - dataset_text_field="text", - max_seq_length=args.max_seq_length, - dataset_num_proc=2, - packing=False, - args=training_args, - ) + if not has_mps: + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=args.max_seq_length, + dataset_num_proc=2, + packing=False, + args=training_args, + ) # Train model - trainer_stats = trainer.train() + trainer_stats = trainer.train() + else: + datasets = dataset.train_test_split(test_size=0.1) + opt = optim.Adam(learning_rate=args.learning_rate) + mlx_lora.train(model, datasets["train"], datasets["test"], opt, mlx_lora.loss, tokenizer, args) + # Save model if args.save_model: @@ -152,9 +191,13 @@ def formatting_prompts_func(examples): quantization_method=quantization_method, ) else: - model.save_pretrained_merged(args.save_path, tokenizer, args.save_method) - if args.push_model: - model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token) + if has_mps: + mx.savez(Path(args.save_path,args.adapter_file), **dict(tree_flatten(model.trainable_parameters()))) + model.save_merged_model(args) + else: + model.save_pretrained_merged(args.save_path, tokenizer, args.save_method) + if args.push_model: + model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token) else: print("Warning: The model is not saved!") @@ -203,6 +246,7 @@ def formatting_prompts_func(examples): # Saving and pushing arguments save_group = parser.add_argument_group('💾 Save Model Options') + save_group.add_argument('--adapter_file', type=str, default="adapters.npz", help="Adapters file name") save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory") save_group.add_argument('--save_model', action='store_true', help="Save the model after training") save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'") diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 745b2102..33eddf68 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -16,6 +16,7 @@ from packaging.version import Version import os, re, subprocess, inspect import numpy as np +from unsloth import devices # # Define a list of modules to check # MODULES_TO_CHECK = ["bitsandbytes"] @@ -90,7 +91,11 @@ pass # Torch 2.4 has including_emulation -major_version, minor_version = torch.cuda.get_device_capability() +devices.get_optimal_device() +if torch.cuda.is_available(): + major_version, minor_version = torch.cuda.get_device_capability() +else: + major_version,minor_version = 0,0 SUPPORTS_BFLOAT16 = (major_version >= 8) old_is_bf16_supported = torch.cuda.is_bf16_supported @@ -104,7 +109,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass # Try loading bitsandbytes and triton -import bitsandbytes as bnb if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ: @@ -116,7 +120,9 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 else: from triton.common.build import libcuda_dirs try: - cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + if not devices.has_mps: + import bitsandbytes as bnb + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 libcuda_dirs() except: warnings.warn( @@ -141,8 +147,9 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 latest_cuda = possible_cudas[latest_cuda] os.system(f"ldconfig /usr/local/{latest_cuda}") pass - - importlib.reload(bnb) + if not devices.has_mps: + import bitsandbytes as bnb + importlib.reload(bnb) importlib.reload(triton) try: libcuda_dirs = lambda: None @@ -150,7 +157,9 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 try: from triton.backends.nvidia.driver import libcuda_dirs except: pass else: from triton.common.build import libcuda_dirs - cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + if not devices.has_mps: + import bitsandbytes as bnb + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 libcuda_dirs() except: warnings.warn( diff --git a/unsloth/devices.py b/unsloth/devices.py new file mode 100644 index 00000000..9190c5a0 --- /dev/null +++ b/unsloth/devices.py @@ -0,0 +1,49 @@ +import sys + +import torch + +if sys.platform == "darwin": + from unsloth import mac_specific + + +def has_mps() -> bool: + if sys.platform != "darwin": + return False + else: + return mac_specific.has_mps + + +def get_cuda_device_string(): + return "cuda" + + +def get_optimal_device_name(): + if torch.cuda.is_available(): + return get_cuda_device_string() + + if has_mps(): + return "mps" + + return "cpu" + + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) + + + +def torch_gc(): + + if torch.cuda.is_available(): + with torch.cuda.device(get_cuda_device_string()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + if has_mps(): + mac_specific.torch_mps_gc() + + + + + + diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index de543962..130fc90c 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -59,19 +59,23 @@ def calculate_settings(n : int) -> (int, int,): pass -import bitsandbytes as bnb +from unsloth import devices + # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files -HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") +HAS_CUDA_STREAM = False global CUDA_STREAM CUDA_STREAM = None -get_ptr = bnb.functional.get_ptr -import ctypes -cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 -cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 -cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 -cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 -cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 +if not devices.has_mps: + import bitsandbytes as bnb + HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") + get_ptr = bnb.functional.get_ptr + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 + cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 + cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 + cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 +import ctypes def QUANT_STATE(W): return getattr(W, "quant_state", None) diff --git a/unsloth/mac_specific.py b/unsloth/mac_specific.py new file mode 100644 index 00000000..6ef92416 --- /dev/null +++ b/unsloth/mac_specific.py @@ -0,0 +1,71 @@ +import logging + +import torch +import platform +from unsloth.sd_hijack_utils import CondFunc +from packaging import version + +log = logging.getLogger(__name__) + + +# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# use check `getattr` and try it for compatibility. +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 +def check_for_mps() -> bool: + if version.parse(torch.__version__) <= version.parse("2.0.1"): + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + +has_mps = check_for_mps() + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if output_dtype == torch.int64: + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) + return cumsum_func(input, *args, **kwargs) + + +if has_mps: + if platform.mac_ver()[0].startswith("13.2."): + # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) + CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) + + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 + CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), + lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), + lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') + # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 + CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) + elif version.parse(torch.__version__) > version.parse("1.13.1"): + cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) + cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) + CondFunc('torch.cumsum', cumsum_fix_func, None) + CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) + CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') + + # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 + if platform.processor() == 'i386': + for funcName in ['torch.argmax', 'torch.Tensor.argmax']: + CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index daa81d97..95791352 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -227,11 +227,17 @@ def _is_openai_available(): return False # ============================================= # Get Flash Attention v2 if Ampere (RTX 30xx, A100) -import bitsandbytes as bnb +from unsloth import devices +if not devices.has_mps: + import bitsandbytes as bnb from transformers import AutoTokenizer from transformers.utils.import_utils import _is_package_available -major_version, minor_version = torch.cuda.get_device_capability() +devices.get_optimal_device() +if torch.cuda.is_available(): + major_version, minor_version = torch.cuda.get_device_capability() +else: + major_version,minor_version = 0,0 SUPPORTS_BFLOAT16 = False HAS_FLASH_ATTENTION = False HAS_FLASH_ATTENTION_SOFTCAPPING = False @@ -906,7 +912,7 @@ def check_nvidia(): output = re.findall(rb'([\d]{1,})[\s]{1,}M', output) output = np.array([int(x.decode('utf-8'))/1024 for x in output]) except: - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and not devices.has_mps: raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!") return output pass diff --git a/unsloth/models/mlx_lora.py b/unsloth/models/mlx_lora.py new file mode 100644 index 00000000..6fb5d928 --- /dev/null +++ b/unsloth/models/mlx_lora.py @@ -0,0 +1,212 @@ +# Copyright © 2023-2024 Apple Inc. + +import argparse +import json +import math +import time +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as optim +import numpy as np +from . import mlx_utils as lora_utils +from mlx.utils import tree_flatten + + +class Dataset: + """ + Light-weight wrapper to hold lines from a jsonl file + """ + + def __init__(self, path: Path, key: str = "text"): + if not path.exists(): + self._data = None + else: + with open(path, "r") as fid: + self._data = [json.loads(l) for l in fid] + self._key = key + + def __getitem__(self, idx: int): + return self._data[idx][self._key] + + def __len__(self): + return len(self._data) + + +def load(args): + def load_and_check(name): + dataset_path = Path(args.data) / f"{name}.jsonl" + try: + return Dataset(dataset_path) + except Exception as e: + print(f"Unable to build dataset {dataset_path} ({e})") + raise + + names = ("train", "valid", "test") + train, valid, test = (load_and_check(n) for n in names) + + if args.train and len(train) == 0: + raise ValueError( + "Training set not found or empty. Must provide training set for fine-tuning." + ) + if args.train and len(valid) == 0: + raise ValueError( + "Validation set not found or empty. Must provide validation set for fine-tuning." + ) + if args.test and len(test) == 0: + raise ValueError( + "Test set not found or empty. Must provide test set for evaluation." + ) + return train, valid, test + + +def loss(model, inputs, targets, lengths): + # Run model on inputs + logits, _ = model(inputs) + logits = logits.astype(mx.float32) + + # Mask padding tokens + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + # Calculate the loss + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + return ce, ntoks + + +def iterate_batches(dset, tokenizer, batch_size, train=False): + # Shuffle indices + while True: + indices = np.arange(len(dset)) + if train: + indices = np.random.permutation(indices) + + # Collect batches from dataset + for i in range(0, len(indices) - batch_size + 1, batch_size): + # Encode batch + batch = [tokenizer.encode(str(dset[np.int32(indices[i + j]).item()])) for j in range(batch_size)] + lengths = [len(x) for x in batch] + + # Check if any sequence is longer than 2048 tokens + if max(lengths) > 2048: + print( + "[WARNING] Some sequences are longer than 2048 tokens. " + "Consider pre-splitting your data to save memory." + ) + + # Pad to the max length + batch_arr = np.zeros((batch_size, max(lengths)), np.int32) + + for j in range(batch_size): + batch_arr[j, : lengths[j]] = batch[j] + batch = mx.array(batch_arr) + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + if not train: + break + + +def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): + all_losses = [] + ntokens = 0 + + # num_batches can be -1 to indicate the entire set + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for it, batch in zip( + index_iterator, + iterate_batches(dataset, tokenizer, batch_size), + ): + losses, toks = loss(model, *batch) + all_losses.append((losses * toks).item()) + ntokens += toks.item() + + return np.sum(all_losses) / ntokens + + +def train(model, train_set, val_set, optimizer, loss, tokenizer, args): + # Create value and grad function for loss + loss_value_and_grad = nn.value_and_grad(model, loss) + + losses = [] + n_tokens = 0 + + # Main training loop + start = time.perf_counter() + for it, batch in zip( + range(args.max_steps), + iterate_batches(train_set, tokenizer, args.per_device_train_batch_size, train=True), + ): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # Model update + optimizer.update(model, grad) + mx.eval(model.parameters(), optimizer.state, lvalue) + + # Record loss + losses.append(lvalue.item()) + n_tokens += toks.item() + + # Report training loss if needed + if (it + 1) % 10 == 0: + train_loss = np.mean(losses) + + stop = time.perf_counter() + print( + f"Iter {it + 1}: Train loss {train_loss:.3f}, " + f"It/sec {10 / (stop - start):.3f}, " + f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" + ) + losses = [] + n_tokens = 0 + start = time.perf_counter() + + # Report validation loss if needed + if it == 0 or (it + 1) % 200 == 0: + stop = time.perf_counter() + val_loss = evaluate( + model, val_set, loss, tokenizer, args.per_device_train_batch_size, 25 + ) + print( + f"Iter {it + 1}: " + f"Val loss {val_loss:.3f}, " + f"Val took {(time.perf_counter() - stop):.3f}s" + ) + + start = time.perf_counter() + + # Save adapter weights if needed + if (it + 1) % 100 == 0: + mx.savez( + args.adapter_file, **dict(tree_flatten(model.trainable_parameters())) + ) + print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.") + + +def generate(model, prompt, tokenizer, args): + print(prompt, end="", flush=True) + + prompt = mx.array(tokenizer.encode(prompt)) + + tokens = [] + skip = 0 + for token, n in zip( + lora_utils.generate(prompt, model, args.temp), + range(args.max_tokens), + ): + if token == tokenizer.eos_token_id: + break + + tokens.append(token.item()) + s = tokenizer.decode(tokens) + if len(s) - skip > 1: + print(s[skip:-1], end="", flush=True) + skip = len(s) - 1 + print(tokenizer.decode(tokens)[skip:], flush=True) + print("=" * 10) + if len(tokens) == 0: + print("No tokens generated for this prompt") + return diff --git a/unsloth/models/mlx_models.py b/unsloth/models/mlx_models.py new file mode 100644 index 00000000..122d7bf8 --- /dev/null +++ b/unsloth/models/mlx_models.py @@ -0,0 +1,355 @@ +# Copyright © 2023 Apple Inc. + +import inspect +import math +from dataclasses import dataclass +from typing import Dict, Optional, Tuple, Union +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +from . import mlx_utils as utils +from mlx.utils import tree_flatten, tree_unflatten + +@dataclass +class ModelArgs: + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + num_key_value_heads: int = None + rope_theta: float = 10000 + rope_traditional: bool = False + model_type: str = None + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + required_keys = {"factor", "type"} + if not all(key in self.rope_scaling for key in required_keys): + raise ValueError(f"rope_scaling must contain keys {required_keys}") + + if self.rope_scaling["type"] != "linear": + raise ValueError("rope_scaling 'type' currently only supports 'linear'") + + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +class LoRALinear(nn.Module): + @staticmethod + def from_linear(linear: nn.Linear, rank: int = 8): + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear + output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits + lora_lin = LoRALinear(input_dims, output_dims, rank) + lora_lin.linear = linear + return lora_lin + + def to_linear(self): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, nn.QuantizedLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = mx.float16 + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + lora_rank: int = 8, + bias: bool = False, + scale: float = 20.0, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, lora_rank), + ) + self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) + + def __call__(self, x): + dtype = self.linear.weight.dtype + if isinstance(self.linear, nn.QuantizedLinear): + dtype = self.linear.scales.dtype + y = self.linear(x.astype(dtype)) + z = (x @ self.lora_a) @ self.lora_b + return y + self.scale * z + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.repeats = n_heads // n_kv_heads + + head_dim = args.hidden_size // n_heads + self.scale = head_dim**-0.5 + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) + rope_scale = ( + 1 / args.rope_scaling["factor"] + if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" + else 1 + ) + self.rope = nn.RoPE( + head_dim, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + ) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + key_cache, value_cache = cache + queries = self.rope(queries, offset=key_cache.shape[2]) + keys = self.rope(keys, offset=key_cache.shape[2]) + keys = mx.concatenate([key_cache, keys], axis=2) + values = mx.concatenate([value_cache, values], axis=2) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=self.scale, mask=mask + ) + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output), (keys, values) + + +class MLP(nn.Module): + def __init__(self, dim, hidden_dim): + super().__init__() + self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) + self.down_proj = nn.Linear(hidden_dim, dim, bias=False) + self.up_proj = nn.Linear(dim, hidden_dim, bias=False) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args.hidden_size, args.intermediate_size) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Tuple[mx.array, mx.array]] = None, + ) -> mx.array: + r, cache = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out, cache + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = None + if h.shape[1] > 1: + mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) + mask = mask.astype(h.dtype) + + if cache is None: + cache = [None] * len(self.layers) + + for e, layer in enumerate(self.layers): + h, cache[e] = layer(h, mask, cache[e]) + + return self.norm(h), cache + + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.model = LlamaModel(args) + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out, cache = self.model(inputs, cache) + return self.lm_head(out), cache + + def save_merged_model(self,args): + model, tokenizer, config = utils.load(args.model_name) + + # Load the LoRA adapter weights which we assume should exist by this point + if not Path(args.save_path,args.adapter_file).is_file(): + raise ValueError( + f"Adapter file {args.adapter_file} missing. ") + + # Load adapters and get number of LoRA layers + adapters = list(mx.load(args.save_path+"/"+args.adapter_file).items()) + lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]]) + + # Freeze all layers other than LORA linears + model.freeze() + for l in model.model.layers[len(model.model.layers) - lora_layers :]: + l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) + l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) + if hasattr(l, "block_sparse_moe"): + l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) + + model.update(tree_unflatten(adapters)) + fused_linears = [ + (n, m.to_linear()) + for n, m in model.named_modules() + if isinstance(m, LoRALinear) + ] + + model.update_modules(tree_unflatten(fused_linears)) + + if False: + de_quantize_layers = [] + for n, m in model.named_modules(): + if isinstance(m, nn.QuantizedLinear): + bias = "bias" in m + weight = m.weight + weight = mx.dequantize( + weight, + m.scales, + m.biases, + m.group_size, + m.bits, + ).astype(mx.float16) + output_dims, input_dims = weight.shape + linear = nn.Linear(input_dims, output_dims, bias=bias) + linear.weight = weight + if bias: + linear.bias = m.bias + de_quantize_layers.append((n, linear)) + + model.update_modules(tree_unflatten(de_quantize_layers)) + + weights = dict(tree_flatten(model.parameters())) + utils.save_model(args.save_path, weights, tokenizer, config) + + if args.push_model: + from huggingface_hub import whoami + try: + username = whoami(token = args.hub_token)["name"] + except: + raise RuntimeError( + "Unsloth: Please supply a token!\n"\ + "Go to https://huggingface.co/settings/tokens" + ) + pass + pass + + if args.push_model and args.hub_path is not None: + hf_path = args.hub_path + if not Path(args.model_name).exists(): + # If the model path doesn't exist, assume it's an HF repo + hf_path = args.model_name + elif hf_path is None: + raise ValueError( + "Must provide original Hugging Face repo to upload local model." + ) + utils.upload_to_hub(config['_name_or_path'],config['model_type'],username,args.save_path, args.hub_path,args.hub_token) diff --git a/unsloth/models/mlx_utils.py b/unsloth/models/mlx_utils.py new file mode 100644 index 00000000..ad916519 --- /dev/null +++ b/unsloth/models/mlx_utils.py @@ -0,0 +1,210 @@ +# Copyright © 2023-2024 Apple Inc. + +import glob +import json +import logging +from pathlib import Path +from typing import Generator + +import mlx.core as mx +import mlx.nn as nn +from . import mlx_models as models +import transformers +from huggingface_hub import snapshot_download,create_repo +from unsloth.save import MODEL_CARD + +def fetch_from_hub(hf_path: str): + model_path = snapshot_download( + repo_id=hf_path, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + weight_files = glob.glob(f"{model_path}/*.safetensors") + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + config = transformers.AutoConfig.from_pretrained(hf_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + hf_path, + ) + return weights, config.to_dict(), tokenizer + + +def upload_to_hub(name_or_path,model_type,username,path: str, name: str, token : str): + import os + + from huggingface_hub import HfApi, ModelCard, logging + + repo_id = f"{name}" + + try: + create_repo( + repo_id = repo_id, + token = token, + repo_type = "model", + exist_ok = False, + private = None, + ) + except: + pass + + try: + content = MODEL_CARD.format( + username = username, + base_model = name_or_path, + model_type = model_type, + method = "", + extra = "unsloth", + ) + card = ModelCard(content) + card.push_to_hub(repo_id, token = token) + except: + pass + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=repo_id, exist_ok=True,token = token) + api.upload_folder( + folder_path=path, + path_in_repo = ".", + token=token, + repo_id=repo_id, + commit_message = "(Trained with Unsloth)", + repo_type="model" + ) + + + +def make_shards(weights: dict, max_file_size_gibibyte: int = 15): + max_file_size_bytes = max_file_size_gibibyte << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + if shard_size + v.nbytes > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += v.nbytes + shards.append(shard) + return shards + + +def save_model(save_dir: str, weights, tokenizer, config): + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + shards = make_shards(weights, max_file_size_gibibyte=5) + shards_count = len(shards) + shard_file_format = ( + "model-{:05d}-of-{:05d}.safetensors" + if shards_count > 1 + else "model.safetensors" + ) + + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + + for i, shard in enumerate(shards): + shard_name = shard_file_format.format(i + 1, shards_count) + mx.save_safetensors( + str(save_dir / shard_name), shard, metadata={"format": "mlx"} + ) + for weight_name in shard.keys(): + index_data["weight_map"][weight_name] = shard_name + del shard + + tokenizer.save_pretrained(save_dir) + with open(save_dir / "config.json", "w") as fid: + json.dump(config, fid, indent=4) + + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + with open(save_dir / "model.safetensors.index.json", "w") as f: + json.dump( + index_data, + f, + indent=4, + ) + + +def load(path_or_hf_repo: str, tokenizer_config={}): + # If the path exists, it will try to load model form it + # otherwise download and cache from the hf_repo and cache + + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + ) + + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + quantization = config.get("quantization", None) + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + model_args = models.ModelArgs.from_dict(config) + model = models.Model(model_args) + if quantization is not None: + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize( + model, + **quantization, + class_predicate=class_predicate, + ) + + model.load_weights(list(weights.items())) + + mx.eval(model.parameters()) + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_path, **tokenizer_config + ) + return model, tokenizer, config + + +def generate( + prompt: mx.array, model: nn.Module, temp: float = 0.0 +) -> Generator[mx.array, None, None]: + """ + Generate text based on the given prompt and model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + temp (float): The temperature for sampling. If temp is 0, use max sampling. + + Yields: + mx.array: The generated text. + """ + + def sample(logits: mx.array) -> mx.array: + return ( + mx.argmax(logits, axis=-1) + if temp == 0 + else mx.random.categorical(logits * (1 / temp)) + ) + + y = prompt + cache = None + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) + yield y + diff --git a/unsloth/save.py b/unsloth/save.py index b4c6b499..bb2c56eb 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -12,9 +12,12 @@ # See the License for the specific language governing permissions and # limitations under the License. -from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit -from peft.tuners.lora import Linear4bit as Peft_Linear4bit -from peft.tuners.lora import Linear as Peft_Linear +from unsloth.devices import has_mps + +if not has_mps(): + from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit + from peft.tuners.lora import Linear4bit as Peft_Linear4bit + from peft.tuners.lora import Linear as Peft_Linear from typing import Optional, Callable, Union, List import torch import os diff --git a/unsloth/sd_hijack_utils.py b/unsloth/sd_hijack_utils.py new file mode 100644 index 00000000..179ebc78 --- /dev/null +++ b/unsloth/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-1, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 302017d5..a0e3ac5a 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -24,6 +24,7 @@ import collections import numpy as np import gc +from unsloth import devices import subprocess from unsloth_zoo.tokenizer_utils import ( @@ -837,7 +838,7 @@ def check_nvidia(): output = re.findall(rb'([\d]{1,})[\s]{1,}M', output) output = np.array([int(x.decode('utf-8'))/1024 for x in output]) except: - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and not devices.has_mps: raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!") return output pass From 066c2278215367de20d712000827e9ad2bc793be Mon Sep 17 00:00:00 2001 From: shashikanth Date: Mon, 25 Nov 2024 10:46:26 +0530 Subject: [PATCH 2/4] minor fixes and enhancements - lazy loading of model - minor refactoring - optimizers and lr schedulers - gc - should improve memory consumption --- unsloth-cli.py | 42 +--- unsloth/mlx/lora.py | 105 ++++++++ unsloth/{models => mlx}/mlx_utils.py | 174 +++++++++++-- unsloth/mlx/models/base.py | 111 +++++++++ unsloth/mlx/models/cache.py | 116 +++++++++ unsloth/mlx/models/llama.py | 304 +++++++++++++++++++++++ unsloth/mlx/trainer/lora.py | 283 +++++++++++++++++++++ unsloth/mlx/trainer/switch_layers.py | 165 +++++++++++++ unsloth/mlx/trainer/trainer.py | 328 +++++++++++++++++++++++++ unsloth/mlx/trainer/utils.py | 149 +++++++++++ unsloth/models/mlx_lora.py | 212 ---------------- unsloth/models/mlx_models.py | 355 --------------------------- 12 files changed, 1726 insertions(+), 618 deletions(-) create mode 100644 unsloth/mlx/lora.py rename unsloth/{models => mlx}/mlx_utils.py (51%) create mode 100644 unsloth/mlx/models/base.py create mode 100644 unsloth/mlx/models/cache.py create mode 100644 unsloth/mlx/models/llama.py create mode 100644 unsloth/mlx/trainer/lora.py create mode 100644 unsloth/mlx/trainer/switch_layers.py create mode 100644 unsloth/mlx/trainer/trainer.py create mode 100644 unsloth/mlx/trainer/utils.py delete mode 100644 unsloth/models/mlx_lora.py delete mode 100644 unsloth/models/mlx_models.py diff --git a/unsloth-cli.py b/unsloth-cli.py index 92c0b4c4..111a31e0 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -43,14 +43,9 @@ def run(args): import logging logging.getLogger('hf-to-gguf').setLevel(logging.WARNING) if has_mps: - import mlx.optimizers as optim - import mlx.core as mx - from unsloth.models import mlx_utils as lora_utils - from unsloth.models import mlx_lora - import numpy as np - from unsloth.models.mlx_models import LoRALinear - from mlx.utils import tree_flatten - from pathlib import Path + from unsloth.mlx import mlx_utils + from unsloth.mlx import lora as mlx_lora + import gc if not has_mps: # Load model and tokenizer @@ -61,26 +56,9 @@ def run(args): load_in_4bit=args.load_in_4bit, ) else: - np.random.seed(args.seed) - - # Building tokenizer_config - tokenizer_config = {} - print("Loading pretrained model") - model, tokenizer, config = lora_utils.load(args.model_name, tokenizer_config) - # Freeze all layers other than LORA linears - model.freeze() - for l in model.model.layers[len(model.model.layers) - args.r :]: - l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) - l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) - if hasattr(l, "block_sparse_moe"): - l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) - - p = sum(v.size for _, v in tree_flatten(model.parameters())) / 10**6 - print(f"Total parameters {p:.3f}M") - p = sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 - print(f"Trainable parameters {p:.3f}M") - + model, tokenizer, config = mlx_utils.load_pretrained(args.model_name) + # Configure PEFT model if not has_mps: model = FastLanguageModel.get_peft_model( @@ -159,8 +137,7 @@ def formatting_prompts_func(examples): trainer_stats = trainer.train() else: datasets = dataset.train_test_split(test_size=0.1) - opt = optim.Adam(learning_rate=args.learning_rate) - mlx_lora.train(model, datasets["train"], datasets["test"], opt, mlx_lora.loss, tokenizer, args) + mlx_lora.train_model(args,model,tokenizer, datasets["train"], datasets["test"]) # Save model @@ -192,8 +169,11 @@ def formatting_prompts_func(examples): ) else: if has_mps: - mx.savez(Path(args.save_path,args.adapter_file), **dict(tree_flatten(model.trainable_parameters()))) - model.save_merged_model(args) + del model + gc.collect() + mlx_utils.save_merged_model(args) + if args.push_model: + mlx_utils.push_to_hub(args,config["_name_or_path"],config["model_type"]) else: model.save_pretrained_merged(args.save_path, tokenizer, args.save_method) if args.push_model: diff --git a/unsloth/mlx/lora.py b/unsloth/mlx/lora.py new file mode 100644 index 00000000..3e423052 --- /dev/null +++ b/unsloth/mlx/lora.py @@ -0,0 +1,105 @@ +from pathlib import Path +import math +import mlx.nn as nn +import mlx.optimizers as optim +import mlx.core as mx +from .trainer.trainer import TrainingArgs, TrainingCallback, train +from .trainer.utils import ( + build_schedule, + linear_to_lora_layers, + print_trainable_parameters, +) +from .mlx_utils import save_config + +def train_model( + args, + model: nn.Module, + tokenizer, + train_set, + valid_set, + training_callback: TrainingCallback = None, +): + model.freeze() + linear_to_lora_layers( + model, + min(args.r,len(model.layers)/2), + {"rank": args.r, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "scale": float(args.lora_alpha)/math.sqrt(float(args.r)) if args.use_rslora else float(args.lora_alpha)/float(args.r)}, + ) + print_trainable_parameters(model) + + adapter_path = Path(args.save_path) + adapter_path.mkdir(parents=True, exist_ok=True) + + adapter_file = adapter_path / args.adapter_file + config = { + "num_layers" : min(args.r,len(model.layers)/2), + "lora_parameters" : {"rank": args.r, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "scale": float(args.lora_alpha)/math.sqrt(float(args.r)) if args.use_rslora else float(args.lora_alpha)/float(args.r)} + } + save_config(config, adapter_path / "adapter_config.json") + + # init training args + training_args = TrainingArgs( + batch_size=args.per_device_train_batch_size, + iters=args.max_steps, + val_batches=25, + steps_per_report=10, + steps_per_eval=200, + steps_per_save=100, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + grad_checkpoint=args.use_gradient_checkpointing, + ) + + mx.random.seed(args.seed) + model.train() + if args.lr_scheduler_type == "linear": + arguments = [0.0,args.learning_rate,args.warmup_steps] + elif args.lr_scheduler_type == "exponential_decay": + arguments = [args.learning_rate,args.weight_decay] + elif args.lr_scheduler_type == "step_decay": + arguments = [args.learning_rate,args.weight_decay,args.warmup_steps] + elif args.lr_scheduler_type == "cosine_decay": + arguments = [args.learning_rate,args.max_steps] + else: + arguments = [args.learning_rate] + + schedule_config = { + "name": "linear_schedule" if args.lr_scheduler_type == "linear" else args.lr_scheduler_type, + "warmup": args.warmup_steps, + "arguments": arguments, + } + + lr = build_schedule(schedule_config) if args.lr_scheduler_type else args.learning_rate + + if args.optim.lower().startswith("sgd"): + opt = optim.SGD(learning_rate=(lr), weight_decay=args.weight_decay) + elif args.optim.lower().startswith("rmsprop"): + opt = optim.RMSprop(learning_rate=(lr)) + elif args.optim.lower().startswith("adagrad"): + opt = optim.Adagrad(learning_rate=(lr)) + elif args.optim.lower().startswith("adaDelta"): + opt = optim.AdaDelta(learning_rate=(lr)) + elif args.optim.lower().startswith("adamw"): + opt = optim.AdamW(learning_rate=(lr),weight_decay=args.weight_decay) + elif args.optim.lower().startswith("adam"): + opt = optim.Adam(learning_rate=(lr)) + elif args.optim.lower().startswith("adamax"): + opt = optim.Adamax(learning_rate=(lr)) + elif args.optim.lower().startswith("lion"): + opt = optim.Lion(learning_rate=(lr), weight_decay=args.weight_decay) + elif args.optim.lower().startswith("adafactor"): + opt = optim.Adafactor(learning_rate=(lr), weight_decay= args.weight_decay) + else: + raise ValueError("The Optimizer type provided is not supported") + + # Train model + train( + model=model, + tokenizer=tokenizer, + args=training_args, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + training_callback=training_callback, + ) + diff --git a/unsloth/models/mlx_utils.py b/unsloth/mlx/mlx_utils.py similarity index 51% rename from unsloth/models/mlx_utils.py rename to unsloth/mlx/mlx_utils.py index ad916519..6d423d30 100644 --- a/unsloth/models/mlx_utils.py +++ b/unsloth/mlx/mlx_utils.py @@ -1,17 +1,24 @@ -# Copyright © 2023-2024 Apple Inc. - +import gc import glob +import shutil import json import logging from pathlib import Path -from typing import Generator +from typing import Generator, Optional,Type, Callable, Tuple, Union import mlx.core as mx import mlx.nn as nn -from . import mlx_models as models +from .models import llama as models import transformers from huggingface_hub import snapshot_download,create_repo from unsloth.save import MODEL_CARD +from mlx.utils import tree_flatten, tree_unflatten +from .trainer.utils import load_adapters + +MODEL_REMAPPING = { + "mistral": "llama", # mistral is compatible with llama +} + def fetch_from_hub(hf_path: str): model_path = snapshot_download( @@ -96,7 +103,7 @@ def save_model(save_dir: str, weights, tokenizer, config): save_dir = Path(save_dir) save_dir.mkdir(parents=True, exist_ok=True) - shards = make_shards(weights, max_file_size_gibibyte=5) + shards = make_shards(weights, max_file_size_gibibyte=1) shards_count = len(shards) shard_file_format = ( "model-{:05d}-of-{:05d}.safetensors" @@ -130,25 +137,27 @@ def save_model(save_dir: str, weights, tokenizer, config): indent=4, ) +def _get_classes(config: dict): + model_type = config["model_type"] + if model_type != "llama" and MODEL_REMAPPING.get(model_type,model_type) != "llama": + msg = f"Model type {model_type} not supported." + logging.error(msg) + raise ValueError(msg) -def load(path_or_hf_repo: str, tokenizer_config={}): - # If the path exists, it will try to load model form it - # otherwise download and cache from the hf_repo and cache - - model_path = Path(path_or_hf_repo) - if not model_path.exists(): - model_path = Path( - snapshot_download( - repo_id=path_or_hf_repo, - allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], - ) - ) + return models.Model, models.ModelArgs +def load(model_path: str, tokenizer_config={}, + get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,): + with open(model_path / "config.json", "r") as f: config = json.loads(f.read()) quantization = config.get("quantization", None) weight_files = glob.glob(str(model_path / "*.safetensors")) + if not weight_files: + # Try weight for back-compat + weight_files = glob.glob(str(model_path / "weight*.safetensors")) + if len(weight_files) == 0: raise FileNotFoundError("No safetensors found in {}".format(model_path)) @@ -156,8 +165,14 @@ def load(path_or_hf_repo: str, tokenizer_config={}): for wf in weight_files: weights.update(mx.load(wf).items()) - model_args = models.ModelArgs.from_dict(config) - model = models.Model(model_args) + model_class, model_args_class = get_model_classes(config=config) + + model_args = model_args_class.from_dict(config) + model = model_class(model_args) + + if hasattr(model, "sanitize"): + weights = model.sanitize(weights) + if quantization is not None: class_predicate = ( lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) @@ -171,13 +186,39 @@ def load(path_or_hf_repo: str, tokenizer_config={}): model.load_weights(list(weights.items())) - mx.eval(model.parameters()) + # mx.eval(model.parameters()) + model.eval() + tokenizer = transformers.AutoTokenizer.from_pretrained( model_path, **tokenizer_config ) return model, tokenizer, config +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Clean unused keys + config.pop("_name_or_path", None) + + # sort the config for better readability + config = dict(sorted(config.items())) + + # write the updated config to the config_path (if provided) + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) + + + def generate( prompt: mx.array, model: nn.Module, temp: float = 0.0 ) -> Generator[mx.array, None, None]: @@ -208,3 +249,96 @@ def sample(logits: mx.array) -> mx.array: y = sample(logits) yield y +def save_merged_model(args): + model_path = get_model_path(args.model_name) + model, tokenizer, config = load(model_path) + model.freeze() + + # Load the LoRA adapter weights which we assume should exist by this point + if not Path(args.save_path,args.adapter_file).is_file(): + raise ValueError( + f"Adapter file {args.adapter_file} missing. ") + + model = load_adapters(model, args.save_path,args.adapter_file) + + fused_linears = [ + (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") + ] + + if fused_linears: + model.update_modules(tree_unflatten(fused_linears)) + + weights = dict(tree_flatten(model.parameters())) + + save_model(args.save_path, weights, tokenizer, config) + + mx.metal.clear_cache() + del model + gc.collect() + + +def push_to_hub(args,name, model_type): + if args.push_model: + from huggingface_hub import whoami + try: + username = whoami(token = args.hub_token)["name"] + except: + raise RuntimeError( + "Unsloth: Please supply a token!\n"\ + "Go to https://huggingface.co/settings/tokens" + ) + pass + pass + + if args.push_model and args.hub_path is not None: + hf_path = args.hub_path + if not Path(args.model_name).exists(): + # If the model path doesn't exist, assume it's an HF repo + hf_path = args.model_name + elif hf_path is None: + raise ValueError( + "Must provide original Hugging Face repo to upload local model." + ) + upload_to_hub(name,model_type,username,args.save_path, args.hub_path,args.hub_token) + + +def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + try: + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + revision=revision, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ], + ) + ) + except: + raise FileNotFoundError( + f"Model not found for path or HF repo: {path_or_hf_repo}.\n" + "Please make sure you specified the local path or Hugging Face" + " repo id correctly.\nIf you are trying to access a private or" + " gated Hugging Face repo, make sure you are authenticated:\n" + "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login" + ) from None + return model_path + + + +def load_pretrained( + path_or_hf_repo: str, + tokenizer_config={}, + model_config={}, +): + model_path = get_model_path(path_or_hf_repo) + + model,tokenizer, config = load(model_path, tokenizer_config) + + return model, tokenizer, config \ No newline at end of file diff --git a/unsloth/mlx/models/base.py b/unsloth/mlx/models/base.py new file mode 100644 index 00000000..72b7e23f --- /dev/null +++ b/unsloth/mlx/models/base.py @@ -0,0 +1,111 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Optional + +import mlx.core as mx +from mlx.utils import tree_map + +from .cache import QuantizedKVCache + + +@dataclass +class BaseModelArgs: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + linds = linds[:, None] + rinds = rinds[None] + mask = linds < rinds + if window_size is not None: + mask = mask | (linds > rinds + window_size) + return mask * -1e9 + + +def create_attention_mask(h: mx.array, cache: Optional[Any] = None): + T = h.shape[1] + if T > 1: + window_size = None + offset = 0 + if cache is not None and cache[0] is not None: + c = cache[0] + if hasattr(c, "max_size"): + offset = min(c.max_size, c.offset) + window_size = c.max_size + else: + offset = c.offset + mask = create_causal_mask(T, offset, window_size=window_size) + mask = mask.astype(h.dtype) + else: + mask = None + return mask + + +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], +) -> mx.array: + if isinstance(cache, QuantizedKVCache): + return quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + return mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) diff --git a/unsloth/mlx/models/cache.py b/unsloth/mlx/models/cache.py new file mode 100644 index 00000000..e87b113e --- /dev/null +++ b/unsloth/mlx/models/cache.py @@ -0,0 +1,116 @@ + +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_map + + + +class _BaseCache: + @property + def state(self): + return [] + + @state.setter + def state(self, v): + if v is not None and v: + raise ValueError("This cache has no state but a state was set.") + + @property + def meta_state(self): + return "" + + @meta_state.setter + def meta_state(self, v): + if v is not None and v: + raise ValueError("This cache has no meta_state but a meta_state was set.") + + def is_trimmable(self): + return False + + +class QuantizedKVCache(_BaseCache): + def __init__(self, group_size: int = 64, bits: int = 8): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + self.group_size = group_size + self.bits = bits + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, k_head_dim = keys.shape + v_head_dim = values.shape[-1] + prev = self.offset + + if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: + el_per_int = 8 * mx.uint32.size // self.bits + new_steps = (self.step + num_steps - 1) // self.step * self.step + shape = (B, n_kv_heads, new_steps) + + def init_quant(dim): + return ( + mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + ) + + def expand_quant(x): + new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) + return mx.concatenate([x, new_x], axis=-2) + + if self.keys is not None: + if prev % self.step != 0: + self.keys, self.values = tree_map( + lambda x: x[..., :prev, :], (self.keys, self.values) + ) + + self.keys, self.values = tree_map( + expand_quant, (self.keys, self.values) + ) + else: + self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) + + self.offset += num_steps + + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] + + return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) + + @property + def state(self): + if self.offset == self.keys[0].shape[2]: + return self.keys, self.values + else: + return tree_map( + lambda x: x[..., : self.offset, :], (self.keys, self.values) + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) + + @meta_state.setter + def meta_state(self, v): + self.step, self.offset, self.group_size, self.bits = map(int, v) + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + + + diff --git a/unsloth/mlx/models/llama.py b/unsloth/mlx/models/llama.py new file mode 100644 index 00000000..7b85b686 --- /dev/null +++ b/unsloth/mlx/models/llama.py @@ -0,0 +1,304 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + if not "factor" in self.rope_scaling: + raise ValueError(f"rope_scaling must contain 'factor'") + rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( + "rope_type" + ) + if rope_type is None: + raise ValueError( + f"rope_scaling must contain either 'type' or 'rope_type'" + ) + if rope_type not in ["linear", "dynamic", "llama3"]: + raise ValueError( + "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" + ) + + +class DynamicNTKScalingRoPE(nn.Module): + """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scale: float = 1.0, + rope_type: str = "default", + rope_scaling: dict = None, + ): + super().__init__() + self.dims = dims + self.max_position_embeddings = max_position_embeddings + self.traditional = traditional + self.scale = scale + self.rope_type = rope_type + self.rope_scaling = rope_scaling + self.base = base + self.compute_freqs() + + def compute_freqs(self): + if self.rope_type != "llama3": + self._freqs = None + return + factor = self.rope_scaling["factor"] + low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.rope_scaling.get( + "original_max_position_embeddings", + 8192, + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) + wavelens = 2 * mx.pi * freqs + + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + self.base = None + + def extra_repr(self): + return ( + f"{self.dims}, traditional={self.traditional}, " + f"max_position_embeddings={self.max_position_embeddings}, " + f"scaling_factor={self.scale}, rope_type={self.rope_type}" + ) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + offset=offset, + freqs=self._freqs, + ) + + +def initialize_rope(args: ModelArgs): + head_dim = args.head_dim or args.hidden_size // args.num_attention_heads + + rope_scaling = args.rope_scaling + rope_type = "default" + rope_scale = 1.0 + + if rope_scaling is not None: + rope_type = ( + rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" + ) + if rope_type == "linear": + rope_scale = 1 / rope_scaling["factor"] + elif rope_type == "llama3": + rope_scale = 1.0 # The scaling is handled internally for llama3 + + return DynamicNTKScalingRoPE( + dims=head_dim, + max_position_embeddings=args.max_position_embeddings, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + rope_type=rope_type, + rope_scaling=rope_scaling, + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers diff --git a/unsloth/mlx/trainer/lora.py b/unsloth/mlx/trainer/lora.py new file mode 100644 index 00000000..3072ccfd --- /dev/null +++ b/unsloth/mlx/trainer/lora.py @@ -0,0 +1,283 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + +from .switch_layers import QuantizedSwitchLinear, SwitchLinear + + +class LoRALinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Linear, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear + output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits + lora_lin = LoRALinear( + input_dims=input_dims, + output_dims=output_dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self, de_quantize: bool = False): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, nn.QuantizedLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = linear.scales.dtype + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized and not de_quantize: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, r), + ) + self.lora_b = mx.zeros(shape=(r, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (self.dropout(x) @ self.lora_a) @ self.lora_b + return y + (self.scale * z).astype(x.dtype) + + +class LoRASwitchLinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Module, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + lora_lin = LoRASwitchLinear( + input_dims=linear.input_dims, + output_dims=linear.output_dims, + num_experts=linear.num_experts, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self, de_quantize: bool = False): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, QuantizedSwitchLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = mx.float16 + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + num_experts, output_dims, input_dims = weight.shape + fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) + + lora_b = (self.scale * self.lora_b).astype(dtype) + lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized and not de_quantize: + fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(r * num_experts, input_dims), + ) + self.lora_b = mx.zeros(shape=(num_experts, output_dims, r)) + self.num_experts = num_experts + + def __call__(self, x, indices): + shape = x.shape[:-3] + (self.num_experts, -1) + + y = self.linear(x, indices) + z = (self.dropout(x) @ self.lora_a.T).reshape(shape) + z = mx.take_along_axis(z, indices[..., None], axis=-2) + z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1) + + return y + (self.scale * z).astype(x.dtype) + + +class LoRAEmbedding(nn.Module): + @staticmethod + def from_base( + embedding: nn.Embedding, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + num_embeddings, dims = embedding.weight.shape + if isinstance(embedding, nn.QuantizedEmbedding): + dims *= 32 // embedding.bits + lora_embedding = LoRAEmbedding( + num_embeddings=num_embeddings, + dims=dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_embedding.embedding = embedding + return lora_embedding + + def fuse(self, de_quantize: bool = False): + embedding = self.embedding + weight = embedding.weight + is_quantized = isinstance(embedding, nn.QuantizedEmbedding) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = embedding.scales.dtype + weight = mx.dequantize( + weight, + embedding.scales, + embedding.biases, + embedding.group_size, + embedding.bits, + ) + num_embeddings, dims = weight.shape + fused_embedding = nn.Embedding(num_embeddings, dims) + + lora_a = (self.scale * self.lora_a).astype(dtype) + lora_b = self.lora_b.astype(dtype) + fused_embedding.weight = weight + lora_a @ lora_b + + if is_quantized and not de_quantize: + fused_embedding = nn.QuantizedEmbedding.from_embedding( + fused_embedding, + embedding.group_size, + embedding.bits, + ) + + return fused_embedding + + def __init__( + self, + num_embeddings: int, + dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + super().__init__() + + # Regular embedding layer + self.embedding = nn.Embedding(num_embeddings, dims) + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(num_embeddings) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_embeddings, r), + ) + self.lora_b = mx.zeros(shape=(r, dims)) + + def __call__(self, x): + y = self.embedding(x) + z = self.dropout(self.lora_a[x] @ self.lora_b) + out = y + (self.scale * z).astype(y.dtype) + return out + + def as_linear(self, x): + y = self.embedding.as_linear(x) + z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T + return y + (self.scale * z).astype(x.dtype) diff --git a/unsloth/mlx/trainer/switch_layers.py b/unsloth/mlx/trainer/switch_layers.py new file mode 100644 index 00000000..00aa65d8 --- /dev/null +++ b/unsloth/mlx/trainer/switch_layers.py @@ -0,0 +1,165 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + + +class QuantizedSwitchLinear(nn.Module): + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + ): + super().__init__() + + scale = math.sqrt(1 / input_dims) + self.weight, self.scales, self.biases = mx.quantize( + mx.random.uniform( + low=-scale, + high=scale, + shape=(num_experts, output_dims, input_dims), + ), + group_size=group_size, + bits=bits, + ) + + if bias: + self.bias = mx.zeros((num_experts, output_dims)) + + self.group_size = group_size + self.bits = bits + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + @property + def input_dims(self): + return self.scales.shape[2] * self.group_size + + @property + def output_dims(self): + return self.weight.shape[1] + + @property + def num_experts(self): + return self.weight.shape[0] + + def __call__(self, x, indices): + x = mx.gather_qmm( + x, + self["weight"], + self["scales"], + self["biases"], + rhs_indices=indices, + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if "bias" in self: + x = x + mx.expand_dims(self["bias"][indices], -2) + return x + + +class SwitchLinear(nn.Module): + def __init__( + self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True + ): + super().__init__() + scale = math.sqrt(1 / input_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_experts, output_dims, input_dims), + ) + + if bias: + self.bias = mx.zeros((num_experts, output_dims)) + + @property + def input_dims(self): + return self.weight.shape[2] + + @property + def output_dims(self): + return self.weight.shape[1] + + @property + def num_experts(self): + return self.weight.shape[0] + + def __call__(self, x, indices): + x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices) + if "bias" in self: + x = x + mx.expand_dims(self["bias"][indices], -2) + return x + + def to_quantized(self, group_size: int = 64, bits: int = 4): + num_experts, output_dims, input_dims = self.weight.shape + ql = QuantizedSwitchLinear( + input_dims, output_dims, num_experts, False, group_size, bits + ) + ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits) + if "bias" in self: + ql.bias = self.bias + return ql + + +class SwitchGLU(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=nn.silu, + bias: bool = False, + ): + super().__init__() + + self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) + self.activation = activation + + def __call__(self, x, indices) -> mx.array: + x = mx.expand_dims(x, (-2, -3)) + + x_up = self.up_proj(x, indices) + x_gate = self.gate_proj(x, indices) + x = self.down_proj(self.activation(x_gate) * x_up, indices) + + return x.squeeze(-2) + + +class SwitchMLP(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=nn.gelu_approx, + bias: bool = False, + ): + super().__init__() + + self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) + self.activation = activation + + def __call__(self, x, indices) -> mx.array: + x = mx.expand_dims(x, (-2, -3)) + + x = self.fc1(x, indices) + x = self.activation(x) + x = self.fc2(x, indices) + + return x.squeeze(-2) diff --git a/unsloth/mlx/trainer/trainer.py b/unsloth/mlx/trainer/trainer.py new file mode 100644 index 00000000..ebbd9ce7 --- /dev/null +++ b/unsloth/mlx/trainer/trainer.py @@ -0,0 +1,328 @@ +import time +from dataclasses import dataclass, field +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.nn.utils import average_gradients +from mlx.utils import tree_flatten + + +def grad_checkpoint(layer): + """ + Update all instances of type(layer) to use gradient checkpointing. + """ + fn = type(layer).__call__ + + def checkpointed_fn(model, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(model, *args, **kwargs) + + return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) + + type(layer).__call__ = checkpointed_fn + + +@dataclass +class TrainingArgs: + batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) + iters: int = field(default=100, metadata={"help": "Iterations to train for."}) + val_batches: int = field( + default=25, + metadata={ + "help": "Number of validation batches, -1 uses the entire validation set." + }, + ) + steps_per_report: int = field( + default=10, + metadata={"help": "Number of training steps between loss reporting."}, + ) + steps_per_eval: int = field( + default=200, metadata={"help": "Number of training steps between validations."} + ) + steps_per_save: int = field( + default=100, metadata={"help": "Save the model every number steps"} + ) + max_seq_length: int = field( + default=2048, metadata={"help": "Maximum sequence length."} + ) + adapter_file: str = field( + default="adapters.safetensors", + metadata={"help": "Save/load path for the trained adapter weights."}, + ) + grad_checkpoint: bool = field( + default=False, + metadata={"help": "Use gradient checkpointing to reduce memory use."}, + ) + + +def default_loss(model, inputs, targets, lengths): + logits = model(inputs) + logits = logits.astype(mx.float32) + + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce, ntoks + + +def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + # Sort by length: + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + if len(dataset) < batch_size: + raise ValueError( + f"Dataset must have at least batch_size={batch_size}" + f" examples but only has {len(dataset)}." + ) + + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + + # Make the batches: + batch_idx = [ + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) + ] + + while True: + indices = np.random.permutation(len(batch_idx)) + for i in indices: + # Encode batch + batch = [tokenizer.encode(str(dataset[np.int32(indices[j]).item()])) for j in batch_idx[i]] + for b in batch: + if b[-1] != tokenizer.eos_token_id: + b.append(tokenizer.eos_token_id) + + lengths = [len(x) for x in batch] + + if max(lengths) > max_seq_length: + print( + f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " + f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " + "Consider pre-splitting your data to save memory." + ) + + # Pad to the nearest multiple of 8 or the maximum length + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) + + for j in range(batch_size // step): + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + lengths[j] = ( + truncated_length # Update lengths to match truncated lengths + ) + batch = mx.array(batch_arr) + + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + if not train: + break + + +def evaluate( + model, + dataset, + tokenizer, + batch_size, + num_batches, + max_seq_length=2048, + loss: callable = default_loss, + iterate_batches: callable = iterate_batches, +): + all_losses = 0 + ntokens = 0 + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + losses, toks = loss(model, *batch) + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) + + all_losses = mx.distributed.all_sum(all_losses) + ntokens = mx.distributed.all_sum(ntokens) + + return (all_losses / ntokens).item() + + +class TrainingCallback: + + def on_train_loss_report(self, train_info: dict): + """Called to report training loss at specified intervals.""" + pass + + def on_val_loss_report(self, val_info: dict): + """Called to report validation loss at specified intervals or the beginning.""" + pass + + +def train( + model, + tokenizer, + optimizer, + train_dataset, + val_dataset, + args: TrainingArgs = TrainingArgs(), + loss: callable = default_loss, + iterate_batches: callable = iterate_batches, + training_callback: TrainingCallback = None, +): + print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") + + if args.grad_checkpoint: + grad_checkpoint(model.layers[0]) + + state = [model.state, optimizer.state] + + def step(batch): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + + # Model update + optimizer.update(model, grad) + + return lvalue, toks + + loss_value_and_grad = nn.value_and_grad(model, loss) + + losses = 0 + n_tokens = 0 + steps = 0 + trained_tokens = 0 + # Main training loop + start = time.perf_counter() + for it, batch in zip( + range(1, args.iters + 1), + iterate_batches( + dataset=train_dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + train=True, + ), + ): + # Report validation loss if needed, the first validation loss + # is always measured before any training. + if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: + stop = time.perf_counter() + val_loss = evaluate( + model=model, + dataset=val_dataset, + loss=loss, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.val_batches, + max_seq_length=args.max_seq_length, + iterate_batches=iterate_batches, + ) + val_time = time.perf_counter() - stop + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) + + if training_callback is not None: + val_info = { + "iteration": it, + "val_loss": val_loss, + "val_time": val_time, + } + training_callback.on_val_loss_report(val_info) + + start = time.perf_counter() + + lvalue, toks = step(batch) + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) + + # Report training loss if needed + if it % args.steps_per_report == 0 or it == args.iters: + stop = time.perf_counter() + + train_loss = mx.distributed.all_sum(losses).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens).item() + learning_rate = optimizer.learning_rate.item() + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) + trained_tokens += n_tokens + peak_mem = mx.metal.get_peak_memory() / 1e9 + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) + + if training_callback is not None: + train_info = { + "iteration": it, + "train_loss": train_loss, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + "peak_memory": peak_mem, + } + training_callback.on_train_loss_report(train_info) + + losses = 0 + n_tokens = 0 + steps = 0 + start = time.perf_counter() + + # Save adapter weights + if it % args.steps_per_save == 0: + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.savez(str(args.adapter_file), **adapter_weights) + checkpoint = ( + Path(args.adapter_file).parent / f"{it:07d}_{Path(args.adapter_file).name}" + ) + mx.savez(str(checkpoint), **adapter_weights) + print( + f"Iter {it}: Saved adapter weights to " + f"{args.adapter_file} and {checkpoint}." + ) + + # Save final weights + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.savez(str(args.adapter_file), **adapter_weights) + print(f"Saved final weights to {args.adapter_file}.") diff --git a/unsloth/mlx/trainer/utils.py b/unsloth/mlx/trainer/utils.py new file mode 100644 index 00000000..0c97f2ef --- /dev/null +++ b/unsloth/mlx/trainer/utils.py @@ -0,0 +1,149 @@ +# Copyright © 2024 Apple Inc. +import json +import types +from pathlib import Path +from typing import Dict + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as opt +from mlx.utils import tree_flatten, tree_unflatten + +from .switch_layers import QuantizedSwitchLinear, SwitchLinear +from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear + + +def build_schedule(schedule_config: Dict): + """ + Build a learning rate schedule from the given config. + """ + schedule_fn = getattr(opt.schedulers, schedule_config["name"]) + arguments = schedule_config["arguments"] + initial_lr = arguments[0] + bound_schedule_fn = schedule_fn(*arguments) + if warmup_steps := schedule_config.get("warmup", 0): + warmup_init = schedule_config.get("warmup_init", 0.0) + warmup_fn = opt.schedulers.linear_schedule( + warmup_init, initial_lr, warmup_steps + ) + return opt.schedulers.join_schedules( + [warmup_fn, bound_schedule_fn], [warmup_steps + 1] + ) + else: + return bound_schedule_fn + + +def linear_to_lora_layers( + model: nn.Module, + num_layers: int, + config: Dict, +): + """ + Convert some of the models linear layers to lora layers. + + Args: + model (nn.Module): The neural network model. + num_layers (int): The number of blocks to convert to lora layers + starting from the last layer. + config (dict): More configuration parameters for LoRA, including the + rank, scale, and optional layer keys. + use_dora (bool): If True, uses DoRA instead of LoRA. + Default: ``False`` + """ + if num_layers > len(model.layers): + raise ValueError( + f"Requested {num_layers} LoRA layers " + f"but the model only has {len(model.layers)} layers." + ) + + def to_lora(layer): + if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): + LoRALayer = LoRALinear + elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)): + LoRALayer = LoRASwitchLinear + elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)): + LoRALayer = LoRAEmbedding + else: + raise ValueError( + f"Can't convert layer of type {type(layer).__name__} to LoRA" + ) + + return LoRALayer.from_base( + layer, + r=config["rank"], + scale=config["scale"], + dropout=config["dropout"], + ) + + keys = config.get("keys", None) + if keys is not None: + keys = set(keys) + elif model.model_type in [ + "mistral", + "llama", + ]: + keys = set(["self_attn.q_proj", "self_attn.v_proj"]) + if model.model_type in ["mixtral", "phimoe"]: + keys.add("block_sparse_moe.gate") + if model.model_type == "qwen2_moe": + keys.add("mlp.gate") + keys.add("mlp.shared_expert_gate") + else: + raise ValueError(f"Lora does not support {model.model_type}") + + for l in model.layers[-min(num_layers, 0) :]: + lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] + if lora_layers: + l.update_modules(tree_unflatten(lora_layers)) + + lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys] + if lora_modules: + model.update_modules(tree_unflatten(lora_modules)) + + +def load_adapters(model: nn.Module, adapter_path: str,adapter_file : str) -> nn.Module: + """ + Load any fine-tuned adapters / layers. + + Args: + model (nn.Module): The neural network model. + adapter_path (str): Path to the adapter configuration file. + + Returns: + nn.Module: The updated model with LoRA layers applied. + """ + adapter_path = Path(adapter_path) + if not adapter_path.exists(): + raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") + with open(adapter_path / "adapter_config.json", "r") as fid: + config = types.SimpleNamespace(**json.load(fid)) + + linear_to_lora_layers( + model, + config.num_layers, + config.lora_parameters, + ) + # adapters = list(mx.load(str(adapter_path /adapter_file)).items()) + # model.load_weights(adapters,False) + model.load_weights(str(adapter_path / adapter_file), strict=False) + return model + + + +def print_trainable_parameters(model): + def nparams(m): + if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 + trainable_p = ( + sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + ) + print( + f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " + f"({trainable_p:.3f}M/{total_p:.3f}M)" + ) diff --git a/unsloth/models/mlx_lora.py b/unsloth/models/mlx_lora.py deleted file mode 100644 index 6fb5d928..00000000 --- a/unsloth/models/mlx_lora.py +++ /dev/null @@ -1,212 +0,0 @@ -# Copyright © 2023-2024 Apple Inc. - -import argparse -import json -import math -import time -from pathlib import Path - -import mlx.core as mx -import mlx.nn as nn -import mlx.optimizers as optim -import numpy as np -from . import mlx_utils as lora_utils -from mlx.utils import tree_flatten - - -class Dataset: - """ - Light-weight wrapper to hold lines from a jsonl file - """ - - def __init__(self, path: Path, key: str = "text"): - if not path.exists(): - self._data = None - else: - with open(path, "r") as fid: - self._data = [json.loads(l) for l in fid] - self._key = key - - def __getitem__(self, idx: int): - return self._data[idx][self._key] - - def __len__(self): - return len(self._data) - - -def load(args): - def load_and_check(name): - dataset_path = Path(args.data) / f"{name}.jsonl" - try: - return Dataset(dataset_path) - except Exception as e: - print(f"Unable to build dataset {dataset_path} ({e})") - raise - - names = ("train", "valid", "test") - train, valid, test = (load_and_check(n) for n in names) - - if args.train and len(train) == 0: - raise ValueError( - "Training set not found or empty. Must provide training set for fine-tuning." - ) - if args.train and len(valid) == 0: - raise ValueError( - "Validation set not found or empty. Must provide validation set for fine-tuning." - ) - if args.test and len(test) == 0: - raise ValueError( - "Test set not found or empty. Must provide test set for evaluation." - ) - return train, valid, test - - -def loss(model, inputs, targets, lengths): - # Run model on inputs - logits, _ = model(inputs) - logits = logits.astype(mx.float32) - - # Mask padding tokens - length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] - - # Calculate the loss - ce = nn.losses.cross_entropy(logits, targets) * length_mask - ntoks = length_mask.sum() - ce = ce.sum() / ntoks - return ce, ntoks - - -def iterate_batches(dset, tokenizer, batch_size, train=False): - # Shuffle indices - while True: - indices = np.arange(len(dset)) - if train: - indices = np.random.permutation(indices) - - # Collect batches from dataset - for i in range(0, len(indices) - batch_size + 1, batch_size): - # Encode batch - batch = [tokenizer.encode(str(dset[np.int32(indices[i + j]).item()])) for j in range(batch_size)] - lengths = [len(x) for x in batch] - - # Check if any sequence is longer than 2048 tokens - if max(lengths) > 2048: - print( - "[WARNING] Some sequences are longer than 2048 tokens. " - "Consider pre-splitting your data to save memory." - ) - - # Pad to the max length - batch_arr = np.zeros((batch_size, max(lengths)), np.int32) - - for j in range(batch_size): - batch_arr[j, : lengths[j]] = batch[j] - batch = mx.array(batch_arr) - yield batch[:, :-1], batch[:, 1:], mx.array(lengths) - - if not train: - break - - -def evaluate(model, dataset, loss, tokenizer, batch_size, num_batches): - all_losses = [] - ntokens = 0 - - # num_batches can be -1 to indicate the entire set - index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) - - for it, batch in zip( - index_iterator, - iterate_batches(dataset, tokenizer, batch_size), - ): - losses, toks = loss(model, *batch) - all_losses.append((losses * toks).item()) - ntokens += toks.item() - - return np.sum(all_losses) / ntokens - - -def train(model, train_set, val_set, optimizer, loss, tokenizer, args): - # Create value and grad function for loss - loss_value_and_grad = nn.value_and_grad(model, loss) - - losses = [] - n_tokens = 0 - - # Main training loop - start = time.perf_counter() - for it, batch in zip( - range(args.max_steps), - iterate_batches(train_set, tokenizer, args.per_device_train_batch_size, train=True), - ): - # Forward and backward pass - (lvalue, toks), grad = loss_value_and_grad(model, *batch) - - # Model update - optimizer.update(model, grad) - mx.eval(model.parameters(), optimizer.state, lvalue) - - # Record loss - losses.append(lvalue.item()) - n_tokens += toks.item() - - # Report training loss if needed - if (it + 1) % 10 == 0: - train_loss = np.mean(losses) - - stop = time.perf_counter() - print( - f"Iter {it + 1}: Train loss {train_loss:.3f}, " - f"It/sec {10 / (stop - start):.3f}, " - f"Tokens/sec {float(n_tokens) / (stop - start):.3f}" - ) - losses = [] - n_tokens = 0 - start = time.perf_counter() - - # Report validation loss if needed - if it == 0 or (it + 1) % 200 == 0: - stop = time.perf_counter() - val_loss = evaluate( - model, val_set, loss, tokenizer, args.per_device_train_batch_size, 25 - ) - print( - f"Iter {it + 1}: " - f"Val loss {val_loss:.3f}, " - f"Val took {(time.perf_counter() - stop):.3f}s" - ) - - start = time.perf_counter() - - # Save adapter weights if needed - if (it + 1) % 100 == 0: - mx.savez( - args.adapter_file, **dict(tree_flatten(model.trainable_parameters())) - ) - print(f"Iter {it + 1}: Saved adapter weights to {args.adapter_file}.") - - -def generate(model, prompt, tokenizer, args): - print(prompt, end="", flush=True) - - prompt = mx.array(tokenizer.encode(prompt)) - - tokens = [] - skip = 0 - for token, n in zip( - lora_utils.generate(prompt, model, args.temp), - range(args.max_tokens), - ): - if token == tokenizer.eos_token_id: - break - - tokens.append(token.item()) - s = tokenizer.decode(tokens) - if len(s) - skip > 1: - print(s[skip:-1], end="", flush=True) - skip = len(s) - 1 - print(tokenizer.decode(tokens)[skip:], flush=True) - print("=" * 10) - if len(tokens) == 0: - print("No tokens generated for this prompt") - return diff --git a/unsloth/models/mlx_models.py b/unsloth/models/mlx_models.py deleted file mode 100644 index 122d7bf8..00000000 --- a/unsloth/models/mlx_models.py +++ /dev/null @@ -1,355 +0,0 @@ -# Copyright © 2023 Apple Inc. - -import inspect -import math -from dataclasses import dataclass -from typing import Dict, Optional, Tuple, Union -from pathlib import Path - -import mlx.core as mx -import mlx.nn as nn -from . import mlx_utils as utils -from mlx.utils import tree_flatten, tree_unflatten - -@dataclass -class ModelArgs: - hidden_size: int - num_hidden_layers: int - intermediate_size: int - num_attention_heads: int - rms_norm_eps: float - vocab_size: int - num_key_value_heads: int = None - rope_theta: float = 10000 - rope_traditional: bool = False - model_type: str = None - rope_scaling: Optional[Dict[str, Union[float, str]]] = None - - def __post_init__(self): - if self.num_key_value_heads is None: - self.num_key_value_heads = self.num_attention_heads - - if self.rope_scaling: - required_keys = {"factor", "type"} - if not all(key in self.rope_scaling for key in required_keys): - raise ValueError(f"rope_scaling must contain keys {required_keys}") - - if self.rope_scaling["type"] != "linear": - raise ValueError("rope_scaling 'type' currently only supports 'linear'") - - @classmethod - def from_dict(cls, params): - return cls( - **{ - k: v - for k, v in params.items() - if k in inspect.signature(cls).parameters - } - ) - - -class LoRALinear(nn.Module): - @staticmethod - def from_linear(linear: nn.Linear, rank: int = 8): - # TODO remove when input_dims and output_dims are attributes - # on linear and quantized linear - output_dims, input_dims = linear.weight.shape - if isinstance(linear, nn.QuantizedLinear): - input_dims *= 32 // linear.bits - lora_lin = LoRALinear(input_dims, output_dims, rank) - lora_lin.linear = linear - return lora_lin - - def to_linear(self): - linear = self.linear - bias = "bias" in linear - weight = linear.weight - is_quantized = isinstance(linear, nn.QuantizedLinear) - - # Use the same type as the linear weight if not quantized - dtype = weight.dtype - - if is_quantized: - dtype = mx.float16 - weight = mx.dequantize( - weight, - linear.scales, - linear.biases, - linear.group_size, - linear.bits, - ) - output_dims, input_dims = weight.shape - fused_linear = nn.Linear(input_dims, output_dims, bias=bias) - - lora_b = (self.scale * self.lora_b.T).astype(dtype) - lora_a = self.lora_a.T.astype(dtype) - fused_linear.weight = weight + lora_b @ lora_a - if bias: - fused_linear.bias = linear.bias - - if is_quantized: - fused_linear = nn.QuantizedLinear.from_linear( - fused_linear, - linear.group_size, - linear.bits, - ) - - return fused_linear - - def __init__( - self, - input_dims: int, - output_dims: int, - lora_rank: int = 8, - bias: bool = False, - scale: float = 20.0, - ): - super().__init__() - - # Regular linear layer weights - self.linear = nn.Linear(input_dims, output_dims, bias=bias) - - # Scale for low-rank update - self.scale = scale - - # Low rank lora weights - scale = 1 / math.sqrt(input_dims) - self.lora_a = mx.random.uniform( - low=-scale, - high=scale, - shape=(input_dims, lora_rank), - ) - self.lora_b = mx.zeros(shape=(lora_rank, output_dims)) - - def __call__(self, x): - dtype = self.linear.weight.dtype - if isinstance(self.linear, nn.QuantizedLinear): - dtype = self.linear.scales.dtype - y = self.linear(x.astype(dtype)) - z = (x @ self.lora_a) @ self.lora_b - return y + self.scale * z - - -class Attention(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - - dim = args.hidden_size - self.n_heads = n_heads = args.num_attention_heads - self.n_kv_heads = n_kv_heads = args.num_key_value_heads - - self.repeats = n_heads // n_kv_heads - - head_dim = args.hidden_size // n_heads - self.scale = head_dim**-0.5 - - self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=False) - self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=False) - self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=False) - rope_scale = ( - 1 / args.rope_scaling["factor"] - if args.rope_scaling is not None and args.rope_scaling["type"] == "linear" - else 1 - ) - self.rope = nn.RoPE( - head_dim, - traditional=args.rope_traditional, - base=args.rope_theta, - scale=rope_scale, - ) - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - B, L, D = x.shape - - queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) - - # Prepare the queries, keys and values for the attention computation - queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) - keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) - - if cache is not None: - key_cache, value_cache = cache - queries = self.rope(queries, offset=key_cache.shape[2]) - keys = self.rope(keys, offset=key_cache.shape[2]) - keys = mx.concatenate([key_cache, keys], axis=2) - values = mx.concatenate([value_cache, values], axis=2) - else: - queries = self.rope(queries) - keys = self.rope(keys) - - output = mx.fast.scaled_dot_product_attention( - queries, keys, values, scale=self.scale, mask=mask - ) - output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) - return self.o_proj(output), (keys, values) - - -class MLP(nn.Module): - def __init__(self, dim, hidden_dim): - super().__init__() - self.gate_proj = nn.Linear(dim, hidden_dim, bias=False) - self.down_proj = nn.Linear(hidden_dim, dim, bias=False) - self.up_proj = nn.Linear(dim, hidden_dim, bias=False) - - def __call__(self, x) -> mx.array: - return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) - - -class TransformerBlock(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.num_attention_heads = args.num_attention_heads - self.hidden_size = args.hidden_size - self.self_attn = Attention(args) - self.mlp = MLP(args.hidden_size, args.intermediate_size) - self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - self.post_attention_layernorm = nn.RMSNorm( - args.hidden_size, eps=args.rms_norm_eps - ) - self.args = args - - def __call__( - self, - x: mx.array, - mask: Optional[mx.array] = None, - cache: Optional[Tuple[mx.array, mx.array]] = None, - ) -> mx.array: - r, cache = self.self_attn(self.input_layernorm(x), mask, cache) - h = x + r - r = self.mlp(self.post_attention_layernorm(h)) - out = h + r - return out, cache - - -class LlamaModel(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.args = args - self.vocab_size = args.vocab_size - self.num_hidden_layers = args.num_hidden_layers - assert self.vocab_size > 0 - self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) - self.layers = [ - TransformerBlock(args=args) for _ in range(args.num_hidden_layers) - ] - self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - h = self.embed_tokens(inputs) - - mask = None - if h.shape[1] > 1: - mask = nn.MultiHeadAttention.create_additive_causal_mask(h.shape[1]) - mask = mask.astype(h.dtype) - - if cache is None: - cache = [None] * len(self.layers) - - for e, layer in enumerate(self.layers): - h, cache[e] = layer(h, mask, cache[e]) - - return self.norm(h), cache - - - -class Model(nn.Module): - def __init__(self, args: ModelArgs): - super().__init__() - self.model = LlamaModel(args) - self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) - - def __call__( - self, - inputs: mx.array, - cache=None, - ): - out, cache = self.model(inputs, cache) - return self.lm_head(out), cache - - def save_merged_model(self,args): - model, tokenizer, config = utils.load(args.model_name) - - # Load the LoRA adapter weights which we assume should exist by this point - if not Path(args.save_path,args.adapter_file).is_file(): - raise ValueError( - f"Adapter file {args.adapter_file} missing. ") - - # Load adapters and get number of LoRA layers - adapters = list(mx.load(args.save_path+"/"+args.adapter_file).items()) - lora_layers = len([m for m in adapters if "q_proj.lora_a" in m[0]]) - - # Freeze all layers other than LORA linears - model.freeze() - for l in model.model.layers[len(model.model.layers) - lora_layers :]: - l.self_attn.q_proj = LoRALinear.from_linear(l.self_attn.q_proj) - l.self_attn.v_proj = LoRALinear.from_linear(l.self_attn.v_proj) - if hasattr(l, "block_sparse_moe"): - l.block_sparse_moe.gate = LoRALinear.from_linear(l.block_sparse_moe.gate) - - model.update(tree_unflatten(adapters)) - fused_linears = [ - (n, m.to_linear()) - for n, m in model.named_modules() - if isinstance(m, LoRALinear) - ] - - model.update_modules(tree_unflatten(fused_linears)) - - if False: - de_quantize_layers = [] - for n, m in model.named_modules(): - if isinstance(m, nn.QuantizedLinear): - bias = "bias" in m - weight = m.weight - weight = mx.dequantize( - weight, - m.scales, - m.biases, - m.group_size, - m.bits, - ).astype(mx.float16) - output_dims, input_dims = weight.shape - linear = nn.Linear(input_dims, output_dims, bias=bias) - linear.weight = weight - if bias: - linear.bias = m.bias - de_quantize_layers.append((n, linear)) - - model.update_modules(tree_unflatten(de_quantize_layers)) - - weights = dict(tree_flatten(model.parameters())) - utils.save_model(args.save_path, weights, tokenizer, config) - - if args.push_model: - from huggingface_hub import whoami - try: - username = whoami(token = args.hub_token)["name"] - except: - raise RuntimeError( - "Unsloth: Please supply a token!\n"\ - "Go to https://huggingface.co/settings/tokens" - ) - pass - pass - - if args.push_model and args.hub_path is not None: - hf_path = args.hub_path - if not Path(args.model_name).exists(): - # If the model path doesn't exist, assume it's an HF repo - hf_path = args.model_name - elif hf_path is None: - raise ValueError( - "Must provide original Hugging Face repo to upload local model." - ) - utils.upload_to_hub(config['_name_or_path'],config['model_type'],username,args.save_path, args.hub_path,args.hub_token) From df72331c0da7a30a264f7b7118f96108400be522 Mon Sep 17 00:00:00 2001 From: shashikanth Date: Thu, 28 Nov 2024 10:14:01 +0530 Subject: [PATCH 3/4] 4 bit quantized models added --- unsloth-cli.py | 2 +- unsloth/mlx/mlx_utils.py | 12 +- unsloth/models/mapper.py | 831 ++++++++++++++++++++++----------------- 3 files changed, 475 insertions(+), 370 deletions(-) diff --git a/unsloth-cli.py b/unsloth-cli.py index 111a31e0..6bd6a8fc 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -57,7 +57,7 @@ def run(args): ) else: print("Loading pretrained model") - model, tokenizer, config = mlx_utils.load_pretrained(args.model_name) + model, tokenizer, config = mlx_utils.load_pretrained(args.model_name,dtype=args.dtype,load_in_4bit=args.load_in_4bit) # Configure PEFT model if not has_mps: diff --git a/unsloth/mlx/mlx_utils.py b/unsloth/mlx/mlx_utils.py index 6d423d30..03da8246 100644 --- a/unsloth/mlx/mlx_utils.py +++ b/unsloth/mlx/mlx_utils.py @@ -8,6 +8,8 @@ import mlx.core as mx import mlx.nn as nn + +from unsloth.models.loader import get_model_name from .models import llama as models import transformers from huggingface_hub import snapshot_download,create_repo @@ -250,7 +252,8 @@ def sample(logits: mx.array) -> mx.array: yield y def save_merged_model(args): - model_path = get_model_path(args.model_name) + model_name = get_model_name(args.model_name,args.load_in_4bit) + model_path = get_model_path(model_name) model, tokenizer, config = load(model_path) model.freeze() @@ -333,11 +336,14 @@ def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path def load_pretrained( - path_or_hf_repo: str, + model_name: str, tokenizer_config={}, model_config={}, + dtype= None, + load_in_4bit=True ): - model_path = get_model_path(path_or_hf_repo) + model_name = get_model_name(model_name,load_in_4bit) + model_path = get_model_path(model_name) model,tokenizer, config = load(model_path, tokenizer_config) diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index d4f1278e..90ce96e6 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -17,441 +17,540 @@ "FLOAT_TO_INT_MAPPER", ] -__INT_TO_FLOAT_MAPPER = \ +from unsloth.devices import has_mps + + +if not has_mps: + __INT_TO_FLOAT_MAPPER = \ + { + "unsloth/mistral-7b-bnb-4bit" : ( + "unsloth/mistral-7b", + "mistralai/Mistral-7B-v0.1", + ), + "unsloth/llama-2-7b-bnb-4bit" : ( + "unsloth/llama-2-7b", + "meta-llama/Llama-2-7b-hf", + ), + "unsloth/llama-2-13b-bnb-4bit" : ( + "unsloth/llama-2-13b", + "meta-llama/Llama-2-13b-hf", + ), + "unsloth/codellama-34b-bnb-4bit" : ( + "codellama/CodeLlama-34b-hf", + ), + "unsloth/zephyr-sft-bnb-4bit" : ( + "unsloth/zephyr-sft", + "HuggingFaceH4/mistral-7b-sft-beta", + ), + "unsloth/tinyllama-bnb-4bit" : ( + "unsloth/tinyllama", + "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", + ), + "unsloth/tinyllama-chat-bnb-4bit" : ( + "unsloth/tinyllama-chat", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ), + "unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : ( + "unsloth/mistral-7b-instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.1", + ), + "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : ( + "unsloth/mistral-7b-instruct-v0.2", + "mistralai/Mistral-7B-Instruct-v0.2", + ), + "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "unsloth/llama-2-7b-chat", + "meta-llama/Llama-2-7b-chat-hf", + ), + "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "unsloth/llama-2-7b-chat", + "meta-llama/Llama-2-7b-chat-hf", + ), + "unsloth/codellama-7b-bnb-4bit" : ( + "unsloth/codellama-7b", + "codellama/CodeLlama-7b-hf", + ), + "unsloth/codellama-13b-bnb-4bit" : ( + "codellama/CodeLlama-13b-hf", + ), + "unsloth/yi-6b-bnb-4bit" : ( + "unsloth/yi-6b", + "01-ai/Yi-6B", + ), + "unsloth/solar-10.7b-bnb-4bit" : ( + "upstage/SOLAR-10.7B-v1.0", + ), + "unsloth/gemma-7b-bnb-4bit" : ( + "unsloth/gemma-7b", + "google/gemma-7b", + ), + "unsloth/gemma-2b-bnb-4bit" : ( + "unsloth/gemma-2b", + "google/gemma-2b", + ), + "unsloth/gemma-7b-it-bnb-4bit" : ( + "unsloth/gemma-7b-it", + "google/gemma-7b-it", + ), + "unsloth/gemma-2b-bnb-4bit" : ( + "unsloth/gemma-2b-it", + "google/gemma-2b-it", + ), + "unsloth/mistral-7b-v0.2-bnb-4bit" : ( + "unsloth/mistral-7b-v0.2", + "alpindale/Mistral-7B-v0.2-hf", + ), + "unsloth/gemma-1.1-2b-it-bnb-4bit" : ( + "unsloth/gemma-1.1-2b-it", + "google/gemma-1.1-2b-it", + ), + "unsloth/gemma-1.1-7b-it-bnb-4bit" : ( + "unsloth/gemma-1.1-7b-it", + "google/gemma-1.1-7b-it", + ), + "unsloth/Starling-LM-7B-beta-bnb-4bit" : ( + "unsloth/Starling-LM-7B-beta", + "Nexusflow/Starling-LM-7B-beta", + ), + "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit" : ( + "unsloth/Hermes-2-Pro-Mistral-7B", + "NousResearch/Hermes-2-Pro-Mistral-7B", + ), + "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit" : ( + "unsloth/OpenHermes-2.5-Mistral-7B", + "teknium/OpenHermes-2.5-Mistral-7B", + ), + "unsloth/codegemma-2b-bnb-4bit" : ( + "unsloth/codegemma-2b", + "google/codegemma-2b", + ), + "unsloth/codegemma-7b-bnb-4bit" : ( + "unsloth/codegemma-7b", + "google/codegemma-7b", + ), + "unsloth/codegemma-7b-it-bnb-4bit" : ( + "unsloth/codegemma-7b-it", + "google/codegemma-7b-it", + ), + "unsloth/llama-3-8b-bnb-4bit" : ( + "unsloth/llama-3-8b", + "meta-llama/Meta-Llama-3-8B", + ), + "unsloth/llama-3-8b-Instruct-bnb-4bit" : ( + "unsloth/llama-3-8b-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + ), + "unsloth/llama-3-70b-bnb-4bit" : ( + "meta-llama/Meta-Llama-3-70B", + ), + "unsloth/llama-3-70b-Instruct-bnb-4bit" : ( + "meta-llama/Meta-Llama-3-70B-Instruct", + ), + "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" : ( + "unsloth/Phi-3-mini-4k-instruct", + "microsoft/Phi-3-mini-4k-instruct", + ), + "unsloth/mistral-7b-v0.3-bnb-4bit" : ( + "unsloth/mistral-7b-v0.3", + "mistralai/Mistral-7B-v0.3", + ), + "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" : ( + "unsloth/mistral-7b-instruct-v0.3", + "mistralai/Mistral-7B-Instruct-v0.3", + ), + "unsloth/Phi-3-medium-4k-instruct-bnb-4bit" : ( + "unsloth/Phi-3-medium-4k-instruct", + "microsoft/Phi-3-medium-4k-instruct", + ), + "unsloth/Qwen2-0.5B-bnb-4bit" : ( + "unsloth/Qwen2-0.5B", + "Qwen/Qwen2-0.5B", + ), + "unsloth/Qwen2-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-0.5B-Instruct", + "Qwen/Qwen2-0.5B-Instruct", + ), + "unsloth/Qwen2-1.5B-bnb-4bit" : ( + "unsloth/Qwen2-1.5B", + "Qwen/Qwen2-1.5B", + ), + "unsloth/Qwen2-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-1.5B-Instruct", + "Qwen/Qwen2-1.5B-Instruct", + ), + "unsloth/Qwen2-7B-bnb-4bit" : ( + "unsloth/Qwen2-7B", + "Qwen/Qwen2-7B", + ), + "unsloth/Qwen2-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-7B-Instruct", + "Qwen/Qwen2-7B-Instruct", + ), + "unsloth/Qwen2-70B-bnb-4bit" : ( + "Qwen/Qwen2-70B", + ), + "unsloth/Qwen2-70B-Instruct-bnb-4bit" : ( + "Qwen/Qwen2-70B-Instruct", + ), + "mistralai/Codestral-22B-v0.1" : ( + "mistral-community/Codestral-22B-v0.1", + ), + "unsloth/gemma-2-9b-bnb-4bit" : ( + "unsloth/gemma-2-9b", + "google/gemma-2-9b", + ), + "unsloth/gemma-2-27b-bnb-4bit" : ( + "unsloth/gemma-2-27b", + "google/gemma-2-27b", + ), + "unsloth/gemma-2-9b-it-bnb-4bit" : ( + "unsloth/gemma-2-9b-it", + "google/gemma-2-9b-it", + ), + "unsloth/gemma-2-27b-it-bnb-4bit" : ( + "unsloth/gemma-2-27b-it", + "google/gemma-2-27b-it", + ), + "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit" : ( # Old Phi pre July + "unsloth/Phi-3-mini-4k-instruct-v0", + ), + "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit" : ( # New 12b Mistral models + "unsloth/Mistral-Nemo-Instruct-2407", + "mistralai/Mistral-Nemo-Instruct-2407", + ), + "unsloth/Mistral-Nemo-Base-2407-bnb-4bit" : ( # New 12b Mistral models + "unsloth/Mistral-Nemo-Base-2407", + "mistralai/Mistral-Nemo-Base-2407", + ), + "unsloth/Meta-Llama-3.1-8B-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-8B", + "meta-llama/Meta-Llama-3.1-8B", + ), + "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + ), + "unsloth/Meta-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B", + "meta-llama/Meta-Llama-3.1-70B", + ), + "unsloth/Meta-Llama-3.1-405B-bnb-4bit" : ( + "meta-llama/Meta-Llama-3.1-405B", + ), + "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit" : ( + "meta-llama/Meta-Llama-3.1-405B-Instruct", + ), + "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B-Instruct", + "meta-llama/Meta-Llama-3.1-70B-Instruct", + ), + "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : ( + "mistralai/Mistral-Large-Instruct-2407", + ), + "unsloth/gemma-2-2b-bnb-4bit" : ( + "unsloth/gemma-2-2b", + "google/gemma-2-2b", + ), + "unsloth/gemma-2-2b-it-bnb-4bit" : ( + "unsloth/gemma-2-2b-it", + "google/gemma-2-2b-it", + ), + "unsloth/Phi-3.5-mini-instruct-bnb-4bit" : ( + "unsloth/Phi-3.5-mini-instruct", + "microsoft/Phi-3.5-mini-instruct", + ), + "unsloth/c4ai-command-r-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-08-2024", + ), + "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-plus-08-2024", + ), + "unsloth/Llama-3.1-Storm-8B-bnb-4bit" : ( + "unsloth/Llama-3.1-Storm-8B", + "akjindal53244/Llama-3.1-Storm-8B", + ), + "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-8B", + "NousResearch/Hermes-3-Llama-3.1-8B", + ), + "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-70B", + "NousResearch/Hermes-3-Llama-3.1-70B", + ), + "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : ( + "NousResearch/Hermes-3-Llama-3.1-405B", + ), + "unsloth/SmolLM-135M-bnb-4bit" : ( + "unsloth/SmolLM-135M", + "HuggingFaceTB/SmolLM-135M", + ), + "unsloth/SmolLM-360M-bnb-4bit" : ( + "unsloth/SmolLM-360M", + "HuggingFaceTB/SmolLM-360M", + ), + "unsloth/SmolLM-1.7B-bnb-4bit" : ( + "unsloth/SmolLM-1.7B", + "HuggingFaceTB/SmolLM-1.7B", + ), + "unsloth/SmolLM-135M-Instruct-bnb-4bit" : ( + "unsloth/SmolLM-135M-Instruct", + "HuggingFaceTB/SmolLM-135M-Instruct", + ), + "unsloth/SmolLM-360M-Instruct-bnb-4bit" : ( + "unsloth/SmolLM-360M-Instruct", + "HuggingFaceTB/SmolLM-360M-Instruct", + ), + "unsloth/SmolLM-1.7B-Instruct-bnb-4bit" : ( + "unsloth/SmolLM-1.7B-Instruct", + "HuggingFaceTB/SmolLM-1.7B-Instruct", + ), + "unsloth/Mistral-Small-Instruct-2409-bnb-4bit" : ( + "unsloth/Mistral-Small-Instruct-2409", + "mistralai/Mistral-Small-Instruct-2409", + ), + "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen2.5-0.5B-Instruct", + ), + "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-Instruct", + "Qwen/Qwen2.5-1.5B-Instruct", + ), + "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-Instruct", + "Qwen/Qwen2.5-3B-Instruct", + ), + "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-7B-Instruct", + ), + "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", + ), + "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-32B-Instruct", + "Qwen/Qwen2.5-32B-Instruct", + ), + "unsloth/Qwen2.5-72B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-72B-Instruct", + "Qwen/Qwen2.5-72B-Instruct", + ), + "unsloth/Qwen2.5-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B", + "Qwen/Qwen2.5-0.5B", + ), + "unsloth/Qwen2.5-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B", + "Qwen/Qwen2.5-1.5B", + ), + "unsloth/Qwen2.5-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-3B", + "Qwen/Qwen2.5-3B", + ), + "unsloth/Qwen2.5-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-7B", + "Qwen/Qwen2.5-7B", + ), + "unsloth/Qwen2.5-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-14B", + "Qwen/Qwen2.5-14B", + ), + "unsloth/Qwen2.5-32B-bnb-4bit" : ( + "unsloth/Qwen2.5-32B", + "Qwen/Qwen2.5-32B", + ), + "unsloth/Qwen2.5-72B-bnb-4bit" : ( + "unsloth/Qwen2.5-72B", + "Qwen/Qwen2.5-72B", + ), + "unsloth/Qwen2.5-Math-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-1.5B", + "Qwen/Qwen2.5-Math-1.5B", + ), + "unsloth/Qwen2.5-Math-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-7B", + "Qwen/Qwen2.5-Math-7B", + ), + "unsloth/Qwen2.5-Math-72B-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-72B", + "Qwen/Qwen2.5-Math-72B", + ), + "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-1.5B-Instruct", + "Qwen/Qwen2.5-Math-1.5B-Instruct", + ), + "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-7B-Instruct", + "Qwen/Qwen2.5-Math-7B-Instruct", + ), + "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-72B-Instruct", + "Qwen/Qwen2.5-Math-72B-Instruct", + ), + "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-0.5B", + "Qwen/Qwen2.5-Coder-0.5B", + ), + "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-1.5B", + "Qwen/Qwen2.5-Coder-1.5B", + ), + "unsloth/Qwen2.5-Coder-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B", + "Qwen/Qwen2.5-Coder-3B", + ), + "unsloth/Qwen2.5-Coder-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-7B", + "Qwen/Qwen2.5-Coder-7B", + ), + "unsloth/Qwen2.5-Coder-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B", + "Qwen/Qwen2.5-Coder-14B", + ), + "unsloth/Qwen2.5-Coder-32B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B", + "Qwen/Qwen2.5-Coder-32B", + ), + "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-Instruct-0.5B", + "Qwen/Qwen2.5-Coder-Instruct-0.5B", + ), + "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-Instruct-1.5B", + "Qwen/Qwen2.5-Coder-Instruct-1.5B", + ), + "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B-Instruct", + "Qwen/Qwen2.5-Coder-3B-Instruct", + ), + "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-7B-Instruct", + "Qwen/Qwen2.5-Coder-7B-Instruct", + ), + "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B-Instruct", + "Qwen/Qwen2.5-Coder-14B-Instruct", + ), + "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B-Instruct", + "Qwen/Qwen2.5-Coder-32B-Instruct", + ), + "unsloth/Llama-3.2-1B-bnb-4bit" : ( + "unsloth/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B", + ), + "unsloth/Llama-3.2-3B-bnb-4bit" : ( + "unsloth/Llama-3.2-3B", + "meta-llama/Llama-3.2-3B", + ), + "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.2-3B-Instruct", + ), + "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.1-Nemotron-70B-Instruct", + "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", + ), + } +else: + __INT_TO_FLOAT_MAPPER = \ { - "unsloth/mistral-7b-bnb-4bit" : ( - "unsloth/mistral-7b", - "mistralai/Mistral-7B-v0.1", - ), - "unsloth/llama-2-7b-bnb-4bit" : ( + "shashikanth-a/llama-2-7b-4bit" : ( "unsloth/llama-2-7b", "meta-llama/Llama-2-7b-hf", ), - "unsloth/llama-2-13b-bnb-4bit" : ( - "unsloth/llama-2-13b", - "meta-llama/Llama-2-13b-hf", - ), - "unsloth/codellama-34b-bnb-4bit" : ( - "codellama/CodeLlama-34b-hf", - ), - "unsloth/zephyr-sft-bnb-4bit" : ( - "unsloth/zephyr-sft", - "HuggingFaceH4/mistral-7b-sft-beta", - ), - "unsloth/tinyllama-bnb-4bit" : ( + "shashikanth-a/tinyllama-4bit" : ( "unsloth/tinyllama", "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", ), - "unsloth/tinyllama-chat-bnb-4bit" : ( + "shashikanth-a/tinyllama-chat-4bit" : ( "unsloth/tinyllama-chat", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ), - "unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : ( - "unsloth/mistral-7b-instruct-v0.1", - "mistralai/Mistral-7B-Instruct-v0.1", - ), - "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : ( - "unsloth/mistral-7b-instruct-v0.2", - "mistralai/Mistral-7B-Instruct-v0.2", - ), - "unsloth/llama-2-7b-chat-bnb-4bit" : ( - "unsloth/llama-2-7b-chat", - "meta-llama/Llama-2-7b-chat-hf", - ), - "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "shashikanth-a/llama-2-7b-chat-4bit" : ( "unsloth/llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf", ), - "unsloth/codellama-7b-bnb-4bit" : ( + "shashikanth-a/codellama-7b-4bit" : ( "unsloth/codellama-7b", "codellama/CodeLlama-7b-hf", ), - "unsloth/codellama-13b-bnb-4bit" : ( - "codellama/CodeLlama-13b-hf", - ), - "unsloth/yi-6b-bnb-4bit" : ( + "shashikanth-a/yi-6b-4bit" : ( "unsloth/yi-6b", "01-ai/Yi-6B", ), - "unsloth/solar-10.7b-bnb-4bit" : ( + "shashikanth-a/solar-10.7b-4bit" : ( "upstage/SOLAR-10.7B-v1.0", ), - "unsloth/gemma-7b-bnb-4bit" : ( - "unsloth/gemma-7b", - "google/gemma-7b", - ), - "unsloth/gemma-2b-bnb-4bit" : ( - "unsloth/gemma-2b", - "google/gemma-2b", - ), - "unsloth/gemma-7b-it-bnb-4bit" : ( - "unsloth/gemma-7b-it", - "google/gemma-7b-it", - ), - "unsloth/gemma-2b-bnb-4bit" : ( - "unsloth/gemma-2b-it", - "google/gemma-2b-it", - ), - "unsloth/mistral-7b-v0.2-bnb-4bit" : ( - "unsloth/mistral-7b-v0.2", - "alpindale/Mistral-7B-v0.2-hf", - ), - "unsloth/gemma-1.1-2b-it-bnb-4bit" : ( - "unsloth/gemma-1.1-2b-it", - "google/gemma-1.1-2b-it", - ), - "unsloth/gemma-1.1-7b-it-bnb-4bit" : ( - "unsloth/gemma-1.1-7b-it", - "google/gemma-1.1-7b-it", - ), - "unsloth/Starling-LM-7B-beta-bnb-4bit" : ( - "unsloth/Starling-LM-7B-beta", - "Nexusflow/Starling-LM-7B-beta", - ), - "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit" : ( - "unsloth/Hermes-2-Pro-Mistral-7B", - "NousResearch/Hermes-2-Pro-Mistral-7B", - ), - "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit" : ( - "unsloth/OpenHermes-2.5-Mistral-7B", - "teknium/OpenHermes-2.5-Mistral-7B", - ), - "unsloth/codegemma-2b-bnb-4bit" : ( - "unsloth/codegemma-2b", - "google/codegemma-2b", - ), - "unsloth/codegemma-7b-bnb-4bit" : ( - "unsloth/codegemma-7b", - "google/codegemma-7b", - ), - "unsloth/codegemma-7b-it-bnb-4bit" : ( - "unsloth/codegemma-7b-it", - "google/codegemma-7b-it", - ), - "unsloth/llama-3-8b-bnb-4bit" : ( + "shashikanth-a/llama-3-8b-4bit" : ( "unsloth/llama-3-8b", "meta-llama/Meta-Llama-3-8B", ), - "unsloth/llama-3-8b-Instruct-bnb-4bit" : ( + "shashikanth-a/llama-3-8b-Instruct-4bit" : ( "unsloth/llama-3-8b-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", ), - "unsloth/llama-3-70b-bnb-4bit" : ( - "meta-llama/Meta-Llama-3-70B", - ), - "unsloth/llama-3-70b-Instruct-bnb-4bit" : ( - "meta-llama/Meta-Llama-3-70B-Instruct", - ), - "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" : ( - "unsloth/Phi-3-mini-4k-instruct", - "microsoft/Phi-3-mini-4k-instruct", - ), - "unsloth/mistral-7b-v0.3-bnb-4bit" : ( - "unsloth/mistral-7b-v0.3", - "mistralai/Mistral-7B-v0.3", - ), - "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" : ( - "unsloth/mistral-7b-instruct-v0.3", - "mistralai/Mistral-7B-Instruct-v0.3", - ), - "unsloth/Phi-3-medium-4k-instruct-bnb-4bit" : ( - "unsloth/Phi-3-medium-4k-instruct", - "microsoft/Phi-3-medium-4k-instruct", - ), - "unsloth/Qwen2-0.5B-bnb-4bit" : ( - "unsloth/Qwen2-0.5B", - "Qwen/Qwen2-0.5B", - ), - "unsloth/Qwen2-0.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-0.5B-Instruct", - "Qwen/Qwen2-0.5B-Instruct", - ), - "unsloth/Qwen2-1.5B-bnb-4bit" : ( - "unsloth/Qwen2-1.5B", - "Qwen/Qwen2-1.5B", - ), - "unsloth/Qwen2-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-1.5B-Instruct", - "Qwen/Qwen2-1.5B-Instruct", - ), - "unsloth/Qwen2-7B-bnb-4bit" : ( - "unsloth/Qwen2-7B", - "Qwen/Qwen2-7B", - ), - "unsloth/Qwen2-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-7B-Instruct", - "Qwen/Qwen2-7B-Instruct", - ), - "unsloth/Qwen2-70B-bnb-4bit" : ( - "Qwen/Qwen2-70B", - ), - "unsloth/Qwen2-70B-Instruct-bnb-4bit" : ( - "Qwen/Qwen2-70B-Instruct", - ), - "mistralai/Codestral-22B-v0.1" : ( - "mistral-community/Codestral-22B-v0.1", - ), - "unsloth/gemma-2-9b-bnb-4bit" : ( - "unsloth/gemma-2-9b", - "google/gemma-2-9b", - ), - "unsloth/gemma-2-27b-bnb-4bit" : ( - "unsloth/gemma-2-27b", - "google/gemma-2-27b", - ), - "unsloth/gemma-2-9b-it-bnb-4bit" : ( - "unsloth/gemma-2-9b-it", - "google/gemma-2-9b-it", - ), - "unsloth/gemma-2-27b-it-bnb-4bit" : ( - "unsloth/gemma-2-27b-it", - "google/gemma-2-27b-it", - ), - "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit" : ( # Old Phi pre July - "unsloth/Phi-3-mini-4k-instruct-v0", - ), - "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit" : ( # New 12b Mistral models - "unsloth/Mistral-Nemo-Instruct-2407", - "mistralai/Mistral-Nemo-Instruct-2407", - ), - "unsloth/Mistral-Nemo-Base-2407-bnb-4bit" : ( # New 12b Mistral models - "unsloth/Mistral-Nemo-Base-2407", - "mistralai/Mistral-Nemo-Base-2407", - ), - "unsloth/Meta-Llama-3.1-8B-bnb-4bit" : ( + "shashikanth-a/Meta-Llama-3.1-8B-4bit" : ( "unsloth/Meta-Llama-3.1-8B", "meta-llama/Meta-Llama-3.1-8B", ), - "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" : ( + "shashikanth-a/Meta-Llama-3.1-8B-Instruct-4bit" : ( "unsloth/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct", ), - "unsloth/Meta-Llama-3.1-70B-bnb-4bit" : ( - "unsloth/Meta-Llama-3.1-70B", - "meta-llama/Meta-Llama-3.1-70B", - ), - "unsloth/Meta-Llama-3.1-405B-bnb-4bit" : ( - "meta-llama/Meta-Llama-3.1-405B", - ), - "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit" : ( - "meta-llama/Meta-Llama-3.1-405B-Instruct", - ), - "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : ( - "unsloth/Meta-Llama-3.1-70B-Instruct", - "meta-llama/Meta-Llama-3.1-70B-Instruct", - ), - "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : ( - "mistralai/Mistral-Large-Instruct-2407", - ), - "unsloth/gemma-2-2b-bnb-4bit" : ( - "unsloth/gemma-2-2b", - "google/gemma-2-2b", - ), - "unsloth/gemma-2-2b-it-bnb-4bit" : ( - "unsloth/gemma-2-2b-it", - "google/gemma-2-2b-it", - ), - "unsloth/Phi-3.5-mini-instruct-bnb-4bit" : ( - "unsloth/Phi-3.5-mini-instruct", - "microsoft/Phi-3.5-mini-instruct", - ), - "unsloth/c4ai-command-r-08-2024-bnb-4bit" : ( - "CohereForAI/c4ai-command-r-08-2024", - ), - "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : ( - "CohereForAI/c4ai-command-r-plus-08-2024", - ), - "unsloth/Llama-3.1-Storm-8B-bnb-4bit" : ( + "shashikanth-a/Llama-3.1-Storm-8B-4bit" : ( "unsloth/Llama-3.1-Storm-8B", "akjindal53244/Llama-3.1-Storm-8B", ), - "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : ( + "shashikanth-a/Hermes-3-Llama-3.1-8B-4bit" : ( "unsloth/Hermes-3-Llama-3.1-8B", "NousResearch/Hermes-3-Llama-3.1-8B", ), - "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : ( - "unsloth/Hermes-3-Llama-3.1-70B", - "NousResearch/Hermes-3-Llama-3.1-70B", - ), - "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : ( - "NousResearch/Hermes-3-Llama-3.1-405B", - ), - "unsloth/SmolLM-135M-bnb-4bit" : ( + "shashikanth-a/SmolLM-135M-4bit" : ( "unsloth/SmolLM-135M", "HuggingFaceTB/SmolLM-135M", ), - "unsloth/SmolLM-360M-bnb-4bit" : ( + "shashikanth-a/SmolLM-360M-4bit" : ( "unsloth/SmolLM-360M", "HuggingFaceTB/SmolLM-360M", ), - "unsloth/SmolLM-1.7B-bnb-4bit" : ( + "shashikanth-a/SmolLM-1.7B-4bit" : ( "unsloth/SmolLM-1.7B", "HuggingFaceTB/SmolLM-1.7B", ), - "unsloth/SmolLM-135M-Instruct-bnb-4bit" : ( + "shashikanth-a/SmolLM-135M-Instruct-4bit" : ( "unsloth/SmolLM-135M-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct", ), - "unsloth/SmolLM-360M-Instruct-bnb-4bit" : ( + "shashikanth-a/SmolLM-360M-Instruct-4bit" : ( "unsloth/SmolLM-360M-Instruct", "HuggingFaceTB/SmolLM-360M-Instruct", ), - "unsloth/SmolLM-1.7B-Instruct-bnb-4bit" : ( + "shashikanth-a/SmolLM-1.7B-Instruct-4bit" : ( "unsloth/SmolLM-1.7B-Instruct", "HuggingFaceTB/SmolLM-1.7B-Instruct", ), - "unsloth/Mistral-Small-Instruct-2409-bnb-4bit" : ( - "unsloth/Mistral-Small-Instruct-2409", - "mistralai/Mistral-Small-Instruct-2409", - ), - "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-0.5B-Instruct", - "Qwen/Qwen2.5-0.5B-Instruct", - ), - "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-1.5B-Instruct", - "Qwen/Qwen2.5-1.5B-Instruct", - ), - "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-3B-Instruct", - "Qwen/Qwen2.5-3B-Instruct", - ), - "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-7B-Instruct", - "Qwen/Qwen2.5-7B-Instruct", - ), - "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-14B-Instruct", - "Qwen/Qwen2.5-14B-Instruct", - ), - "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-32B-Instruct", - "Qwen/Qwen2.5-32B-Instruct", - ), - "unsloth/Qwen2.5-72B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-72B-Instruct", - "Qwen/Qwen2.5-72B-Instruct", - ), - "unsloth/Qwen2.5-0.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-0.5B", - "Qwen/Qwen2.5-0.5B", - ), - "unsloth/Qwen2.5-1.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-1.5B", - "Qwen/Qwen2.5-1.5B", - ), - "unsloth/Qwen2.5-3B-bnb-4bit" : ( - "unsloth/Qwen2.5-3B", - "Qwen/Qwen2.5-3B", - ), - "unsloth/Qwen2.5-7B-bnb-4bit" : ( - "unsloth/Qwen2.5-7B", - "Qwen/Qwen2.5-7B", - ), - "unsloth/Qwen2.5-14B-bnb-4bit" : ( - "unsloth/Qwen2.5-14B", - "Qwen/Qwen2.5-14B", - ), - "unsloth/Qwen2.5-32B-bnb-4bit" : ( - "unsloth/Qwen2.5-32B", - "Qwen/Qwen2.5-32B", - ), - "unsloth/Qwen2.5-72B-bnb-4bit" : ( - "unsloth/Qwen2.5-72B", - "Qwen/Qwen2.5-72B", - ), - "unsloth/Qwen2.5-Math-1.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-1.5B", - "Qwen/Qwen2.5-Math-1.5B", - ), - "unsloth/Qwen2.5-Math-7B-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-7B", - "Qwen/Qwen2.5-Math-7B", - ), - "unsloth/Qwen2.5-Math-72B-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-72B", - "Qwen/Qwen2.5-Math-72B", - ), - "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-1.5B-Instruct", - "Qwen/Qwen2.5-Math-1.5B-Instruct", - ), - "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-7B-Instruct", - "Qwen/Qwen2.5-Math-7B-Instruct", - ), - "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-72B-Instruct", - "Qwen/Qwen2.5-Math-72B-Instruct", - ), - "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-0.5B", - "Qwen/Qwen2.5-Coder-0.5B", - ), - "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-1.5B", - "Qwen/Qwen2.5-Coder-1.5B", - ), - "unsloth/Qwen2.5-Coder-3B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-3B", - "Qwen/Qwen2.5-Coder-3B", - ), - "unsloth/Qwen2.5-Coder-7B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-7B", - "Qwen/Qwen2.5-Coder-7B", - ), - "unsloth/Qwen2.5-Coder-14B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-14B", - "Qwen/Qwen2.5-Coder-14B", - ), - "unsloth/Qwen2.5-Coder-32B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-32B", - "Qwen/Qwen2.5-Coder-32B", - ), - "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-Instruct-0.5B", - "Qwen/Qwen2.5-Coder-Instruct-0.5B", - ), - "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-Instruct-1.5B", - "Qwen/Qwen2.5-Coder-Instruct-1.5B", - ), - "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-3B-Instruct", - "Qwen/Qwen2.5-Coder-3B-Instruct", - ), - "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-7B-Instruct", - "Qwen/Qwen2.5-Coder-7B-Instruct", - ), - "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-14B-Instruct", - "Qwen/Qwen2.5-Coder-14B-Instruct", - ), - "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-32B-Instruct", - "Qwen/Qwen2.5-Coder-32B-Instruct", - ), - "unsloth/Llama-3.2-1B-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-1B-4bit" : ( "unsloth/Llama-3.2-1B", "meta-llama/Llama-3.2-1B", ), - "unsloth/Llama-3.2-3B-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-3B-4bit" : ( "unsloth/Llama-3.2-3B", "meta-llama/Llama-3.2-3B", ), - "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-1B-Instruct-4bit" : ( "unsloth/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", ), - "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-3B-Instruct-4bit" : ( "unsloth/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", ), - "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : ( - "unsloth/Llama-3.1-Nemotron-70B-Instruct", - "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", - ), } INT_TO_FLOAT_MAPPER = {} From a246add92c76032d1b86cfa5ba8be20e193bc9f6 Mon Sep 17 00:00:00 2001 From: shashikanth Date: Fri, 29 Nov 2024 09:35:06 +0530 Subject: [PATCH 4/4] merge fixes --- unsloth-cli.py | 4 ++-- unsloth/__init__.py | 19 +++++++++-------- unsloth/mlx/mlx_utils.py | 4 ++-- unsloth/mlx/trainer/trainer.py | 8 ++++---- unsloth/models/__init__.py | 14 +++++++------ unsloth/models/_utils.py | 37 ++++++++++++++++++---------------- unsloth/save.py | 2 +- 7 files changed, 48 insertions(+), 40 deletions(-) diff --git a/unsloth-cli.py b/unsloth-cli.py index 6bd6a8fc..71b8507e 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -35,7 +35,6 @@ def run(args): import torch - from unsloth import FastLanguageModel from datasets import load_dataset from trl import SFTTrainer from transformers import TrainingArguments @@ -48,6 +47,7 @@ def run(args): import gc if not has_mps: + from unsloth import FastLanguageModel # Load model and tokenizer model, tokenizer = FastLanguageModel.from_pretrained( model_name=args.model_name, @@ -226,7 +226,7 @@ def formatting_prompts_func(examples): # Saving and pushing arguments save_group = parser.add_argument_group('💾 Save Model Options') - save_group.add_argument('--adapter_file', type=str, default="adapters.npz", help="Adapters file name") + save_group.add_argument('--adapter_file', type=str, default="adapters.safetensors", help="Adapters file name") save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory") save_group.add_argument('--save_model', action='store_true', help="Save the model after training") save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'") diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 5cacb6c7..7900697a 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -180,11 +180,14 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`") pass -from .models import * -from .save import * -from .chat_templates import * -from .tokenizer_utils import * -from .trainer import * - -# Patch TRL trainers for backwards compatibility -_patch_trl_trainer() +if not devices.has_mps: + from .models import * + from .save import * + from .chat_templates import * + from .tokenizer_utils import * + from .trainer import * + + # Patch TRL trainers for backwards compatibility + _patch_trl_trainer() +else: + from .models._utils import is_bfloat16_supported diff --git a/unsloth/mlx/mlx_utils.py b/unsloth/mlx/mlx_utils.py index 03da8246..22168da5 100644 --- a/unsloth/mlx/mlx_utils.py +++ b/unsloth/mlx/mlx_utils.py @@ -9,7 +9,7 @@ import mlx.core as mx import mlx.nn as nn -from unsloth.models.loader import get_model_name +from unsloth.models.loader_utils import get_model_name from .models import llama as models import transformers from huggingface_hub import snapshot_download,create_repo @@ -165,7 +165,7 @@ def load(model_path: str, tokenizer_config={}, weights = {} for wf in weight_files: - weights.update(mx.load(wf).items()) + weights.update(mx.load(wf)) model_class, model_args_class = get_model_classes(config=config) diff --git a/unsloth/mlx/trainer/trainer.py b/unsloth/mlx/trainer/trainer.py index ebbd9ce7..cd9c4bca 100644 --- a/unsloth/mlx/trainer/trainer.py +++ b/unsloth/mlx/trainer/trainer.py @@ -96,7 +96,7 @@ def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False) indices = np.random.permutation(len(batch_idx)) for i in indices: # Encode batch - batch = [tokenizer.encode(str(dataset[np.int32(indices[j]).item()])) for j in batch_idx[i]] + batch = [tokenizer.encode(str(dataset[j])) for j in batch_idx[i]] for b in batch: if b[-1] != tokenizer.eos_token_id: b.append(tokenizer.eos_token_id) @@ -312,11 +312,11 @@ def step(batch): # Save adapter weights if it % args.steps_per_save == 0: adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.savez(str(args.adapter_file), **adapter_weights) + mx.save_safetensors(str(args.adapter_file), adapter_weights) checkpoint = ( Path(args.adapter_file).parent / f"{it:07d}_{Path(args.adapter_file).name}" ) - mx.savez(str(checkpoint), **adapter_weights) + mx.save_safetensors(str(checkpoint), adapter_weights) print( f"Iter {it}: Saved adapter weights to " f"{args.adapter_file} and {checkpoint}." @@ -324,5 +324,5 @@ def step(batch): # Save final weights adapter_weights = dict(tree_flatten(model.trainable_parameters())) - mx.savez(str(args.adapter_file), **adapter_weights) + mx.save_safetensors(str(args.adapter_file), adapter_weights) print(f"Saved final weights to {args.adapter_file}.") diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 3230cdc2..d208fb34 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .loader import FastLanguageModel, FastVisionModel -from .llama import FastLlamaModel -from .mistral import FastMistralModel -from .qwen2 import FastQwen2Model -from .dpo import PatchDPOTrainer -from ._utils import is_bfloat16_supported +from unsloth import devices +if not devices.has_mps: + from .loader import FastLanguageModel, FastVisionModel + from .llama import FastLlamaModel + from .mistral import FastMistralModel + from .qwen2 import FastQwen2Model + from .dpo import PatchDPOTrainer + from ._utils import is_bfloat16_supported diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index ec0da431..8e0bf1cc 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -73,17 +73,19 @@ import numpy as np import warnings, subprocess, re, inspect, psutil, os, math from packaging.version import Version +from unsloth import devices from unsloth_zoo.tokenizer_utils import ( patch_tokenizer as _patch_tokenizer, ) -from unsloth_zoo.patching_utils import ( - patch_compiling_bitsandbytes, - patch_layernorm, - patch_torch_compile, - patch_model_and_tokenizer, - patch_compiled_autograd, -) +if not devices.has_mps: + from unsloth_zoo.patching_utils import ( + patch_compiling_bitsandbytes, + patch_layernorm, + patch_torch_compile, + patch_model_and_tokenizer, + patch_compiled_autograd, + ) from unsloth_zoo.gradient_checkpointing import ( Unsloth_Offloaded_Gradient_Checkpointer, unsloth_offloaded_gradient_checkpoint, @@ -106,10 +108,11 @@ from unsloth_zoo.vision_utils import ( process_vision_info, ) -from unsloth_zoo.compiler import ( - get_transformers_model_type, - unsloth_compile_transformers as _unsloth_compile_transformers, -) +if not devices.has_mps: + from unsloth_zoo.compiler import ( + get_transformers_model_type, + unsloth_compile_transformers as _unsloth_compile_transformers, + ) # ============================================= # Disable some warnings which can get annoying @@ -271,7 +274,6 @@ def _is_openai_available(): return False # ============================================= # Get Flash Attention v2 if Ampere (RTX 30xx, A100) -from unsloth import devices if not devices.has_mps: import bitsandbytes as bnb from transformers import AutoTokenizer @@ -452,11 +454,12 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile( - debug = UNSLOTH_COMPILE_DEBUG, - O3 = UNSLOTH_COMPILE_MAXIMUM, - ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, -) +if not devices.has_mps: + patch_torch_compile( + debug = UNSLOTH_COMPILE_DEBUG, + O3 = UNSLOTH_COMPILE_MAXIMUM, + ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, + ) torch_compile_options = { "epilogue_fusion" : True, diff --git a/unsloth/save.py b/unsloth/save.py index b15baf42..cb452218 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -18,6 +18,7 @@ from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit from peft.tuners.lora import Linear4bit as Peft_Linear4bit from peft.tuners.lora import Linear as Peft_Linear + from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias from typing import Optional, Callable, Union, List import torch import os @@ -25,7 +26,6 @@ import pickle import gc from transformers.models.llama.modeling_llama import logger -from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias import subprocess import psutil import re