Skip to content

Commit

Permalink
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Rebase on main - resolve conflicts
Browse files Browse the repository at this point in the history
matthewdouglas committed Oct 28, 2024
1 parent 59883ac commit 61189fc
Showing 6 changed files with 25 additions and 110 deletions.
4 changes: 0 additions & 4 deletions bitsandbytes/optim/optimizer.py
Original file line number Diff line number Diff line change
@@ -5,11 +5,7 @@
from collections import abc as container_abcs, defaultdict
from copy import deepcopy
from itertools import chain
<<<<<<< HEAD
from typing import Any, Dict, Optional
=======
from typing import Optional
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

import torch

59 changes: 14 additions & 45 deletions csrc/kernels.cu
Original file line number Diff line number Diff line change
@@ -1617,20 +1617,10 @@ __global__ void kPercentileClipping(T * __restrict__ g, float *gnorm_vec, int st
template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH>
__launch_bounds__(256, 3)
__global__ void
<<<<<<< HEAD
kOptimizerStatic8bit2StateBlockwise(T* p, T* __restrict__ const g, T* return_updates,
unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2,
const float eps, const int step, const float lr,
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2,
float weight_decay,
const float gnorm_scale, const bool skip_zeros, const int n)
{
=======
kOptimizerStatic8bit2StateBlockwise(
T* p,
T* __restrict__ p,
T* __restrict__ const g,
T* __restrict__ return_updates,
unsigned char* state1,
unsigned char* state2,
const float beta1,
@@ -1649,7 +1639,6 @@ kOptimizerStatic8bit2StateBlockwise(
const bool skip_zeros,
const int n
) {
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

//const int n_full = n + (n%BLOCK_SIZE);
const int n_full = gridDim.x * BLOCK_SIZE;
@@ -1834,28 +1823,22 @@ kOptimizerStatic8bit2StateBlockwise(
//if(!skip_zeros || (skip_zeros && ((float)g_vals[j] != 0.0f)))
if(!isnan((float)g_vals[j]) && !isinf((float)g_vals[j]))
{
<<<<<<< HEAD
if (return_updates == nullptr) {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
if(weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
} else {
p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))));
}
=======
if (OPTIMIZER == ADEMAMIX) {
p_vals[j] = T((float)p_vals[j] - lr * (
((s1_vals[j] / correction1) + (alpha * s3_vals[j])) / (
(sqrtf(s2_vals[j]) / correction2) + eps
)
));
} else {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
if (return_updates == nullptr) {
p_vals[j] = (T)(((float)p_vals[j]) + ((step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))))));
} else {
p_vals[j] = (T)(step_size*(__fdividef(s1_vals[j],(sqrtf(s2_vals[j])+(correction2*eps)))));
}
}

if(weight_decay > 0.0f)
if (return_updates == nullptr && weight_decay > 0.0f)
p_vals[j] = ((float)p_vals[j])*(1.0f-(lr*weight_decay));
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
}
}

@@ -3813,7 +3796,7 @@ MAKE_Optimizer32bit1State(ADAGRAD, float)
MAKE_Optimizer32bit1State(ADAGRAD, __nv_bfloat16)

#define MAKE_PreconditionOptimizer32bit2State(oname, gtype) \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
template __global__ void kPreconditionOptimizer32bit2State<gtype, oname, 4096, 8>(gtype* g, gtype* p, \
float* state1, float* state2, float *unorm, \
const float beta1, const float beta2, const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, const int n); \
@@ -3825,28 +3808,19 @@ MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, float)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, half)
MAKE_PreconditionOptimizer32bit2State(ADEMAMIX, __nv_bfloat16)

<<<<<<< HEAD
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
=======
template __global__ void kOptimizer32bit2State<float, ADAM>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<half, ADAM>(half* g, half* p, half* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADAM>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<float, ADEMAMIX>(float* g, float* p, float* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<half, ADEMAMIX>(half* g, half* p, half* return_updates,float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
template __global__ void kOptimizer32bit2State<__nv_bfloat16, ADEMAMIX>(__nv_bfloat16* g, __nv_bfloat16* p, __nv_bfloat16* return_updates, float* state1, float* state2, float *unorm, const float max_unorm, const float param_norm,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const float weight_decay,const int step, const float lr, const float gnorm_scale, const bool skip_zeros, const int n);

>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

#define MAKE_PreconditionStatic8bit1State(oname, gtype) \
template __global__ void kPreconditionOptimizerStatic8bit1State<gtype, oname>(gtype* p, gtype* __restrict__ const g, unsigned char*__restrict__ const state1, \
@@ -4006,14 +3980,9 @@ template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, General
template __global__ void kDequantizeBlockwise<__nv_bfloat16, 512, 64, 8, NF4>(float *code, unsigned char * A, float * absmax, __nv_bfloat16 *out, const int blocksize, const int n);

#define MAKE_OptimizerStatic8bit2StateBlockwise(oname, gtype, block_size, num_per_thread) \
<<<<<<< HEAD
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, gtype* return_updates, \
unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, \
=======
template __global__ void kOptimizerStatic8bit2StateBlockwise<gtype, oname, block_size, num_per_thread>(gtype* p, gtype* __restrict__ const g, unsigned char* state1, unsigned char* state2, \
const float beta1, const float beta2, const float beta3, const float alpha, \
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
const float eps, const int step, const float lr, \
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2, \
float* absmax1, float* absmax2, \
7 changes: 1 addition & 6 deletions csrc/kernels.cuh
Original file line number Diff line number Diff line change
@@ -89,13 +89,8 @@ kOptimizerStatic8bit2State(T* p, T* const g, T* return_updates, unsigned char* s
float weight_decay, const float gnorm_scale, const int n);

template<typename T, int OPTIMIZER, int BLOCK_SIZE, int N_PER_TH> __global__ void kOptimizerStatic8bit2StateBlockwise(
<<<<<<< HEAD
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float eps, const int step, const float lr,
=======
T* p, T* __restrict__ const g, unsigned char* state1, unsigned char* state2,
T* p, T* __restrict__ const g, T* return_updates, unsigned char* state1, unsigned char* state2,
const float beta1, const float beta2, const float beta3, const float alpha, const float eps, const int step, const float lr,
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
float* __restrict__ const quantiles1, float* __restrict__ const quantiles2,
float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, const bool skip_zeros, const int n);

26 changes: 3 additions & 23 deletions csrc/ops.cu
Original file line number Diff line number Diff line change
@@ -109,11 +109,7 @@ template<typename T, int OPTIMIZER> void optimizer32bit(T* g, T* p, T* return_up
kPreconditionOptimizer32bit2State<T, OPTIMIZER, 4096, 8><<<num_blocks, 512>>>(g, p, state1, state2, unorm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
}
<<<<<<< HEAD
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
=======
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
kOptimizer32bit2State<T, OPTIMIZER><<<num_blocks, 1024>>>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n);
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
@@ -200,15 +196,10 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
#define BLOCKSIZE_1STATE 256
#define NUM_1STATE 1

<<<<<<< HEAD
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)
{
=======
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
T* p,
T* g,
T* return_updates,
unsigned char* state1,
unsigned char* state2,
float beta1,
@@ -227,7 +218,6 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
bool skip_zeros,
int n
) {
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))

int num_blocks = 0;
switch(OPTIMIZER)
@@ -236,16 +226,11 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(
case ADEMAMIX:
num_blocks = n/BLOCKSIZE_2STATE;
num_blocks = n % BLOCKSIZE_2STATE == 0 ? num_blocks : num_blocks + 1;
<<<<<<< HEAD
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n);
=======
kOptimizerStatic8bit2StateBlockwise<T, OPTIMIZER, BLOCKSIZE_2STATE, NUM_2STATE><<<num_blocks, BLOCKSIZE_2STATE/NUM_2STATE>>>(
p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr,
quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale,
skip_zeros, n
);
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
CUDA_CHECK_RETURN(cudaPeekAtLastError());
break;
case MOMENTUM:
@@ -872,13 +857,8 @@ MAKE_optimizerStatic8bit(ADAGRAD, float)


#define MAKE_optimizerStatic8bitBlockwise(gtype, optim_name) \
<<<<<<< HEAD
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, gtype* return_updates, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
=======
template void optimizerStatic8bitBlockwise<gtype, optim_name>(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n); \

MAKE_optimizerStatic8bitBlockwise(half, ADAM);
8 changes: 2 additions & 6 deletions csrc/ops.cuh
Original file line number Diff line number Diff line change
@@ -163,13 +163,9 @@ template<typename T, int OPTIMIZER> void optimizerStatic8bit(T* p, T* g, T* retu
float weight_decay,
const float gnorm_scale, int n);

<<<<<<< HEAD
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g, T* return_updates,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr,
=======
template<typename T, int OPTIMIZER> void optimizerStatic8bitBlockwise(T* p, T* g,
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr,
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha,
float eps, int step, float lr,
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale,
bool skip_zeros, int n);

31 changes: 5 additions & 26 deletions csrc/pythonInterface.cpp
Original file line number Diff line number Diff line change
@@ -55,11 +55,7 @@ void fname##32bit_grad_##gbits(gtype *g, gtype *p, gtype *return_updates, \
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, float gnorm_scale, bool skip_zeros, const int n) \
<<<<<<< HEAD
{ optimizer32bit<gtype, oname>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
=======
{ optimizer32bit<gtype, oname>(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
{ optimizer32bit<gtype, oname>(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \

MAKE_FUNC32(momentum, MOMENTUM, float, 32)
MAKE_FUNC32(momentum, MOMENTUM, half, 16)
@@ -101,17 +97,10 @@ MAKE_FUNC8(lion, LION, float, 32)
MAKE_FUNC8(lion, LION, half, 16)

#define MAKE_BLOCKWISE8(fname, optim_name, gtype, gbits) \
<<<<<<< HEAD
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
=======
void fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n)\
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
{ optimizerStatic8bitBlockwise<gtype, optim_name>(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); }\

MAKE_BLOCKWISE8(adam, ADAM, half, fp16)
MAKE_BLOCKWISE8(adam, ADAM, __nv_bfloat16, bf16)
@@ -249,11 +238,7 @@ extern "C"
const float beta1, const float beta2, const float beta3, const float alpha, \
const float eps, const float weight_decay, \
const int step, const float lr, const float gnorm_scale, bool skip_zeros, const int n) \
<<<<<<< HEAD
{ name##32bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
=======
{ name##32bit_grad_##gbits(g, p, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
{ name##32bit_grad_##gbits(g, p, return_updates, state1, state2, unorm, max_unorm, param_norm, beta1, beta2, beta3, alpha, eps, weight_decay, step, lr, gnorm_scale, skip_zeros, n); } \

MAKE_CFUNC32(adam, float, fp32)
MAKE_CFUNC32(adam, half, fp16)
@@ -295,17 +280,11 @@ extern "C"
MAKE_CFUNC8(lion, half, 16)

#define MAKE_CBLOCKWISE8(fname, optim_name, gtype, gbits) \
<<<<<<< HEAD
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, gtype* return_updates, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_grad_##gbits(p, g, return_updates, state1, state2, beta1, beta2, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
=======
void c##fname##_8bit_blockwise_grad_##gbits(gtype* p, gtype* g, \
unsigned char* state1, unsigned char* state2, float beta1, float beta2, float beta3, float alpha, float eps, int step, float lr, \
float* quantiles1, float* quantiles2, float* absmax1, float* absmax2, float weight_decay, const float gnorm_scale, bool skip_zeros, int n) \
{ fname##_8bit_blockwise_grad_##gbits(p, g, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \
>>>>>>> d964546 (Add AdEMAMix optimizer (#1360))
{ fname##_8bit_blockwise_grad_##gbits(p, g, return_updates, state1, state2, beta1, beta2, beta3, alpha, eps, step, lr, quantiles1, quantiles2, absmax1, absmax2, weight_decay, gnorm_scale, skip_zeros, n); } \


MAKE_CBLOCKWISE8(adam, ADAM, half, fp16)
MAKE_CBLOCKWISE8(adam, ADAM, float, fp32)

0 comments on commit 61189fc

Please sign in to comment.