From b1d29212222f5dd566fab5453e3fd342089c8191 Mon Sep 17 00:00:00 2001 From: Wing Lian Date: Fri, 16 Aug 2024 21:32:00 -0400 Subject: [PATCH] add validation to prevent 8bit lora finetuning on H100s (#1827) --- .../utils/config/models/input/v0_4_1/__init__.py | 13 +++++++++++++ 1 file changed, 13 insertions(+) diff --git a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py index 647d6b88c7..5e690bb88e 100644 --- a/src/axolotl/utils/config/models/input/v0_4_1/__init__.py +++ b/src/axolotl/utils/config/models/input/v0_4_1/__init__.py @@ -1267,6 +1267,19 @@ def check_sample_packing_w_sdpa_bf16(cls, data): return data + @model_validator(mode="before") + @classmethod + def check_hopper_8bit_lora(cls, data): + is_sm_90: bool = ( + data["capabilities"] + and data["capabilities"].get("compute_capability") == "sm_90" + ) + if data.get("adapter") and data.get("load_in_8bit") and is_sm_90: + # see https://github.com/bitsandbytes-foundation/bitsandbytes/issues/538#issuecomment-2262945464 + raise ValueError("8-bit LoRA is not supported on Hopper GPUs") + + return data + @model_validator(mode="before") @classmethod def check_fsdp_deepspeed(cls, data):