Skip to content

Commit

Permalink
Rebase on main - resolve conflicts
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 28, 2024
1 parent 59883ac commit 61189fc
Show file tree
Hide file tree
Showing 6 changed files with 25 additions and 110 deletions.
4 changes: 0 additions & 4 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,11 +5,7 @@
from collections import abc as container_abcs, defaultdict
from copy import deepcopy
from itertools import chain
<<<<<<< HEAD
from typing import Any, Dict, Optional
=======
from typing import Optional
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

import torch

Expand Down
59 changes: 14 additions & 45 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1617,20 +1617,10 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
<<<<<<< HEAD
kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, T* return_updates,
unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n)
{
=======
kOptimizerStatic8bit2StateBlockwise(
T* p,
T* __restrict__ p,
T* __restrict__ const g,
T* __restrict__ return_updates,
unsigned char* state1,
unsigned char* state2,
const float beta1,
Expand All @@ -1649,7 +1639,6 @@ kOptimizerStatic8bit2StateBlockwise(
const bool skip_zeros,
const int n
) {
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
Expand Down Expand Up @@ -1834,28 +1823,22 @@ kOptimizerStatic8bit2StateBlockwise(
//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
<<<<<<< HEAD
if (return_updates == nullptr) {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
} else {
p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))));
}
=======
if (OPTIMIZER == ADEMAMIX) {
p_vals[j] = T((float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
));
} else {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
if (return_updates == nullptr) {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
} else {
p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))));
}
}

if(weight_decay > 0.0f)
if (return_updates == nullptr && weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
}
}

Expand Down Expand Up @@ -3813,7 +3796,7 @@ MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)

#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float* state2, float *unorm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
Expand All @@ -3825,28 +3808,19 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)

<<<<<<< HEAD
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
=======
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, half* return_updates,float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);

>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
Expand Down Expand Up @@ -4006,14 +3980,9 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);

#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
<<<<<<< HEAD
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, gtype* return_updates, \
unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, \
=======
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, const float beta3, const float alpha, \
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \
Expand Down
7 changes: 1 addition & 6 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -89,13 +89,8 @@ kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* s
float weight_decay, const float gnorm_scale, const int n);

template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
<<<<<<< HEAD
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float eps, const int step, const float lr,
=======
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr,
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);

Expand Down
26 changes: 3 additions & 23 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -109,11 +109,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_up
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
<<<<<<< HEAD
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
=======
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
Expand Down Expand Up @@ -200,15 +196,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
#define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1

<<<<<<< HEAD
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
=======
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p,
T* g,
T* return_updates,
unsigned char* state1,
unsigned char* state2,
float beta1,
Expand All @@ -227,7 +218,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
bool skip_zeros,
int n
) {
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

int num_blocks = 0;
switch(OPTIMIZER)
Expand All @@ -236,16 +226,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
case ADEMAMIX:
num_blocks = n/BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
<<<<<<< HEAD
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
=======
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
skip_zeros, n
);
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
Expand Down Expand Up @@ -872,13 +857,8 @@ MAKE_optimizerStatic8bit(ADAGRAD, float)


#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
<<<<<<< HEAD
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, gtype* return_updates, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
=======
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \

MAKE_optimizerStatic8bitBlockwise(half, ADAM);
Expand Down
8 changes: 2 additions & 6 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -163,13 +163,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
float weight_decay,
const float gnorm_scale, int n);

<<<<<<< HEAD
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
=======
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros, int n);

Expand Down
Loading

0 comments on commit 61189fc

Please sign in to comment.