Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added Support for Apple Silicon #1289

Open
wants to merge 5 commits into
base: main
Choose a base branch
from
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
94 changes: 59 additions & 35 deletions unsloth-cli.py
Original file line number Diff line number Diff line change
Expand Up @@ -31,38 +31,49 @@

import argparse

from unsloth.devices import has_mps

def run(args):
import torch
from unsloth import FastLanguageModel
from datasets import load_dataset
from trl import SFTTrainer
from transformers import TrainingArguments
from unsloth import is_bfloat16_supported
import logging
logging.getLogger('hf-to-gguf').setLevel(logging.WARNING)
if has_mps:
from unsloth.mlx import mlx_utils
from unsloth.mlx import lora as mlx_lora
import gc

if not has_mps:
from unsloth import FastLanguageModel
# Load model and tokenizer
model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_seq_length,
dtype=args.dtype,
load_in_4bit=args.load_in_4bit,
)

model, tokenizer = FastLanguageModel.from_pretrained(
model_name=args.model_name,
max_seq_length=args.max_seq_length,
dtype=args.dtype,
load_in_4bit=args.load_in_4bit,
)
else:
print("Loading pretrained model")
model, tokenizer, config = mlx_utils.load_pretrained(args.model_name,dtype=args.dtype,load_in_4bit=args.load_in_4bit)

# Configure PEFT model
model = FastLanguageModel.get_peft_model(
model,
r=args.r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.bias,
use_gradient_checkpointing=args.use_gradient_checkpointing,
random_state=args.random_state,
use_rslora=args.use_rslora,
loftq_config=args.loftq_config,
)
if not has_mps:
model = FastLanguageModel.get_peft_model(
model,
r=args.r,
target_modules=["q_proj", "k_proj", "v_proj", "o_proj",
"gate_proj", "up_proj", "down_proj"],
lora_alpha=args.lora_alpha,
lora_dropout=args.lora_dropout,
bias=args.bias,
use_gradient_checkpointing=args.use_gradient_checkpointing,
random_state=args.random_state,
use_rslora=args.use_rslora,
loftq_config=args.loftq_config,
)

alpaca_prompt = """Below is an instruction that describes a task, paired with an input that provides further context. Write a response that appropriately completes the request.

Expand Down Expand Up @@ -110,19 +121,24 @@ def formatting_prompts_func(examples):
)

# Initialize trainer
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
dataset_num_proc=2,
packing=False,
args=training_args,
)
if not has_mps:
trainer = SFTTrainer(
model=model,
tokenizer=tokenizer,
train_dataset=dataset,
dataset_text_field="text",
max_seq_length=args.max_seq_length,
dataset_num_proc=2,
packing=False,
args=training_args,
)

# Train model
trainer_stats = trainer.train()
trainer_stats = trainer.train()
else:
datasets = dataset.train_test_split(test_size=0.1)
mlx_lora.train_model(args,model,tokenizer, datasets["train"], datasets["test"])


# Save model
if args.save_model:
Expand Down Expand Up @@ -152,9 +168,16 @@ def formatting_prompts_func(examples):
quantization_method=quantization_method,
)
else:
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
if args.push_model:
model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
if has_mps:
del model
gc.collect()
mlx_utils.save_merged_model(args)
if args.push_model:
mlx_utils.push_to_hub(args,config["_name_or_path"],config["model_type"])
else:
model.save_pretrained_merged(args.save_path, tokenizer, args.save_method)
if args.push_model:
model.push_to_hub_merged(args.save_path, tokenizer, args.hub_token)
else:
print("Warning: The model is not saved!")

Expand Down Expand Up @@ -203,6 +226,7 @@ def formatting_prompts_func(examples):

# Saving and pushing arguments
save_group = parser.add_argument_group('💾 Save Model Options')
save_group.add_argument('--adapter_file', type=str, default="adapters.safetensors", help="Adapters file name")
save_group.add_argument('--output_dir', type=str, default="outputs", help="Output directory")
save_group.add_argument('--save_model', action='store_true', help="Save the model after training")
save_group.add_argument('--save_method', type=str, default="merged_16bit", choices=["merged_16bit", "merged_4bit", "lora"], help="Save method for the model, default is 'merged_16bit'")
Expand Down
38 changes: 25 additions & 13 deletions unsloth/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -16,6 +16,7 @@
from packaging.version import Version
import os, re, subprocess, inspect
import numpy as np
from unsloth import devices

# # Define a list of modules to check
# MODULES_TO_CHECK = ["bitsandbytes"]
Expand Down Expand Up @@ -90,7 +91,11 @@
pass

# Torch 2.4 has including_emulation
major_version, minor_version = torch.cuda.get_device_capability()
devices.get_optimal_device()
if torch.cuda.is_available():
major_version, minor_version = torch.cuda.get_device_capability()
else:
major_version,minor_version = 0,0
SUPPORTS_BFLOAT16 = (major_version >= 8)

old_is_bf16_supported = torch.cuda.is_bf16_supported
Expand All @@ -104,7 +109,6 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
pass

# Try loading bitsandbytes and triton
import bitsandbytes as bnb

if "SPACE_AUTHOR_NAME" not in os.environ and "SPACE_REPO_NAME" not in os.environ:

Expand All @@ -116,7 +120,9 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
else: from triton.common.build import libcuda_dirs

try:
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
if not devices.has_mps:
import bitsandbytes as bnb
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
Expand All @@ -141,16 +147,19 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
latest_cuda = possible_cudas[latest_cuda]
os.system(f"ldconfig /usr/local/{latest_cuda}")
pass

importlib.reload(bnb)
if not devices.has_mps:
import bitsandbytes as bnb
importlib.reload(bnb)
importlib.reload(triton)
try:
libcuda_dirs = lambda: None
if Version(triton.__version__) >= Version("3.0.0"):
try: from triton.backends.nvidia.driver import libcuda_dirs
except: pass
else: from triton.common.build import libcuda_dirs
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
if not devices.has_mps:
import bitsandbytes as bnb
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
libcuda_dirs()
except:
warnings.warn(
Expand All @@ -171,11 +180,14 @@ def is_bf16_supported(): return SUPPORTS_BFLOAT16
raise ImportError("Unsloth: Please install unsloth_zoo via `pip install unsloth-zoo`")
pass

from .models import *
from .save import *
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *
if not devices.has_mps:
from .models import *
from .save import *
from .chat_templates import *
from .tokenizer_utils import *
from .trainer import *

# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
# Patch TRL trainers for backwards compatibility
_patch_trl_trainer()
else:
from .models._utils import is_bfloat16_supported
49 changes: 49 additions & 0 deletions unsloth/devices.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,49 @@
import sys

import torch

if sys.platform == "darwin":
from unsloth import mac_specific


def has_mps() -> bool:
if sys.platform != "darwin":
return False
else:
return mac_specific.has_mps


def get_cuda_device_string():
return "cuda"


def get_optimal_device_name():
if torch.cuda.is_available():
return get_cuda_device_string()

if has_mps():
return "mps"

return "cpu"


def get_optimal_device():
return torch.device(get_optimal_device_name())



def torch_gc():

if torch.cuda.is_available():
with torch.cuda.device(get_cuda_device_string()):
torch.cuda.empty_cache()
torch.cuda.ipc_collect()

if has_mps():
mac_specific.torch_mps_gc()






22 changes: 13 additions & 9 deletions unsloth/kernels/utils.py
Original file line number Diff line number Diff line change
Expand Up @@ -59,19 +59,23 @@ def calculate_settings(n : int) -> (int, int,):
pass


import bitsandbytes as bnb
from unsloth import devices

# https://github.com/bitsandbytes-foundation/bitsandbytes/pull/1330/files
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
HAS_CUDA_STREAM = False
global CUDA_STREAM
CUDA_STREAM = None
get_ptr = bnb.functional.get_ptr
import ctypes
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16
if not devices.has_mps:
import bitsandbytes as bnb
HAS_CUDA_STREAM = Version(bnb.__version__) > Version("0.43.3")
get_ptr = bnb.functional.get_ptr
cdequantize_blockwise_fp32 = bnb.functional.lib.cdequantize_blockwise_fp32
cdequantize_blockwise_fp16_nf4 = bnb.functional.lib.cdequantize_blockwise_fp16_nf4
cdequantize_blockwise_bf16_nf4 = bnb.functional.lib.cdequantize_blockwise_bf16_nf4
cgemm_4bit_inference_naive_fp16 = bnb.functional.lib.cgemm_4bit_inference_naive_fp16
cgemm_4bit_inference_naive_bf16 = bnb.functional.lib.cgemm_4bit_inference_naive_bf16

import ctypes

def QUANT_STATE(W):
return getattr(W, "quant_state", None)
Expand Down
71 changes: 71 additions & 0 deletions unsloth/mac_specific.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import logging

import torch
import platform
from unsloth.sd_hijack_utils import CondFunc
from packaging import version

log = logging.getLogger(__name__)


# before torch version 1.13, has_mps is only available in nightly pytorch and macOS 12.3+,
# use check `getattr` and try it for compatibility.
# in torch version 1.13, backends.mps.is_available() and backends.mps.is_built() are introduced in to check mps availabilty,
# since torch 2.0.1+ nightly build, getattr(torch, 'has_mps', False) was deprecated, see https://github.com/pytorch/pytorch/pull/103279
def check_for_mps() -> bool:
if version.parse(torch.__version__) <= version.parse("2.0.1"):
if not getattr(torch, 'has_mps', False):
return False
try:
torch.zeros(1).to(torch.device("mps"))
return True
except Exception:
return False
else:
return torch.backends.mps.is_available() and torch.backends.mps.is_built()


has_mps = check_for_mps()


# MPS workaround for https://github.com/pytorch/pytorch/issues/89784
def cumsum_fix(input, cumsum_func, *args, **kwargs):
if input.device.type == 'mps':
output_dtype = kwargs.get('dtype', input.dtype)
if output_dtype == torch.int64:
return cumsum_func(input.cpu(), *args, **kwargs).to(input.device)
elif output_dtype == torch.bool or cumsum_needs_int_fix and (output_dtype == torch.int8 or output_dtype == torch.int16):
return cumsum_func(input.to(torch.int32), *args, **kwargs).to(torch.int64)
return cumsum_func(input, *args, **kwargs)


if has_mps:
if platform.mac_ver()[0].startswith("13.2."):
# MPS workaround for https://github.com/pytorch/pytorch/issues/95188, thanks to danieldk (https://github.com/explosion/curated-transformers/pull/124)
CondFunc('torch.nn.functional.linear', lambda _, input, weight, bias: (torch.matmul(input, weight.t()) + bias) if bias is not None else torch.matmul(input, weight.t()), lambda _, input, weight, bias: input.numel() > 10485760)

if version.parse(torch.__version__) < version.parse("1.13"):
# PyTorch 1.13 doesn't need these fixes but unfortunately is slower and has regressions that prevent training from working

# MPS workaround for https://github.com/pytorch/pytorch/issues/79383
CondFunc('torch.Tensor.to', lambda orig_func, self, *args, **kwargs: orig_func(self.contiguous(), *args, **kwargs),
lambda _, self, *args, **kwargs: self.device.type != 'mps' and (args and isinstance(args[0], torch.device) and args[0].type == 'mps' or isinstance(kwargs.get('device'), torch.device) and kwargs['device'].type == 'mps'))
# MPS workaround for https://github.com/pytorch/pytorch/issues/80800
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, *args, **kwargs: orig_func(*([args[0].contiguous()] + list(args[1:])), **kwargs),
lambda _, *args, **kwargs: args and isinstance(args[0], torch.Tensor) and args[0].device.type == 'mps')
# MPS workaround for https://github.com/pytorch/pytorch/issues/90532
CondFunc('torch.Tensor.numpy', lambda orig_func, self, *args, **kwargs: orig_func(self.detach(), *args, **kwargs), lambda _, self, *args, **kwargs: self.requires_grad)
elif version.parse(torch.__version__) > version.parse("1.13.1"):
cumsum_needs_int_fix = not torch.Tensor([1,2]).to(torch.device("mps")).equal(torch.ShortTensor([1,1]).to(torch.device("mps")).cumsum(0))
cumsum_fix_func = lambda orig_func, input, *args, **kwargs: cumsum_fix(input, orig_func, *args, **kwargs)
CondFunc('torch.cumsum', cumsum_fix_func, None)
CondFunc('torch.Tensor.cumsum', cumsum_fix_func, None)
CondFunc('torch.narrow', lambda orig_func, *args, **kwargs: orig_func(*args, **kwargs).clone(), None)

# MPS workaround for https://github.com/pytorch/pytorch/issues/96113
CondFunc('torch.nn.functional.layer_norm', lambda orig_func, x, normalized_shape, weight, bias, eps, **kwargs: orig_func(x.float(), normalized_shape, weight.float() if weight is not None else None, bias.float() if bias is not None else bias, eps).to(x.dtype), lambda _, input, *args, **kwargs: len(args) == 4 and input.device.type == 'mps')

# MPS workaround for https://github.com/pytorch/pytorch/issues/92311
if platform.processor() == 'i386':
for funcName in ['torch.argmax', 'torch.Tensor.argmax']:
CondFunc(funcName, lambda _, input, *args, **kwargs: torch.max(input.float() if input.dtype == torch.int64 else input, *args, **kwargs)[1], lambda _, input, *args, **kwargs: input.device.type == 'mps')
Loading