diff --git a/unsloth-cli.py b/unsloth-cli.py index ddb0ac8b..71b8507e 100644 --- a/unsloth-cli.py +++ b/unsloth-cli.py @@ -31,38 +31,49 @@ import argparse +from unsloth.devices import has_mps + def run(args): import torch - from unsloth import FastLanguageModel from datasets import load_dataset from trl import SFTTrainer from transformers import TrainingArguments from unsloth import is_bfloat16_supported import logging logging.getLogger('hf-to-gguf').setLevel(logging.WARNING) + if has_mps: + from unsloth.mlx import mlx_utils + from unsloth.mlx import lora as mlx_lora + import gc + if not has_mps: + from unsloth import FastLanguageModel # Load model and tokenizer - model, tokenizer = FastLanguageModel.from_pretrained( - model_name=args.model_name, - max_seq_length=args.max_seq_length, - dtype=args.dtype, - load_in_4bit=args.load_in_4bit, - ) - + model, tokenizer = FastLanguageModel.from_pretrained( + model_name=args.model_name, + max_seq_length=args.max_seq_length, + dtype=args.dtype, + load_in_4bit=args.load_in_4bit, + ) + else: + print("Loading pretrained model") + model, tokenizer, config = mlx_utils.load_pretrained(args.model_name,dtype=args.dtype,load_in_4bit=args.load_in_4bit) + # Configure PEFT model - model = FastLanguageModel.get_peft_model( - model, - r=args.r, - target_modules=["q_proj", "k_proj", "v_proj", "o_proj", - "gate_proj", "up_proj", "down_proj"], - lora_alpha=args.lora_alpha, - lora_dropout=args.lora_dropout, - bias=args.bias, - use_gradient_checkpointing=args.use_gradient_checkpointing, - random_state=args.random_state, - use_rslora=args.use_rslora, - loftq_config=args.loftq_config, - ) + if not has_mps: + model = FastLanguageModel.get_peft_model( + model, + r=args.r, + target_modules=["q_proj", "k_proj", "v_proj", "o_proj", + "gate_proj", "up_proj", "down_proj"], + lora_alpha=args.lora_alpha, + lora_dropout=args.lora_dropout, + bias=args.bias, + use_gradient_checkpointing=args.use_gradient_checkpointing, + random_state=args.random_state, + use_rslora=args.use_rslora, + loftq_config=args.loftq_config, + ) alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request. @@ -110,19 +121,24 @@ def formatting_prompts_func(examples): ) # Initialize trainer - trainer = SFTTrainer( - model=model, - tokenizer=tokenizer, - train_dataset=dataset, - dataset_text_field="text", - max_seq_length=args.max_seq_length, - dataset_num_proc=2, - packing=False, - args=training_args, - ) + if not has_mps: + trainer = SFTTrainer( + model=model, + tokenizer=tokenizer, + train_dataset=dataset, + dataset_text_field="text", + max_seq_length=args.max_seq_length, + dataset_num_proc=2, + packing=False, + args=training_args, + ) # Train model - trainer_stats = trainer.train() + trainer_stats = trainer.train() + else: + datasets = dataset.train_test_split(test_size=0.1) + mlx_lora.train_model(args,model,tokenizer, datasets["train"], datasets["test"]) + # Save model if args.save_model: @@ -152,9 +168,16 @@ def formatting_prompts_func(examples): quantization_method=quantization_method, ) else: - model.save_pretrained_merged(args.save_path, tokenizer, args.save_method) - if args.push_model: - model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token) + if has_mps: + del model + gc.collect() + mlx_utils.save_merged_model(args) + if args.push_model: + mlx_utils.push_to_hub(args,config["_name_or_path"],config["model_type"]) + else: + model.save_pretrained_merged(args.save_path, tokenizer, args.save_method) + if args.push_model: + model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token) else: print("Warning: The model is not saved!") @@ -203,6 +226,7 @@ def formatting_prompts_func(examples): # Saving and pushing arguments save_group = parser.add_argument_group('💾 Save Model Options') + save_group.add_argument('--adapter_file', type=str, default="adapters.safetensors", help="Adapters file name") save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory") save_group.add_argument('--save_model', action='store_true', help="Save the model after training") save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'") diff --git a/unsloth/__init__.py b/unsloth/__init__.py index 980425e1..7900697a 100644 --- a/unsloth/__init__.py +++ b/unsloth/__init__.py @@ -16,6 +16,7 @@ from packaging.version import Version import os, re, subprocess, inspect import numpy as np +from unsloth import devices # # Define a list of modules to check # MODULES_TO_CHECK = ["bitsandbytes"] @@ -90,7 +91,11 @@ pass # Torch 2.4 has including_emulation -major_version, minor_version = torch.cuda.get_device_capability() +devices.get_optimal_device() +if torch.cuda.is_available(): + major_version, minor_version = torch.cuda.get_device_capability() +else: + major_version,minor_version = 0,0 SUPPORTS_BFLOAT16 = (major_version >= 8) old_is_bf16_supported = torch.cuda.is_bf16_supported @@ -104,7 +109,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 pass # Try loading bitsandbytes and triton -import bitsandbytes as bnb if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ: @@ -116,7 +120,9 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 else: from triton.common.build import libcuda_dirs try: - cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + if not devices.has_mps: + import bitsandbytes as bnb + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 libcuda_dirs() except: warnings.warn( @@ -141,8 +147,9 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 latest_cuda = possible_cudas[latest_cuda] os.system(f"ldconfig /usr/local/{latest_cuda}") pass - - importlib.reload(bnb) + if not devices.has_mps: + import bitsandbytes as bnb + importlib.reload(bnb) importlib.reload(triton) try: libcuda_dirs = lambda: None @@ -150,7 +157,9 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 try: from triton.backends.nvidia.driver import libcuda_dirs except: pass else: from triton.common.build import libcuda_dirs - cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + if not devices.has_mps: + import bitsandbytes as bnb + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 libcuda_dirs() except: warnings.warn( @@ -171,11 +180,14 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16 raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`") pass -from .models import * -from .save import * -from .chat_templates import * -from .tokenizer_utils import * -from .trainer import * +if not devices.has_mps: + from .models import * + from .save import * + from .chat_templates import * + from .tokenizer_utils import * + from .trainer import * -# Patch TRL trainers for backwards compatibility -_patch_trl_trainer() + # Patch TRL trainers for backwards compatibility + _patch_trl_trainer() +else: + from .models._utils import is_bfloat16_supported diff --git a/unsloth/devices.py b/unsloth/devices.py new file mode 100644 index 00000000..9190c5a0 --- /dev/null +++ b/unsloth/devices.py @@ -0,0 +1,49 @@ +import sys + +import torch + +if sys.platform == "darwin": + from unsloth import mac_specific + + +def has_mps() -> bool: + if sys.platform != "darwin": + return False + else: + return mac_specific.has_mps + + +def get_cuda_device_string(): + return "cuda" + + +def get_optimal_device_name(): + if torch.cuda.is_available(): + return get_cuda_device_string() + + if has_mps(): + return "mps" + + return "cpu" + + +def get_optimal_device(): + return torch.device(get_optimal_device_name()) + + + +def torch_gc(): + + if torch.cuda.is_available(): + with torch.cuda.device(get_cuda_device_string()): + torch.cuda.empty_cache() + torch.cuda.ipc_collect() + + if has_mps(): + mac_specific.torch_mps_gc() + + + + + + diff --git a/unsloth/kernels/utils.py b/unsloth/kernels/utils.py index de543962..130fc90c 100644 --- a/unsloth/kernels/utils.py +++ b/unsloth/kernels/utils.py @@ -59,19 +59,23 @@ def calculate_settings(n : int) -> (int, int,): pass -import bitsandbytes as bnb +from unsloth import devices + # https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files -HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") +HAS_CUDA_STREAM = False global CUDA_STREAM CUDA_STREAM = None -get_ptr = bnb.functional.get_ptr -import ctypes -cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 -cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 -cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 -cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 -cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 +if not devices.has_mps: + import bitsandbytes as bnb + HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3") + get_ptr = bnb.functional.get_ptr + cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32 + cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4 + cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4 + cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16 + cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16 +import ctypes def QUANT_STATE(W): return getattr(W, "quant_state", None) diff --git a/unsloth/mac_specific.py b/unsloth/mac_specific.py new file mode 100644 index 00000000..6ef92416 --- /dev/null +++ b/unsloth/mac_specific.py @@ -0,0 +1,71 @@ +import logging + +import torch +import platform +from unsloth.sd_hijack_utils import CondFunc +from packaging import version + +log = logging.getLogger(__name__) + + +# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+, +# use check `getattr` and try it for compatibility. +# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty, +# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279 +def check_for_mps() -> bool: + if version.parse(torch.__version__) <= version.parse("2.0.1"): + if not getattr(torch, 'has_mps', False): + return False + try: + torch.zeros(1).to(torch.device("mps")) + return True + except Exception: + return False + else: + return torch.backends.mps.is_available() and torch.backends.mps.is_built() + + +has_mps = check_for_mps() + + +# MPS workaround for https://github.com/pytorch/pytorch/issues/89784 +def cumsum_fix(input, cumsum_func, *args, **kwargs): + if input.device.type == 'mps': + output_dtype = kwargs.get('dtype', input.dtype) + if output_dtype == torch.int64: + return cumsum_func(input.cpu(), *args, **kwargs).to(input.device) + elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16): + return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64) + return cumsum_func(input, *args, **kwargs) + + +if has_mps: + if platform.mac_ver()[0].startswith("13.2."): + # MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124) + CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760) + + if version.parse(torch.__version__) < version.parse("1.13"): + # PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working + + # MPS workaround for https://github.com/pytorch/pytorch/issues/79383 + CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs), + lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps')) + # MPS workaround for https://github.com/pytorch/pytorch/issues/80800 + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs), + lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps') + # MPS workaround for https://github.com/pytorch/pytorch/issues/90532 + CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad) + elif version.parse(torch.__version__) > version.parse("1.13.1"): + cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0)) + cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs) + CondFunc('torch.cumsum', cumsum_fix_func, None) + CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None) + CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None) + + # MPS workaround for https://github.com/pytorch/pytorch/issues/96113 + CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps') + + # MPS workaround for https://github.com/pytorch/pytorch/issues/92311 + if platform.processor() == 'i386': + for funcName in ['torch.argmax', 'torch.Tensor.argmax']: + CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps') diff --git a/unsloth/mlx/lora.py b/unsloth/mlx/lora.py new file mode 100644 index 00000000..3e423052 --- /dev/null +++ b/unsloth/mlx/lora.py @@ -0,0 +1,105 @@ +from pathlib import Path +import math +import mlx.nn as nn +import mlx.optimizers as optim +import mlx.core as mx +from .trainer.trainer import TrainingArgs, TrainingCallback, train +from .trainer.utils import ( + build_schedule, + linear_to_lora_layers, + print_trainable_parameters, +) +from .mlx_utils import save_config + +def train_model( + args, + model: nn.Module, + tokenizer, + train_set, + valid_set, + training_callback: TrainingCallback = None, +): + model.freeze() + linear_to_lora_layers( + model, + min(args.r,len(model.layers)/2), + {"rank": args.r, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "scale": float(args.lora_alpha)/math.sqrt(float(args.r)) if args.use_rslora else float(args.lora_alpha)/float(args.r)}, + ) + print_trainable_parameters(model) + + adapter_path = Path(args.save_path) + adapter_path.mkdir(parents=True, exist_ok=True) + + adapter_file = adapter_path / args.adapter_file + config = { + "num_layers" : min(args.r,len(model.layers)/2), + "lora_parameters" : {"rank": args.r, "alpha": args.lora_alpha, "dropout": args.lora_dropout, "scale": float(args.lora_alpha)/math.sqrt(float(args.r)) if args.use_rslora else float(args.lora_alpha)/float(args.r)} + } + save_config(config, adapter_path / "adapter_config.json") + + # init training args + training_args = TrainingArgs( + batch_size=args.per_device_train_batch_size, + iters=args.max_steps, + val_batches=25, + steps_per_report=10, + steps_per_eval=200, + steps_per_save=100, + adapter_file=adapter_file, + max_seq_length=args.max_seq_length, + grad_checkpoint=args.use_gradient_checkpointing, + ) + + mx.random.seed(args.seed) + model.train() + if args.lr_scheduler_type == "linear": + arguments = [0.0,args.learning_rate,args.warmup_steps] + elif args.lr_scheduler_type == "exponential_decay": + arguments = [args.learning_rate,args.weight_decay] + elif args.lr_scheduler_type == "step_decay": + arguments = [args.learning_rate,args.weight_decay,args.warmup_steps] + elif args.lr_scheduler_type == "cosine_decay": + arguments = [args.learning_rate,args.max_steps] + else: + arguments = [args.learning_rate] + + schedule_config = { + "name": "linear_schedule" if args.lr_scheduler_type == "linear" else args.lr_scheduler_type, + "warmup": args.warmup_steps, + "arguments": arguments, + } + + lr = build_schedule(schedule_config) if args.lr_scheduler_type else args.learning_rate + + if args.optim.lower().startswith("sgd"): + opt = optim.SGD(learning_rate=(lr), weight_decay=args.weight_decay) + elif args.optim.lower().startswith("rmsprop"): + opt = optim.RMSprop(learning_rate=(lr)) + elif args.optim.lower().startswith("adagrad"): + opt = optim.Adagrad(learning_rate=(lr)) + elif args.optim.lower().startswith("adaDelta"): + opt = optim.AdaDelta(learning_rate=(lr)) + elif args.optim.lower().startswith("adamw"): + opt = optim.AdamW(learning_rate=(lr),weight_decay=args.weight_decay) + elif args.optim.lower().startswith("adam"): + opt = optim.Adam(learning_rate=(lr)) + elif args.optim.lower().startswith("adamax"): + opt = optim.Adamax(learning_rate=(lr)) + elif args.optim.lower().startswith("lion"): + opt = optim.Lion(learning_rate=(lr), weight_decay=args.weight_decay) + elif args.optim.lower().startswith("adafactor"): + opt = optim.Adafactor(learning_rate=(lr), weight_decay= args.weight_decay) + else: + raise ValueError("The Optimizer type provided is not supported") + + # Train model + train( + model=model, + tokenizer=tokenizer, + args=training_args, + optimizer=opt, + train_dataset=train_set, + val_dataset=valid_set, + training_callback=training_callback, + ) + diff --git a/unsloth/mlx/mlx_utils.py b/unsloth/mlx/mlx_utils.py new file mode 100644 index 00000000..22168da5 --- /dev/null +++ b/unsloth/mlx/mlx_utils.py @@ -0,0 +1,350 @@ +import gc +import glob +import shutil +import json +import logging +from pathlib import Path +from typing import Generator, Optional,Type, Callable, Tuple, Union + +import mlx.core as mx +import mlx.nn as nn + +from unsloth.models.loader_utils import get_model_name +from .models import llama as models +import transformers +from huggingface_hub import snapshot_download,create_repo +from unsloth.save import MODEL_CARD +from mlx.utils import tree_flatten, tree_unflatten +from .trainer.utils import load_adapters + +MODEL_REMAPPING = { + "mistral": "llama", # mistral is compatible with llama +} + + +def fetch_from_hub(hf_path: str): + model_path = snapshot_download( + repo_id=hf_path, + allow_patterns=["*.json", "*.safetensors", "tokenizer.model"], + ) + weight_files = glob.glob(f"{model_path}/*.safetensors") + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf).items()) + + config = transformers.AutoConfig.from_pretrained(hf_path) + tokenizer = transformers.AutoTokenizer.from_pretrained( + hf_path, + ) + return weights, config.to_dict(), tokenizer + + +def upload_to_hub(name_or_path,model_type,username,path: str, name: str, token : str): + import os + + from huggingface_hub import HfApi, ModelCard, logging + + repo_id = f"{name}" + + try: + create_repo( + repo_id = repo_id, + token = token, + repo_type = "model", + exist_ok = False, + private = None, + ) + except: + pass + + try: + content = MODEL_CARD.format( + username = username, + base_model = name_or_path, + model_type = model_type, + method = "", + extra = "unsloth", + ) + card = ModelCard(content) + card.push_to_hub(repo_id, token = token) + except: + pass + logging.set_verbosity_info() + + api = HfApi() + api.create_repo(repo_id=repo_id, exist_ok=True,token = token) + api.upload_folder( + folder_path=path, + path_in_repo = ".", + token=token, + repo_id=repo_id, + commit_message = "(Trained with Unsloth)", + repo_type="model" + ) + + + +def make_shards(weights: dict, max_file_size_gibibyte: int = 15): + max_file_size_bytes = max_file_size_gibibyte << 30 + shards = [] + shard, shard_size = {}, 0 + for k, v in weights.items(): + if shard_size + v.nbytes > max_file_size_bytes: + shards.append(shard) + shard, shard_size = {}, 0 + shard[k] = v + shard_size += v.nbytes + shards.append(shard) + return shards + + +def save_model(save_dir: str, weights, tokenizer, config): + save_dir = Path(save_dir) + save_dir.mkdir(parents=True, exist_ok=True) + + shards = make_shards(weights, max_file_size_gibibyte=1) + shards_count = len(shards) + shard_file_format = ( + "model-{:05d}-of-{:05d}.safetensors" + if shards_count > 1 + else "model.safetensors" + ) + + total_size = sum(v.nbytes for v in weights.values()) + index_data = {"metadata": {"total_size": total_size}, "weight_map": {}} + + for i, shard in enumerate(shards): + shard_name = shard_file_format.format(i + 1, shards_count) + mx.save_safetensors( + str(save_dir / shard_name), shard, metadata={"format": "mlx"} + ) + for weight_name in shard.keys(): + index_data["weight_map"][weight_name] = shard_name + del shard + + tokenizer.save_pretrained(save_dir) + with open(save_dir / "config.json", "w") as fid: + json.dump(config, fid, indent=4) + + index_data["weight_map"] = { + k: index_data["weight_map"][k] for k in sorted(index_data["weight_map"]) + } + with open(save_dir / "model.safetensors.index.json", "w") as f: + json.dump( + index_data, + f, + indent=4, + ) + +def _get_classes(config: dict): + model_type = config["model_type"] + if model_type != "llama" and MODEL_REMAPPING.get(model_type,model_type) != "llama": + msg = f"Model type {model_type} not supported." + logging.error(msg) + raise ValueError(msg) + + return models.Model, models.ModelArgs + +def load(model_path: str, tokenizer_config={}, + get_model_classes: Callable[[dict], Tuple[Type[nn.Module], Type]] = _get_classes,): + + with open(model_path / "config.json", "r") as f: + config = json.loads(f.read()) + quantization = config.get("quantization", None) + + weight_files = glob.glob(str(model_path / "*.safetensors")) + if not weight_files: + # Try weight for back-compat + weight_files = glob.glob(str(model_path / "weight*.safetensors")) + + if len(weight_files) == 0: + raise FileNotFoundError("No safetensors found in {}".format(model_path)) + + weights = {} + for wf in weight_files: + weights.update(mx.load(wf)) + + model_class, model_args_class = get_model_classes(config=config) + + model_args = model_args_class.from_dict(config) + model = model_class(model_args) + + if hasattr(model, "sanitize"): + weights = model.sanitize(weights) + + if quantization is not None: + class_predicate = ( + lambda p, m: isinstance(m, (nn.Linear, nn.Embedding)) + and f"{p}.scales" in weights + ) + nn.quantize( + model, + **quantization, + class_predicate=class_predicate, + ) + + model.load_weights(list(weights.items())) + + # mx.eval(model.parameters()) + model.eval() + + tokenizer = transformers.AutoTokenizer.from_pretrained( + model_path, **tokenizer_config + ) + return model, tokenizer, config + + +def save_config( + config: dict, + config_path: Union[str, Path], +) -> None: + """Save the model configuration to the ``config_path``. + + The final configuration will be sorted before saving for better readability. + + Args: + config (dict): The model configuration. + config_path (Union[str, Path]): Model configuration file path. + """ + # Clean unused keys + config.pop("_name_or_path", None) + + # sort the config for better readability + config = dict(sorted(config.items())) + + # write the updated config to the config_path (if provided) + with open(config_path, "w") as fid: + json.dump(config, fid, indent=4) + + + +def generate( + prompt: mx.array, model: nn.Module, temp: float = 0.0 +) -> Generator[mx.array, None, None]: + """ + Generate text based on the given prompt and model. + + Args: + prompt (mx.array): The input prompt. + model (nn.Module): The model to use for generation. + temp (float): The temperature for sampling. If temp is 0, use max sampling. + + Yields: + mx.array: The generated text. + """ + + def sample(logits: mx.array) -> mx.array: + return ( + mx.argmax(logits, axis=-1) + if temp == 0 + else mx.random.categorical(logits * (1 / temp)) + ) + + y = prompt + cache = None + while True: + logits, cache = model(y[None], cache=cache) + logits = logits[:, -1, :] + y = sample(logits) + yield y + +def save_merged_model(args): + model_name = get_model_name(args.model_name,args.load_in_4bit) + model_path = get_model_path(model_name) + model, tokenizer, config = load(model_path) + model.freeze() + + # Load the LoRA adapter weights which we assume should exist by this point + if not Path(args.save_path,args.adapter_file).is_file(): + raise ValueError( + f"Adapter file {args.adapter_file} missing. ") + + model = load_adapters(model, args.save_path,args.adapter_file) + + fused_linears = [ + (n, m.fuse()) for n, m in model.named_modules() if hasattr(m, "fuse") + ] + + if fused_linears: + model.update_modules(tree_unflatten(fused_linears)) + + weights = dict(tree_flatten(model.parameters())) + + save_model(args.save_path, weights, tokenizer, config) + + mx.metal.clear_cache() + del model + gc.collect() + + +def push_to_hub(args,name, model_type): + if args.push_model: + from huggingface_hub import whoami + try: + username = whoami(token = args.hub_token)["name"] + except: + raise RuntimeError( + "Unsloth: Please supply a token!\n"\ + "Go to https://huggingface.co/settings/tokens" + ) + pass + pass + + if args.push_model and args.hub_path is not None: + hf_path = args.hub_path + if not Path(args.model_name).exists(): + # If the model path doesn't exist, assume it's an HF repo + hf_path = args.model_name + elif hf_path is None: + raise ValueError( + "Must provide original Hugging Face repo to upload local model." + ) + upload_to_hub(name,model_type,username,args.save_path, args.hub_path,args.hub_token) + + +def get_model_path(path_or_hf_repo: str, revision: Optional[str] = None) -> Path: + model_path = Path(path_or_hf_repo) + if not model_path.exists(): + try: + model_path = Path( + snapshot_download( + repo_id=path_or_hf_repo, + revision=revision, + allow_patterns=[ + "*.json", + "*.safetensors", + "*.py", + "tokenizer.model", + "*.tiktoken", + "*.txt", + ], + ) + ) + except: + raise FileNotFoundError( + f"Model not found for path or HF repo: {path_or_hf_repo}.\n" + "Please make sure you specified the local path or Hugging Face" + " repo id correctly.\nIf you are trying to access a private or" + " gated Hugging Face repo, make sure you are authenticated:\n" + "https://huggingface.co/docs/huggingface_hub/en/guides/cli#huggingface-cli-login" + ) from None + return model_path + + + +def load_pretrained( + model_name: str, + tokenizer_config={}, + model_config={}, + dtype= None, + load_in_4bit=True +): + model_name = get_model_name(model_name,load_in_4bit) + model_path = get_model_path(model_name) + + model,tokenizer, config = load(model_path, tokenizer_config) + + return model, tokenizer, config \ No newline at end of file diff --git a/unsloth/mlx/models/base.py b/unsloth/mlx/models/base.py new file mode 100644 index 00000000..72b7e23f --- /dev/null +++ b/unsloth/mlx/models/base.py @@ -0,0 +1,111 @@ +import inspect +from dataclasses import dataclass +from typing import Any, Optional + +import mlx.core as mx +from mlx.utils import tree_map + +from .cache import QuantizedKVCache + + +@dataclass +class BaseModelArgs: + @classmethod + def from_dict(cls, params): + return cls( + **{ + k: v + for k, v in params.items() + if k in inspect.signature(cls).parameters + } + ) + + +def create_causal_mask(N: int, offset: int = 0, window_size: Optional[int] = None): + rinds = mx.arange(offset + N) + linds = mx.arange(offset, offset + N) if offset else rinds + linds = linds[:, None] + rinds = rinds[None] + mask = linds < rinds + if window_size is not None: + mask = mask | (linds > rinds + window_size) + return mask * -1e9 + + +def create_attention_mask(h: mx.array, cache: Optional[Any] = None): + T = h.shape[1] + if T > 1: + window_size = None + offset = 0 + if cache is not None and cache[0] is not None: + c = cache[0] + if hasattr(c, "max_size"): + offset = min(c.max_size, c.offset) + window_size = c.max_size + else: + offset = c.offset + mask = create_causal_mask(T, offset, window_size=window_size) + mask = mask.astype(h.dtype) + else: + mask = None + return mask + + +def quantized_scaled_dot_product_attention( + queries: mx.array, + q_keys: tuple[mx.array, mx.array, mx.array], + q_values: tuple[mx.array, mx.array, mx.array], + scale: float, + mask: Optional[mx.array], + group_size: int = 64, + bits: int = 8, +) -> mx.array: + B, n_q_heads, L, D = queries.shape + n_kv_heads = q_keys[0].shape[-3] + n_repeats = n_q_heads // n_kv_heads + + queries *= scale + + if n_repeats > 1: + queries = mx.reshape(queries, (B, n_kv_heads, n_repeats, L, D)) + q_keys = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_keys) + q_values = tree_map(lambda x: mx.expand_dims(x, axis=-3), q_values) + + scores = mx.quantized_matmul( + queries, *q_keys, transpose=True, group_size=group_size, bits=bits + ) + if mask is not None: + scores += mask + scores = mx.softmax(scores, axis=-1, precise=True) + out = mx.quantized_matmul( + scores, *q_values, transpose=False, group_size=group_size, bits=bits + ) + + if n_repeats > 1: + out = mx.reshape(out, (B, n_q_heads, L, D)) + + return out + + +def scaled_dot_product_attention( + queries, + keys, + values, + cache, + scale: float, + mask: Optional[mx.array], +) -> mx.array: + if isinstance(cache, QuantizedKVCache): + return quantized_scaled_dot_product_attention( + queries, + keys, + values, + scale=scale, + mask=mask, + group_size=cache.group_size, + bits=cache.bits, + ) + else: + return mx.fast.scaled_dot_product_attention( + queries, keys, values, scale=scale, mask=mask + ) diff --git a/unsloth/mlx/models/cache.py b/unsloth/mlx/models/cache.py new file mode 100644 index 00000000..e87b113e --- /dev/null +++ b/unsloth/mlx/models/cache.py @@ -0,0 +1,116 @@ + +from typing import Any, Dict, List, Optional + +import mlx.core as mx +import mlx.nn as nn +from mlx.utils import tree_map + + + +class _BaseCache: + @property + def state(self): + return [] + + @state.setter + def state(self, v): + if v is not None and v: + raise ValueError("This cache has no state but a state was set.") + + @property + def meta_state(self): + return "" + + @meta_state.setter + def meta_state(self, v): + if v is not None and v: + raise ValueError("This cache has no meta_state but a meta_state was set.") + + def is_trimmable(self): + return False + + +class QuantizedKVCache(_BaseCache): + def __init__(self, group_size: int = 64, bits: int = 8): + self.keys = None + self.values = None + self.offset = 0 + self.step = 256 + self.group_size = group_size + self.bits = bits + + def update_and_fetch(self, keys, values): + B, n_kv_heads, num_steps, k_head_dim = keys.shape + v_head_dim = values.shape[-1] + prev = self.offset + + if self.keys is None or (prev + num_steps) > self.keys[0].shape[-2]: + el_per_int = 8 * mx.uint32.size // self.bits + new_steps = (self.step + num_steps - 1) // self.step * self.step + shape = (B, n_kv_heads, new_steps) + + def init_quant(dim): + return ( + mx.zeros((*shape, dim // el_per_int), dtype=mx.uint32), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + mx.zeros((*shape, dim // self.group_size), dtype=keys.dtype), + ) + + def expand_quant(x): + new_x = mx.zeros((*shape, x.shape[-1]), dtype=x.dtype) + return mx.concatenate([x, new_x], axis=-2) + + if self.keys is not None: + if prev % self.step != 0: + self.keys, self.values = tree_map( + lambda x: x[..., :prev, :], (self.keys, self.values) + ) + + self.keys, self.values = tree_map( + expand_quant, (self.keys, self.values) + ) + else: + self.keys, self.values = init_quant(k_head_dim), init_quant(v_head_dim) + + self.offset += num_steps + + keys = mx.quantize(keys, group_size=self.group_size, bits=self.bits) + values = mx.quantize(values, group_size=self.group_size, bits=self.bits) + for i in range(len(self.keys)): + self.keys[i][..., prev : self.offset, :] = keys[i] + self.values[i][..., prev : self.offset, :] = values[i] + + return tree_map(lambda x: x[..., : self.offset, :], (self.keys, self.values)) + + @property + def state(self): + if self.offset == self.keys[0].shape[2]: + return self.keys, self.values + else: + return tree_map( + lambda x: x[..., : self.offset, :], (self.keys, self.values) + ) + + @state.setter + def state(self, v): + self.keys, self.values = v + + @property + def meta_state(self): + return tuple(map(str, (self.step, self.offset, self.group_size, self.bits))) + + @meta_state.setter + def meta_state(self, v): + self.step, self.offset, self.group_size, self.bits = map(int, v) + + def is_trimmable(self): + return True + + def trim(self, n): + n = min(self.offset, n) + self.offset -= n + return n + + + + diff --git a/unsloth/mlx/models/llama.py b/unsloth/mlx/models/llama.py new file mode 100644 index 00000000..7b85b686 --- /dev/null +++ b/unsloth/mlx/models/llama.py @@ -0,0 +1,304 @@ +from dataclasses import dataclass +from typing import Any, Dict, Optional, Union + +import mlx.core as mx +import mlx.nn as nn + +from .base import BaseModelArgs, create_attention_mask, scaled_dot_product_attention + + +@dataclass +class ModelArgs(BaseModelArgs): + model_type: str + hidden_size: int + num_hidden_layers: int + intermediate_size: int + num_attention_heads: int + rms_norm_eps: float + vocab_size: int + head_dim: Optional[int] = None + max_position_embeddings: Optional[int] = None + num_key_value_heads: Optional[int] = None + attention_bias: bool = False + mlp_bias: bool = False + rope_theta: float = 10000 + rope_traditional: bool = False + rope_scaling: Optional[Dict[str, Union[float, str]]] = None + tie_word_embeddings: bool = True + + def __post_init__(self): + if self.num_key_value_heads is None: + self.num_key_value_heads = self.num_attention_heads + + if self.rope_scaling: + if not "factor" in self.rope_scaling: + raise ValueError(f"rope_scaling must contain 'factor'") + rope_type = self.rope_scaling.get("type") or self.rope_scaling.get( + "rope_type" + ) + if rope_type is None: + raise ValueError( + f"rope_scaling must contain either 'type' or 'rope_type'" + ) + if rope_type not in ["linear", "dynamic", "llama3"]: + raise ValueError( + "rope_scaling 'type' currently only supports 'linear', 'dynamic' or 'llama3'" + ) + + +class DynamicNTKScalingRoPE(nn.Module): + """Implements the rotary positional encoding with Dynamic NTK scaling and Llama 3 RoPE.""" + + def __init__( + self, + dims: int, + max_position_embeddings: int = 2048, + traditional: bool = False, + base: float = 10000, + scale: float = 1.0, + rope_type: str = "default", + rope_scaling: dict = None, + ): + super().__init__() + self.dims = dims + self.max_position_embeddings = max_position_embeddings + self.traditional = traditional + self.scale = scale + self.rope_type = rope_type + self.rope_scaling = rope_scaling + self.base = base + self.compute_freqs() + + def compute_freqs(self): + if self.rope_type != "llama3": + self._freqs = None + return + factor = self.rope_scaling["factor"] + low_freq_factor = self.rope_scaling.get("low_freq_factor", 1.0) + high_freq_factor = self.rope_scaling.get("high_freq_factor", 4.0) + old_context_len = self.rope_scaling.get( + "original_max_position_embeddings", + 8192, + ) + + low_freq_wavelen = old_context_len / low_freq_factor + high_freq_wavelen = old_context_len / high_freq_factor + + freqs = self.base ** (mx.arange(0, self.dims, 2) / self.dims) + wavelens = 2 * mx.pi * freqs + + freqs = mx.where(wavelens > low_freq_wavelen, freqs * factor, freqs) + is_medium_freq = (wavelens > high_freq_wavelen) & (wavelens < low_freq_wavelen) + smooth_factors = (old_context_len / wavelens - low_freq_factor) / ( + high_freq_factor - low_freq_factor + ) + smooth_freqs = freqs / ((1 - smooth_factors) / factor + smooth_factors) + self._freqs = mx.where(is_medium_freq, smooth_freqs, freqs) + self.base = None + + def extra_repr(self): + return ( + f"{self.dims}, traditional={self.traditional}, " + f"max_position_embeddings={self.max_position_embeddings}, " + f"scaling_factor={self.scale}, rope_type={self.rope_type}" + ) + + def __call__(self, x, offset: int = 0): + return mx.fast.rope( + x, + self.dims, + traditional=self.traditional, + base=self.base, + scale=self.scale, + offset=offset, + freqs=self._freqs, + ) + + +def initialize_rope(args: ModelArgs): + head_dim = args.head_dim or args.hidden_size // args.num_attention_heads + + rope_scaling = args.rope_scaling + rope_type = "default" + rope_scale = 1.0 + + if rope_scaling is not None: + rope_type = ( + rope_scaling.get("type") or rope_scaling.get("rope_type") or "default" + ) + if rope_type == "linear": + rope_scale = 1 / rope_scaling["factor"] + elif rope_type == "llama3": + rope_scale = 1.0 # The scaling is handled internally for llama3 + + return DynamicNTKScalingRoPE( + dims=head_dim, + max_position_embeddings=args.max_position_embeddings, + traditional=args.rope_traditional, + base=args.rope_theta, + scale=rope_scale, + rope_type=rope_type, + rope_scaling=rope_scaling, + ) + + +class Attention(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + self.n_heads = n_heads = args.num_attention_heads + self.n_kv_heads = n_kv_heads = args.num_key_value_heads + + self.head_dim = head_dim = args.head_dim or args.hidden_size // n_heads + + self.scale = head_dim**-0.5 + if hasattr(args, "attention_bias"): + attention_bias = args.attention_bias + else: + attention_bias = False + + self.q_proj = nn.Linear(dim, n_heads * head_dim, bias=attention_bias) + self.k_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.v_proj = nn.Linear(dim, n_kv_heads * head_dim, bias=attention_bias) + self.o_proj = nn.Linear(n_heads * head_dim, dim, bias=attention_bias) + + self.rope = initialize_rope(args) + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + B, L, D = x.shape + + queries, keys, values = self.q_proj(x), self.k_proj(x), self.v_proj(x) + + # Prepare the queries, keys and values for the attention computation + queries = queries.reshape(B, L, self.n_heads, -1).transpose(0, 2, 1, 3) + keys = keys.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + values = values.reshape(B, L, self.n_kv_heads, -1).transpose(0, 2, 1, 3) + + if cache is not None: + queries = self.rope(queries, offset=cache.offset) + keys = self.rope(keys, offset=cache.offset) + keys, values = cache.update_and_fetch(keys, values) + else: + queries = self.rope(queries) + keys = self.rope(keys) + + output = scaled_dot_product_attention( + queries, keys, values, cache=cache, scale=self.scale, mask=mask + ) + + output = output.transpose(0, 2, 1, 3).reshape(B, L, -1) + return self.o_proj(output) + + +class MLP(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + + dim = args.hidden_size + hidden_dim = args.intermediate_size + if hasattr(args, "mlp_bias"): + mlp_bias = args.mlp_bias + else: + mlp_bias = False + + self.gate_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + self.down_proj = nn.Linear(hidden_dim, dim, bias=mlp_bias) + self.up_proj = nn.Linear(dim, hidden_dim, bias=mlp_bias) + + def __call__(self, x) -> mx.array: + return self.down_proj(nn.silu(self.gate_proj(x)) * self.up_proj(x)) + + +class TransformerBlock(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.num_attention_heads = args.num_attention_heads + self.hidden_size = args.hidden_size + self.self_attn = Attention(args) + self.mlp = MLP(args) + self.input_layernorm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + self.post_attention_layernorm = nn.RMSNorm( + args.hidden_size, eps=args.rms_norm_eps + ) + self.args = args + + def __call__( + self, + x: mx.array, + mask: Optional[mx.array] = None, + cache: Optional[Any] = None, + ) -> mx.array: + r = self.self_attn(self.input_layernorm(x), mask, cache) + h = x + r + r = self.mlp(self.post_attention_layernorm(h)) + out = h + r + return out + + +class LlamaModel(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.vocab_size = args.vocab_size + self.num_hidden_layers = args.num_hidden_layers + assert self.vocab_size > 0 + self.embed_tokens = nn.Embedding(args.vocab_size, args.hidden_size) + self.layers = [ + TransformerBlock(args=args) for _ in range(args.num_hidden_layers) + ] + self.norm = nn.RMSNorm(args.hidden_size, eps=args.rms_norm_eps) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + h = self.embed_tokens(inputs) + + mask = create_attention_mask(h, cache) + + if cache is None: + cache = [None] * len(self.layers) + + for layer, c in zip(self.layers, cache): + h = layer(h, mask, cache=c) + + return self.norm(h) + + +class Model(nn.Module): + def __init__(self, args: ModelArgs): + super().__init__() + self.args = args + self.model_type = args.model_type + self.model = LlamaModel(args) + if not args.tie_word_embeddings: + self.lm_head = nn.Linear(args.hidden_size, args.vocab_size, bias=False) + + def __call__( + self, + inputs: mx.array, + cache=None, + ): + out = self.model(inputs, cache) + if self.args.tie_word_embeddings: + out = self.model.embed_tokens.as_linear(out) + else: + out = self.lm_head(out) + return out + + def sanitize(self, weights): + # Remove unused precomputed rotary freqs + return { + k: v for k, v in weights.items() if "self_attn.rotary_emb.inv_freq" not in k + } + + @property + def layers(self): + return self.model.layers diff --git a/unsloth/mlx/trainer/lora.py b/unsloth/mlx/trainer/lora.py new file mode 100644 index 00000000..3072ccfd --- /dev/null +++ b/unsloth/mlx/trainer/lora.py @@ -0,0 +1,283 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + +from .switch_layers import QuantizedSwitchLinear, SwitchLinear + + +class LoRALinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Linear, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + # TODO remove when input_dims and output_dims are attributes + # on linear and quantized linear + output_dims, input_dims = linear.weight.shape + if isinstance(linear, nn.QuantizedLinear): + input_dims *= 32 // linear.bits + lora_lin = LoRALinear( + input_dims=input_dims, + output_dims=output_dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self, de_quantize: bool = False): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, nn.QuantizedLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = linear.scales.dtype + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + output_dims, input_dims = weight.shape + fused_linear = nn.Linear(input_dims, output_dims, bias=bias) + + lora_b = (self.scale * self.lora_b.T).astype(dtype) + lora_a = self.lora_a.T.astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized and not de_quantize: + fused_linear = nn.QuantizedLinear.from_linear( + fused_linear, + linear.group_size, + linear.bits, + ) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = nn.Linear(input_dims, output_dims, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(input_dims, r), + ) + self.lora_b = mx.zeros(shape=(r, output_dims)) + + def __call__(self, x): + y = self.linear(x) + z = (self.dropout(x) @ self.lora_a) @ self.lora_b + return y + (self.scale * z).astype(x.dtype) + + +class LoRASwitchLinear(nn.Module): + @staticmethod + def from_base( + linear: nn.Module, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + lora_lin = LoRASwitchLinear( + input_dims=linear.input_dims, + output_dims=linear.output_dims, + num_experts=linear.num_experts, + r=r, + dropout=dropout, + scale=scale, + ) + lora_lin.linear = linear + return lora_lin + + def fuse(self, de_quantize: bool = False): + linear = self.linear + bias = "bias" in linear + weight = linear.weight + is_quantized = isinstance(linear, QuantizedSwitchLinear) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = mx.float16 + weight = mx.dequantize( + weight, + linear.scales, + linear.biases, + linear.group_size, + linear.bits, + ) + num_experts, output_dims, input_dims = weight.shape + fused_linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) + + lora_b = (self.scale * self.lora_b).astype(dtype) + lora_a = self.lora_a.reshape(num_experts, -1, input_dims).astype(dtype) + fused_linear.weight = weight + lora_b @ lora_a + if bias: + fused_linear.bias = linear.bias + + if is_quantized and not de_quantize: + fused_linear = fused_linear.to_quantized(linear.group_size, linear.bits) + + return fused_linear + + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + bias: bool = False, + ): + super().__init__() + + # Regular linear layer weights + self.linear = SwitchLinear(input_dims, output_dims, num_experts, bias=bias) + + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(input_dims) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(r * num_experts, input_dims), + ) + self.lora_b = mx.zeros(shape=(num_experts, output_dims, r)) + self.num_experts = num_experts + + def __call__(self, x, indices): + shape = x.shape[:-3] + (self.num_experts, -1) + + y = self.linear(x, indices) + z = (self.dropout(x) @ self.lora_a.T).reshape(shape) + z = mx.take_along_axis(z, indices[..., None], axis=-2) + z = z[..., None, :] @ self.lora_b[indices].swapaxes(-2, -1) + + return y + (self.scale * z).astype(x.dtype) + + +class LoRAEmbedding(nn.Module): + @staticmethod + def from_base( + embedding: nn.Embedding, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + num_embeddings, dims = embedding.weight.shape + if isinstance(embedding, nn.QuantizedEmbedding): + dims *= 32 // embedding.bits + lora_embedding = LoRAEmbedding( + num_embeddings=num_embeddings, + dims=dims, + r=r, + dropout=dropout, + scale=scale, + ) + lora_embedding.embedding = embedding + return lora_embedding + + def fuse(self, de_quantize: bool = False): + embedding = self.embedding + weight = embedding.weight + is_quantized = isinstance(embedding, nn.QuantizedEmbedding) + + # Use the same type as the linear weight if not quantized + dtype = weight.dtype + + if is_quantized: + dtype = embedding.scales.dtype + weight = mx.dequantize( + weight, + embedding.scales, + embedding.biases, + embedding.group_size, + embedding.bits, + ) + num_embeddings, dims = weight.shape + fused_embedding = nn.Embedding(num_embeddings, dims) + + lora_a = (self.scale * self.lora_a).astype(dtype) + lora_b = self.lora_b.astype(dtype) + fused_embedding.weight = weight + lora_a @ lora_b + + if is_quantized and not de_quantize: + fused_embedding = nn.QuantizedEmbedding.from_embedding( + fused_embedding, + embedding.group_size, + embedding.bits, + ) + + return fused_embedding + + def __init__( + self, + num_embeddings: int, + dims: int, + r: int = 8, + dropout: float = 0.0, + scale: float = 20.0, + ): + super().__init__() + + # Regular embedding layer + self.embedding = nn.Embedding(num_embeddings, dims) + self.dropout = nn.Dropout(p=dropout) + + # Scale for low-rank update + self.scale = scale + + # Low rank lora weights + scale = 1 / math.sqrt(num_embeddings) + self.lora_a = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_embeddings, r), + ) + self.lora_b = mx.zeros(shape=(r, dims)) + + def __call__(self, x): + y = self.embedding(x) + z = self.dropout(self.lora_a[x] @ self.lora_b) + out = y + (self.scale * z).astype(y.dtype) + return out + + def as_linear(self, x): + y = self.embedding.as_linear(x) + z = (self.dropout(x) @ self.lora_b.T) @ self.lora_a.T + return y + (self.scale * z).astype(x.dtype) diff --git a/unsloth/mlx/trainer/switch_layers.py b/unsloth/mlx/trainer/switch_layers.py new file mode 100644 index 00000000..00aa65d8 --- /dev/null +++ b/unsloth/mlx/trainer/switch_layers.py @@ -0,0 +1,165 @@ +import math + +import mlx.core as mx +import mlx.nn as nn + + +class QuantizedSwitchLinear(nn.Module): + def __init__( + self, + input_dims: int, + output_dims: int, + num_experts: int, + bias: bool = True, + group_size: int = 64, + bits: int = 4, + ): + super().__init__() + + scale = math.sqrt(1 / input_dims) + self.weight, self.scales, self.biases = mx.quantize( + mx.random.uniform( + low=-scale, + high=scale, + shape=(num_experts, output_dims, input_dims), + ), + group_size=group_size, + bits=bits, + ) + + if bias: + self.bias = mx.zeros((num_experts, output_dims)) + + self.group_size = group_size + self.bits = bits + + # Freeze this model's parameters + self.freeze() + + def unfreeze(self, *args, **kwargs): + """Wrap unfreeze so that we unfreeze any layers we might contain but + our parameters will remain frozen.""" + super().unfreeze(*args, **kwargs) + self.freeze(recurse=False) + + @property + def input_dims(self): + return self.scales.shape[2] * self.group_size + + @property + def output_dims(self): + return self.weight.shape[1] + + @property + def num_experts(self): + return self.weight.shape[0] + + def __call__(self, x, indices): + x = mx.gather_qmm( + x, + self["weight"], + self["scales"], + self["biases"], + rhs_indices=indices, + transpose=True, + group_size=self.group_size, + bits=self.bits, + ) + if "bias" in self: + x = x + mx.expand_dims(self["bias"][indices], -2) + return x + + +class SwitchLinear(nn.Module): + def __init__( + self, input_dims: int, output_dims: int, num_experts: int, bias: bool = True + ): + super().__init__() + scale = math.sqrt(1 / input_dims) + self.weight = mx.random.uniform( + low=-scale, + high=scale, + shape=(num_experts, output_dims, input_dims), + ) + + if bias: + self.bias = mx.zeros((num_experts, output_dims)) + + @property + def input_dims(self): + return self.weight.shape[2] + + @property + def output_dims(self): + return self.weight.shape[1] + + @property + def num_experts(self): + return self.weight.shape[0] + + def __call__(self, x, indices): + x = mx.gather_mm(x, self["weight"].swapaxes(-1, -2), rhs_indices=indices) + if "bias" in self: + x = x + mx.expand_dims(self["bias"][indices], -2) + return x + + def to_quantized(self, group_size: int = 64, bits: int = 4): + num_experts, output_dims, input_dims = self.weight.shape + ql = QuantizedSwitchLinear( + input_dims, output_dims, num_experts, False, group_size, bits + ) + ql.weight, ql.scales, ql.biases = mx.quantize(self.weight, group_size, bits) + if "bias" in self: + ql.bias = self.bias + return ql + + +class SwitchGLU(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=nn.silu, + bias: bool = False, + ): + super().__init__() + + self.gate_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.up_proj = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.down_proj = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) + self.activation = activation + + def __call__(self, x, indices) -> mx.array: + x = mx.expand_dims(x, (-2, -3)) + + x_up = self.up_proj(x, indices) + x_gate = self.gate_proj(x, indices) + x = self.down_proj(self.activation(x_gate) * x_up, indices) + + return x.squeeze(-2) + + +class SwitchMLP(nn.Module): + def __init__( + self, + input_dims: int, + hidden_dims: int, + num_experts: int, + activation=nn.gelu_approx, + bias: bool = False, + ): + super().__init__() + + self.fc1 = SwitchLinear(input_dims, hidden_dims, num_experts, bias=bias) + self.fc2 = SwitchLinear(hidden_dims, input_dims, num_experts, bias=bias) + self.activation = activation + + def __call__(self, x, indices) -> mx.array: + x = mx.expand_dims(x, (-2, -3)) + + x = self.fc1(x, indices) + x = self.activation(x) + x = self.fc2(x, indices) + + return x.squeeze(-2) diff --git a/unsloth/mlx/trainer/trainer.py b/unsloth/mlx/trainer/trainer.py new file mode 100644 index 00000000..cd9c4bca --- /dev/null +++ b/unsloth/mlx/trainer/trainer.py @@ -0,0 +1,328 @@ +import time +from dataclasses import dataclass, field +from pathlib import Path + +import mlx.core as mx +import mlx.nn as nn +import numpy as np +from mlx.nn.utils import average_gradients +from mlx.utils import tree_flatten + + +def grad_checkpoint(layer): + """ + Update all instances of type(layer) to use gradient checkpointing. + """ + fn = type(layer).__call__ + + def checkpointed_fn(model, *args, **kwargs): + def inner_fn(params, *args, **kwargs): + model.update(params) + return fn(model, *args, **kwargs) + + return mx.checkpoint(inner_fn)(model.trainable_parameters(), *args, **kwargs) + + type(layer).__call__ = checkpointed_fn + + +@dataclass +class TrainingArgs: + batch_size: int = field(default=4, metadata={"help": "Minibatch size."}) + iters: int = field(default=100, metadata={"help": "Iterations to train for."}) + val_batches: int = field( + default=25, + metadata={ + "help": "Number of validation batches, -1 uses the entire validation set." + }, + ) + steps_per_report: int = field( + default=10, + metadata={"help": "Number of training steps between loss reporting."}, + ) + steps_per_eval: int = field( + default=200, metadata={"help": "Number of training steps between validations."} + ) + steps_per_save: int = field( + default=100, metadata={"help": "Save the model every number steps"} + ) + max_seq_length: int = field( + default=2048, metadata={"help": "Maximum sequence length."} + ) + adapter_file: str = field( + default="adapters.safetensors", + metadata={"help": "Save/load path for the trained adapter weights."}, + ) + grad_checkpoint: bool = field( + default=False, + metadata={"help": "Use gradient checkpointing to reduce memory use."}, + ) + + +def default_loss(model, inputs, targets, lengths): + logits = model(inputs) + logits = logits.astype(mx.float32) + + length_mask = mx.arange(inputs.shape[1])[None, :] < lengths[:, None] + + ce = nn.losses.cross_entropy(logits, targets) * length_mask + ntoks = length_mask.sum() + ce = ce.sum() / ntoks + + return ce, ntoks + + +def iterate_batches(dataset, tokenizer, batch_size, max_seq_length, train=False): + # Sort by length: + idx = sorted(range(len(dataset)), key=lambda idx: len(dataset[idx])) + if len(dataset) < batch_size: + raise ValueError( + f"Dataset must have at least batch_size={batch_size}" + f" examples but only has {len(dataset)}." + ) + + # If running in distributed mode (N machines) then each one should skip N-1 + # samples + step = mx.distributed.init().size() + if batch_size % step != 0: + raise ValueError("The batch size must be divisible by the number of workers") + + # Make the batches: + batch_idx = [ + idx[i : i + batch_size : step] + for i in range(0, len(idx) - batch_size + 1, batch_size) + ] + + while True: + indices = np.random.permutation(len(batch_idx)) + for i in indices: + # Encode batch + batch = [tokenizer.encode(str(dataset[j])) for j in batch_idx[i]] + for b in batch: + if b[-1] != tokenizer.eos_token_id: + b.append(tokenizer.eos_token_id) + + lengths = [len(x) for x in batch] + + if max(lengths) > max_seq_length: + print( + f"[WARNING] Some sequences are longer than {max_seq_length} tokens. " + f"The longest sentence {max(lengths)} will be truncated to {max_seq_length}. " + "Consider pre-splitting your data to save memory." + ) + + # Pad to the nearest multiple of 8 or the maximum length + pad_to = 8 + max_length_in_batch = pad_to * ((max(lengths) + pad_to - 1) // pad_to) + max_length_in_batch = min(max_length_in_batch, max_seq_length) + + batch_arr = np.zeros((batch_size // step, max_length_in_batch), np.int32) + + for j in range(batch_size // step): + truncated_length = min(lengths[j], max_seq_length) + batch_arr[j, :truncated_length] = batch[j][:truncated_length] + lengths[j] = ( + truncated_length # Update lengths to match truncated lengths + ) + batch = mx.array(batch_arr) + + yield batch[:, :-1], batch[:, 1:], mx.array(lengths) + + if not train: + break + + +def evaluate( + model, + dataset, + tokenizer, + batch_size, + num_batches, + max_seq_length=2048, + loss: callable = default_loss, + iterate_batches: callable = iterate_batches, +): + all_losses = 0 + ntokens = 0 + + index_iterator = iter(range(num_batches)) if num_batches != -1 else iter(int, 1) + + for _, batch in zip( + index_iterator, + iterate_batches( + dataset=dataset, + tokenizer=tokenizer, + batch_size=batch_size, + max_seq_length=max_seq_length, + ), + ): + losses, toks = loss(model, *batch) + all_losses += losses * toks + ntokens += toks + mx.eval(all_losses, ntokens) + + all_losses = mx.distributed.all_sum(all_losses) + ntokens = mx.distributed.all_sum(ntokens) + + return (all_losses / ntokens).item() + + +class TrainingCallback: + + def on_train_loss_report(self, train_info: dict): + """Called to report training loss at specified intervals.""" + pass + + def on_val_loss_report(self, val_info: dict): + """Called to report validation loss at specified intervals or the beginning.""" + pass + + +def train( + model, + tokenizer, + optimizer, + train_dataset, + val_dataset, + args: TrainingArgs = TrainingArgs(), + loss: callable = default_loss, + iterate_batches: callable = iterate_batches, + training_callback: TrainingCallback = None, +): + print(f"Starting training..., iters: {args.iters}") + world = mx.distributed.init() + world_size = world.size() + rank = world.rank() + if world_size > 1: + print(f"Node {rank} of {world_size}") + + if args.grad_checkpoint: + grad_checkpoint(model.layers[0]) + + state = [model.state, optimizer.state] + + def step(batch): + # Forward and backward pass + (lvalue, toks), grad = loss_value_and_grad(model, *batch) + + # All reduce the gradients if running in distributed mode + grad = average_gradients(grad) + + # Model update + optimizer.update(model, grad) + + return lvalue, toks + + loss_value_and_grad = nn.value_and_grad(model, loss) + + losses = 0 + n_tokens = 0 + steps = 0 + trained_tokens = 0 + # Main training loop + start = time.perf_counter() + for it, batch in zip( + range(1, args.iters + 1), + iterate_batches( + dataset=train_dataset, + tokenizer=tokenizer, + batch_size=args.batch_size, + max_seq_length=args.max_seq_length, + train=True, + ), + ): + # Report validation loss if needed, the first validation loss + # is always measured before any training. + if it == 1 or it % args.steps_per_eval == 0 or it == args.iters: + stop = time.perf_counter() + val_loss = evaluate( + model=model, + dataset=val_dataset, + loss=loss, + tokenizer=tokenizer, + batch_size=args.batch_size, + num_batches=args.val_batches, + max_seq_length=args.max_seq_length, + iterate_batches=iterate_batches, + ) + val_time = time.perf_counter() - stop + if rank == 0: + print( + f"Iter {it}: " + f"Val loss {val_loss:.3f}, " + f"Val took {val_time:.3f}s", + flush=True, + ) + + if training_callback is not None: + val_info = { + "iteration": it, + "val_loss": val_loss, + "val_time": val_time, + } + training_callback.on_val_loss_report(val_info) + + start = time.perf_counter() + + lvalue, toks = step(batch) + losses += lvalue + n_tokens += toks + steps += 1 + mx.eval(state, losses, n_tokens) + + # Report training loss if needed + if it % args.steps_per_report == 0 or it == args.iters: + stop = time.perf_counter() + + train_loss = mx.distributed.all_sum(losses).item() + train_loss /= steps * mx.distributed.init().size() + n_tokens = mx.distributed.all_sum(n_tokens).item() + learning_rate = optimizer.learning_rate.item() + it_sec = args.steps_per_report / (stop - start) + tokens_sec = float(n_tokens) / (stop - start) + trained_tokens += n_tokens + peak_mem = mx.metal.get_peak_memory() / 1e9 + if rank == 0: + print( + f"Iter {it}: Train loss {train_loss:.3f}, " + f"Learning Rate {learning_rate:.3e}, " + f"It/sec {it_sec:.3f}, " + f"Tokens/sec {tokens_sec:.3f}, " + f"Trained Tokens {trained_tokens}, " + f"Peak mem {peak_mem:.3f} GB", + flush=True, + ) + + if training_callback is not None: + train_info = { + "iteration": it, + "train_loss": train_loss, + "learning_rate": learning_rate, + "iterations_per_second": it_sec, + "tokens_per_second": tokens_sec, + "trained_tokens": trained_tokens, + "peak_memory": peak_mem, + } + training_callback.on_train_loss_report(train_info) + + losses = 0 + n_tokens = 0 + steps = 0 + start = time.perf_counter() + + # Save adapter weights + if it % args.steps_per_save == 0: + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + checkpoint = ( + Path(args.adapter_file).parent / f"{it:07d}_{Path(args.adapter_file).name}" + ) + mx.save_safetensors(str(checkpoint), adapter_weights) + print( + f"Iter {it}: Saved adapter weights to " + f"{args.adapter_file} and {checkpoint}." + ) + + # Save final weights + adapter_weights = dict(tree_flatten(model.trainable_parameters())) + mx.save_safetensors(str(args.adapter_file), adapter_weights) + print(f"Saved final weights to {args.adapter_file}.") diff --git a/unsloth/mlx/trainer/utils.py b/unsloth/mlx/trainer/utils.py new file mode 100644 index 00000000..0c97f2ef --- /dev/null +++ b/unsloth/mlx/trainer/utils.py @@ -0,0 +1,149 @@ +# Copyright © 2024 Apple Inc. +import json +import types +from pathlib import Path +from typing import Dict + +import mlx.core as mx +import mlx.nn as nn +import mlx.optimizers as opt +from mlx.utils import tree_flatten, tree_unflatten + +from .switch_layers import QuantizedSwitchLinear, SwitchLinear +from .lora import LoRAEmbedding, LoRALinear, LoRASwitchLinear + + +def build_schedule(schedule_config: Dict): + """ + Build a learning rate schedule from the given config. + """ + schedule_fn = getattr(opt.schedulers, schedule_config["name"]) + arguments = schedule_config["arguments"] + initial_lr = arguments[0] + bound_schedule_fn = schedule_fn(*arguments) + if warmup_steps := schedule_config.get("warmup", 0): + warmup_init = schedule_config.get("warmup_init", 0.0) + warmup_fn = opt.schedulers.linear_schedule( + warmup_init, initial_lr, warmup_steps + ) + return opt.schedulers.join_schedules( + [warmup_fn, bound_schedule_fn], [warmup_steps + 1] + ) + else: + return bound_schedule_fn + + +def linear_to_lora_layers( + model: nn.Module, + num_layers: int, + config: Dict, +): + """ + Convert some of the models linear layers to lora layers. + + Args: + model (nn.Module): The neural network model. + num_layers (int): The number of blocks to convert to lora layers + starting from the last layer. + config (dict): More configuration parameters for LoRA, including the + rank, scale, and optional layer keys. + use_dora (bool): If True, uses DoRA instead of LoRA. + Default: ``False`` + """ + if num_layers > len(model.layers): + raise ValueError( + f"Requested {num_layers} LoRA layers " + f"but the model only has {len(model.layers)} layers." + ) + + def to_lora(layer): + if isinstance(layer, (nn.Linear, nn.QuantizedLinear)): + LoRALayer = LoRALinear + elif isinstance(layer, (SwitchLinear, QuantizedSwitchLinear)): + LoRALayer = LoRASwitchLinear + elif isinstance(layer, (nn.Embedding, nn.QuantizedEmbedding)): + LoRALayer = LoRAEmbedding + else: + raise ValueError( + f"Can't convert layer of type {type(layer).__name__} to LoRA" + ) + + return LoRALayer.from_base( + layer, + r=config["rank"], + scale=config["scale"], + dropout=config["dropout"], + ) + + keys = config.get("keys", None) + if keys is not None: + keys = set(keys) + elif model.model_type in [ + "mistral", + "llama", + ]: + keys = set(["self_attn.q_proj", "self_attn.v_proj"]) + if model.model_type in ["mixtral", "phimoe"]: + keys.add("block_sparse_moe.gate") + if model.model_type == "qwen2_moe": + keys.add("mlp.gate") + keys.add("mlp.shared_expert_gate") + else: + raise ValueError(f"Lora does not support {model.model_type}") + + for l in model.layers[-min(num_layers, 0) :]: + lora_layers = [(k, to_lora(m)) for k, m in l.named_modules() if k in keys] + if lora_layers: + l.update_modules(tree_unflatten(lora_layers)) + + lora_modules = [(k, to_lora(m)) for k, m in model.named_modules() if k in keys] + if lora_modules: + model.update_modules(tree_unflatten(lora_modules)) + + +def load_adapters(model: nn.Module, adapter_path: str,adapter_file : str) -> nn.Module: + """ + Load any fine-tuned adapters / layers. + + Args: + model (nn.Module): The neural network model. + adapter_path (str): Path to the adapter configuration file. + + Returns: + nn.Module: The updated model with LoRA layers applied. + """ + adapter_path = Path(adapter_path) + if not adapter_path.exists(): + raise FileNotFoundError(f"The adapter path does not exist: {adapter_path}") + with open(adapter_path / "adapter_config.json", "r") as fid: + config = types.SimpleNamespace(**json.load(fid)) + + linear_to_lora_layers( + model, + config.num_layers, + config.lora_parameters, + ) + # adapters = list(mx.load(str(adapter_path /adapter_file)).items()) + # model.load_weights(adapters,False) + model.load_weights(str(adapter_path / adapter_file), strict=False) + return model + + + +def print_trainable_parameters(model): + def nparams(m): + if isinstance(m, (nn.QuantizedLinear, nn.QuantizedEmbedding)): + return m.weight.size * (32 // m.bits) + return sum(v.size for _, v in tree_flatten(m.parameters())) + + leaf_modules = tree_flatten( + model.leaf_modules(), is_leaf=lambda m: isinstance(m, nn.Module) + ) + total_p = sum(nparams(m) for _, m in leaf_modules) / 10**6 + trainable_p = ( + sum(v.size for _, v in tree_flatten(model.trainable_parameters())) / 10**6 + ) + print( + f"Trainable parameters: {(trainable_p * 100 / total_p):.3f}% " + f"({trainable_p:.3f}M/{total_p:.3f}M)" + ) diff --git a/unsloth/models/__init__.py b/unsloth/models/__init__.py index 3230cdc2..d208fb34 100644 --- a/unsloth/models/__init__.py +++ b/unsloth/models/__init__.py @@ -12,9 +12,11 @@ # See the License for the specific language governing permissions and # limitations under the License. -from .loader import FastLanguageModel, FastVisionModel -from .llama import FastLlamaModel -from .mistral import FastMistralModel -from .qwen2 import FastQwen2Model -from .dpo import PatchDPOTrainer -from ._utils import is_bfloat16_supported +from unsloth import devices +if not devices.has_mps: + from .loader import FastLanguageModel, FastVisionModel + from .llama import FastLlamaModel + from .mistral import FastMistralModel + from .qwen2 import FastQwen2Model + from .dpo import PatchDPOTrainer + from ._utils import is_bfloat16_supported diff --git a/unsloth/models/_utils.py b/unsloth/models/_utils.py index 3a29352a..8e0bf1cc 100644 --- a/unsloth/models/_utils.py +++ b/unsloth/models/_utils.py @@ -73,17 +73,19 @@ import numpy as np import warnings, subprocess, re, inspect, psutil, os, math from packaging.version import Version +from unsloth import devices from unsloth_zoo.tokenizer_utils import ( patch_tokenizer as _patch_tokenizer, ) -from unsloth_zoo.patching_utils import ( - patch_compiling_bitsandbytes, - patch_layernorm, - patch_torch_compile, - patch_model_and_tokenizer, - patch_compiled_autograd, -) +if not devices.has_mps: + from unsloth_zoo.patching_utils import ( + patch_compiling_bitsandbytes, + patch_layernorm, + patch_torch_compile, + patch_model_and_tokenizer, + patch_compiled_autograd, + ) from unsloth_zoo.gradient_checkpointing import ( Unsloth_Offloaded_Gradient_Checkpointer, unsloth_offloaded_gradient_checkpoint, @@ -106,10 +108,11 @@ from unsloth_zoo.vision_utils import ( process_vision_info, ) -from unsloth_zoo.compiler import ( - get_transformers_model_type, - unsloth_compile_transformers as _unsloth_compile_transformers, -) +if not devices.has_mps: + from unsloth_zoo.compiler import ( + get_transformers_model_type, + unsloth_compile_transformers as _unsloth_compile_transformers, + ) # ============================================= # Disable some warnings which can get annoying @@ -271,11 +274,16 @@ def _is_openai_available(): return False # ============================================= # Get Flash Attention v2 if Ampere (RTX 30xx, A100) -import bitsandbytes as bnb +if not devices.has_mps: + import bitsandbytes as bnb from transformers import AutoTokenizer from transformers.utils.import_utils import _is_package_available -major_version, minor_version = torch.cuda.get_device_capability() +devices.get_optimal_device() +if torch.cuda.is_available(): + major_version, minor_version = torch.cuda.get_device_capability() +else: + major_version,minor_version = 0,0 SUPPORTS_BFLOAT16 = False HAS_FLASH_ATTENTION = False HAS_FLASH_ATTENTION_SOFTCAPPING = False @@ -446,11 +454,12 @@ def is_big_gpu(index): return True import torch._inductor.utils torch._inductor.utils.is_big_gpu = is_big_gpu -patch_torch_compile( - debug = UNSLOTH_COMPILE_DEBUG, - O3 = UNSLOTH_COMPILE_MAXIMUM, - ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, -) +if not devices.has_mps: + patch_torch_compile( + debug = UNSLOTH_COMPILE_DEBUG, + O3 = UNSLOTH_COMPILE_MAXIMUM, + ignore_errors = UNSLOTH_COMPILE_IGNORE_ERRORS, + ) torch_compile_options = { "epilogue_fusion" : True, @@ -956,7 +965,7 @@ def check_nvidia(): output = re.findall(rb'([\d]{1,})[\s]{1,}M', output) output = np.array([int(x.decode('utf-8'))/1024 for x in output]) except: - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and not devices.has_mps: raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!") return output pass diff --git a/unsloth/models/mapper.py b/unsloth/models/mapper.py index b2f73aa6..4a7331e5 100644 --- a/unsloth/models/mapper.py +++ b/unsloth/models/mapper.py @@ -17,489 +17,588 @@ "FLOAT_TO_INT_MAPPER", ] -__INT_TO_FLOAT_MAPPER = \ +from unsloth.devices import has_mps + + +if not has_mps: + __INT_TO_FLOAT_MAPPER = \ + { + "unsloth/mistral-7b-bnb-4bit" : ( + "unsloth/mistral-7b", + "mistralai/Mistral-7B-v0.1", + ), + "unsloth/llama-2-7b-bnb-4bit" : ( + "unsloth/llama-2-7b", + "meta-llama/Llama-2-7b-hf", + ), + "unsloth/llama-2-13b-bnb-4bit" : ( + "unsloth/llama-2-13b", + "meta-llama/Llama-2-13b-hf", + ), + "unsloth/codellama-34b-bnb-4bit" : ( + "codellama/CodeLlama-34b-hf", + ), + "unsloth/zephyr-sft-bnb-4bit" : ( + "unsloth/zephyr-sft", + "HuggingFaceH4/mistral-7b-sft-beta", + ), + "unsloth/tinyllama-bnb-4bit" : ( + "unsloth/tinyllama", + "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", + ), + "unsloth/tinyllama-chat-bnb-4bit" : ( + "unsloth/tinyllama-chat", + "TinyLlama/TinyLlama-1.1B-Chat-v1.0", + ), + "unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : ( + "unsloth/mistral-7b-instruct-v0.1", + "mistralai/Mistral-7B-Instruct-v0.1", + ), + "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : ( + "unsloth/mistral-7b-instruct-v0.2", + "mistralai/Mistral-7B-Instruct-v0.2", + ), + "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "unsloth/llama-2-7b-chat", + "meta-llama/Llama-2-7b-chat-hf", + ), + "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "unsloth/llama-2-7b-chat", + "meta-llama/Llama-2-7b-chat-hf", + ), + "unsloth/codellama-7b-bnb-4bit" : ( + "unsloth/codellama-7b", + "codellama/CodeLlama-7b-hf", + ), + "unsloth/codellama-13b-bnb-4bit" : ( + "codellama/CodeLlama-13b-hf", + ), + "unsloth/yi-6b-bnb-4bit" : ( + "unsloth/yi-6b", + "01-ai/Yi-6B", + ), + "unsloth/solar-10.7b-bnb-4bit" : ( + "upstage/SOLAR-10.7B-v1.0", + ), + "unsloth/gemma-7b-bnb-4bit" : ( + "unsloth/gemma-7b", + "google/gemma-7b", + ), + "unsloth/gemma-2b-bnb-4bit" : ( + "unsloth/gemma-2b", + "google/gemma-2b", + ), + "unsloth/gemma-7b-it-bnb-4bit" : ( + "unsloth/gemma-7b-it", + "google/gemma-7b-it", + ), + "unsloth/gemma-2b-bnb-4bit" : ( + "unsloth/gemma-2b-it", + "google/gemma-2b-it", + ), + "unsloth/mistral-7b-v0.2-bnb-4bit" : ( + "unsloth/mistral-7b-v0.2", + "alpindale/Mistral-7B-v0.2-hf", + ), + "unsloth/gemma-1.1-2b-it-bnb-4bit" : ( + "unsloth/gemma-1.1-2b-it", + "google/gemma-1.1-2b-it", + ), + "unsloth/gemma-1.1-7b-it-bnb-4bit" : ( + "unsloth/gemma-1.1-7b-it", + "google/gemma-1.1-7b-it", + ), + "unsloth/Starling-LM-7B-beta-bnb-4bit" : ( + "unsloth/Starling-LM-7B-beta", + "Nexusflow/Starling-LM-7B-beta", + ), + "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit" : ( + "unsloth/Hermes-2-Pro-Mistral-7B", + "NousResearch/Hermes-2-Pro-Mistral-7B", + ), + "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit" : ( + "unsloth/OpenHermes-2.5-Mistral-7B", + "teknium/OpenHermes-2.5-Mistral-7B", + ), + "unsloth/codegemma-2b-bnb-4bit" : ( + "unsloth/codegemma-2b", + "google/codegemma-2b", + ), + "unsloth/codegemma-7b-bnb-4bit" : ( + "unsloth/codegemma-7b", + "google/codegemma-7b", + ), + "unsloth/codegemma-7b-it-bnb-4bit" : ( + "unsloth/codegemma-7b-it", + "google/codegemma-7b-it", + ), + "unsloth/llama-3-8b-bnb-4bit" : ( + "unsloth/llama-3-8b", + "meta-llama/Meta-Llama-3-8B", + ), + "unsloth/llama-3-8b-Instruct-bnb-4bit" : ( + "unsloth/llama-3-8b-Instruct", + "meta-llama/Meta-Llama-3-8B-Instruct", + ), + "unsloth/llama-3-70b-bnb-4bit" : ( + "meta-llama/Meta-Llama-3-70B", + ), + "unsloth/llama-3-70b-Instruct-bnb-4bit" : ( + "meta-llama/Meta-Llama-3-70B-Instruct", + ), + "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" : ( + "unsloth/Phi-3-mini-4k-instruct", + "microsoft/Phi-3-mini-4k-instruct", + ), + "unsloth/mistral-7b-v0.3-bnb-4bit" : ( + "unsloth/mistral-7b-v0.3", + "mistralai/Mistral-7B-v0.3", + ), + "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" : ( + "unsloth/mistral-7b-instruct-v0.3", + "mistralai/Mistral-7B-Instruct-v0.3", + ), + "unsloth/Phi-3-medium-4k-instruct-bnb-4bit" : ( + "unsloth/Phi-3-medium-4k-instruct", + "microsoft/Phi-3-medium-4k-instruct", + ), + "unsloth/Qwen2-0.5B-bnb-4bit" : ( + "unsloth/Qwen2-0.5B", + "Qwen/Qwen2-0.5B", + ), + "unsloth/Qwen2-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-0.5B-Instruct", + "Qwen/Qwen2-0.5B-Instruct", + ), + "unsloth/Qwen2-1.5B-bnb-4bit" : ( + "unsloth/Qwen2-1.5B", + "Qwen/Qwen2-1.5B", + ), + "unsloth/Qwen2-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-1.5B-Instruct", + "Qwen/Qwen2-1.5B-Instruct", + ), + "unsloth/Qwen2-7B-bnb-4bit" : ( + "unsloth/Qwen2-7B", + "Qwen/Qwen2-7B", + ), + "unsloth/Qwen2-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-7B-Instruct", + "Qwen/Qwen2-7B-Instruct", + ), + "unsloth/Qwen2-70B-bnb-4bit" : ( + "Qwen/Qwen2-70B", + ), + "unsloth/Qwen2-70B-Instruct-bnb-4bit" : ( + "Qwen/Qwen2-70B-Instruct", + ), + "mistralai/Codestral-22B-v0.1" : ( + "mistral-community/Codestral-22B-v0.1", + ), + "unsloth/gemma-2-9b-bnb-4bit" : ( + "unsloth/gemma-2-9b", + "google/gemma-2-9b", + ), + "unsloth/gemma-2-27b-bnb-4bit" : ( + "unsloth/gemma-2-27b", + "google/gemma-2-27b", + ), + "unsloth/gemma-2-9b-it-bnb-4bit" : ( + "unsloth/gemma-2-9b-it", + "google/gemma-2-9b-it", + ), + "unsloth/gemma-2-27b-it-bnb-4bit" : ( + "unsloth/gemma-2-27b-it", + "google/gemma-2-27b-it", + ), + "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit" : ( # Old Phi pre July + "unsloth/Phi-3-mini-4k-instruct-v0", + ), + "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit" : ( # New 12b Mistral models + "unsloth/Mistral-Nemo-Instruct-2407", + "mistralai/Mistral-Nemo-Instruct-2407", + ), + "unsloth/Mistral-Nemo-Base-2407-bnb-4bit" : ( # New 12b Mistral models + "unsloth/Mistral-Nemo-Base-2407", + "mistralai/Mistral-Nemo-Base-2407", + ), + "unsloth/Meta-Llama-3.1-8B-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-8B", + "meta-llama/Meta-Llama-3.1-8B", + ), + "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-8B-Instruct", + "meta-llama/Meta-Llama-3.1-8B-Instruct", + ), + "unsloth/Meta-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B", + "meta-llama/Meta-Llama-3.1-70B", + ), + "unsloth/Meta-Llama-3.1-405B-bnb-4bit" : ( + "meta-llama/Meta-Llama-3.1-405B", + ), + "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit" : ( + "meta-llama/Meta-Llama-3.1-405B-Instruct", + ), + "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : ( + "unsloth/Meta-Llama-3.1-70B-Instruct", + "meta-llama/Meta-Llama-3.1-70B-Instruct", + ), + "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : ( + "mistralai/Mistral-Large-Instruct-2407", + ), + "unsloth/gemma-2-2b-bnb-4bit" : ( + "unsloth/gemma-2-2b", + "google/gemma-2-2b", + ), + "unsloth/gemma-2-2b-it-bnb-4bit" : ( + "unsloth/gemma-2-2b-it", + "google/gemma-2-2b-it", + ), + "unsloth/Phi-3.5-mini-instruct-bnb-4bit" : ( + "unsloth/Phi-3.5-mini-instruct", + "microsoft/Phi-3.5-mini-instruct", + ), + "unsloth/c4ai-command-r-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-08-2024", + ), + "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : ( + "CohereForAI/c4ai-command-r-plus-08-2024", + ), + "unsloth/Llama-3.1-Storm-8B-bnb-4bit" : ( + "unsloth/Llama-3.1-Storm-8B", + "akjindal53244/Llama-3.1-Storm-8B", + ), + "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-8B", + "NousResearch/Hermes-3-Llama-3.1-8B", + ), + "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : ( + "unsloth/Hermes-3-Llama-3.1-70B", + "NousResearch/Hermes-3-Llama-3.1-70B", + ), + "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : ( + "NousResearch/Hermes-3-Llama-3.1-405B", + ), + "unsloth/SmolLM-135M-bnb-4bit" : ( + "unsloth/SmolLM-135M", + "HuggingFaceTB/SmolLM-135M", + ), + "unsloth/SmolLM-360M-bnb-4bit" : ( + "unsloth/SmolLM-360M", + "HuggingFaceTB/SmolLM-360M", + ), + "unsloth/SmolLM-1.7B-bnb-4bit" : ( + "unsloth/SmolLM-1.7B", + "HuggingFaceTB/SmolLM-1.7B", + ), + "unsloth/SmolLM-135M-Instruct-bnb-4bit" : ( + "unsloth/SmolLM-135M-Instruct", + "HuggingFaceTB/SmolLM-135M-Instruct", + ), + "unsloth/SmolLM-360M-Instruct-bnb-4bit" : ( + "unsloth/SmolLM-360M-Instruct", + "HuggingFaceTB/SmolLM-360M-Instruct", + ), + "unsloth/SmolLM-1.7B-Instruct-bnb-4bit" : ( + "unsloth/SmolLM-1.7B-Instruct", + "HuggingFaceTB/SmolLM-1.7B-Instruct", + ), + "unsloth/Mistral-Small-Instruct-2409-bnb-4bit" : ( + "unsloth/Mistral-Small-Instruct-2409", + "mistralai/Mistral-Small-Instruct-2409", + ), + "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B-Instruct", + "Qwen/Qwen2.5-0.5B-Instruct", + ), + "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B-Instruct", + "Qwen/Qwen2.5-1.5B-Instruct", + ), + "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-3B-Instruct", + "Qwen/Qwen2.5-3B-Instruct", + ), + "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-7B-Instruct", + "Qwen/Qwen2.5-7B-Instruct", + ), + "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-14B-Instruct", + "Qwen/Qwen2.5-14B-Instruct", + ), + "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-32B-Instruct", + "Qwen/Qwen2.5-32B-Instruct", + ), + "unsloth/Qwen2.5-72B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-72B-Instruct", + "Qwen/Qwen2.5-72B-Instruct", + ), + "unsloth/Qwen2.5-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-0.5B", + "Qwen/Qwen2.5-0.5B", + ), + "unsloth/Qwen2.5-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-1.5B", + "Qwen/Qwen2.5-1.5B", + ), + "unsloth/Qwen2.5-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-3B", + "Qwen/Qwen2.5-3B", + ), + "unsloth/Qwen2.5-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-7B", + "Qwen/Qwen2.5-7B", + ), + "unsloth/Qwen2.5-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-14B", + "Qwen/Qwen2.5-14B", + ), + "unsloth/Qwen2.5-32B-bnb-4bit" : ( + "unsloth/Qwen2.5-32B", + "Qwen/Qwen2.5-32B", + ), + "unsloth/Qwen2.5-72B-bnb-4bit" : ( + "unsloth/Qwen2.5-72B", + "Qwen/Qwen2.5-72B", + ), + "unsloth/Qwen2.5-Math-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-1.5B", + "Qwen/Qwen2.5-Math-1.5B", + ), + "unsloth/Qwen2.5-Math-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-7B", + "Qwen/Qwen2.5-Math-7B", + ), + "unsloth/Qwen2.5-Math-72B-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-72B", + "Qwen/Qwen2.5-Math-72B", + ), + "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-1.5B-Instruct", + "Qwen/Qwen2.5-Math-1.5B-Instruct", + ), + "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-7B-Instruct", + "Qwen/Qwen2.5-Math-7B-Instruct", + ), + "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Math-72B-Instruct", + "Qwen/Qwen2.5-Math-72B-Instruct", + ), + "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-0.5B", + "Qwen/Qwen2.5-Coder-0.5B", + ), + "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-1.5B", + "Qwen/Qwen2.5-Coder-1.5B", + ), + "unsloth/Qwen2.5-Coder-3B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B", + "Qwen/Qwen2.5-Coder-3B", + ), + "unsloth/Qwen2.5-Coder-7B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-7B", + "Qwen/Qwen2.5-Coder-7B", + ), + "unsloth/Qwen2.5-Coder-14B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B", + "Qwen/Qwen2.5-Coder-14B", + ), + "unsloth/Qwen2.5-Coder-32B-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B", + "Qwen/Qwen2.5-Coder-32B", + ), + "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-0.5B-Instruct", + "Qwen/Qwen2.5-Coder-0.5B-Instruct", + ), + "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-1.5B-Instruct", + "Qwen/Qwen2.5-Coder-1.5B-Instruct", + ), + "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-3B-Instruct", + "Qwen/Qwen2.5-Coder-3B-Instruct", + ), + "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-7B-Instruct", + "Qwen/Qwen2.5-Coder-7B-Instruct", + ), + "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-14B-Instruct", + "Qwen/Qwen2.5-Coder-14B-Instruct", + ), + "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2.5-Coder-32B-Instruct", + "Qwen/Qwen2.5-Coder-32B-Instruct", + ), + "unsloth/Llama-3.2-1B-bnb-4bit" : ( + "unsloth/Llama-3.2-1B", + "meta-llama/Llama-3.2-1B", + ), + "unsloth/Llama-3.2-3B-bnb-4bit" : ( + "unsloth/Llama-3.2-3B", + "meta-llama/Llama-3.2-3B", + ), + "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-1B-Instruct", + "meta-llama/Llama-3.2-1B-Instruct", + ), + "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-3B-Instruct", + "meta-llama/Llama-3.2-3B-Instruct", + ), + "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.1-Nemotron-70B-Instruct", + "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", + ), + "unsloth/Qwen2-VL-2B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-VL-2B-Instruct", + "Qwen/Qwen2-VL-2B-Instruct", + ), + "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit" : ( + "unsloth/Qwen2-VL-7B-Instruct", + "Qwen/Qwen2-VL-7B-Instruct", + ), + "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-11B-Vision-Instruct", + "meta-llama/Llama-3.2-11B-Vision-Instruct", + ), + "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision-Instruct", + "meta-llama/Llama-3.2-90B-Vision-Instruct", + ), + "unsloth/Llama-3.2-11B-Vision-bnb-4bit" : ( + "unsloth/Llama-3.2-11B-Vision", + "meta-llama/Llama-3.2-11B-Vision", + ), + "unsloth/Llama-3.2-90B-Vision-bnb-4bit" : ( + "unsloth/Llama-3.2-90B-Vision", + "meta-llama/Llama-3.2-90B-Vision", + ), + "unsloth/Pixtral-12B-2409-bnb-4bit" : ( + "unsloth/Pixtral-12B-2409", + "mistralai/Pixtral-12B-2409", + ), + "unsloth/Pixtral-12B-2409-Base-bnb-4bit" : ( + "unsloth/Pixtral-12B-Base-2409", + "mistralai/Pixtral-12B-Base-2409", + ), + "unsloth/llava-1.5-7b-hf-bnb-4bit" : ( + "unsloth/llava-1.5-7b-hf", + "llava-hf/llava-1.5-7b-hf", + ), + "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit" : ( + "unsloth/llava-v1.6-mistral-7b-hf", + "llava-hf/llava-v1.6-mistral-7b-hf", + ), + "unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit" : ( + "unsloth/Llama-3.1-Tulu-3-8B", + "allenai/Llama-3.1-Tulu-3-8B", + ), + "unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit" : ( + "unsloth/Llama-3.1-Tulu-3-70B", + "allenai/Llama-3.1-Tulu-3-70B", + ), + } +else: + __INT_TO_FLOAT_MAPPER = \ { - "unsloth/mistral-7b-bnb-4bit" : ( - "unsloth/mistral-7b", - "mistralai/Mistral-7B-v0.1", - ), - "unsloth/llama-2-7b-bnb-4bit" : ( + "shashikanth-a/llama-2-7b-4bit" : ( "unsloth/llama-2-7b", "meta-llama/Llama-2-7b-hf", ), - "unsloth/llama-2-13b-bnb-4bit" : ( - "unsloth/llama-2-13b", - "meta-llama/Llama-2-13b-hf", - ), - "unsloth/codellama-34b-bnb-4bit" : ( - "codellama/CodeLlama-34b-hf", - ), - "unsloth/zephyr-sft-bnb-4bit" : ( - "unsloth/zephyr-sft", - "HuggingFaceH4/mistral-7b-sft-beta", - ), - "unsloth/tinyllama-bnb-4bit" : ( + "shashikanth-a/tinyllama-4bit" : ( "unsloth/tinyllama", "TinyLlama/TinyLlama-1.1B-intermediate-step-1431k-3T", ), - "unsloth/tinyllama-chat-bnb-4bit" : ( + "shashikanth-a/tinyllama-chat-4bit" : ( "unsloth/tinyllama-chat", "TinyLlama/TinyLlama-1.1B-Chat-v1.0", ), - "unsloth/mistral-7b-instruct-v0.1-bnb-4bit" : ( - "unsloth/mistral-7b-instruct-v0.1", - "mistralai/Mistral-7B-Instruct-v0.1", - ), - "unsloth/mistral-7b-instruct-v0.2-bnb-4bit" : ( - "unsloth/mistral-7b-instruct-v0.2", - "mistralai/Mistral-7B-Instruct-v0.2", - ), - "unsloth/llama-2-7b-chat-bnb-4bit" : ( - "unsloth/llama-2-7b-chat", - "meta-llama/Llama-2-7b-chat-hf", - ), - "unsloth/llama-2-7b-chat-bnb-4bit" : ( + "shashikanth-a/llama-2-7b-chat-4bit" : ( "unsloth/llama-2-7b-chat", "meta-llama/Llama-2-7b-chat-hf", ), - "unsloth/codellama-7b-bnb-4bit" : ( + "shashikanth-a/codellama-7b-4bit" : ( "unsloth/codellama-7b", "codellama/CodeLlama-7b-hf", ), - "unsloth/codellama-13b-bnb-4bit" : ( - "codellama/CodeLlama-13b-hf", - ), - "unsloth/yi-6b-bnb-4bit" : ( + "shashikanth-a/yi-6b-4bit" : ( "unsloth/yi-6b", "01-ai/Yi-6B", ), - "unsloth/solar-10.7b-bnb-4bit" : ( + "shashikanth-a/solar-10.7b-4bit" : ( "upstage/SOLAR-10.7B-v1.0", ), - "unsloth/gemma-7b-bnb-4bit" : ( - "unsloth/gemma-7b", - "google/gemma-7b", - ), - "unsloth/gemma-2b-bnb-4bit" : ( - "unsloth/gemma-2b", - "google/gemma-2b", - ), - "unsloth/gemma-7b-it-bnb-4bit" : ( - "unsloth/gemma-7b-it", - "google/gemma-7b-it", - ), - "unsloth/gemma-2b-bnb-4bit" : ( - "unsloth/gemma-2b-it", - "google/gemma-2b-it", - ), - "unsloth/mistral-7b-v0.2-bnb-4bit" : ( - "unsloth/mistral-7b-v0.2", - "alpindale/Mistral-7B-v0.2-hf", - ), - "unsloth/gemma-1.1-2b-it-bnb-4bit" : ( - "unsloth/gemma-1.1-2b-it", - "google/gemma-1.1-2b-it", - ), - "unsloth/gemma-1.1-7b-it-bnb-4bit" : ( - "unsloth/gemma-1.1-7b-it", - "google/gemma-1.1-7b-it", - ), - "unsloth/Starling-LM-7B-beta-bnb-4bit" : ( - "unsloth/Starling-LM-7B-beta", - "Nexusflow/Starling-LM-7B-beta", - ), - "unsloth/Hermes-2-Pro-Mistral-7B-bnb-4bit" : ( - "unsloth/Hermes-2-Pro-Mistral-7B", - "NousResearch/Hermes-2-Pro-Mistral-7B", - ), - "unsloth/OpenHermes-2.5-Mistral-7B-bnb-4bit" : ( - "unsloth/OpenHermes-2.5-Mistral-7B", - "teknium/OpenHermes-2.5-Mistral-7B", - ), - "unsloth/codegemma-2b-bnb-4bit" : ( - "unsloth/codegemma-2b", - "google/codegemma-2b", - ), - "unsloth/codegemma-7b-bnb-4bit" : ( - "unsloth/codegemma-7b", - "google/codegemma-7b", - ), - "unsloth/codegemma-7b-it-bnb-4bit" : ( - "unsloth/codegemma-7b-it", - "google/codegemma-7b-it", - ), - "unsloth/llama-3-8b-bnb-4bit" : ( + "shashikanth-a/llama-3-8b-4bit" : ( "unsloth/llama-3-8b", "meta-llama/Meta-Llama-3-8B", ), - "unsloth/llama-3-8b-Instruct-bnb-4bit" : ( + "shashikanth-a/llama-3-8b-Instruct-4bit" : ( "unsloth/llama-3-8b-Instruct", "meta-llama/Meta-Llama-3-8B-Instruct", ), - "unsloth/llama-3-70b-bnb-4bit" : ( - "meta-llama/Meta-Llama-3-70B", - ), - "unsloth/llama-3-70b-Instruct-bnb-4bit" : ( - "meta-llama/Meta-Llama-3-70B-Instruct", - ), - "unsloth/Phi-3-mini-4k-instruct-bnb-4bit" : ( - "unsloth/Phi-3-mini-4k-instruct", - "microsoft/Phi-3-mini-4k-instruct", - ), - "unsloth/mistral-7b-v0.3-bnb-4bit" : ( - "unsloth/mistral-7b-v0.3", - "mistralai/Mistral-7B-v0.3", - ), - "unsloth/mistral-7b-instruct-v0.3-bnb-4bit" : ( - "unsloth/mistral-7b-instruct-v0.3", - "mistralai/Mistral-7B-Instruct-v0.3", - ), - "unsloth/Phi-3-medium-4k-instruct-bnb-4bit" : ( - "unsloth/Phi-3-medium-4k-instruct", - "microsoft/Phi-3-medium-4k-instruct", - ), - "unsloth/Qwen2-0.5B-bnb-4bit" : ( - "unsloth/Qwen2-0.5B", - "Qwen/Qwen2-0.5B", - ), - "unsloth/Qwen2-0.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-0.5B-Instruct", - "Qwen/Qwen2-0.5B-Instruct", - ), - "unsloth/Qwen2-1.5B-bnb-4bit" : ( - "unsloth/Qwen2-1.5B", - "Qwen/Qwen2-1.5B", - ), - "unsloth/Qwen2-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-1.5B-Instruct", - "Qwen/Qwen2-1.5B-Instruct", - ), - "unsloth/Qwen2-7B-bnb-4bit" : ( - "unsloth/Qwen2-7B", - "Qwen/Qwen2-7B", - ), - "unsloth/Qwen2-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-7B-Instruct", - "Qwen/Qwen2-7B-Instruct", - ), - "unsloth/Qwen2-70B-bnb-4bit" : ( - "Qwen/Qwen2-70B", - ), - "unsloth/Qwen2-70B-Instruct-bnb-4bit" : ( - "Qwen/Qwen2-70B-Instruct", - ), - "mistralai/Codestral-22B-v0.1" : ( - "mistral-community/Codestral-22B-v0.1", - ), - "unsloth/gemma-2-9b-bnb-4bit" : ( - "unsloth/gemma-2-9b", - "google/gemma-2-9b", - ), - "unsloth/gemma-2-27b-bnb-4bit" : ( - "unsloth/gemma-2-27b", - "google/gemma-2-27b", - ), - "unsloth/gemma-2-9b-it-bnb-4bit" : ( - "unsloth/gemma-2-9b-it", - "google/gemma-2-9b-it", - ), - "unsloth/gemma-2-27b-it-bnb-4bit" : ( - "unsloth/gemma-2-27b-it", - "google/gemma-2-27b-it", - ), - "unsloth/Phi-3-mini-4k-instruct-v0-bnb-4bit" : ( # Old Phi pre July - "unsloth/Phi-3-mini-4k-instruct-v0", - ), - "unsloth/Mistral-Nemo-Instruct-2407-bnb-4bit" : ( # New 12b Mistral models - "unsloth/Mistral-Nemo-Instruct-2407", - "mistralai/Mistral-Nemo-Instruct-2407", - ), - "unsloth/Mistral-Nemo-Base-2407-bnb-4bit" : ( # New 12b Mistral models - "unsloth/Mistral-Nemo-Base-2407", - "mistralai/Mistral-Nemo-Base-2407", - ), - "unsloth/Meta-Llama-3.1-8B-bnb-4bit" : ( + "shashikanth-a/Meta-Llama-3.1-8B-4bit" : ( "unsloth/Meta-Llama-3.1-8B", "meta-llama/Meta-Llama-3.1-8B", ), - "unsloth/Meta-Llama-3.1-8B-Instruct-bnb-4bit" : ( + "shashikanth-a/Meta-Llama-3.1-8B-Instruct-4bit" : ( "unsloth/Meta-Llama-3.1-8B-Instruct", "meta-llama/Meta-Llama-3.1-8B-Instruct", ), - "unsloth/Meta-Llama-3.1-70B-bnb-4bit" : ( - "unsloth/Meta-Llama-3.1-70B", - "meta-llama/Meta-Llama-3.1-70B", - ), - "unsloth/Meta-Llama-3.1-405B-bnb-4bit" : ( - "meta-llama/Meta-Llama-3.1-405B", - ), - "unsloth/Meta-Llama-3.1-405B-Instruct-bnb-4bit" : ( - "meta-llama/Meta-Llama-3.1-405B-Instruct", - ), - "unsloth/Meta-Llama-3.1-70B-Instruct-bnb-4bit" : ( - "unsloth/Meta-Llama-3.1-70B-Instruct", - "meta-llama/Meta-Llama-3.1-70B-Instruct", - ), - "unsloth/Mistral-Large-Instruct-2407-bnb-4bit" : ( - "mistralai/Mistral-Large-Instruct-2407", - ), - "unsloth/gemma-2-2b-bnb-4bit" : ( - "unsloth/gemma-2-2b", - "google/gemma-2-2b", - ), - "unsloth/gemma-2-2b-it-bnb-4bit" : ( - "unsloth/gemma-2-2b-it", - "google/gemma-2-2b-it", - ), - "unsloth/Phi-3.5-mini-instruct-bnb-4bit" : ( - "unsloth/Phi-3.5-mini-instruct", - "microsoft/Phi-3.5-mini-instruct", - ), - "unsloth/c4ai-command-r-08-2024-bnb-4bit" : ( - "CohereForAI/c4ai-command-r-08-2024", - ), - "unsloth/c4ai-command-r-plus-08-2024-bnb-4bit" : ( - "CohereForAI/c4ai-command-r-plus-08-2024", - ), - "unsloth/Llama-3.1-Storm-8B-bnb-4bit" : ( + "shashikanth-a/Llama-3.1-Storm-8B-4bit" : ( "unsloth/Llama-3.1-Storm-8B", "akjindal53244/Llama-3.1-Storm-8B", ), - "unsloth/Hermes-3-Llama-3.1-8B-bnb-4bit" : ( + "shashikanth-a/Hermes-3-Llama-3.1-8B-4bit" : ( "unsloth/Hermes-3-Llama-3.1-8B", "NousResearch/Hermes-3-Llama-3.1-8B", ), - "unsloth/Hermes-3-Llama-3.1-70B-bnb-4bit" : ( - "unsloth/Hermes-3-Llama-3.1-70B", - "NousResearch/Hermes-3-Llama-3.1-70B", - ), - "unsloth/Hermes-3-Llama-3.1-405B-bnb-4bit" : ( - "NousResearch/Hermes-3-Llama-3.1-405B", - ), - "unsloth/SmolLM-135M-bnb-4bit" : ( + "shashikanth-a/SmolLM-135M-4bit" : ( "unsloth/SmolLM-135M", "HuggingFaceTB/SmolLM-135M", ), - "unsloth/SmolLM-360M-bnb-4bit" : ( + "shashikanth-a/SmolLM-360M-4bit" : ( "unsloth/SmolLM-360M", "HuggingFaceTB/SmolLM-360M", ), - "unsloth/SmolLM-1.7B-bnb-4bit" : ( + "shashikanth-a/SmolLM-1.7B-4bit" : ( "unsloth/SmolLM-1.7B", "HuggingFaceTB/SmolLM-1.7B", ), - "unsloth/SmolLM-135M-Instruct-bnb-4bit" : ( + "shashikanth-a/SmolLM-135M-Instruct-4bit" : ( "unsloth/SmolLM-135M-Instruct", "HuggingFaceTB/SmolLM-135M-Instruct", ), - "unsloth/SmolLM-360M-Instruct-bnb-4bit" : ( + "shashikanth-a/SmolLM-360M-Instruct-4bit" : ( "unsloth/SmolLM-360M-Instruct", "HuggingFaceTB/SmolLM-360M-Instruct", ), - "unsloth/SmolLM-1.7B-Instruct-bnb-4bit" : ( + "shashikanth-a/SmolLM-1.7B-Instruct-4bit" : ( "unsloth/SmolLM-1.7B-Instruct", "HuggingFaceTB/SmolLM-1.7B-Instruct", ), - "unsloth/Mistral-Small-Instruct-2409-bnb-4bit" : ( - "unsloth/Mistral-Small-Instruct-2409", - "mistralai/Mistral-Small-Instruct-2409", - ), - "unsloth/Qwen2.5-0.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-0.5B-Instruct", - "Qwen/Qwen2.5-0.5B-Instruct", - ), - "unsloth/Qwen2.5-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-1.5B-Instruct", - "Qwen/Qwen2.5-1.5B-Instruct", - ), - "unsloth/Qwen2.5-3B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-3B-Instruct", - "Qwen/Qwen2.5-3B-Instruct", - ), - "unsloth/Qwen2.5-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-7B-Instruct", - "Qwen/Qwen2.5-7B-Instruct", - ), - "unsloth/Qwen2.5-14B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-14B-Instruct", - "Qwen/Qwen2.5-14B-Instruct", - ), - "unsloth/Qwen2.5-32B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-32B-Instruct", - "Qwen/Qwen2.5-32B-Instruct", - ), - "unsloth/Qwen2.5-72B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-72B-Instruct", - "Qwen/Qwen2.5-72B-Instruct", - ), - "unsloth/Qwen2.5-0.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-0.5B", - "Qwen/Qwen2.5-0.5B", - ), - "unsloth/Qwen2.5-1.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-1.5B", - "Qwen/Qwen2.5-1.5B", - ), - "unsloth/Qwen2.5-3B-bnb-4bit" : ( - "unsloth/Qwen2.5-3B", - "Qwen/Qwen2.5-3B", - ), - "unsloth/Qwen2.5-7B-bnb-4bit" : ( - "unsloth/Qwen2.5-7B", - "Qwen/Qwen2.5-7B", - ), - "unsloth/Qwen2.5-14B-bnb-4bit" : ( - "unsloth/Qwen2.5-14B", - "Qwen/Qwen2.5-14B", - ), - "unsloth/Qwen2.5-32B-bnb-4bit" : ( - "unsloth/Qwen2.5-32B", - "Qwen/Qwen2.5-32B", - ), - "unsloth/Qwen2.5-72B-bnb-4bit" : ( - "unsloth/Qwen2.5-72B", - "Qwen/Qwen2.5-72B", - ), - "unsloth/Qwen2.5-Math-1.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-1.5B", - "Qwen/Qwen2.5-Math-1.5B", - ), - "unsloth/Qwen2.5-Math-7B-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-7B", - "Qwen/Qwen2.5-Math-7B", - ), - "unsloth/Qwen2.5-Math-72B-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-72B", - "Qwen/Qwen2.5-Math-72B", - ), - "unsloth/Qwen2.5-Math-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-1.5B-Instruct", - "Qwen/Qwen2.5-Math-1.5B-Instruct", - ), - "unsloth/Qwen2.5-Math-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-7B-Instruct", - "Qwen/Qwen2.5-Math-7B-Instruct", - ), - "unsloth/Qwen2.5-Math-72B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Math-72B-Instruct", - "Qwen/Qwen2.5-Math-72B-Instruct", - ), - "unsloth/Qwen2.5-Coder-0.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-0.5B", - "Qwen/Qwen2.5-Coder-0.5B", - ), - "unsloth/Qwen2.5-Coder-1.5B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-1.5B", - "Qwen/Qwen2.5-Coder-1.5B", - ), - "unsloth/Qwen2.5-Coder-3B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-3B", - "Qwen/Qwen2.5-Coder-3B", - ), - "unsloth/Qwen2.5-Coder-7B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-7B", - "Qwen/Qwen2.5-Coder-7B", - ), - "unsloth/Qwen2.5-Coder-14B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-14B", - "Qwen/Qwen2.5-Coder-14B", - ), - "unsloth/Qwen2.5-Coder-32B-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-32B", - "Qwen/Qwen2.5-Coder-32B", - ), - "unsloth/Qwen2.5-Coder-0.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-0.5B-Instruct", - "Qwen/Qwen2.5-Coder-0.5B-Instruct", - ), - "unsloth/Qwen2.5-Coder-1.5B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-1.5B-Instruct", - "Qwen/Qwen2.5-Coder-1.5B-Instruct", - ), - "unsloth/Qwen2.5-Coder-3B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-3B-Instruct", - "Qwen/Qwen2.5-Coder-3B-Instruct", - ), - "unsloth/Qwen2.5-Coder-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-7B-Instruct", - "Qwen/Qwen2.5-Coder-7B-Instruct", - ), - "unsloth/Qwen2.5-Coder-14B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-14B-Instruct", - "Qwen/Qwen2.5-Coder-14B-Instruct", - ), - "unsloth/Qwen2.5-Coder-32B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2.5-Coder-32B-Instruct", - "Qwen/Qwen2.5-Coder-32B-Instruct", - ), - "unsloth/Llama-3.2-1B-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-1B-4bit" : ( "unsloth/Llama-3.2-1B", "meta-llama/Llama-3.2-1B", ), - "unsloth/Llama-3.2-3B-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-3B-4bit" : ( "unsloth/Llama-3.2-3B", "meta-llama/Llama-3.2-3B", ), - "unsloth/Llama-3.2-1B-Instruct-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-1B-Instruct-4bit" : ( "unsloth/Llama-3.2-1B-Instruct", "meta-llama/Llama-3.2-1B-Instruct", ), - "unsloth/Llama-3.2-3B-Instruct-bnb-4bit" : ( + "shashikanth-a/Llama-3.2-3B-Instruct-4bit" : ( "unsloth/Llama-3.2-3B-Instruct", "meta-llama/Llama-3.2-3B-Instruct", ), - "unsloth/Llama-3.1-Nemotron-70B-Instruct-bnb-4bit" : ( - "unsloth/Llama-3.1-Nemotron-70B-Instruct", - "nvidia/Llama-3.1-Nemotron-70B-Instruct-HF", - ), - "unsloth/Qwen2-VL-2B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-VL-2B-Instruct", - "Qwen/Qwen2-VL-2B-Instruct", - ), - "unsloth/Qwen2-VL-7B-Instruct-bnb-4bit" : ( - "unsloth/Qwen2-VL-7B-Instruct", - "Qwen/Qwen2-VL-7B-Instruct", - ), - "unsloth/Llama-3.2-11B-Vision-Instruct-bnb-4bit" : ( - "unsloth/Llama-3.2-11B-Vision-Instruct", - "meta-llama/Llama-3.2-11B-Vision-Instruct", - ), - "unsloth/Llama-3.2-90B-Vision-Instruct-bnb-4bit" : ( - "unsloth/Llama-3.2-90B-Vision-Instruct", - "meta-llama/Llama-3.2-90B-Vision-Instruct", - ), - "unsloth/Llama-3.2-11B-Vision-bnb-4bit" : ( - "unsloth/Llama-3.2-11B-Vision", - "meta-llama/Llama-3.2-11B-Vision", - ), - "unsloth/Llama-3.2-90B-Vision-bnb-4bit" : ( - "unsloth/Llama-3.2-90B-Vision", - "meta-llama/Llama-3.2-90B-Vision", - ), - "unsloth/Pixtral-12B-2409-bnb-4bit" : ( - "unsloth/Pixtral-12B-2409", - "mistralai/Pixtral-12B-2409", - ), - "unsloth/Pixtral-12B-2409-Base-bnb-4bit" : ( - "unsloth/Pixtral-12B-Base-2409", - "mistralai/Pixtral-12B-Base-2409", - ), - "unsloth/llava-1.5-7b-hf-bnb-4bit" : ( - "unsloth/llava-1.5-7b-hf", - "llava-hf/llava-1.5-7b-hf", - ), - "unsloth/llava-v1.6-mistral-7b-hf-bnb-4bit" : ( - "unsloth/llava-v1.6-mistral-7b-hf", - "llava-hf/llava-v1.6-mistral-7b-hf", - ), - "unsloth/Llama-3.1-Tulu-3-8B-bnb-4bit" : ( - "unsloth/Llama-3.1-Tulu-3-8B", - "allenai/Llama-3.1-Tulu-3-8B", - ), - "unsloth/Llama-3.1-Tulu-3-70B-bnb-4bit" : ( - "unsloth/Llama-3.1-Tulu-3-70B", - "allenai/Llama-3.1-Tulu-3-70B", - ), } INT_TO_FLOAT_MAPPER = {} diff --git a/unsloth/save.py b/unsloth/save.py index b503b2b4..cb452218 100644 --- a/unsloth/save.py +++ b/unsloth/save.py @@ -12,9 +12,13 @@ # See the License for the specific language governing permissions and # limitations under the License. -from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit -from peft.tuners.lora import Linear4bit as Peft_Linear4bit -from peft.tuners.lora import Linear as Peft_Linear +from unsloth.devices import has_mps + +if not has_mps(): + from bitsandbytes.nn import Linear4bit as Bnb_Linear4bit + from peft.tuners.lora import Linear4bit as Peft_Linear4bit + from peft.tuners.lora import Linear as Peft_Linear + from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias from typing import Optional, Callable, Union, List import torch import os @@ -22,7 +26,6 @@ import pickle import gc from transformers.models.llama.modeling_llama import logger -from .kernels import fast_dequantize, QUANT_STATE, get_lora_parameters_bias import subprocess import psutil import re diff --git a/unsloth/sd_hijack_utils.py b/unsloth/sd_hijack_utils.py new file mode 100644 index 00000000..179ebc78 --- /dev/null +++ b/unsloth/sd_hijack_utils.py @@ -0,0 +1,28 @@ +import importlib + +class CondFunc: + def __new__(cls, orig_func, sub_func, cond_func): + self = super(CondFunc, cls).__new__(cls) + if isinstance(orig_func, str): + func_path = orig_func.split('.') + for i in range(len(func_path)-1, -1, -1): + try: + resolved_obj = importlib.import_module('.'.join(func_path[:i])) + break + except ImportError: + pass + for attr_name in func_path[i:-1]: + resolved_obj = getattr(resolved_obj, attr_name) + orig_func = getattr(resolved_obj, func_path[-1]) + setattr(resolved_obj, func_path[-1], lambda *args, **kwargs: self(*args, **kwargs)) + self.__init__(orig_func, sub_func, cond_func) + return lambda *args, **kwargs: self(*args, **kwargs) + def __init__(self, orig_func, sub_func, cond_func): + self.__orig_func = orig_func + self.__sub_func = sub_func + self.__cond_func = cond_func + def __call__(self, *args, **kwargs): + if not self.__cond_func or self.__cond_func(self.__orig_func, *args, **kwargs): + return self.__sub_func(self.__orig_func, *args, **kwargs) + else: + return self.__orig_func(*args, **kwargs) diff --git a/unsloth/tokenizer_utils.py b/unsloth/tokenizer_utils.py index 302017d5..a0e3ac5a 100644 --- a/unsloth/tokenizer_utils.py +++ b/unsloth/tokenizer_utils.py @@ -24,6 +24,7 @@ import collections import numpy as np import gc +from unsloth import devices import subprocess from unsloth_zoo.tokenizer_utils import ( @@ -837,7 +838,7 @@ def check_nvidia(): output = re.findall(rb'([\d]{1,})[\s]{1,}M', output) output = np.array([int(x.decode('utf-8'))/1024 for x in output]) except: - if not torch.cuda.is_available(): + if not torch.cuda.is_available() and not devices.has_mps: raise RuntimeError("Unsloth: We do not support AMD / Intel machines yet - it is a work in progress!") return output pass