diff --git a/bitsandbytes/optim/optimizer.py b/bitsandbytes/optim/optimizer.py index b1a8fbea2..3b21f09d2 100644 --- a/bitsandbytes/optim/optimizer.py +++ b/bitsandbytes/optim/optimizer.py @@ -5,11 +5,7 @@ from collections import abc as container_abcs, defaultdict from copy import deepcopy from itertools import chain -<<<<<<< HEAD from typing import Any, Dict, Optional -======= -from typing import Optional ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) import torch diff --git a/csrc/kernels.cu b/csrc/kernels.cu index 017659518..27bf5a85e 100644 --- a/csrc/kernels.cu +++ b/csrc/kernels.cu @@ -1617,20 +1617,10 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st template __launch_bounds__(256, 3) __global__ void -<<<<<<< HEAD -kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, - unsigned char* state1, unsigned char* state2, - const float beta1, const float beta2, - const float eps, const int step, const float lr, - float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, - float* absmax1, float* absmax2, - float weight_decay, - const float gnorm_scale, const bool skip_zeros, const int n) -{ -======= kOptimizerStatic8bit2StateBlockwise( - T* p, + T* __restrict__ p, T* __restrict__ const g, + T* __restrict__ return_updates, unsigned char* state1, unsigned char* state2, const float beta1, @@ -1649,7 +1639,6 @@ kOptimizerStatic8bit2StateBlockwise( const bool skip_zeros, const int n ) { ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) //const int n_full = n + (n%BLOCK_SIZE); const int n_full = gridDim.x * BLOCK_SIZE; @@ -1834,15 +1823,6 @@ kOptimizerStatic8bit2StateBlockwise( //if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f))) if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j])) { -<<<<<<< HEAD - if (return_updates == nullptr) { - p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); - if(weight_decay > 0.0f) - p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); - } else { - p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))); - } -======= if (OPTIMIZER == ADEMAMIX) { p_vals[j] = T((float)p_vals[j] - lr * ( ((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / ( @@ -1850,12 +1830,15 @@ kOptimizerStatic8bit2StateBlockwise( ) )); } else { - p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + if (return_updates == nullptr) { + p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))))); + } else { + p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps))))); + } } - if(weight_decay > 0.0f) + if (return_updates == nullptr && weight_decay > 0.0f) p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay)); ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) } } @@ -3813,7 +3796,7 @@ MAKE_Optimizer32bit1State(ADAGRAD, float) MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16) #define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \ -template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ +template __global__ void kPreconditionOptimizer32bit2State(gtype* g, gtype* p, \ float* state1, float* state2, float *unorm, \ const float beta1, const float beta2, const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, const int n); \ @@ -3825,28 +3808,19 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half) MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16) -<<<<<<< HEAD template __global__ void kOptimizer32bit2State(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, - const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -======= -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State(half* g, half* p, half* return_updates,float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); -template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, +template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) #define MAKE_PreconditionStatic8bit1State(oname, gtype) \ template __global__ void kPreconditionOptimizerStatic8bit1State(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \ @@ -4006,14 +3980,9 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n); #define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \ -<<<<<<< HEAD template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, gtype* return_updates, \ unsigned char* state1, unsigned char* state2, \ - const float beta1, const float beta2, \ -======= -template __global__ void kOptimizerStatic8bit2StateBlockwise(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \ const float beta1, const float beta2, const float beta3, const float alpha, \ ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) const float eps, const int step, const float lr, \ float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \ float* absmax1, float* absmax2, \ diff --git a/csrc/kernels.cuh b/csrc/kernels.cuh index 8858ea6f9..793c523bd 100644 --- a/csrc/kernels.cuh +++ b/csrc/kernels.cuh @@ -89,13 +89,8 @@ kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* s float weight_decay, const float gnorm_scale, const int n); template __global__ void kOptimizerStatic8bit2StateBlockwise( -<<<<<<< HEAD - T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2, - const float beta1, const float beta2, const float eps, const int step, const float lr, -======= - T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2, + T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2, const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr, ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n); diff --git a/csrc/ops.cu b/csrc/ops.cu index bd4919d3f..e3c99a875 100644 --- a/csrc/ops.cu +++ b/csrc/ops.cu @@ -109,11 +109,7 @@ template void optimizer32bit(T* g, T* p, T* return_up kPreconditionOptimizer32bit2State<<>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); } -<<<<<<< HEAD - kOptimizer32bit2State<<>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); -======= - kOptimizer32bit2State<<>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) + kOptimizer32bit2State<<>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -200,15 +196,10 @@ template void optimizerStatic8bit(T* p, T* g, T* retu #define BLOCKSIZE_1STATE 256 #define NUM_1STATE 1 -<<<<<<< HEAD -template void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) -{ -======= template void optimizerStatic8bitBlockwise( T* p, T* g, + T* return_updates, unsigned char* state1, unsigned char* state2, float beta1, @@ -227,7 +218,6 @@ template void optimizerStatic8bitBlockwise( bool skip_zeros, int n ) { ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) int num_blocks = 0; switch(OPTIMIZER) @@ -236,16 +226,11 @@ template void optimizerStatic8bitBlockwise( case ADEMAMIX: num_blocks = n/BLOCKSIZE_2STATE; num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1; -<<<<<<< HEAD - kOptimizerStatic8bit2StateBlockwise<<>>(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr, - quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); -======= kOptimizerStatic8bit2StateBlockwise<<>>( - p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, + p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n ); ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) CUDA_CHECK_RETURN(cudaPeekAtLastError()); break; case MOMENTUM: @@ -872,13 +857,8 @@ MAKE_optimizerStatic8bit(ADAGRAD, float) #define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \ -<<<<<<< HEAD template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, gtype* return_updates, \ - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ -======= -template void optimizerStatic8bitBlockwise(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \ MAKE_optimizerStatic8bitBlockwise(half, ADAM); diff --git a/csrc/ops.cuh b/csrc/ops.cuh index 1b4ae3fbf..65938e80d 100644 --- a/csrc/ops.cuh +++ b/csrc/ops.cuh @@ -163,13 +163,9 @@ template void optimizerStatic8bit(T* p, T* g, T* retu float weight_decay, const float gnorm_scale, int n); -<<<<<<< HEAD template void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, -======= -template void optimizerStatic8bitBlockwise(T* p, T* g, - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) + unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, + float eps, int step, float lr, float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); diff --git a/csrc/pythonInterface.cpp b/csrc/pythonInterface.cpp index 5f20aee7d..322299963 100644 --- a/csrc/pythonInterface.cpp +++ b/csrc/pythonInterface.cpp @@ -55,11 +55,7 @@ void fname##32bit_grad_##gbits(gtype *g, gtype *p, gtype *return_updates, \ const float beta1, const float beta2, const float beta3, const float alpha, \ const float eps, const float weight_decay, \ const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \ -<<<<<<< HEAD -{ optimizer32bit(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ -======= -{ optimizer32bit(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) +{ optimizer32bit(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ MAKE_FUNC32(momentum, MOMENTUM, float, 32) MAKE_FUNC32(momentum, MOMENTUM, half, 16) @@ -101,17 +97,10 @@ MAKE_FUNC8(lion, LION, float, 32) MAKE_FUNC8(lion, LION, half, 16) #define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \ -<<<<<<< HEAD void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \ - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ -{ optimizerStatic8bitBlockwise(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ -======= -void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\ -{ optimizerStatic8bitBlockwise(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) +{ optimizerStatic8bitBlockwise(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\ MAKE_BLOCKWISE8(adam, ADAM, half, fp16) MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16) @@ -249,11 +238,7 @@ extern "C" const float beta1, const float beta2, const float beta3, const float alpha, \ const float eps, const float weight_decay, \ const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \ -<<<<<<< HEAD - { name##32bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ -======= - { name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) + { name##32bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \ MAKE_CFUNC32(adam, float, fp32) MAKE_CFUNC32(adam, half, fp16) @@ -295,17 +280,11 @@ extern "C" MAKE_CFUNC8(lion, half, 16) #define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \ -<<<<<<< HEAD void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \ - unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \ - float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_grad_##gbits(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ -======= - void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \ unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \ float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \ - { fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ ->>>>>>> d964546 (Add AdEMAMix optimizer (#1360)) + { fname##_8bit_blockwise_grad_##gbits(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \ + MAKE_CBLOCKWISE8(adam, ADAM, half, fp16) MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)