Skip to content

Commit

Permalink
Support eturn_outputs buffer option for 1-state optimizers
Browse files Browse the repository at this point in the history
  • Loading branch information
matthewdouglas committed Oct 30, 2024
1 parent 31854da commit b1fb85b
Show file tree
Hide file tree
Showing 3 changed files with 33 additions and 33 deletions.
50 changes: 25 additions & 25 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
Expand Up @@ -1011,7 +1011,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,

template<typename T, int OPTIMIZER>
__launch_bounds__(TH, 1)
__global__ void kOptimizer32bit1State(T *g, T *p,
__global__ void kOptimizer32bit1State(T *g, T *p, T *return_updates,
float *state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n)
Expand Down Expand Up @@ -1057,13 +1057,13 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
__syncthreads();
LoadFloat(temp_storage.loadf).Load(&(state1[i]), s1_vals, valid_items);
__syncthreads();
Load(temp_storage.load).Load(&(p[i]), p_vals, valid_items);
Load(temp_storage.load).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);

# pragma unroll 4
for(unsigned int j = 0; j < NUM_PER_THREAD; j++)
{
g_vals[j] = gnorm_scale*((float)g_vals[j]);
if(weight_decay > 0.0f)
if(weight_decay > 0.0f && return_updates == nullptr)
g_vals[j] = (float)g_vals[j] + (((float)p_vals[j])*weight_decay);
}

Expand All @@ -1080,26 +1080,26 @@ __global__ void kOptimizer32bit1State(T *g, T *p,
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);

p_vals[j] = ((float)p_vals[j]) + update_scale*(-lr*(s1_vals[j]));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + update_scale*(-lr*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_vals[j]))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*((float)g_vals[j]));
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*((float)g_vals[j])*((float)g_vals[j]));
p_vals[j] = ((float)p_vals[j]) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - update_scale*(lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps));
break;
case ADAGRAD:
s1_vals[j] = s1_vals[j] + ((float)g_vals[j])*((float)g_vals[j]);
p_vals[j] = ((float)p_vals[j]) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*__fdividef((float)g_vals[j],sqrtf((float)s1_vals[j])+eps);
break;
}
}
}

__syncthreads();
Store(temp_storage.store).Store(&(p[i]), p_vals, valid_items);
Store(temp_storage.store).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
__syncthreads();
StoreFloat(temp_storage.storef).Store(&(state1[i]), s1_vals, valid_items);
}
Expand Down Expand Up @@ -1447,7 +1447,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c
template<typename T, int OPTIMIZER>
__global__ void
__launch_bounds__(1024, 1)
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
Expand Down Expand Up @@ -1503,7 +1503,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items);
LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);

if((i + (threadIdx.x*NUM_PER_THREAD2) + NUM_PER_THREAD2) > n){ continue; }

Expand All @@ -1513,7 +1513,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
g_val = float(g_vals[j]);
g_val *= gnorm_scale;

if(weight_decay > 0.0f) {
if(weight_decay > 0.0f && return_updates == nullptr) {
switch(OPTIMIZER) {
case ADAGRAD:
case MOMENTUM:
Expand All @@ -1536,15 +1536,15 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
else
s1_vals[j] = s1_vals[j]*beta1 + ((float)g_vals[j]);

p_vals[j] = ((float)p_vals[j]) + (-lr*update_scale*(s1_vals[j]));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) + (-lr*update_scale*(s1_vals[j]));
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*sgn(((float)s1_vals[j])*beta1 + ((1.0f-beta1)*((float)g_val))));
s1_vals[j] = s1_vals[j]*beta2 + ((1.0f-beta2)*g_val);
break;
case RMSPROP:
s1_vals[j] = s1_vals[j]*beta1 + ((1.0f-beta1)*(g_val*g_val));
p_vals[j] = ((float)p_vals[j]) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - (lr*__fdividef(g_val,sqrtf(s1_vals[j])+eps));
break;
}

Expand All @@ -1560,7 +1560,7 @@ kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
}
}

StoreT(temp_storage.storeh).Store(&(p[i]), p_vals, valid_items);
StoreT(temp_storage.storeh).Store(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items);
__syncthreads();
StoreChar(temp_storage.storec).Store(&(state1[i]), c1s, valid_items);
__syncthreads();
Expand Down Expand Up @@ -1893,7 +1893,7 @@ kOptimizerStatic8bit2StateBlockwise(
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char* state1,
kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
Expand Down Expand Up @@ -1957,7 +1957,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
__syncthreads();
LoadChar(temp_storage.loadc).Load(&(state1[i]), c1s, valid_items, 128);
__syncthreads();
LoadT(temp_storage.loadh).Load(&(p[i]), p_vals, valid_items, (T)0.0f);
LoadT(temp_storage.loadh).Load(return_updates == nullptr ? &(p[i]) : &(return_updates[i]), p_vals, valid_items, (T)0.0f);

new_local_abs_max1 = -FLT_MAX;

Expand All @@ -1969,7 +1969,7 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
g_val *= gnorm_scale;
if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
{
if(weight_decay > 0.0f) {
if(weight_decay > 0.0f && return_updates == nullptr) {
switch(OPTIMIZER) {
case MOMENTUM:
case ADAGRAD:
Expand Down Expand Up @@ -2032,18 +2032,18 @@ kOptimizerStatic8bit1StateBlockwise(T* p, T* __restrict__ const g, unsigned char
switch(OPTIMIZER)
{
case MOMENTUM:
p_vals[j] = ((float)p_vals[j]) - lr*(s1_vals[j]);
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(s1_vals[j]);
break;
case LION:
p_vals[j] = ((float)p_vals[j]) - ((float)g_vals[j]);
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - ((float)g_vals[j]);
break;
case RMSPROP:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break;
case ADAGRAD:
g_val = g_vals[j];
p_vals[j] = ((float)p_vals[j]) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
p_vals[j] = (return_updates == nullptr ? (float)p_vals[j] : 0.0f) - lr*(__fdividef(g_val, sqrtf(s1_vals[j])+eps));
break;
}
}
Expand Down Expand Up @@ -3782,7 +3782,7 @@ MAKE_PreconditionOptimizer32bit1State(ADAGRAD, float)
MAKE_PreconditionOptimizer32bit1State(ADAGRAD, __nv_bfloat16)

#define MAKE_Optimizer32bit1State(oname, gtype) \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, float* state1, float *unorm, const float max_unorm, const float param_norm, \
template __global__ void kOptimizer32bit1State<gtype, oname>(gtype* g, gtype* p, gtype* return_updates, float* state1, float *unorm, const float max_unorm, const float param_norm, \
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n); \

MAKE_Optimizer32bit1State(MOMENTUM, half)
Expand Down Expand Up @@ -3847,7 +3847,7 @@ MAKE_PreconditionStatic8bit1State(ADAGRAD, half)
MAKE_PreconditionStatic8bit1State(ADAGRAD, float)

#define MAKE_optimizerStatic8bit1State(oname, gtype) \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, unsigned char* state1, \
template __global__ void kOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* const g, gtype* return_updates, unsigned char* state1, \
const float *unorm, const float max_unorm, const float param_norm, \
const float beta1, \
const float beta2, \
Expand Down Expand Up @@ -4002,7 +4002,7 @@ MAKE_OptimizerStatic8bit2StateBlockwise(ADEMAMIX, __nv_bfloat16, 256, 1)

#define MAKE_OptimizerStatic8bit1StateBlockwise(oname, gtype, block_size, num_per_thread) \
template __global__ void kOptimizerStatic8bit1StateBlockwise<gtype, oname, block_size, num_per_thread>( \
gtype* p, gtype* __restrict__ const g, unsigned char* state1, \
gtype* p, gtype* __restrict__ const g, gtype* return_updates, unsigned char* state1, \
const float beta1, const float beta2, \
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, \
Expand Down
6 changes: 3 additions & 3 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ __global__ void kPreconditionOptimizer32bit1State(T* g, T* p,
const int step, const float lr, const float gnorm_scale, const int n);

template<typename T, int OPTIMIZER>
__global__ void kOptimizer32bit1State(T* g, T* p,
__global__ void kOptimizer32bit1State(T* g, T* p, T* return_updates,
float* state1, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,
const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
Expand All @@ -57,7 +57,7 @@ kPreconditionOptimizerStatic8bit1State(T* p, T* __restrict__ const g, unsigned c

template<typename T, int OPTIMIZER>
__global__ void
kOptimizerStatic8bit1State(T* p, T* const g, unsigned char* state1,
kOptimizerStatic8bit1State(T* p, T* const g, T* return_updates, unsigned char* state1,
const float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
Expand Down Expand Up @@ -96,7 +96,7 @@ template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ voi
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);

template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit1StateBlockwise(
T* p, T* __restrict__ const g, unsigned char* state1,
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1,
Expand Down
10 changes: 5 additions & 5 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
Expand Up @@ -122,12 +122,12 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_up
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}

kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update after the parameter update
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
kOptimizer32bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());

if(max_unorm > 0.0f)
Expand Down Expand Up @@ -172,13 +172,13 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
CUDA_CHECK_RETURN(cudaMemset(new_max1, 0, 1*sizeof(float)));
kPreconditionOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 256>>>(p, g, state1, unorm, beta1, beta2, eps, step, quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case LION:
// in lion, the momentum update happens after the parameter update
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
kOptimizerStatic8bit1State<T, OPTIMIZER><<<num_blocks, 1024>>>(p, g, return_updates, state1, unorm, max_unorm, param_norm, beta1, beta2, eps, step, lr,
quantiles1, max1, new_max1, weight_decay, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());

Expand Down Expand Up @@ -239,7 +239,7 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
case LION:
num_blocks = n/BLOCKSIZE_1STATE;
num_blocks = n % BLOCKSIZE_1STATE == 0 ? num_blocks : num_blocks + 1;
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, state1, beta1, beta2, eps, step, lr,
kOptimizerStatic8bit1StateBlockwise<T, OPTIMIZER, BLOCKSIZE_1STATE, NUM_1STATE><<<num_blocks, BLOCKSIZE_1STATE/NUM_1STATE>>>(p, g, return_updates, state1, beta1, beta2, eps, step, lr,
quantiles1, absmax1, weight_decay, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
Expand Down

0 comments on commit b1fb85b

Please sign in to comment.