Skip to content

Commit

Permalink
fix for faulty bitsandbytes-foundation#1222 ("Add "lamb" to `str2op…
Browse files Browse the repository at this point in the history
…timizer32bit`") (bitsandbytes-foundation#1240)

* Revert "Add `"lamb"` to `str2optimizer32bit`"

* Update bitsandbytes/functional.py
  • Loading branch information
younesbelkada authored Jun 5, 2024
1 parent 1f2ca43 commit b22ae26
Showing 1 changed file with 36 additions and 88 deletions.
124 changes: 36 additions & 88 deletions bitsandbytes/functional.py
Original file line number Diff line number Diff line change
Expand Up @@ -27,128 +27,67 @@ def prod(iterable):
if lib and lib.compiled_with_cuda:
"""C FUNCTIONS FOR OPTIMIZERS"""
str2optimizer32bit = {
"adagrad": (
lib.cadagrad32bit_grad_fp32,
lib.cadagrad32bit_grad_fp16,
),
"adam": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"pagedadam": (
lib.cpagedadam32bit_grad_fp32,
lib.cpagedadam32bit_grad_fp16,
lib.cpagedadam32bit_grad_bf16,
),
"adamw": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
lib.cadam32bit_grad_bf16,
),
"pagedadamw": (
lib.cpagedadam32bit_grad_fp32,
lib.cpagedadam32bit_grad_fp16,
lib.cpagedadam32bit_grad_bf16,
),
"lamb": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
"momentum": (
lib.cmomentum32bit_grad_32,
lib.cmomentum32bit_grad_16,
),
"lars": (
lib.clars32bit_grad_fp32,
lib.clars32bit_grad_fp16,
"rmsprop": (
lib.crmsprop32bit_grad_32,
lib.crmsprop32bit_grad_16,
),
"lion": (
lib.clion32bit_grad_fp32,
lib.clion32bit_grad_fp16,
lib.clion32bit_grad_bf16,
),
"momentum": (
lib.cmomentum32bit_grad_fp32,
lib.cmomentum32bit_grad_fp16,
"adagrad": (
lib.cadagrad32bit_grad_32,
lib.cadagrad32bit_grad_16,
),
"rmsprop": (
lib.crmsprop32bit_grad_fp32,
lib.crmsprop32bit_grad_fp16,
"lamb": (
lib.cadam32bit_grad_fp32,
lib.cadam32bit_grad_fp16,
),
}

str2optimizer8bit = {
"adagrad": (
lib.cadagrad8bit_grad_fp32,
lib.cadagrad8bit_grad_fp16,
),
"adam": (
lib.cadam_static_8bit_grad_fp32,
lib.cadam_static_8bit_grad_fp16,
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
),
"pagedadam": (
lib.cpagedadam8bit_grad_fp32,
lib.cpagedadam8bit_grad_fp16,
lib.cpagedadam8bit_grad_bf16,
"momentum": (
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
),
"adamw": (
lib.cadam_static_8bit_grad_fp32,
lib.cadam_static_8bit_grad_fp16,
"rmsprop": (
lib.crmsprop_static_8bit_grad_32,
lib.crmsprop_static_8bit_grad_16,
),
"pagedadamw": (
lib.cpagedadam8bit_grad_fp32,
lib.cpagedadam8bit_grad_fp16,
lib.cpagedadam8bit_grad_bf16,
"lion": (
lib.clion_static_8bit_grad_32,
lib.clion_static_8bit_grad_16,
),
"lamb": (
lib.cadam_static_8bit_grad_fp32,
lib.cadam_static_8bit_grad_fp16,
lib.cadam_static_8bit_grad_32,
lib.cadam_static_8bit_grad_16,
),
"lars": (
lib.clars8bit_grad_fp32,
lib.clars8bit_grad_fp16,
),
"lion": (
lib.clion_static_8bit_grad_fp32,
lib.clion_static_8bit_grad_fp16,
),
"momentum": (
lib.cmomentum_static_8bit_grad_fp32,
lib.cmomentum_static_8bit_grad_fp16,
),
"rmsprop": (
lib.crmsprop_static_8bit_grad_fp32,
lib.crmsprop_static_8bit_grad_fp16,
lib.cmomentum_static_8bit_grad_32,
lib.cmomentum_static_8bit_grad_16,
),
}

str2optimizer8bit_blockwise = {
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
),
"adam": (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
),
"pagedadam": (
lib.cpagedadam8bit_blockwise_fp32,
lib.cpagedadam8bit_blockwise_fp16,
lib.cpagedadam8bit_blockwise_bf16,
),
"adamw": (
lib.cadam_8bit_blockwise_grad_fp32,
lib.cadam_8bit_blockwise_grad_fp16,
lib.cadam_8bit_blockwise_grad_bf16,
),
"pagedadamw": (
lib.cpagedadam8bit_blockwise_fp32,
lib.cpagedadam8bit_blockwise_fp16,
lib.cpagedadam8bit_blockwise_bf16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
),
"momentum": (
lib.cmomentum_8bit_blockwise_grad_fp32,
lib.cmomentum_8bit_blockwise_grad_fp16,
Expand All @@ -157,6 +96,15 @@ def prod(iterable):
lib.crmsprop_8bit_blockwise_grad_fp32,
lib.crmsprop_8bit_blockwise_grad_fp16,
),
"lion": (
lib.clion_8bit_blockwise_grad_fp32,
lib.clion_8bit_blockwise_grad_fp16,
lib.clion_8bit_blockwise_grad_bf16,
),
"adagrad": (
lib.cadagrad_8bit_blockwise_grad_fp32,
lib.cadagrad_8bit_blockwise_grad_fp16,
),
}


Expand Down

0 comments on commit b22ae26

Please sign in to comment.